Merge pull request #72936 from 372046933:range_dtype
PiperOrigin-RevId: 662422313
diff --git a/.bazelrc b/.bazelrc
index 11783a8..edb4659 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -573,6 +573,9 @@
build:rbe_win_clang --linkopt=/FORCE:MULTIPLE
build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE
+# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits.
+build:rbe_windows_x86_cpu --config=rbe_win_clang
+
# END TF REMOTE BUILD EXECUTION OPTIONS
# TFLite build configs for generic embedded Linux
@@ -787,17 +790,19 @@
# PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over
# the whole TF code base. These are usually run continuously or upon presubmit.
-# CPU PYCPP:
+# LINUX CPU PYCPP:
test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# CUDA PYCPP:
+
+# LINUX CUDA PYCPP:
test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# ARM64 PYCPP
+
+# LINUX ARM64 PYCPP
# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on
# Linux x86 so that we can use RBE. Since tests still need to run on the single
# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit.
@@ -830,6 +835,13 @@
# CROSS-COMPILE MACOS X86 PYCPP
build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test
build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test
+# WINDOWS X86-64 CPU PYCPP
+test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600"
+test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only
+test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/...
+
# END TF TEST SUITE OPTIONS
# START CROSS-COMPILE CONFIGS
diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml
index 60af559..2013b8c 100644
--- a/.github/workflows/osv-scanner-scheduled.yml
+++ b/.github/workflows/osv-scanner-scheduled.yml
@@ -28,7 +28,7 @@
jobs:
scan-scheduled:
if: github.repository == 'tensorflow/tensorflow'
- uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.1"
+ uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.2"
with:
scan-args: |-
--lockfile=requirements.txt:./requirements_lock_3_9.txt
diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml
index f457db5..8b1a034 100644
--- a/.github/workflows/pylint-presubmit.yml
+++ b/.github/workflows/pylint-presubmit.yml
@@ -38,7 +38,7 @@
run: |
echo Changed files: ${{ steps.get_file_changes.outputs.files }}
- name: Set up Python 3.9
- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0
+ uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
with:
python-version: "3.9"
- name: Install Python dependencies
diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml
index e72eab8..ceb213e 100644
--- a/.github/workflows/scorecards-analysis.yml
+++ b/.github/workflows/scorecards-analysis.yml
@@ -46,7 +46,7 @@
persist-credentials: false
- name: "Run analysis"
- uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3
+ uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0
with:
results_file: results.sarif
results_format: sarif
@@ -55,7 +55,7 @@
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
- uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3
+ uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4
with:
name: SARIF file
path: results.sarif
@@ -64,6 +64,6 @@
# Upload the results to GitHub's code scanning dashboard (optional).
# Commenting out will disable upload of results to your repo's Code Scanning dashboard
- name: "Upload to code-scanning"
- uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11
+ uses: github/codeql-action/upload-sarif@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15
with:
sarif_file: results.sarif
diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml
index 2c81873..c72cc98 100644
--- a/.github/workflows/sigbuild-docker-branch.yml
+++ b/.github/workflows/sigbuild-docker-branch.yml
@@ -43,16 +43,16 @@
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
-
name: Set up Docker Buildx
- uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0
+ uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1
-
name: Login to DockerHub
- uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+ uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Login to GCR
- uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+ uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
with:
registry: gcr.io
username: _json_key
@@ -67,7 +67,7 @@
-
name: Build and push
id: docker_build
- uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0
+ uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0
with:
push: true
context: ./tensorflow/tools/tf_sig_build_dockerfiles
diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml
index 7de12e1..e21ddb0 100644
--- a/.github/workflows/sigbuild-docker-presubmit.yml
+++ b/.github/workflows/sigbuild-docker-presubmit.yml
@@ -47,16 +47,25 @@
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
-
name: Set up Docker Buildx
- uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0
+ uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1
-
name: Login to GCR
if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging')
- uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+ uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
with:
registry: gcr.io
username: _json_key
password: ${{ secrets.GCP_CREDS }}
-
+ name: Login to AR
+ # Once this is verified, change the label's name. For now, we will piggyback on gcr.io actions.
+ if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging')
+ uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+ with:
+ registry: us-central1-docker.pkg.dev
+ username: _json_key
+ password: ${{ secrets.GCP_CREDS }}
+ -
name: Grab the date to do cache busting (assumes same day OK to keep)
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT"
@@ -64,7 +73,7 @@
-
name: Build containers, and push to GCR only if the 'build and push to gcr.io for staging' label is applied
id: docker_build
- uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0
+ uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0
with:
push: ${{ contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') }}
context: ./tensorflow/tools/tf_sig_build_dockerfiles
@@ -74,6 +83,7 @@
CACHEBUSTER=${{ steps.date.outputs.DATE }}
tags: |
gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }}
+ us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ github.event.number }}-${{ matrix.python-version }}
cache-from: |
type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }}
type=registry,ref=gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }}
diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml
index 062338e..78e7fd7 100644
--- a/.github/workflows/sigbuild-docker.yml
+++ b/.github/workflows/sigbuild-docker.yml
@@ -46,21 +46,29 @@
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
-
name: Set up Docker Buildx
- uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0
+ uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1
-
name: Login to DockerHub
- uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+ uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Login to GCR
- uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+ uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
with:
registry: gcr.io
username: _json_key
password: ${{ secrets.GCP_CREDS }}
-
+ name: Login to AR
+ # Once this is verified, removed gcr.io actions.
+ uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
+ with:
+ registry: us-central1-docker.pkg.dev
+ username: _json_key
+ password: ${{ secrets.GCP_CREDS }}
+ -
name: Grab the upcoming TF version to tag this container
run: |
# [[:digit:]] searches for numbers and \+ joins them together
@@ -74,7 +82,7 @@
-
name: Build and push
id: docker_build
- uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0
+ uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0
with:
push: true
context: ./tensorflow/tools/tf_sig_build_dockerfiles
@@ -87,6 +95,8 @@
tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }}
gcr.io/tensorflow-sigs/build:latest-${{ matrix.python-version }}
gcr.io/tensorflow-sigs/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }}
+ us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:latest-${{ matrix.python-version }}
+ us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }}
cache-from: type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }}
cache-to: type=inline
-
diff --git a/RELEASE.md b/RELEASE.md
index 51b3c60..eb2d939 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -35,6 +35,11 @@
should run synchronously, as opposed to be parallelizable when
`options.experimental_optimization.map_parallelization=True`. This saves
memory compared to setting `num_parallel_calls=1`.
+ * Add optional `use_unbounded_threadpool` argument to `map`, to specify that
+ the `map` should use an unbounded threadpool instead of the default pool
+ that is based on the number of cores on the machine. This can improve
+ throughput for map functions which perform IO or otherwise release the
+ CPU.
* `tf.lite`
* `Dequantize` op supports `TensorType_INT4`.
* This change includes per-channel dequantization.
diff --git a/ci/devinfra/docker_windows/Dockerfile b/ci/devinfra/docker_windows/Dockerfile
deleted file mode 100644
index 540f82a..0000000
--- a/ci/devinfra/docker_windows/Dockerfile
+++ /dev/null
@@ -1,256 +0,0 @@
-FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019@sha256:c1b2be17aa0c1a5d9493a306395a6f07141aae8d7897f7ba319183f28719c990
-
-# Set default powershell policy for this script (ProgressPreference='SilentlyContinue' makes
-# downloads with Invoke-WebRequest not show the progress bar and is MUCH faster).
-SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue'; $VerbosePreference = 'Continue';"]
-
-# Workaround for networking (b/112379377) was closed as won't fix for MTU setting.
-# Remaining lines handle making the metadata server on the VM accessible inside docker.
-RUN Get-NetAdapter | Where-Object Name -like "*Ethernet*" | ForEach-Object { \
- & netsh interface ipv4 set subinterface $_.InterfaceIndex mtu=1460 store=persistent }; \
- $gateway = (Get-NetRoute | Where { $_.DestinationPrefix -eq \"0.0.0.0/0\" } | Sort-Object RouteMetric \
- | Select NextHop).NextHop; \
- $ifIndex = (Get-NetAdapter -InterfaceDescription \"Hyper-V Virtual Ethernet*\" | Sort-Object \
- | Select ifIndex).ifIndex; \
- New-NetRoute -DestinationPrefix 169.254.169.254/32 -InterfaceIndex $ifIndex -NextHop $gateway
-
-# Enable Long Paths for Win32 File/Folder APIs.
-RUN New-ItemProperty -Path HKLM:\SYSTEM\CurrentControlSet\Control\FileSystem \
- -Name LongPathsEnabled -Value 1 -PropertyType DWORD -Force
-
-# Install Visual C++ Redistributable for Visual Studio 2015-2022.
-RUN New-Item -Path "C:/" -Name "TEMP" -ItemType "directory"; \
- Invoke-WebRequest "https://aka.ms/vs/17/release/vc_redist.x64.exe" \
- -OutFile C:/TEMP/vc_redist.x64.exe -UseBasicParsing; \
- Start-Process -filepath C:/TEMP/vc_redist.x64.exe -ArgumentList '/install', '/passive', '/norestart' -Wait; \
- Remove-Item C:/TEMP/vc_redist.x64.exe
-
-# Install Visual Studio 2022 Build Tools. Install ManagedDesktopBuildTools separately to ensure all Optional workloads are installed too.
-RUN Invoke-WebRequest "https://aka.ms/vs/17/release/vs_buildtools.exe" \
- -OutFile C:/TEMP/vs_buildtools.exe -UseBasicParsing; \
- Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \
- "--quiet", "--wait", "--nocache", \
- "--add", "Microsoft.VisualStudio.Workload.VCTools", \
- "--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", \
- "--add", "Microsoft.VisualStudio.Component.Windows10SDK.19041" -Wait; \
- Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \
- "--quiet", "--wait", "--nocache", "--includeOptional", \
- "--add", "Microsoft.VisualStudio.Workload.ManagedDesktopBuildTools" -Wait; \
- Remove-Item C:/TEMP/vs_buildtools.exe; \
- [Environment]::SetEnvironmentVariable(\"BAZEL_VC\", \"C:\VS\VC\", \"Machine\"); \
- $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\VS\VC\Tools\MSVC\14.33.31629\bin\Hostx64\x64;C:\VS\Common7\Tools;C:\VS\MSBuild\Current\Bin\", \"Machine\");
-
-# Add signtool.exe to the PATH. Note this path may need to be edited if updates
-# are made to the Windows 10 SDK.
-RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files (x86)\Windows Kits\10\App Certification Kit\", \"Machine\");
-
-# Install WiX toolset (v4) - Necessary for MSI Installer/Signing builds
-RUN dotnet tool install --global wix
-
-# Install msys2, packages and add to path.
-RUN [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; \
- Invoke-WebRequest "https://repo.msys2.org/distrib/x86_64/msys2-base-x86_64-20220319.sfx.exe" \
- -OutFile msys2_install.exe -UseBasicParsing; \
- .\msys2_install.exe -y -oC:\; \
- Remove-Item msys2_install.exe; \
- function msys() { C:\msys64\usr\bin\bash.exe @('-lc') + @Args; } \
- msys ' '; \
- msys 'pacman --noconfirm -Syy bsdcpio bsdtar bzip2'; \
- msys 'pacman --noconfirm -Syy coreutils curl dash file filesystem findutils'; \
- msys 'pacman --noconfirm -Syy flex gawk gcc-libs grep gzip inetutils info'; \
- msys 'pacman --noconfirm -Syy less lndir mintty ncurses pactoys-git patch'; \
- msys 'pacman --noconfirm -Syy pax-git pkgfile rebase sed tar tftp-hpa time tzcode util-linux which'; \
- $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\msys64;C:\msys64\usr\bin\", \"Machine\");
-
-# Install Go 1.19.1
-RUN Invoke-WebRequest "https://go.dev/dl/go1.19.1.windows-amd64.msi" \
- -OutFile C:/TEMP/go_install.msi -UseBasicParsing; \
- Start-Process C:/TEMP/go_install.msi -ArgumentList "/quiet", "/log", "C:/TEMP/go_install_log.txt", \
- "InstallAllUsers=1", "PrependPath=1" -wait; \
- Remove-Item C:/TEMP/go_install.msi; \
- Remove-Item C:/TEMP/go_install_log.txt
-
-# Install Python 3.
-RUN Invoke-WebRequest "https://www.python.org/ftp/python/3.10.4/python-3.10.4-amd64.exe" \
- -OutFile C:/TEMP/python_install.exe -UseBasicParsing; \
- Start-Process C:/TEMP/python_install.exe -ArgumentList "/quiet", "/log", "C:/TEMP/python_install_log.txt", \
- "InstallAllUsers=1", "PrependPath=1" -wait; \
- Remove-Item C:/TEMP/python_install.exe; \
- Remove-Item C:/TEMP/python_install_log.txt
-
-# Install JDK 17
-RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \
- $zulu_url = \"https://cdn.azul.com/zulu/bin/zulu17.32.13-ca-jdk17.0.2-win_x64.zip\"; \
- $zulu_zip = \"c:/temp/jdk_install.zip\"; \
- $zulu_extracted_path = \"c:/temp/\" + [IO.Path]::GetFileNameWithoutExtension($zulu_url); \
- $zulu_root = \"c:/openjdk\"; \
- (New-Object Net.WebClient).DownloadFile($zulu_url, $zulu_zip); \
- [System.IO.Compression.ZipFile]::ExtractToDirectory($zulu_zip, \"c:/temp\"); \
- Move-Item $zulu_extracted_path -Destination $zulu_root; \
- Remove-Item $zulu_zip; \
- $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${zulu_root}\bin\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"JAVA_HOME\", $zulu_root, \"Machine\")
-
-# Install gcloud (install.bat installs directly into bin folder of extracted zip contents)
-# Install needed gcloud components
-RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \
- $pkg_url = \"https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-396.0.0-windows-x86_64.zip\"; \
- $pkg_zip = \"c:/temp/gcloud.zip\"; \
- $pkg_extracted_path = \"c:/google-cloud-sdk\"; \
- (New-Object Net.WebClient).DownloadFile($pkg_url, $pkg_zip); \
- [System.IO.Compression.ZipFile]::ExtractToDirectory($pkg_zip, \"c:/\"); \
- Start-Process cmd.exe -ArgumentList "/c", "/s", "$pkg_extracted_path/install.bat", "-q" -wait; \
- Remove-Item $pkg_zip; \
- $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${pkg_extracted_path}\bin\", \"Machine\"); \
- $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \
- gcloud components install docker-credential-gcr kubectl gsutil;
-
-# Install cygwin and packages
-# Running a seperate ps1 file since when running inside a Dockerfile, it does
-# not work.
-COPY install/install_cygwin.ps1 c:/
-RUN c:/install_cygwin.ps1; \
- $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Cygwin64\bin\", \"Machine\");
-RUN Remove-Item c:/install_cygwin.ps1
-
-# Install Chocolatey and packages
-RUN Invoke-Expression ((New-Object Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')); \
- $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \
- choco feature enable -n allowGlobalConfirmation; \
- choco install 7zip; \
- choco install 7zip.install; \
- choco install 7zip.portable; \
- choco install anaconda2 --version 5.0.1; \
- choco install anaconda3 --version 5.0.1; \
- choco install android-sdk --version 25.2.3.1; \
- choco install AndroidStudio --version 3.0.1.0; \
- choco install ant --version 1.10.1; \
- choco install ccleaner; \
- choco install chocolatey; \
- choco install chocolatey-core.extension; \
- choco install chocolatey-visualstudio.extension; \
- choco install chocolatey-windowsupdate.extension; \
- choco install cmake.install; \
- choco install dotnetcore-sdk; \
- choco install git; \
- choco install git.install; \
- choco install GoogleChrome; \
- choco install gradle --version 4.4.1; \
- choco install jdk8; \
- choco install KB2533623; \
- choco install KB2919355; \
- choco install KB2919442; \
- choco install KB2999226; \
- choco install KB3033929; \
- choco install KB3035131; \
- choco install maven; \
- choco install ninja; \
- choco install nodejs --version 9.3.0; \
- choco install nodejs.install --version 9.3.0; \
- choco install nuget.commandline; \
- choco install openjdk11; \
- choco install peazip; \
- choco install peazip.install; \
- choco install peazip.portable; \
- choco install php --version 7.2.0; \
- choco install protoc --version 3.2.0; \
- choco install ruby --version 2.5.0.1; \
- choco install swig --version 3.0.9; \
- choco install sysinternals; \
- choco install unrar; \
- choco install unzip; \
- choco install vcredist140; \
- choco install vcredist2015; \
- choco install vim; \
- choco install winrar; \
- choco install zip; \
- choco install Firefox; \
- choco install iisexpress;
-
-RUN cmd /c 'mklink /J c:\Anaconda c:\tools\anaconda2';
-RUN cmd /c 'mklink c:\programdata\chocolatey\bin\rar.exe \"c:\program files\winrar\rar.exe\"';
-
-# Installing pip packages
-RUN pip install --upgrade setuptools; \
- pip install altgraph appdirs cachetools certifi cffi chardet colorama \
- cryptography cycler Cython decorator google-api-python-client \
- google-auth google-auth-httplib2 grpcio httplib2 idna ipython-genutils \
- kiwisolver macholib matplotlib nose numpy packaging pandas pickleshare pip \
- prompt-toolkit protobuf psutil pyasn1 pyasn1-modules pycparser Pygments \
- pyparsing pyreadline python-dateutil pytz pywin32 requests rsa setuptools \
- simplegeneric six Tempita traitlets uritemplate urllib3 virtualenv wcwidth \
- wheel win-unicode-console;
-
-# Hardcoding Android license since I did not find any solution on accepting it
-# through the docker build command. If the licensing agreement changes, this
-# will need to be updated as well.
-RUN New-Item -ItemType Directory -Path C:\Android\android-sdk\licenses; \
- Set-Content -Path .\Android\android-sdk\licenses\android-sdk-license -Value "`n24333f8a63b6825ea9c5514f83c2829b004d1fee" -NoNewLine;
-
-# Add sdkmanager to PATH
-RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Android\android-sdk\tools\bin\", \"Machine\");
-
-# Install android packages
-RUN $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \
- New-Item C:\Users\ContainerAdministrator\.android\repositories.cfg; \
- sdkmanager 'ndk-bundle'; \
- sdkmanager 'platforms;android-33'; \
- sdkmanager 'add-ons;addon-google_apis-google-24'; \
- sdkmanager 'cmake;3.10.2.4988404'; \
- sdkmanager 'cmake;3.18.1'; \
- sdkmanager 'cmake;3.22.1'; \
- sdkmanager 'cmake;3.6.4111459'; \
- sdkmanager 'emulator'; \
- sdkmanager 'system-images;android-27;google_apis;x86'; \
- sdkmanager 'sources;android-27'; \
- sdkmanager 'extras;google;Android_Emulator_Hypervisor_Driver'; \
- sdkmanager 'extras;google;auto'; \
- sdkmanager 'extras;google;google_play_services'; \
- sdkmanager 'extras;google;instantapps'; \
- sdkmanager 'extras;google;m2repository'; \
- sdkmanager 'extras;google;market_apk_expansion'; \
- sdkmanager 'extras;google;market_licensing'; \
- sdkmanager 'extras;google;simulators'; \
- sdkmanager 'extras;google;usb_driver'; \
- sdkmanager 'extras;google;webdriver'; \
- sdkmanager 'extras;android;m2repository'; \
- sdkmanager 'extras;intel;Hardware_Accelerated_Execution_Manager'; \
- sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout;1.0.0'; \
- sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout-solver;1.0.2'; \
- sdkmanager 'patcher;v4'; \
- sdkmanager 'ndk;25.1.8937393'; \
- sdkmanager 'build-tools;27.0.3';
-
-# Install Scoop and packages
-RUN iex \"& {$(irm get.scoop.sh)} -RunAsAdmin\"; \
- scoop install perl; \
- scoop install bazel; \
- scoop install cuda; \
- scoop install azure-functions-core-tools; \
- scoop install azure-cli;
-
-# Setting environment variables
-RUN [Environment]::SetEnvironmentVariable('CYGWIN', 'winsymlinks:native', 'Machine'); \
- [Environment]::SetEnvironmentVariable('HOME', 'C:\Users\ContainerAdministrator\', 'Machine'); \
- [Environment]::SetEnvironmentVariable('HOMEDRIVE', 'C:', 'Machine'); \
- [Environment]::SetEnvironmentVariable('HOMEPATH', '\Users\ContainerAdministrator\', 'Machine'); \
- [Environment]::SetEnvironmentVariable('GOROOT', 'C:\Program Files\Go\', 'Machine'); \
- [Environment]::SetEnvironmentVariable('KOKORO_POSIX_ROOT', '/tmpfs', 'Machine'); \
- [Environment]::SetEnvironmentVariable('KOKORO_ROOT', 'T:\', 'Machine'); \
- [Environment]::SetEnvironmentVariable('SHELL', '/bin/bash', 'Machine'); \
- $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
- [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files\CMake\bin\", \"Machine\");
-
-
-# Restore default shell for Windows containers.
-SHELL ["cmd.exe", "/s", "/c"]
-
-# Default to PowerShell if no other command specified.
-CMD ["powershell.exe", "-NoLogo", "-ExecutionPolicy", "Bypass"]
diff --git a/ci/official/containers/linux_arm64/build.sh b/ci/official/containers/linux_arm64/build.sh
index 611d5f4..ffead7f 100755
--- a/ci/official/containers/linux_arm64/build.sh
+++ b/ci/official/containers/linux_arm64/build.sh
@@ -16,8 +16,8 @@
# Builds the following Docker images for Linux ARM64. See the accompanying
# Dockerfile for more details:
-# - gcr.io/tensorflow-sigs/build-arm64:jax-latest-multi-python
-# - gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
+# - us-central1-docker.pkg.dev/tensorflow-sigs/build-arm64:jax-latest-multi-python
+# - us-central1-docker.pkg.dev/tensorflow-sigs/build-arm64:tf-latest-multi-python
set -exo pipefail
@@ -40,16 +40,14 @@
fi
fi
-# TODO(b/341050361): When these steps are verified, removed the GCR image code.
AR_IMAGE_PATH="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64"
# Build for both JAX and TF usage. We do these in one place because they share
# almost all of the same cache layers
export DOCKER_BUILDKIT=1
for target in jax tf; do
- IMAGE="gcr.io/tensorflow-sigs/build-arm64:$target-$TAG"
AR_IMAGE="$AR_IMAGE_PATH:$target-$TAG"
- docker pull "$IMAGE" || true
+ docker pull "$AR_IMAGE" || true
# Due to some flakiness of resources pulled in the build, allow the docker
# command to reattempt build a few times in the case of failure (b/302558736)
set +e
@@ -58,8 +56,8 @@
docker build \
--build-arg REQUIREMENTS_FILE=jax.requirements.txt \
--target=$target \
- --cache-from "$IMAGE" \
- -t "$IMAGE" -t "$AR_IMAGE" . && break
+ --cache-from "$AR_IMAGE" \
+ -t "$AR_IMAGE" . && break
done
final=$?
if [ $final -ne 0 ]; then
@@ -68,8 +66,6 @@
set -e
if [[ -n "$KOKORO_BUILD_ID" ]]; then
- gcloud auth configure-docker
- docker push "$IMAGE"
gcloud auth configure-docker us-central1-docker.pkg.dev
docker push "$AR_IMAGE"
fi
diff --git a/ci/official/envs/rbe b/ci/official/envs/rbe
index 12cc600..35f8173 100644
--- a/ci/official/envs/rbe
+++ b/ci/official/envs/rbe
@@ -33,7 +33,17 @@
fi
TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config rbe_$TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX"
-# These flags share the user's gcloud credentials with the container, so that bazel
-# inside the container can authenticate. Note: TF's CI does not have any credential
-# stored here.
-TFCI_DOCKER_ARGS="$TFCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud"
+if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+ # Docker on Windows doesn't support the `host` networking mode, and so
+ # port-forwarding is required for the container to detect it's running on GCE.
+ export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress")
+ netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80
+ # A local firewall rule for the container is added in
+ # ci/official/utilities/setup_docker.sh.
+else
+ # The volume mapping flag below shares the user's gcloud credentials, if any,
+ # with the container, in case the user has credentials stored there.
+ # This would allow Bazel to authenticate for RBE.
+ # Note: TF's CI does not have any credentials stored there.
+ TFCI_DOCKER_ARGS="$TFCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud"
+fi
diff --git a/ci/official/envs/windows_x86 b/ci/official/envs/windows_x86
new file mode 100644
index 0000000..568a47f
--- /dev/null
+++ b/ci/official/envs/windows_x86
@@ -0,0 +1,20 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+TFCI_DOCKER_ENABLE=1
+TFCI_DOCKER_PULL_ENABLE=1
+TFCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc"
+TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION"
+TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=windows_x86_cpu
+TFCI_OUTPUT_DIR=build_output
diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh
index cf34600..f6f2090 100755
--- a/ci/official/pycpp.sh
+++ b/ci/official/pycpp.sh
@@ -15,12 +15,19 @@
# ==============================================================================
source "${BASH_SOURCE%/*}/utilities/setup.sh"
-if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then
- tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
+if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+ PROFILE_JSON_PATH=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR")
+ PROFILE_JSON_PATH="$PROFILE_JSON_PATH/profile.json.gz"
else
- tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
+ PROFILE_JSON_PATH="$TFCI_OUTPUT_DIR/profile.json.gz"
+fi
+
+if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then
+ tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
+else
+ tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
fi
# Note: the profile can be viewed by visiting chrome://tracing in a Chrome browser.
# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling
-tfrun bazel analyze-profile "$TFCI_OUTPUT_DIR/profile.json.gz"
+tfrun bazel analyze-profile "$PROFILE_JSON_PATH"
diff --git a/ci/official/utilities/cleanup_summary.sh b/ci/official/utilities/cleanup_summary.sh
index dbe2203..6b6fdfa 100755
--- a/ci/official/utilities/cleanup_summary.sh
+++ b/ci/official/utilities/cleanup_summary.sh
@@ -23,8 +23,9 @@
can view more detailed results that are probably easier to read than this log.
Try the links below:
EOF
- # Find any "Streaming build results to" line, then print the last word in it,
- # and don't print duplicates
+ # Find any "Streaming build results to" lines,
+ # de-duplicate,
+ # and print the last word from each
awk '/Streaming build results to/ {print $NF}' "$TFCI_OUTPUT_DIR/script.log" | uniq
}
@@ -32,14 +33,15 @@
# Each failed target there will have its own representation, making failures
# easier to find and read.
function resultstore_extract {
- local \
- XML_PATH="$TFCI_OUTPUT_DIR/Bazel_Test_and_Build_Results/sponge_log.xml"
+ local PYTHON_BIN XML_PATH
+ PYTHON_BIN=$(which python3 2>/dev/null || which python)
+ XML_PATH="$TFCI_OUTPUT_DIR/Bazel_Test_and_Build_Results/sponge_log.xml"
- python3 \
+ "$PYTHON_BIN" \
"$TFCI_GIT_DIR/ci/official/utilities/extract_resultstore_links.py" \
"$TFCI_OUTPUT_DIR/script.log" \
--print \
- --xml-out-path "$XML_PATH" || resultstore_extract_fallback
+ --xml-out-path "$XML_PATH"
}
if grep -q "Streaming build results to" "$TFCI_OUTPUT_DIR/script.log"; then
diff --git a/ci/official/utilities/convert_msys_paths_to_win_paths.py b/ci/official/utilities/convert_msys_paths_to_win_paths.py
new file mode 100644
index 0000000..ed1dd3b
--- /dev/null
+++ b/ci/official/utilities/convert_msys_paths_to_win_paths.py
@@ -0,0 +1,76 @@
+#!/usr/bin/python3
+# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+"""Converts MSYS Linux-like paths stored in env variables to Windows paths.
+
+This is necessary on Windows, because some applications do not understand/handle
+Linux-like paths MSYS uses, for example, Docker.
+"""
+
+import argparse
+import os
+
+
+def should_convert(var_name: str,
+ blacklist: list[str] | None,
+ whitelist_prefix: list[str] | None):
+ """Check the variable name against white/black lists."""
+ if blacklist and var_name in blacklist:
+ return False
+ if not whitelist_prefix:
+ return True
+
+ for prefix in whitelist_prefix:
+ if var_name.startswith(prefix):
+ return True
+ return False
+
+
+def main(parsed_args: argparse.Namespace):
+ converted_vars = {}
+
+ for var, value in os.environ.items():
+ if not value or not should_convert(var,
+ parsed_args.blacklist,
+ parsed_args.whitelist_prefix):
+ continue
+
+ # In Python, MSYS, Linux-like paths are automatically read as Windows paths
+ # with forward slashes, e.g. 'C:/Program Files', instead of
+ # '/c/Program Files', thus becoming converted simply by virtue of having
+ # been read.
+ converted_vars[var] = value
+
+ var_str = '\n'.join(f'{k}="{v}"'
+ for k, v in converted_vars.items())
+ # The string can then be piped into `source`, to re-set the
+ # 'converted' variables.
+ print(var_str)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description=(
+ 'Convert MSYS paths in environment variables to Windows paths.'))
+ parser.add_argument('--blacklist',
+ nargs='*',
+ help='List of variables to ignore')
+ parser.add_argument('--whitelist-prefix',
+ nargs='*',
+ help='Prefix for variables to include')
+ args = parser.parse_args()
+
+ main(args)
diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh
index 2277b75..55d0e07 100755
--- a/ci/official/utilities/setup.sh
+++ b/ci/official/utilities/setup.sh
@@ -118,6 +118,12 @@
# functionality instead.
tfrun() { "$@"; }
+if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+ source ./ci/official/utilities/windows.sh
+ echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)'
+ source <(python ./ci/official/utilities/convert_msys_paths_to_win_paths.py --whitelist-prefix TFCI_)
+fi
+
# Run all "tfrun" commands under Docker. See setup_docker.sh for details
if [[ "$TFCI_DOCKER_ENABLE" == 1 ]]; then
source ./ci/official/utilities/setup_docker.sh
diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh
index 91618c7..61db7c2 100755
--- a/ci/official/utilities/setup_docker.sh
+++ b/ci/official/utilities/setup_docker.sh
@@ -37,10 +37,30 @@
# Pass all existing TFCI_ variables into the Docker container
env_file=$(mktemp)
env | grep ^TFCI_ > "$env_file"
- docker run $TFCI_DOCKER_ARGS --name tf -w "$TFCI_GIT_DIR" -itd --rm \
- -v "$TFCI_GIT_DIR:$TFCI_GIT_DIR" \
+
+ WORKING_DIR="$TFCI_GIT_DIR"
+ if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+ env_file=$(cygpath -m $env_file)
+ # Host dirs can only be mapped to an existing drive inside the container, so
+ # T:\ is replaced with C:\.
+ _TFCI_OUTPUT_DIR_WIN=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR")
+ sed -iE 's|^TFCI_OUTPUT_DIR=.*|TFCI_OUTPUT_DIR='"$_TFCI_OUTPUT_DIR_WIN"'|g' $env_file
+ WORKING_DIR=$(replace_drive_letter_with_c "$TFCI_GIT_DIR")
+ echo "GCE_METADATA_HOST=$IP_ADDR" > $env_file
+ fi
+
+ docker run $TFCI_DOCKER_ARGS --name tf -w "$WORKING_DIR" -itd --rm \
+ -v "$TFCI_GIT_DIR:$WORKING_DIR" \
--env-file "$env_file" \
"$TFCI_DOCKER_IMAGE" \
bash
+
+ if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+ # Allow requests from the container.
+ # Additional setup is contained in ci/official/envs/rbe.
+ CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' tf)
+ netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP_ADDR"
+ fi
+
fi
tfrun() { docker exec tf "$@"; }
diff --git a/ci/official/utilities/windows.sh b/ci/official/utilities/windows.sh
new file mode 100644
index 0000000..1ab2d89
--- /dev/null
+++ b/ci/official/utilities/windows.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Windows-specific utilities.
+#
+
+# Docker on Windows has difficulty using volumes other than C:\, when it comes
+# to setting up up volume mappings.
+# Thus, the drive letter is replaced with C:\, in case it's
+# something else (ex. T:), which is frequently the case inside Kokoro jobs.
+function replace_drive_letter_with_c () {
+ sed -E "s|^[a-zA-Z]:|C:|g" <<< $1
+}
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 3556df1..c96cd8c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -1445,7 +1445,7 @@
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:reference_ops",
- "//tensorflow/lite/schema:schema_fbs",
+ "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/lite/toco/logging:conversion_log_util",
"//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc",
"//tensorflow/lite/toco:model_flags_proto_cc",
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc
index 86b201f..1123ccb 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc
@@ -24,6 +24,7 @@
#include "tensorflow/c/tf_buffer.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/function.h"
@@ -32,7 +33,6 @@
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace parallel_device {
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
index ec2ce95..88ef5c1 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
@@ -27,8 +27,8 @@
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
// integration testing rather than purely testing the parallel device. They
diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc
index 57f1a65..357432e 100644
--- a/tensorflow/c/experimental/grappler/grappler_test.cc
+++ b/tensorflow/c/experimental/grappler/grappler_test.cc
@@ -19,6 +19,7 @@
#include "tensorflow/c/tf_buffer.h"
#include "tensorflow/c/tf_buffer_internal.h"
#include "tensorflow/c/tf_status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -33,7 +34,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/protobuf/error_codes.pb.h"
namespace tensorflow {
diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD
index 56586f7..45c55c3 100644
--- a/tensorflow/c/experimental/next_pluggable_device/BUILD
+++ b/tensorflow/c/experimental/next_pluggable_device/BUILD
@@ -87,7 +87,6 @@
"//tensorflow/core/tfrt/common:pjrt_util",
"@com_google_absl//absl/log:check",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
@@ -98,5 +97,6 @@
"@local_xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl",
"@local_xla//xla/pjrt/cpu:cpu_client",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
index 7f45fd9..1952364 100644
--- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
+++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
@@ -30,10 +30,10 @@
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/tfrt/common/async_value_tensor.h"
#include "tensorflow/core/tfrt/common/pjrt_util.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD
index 7284261..76f1db6 100644
--- a/tensorflow/c/experimental/ops/BUILD
+++ b/tensorflow/c/experimental/ops/BUILD
@@ -22,11 +22,11 @@
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
- "//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:tracing_utils",
- "//tensorflow/core:framework",
- "//tensorflow/core/platform:errors",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/platform:status",
"@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
],
)
diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc
index db9464d..23deef1 100644
--- a/tensorflow/c/experimental/ops/array_ops.cc
+++ b/tensorflow/c/experimental/ops/array_ops.cc
@@ -17,11 +17,14 @@
#include "tensorflow/c/experimental/ops/array_ops.h"
+#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
+#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/tracing_utils.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/status.h"
+#include "tsl/platform/errors.h"
using tensorflow::tracing::MaybeSetOpName;
diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h
index f4d170a..466c36f 100644
--- a/tensorflow/c/experimental/ops/array_ops.h
+++ b/tensorflow/c/experimental/ops/array_ops.h
@@ -18,8 +18,11 @@
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
+#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace ops {
diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc
index 3a6de51..463f64c 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc
+++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc
@@ -19,6 +19,7 @@
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/cc/saved_model/constants.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/types.pb.h"
@@ -26,7 +27,6 @@
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
index d06608f..18d7498 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
+++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
@@ -20,12 +20,12 @@
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc
index 94441a7..ece9057 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor.cc
+++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc
@@ -230,7 +230,7 @@
return std::make_unique<HostMemoryAllocation>(buffer, size, this);
}
- void HostMemoryDeallocate(void* mem, uint64_t size) override {
+ void HostMemoryDeallocate(void* mem) override {
stream_executor_->host_memory_deallocate(&device_, mem);
}
@@ -432,7 +432,6 @@
name_(platform.name) {}
CPlatform::~CPlatform() {
- executor_cache_.DestroyAllExecutors();
platform_fns_.destroy_device_fns(&platform_, &device_fns_);
platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
@@ -454,6 +453,11 @@
config.ordinal = ordinal;
return GetExecutor(config);
}
+absl::StatusOr<StreamExecutor*> CPlatform::FindExisting(int ordinal) {
+ stream_executor::StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ return executor_cache_.Get(config);
+}
absl::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
const StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate(
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
index e3e025c..d87794a 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
+++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
@@ -19,6 +19,7 @@
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
#include <cstdint>
+#include <memory>
#include <string>
#include <utility>
@@ -99,12 +100,15 @@
absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
- absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
- const StreamExecutorConfig& config) override;
-
- void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
+ absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
private:
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
+ absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+ const StreamExecutorConfig& config);
+
SP_Platform platform_;
void (*destroy_platform_)(SP_Platform*);
SP_PlatformFns platform_fns_;
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
index 0082c65..20aded8 100644
--- a/tensorflow/c/kernels.cc
+++ b/tensorflow/c/kernels.cc
@@ -794,10 +794,7 @@
#else
const auto* device = reinterpret_cast<const tensorflow::Device*>(
device_base->UnderlyingDevice());
- const absl::StatusOr<int> id = tsl::GetDeviceIdFromDeviceParsedName(
- device->parsed_name(), tensorflow::DeviceType(device->device_type()));
- if (!id.ok()) return -1;
- return *id;
+ return tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name());
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
}
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index da27e61..02fd678 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -588,12 +588,12 @@
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/cc/saved_model/bundle_v2_test.cc b/tensorflow/cc/saved_model/bundle_v2_test.cc
index a0bbb82..1380282 100644
--- a/tensorflow/cc/saved_model/bundle_v2_test.cc
+++ b/tensorflow/cc/saved_model/bundle_v2_test.cc
@@ -28,10 +28,10 @@
#include "json/reader.h"
#include "json/value.h"
#include "tensorflow/cc/saved_model/metrics.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
namespace tensorflow {
diff --git a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc
index 1f6b0e1..3182afc 100644
--- a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc
+++ b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc
@@ -25,6 +25,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -32,7 +33,6 @@
#include "tensorflow/tools/proto_splitter/cc/util.h"
#include "tensorflow/tools/proto_splitter/chunk.pb.h"
#include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index d74c0cd..eb4ef40 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -23,6 +23,7 @@
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/state_ops.h"
#include "tensorflow/cc/saved_model/loader.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
@@ -35,7 +36,6 @@
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
namespace tensorflow {
diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc
index e3f9f78..211fd1e 100644
--- a/tensorflow/cc/training/coordinator_test.cc
+++ b/tensorflow/cc/training/coordinator_test.cc
@@ -16,6 +16,7 @@
#include "tensorflow/cc/training/coordinator.h"
#include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/env.h"
@@ -23,7 +24,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/protobuf/error_codes.pb.h"
namespace tensorflow {
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 9a5f612..f4de69b 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -26,6 +26,7 @@
#include "tensorflow/cc/ops/random_ops.h"
#include "tensorflow/cc/ops/state_ops.h"
#include "tensorflow/cc/training/coordinator.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
@@ -40,7 +41,6 @@
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index eddd237..6efe665 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -518,7 +518,6 @@
":internal",
# We reuse VariableInfo in TFRT's implementation of TpuExecuteOp.
"//learning/brain/tfrt/tf_tpu:__pkg__",
- "//learning/brain/tfrt/tpu_plugin:__pkg__",
"//learning/brain/tfrt/tpu_common:__pkg__",
"//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
],
@@ -539,9 +538,6 @@
":internal",
# We reuse VariableInfo in TFRT's implementation of TpuExecuteOp.
"//learning/brain/tfrt/tf_tpu:__pkg__",
- "//learning/brain/tfrt/tpu_plugin:__pkg__",
- "//learning/brain/tfrt/tpu_common:__pkg__",
- "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
],
deps = [
":variable_info",
@@ -612,8 +608,6 @@
# We reuse VariableInfo in TFRT's implementation of TpuExecuteOp.
"//learning/brain/tfrt/tf_tpu:__pkg__",
"//learning/brain/tfrt/tpu_plugin:__pkg__",
- "//learning/brain/tfrt/tpu_common:__pkg__",
- "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
"//tensorflow/core/tfrt/gpu/kernel:__pkg__",
],
deps = [
@@ -678,7 +672,6 @@
"//tensorflow/core/tfrt/common:pjrt_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/pjrt:pjrt_client",
@@ -686,6 +679,7 @@
"@local_xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@local_xla//xla/tests:literal_test_util",
"@local_xla//xla/tsl/framework:device_id_utils",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -712,11 +706,11 @@
"//tensorflow/core/tfrt/common:pjrt_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/pjrt:pjrt_client",
"@local_xla//xla/tests:literal_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -726,7 +720,6 @@
hdrs = ["xla_compile_util.h"],
visibility = [
":internal",
- "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
"//tensorflow/core/tfrt/gpu/kernel:__pkg__",
],
deps = [
@@ -770,10 +763,7 @@
name = "device_compiler",
hdrs = ["device_compiler.h"],
copts = tf_copts(),
- visibility = [
- ":internal",
- "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
- ],
+ visibility = [":internal"],
deps = [
":device_compilation_cache",
":device_compilation_cluster_signature",
@@ -1118,7 +1108,6 @@
],
visibility = [
":internal",
- "//tensorflow/core/tfrt/utils:__pkg__",
"//third_party/cloud_tpu/inference_converter:__pkg__",
"//waymo/onboard/ml/chauffeur_net:__pkg__",
],
@@ -1564,10 +1553,7 @@
name = "device_compiler_client",
srcs = ["device_compiler_client.cc"],
hdrs = ["device_compiler_client.h"],
- visibility = [
- ":internal",
- "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
- ],
+ visibility = [":internal"],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core/util:determinism",
@@ -1596,6 +1582,7 @@
cc_library(
name = "device_executable_persistor",
+ srcs = ["device_executable_persistor.cc"],
hdrs = ["device_executable_persistor.h"],
deps = [
":xla_compilation_cache_proto_cc",
@@ -1608,6 +1595,8 @@
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:statusor",
"@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla:util",
"@local_xla//xla/pjrt:pjrt_client",
diff --git a/tensorflow/compiler/jit/device_context_test.cc b/tensorflow/compiler/jit/device_context_test.cc
index be85ff9..d02337d 100644
--- a/tensorflow/compiler/jit/device_context_test.cc
+++ b/tensorflow/compiler/jit/device_context_test.cc
@@ -21,8 +21,8 @@
#include <gtest/gtest.h>
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/jit/device_executable_persistor.cc b/tensorflow/compiler/jit/device_executable_persistor.cc
new file mode 100644
index 0000000..b673af7
--- /dev/null
+++ b/tensorflow/compiler/jit/device_executable_persistor.cc
@@ -0,0 +1,37 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/device_executable_persistor.h"
+
+#include <string>
+
+#include "absl/strings/str_cat.h"
+
+namespace tensorflow {
+
+std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key) {
+ static constexpr char kXlaSerializedCacheKeySeparator[] = "__";
+ return absl::StrCat(
+ key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator,
+ key.signature_fingerprint(), kXlaSerializedCacheKeySeparator,
+ key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator,
+ key.device_type(),
+ key.compiled_using_pjrt()
+ ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt")
+ : "",
+ ".pb");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h
index 78d2089..0f546c0 100644
--- a/tensorflow/compiler/jit/device_executable_persistor.h
+++ b/tensorflow/compiler/jit/device_executable_persistor.h
@@ -20,6 +20,7 @@
#include <string>
#include "absl/log/log.h"
+#include "absl/status/status.h"
#include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
#include "tensorflow/compiler/jit/xla_device_compiler_client.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -35,6 +36,9 @@
namespace tensorflow {
+// Returns the persisted compilation cache file name for the given key.
+std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key);
+
// Offers a way to persist and/or load compiled `ExecutableType`s along with the
// corresponding HLO (`CompilationResult`) to/from `persistent_cache_directory`
// (if one was provided during construction) on disk using `ClientType`.
@@ -142,8 +146,6 @@
const xla::HloModuleProto& hlo_module,
const XlaSerializedCacheEntry& entry) const;
- std::string XlaSerializedCacheKeyToString(
- const XlaSerializedCacheKey& key) const;
std::string GetFilePath(const XlaSerializedCacheKey& key) const;
const DeviceType device_type_;
@@ -173,24 +175,9 @@
config.persistent_cache_directory_read_only) {}
template <typename ExecutableType, typename ClientType>
-std::string DeviceExecutablePersistor<ExecutableType, ClientType>::
- XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) const {
- static constexpr char kXlaSerializedCacheKeySeparator[] = "__";
- return absl::StrCat(
- key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator,
- key.signature_fingerprint(), kXlaSerializedCacheKeySeparator,
- key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator,
- key.device_type(),
- key.compiled_using_pjrt()
- ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt")
- : "");
-}
-
-template <typename ExecutableType, typename ClientType>
std::string DeviceExecutablePersistor<ExecutableType, ClientType>::GetFilePath(
const XlaSerializedCacheKey& key) const {
- const std::string file_name =
- absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb");
+ const std::string file_name = XlaSerializedCacheKeyToFileName(key);
return io::JoinPath(persistent_cache_directory_, file_name);
}
@@ -299,9 +286,10 @@
// Write to temp location, then when that completes, atomically move into the
// final location.
- std::string temp_path = io::JoinPath(
- persistent_cache_directory_, XlaSerializedCacheKeyToString(entry.key()));
- if (!env->CreateUniqueFileName(&temp_path, ".pb.tmp")) {
+ std::string temp_path =
+ io::JoinPath(persistent_cache_directory_,
+ XlaSerializedCacheKeyToFileName(entry.key()));
+ if (!env->CreateUniqueFileName(&temp_path, ".tmp")) {
return absl::UnavailableError(absl::StrCat(
"Could not create a unique file inside ", persistent_cache_directory_));
}
diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc
index 51b6e57..794f32d 100644
--- a/tensorflow/compiler/jit/pjrt_device_context.cc
+++ b/tensorflow/compiler/jit/pjrt_device_context.cc
@@ -52,12 +52,10 @@
cpu_tensor->shape(), cpu_tensor->dtype(),
/*fast_mem=*/false, layout_preference));
const xla::Layout* device_layout = &(shape.layout());
- // The device id should match the local_hardware_id in
+ // The device id should match the local_device_id in
// tensorflow/compiler/xla/pjrt/pjrt_client.h.
- TF_ASSIGN_OR_RETURN(
- const int pjrt_device_id,
- tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name(),
- DeviceType(device->device_type())));
+ const int pjrt_device_id =
+ tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name());
TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device,
pjrt_client->LookupAddressableDevice(
xla::PjRtLocalDeviceId(pjrt_device_id)));
@@ -260,12 +258,10 @@
xla::PjRtBuffer* src_device_buffer =
tensorflow::AsyncValueTensor::FromTensor(input)->GetBuffer().get();
- // The device id should match the local_hardware_id in
+ // The device id should match the local_device_id in
// tensorflow/compiler/xla/pjrt/pjrt_client.h.
const int pjrt_dst_device_id =
- tsl::GetDeviceIdFromDeviceParsedName(dst->parsed_name(),
- DeviceType(dst->device_type()))
- .value();
+ tsl::GetDeviceIdFromDeviceParsedName(dst->parsed_name());
xla::PjRtDevice* pjrt_dst_device =
(*pjrt_dst_client)
->LookupAddressableDevice(xla::PjRtLocalDeviceId(pjrt_dst_device_id))
diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc
index 3f96101..eaabf18 100644
--- a/tensorflow/compiler/jit/shape_inference_test.cc
+++ b/tensorflow/compiler/jit/shape_inference_test.cc
@@ -29,6 +29,7 @@
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -36,7 +37,6 @@
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc
index 62da04c..bec124f 100644
--- a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc
+++ b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc
@@ -26,8 +26,8 @@
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index ebeeaef..27a8f16 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -853,9 +853,8 @@
->use_pjrt_tensor_buffer;
const DeviceType& device_type = GetDeviceType(ctx);
- TF_ASSIGN_OR_RETURN(const int pjrt_device_id,
- tsl::GetDeviceIdFromDeviceParsedName(
- ctx->device()->parsed_name(), device_type));
+ const int pjrt_device_id =
+ tsl::GetDeviceIdFromDeviceParsedName(ctx->device()->parsed_name());
TF_ASSIGN_OR_RETURN(xla::PjRtDevice * device,
pjrt_client->LookupAddressableDevice(
xla::PjRtLocalDeviceId(pjrt_device_id)));
diff --git a/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc b/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc
index 0ba66c2..563e75c 100644
--- a/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc
+++ b/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc
@@ -39,6 +39,7 @@
#include "xla/pjrt/pjrt_client.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/framework/allocator.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/fake_input.h"
@@ -50,7 +51,6 @@
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/tfrt/common/create_pjrt_client_util.h"
#include "tensorflow/core/tfrt/common/pjrt_util.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc
index d19e4fc..443fdf3 100644
--- a/tensorflow/compiler/jit/xla_launch_util_test.cc
+++ b/tensorflow/compiler/jit/xla_launch_util_test.cc
@@ -34,6 +34,7 @@
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/framework/allocator.h"
#include "xla/tsl/framework/device_id_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/fake_input.h"
@@ -45,7 +46,6 @@
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/tfrt/common/create_pjrt_client_util.h"
#include "tensorflow/core/tfrt/common/pjrt_util.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
@@ -674,9 +674,8 @@
->tensorflow_accelerator_device_info()
->use_pjrt_tensor_buffer;
const DeviceType& device_type = GetDeviceType(context_.get());
- TF_ASSERT_OK_AND_ASSIGN(const int pjrt_device_id,
- tsl::GetDeviceIdFromDeviceParsedName(
- context_->device()->parsed_name(), device_type));
+ const int pjrt_device_id =
+ tsl::GetDeviceIdFromDeviceParsedName(context_->device()->parsed_name());
TF_ASSERT_OK_AND_ASSIGN(xla::PjRtDevice * pjrt_device,
pjrt_client_->LookupAddressableDevice(
xla::PjRtLocalDeviceId(pjrt_device_id)));
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index a9b1e1b..75ca89d 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -318,6 +318,13 @@
],
)
+cc_library(
+ name = "stateful_error_reporter",
+ hdrs = ["stateful_error_reporter.h"],
+ compatible_with = get_compatible_with_portable(),
+ deps = ["//tensorflow/compiler/mlir/lite/core/api:error_reporter"],
+)
+
gentbl_cc_library(
name = "tensorflow_lite_canonicalize_inc_gen",
compatible_with = get_compatible_with_portable(),
@@ -1095,8 +1102,8 @@
name = "flatbuffer_to_string",
srcs = ["flatbuffer_to_string.cc"],
deps = [
+ "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_reflection",
- "//tensorflow/lite/core:model_builder",
"@flatbuffers",
],
)
@@ -1146,7 +1153,6 @@
"//tensorflow/core:framework",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/lite/core:framework",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/tools/versioning",
"//tensorflow/lite/tools/versioning:gpu_compatibility",
@@ -1189,8 +1195,10 @@
":size_utils",
":tensorflow_lite",
"//tensorflow/compiler/mlir/lite:control_edges",
+ "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
"//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
"//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
+ "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable",
"//tensorflow/compiler/mlir/lite/schema:schema_utils",
"//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom",
@@ -1204,8 +1212,6 @@
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
- "//tensorflow/core/platform:status",
- "//tensorflow/lite:model_builder",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
@@ -1357,7 +1363,6 @@
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
- "//tensorflow/lite:framework",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/compiler/mlir/lite/core/BUILD b/tensorflow/compiler/mlir/lite/core/BUILD
index 04e37c6..4816761 100644
--- a/tensorflow/compiler/mlir/lite/core/BUILD
+++ b/tensorflow/compiler/mlir/lite/core/BUILD
@@ -1,13 +1,59 @@
load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
+load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts_warnings")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
+exports_files(
+ [
+ "model_builder_base.h",
+ ],
+ visibility = ["//tensorflow/lite/core:__pkg__"],
+)
+
cc_library(
name = "macros",
hdrs = ["macros.h"],
compatible_with = get_compatible_with_portable(),
visibility = ["//visibility:public"],
)
+
+cc_library(
+ name = "model_builder_base",
+ srcs = ["model_builder_base.cc"],
+ hdrs = ["model_builder_base.h"],
+ compatible_with = get_compatible_with_portable(),
+ copts = tflite_copts_warnings(),
+ visibility = [
+ "//tensorflow/compiler/mlir/lite:__subpackages__",
+ "//tensorflow/lite/core:__pkg__",
+ ],
+ deps = [
+ ":macros",
+ "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+ "//tensorflow/compiler/mlir/lite/core/api:verifier",
+ "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
+ "//tensorflow/lite:allocation",
+ "@com_google_absl//absl/strings",
+ "@flatbuffers",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "absl_error_model_builder",
+ srcs = ["absl_error_model_builder.cc"],
+ hdrs = ["absl_error_model_builder.h"],
+ compatible_with = get_compatible_with_portable(),
+ copts = tflite_copts_warnings(),
+ visibility = [
+ "//tensorflow/compiler/mlir/lite:__subpackages__",
+ ],
+ deps = [
+ ":model_builder_base",
+ "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+ "@com_google_absl//absl/log",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc
new file mode 100644
index 0000000..269d81e
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc
@@ -0,0 +1,40 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
+
+#include <cstdarg>
+#include <cstdio>
+
+#include "absl/log/log.h"
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+namespace mlir::TFL {
+
+int AbslErrorReporter::Report(const char* format, va_list args) {
+ char buffer[1024];
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wformat-nonliteral"
+ vsprintf(buffer, format, args);
+#pragma clang diagnostic pop
+ LOG(ERROR) << buffer;
+ return 0;
+}
+
+tflite::ErrorReporter* GetAbslErrorReporter() {
+ static AbslErrorReporter* error_reporter = new AbslErrorReporter;
+ return error_reporter;
+}
+
+} // namespace mlir::TFL
diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h
new file mode 100644
index 0000000..c3d76e2
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h
@@ -0,0 +1,47 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_
+
+#include <cstdarg>
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h"
+
+namespace mlir::TFL {
+
+// An error reporter that uses absl logging.
+class AbslErrorReporter : public tflite::ErrorReporter {
+ int Report(const char* format, va_list args) override;
+};
+
+tflite::ErrorReporter* GetAbslErrorReporter();
+
+class FlatBufferModelAbslError
+ : public tflite::impl::FlatBufferModelBase<FlatBufferModelAbslError> {
+ public:
+ // Use stderr_reporter as the default error reporter.
+ static tflite::ErrorReporter* GetDefaultErrorReporter() {
+ return GetAbslErrorReporter();
+ }
+
+ // Inherit all constructors from FlatBufferModelBase since inherited factory
+ // methods refer to them.
+ using FlatBufferModelBase<FlatBufferModelAbslError>::FlatBufferModelBase;
+};
+
+} // namespace mlir::TFL
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_
diff --git a/tensorflow/compiler/mlir/lite/core/api/BUILD b/tensorflow/compiler/mlir/lite/core/api/BUILD
new file mode 100644
index 0000000..cc2d519
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/BUILD
@@ -0,0 +1,44 @@
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
+load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = [
+ "//tensorflow/compiler/mlir/lite:__subpackages__",
+ "//tensorflow/lite:__subpackages__",
+ ],
+ licenses = ["notice"],
+)
+
+exports_files(["error_reporter.h"])
+
+cc_library(
+ name = "error_reporter",
+ srcs = ["error_reporter.cc"],
+ hdrs = ["error_reporter.h"],
+ compatible_with = get_compatible_with_portable(),
+ copts = tflite_copts(),
+ deps = [],
+)
+
+exports_files(["verifier.h"])
+
+cc_library(
+ name = "verifier",
+ hdrs = ["verifier.h"],
+ compatible_with = get_compatible_with_portable(),
+ copts = tflite_copts(),
+ visibility = ["//visibility:public"],
+ deps = [":error_reporter"],
+)
+
+tf_cc_test(
+ name = "error_reporter_test",
+ size = "small",
+ srcs = ["error_reporter_test.cc"],
+ deps = [
+ ":error_reporter",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc
new file mode 100644
index 0000000..96f7561
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+#include <cstdarg>
+
+namespace tflite {
+
+int ErrorReporter::Report(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+} // namespace tflite
diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter.h b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h
new file mode 100644
index 0000000..79c9fc9
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+/// A functor that reports error to supporting system. Invoked similar to
+/// printf.
+///
+/// Usage:
+/// ErrorReporter foo;
+/// foo.Report("test %d", 5);
+/// or
+/// va_list args;
+/// foo.Report("test %d", args); // where args is va_list
+///
+/// Subclass ErrorReporter to provide another reporting destination.
+/// For example, if you have a GUI program, you might redirect to a buffer
+/// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+ virtual ~ErrorReporter() = default;
+ /// Converts `args` to character equivalents according to `format` string,
+ /// constructs the error string and report it.
+ /// Returns number of characters written or zero on success, and negative
+ /// number on error.
+ virtual int Report(const char* format, va_list args) = 0;
+
+ /// Converts arguments to character equivalents according to `format` string,
+ /// constructs the error string and report it.
+ /// Returns number of characters written or zero on success, and negative
+ /// number on error.
+ int Report(const char* format, ...);
+
+ /// Equivalent to `Report` above. The additional `void*` parameter is unused.
+ /// This method is for compatibility with macros that takes `TfLiteContext`,
+ /// like TF_LITE_ENSURE and related macros.
+ int ReportError(void*, const char* format, ...);
+};
+
+} // namespace tflite
+
+// You should not make bare calls to the error reporter, instead use the
+// TF_LITE_REPORT_ERROR macro, since this allows message strings to be
+// stripped when the binary size has to be optimized. If you are looking to
+// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and
+// every call will be stubbed out, taking no memory.
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+#define TF_LITE_REPORT_ERROR(reporter, ...) \
+ do { \
+ static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \
+ } while (false)
+#else // TF_LITE_STRIP_ERROR_STRINGS
+#define TF_LITE_REPORT_ERROR(reporter, ...)
+#endif // TF_LITE_STRIP_ERROR_STRINGS
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc
new file mode 100644
index 0000000..ca7c4a2
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() { buffer_[0] = 0; }
+ int Report(const char* format, va_list args) override {
+ vsnprintf(buffer_, kBufferSize, format, args);
+ return 0;
+ }
+ char* GetBuffer() { return buffer_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+};
+
+TEST(ErrorReporter, TestReport) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ reporter->Report("Error: %d", 23);
+ EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+}
+
+TEST(ErrorReporter, TestReportMacro) {
+ MockErrorReporter mock_reporter;
+ // Only define the reporter if it's used, to avoid warnings.
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+ ErrorReporter* reporter = &mock_reporter;
+#endif // TFLITE_STRIP_ERROR_STRINGS
+
+ TF_LITE_REPORT_ERROR(reporter, "Error: %d", 23);
+
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+ EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+#else // TF_LITE_STRIP_ERROR_STRINGS
+ EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), ""));
+#endif // TF_LITE_STRIP_ERROR_STRINGS
+}
+
+} // namespace tflite
diff --git a/tensorflow/compiler/mlir/lite/core/api/verifier.h b/tensorflow/compiler/mlir/lite/core/api/verifier.h
new file mode 100644
index 0000000..2e24347
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/verifier.h
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// \file
+///
+/// Abstract interface for verifying a model.
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+namespace tflite {
+
+/// Abstract interface that verifies whether a given model is legit.
+/// It facilitates the use-case to verify and build a model without loading it
+/// twice.
+/// (See also "tensorflow/lite/tools/verifier.h".)
+class TfLiteVerifier {
+ public:
+ /// Returns true if the model is legit.
+ virtual bool Verify(const char* data, int length,
+ ErrorReporter* reporter) = 0;
+ virtual ~TfLiteVerifier() {}
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_
diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.cc b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc
new file mode 100644
index 0000000..1537b90
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h"
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/allocation.h"
+
+namespace tflite {
+
+#ifndef TFLITE_MCU
+// Loads a model from `filename`. If `mmap_file` is true then use mmap,
+// otherwise make a copy of the model in a buffer.
+std::unique_ptr<Allocation> GetAllocationFromFile(
+ const char* filename, ErrorReporter* error_reporter) {
+ std::unique_ptr<Allocation> allocation;
+ if (MMAPAllocation::IsSupported()) {
+ allocation = std::make_unique<MMAPAllocation>(filename, error_reporter);
+ } else {
+ allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter);
+ }
+ return allocation;
+}
+
+// Loads a model from `fd`. If `mmap_file` is true then use mmap,
+// otherwise make a copy of the model in a buffer.
+std::unique_ptr<Allocation> GetAllocationFromFile(
+ int fd, ErrorReporter* error_reporter) {
+ std::unique_ptr<Allocation> allocation;
+ if (MMAPAllocation::IsSupported()) {
+ allocation = std::make_unique<MMAPAllocation>(fd, error_reporter);
+ } else {
+ allocation = std::make_unique<FileCopyAllocation>(
+ absl::StrCat("/proc/self/fd/", fd).c_str(), error_reporter);
+ }
+ return allocation;
+}
+
+#endif
+
+} // namespace tflite
diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/tensorflow/compiler/mlir/lite/core/model_builder_base.h
new file mode 100644
index 0000000..4e40e61
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.h
@@ -0,0 +1,614 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// \file
+///
+/// Deserialization infrastructure for tflite. Provides functionality
+/// to go from a serialized tflite model in flatbuffer format to an
+/// in-memory representation of the model.
+///
+/// WARNING: Users of TensorFlow Lite should not include this file directly,
+/// but should instead include "third_party/tensorflow/lite/model_builder.h".
+/// Only the TensorFlow Lite implementation itself should include this
+/// file directly.
+// IWYU pragma: private, include "third_party/tensorflow/lite/model_builder.h"
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "flatbuffers/base.h" // from @flatbuffers
+#include "flatbuffers/buffer.h" // from @flatbuffers
+#include "flatbuffers/vector.h" // from @flatbuffers
+#include "flatbuffers/verifier.h" // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+#include "tensorflow/compiler/mlir/lite/core/api/verifier.h"
+#include "tensorflow/compiler/mlir/lite/core/macros.h"
+#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
+#include "tensorflow/lite/allocation.h"
+
+namespace tflite {
+
+std::unique_ptr<Allocation> GetAllocationFromFile(
+ const char* filename, ErrorReporter* error_reporter);
+
+std::unique_ptr<Allocation> GetAllocationFromFile(
+ int fd, ErrorReporter* error_reporter);
+
+/// An RAII object that represents a read-only tflite model, copied from disk,
+/// or mmapped. This uses flatbuffers as the serialization format.
+///
+/// NOTE: The current API requires that a FlatBufferModelBase instance be kept
+/// alive by the client as long as it is in use by any dependent Interpreter
+/// instances. As the FlatBufferModelBase instance is effectively immutable
+/// after creation, the client may safely use a single model with multiple
+/// dependent Interpreter instances, even across multiple threads (though note
+/// that each Interpreter instance is *not* thread-safe).
+///
+/// <pre><code>
+/// using namespace tflite;
+/// StderrReporter error_reporter;
+/// auto model = FlatBufferModelBase::BuildFromFile("interesting_model.tflite",
+/// &error_reporter);
+/// MyOpResolver resolver; // You need to subclass OpResolver to provide
+/// // implementations.
+/// InterpreterBuilder builder(*model, resolver);
+/// std::unique_ptr<Interpreter> interpreter;
+/// if(builder(&interpreter) == kTfLiteOk) {
+/// .. run model inference with interpreter
+/// }
+/// </code></pre>
+///
+/// OpResolver must be defined to provide your kernel implementations to the
+/// interpreter. This is environment specific and may consist of just the
+/// builtin ops, or some custom operators you defined to extend tflite.
+namespace impl {
+
+template <typename T>
+class FlatBufferModelBase {
+ public:
+ /// Builds a model based on a file.
+ /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+ /// is longer than the FlatBufferModelBase instance.
+ /// Returns a nullptr in case of failure.
+ static std::unique_ptr<T> BuildFromFile(
+ const char* filename,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+ std::unique_ptr<T> model = BuildFromAllocation(
+ GetAllocationFromFile(filename, error_reporter), error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+ return model;
+#else
+ return ByteConvertModel(std::move(model), error_reporter);
+#endif
+ }
+
+ /// Verifies whether the content of the file is legit, then builds a model
+ /// based on the file.
+ /// The extra_verifier argument is an additional optional verifier for the
+ /// file contents. By default, we always check with tflite::VerifyModelBuffer.
+ /// If extra_verifier is supplied, the file contents is also checked against
+ /// the extra_verifier after the check against tflite::VerifyModelBuilder.
+ /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+ /// is longer than the FlatBufferModelBase instance.
+ /// Returns a nullptr in case of failure.
+ static std::unique_ptr<T> VerifyAndBuildFromFile(
+ const char* filename, TfLiteVerifier* extra_verifier = nullptr,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+ std::unique_ptr<T> model = VerifyAndBuildFromAllocation(
+ GetAllocationFromFile(filename, error_reporter), extra_verifier,
+ error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+ return model;
+#else
+ return ByteConvertModel(std::move(model), error_reporter);
+#endif
+ }
+
+ /// Builds a model based on a file descriptor.
+ /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+ /// is longer than the FlatBufferModelBase instance. Caller retains ownership
+ /// of `fd` and must ensure it is closed after BuildFromFile returns. Returns
+ /// a nullptr in case of failure.
+ static std::unique_ptr<T> BuildFromFileDescriptor(
+ int fd, ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+ std::unique_ptr<T> model = BuildFromAllocation(
+ GetAllocationFromFile(fd, error_reporter), error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+ return model;
+#else
+ return ByteConvertModel(std::move(model), error_reporter);
+#endif
+ }
+
+ /// Verifies whether the content of the file descriptor is legit, then builds
+ /// a model based on the file.
+ /// The extra_verifier argument is an additional optional verifier for the
+ /// file contents. By default, we always check with tflite::VerifyModelBuffer.
+ /// If extra_verifier is supplied, the file contents is also checked against
+ /// the extra_verifier after the check against tflite::VerifyModelBuilder.
+ /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+ /// is longer than the FlatBufferModelBase instance.
+ /// Returns a nullptr in case of failure.
+ static std::unique_ptr<T> VerifyAndBuildFromFileDescriptor(
+ int fd, TfLiteVerifier* extra_verifier = nullptr,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+ std::unique_ptr<FlatBufferModelBase<T>> model =
+ VerifyAndBuildFromAllocation(GetAllocationFromFile(fd, error_reporter),
+ extra_verifier, error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+ return model;
+#else
+ return ByteConvertModel(std::move(model), error_reporter);
+#endif
+ }
+
+ /// Builds a model based on a pre-loaded flatbuffer.
+ /// Caller retains ownership of the buffer and should keep it alive until
+ /// the returned object is destroyed. Caller also retains ownership of
+ /// `error_reporter` and must ensure its lifetime is longer than the
+ /// FlatBufferModelBase instance.
+ /// Returns a nullptr in case of failure.
+ /// NOTE: this does NOT validate the buffer so it should NOT be called on
+ /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
+ static std::unique_ptr<T> BuildFromBuffer(
+ const char* caller_owned_buffer, size_t buffer_size,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+ std::unique_ptr<Allocation> allocation(
+ new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
+ return BuildFromAllocation(std::move(allocation), error_reporter);
+ }
+
+ /// Verifies whether the content of the buffer is legit, then builds a model
+ /// based on the pre-loaded flatbuffer.
+ /// The extra_verifier argument is an additional optional verifier for the
+ /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
+ /// extra_verifier is supplied, the buffer is checked against the
+ /// extra_verifier after the check against tflite::VerifyModelBuilder. The
+ /// caller retains ownership of the buffer and should keep it alive until the
+ /// returned object is destroyed. Caller retains ownership of `error_reporter`
+ /// and must ensure its lifetime is longer than the FlatBufferModelBase
+ /// instance. Returns a nullptr in case of failure.
+ static std::unique_ptr<T> VerifyAndBuildFromBuffer(
+ const char* caller_owned_buffer, size_t buffer_size,
+ TfLiteVerifier* extra_verifier = nullptr,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+ std::unique_ptr<Allocation> allocation(
+ new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
+ return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
+ error_reporter);
+ }
+
+#if FLATBUFFERS_LITTLEENDIAN == 0
+
+ void ByteSwapSerializedModel(std::string* serialized_model,
+ bool from_big_endian) {
+ const uint8_t* buffer =
+ reinterpret_cast<const uint8_t*>(serialized_model->c_str());
+ const tflite::Model* input_model = tflite::GetModel(buffer);
+ ByteSwapTFLiteModel(input_model, from_big_endian);
+ }
+
+ void ByteSwapBuffer(int8_t tensor_type, size_t buffer_size, uint8_t* buffer,
+ bool from_big_endian) {
+ switch (tensor_type) {
+ case tflite::TensorType_STRING: {
+ auto bp = reinterpret_cast<int32_t*>(buffer);
+ int num_of_strings =
+ from_big_endian ? bp[0] : flatbuffers::EndianSwap(bp[0]);
+ for (int i = 0; i < num_of_strings + 2; i++)
+ bp[i] = flatbuffers::EndianSwap(bp[i]);
+ break;
+ }
+ // 16-bit types
+ case tflite::TensorType_FLOAT16:
+ case tflite::TensorType_INT16:
+ case tflite::TensorType_UINT16: {
+ auto bp = reinterpret_cast<uint16_t*>(buffer);
+ for (int i = 0; i < buffer_size / 2; i++)
+ bp[i] = flatbuffers::EndianSwap(bp[i]);
+ break;
+ }
+ // 32-bit types
+ case tflite::TensorType_FLOAT32:
+ case tflite::TensorType_INT32:
+ case tflite::TensorType_UINT32:
+ case tflite::TensorType_COMPLEX64: {
+ auto bp = reinterpret_cast<uint32_t*>(buffer);
+ for (int i = 0; i < buffer_size / 4; i++)
+ bp[i] = flatbuffers::EndianSwap(bp[i]);
+ break;
+ }
+ // 64-bit types
+ case tflite::TensorType_INT64:
+ case tflite::TensorType_FLOAT64:
+ case tflite::TensorType_UINT64:
+ case tflite::TensorType_COMPLEX128: {
+ auto bp = reinterpret_cast<uint64_t*>(buffer);
+ for (int i = 0; i < buffer_size / 8; i++)
+ bp[i] = flatbuffers::EndianSwap(bp[i]);
+ break;
+ }
+ default:
+ break;
+ }
+ }
+
+ void ByteSwapTFLiteModel(const tflite::Model* tfl_model,
+ bool from_big_endian) {
+ bool buffer_swapped[tfl_model->buffers()->size()] = {};
+ for (size_t subgraph_idx = 0; subgraph_idx < tfl_model->subgraphs()->size();
+ subgraph_idx++) {
+ const tflite::SubGraph* subgraph =
+ tfl_model->subgraphs()->Get(subgraph_idx);
+ for (size_t ts_idx = 0; ts_idx < subgraph->tensors()->size(); ts_idx++) {
+ const tflite::Tensor* tensor = subgraph->tensors()->Get(ts_idx);
+ if (tensor->buffer() > 0 &&
+ tensor->buffer() < tfl_model->buffers()->size() &&
+ !buffer_swapped[tensor->buffer()]) {
+ const tflite::Buffer* buffer_ =
+ (*tfl_model->buffers())[tensor->buffer()];
+ if (!buffer_ || !buffer_->data()) continue;
+ auto* buffer = buffer_->data();
+ uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
+ ByteSwapBuffer(tensor->type(), buffer->size(), buff_,
+ from_big_endian);
+ buffer_swapped[tensor->buffer()] = true;
+ }
+ }
+ }
+ }
+
+ std::unique_ptr<T> ByteConvertModel(std::unique_ptr<T> model,
+ ErrorReporter* error_reporter,
+ bool from_big_endian) {
+ if (model == nullptr) return model;
+ auto tfl_model = model->GetModel();
+ if (tfl_model->subgraphs()->size() == 0) return model;
+ if (tfl_model->subgraphs()->Get(0)->tensors()->size() == 0) return model;
+ if (tfl_model->buffers()->size() < 2) return model;
+ return ByteSwapFlatBufferModelBase<T>(std::move(model), error_reporter,
+ from_big_endian);
+ }
+
+ std::unique_ptr<T> ByteSwapFlatBufferModelBase(std::unique_ptr<T> model,
+ ErrorReporter* error_reporter,
+ bool from_big_endian) {
+ FlatBufferModelBase<T>* modelp = model.release();
+ auto tflite_model = modelp->GetModel();
+ auto copied_model = std::make_unique<tflite::ModelT>();
+ tflite_model->UnPackTo(copied_model.get(), nullptr);
+ ByteSwapTFLiteModelT(copied_model.get(), from_big_endian);
+ std::unique_ptr<flatbuffers::FlatBufferBuilder> builder(
+ new flatbuffers::FlatBufferBuilder());
+ auto packed_model = tflite::Model::Pack(*builder, copied_model.get());
+ tflite::FinishModelBuffer(*builder, packed_model);
+ flatbuffers::FlatBufferBuilder* builder_ = builder.release();
+ return BuildFromBuffer(
+ reinterpret_cast<const char*>(builder_->GetBufferPointer()),
+ builder_->GetSize(), error_reporter);
+ }
+
+ void ByteSwapTFLiteModelT(tflite::ModelT* tfl_modelt, bool from_big_endian) {
+ size_t bytes_per_elem = 0;
+ bool buffer_swapped[tfl_modelt->buffers.size()] = {};
+ for (size_t subgraph_idx = 0; subgraph_idx < tfl_modelt->subgraphs.size();
+ subgraph_idx++) {
+ tflite::SubGraphT* subgraph =
+ tfl_modelt->subgraphs.at(subgraph_idx).get();
+ for (size_t ts_idx = 0; ts_idx < subgraph->tensors.size(); ts_idx++) {
+ tflite::TensorT* tensor = subgraph->tensors[ts_idx].get();
+ if (tensor->buffer > 0 && tensor->buffer < tfl_modelt->buffers.size() &&
+ !buffer_swapped[tensor->buffer]) {
+ const auto* buffer =
+ &(tfl_modelt->buffers[tensor->buffer].get()->data);
+ if (buffer && buffer->data()) {
+ uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
+ ByteSwapBuffer(tensor->type, buffer->size(), buff_,
+ from_big_endian);
+ buffer_swapped[tensor->buffer] = true;
+ }
+ }
+ }
+ }
+ }
+
+#endif
+
+ /// Builds a model directly from an allocation.
+ /// Ownership of the allocation is passed to the model, but the caller
+ /// retains ownership of `error_reporter` and must ensure its lifetime is
+ /// longer than the FlatBufferModelBase instance.
+ /// Returns a nullptr in case of failure (e.g., the allocation is invalid).
+ static std::unique_ptr<T> BuildFromAllocation(
+ std::unique_ptr<Allocation> allocation,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ std::unique_ptr<T> model(
+ new T(std::move(allocation), ValidateErrorReporter(error_reporter)));
+ if (!model->initialized()) {
+ model.reset();
+ } else {
+ model->ValidateModelBuffers(error_reporter);
+ }
+ return model;
+ }
+
+ /// Verifies whether the content of the allocation is legit, then builds a
+ /// model based on the provided allocation.
+ /// The extra_verifier argument is an additional optional verifier for the
+ /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
+ /// extra_verifier is supplied, the buffer is checked against the
+ /// extra_verifier after the check against tflite::VerifyModelBuilder.
+ /// Ownership of the allocation is passed to the model, but the caller
+ /// retains ownership of `error_reporter` and must ensure its lifetime is
+ /// longer than the FlatBufferModelBase instance.
+ /// Returns a nullptr in case of failure.
+ static std::unique_ptr<T> VerifyAndBuildFromAllocation(
+ std::unique_ptr<Allocation> allocation,
+ TfLiteVerifier* extra_verifier = nullptr,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+ if (!allocation || !allocation->valid()) {
+ TF_LITE_REPORT_ERROR(error_reporter,
+ "The model allocation is null/empty");
+ return nullptr;
+ }
+
+ {
+ // Flatbuffers can only be smaller than 2GB. The file format appends some
+ // data after the actual flabuffer. We truncate the allocation size to 2GB
+ // so that the verifier doesn't early exit on us.
+ size_t allocation_size =
+ std::min(allocation->bytes(),
+ static_cast<size_t>(FLATBUFFERS_MAX_BUFFER_SIZE - 1));
+ flatbuffers::Verifier base_verifier(
+ reinterpret_cast<const uint8_t*>(allocation->base()),
+ allocation_size);
+ if (!VerifyModelBuffer(base_verifier)) {
+ TF_LITE_REPORT_ERROR(error_reporter,
+ "The model is not a valid Flatbuffer buffer");
+ return nullptr;
+ }
+
+ if (extra_verifier &&
+ !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
+ allocation_size, error_reporter)) {
+ // The verifier will have already logged an appropriate error message.
+ return nullptr;
+ }
+ }
+
+ return BuildFromAllocation(std::move(allocation), error_reporter);
+ }
+
+ /// Builds a model directly from a flatbuffer pointer
+ /// Caller retains ownership of the buffer and should keep it alive until the
+ /// returned object is destroyed. Caller retains ownership of `error_reporter`
+ /// and must ensure its lifetime is longer than the FlatBufferModelBase
+ /// instance. Returns a nullptr in case of failure.
+ static std::unique_ptr<T> BuildFromModel(
+ const tflite::Model* caller_owned_model_spec,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+
+ if (CheckBufferOutsideModel(caller_owned_model_spec)) {
+ TF_LITE_REPORT_ERROR(error_reporter,
+ "The model contains weights not accessible from "
+ "tflite::Model *, please use other api");
+ return nullptr;
+ }
+
+ std::unique_ptr<T> model(new T(caller_owned_model_spec, error_reporter));
+ if (!model->initialized()) {
+ model.reset();
+ } else {
+ model->ValidateModelBuffers(error_reporter);
+ }
+ return model;
+ }
+
+ // Releases memory or unmaps mmaped memory.
+ ~FlatBufferModelBase() = default;
+
+ // Copying or assignment is disallowed to simplify ownership semantics.
+ FlatBufferModelBase(const FlatBufferModelBase&) = delete;
+ FlatBufferModelBase& operator=(const FlatBufferModelBase&) = delete;
+
+ bool initialized() const { return model_ != nullptr; }
+ const tflite::Model* operator->() const { return model_; }
+ const tflite::Model* GetModel() const { return model_; }
+ ErrorReporter* error_reporter() const { return error_reporter_; }
+ const Allocation* allocation() const { return allocation_.get(); }
+
+ // Returns the minimum runtime version from the flatbuffer. This runtime
+ // version encodes the minimum required interpreter version to run the
+ // flatbuffer model. If the minimum version can't be determined, an empty
+ // string will be returned.
+ // Note that the returned minimum version is a lower-bound but not a strict
+ // lower-bound; ops in the graph may not have an associated runtime version,
+ // in which case the actual required runtime might be greater than the
+ // reported minimum.
+ std::string GetMinimumRuntime() const {
+ if (!model_ || !model_->metadata()) return "";
+
+ for (int i = 0; i < model_->metadata()->size(); ++i) {
+ auto metadata = model_->metadata()->Get(i);
+ if (metadata->name()->str() == tflite_metadata_min_runtime_version) {
+ auto buf = metadata->buffer();
+ auto* buffer = (*model_->buffers())[buf];
+ auto* array = buffer->data();
+ // Get the real length of the runtime string, since there might be
+ // trailing
+ // '\0's in the buffer.
+ for (int len = 0; len < array->size(); ++len) {
+ if (array->data()[len] == '\0') {
+ return std::string(reinterpret_cast<const char*>(array->data()),
+ len);
+ }
+ }
+ // If there is no '\0' in the buffer, this indicates that the flatbuffer
+ // is malformed.
+ TF_LITE_REPORT_ERROR(
+ error_reporter_,
+ "Min_runtime_version in model metadata is malformed");
+ break;
+ }
+ }
+ return "";
+ }
+
+ // Return model metadata as a mapping of name & buffer strings.
+ // See Metadata table in TFLite schema.
+ std::map<std::string, std::string> ReadAllMetadata() const {
+ return ReadAllMetadata(model_);
+ }
+
+ // // Return model metadata as a mapping of name & buffer strings.
+ // // See Metadata table in TFLite schema.
+ static std::map<std::string, std::string> ReadAllMetadata(
+ const ::tflite::Model* model) {
+ std::map<std::string, std::string> keys_values;
+ if (!model || !model->metadata() || !model->buffers()) return keys_values;
+
+ for (int i = 0; i < model->metadata()->size(); ++i) {
+ auto metadata = model->metadata()->Get(i);
+ auto buf = metadata->buffer();
+ if (buf >= model->buffers()->size()) continue;
+ const tflite::Buffer* buffer = (*model->buffers())[buf];
+ if (!buffer || !buffer->data()) continue;
+ const flatbuffers::Vector<uint8_t>* array = buffer->data();
+ if (!array) continue;
+ std::string val = std::string(
+ reinterpret_cast<const char*>(array->data()), array->size());
+ // Skip if key or value of metadata is empty.
+ if (!metadata->name() || val.empty()) continue;
+ keys_values[metadata->name()->str()] = val;
+ }
+ return keys_values;
+ }
+
+ // Validates if the FlatBufferModelBase's buffer is well-formed. Specifically,
+ // it checks if the 0th entry of the model buffers is an empty buffer
+ // (sentinel). This is a convention so that tensors without a buffer can
+ // provide 0 as their buffer. NOTE: The function doesn't explicitly fail for
+ // backward compatibility reasons; it just provides a warning in case of
+ // failures.
+ void ValidateModelBuffers(ErrorReporter* error_reporter) {
+ auto buffers = model_->buffers();
+ if (buffers && buffers->size() > 0) {
+ auto first_buffer = buffers->Get(0);
+ if (first_buffer && first_buffer->size() != 0) {
+ // Note the 0th entry of this array must be an empty buffer (sentinel).
+ // This is a convention so that tensors without a buffer can provide 0
+ // as their buffer.
+ TF_LITE_REPORT_ERROR(
+ error_reporter,
+ "The 0th entry of the model buffer must be an empty buffer.");
+ }
+ }
+ }
+
+ /// Returns true if the model identifier is correct (otherwise false and
+ /// reports an error).
+ bool CheckModelIdentifier() const {
+ if (allocation_->bytes() < 7) {
+ TF_LITE_REPORT_ERROR(
+ error_reporter_,
+ "Model provided must have at least 7 bytes to hold identifier.\n");
+ return false;
+ }
+ if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
+ const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
+ // Suppress unused variable warning.
+ (void)ident;
+ TF_LITE_REPORT_ERROR(
+ error_reporter_,
+ "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
+ ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
+ return false;
+ }
+ return true;
+ }
+
+ /// Check If the buffer is stored as part of the Flatbuffer or outside
+ /// Return false if the buffers are part of the Flatbuffer
+ static bool CheckBufferOutsideModel(const tflite::Model* model) {
+ if (!model || !model->metadata()) return false;
+
+ for (int i = 0; i < model->metadata()->size(); ++i) {
+ auto metadata = model->metadata()->Get(i);
+ if (metadata->name()->str() == tflite_metadata_buffer_location) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ protected:
+ /// Loads a model from a given allocation. FlatBufferModelBase will take over
+ /// the ownership of `allocation`, and delete it in destructor. The ownership
+ /// of `error_reporter`remains with the caller and must have lifetime at least
+ /// as much as FlatBufferModelBase. This is to allow multiple models to use
+ /// the same ErrorReporter instance.
+ explicit FlatBufferModelBase(
+ std::unique_ptr<Allocation> allocation,
+ ErrorReporter* error_reporter = T::GetDefaultErrorReporter())
+ : error_reporter_(ValidateErrorReporter(error_reporter)),
+ allocation_(std::move(allocation)) {
+ if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
+ return;
+ }
+
+ model_ = ::tflite::GetModel(allocation_->base());
+ }
+
+ /// Loads a model from Model flatbuffer. The `model` has to remain alive and
+ /// unchanged until the end of this flatbuffer model's lifetime.
+ FlatBufferModelBase(const Model* model, ErrorReporter* error_reporter)
+ : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
+
+ static ErrorReporter* ValidateErrorReporter(ErrorReporter* error_reporter) {
+ return error_reporter ? error_reporter : T::GetDefaultErrorReporter();
+ }
+
+ /// Flatbuffer traverser pointer. (Model* is a pointer that is within the
+ /// allocated memory of the data allocated by allocation's internals.
+ const tflite::Model* model_ = nullptr;
+ /// The error reporter to use for model errors and subsequent errors when
+ /// the interpreter is created
+ ErrorReporter* error_reporter_;
+ /// The allocator used for holding memory of the model. Note that this will
+ /// be null if the client provides a tflite::Model directly.
+ std::unique_ptr<Allocation> allocation_;
+};
+
+} // namespace impl
+
+} // namespace tflite
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_
diff --git a/tensorflow/compiler/mlir/lite/debug/BUILD b/tensorflow/compiler/mlir/lite/debug/BUILD
index 2516cb4..5c179eb 100644
--- a/tensorflow/compiler/mlir/lite/debug/BUILD
+++ b/tensorflow/compiler/mlir/lite/debug/BUILD
@@ -59,8 +59,8 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:path",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc
index 371e318..5d1ed84 100644
--- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc
+++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc
@@ -46,8 +46,8 @@
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/path.h"
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index 6164222..269f5cd 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -113,7 +113,6 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/tstring.h"
-#include "tensorflow/lite/core/interpreter.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
@@ -159,6 +158,11 @@
ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
+// LINT.IfChange(optional_tensor)
+// Taken from third_party/tensorflow/lite/core/c/common.h
+constexpr int kTfLiteMigrationOptionalTensor = -1;
+// LINT.ThenChange(//tensorflow/lite/core/c/common.h:optional_tensor)
+
// Use initial buffer size in flatbuffer builder to be same as the initial size
// used by the TOCO export. (It does not explain rationale for this choice.)
constexpr size_t kInitialBufferSize = 10240;
@@ -3024,7 +3028,7 @@
operands.reserve(real_inst->getNumOperands());
for (auto operand : real_inst->getOperands()) {
if (mlir::isa<NoneType>(operand.getType()))
- operands.push_back(kTfLiteOptionalTensor);
+ operands.push_back(kTfLiteMigrationOptionalTensor);
else if (auto stats_op =
llvm::dyn_cast_or_null<mlir::quantfork::StatisticsOp>(
operand.getDefiningOp()))
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index a03e988..a289126 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -44,6 +44,7 @@
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project
@@ -70,16 +71,17 @@
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
-#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "stablehlo/dialect/VhloOps.h" // from @stablehlo
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/offset_buffer.h"
#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h"
+#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h"
@@ -97,8 +99,6 @@
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/status.h"
-#include "tensorflow/lite/model_builder.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
@@ -625,7 +625,7 @@
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
OpBuilder builder,
- const std::unique_ptr<tflite::FlatBufferModel>& model_ptr) {
+ const std::unique_ptr<tfl::FlatBufferModelAbslError>& model_ptr) {
llvm::SmallVector<Value, 4> operands;
llvm::SmallVector<mlir::Type, 2> outputTypes;
@@ -1116,7 +1116,7 @@
bool experimental_prune_unreachable_nodes_unconditionally,
const tflite::SignatureDefT* signature,
const tflite::ControlEdges& control_edges,
- const std::unique_ptr<tflite::FlatBufferModel>& model_ptr,
+ const std::unique_ptr<tfl::FlatBufferModelAbslError>& model_ptr,
bool use_stablehlo_constant) {
// Populate from metadata.
ControlNodes control_nodes;
@@ -1518,8 +1518,8 @@
mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect,
mlir::stablehlo::StablehloDialect, mlir::vhlo::VhloDialect>();
- auto model_ptr =
- FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
+ auto model_ptr = tfl::FlatBufferModelAbslError::VerifyAndBuildFromBuffer(
+ buffer.data(), buffer.length());
if (nullptr == model_ptr) {
return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
}
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
index df28f50..b393a88 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
@@ -27,7 +27,7 @@
#include "flatbuffers/minireflect.h" // from @flatbuffers
#include "tensorflow/compiler/mlir/lite/schema/reflection/schema_generated.h"
#if FLATBUFFERS_LITTLEENDIAN == 0
-#include "tensorflow/lite/core/model_builder.h"
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
#endif
namespace tflite {
@@ -144,7 +144,8 @@
// If the flatbuffer model comes from stdin, convert its tensor content from
// BE to LE to ensure the output text string is the same as on LE platforms.
if (std::string(argv[1]) == "-")
- tflite::FlatBufferModel::ByteSwapSerializedModel(&serialized_model, true);
+ mlir::TFL::FlatBufferModelAbslError::ByteSwapSerializedModel(
+ &serialized_model, true);
#endif
tflite::ToString(serialized_model);
return 0;
diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD
index 6d2de49..299bb9e 100644
--- a/tensorflow/compiler/mlir/lite/python/BUILD
+++ b/tensorflow/compiler/mlir/lite/python/BUILD
@@ -195,8 +195,10 @@
":saved_model_to_tfl_flatbuffer",
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status_headers",
+ "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
"//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc",
"//tensorflow/compiler/mlir/lite/metrics:error_collector",
+ "//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_error_reporter",
"//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_utils",
"//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
@@ -205,8 +207,6 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/lite:model_builder",
- "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter",
"//tensorflow/lite/toco:model",
"//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_convert",
diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc
index 31ec151..351846c 100644
--- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc
+++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc
@@ -31,10 +31,12 @@
#include "google/protobuf/text_format.h"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_status.h"
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"
#include "tensorflow/compiler/mlir/lite/metrics/error_collector.h"
#include "tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h"
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
+#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
@@ -46,8 +48,6 @@
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/platform/status.h"
-#include "tensorflow/lite/model_builder.h"
-#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/toco/logging/conversion_log_util.h"
#include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
#include "tensorflow/lite/toco/model.h"
@@ -309,7 +309,7 @@
bool enable_variable_quantization,
bool disable_per_channel_for_dense_layers,
PyObject* debug_options_proto_txt_raw) {
- using tflite::interpreter_wrapper::PythonErrorReporter;
+ using tflite_migration::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr;
Py_ssize_t length;
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
@@ -362,9 +362,9 @@
return nullptr;
}
- std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buf, length,
- error_reporter.get());
+ std::unique_ptr<mlir::TFL::FlatBufferModelAbslError> model =
+ mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer(
+ buf, length, error_reporter.get());
if (!model) {
PyErr_Format(PyExc_ValueError, "Invalid model");
return nullptr;
@@ -399,7 +399,7 @@
}
PyObject* MlirSparsifyModel(PyObject* data) {
- using tflite::interpreter_wrapper::PythonErrorReporter;
+ using tflite_migration::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr;
Py_ssize_t length;
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
@@ -408,9 +408,9 @@
PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
return nullptr;
}
- std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buf, length,
- error_reporter.get());
+ std::unique_ptr<mlir::TFL::FlatBufferModelAbslError> model =
+ mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer(
+ buf, length, error_reporter.get());
if (!model) {
PyErr_Format(PyExc_ValueError, "Invalid model");
return nullptr;
diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD
index 8d2cb7a..9268de7 100644
--- a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD
@@ -15,3 +15,14 @@
"//third_party/python_runtime:headers", # buildcleaner: keep
],
)
+
+cc_library(
+ name = "python_error_reporter",
+ srcs = ["python_error_reporter.cc"],
+ hdrs = ["python_error_reporter.h"],
+ compatible_with = get_compatible_with_portable(),
+ deps = [
+ "//tensorflow/compiler/mlir/lite:stateful_error_reporter",
+ "//third_party/python_runtime:headers", # buildcleaner: keep
+ ],
+)
diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc
new file mode 100644
index 0000000..75f9222
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc
@@ -0,0 +1,47 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h"
+
+#include <cstdarg>
+#include <cstdio>
+#include <string>
+
+namespace tflite_migration {
+namespace interpreter_wrapper {
+
+// Report an error message
+int PythonErrorReporter::Report(const char* format, va_list args) {
+ char buf[1024];
+ int formatted = vsnprintf(buf, sizeof(buf), format, args);
+ buffer_ << buf;
+ return formatted;
+}
+
+// Set's a Python runtime exception with the last error.
+PyObject* PythonErrorReporter::exception() {
+ std::string last_message = message();
+ PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
+ return nullptr;
+}
+
+// Gets the last error message and clears the buffer.
+std::string PythonErrorReporter::message() {
+ std::string value = buffer_.str();
+ buffer_.clear();
+ return value;
+}
+} // namespace interpreter_wrapper
+} // namespace tflite_migration
diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h
new file mode 100644
index 0000000..f98a352
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h
@@ -0,0 +1,50 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
+
+#include <Python.h>
+
+#include <cstdarg>
+#include <sstream>
+#include <string>
+
+#include "tensorflow/compiler/mlir/lite/stateful_error_reporter.h"
+
+namespace tflite_migration {
+namespace interpreter_wrapper {
+
+class PythonErrorReporter : public tflite_migration::StatefulErrorReporter {
+ public:
+ PythonErrorReporter() = default;
+
+ // Report an error message
+ int Report(const char* format, va_list args) override;
+
+ // Sets a Python runtime exception with the last error and
+ // clears the error message buffer.
+ PyObject* exception();
+
+ // Gets the last error message and clears the buffer.
+ std::string message() override;
+
+ private:
+ std::stringstream buffer_;
+};
+
+} // namespace interpreter_wrapper
+} // namespace tflite_migration
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
index e57c2b3..c269b41 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
@@ -168,16 +168,16 @@
deps = [
":quantize_model",
":test_util",
+ "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/lite/schema:schema_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
- "//tensorflow/lite:framework",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
"@flatbuffers",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -205,11 +205,11 @@
deps = [
":quantize_weights",
":test_util",
+ "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/lite/schema:schema_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
- "//tensorflow/lite:framework",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
"@flatbuffers",
@@ -223,9 +223,7 @@
srcs = ["test_util.cc"],
hdrs = ["test_util.h"],
deps = [
- "//tensorflow/lite:framework",
- "//tensorflow/lite/core/api",
+ "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
"@com_google_googletest//:gtest",
- "@flatbuffers",
],
)
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc
index 1e7cdcd..371f452 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc
@@ -30,15 +30,15 @@
#include "absl/status/status.h"
#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers
#include "flatbuffers/vector.h" // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
-#include "tensorflow/lite/model_builder.h"
-#include "tsl/lib/core/status_test_util.h"
// Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc
@@ -50,6 +50,7 @@
namespace optimize {
namespace {
+using mlir::TFL::FlatBufferModelAbslError;
using testing::Eq;
using testing::FloatEq;
using testing::FloatNear;
@@ -100,7 +101,7 @@
return status;
}
- auto flatbuffer_model = FlatBufferModel::BuildFromBuffer(
+ auto flatbuffer_model = FlatBufferModelAbslError::BuildFromBuffer(
output_buffer.data(), output_buffer.size());
*model = UnPackFlatBufferModel(*flatbuffer_model->GetModel());
return absl::OkStatus();
@@ -157,9 +158,10 @@
disable_per_channel_for_dense_layers);
}
-std::unique_ptr<FlatBufferModel> ReadModel(const std::string& model_name) {
+std::unique_ptr<FlatBufferModelAbslError> ReadModel(
+ const std::string& model_name) {
auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name);
- return FlatBufferModel::BuildFromFile(model_path.c_str());
+ return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
}
template <typename T>
@@ -198,7 +200,7 @@
model_ = UnPackFlatBufferModel(*readonly_model_);
}
- std::unique_ptr<FlatBufferModel> input_model_;
+ std::unique_ptr<FlatBufferModelAbslError> input_model_;
const Model* readonly_model_;
tflite::ModelT model_;
std::string output_buffer_; // Raw buffer for quantized output model.
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc
index 7a42e74..db124c8 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc
@@ -27,6 +27,7 @@
#include "flatbuffers/buffer.h" // from @flatbuffers
#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers
#include "flatbuffers/vector.h" // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
@@ -34,7 +35,6 @@
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
-#include "tensorflow/lite/model_builder.h"
#include "tsl/platform/logging.h"
// Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc
@@ -50,6 +50,7 @@
using mlir::lite::BufferType;
using mlir::lite::CustomOpMap;
using mlir::lite::QuantizeWeights;
+using mlir::TFL::FlatBufferModelAbslError;
constexpr bool kUseUpdatedHybridSchemeDefault = true;
std::unique_ptr<ModelT> CreateMutableModelFromFile(const Model* input_model) {
@@ -58,28 +59,28 @@
return copied_model;
}
-std::unique_ptr<FlatBufferModel> ReadTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadTestModel() {
auto model_path = tensorflow::io::JoinPath(
*g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights);
- return FlatBufferModel::BuildFromFile(model_path.c_str());
+ return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
}
-std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadSharedWeightsTestModel() {
auto model_path = tensorflow::io::JoinPath(
*g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights);
- return FlatBufferModel::BuildFromFile(model_path.c_str());
+ return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
}
-std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadGatherTestModel() {
auto model_path = tensorflow::io::JoinPath(
*g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather);
- return FlatBufferModel::BuildFromFile(model_path.c_str());
+ return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
}
-std::unique_ptr<FlatBufferModel> ReadCustomOpTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadCustomOpTestModel() {
auto model_path = tensorflow::io::JoinPath(
*g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp);
- return FlatBufferModel::BuildFromFile(model_path.c_str());
+ return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
}
template <typename T>
@@ -111,7 +112,7 @@
model_ = input_model_->GetModel();
}
- std::unique_ptr<FlatBufferModel> input_model_;
+ std::unique_ptr<FlatBufferModelAbslError> input_model_;
const Model* model_;
bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) {
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc
index e096868..66c1ade 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc
@@ -14,6 +14,9 @@
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h"
+#include <cstdarg>
+#include <cstdio>
+
#include <gtest/gtest.h>
namespace mlir {
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h
index b4e317c1..8953a38 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h
@@ -15,7 +15,9 @@
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_
-#include "tensorflow/lite/core/api/error_reporter.h"
+#include <cstdarg>
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
namespace mlir {
namespace lite {
diff --git a/tensorflow/compiler/mlir/lite/schema/BUILD b/tensorflow/compiler/mlir/lite/schema/BUILD
index 7dd8eec..14e80f4 100644
--- a/tensorflow/compiler/mlir/lite/schema/BUILD
+++ b/tensorflow/compiler/mlir/lite/schema/BUILD
@@ -11,7 +11,15 @@
)
exports_files(
- srcs = ["schema.fbs"],
+ srcs = [
+ "schema.fbs",
+ "schema_v0.fbs",
+ "schema_v1.fbs",
+ "schema_v2.fbs",
+ "schema_v3.fbs",
+ "schema_v3a.fbs",
+ "schema_v3b.fbs",
+ ],
)
filegroup(
diff --git a/tensorflow/lite/schema/schema_v0.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v0.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v0.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v0.fbs
diff --git a/tensorflow/lite/schema/schema_v1.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v1.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v1.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v1.fbs
diff --git a/tensorflow/lite/schema/schema_v2.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v2.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v2.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v2.fbs
diff --git a/tensorflow/lite/schema/schema_v3.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v3.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v3.fbs
diff --git a/tensorflow/lite/schema/schema_v3a.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3a.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v3a.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v3a.fbs
diff --git a/tensorflow/lite/schema/schema_v3c.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3c.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v3c.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v3c.fbs
diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD
index 92e6aaa..7c15ac2 100644
--- a/tensorflow/compiler/mlir/lite/sparsity/BUILD
+++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD
@@ -54,9 +54,9 @@
],
deps = [
":sparsify_model",
+ "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata",
- "//tensorflow/lite/core:model_builder",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc
index 0d1339d..cc557b5 100644
--- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc
+++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc
@@ -27,9 +27,9 @@
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
#include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h"
-#include "tensorflow/lite/core/model_builder.h"
namespace mlir {
namespace lite {
@@ -41,7 +41,7 @@
std::string expected_value = "test_data";
// Load input model
- auto input_fbm = tflite::FlatBufferModel::BuildFromFile(
+ auto input_fbm = mlir::TFL::FlatBufferModelAbslError::BuildFromFile(
"tensorflow/compiler/mlir/lite/sparsity/testdata/"
"sparse_tensor.bin");
tflite::ModelT input_model;
@@ -60,7 +60,7 @@
// Sparsify and create output model
flatbuffers::FlatBufferBuilder output_builder;
ASSERT_TRUE(SparsifyModel(input_model, &output_builder).ok());
- auto output_fbm = tflite::FlatBufferModel::BuildFromBuffer(
+ auto output_fbm = mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer(
reinterpret_cast<const char*>(output_builder.GetCurrentBufferPointer()),
output_builder.GetSize());
tflite::ModelT output_model;
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD
index 2ba51a7..56c3d89 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD
+++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD
@@ -643,8 +643,12 @@
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:gather",
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad",
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce",
+ "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce_window",
+ "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:slice",
+ "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:sort",
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:util",
"//tensorflow/compiler/mlir/tensorflow",
+ "@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
@@ -691,7 +695,8 @@
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv",
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv_util",
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad_util",
- "@llvm-project//llvm:Support",
+ "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce_window",
+ "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:slice",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir
index 6e69c40..db09eca 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir
@@ -164,6 +164,31 @@
// CHECK-SAME: [0, 0], [0, 0]
// CHECK-SAME: (tensor<1x256x256x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32>
+// -----
+
+// CHECK-LABEL: conv2d_nhwc_ohwi_nhwc_asymmetric_padded
+func.func @conv2d_nhwc_ohwi_nhwc_asymmetric_padded(%input: tensor<1x255x255x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> {
+ %0 = "mhlo.convolution"(%input, %filter) {
+ dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>,
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ window_strides = dense<1> : tensor<2xi64>,
+ padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
+ rhs_dilation = dense<[1, 1]> : tensor<2xi64>,
+ lhs_dilation = dense<[1, 1]> : tensor<2xi64>
+ } : (tensor<1x255x255x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32>
+ func.return %0 : tensor<1x256x256x2xf32>
+}
+
+// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad"
+// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]>
+// CHECK-SAME: edge_padding_low = dense<0>
+// CHECK-SAME: interior_padding = dense<0>
+// CHECK: mhlo.convolution(%[[PADDED_LHS]]
+// CHECK-SAME: pad
+// CHECK-SAME: [0, 0], [0, 0]
+// CHECK-SAME: (tensor<1x256x256x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32>
+
// -----
@@ -417,7 +442,6 @@
// 1D
//=--
-// TODO: b/351437662 - Add support for conv1d.
// CHECK-LABEL: conv1d_nsc_osi_nsc
func.func @conv1d_nsc_osi_nsc(%arg0: tensor<16x32x256xf32>, %arg1: tensor<256x1x256xf32>) -> tensor<16x32x256xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
@@ -546,3 +570,106 @@
// CHECK-SAME: edge_padding_low = dense<[0, 1, 0]>
// CHECK-SAME: (tensor<2x2x3xf32>, tensor<f32>) -> tensor<3x3x3xf32>
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.reduce_window
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: reduce_window_valid_channel_first
+func.func @reduce_window_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
+ // "0xFF800000" represents -INF for f32.
+ %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+ %1 = "mhlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+ mhlo.return %2 : tensor<f32>
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<0> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>,
+ window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor<f32>) -> tensor<4x3x7x7xf32>
+ func.return %1 : tensor<4x3x7x7xf32>
+}
+
+// CHECK: %[[INIT_CST:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
+// CHECK: %[[TPOSE_IN:.*]] = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+// CHECK: %[[RW:.*]] = "mhlo.reduce_window"(%[[TPOSE_IN]], %[[INIT_CST]])
+// CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]>
+// CHECK-SAME: window_strides = dense<[1, 2, 2, 1]>
+// CHECK: %3 = "mhlo.transpose"(%[[RW]]) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32>
+
+// -----
+
+// CHECK-LABEL: reduce_window_same_channel_first
+func.func @reduce_window_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> {
+ // "0xFF800000" represents -INF for f32.
+ %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+ %1 = "mhlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 0], [0, 1], [0, 1]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>,
+ window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor<f32>) -> tensor<4x3x8x8xf32>
+ func.return %1 : tensor<4x3x8x8xf32>
+}
+
+// CHECK: %[[INIT_CST:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
+// CHECK: %[[TPOSE_IN:.*]] = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+// CHECK: %[[RW:.*]] = "mhlo.reduce_window"(%[[TPOSE_IN]], %[[INIT_CST]])
+// CHECK-SAME: padding
+// CHECK-SAME: [0, 0], [0, 1], [0, 1], [0, 0]
+// CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]>
+// CHECK-SAME: window_strides = dense<[1, 2, 2, 1]>
+// CHECK: %3 = "mhlo.transpose"(%[[RW]]) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.dynamic_slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dynamic_slice
+func.func @dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<4x2xf32> {
+ %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+ func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK: mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_ui32
+func.func @dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor<ui32>, %arg2: tensor<ui32>) -> tensor<4x2xf32> {
+ %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<ui32>, tensor<ui32>) -> tensor<4x2xf32>
+ func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK: mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+
+// CHECK-LABEL: dynamic_slice_ui64
+func.func @dynamic_slice_ui64(%arg0: tensor<7x3xf32>, %arg1: tensor<ui64>, %arg2: tensor<ui64>) -> tensor<4x2xf32> {
+ %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<ui64>, tensor<ui64>) -> tensor<4x2xf32>
+ func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK: mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_i64
+func.func @dynamic_slice_i64(%arg0: tensor<7x3xf32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<4x2xf32> {
+ %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
+ func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK: mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir
index 8ca6f44..55f2b65 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir
@@ -72,13 +72,13 @@
func.return %0 : tensor<3x5x1x4xf32>
}
-// CHECK: %[[TRANSPOSED_0:.*]] = "tfl.transpose"
-// CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose"
-// CHECK-NEXT: %[[RESHAPED_0:.*]] = mhlo.reshape %[[TRANSPOSED_0]]
-// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %[[TRANSPOSED_1]]
-// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
-// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]]
-// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32>
+// CHECK: %[[TRANSPOSED_0:.*]] = "tfl.transpose"
+// CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose"
+// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%[[TRANSPOSED_0]]
+// CHECK: %[[RESHAPED_1:.*]] = "tfl.reshape"(%[[TRANSPOSED_1]]
+// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
+// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]]
+// CHECK: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32>
// -----
@@ -96,11 +96,10 @@
func.return %0 : tensor<1x1x1024xf32>
}
-// CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0
-// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1
-// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32>
-// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]]
-// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32>
+// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0
+// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32>
+// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]]
+// CHECK: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32>
// -----
@@ -115,11 +114,10 @@
func.return %0 : tensor<8xi32>
}
-// CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0
-// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1
-// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32>
-// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]]
-// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<8xi32>
+// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0
+// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32>
+// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]]
+// CHECK: return %[[RESHAPED_BMM]] : tensor<8xi32>
// -----
@@ -135,29 +133,30 @@
func.return %0 : tensor<4x4x?xf32>
}
-// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
-// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32>
-// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
-// CHECK-NEXT: %3 = mhlo.reshape %arg0 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32>
-// CHECK-NEXT: %4 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %7 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%4, %5, %7) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %9 = "tfl.unsorted_segment_prod"(%4, %6, %7) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %9, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %12 = mhlo.dynamic_reshape %2, %11 : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
-// CHECK-NEXT: %13 = "tfl.batch_matmul"(%3, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32>
-// CHECK-NEXT: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
-// CHECK-NEXT: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32>
-// CHECK-NEXT: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %21 = mhlo.dynamic_reshape %13, %20 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
-// CHECK-NEXT: return %21 : tensor<4x4x?xf32>
+// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32>
+// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
+// CHECK: %3 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
+// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK: %10 = "tfl.concatenation"(%9, %8, %7) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK: %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK: %12 = "tfl.reshape"(%2, %11) : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
+// CHECK: %13 = "tfl.batch_matmul"(%arg0, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32>
+// CHECK: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32>
+// CHECK: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
+// CHECK: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK: %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK: %22 = "tfl.reshape"(%13, %21) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
+// CHECK: return %22 : tensor<4x4x?xf32>
// -----
@@ -173,43 +172,45 @@
func.return %0 : tensor<2x?x2x4xf32>
}
-// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
-// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
-// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
-// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
-// CHECK-NEXT: %12 = mhlo.dynamic_reshape %arg0, %11 : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32>
-// CHECK-NEXT: %13 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %19 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %20 = "tfl.gather"(%13, %19) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %21 = "tfl.concatenation"(%20, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
-// CHECK-NEXT: %22 = mhlo.dynamic_reshape %2, %21 : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
-// CHECK-NEXT: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32>
-// CHECK-NEXT: %24 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %25 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
-// CHECK-NEXT: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
-// CHECK-NEXT: %28 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
-// CHECK-NEXT: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32>
-// CHECK-NEXT: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32>
-// CHECK-NEXT: %31 = mhlo.dynamic_reshape %23, %30 : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32>
-// CHECK-NEXT: return %31 : tensor<2x?x2x4xf32>
+// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
+// CHECK: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
+// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
+// CHECK: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
+// CHECK: %12 = "tfl.cast"(%11) : (tensor<4xi32>) -> tensor<4xi32>
+// CHECK: %13 = "tfl.reshape"(%arg0, %12) : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32>
+// CHECK: %14 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %17 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %18 = "tfl.unsorted_segment_prod"(%14, %15, %17) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %19 = "tfl.unsorted_segment_prod"(%14, %16, %17) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %20 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK: %21 = "tfl.gather"(%14, %20) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %22 = "tfl.concatenation"(%21, %19, %18) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
+// CHECK: %23 = "tfl.cast"(%22) : (tensor<4xi32>) -> tensor<4xi32>
+// CHECK: %24 = "tfl.reshape"(%2, %23) : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
+// CHECK: %25 = "tfl.batch_matmul"(%13, %24) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32>
+// CHECK: %26 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
+// CHECK: %27 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
+// CHECK: %28 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK: %29 = "tfl.gather"(%26, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
+// CHECK: %30 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK: %31 = "tfl.gather"(%27, %30) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK: %32 = "tfl.concatenation"(%29, %31) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32>
+// CHECK: %33 = "tfl.cast"(%32) : (tensor<4xi32>) -> tensor<4xi32>
+// CHECK: %34 = "tfl.reshape"(%25, %33) : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32>
+// CHECK: return %34 : tensor<2x?x2x4xf32>
// -----
-
// CHECK-LABEL: dot_general_dynamic_lhs_rhs_out_dims
func.func @dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf32>, %arg1: tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
@@ -222,37 +223,40 @@
func.return %0 : tensor<2x2x?x4x?xf32>
}
-// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
-// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
-// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32>
-// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %11 = mhlo.dynamic_reshape %arg0, %10 : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32>
-// CHECK-NEXT: %12 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %13 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %16 = "tfl.unsorted_segment_prod"(%12, %13, %15) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%12, %14, %15) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %19 = "tfl.concatenation"(%18, %17, %16) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %20 = mhlo.dynamic_reshape %2, %19 : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
-// CHECK-NEXT: %21 = "tfl.batch_matmul"(%11, %20) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32>
-// CHECK-NEXT: %22 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %23 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %24 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
-// CHECK-NEXT: %25 = "tfl.gather"(%22, %24) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
-// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %27 = "tfl.gather"(%23, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %28 = "tfl.concatenation"(%25, %27) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
-// CHECK-NEXT: %29 = mhlo.dynamic_reshape %21, %28 : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32>
-// CHECK-NEXT: return %29 : tensor<2x2x?x4x?xf32>
+// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
+// CHECK: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
+// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32>
+// CHECK: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK: %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK: %12 = "tfl.reshape"(%arg0, %11) : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32>
+// CHECK: %13 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %14 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %19 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK: %20 = "tfl.concatenation"(%19, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK: %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK: %22 = "tfl.reshape"(%2, %21) : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
+// CHECK: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32>
+// CHECK: %24 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
+// CHECK: %25 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
+// CHECK: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
+// CHECK: %28 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
+// CHECK: %31 = "tfl.cast"(%30) : (tensor<5xi32>) -> tensor<5xi32>
+// CHECK: %32 = "tfl.reshape"(%23, %31) : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32>
+// CHECK: return %32 : tensor<2x2x?x4x?xf32
// -----
@@ -268,27 +272,28 @@
func.return %0 : tensor<4x4x256xf32>
}
-// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
-// CHECK-NEXT: %9 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %11 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %12 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %13 = "tfl.unsorted_segment_prod"(%9, %10, %12) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %14 = "tfl.unsorted_segment_prod"(%9, %11, %12) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %16 = "tfl.concatenation"(%15, %14, %13) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %17 = mhlo.dynamic_reshape %arg1, %16 : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32>
-// CHECK-NEXT: %18 = "tfl.batch_matmul"(%8, %17) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
-// CHECK-NEXT: %19 = mhlo.reshape %18 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32>
-// CHECK-NEXT: return %19 : tensor<4x4x256xf32>
+// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32>
+// CHECK-DAG: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK: %8 = "tfl.cast"(%7) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK: %9 = "tfl.reshape"(%arg0, %8) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
+// CHECK: %10 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
+// CHECK-DAG: %11 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %12 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %13 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %14 = "tfl.unsorted_segment_prod"(%10, %11, %13) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %15 = "tfl.unsorted_segment_prod"(%10, %12, %13) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK: %16 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK: %17 = "tfl.concatenation"(%16, %15, %14) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK: %18 = "tfl.cast"(%17) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK: %19 = "tfl.reshape"(%arg1, %18) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32>
+// CHECK: %20 = "tfl.batch_matmul"(%9, %19) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
+// CHECK: return %20 : tensor<4x4x256xf32>
// -----
@@ -568,13 +573,15 @@
}
// CHECK-DAG: %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
-// CHECK: %1 = mhlo.constant dense<0> : tensor<i32>
+// CHECK-DAG: %1 = mhlo.constant dense<0> : tensor<i32>
// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32>
-// CHECK: %3 = mhlo.reshape %2 : (tensor<32xi32>) -> tensor<1x32x1xi32>
-// CHECK: %cst = arith.constant dense<1> : tensor<1xi32>
-// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32>
-// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32>
-// CHECK: return %4, %5 : tensor<1x1xf32>, tensor<1x1xi32>
+// CHECK: %cst = arith.constant dense<[1, 32, 1]> : tensor<3xi64>
+// CHECK: %3 = "tfl.cast"(%cst) : (tensor<3xi64>) -> tensor<3xi32>
+// CHECK: %4 = "tfl.reshape"(%2, %3) : (tensor<32xi32>, tensor<3xi32>) -> tensor<1x32x1xi32>
+// CHECK: %cst_0 = arith.constant dense<1> : tensor<1xi32>
+// CHECK: %5 = "tfl.reduce_max"(%arg0, %cst_0) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32>
+// CHECK: %6 = "tfl.arg_max"(%arg0, %cst_0) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32>
+// CHECK: return %5, %6 : tensor<1x1xf32>, tensor<1x1xi32>
// -----
@@ -597,14 +604,16 @@
func.return %4#1 : tensor<1xi32>
}
-// CHECK: %0 = mhlo.constant dense<0> : tensor<i32>
+// CHECK-DAG: %0 = mhlo.constant dense<0> : tensor<i32>
// CHECK-DAG: %1 = mhlo.constant dense<-2147483648> : tensor<i32>
// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32>
-// CHECK: %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32>
-// CHECK: %cst = arith.constant dense<1> : tensor<1xi32>
-// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
-// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
-// CHECK: return %5 : tensor<1xi32>
+// CHECK: %cst = arith.constant dense<[1, 9]> : tensor<2xi64>
+// CHECK: %3 = "tfl.cast"(%cst) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %4 = "tfl.reshape"(%2, %3) : (tensor<9xi32>, tensor<2xi32>) -> tensor<1x9xi32>
+// CHECK: %cst_0 = arith.constant dense<1> : tensor<1xi32>
+// CHECK: %5 = "tfl.reduce_max"(%arg0, %cst_0) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
+// CHECK: %6 = "tfl.arg_max"(%arg0, %cst_0) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
+// CHECK: return %6 : tensor<1xi32>
// -----
@@ -618,11 +627,11 @@
func.return %0 : tensor<1x32x1xf32>
}
-// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
-// CHECK: %cst_0 = arith.constant dense<3.000000e+00> : tensor<f32>
-// CHECK: %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor<f32>
-// CHECK: %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor<f32>) -> tensor<1x32x1xf32>
-// CHECK: return %1 : tensor<1x32x1xf32>
+// CHECK-DAG: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK-DAG: %cst_0 = arith.constant dense<3.000000e+00> : tensor<f32>
+// CHECK: %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor<f32>
+// CHECK: %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor<f32>) -> tensor<1x32x1xf32>
+// CHECK: return %1 : tensor<1x32x1xf32>
// -----
@@ -637,6 +646,44 @@
// -----
//===----------------------------------------------------------------------===//
+// mhlo.(dynamic)reshape
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: reshape
+func.func @reshape(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
+ %0 = "mhlo.reshape"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2xf32>
+ func.return %0 : tensor<3x2xf32>
+}
+
+// CHECK: %cst = arith.constant dense<[3, 2]> : tensor<2xi64>
+// CHECK: %0 = "tfl.cast"(%cst) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_reshape_i32
+func.func @dynamic_reshape_i32(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
+ %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+ func.return %0 : tensor<?x?xf32>
+}
+
+// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi32>) -> tensor<2xi32>
+// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_reshape_i64
+func.func @dynamic_reshape_i64(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi64>) -> tensor<?x?xf32> {
+ %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<?x?xf32>
+ func.return %0 : tensor<?x?xf32>
+}
+
+// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
// mhlo.convolution
//===----------------------------------------------------------------------===//
@@ -1473,7 +1520,7 @@
func.return %0 : tensor<4x64x128xf32>
}
-// CHECK: %[[VAL_0:.*]] = mhlo.reshape %arg1 : (tensor<4x64xi32>) -> tensor<4x64x1xi32>
+// CHECK: %[[VAL_0:.*]] = "tfl.reshape"(%arg1, %0) : (tensor<4x64xi32>, tensor<3xi32>) -> tensor<4x64x1xi32
// CHECK: %[[VAL_1:.*]] = "tfl.gather_nd"(%arg0, %[[VAL_0]]) : (tensor<98x128xf32>, tensor<4x64x1xi32>) -> tensor<4x64x128xf32>
// -----
@@ -1582,3 +1629,1209 @@
}
// CHECK: %0 = "tfl.gather_nd"(%arg0, %arg1) : (tensor<256000xf32>, tensor<?x?x1xi32>) -> tensor<?x?xf32>
+
+// -----
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> avg pool
+//===------------------------------------------------------------------------===
+
+// CHECK-LABEL: avgpool_same_channel_first
+func.func @avgpool_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> {
+ %0 = mhlo.constant dense<1.000000e+00> : tensor<4x16x16x3xf32>
+ %1 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+ %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+ %3 = "mhlo.reduce_window"(%2, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %8 = mhlo.add %arg1, %arg2 : tensor<f32>
+ mhlo.return %8 : tensor<f32>
+ }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x8x8x3xf32>
+ %4 = "mhlo.transpose"(%3) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+ %5 = "mhlo.reduce_window"(%0, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %8 = mhlo.add %arg1, %arg2 : tensor<f32>
+ mhlo.return %8 : tensor<f32>
+ }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x8x8x3xf32>
+ %6 = "mhlo.transpose"(%5) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+ %7 = mhlo.divide %4, %6 : tensor<4x3x8x8xf32>
+ return %7 : tensor<4x3x8x8xf32>
+}
+
+// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK-SAME: (tensor<4x3x16x16xf32>, tensor<4xi32>) -> tensor<4x16x16x3xf32>
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x8x8x3xf32>
+// CHECK: %[[TPOSED_OUT:.*]] = "tfl.transpose"(%[[POOL_OUT]]
+// CHECK-SAME: (tensor<4x8x8x3xf32>, tensor<4xi32>) -> tensor<4x3x8x8xf32>
+// CHECK: return %[[TPOSED_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_channel_first
+func.func @avgpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
+ %0 = mhlo.constant dense<9.000000e+00> : tensor<4x3x7x7xf32>
+ %1 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+ %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+ %3 = "mhlo.reduce_window"(%2, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+ mhlo.return %6 : tensor<f32>
+ }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x7x7x3xf32>
+ %4 = "mhlo.transpose"(%3) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32>
+ %5 = mhlo.divide %4, %0 : tensor<4x3x7x7xf32>
+ return %5 : tensor<4x3x7x7xf32>
+}
+
+// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK-SAME: (tensor<4x3x16x16xf32>, tensor<4xi32>) -> tensor<4x16x16x3xf32>
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x7x7x3xf32>
+// CHECK: %[[TPOSED_OUT:.*]] = "tfl.transpose"(%[[POOL_OUT]]
+// CHECK-SAME: (tensor<4x7x7x3xf32>, tensor<4xi32>) -> tensor<4x3x7x7xf32>
+// CHECK: return %[[TPOSED_OUT]]
+
+// -----
+
+func.func @avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+ %0 = mhlo.constant dense<0.0> : tensor<f32>
+ %1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
+ %2 = "mhlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %5 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%5) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<0> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+ %3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
+ func.return %3 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_broadcasted_divisor
+func.func @avgpool_valid_broadcasted_divisor(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+ %0 = mhlo.constant dense<0.0> : tensor<f32>
+ %1 = mhlo.constant dense<9.0> : tensor<f32>
+ %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<4x7x7x8xf32>
+ %3 = "mhlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %5 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%5) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<0> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+ %4 = mhlo.divide %3, %2 : tensor<4x7x7x8xf32>
+ func.return %4 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_rw
+func.func @avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+ %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
+ %1 = mhlo.constant dense<0.0> : tensor<f32>
+ %2 = "mhlo.reduce_window"(%arg0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+ %3 = "mhlo.reduce_window"(%0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+ %4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
+ func.return %4 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_rw_broadcasted_const_lhs
+func.func @avgpool_valid_rw_broadcasted_const_lhs(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+ %0 = mhlo.constant dense<1.0> : tensor<f32>
+ %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<4x16x16x8xf32>
+ %2 = mhlo.constant dense<0.0> : tensor<f32>
+ %3 = "mhlo.reduce_window"(%arg0, %2) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+ %4 = "mhlo.reduce_window"(%1, %2) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+ %5 = mhlo.divide %3, %4 : tensor<4x7x7x8xf32>
+ func.return %5 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_same
+func.func @avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+ %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
+ %1 = mhlo.constant dense<0.0> : tensor<f32>
+ %2 = "mhlo.reduce_window"(%arg0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+ %3 = "mhlo.reduce_window"(%0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+ %4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
+ func.return %4 : tensor<4x8x8x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_reshape_broadcast
+func.func @avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+ %0 = mhlo.constant dense<1.000000e+00> : tensor<1x16x16x1xf32>
+ %1 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+ %2 = "mhlo.reduce_window"(%arg0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %7 = mhlo.add %arg1, %arg2 : tensor<f32>
+ mhlo.return %7 : tensor<f32>
+ }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+ %3 = "mhlo.reduce_window"(%0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %7 = mhlo.add %arg1, %arg2 : tensor<f32>
+ mhlo.return %7 : tensor<f32>
+ }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x16x16x1xf32>, tensor<f32>) -> tensor<1x8x8x1xf32>
+ %4 = mhlo.reshape %3 : (tensor<1x8x8x1xf32>) -> tensor<8x8xf32>
+ %5 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<8x8xf32>) -> tensor<4x8x8x8xf32>
+ %6 = mhlo.divide %2, %5 : tensor<4x8x8x8xf32>
+ return %6 : tensor<4x8x8x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> max pool
+//===------------------------------------------------------------------------===
+
+// CHECK-LABEL: maxpool_same
+func.func @maxpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+ // "0xFF800000" represents -INF for f32.
+ %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+ %1 = "mhlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+ func.return %1 : tensor<4x8x8x8xf32>
+}
+
+// CHECK: %1 = "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+
+// -----
+
+// CHECK-LABEL: maxpool_valid
+func.func @maxpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+ // "0xFF800000" represents -INF for f32.
+ %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+ %1 = "mhlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%6) : (tensor<f32>) -> ()
+ }) {
+ base_dilations = dense<1> : tensor<4xi64>,
+ padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+ window_dilations = dense<1> : tensor<4xi64>,
+ window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+ func.return %1 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %1 = "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+
+// -----
+
+// CHECK-LABEL: maxpool_valid_channel_first
+func.func @maxpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
+ // "0xFF800000" represents -INF for f32.
+ %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+ %1 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+ %2 = "mhlo.reduce_window"(%1, %0) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %4 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+ mhlo.return %4 : tensor<f32>
+ }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x7x7x3xf32>
+ %3 = "mhlo.transpose"(%2) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32>
+ return %3 : tensor<4x3x7x7xf32>
+}
+
+// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK: "tfl.max_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x7x7x3xf32>
+// CHECK: return
+// CHECK-SAME: tensor<4x3x7x7xf32>
+
+// -----
+
+// CHECK-LABEL: maxpool_same_channel_first
+func.func @maxpool_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> {
+ // "0xFF800000" represents -INF for f32.
+ %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+ %1 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+ %2 = "mhlo.reduce_window"(%1, %0) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %4 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+ mhlo.return %4 : tensor<f32>
+ }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x8x8x3xf32>
+ %3 = "mhlo.transpose"(%2) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+ return %3 : tensor<4x3x8x8xf32>
+}
+
+// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK: "tfl.max_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x8x8x3xf32>
+// CHECK: return
+// CHECK-SAME: tensor<4x3x8x8xf32>
+
+// -----
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> tfl.cumsum
+//===------------------------------------------------------------------------===
+
+// CHECK-LABEL: reduce_window_sum
+func.func @reduce_window_sum(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
+ %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+ %1 = "mhlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %2 = mhlo.add %arg1, %arg2 : tensor<f32>
+ "mhlo.return"(%2) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[3, 0], [0, 0]]> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[4, 1]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
+ func.return %1 : tensor<4x12xf32>
+}
+
+// CHECK: %[[AXIS:.*]] = arith.constant dense<0> : tensor<i32>
+// CHECK: "tfl.cumsum"(%arg0, %[[AXIS]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i32>) -> tensor<4x12xf32>
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: slice
+func.func @slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
+ %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x4672xf32>) -> tensor<1x519xf32>
+ func.return %0 : tensor<1x519xf32>
+}
+
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 4153]> : tensor<2xi64>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<[1, 4672]> : tensor<2xi64>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<1> : tensor<2xi64>
+// CHECK: %[[VAL_0:.*]] = "tfl.cast"(%[[CST]]) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %[[VAL_1:.*]] = "tfl.cast"(%[[CST_0]]) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %[[VAL_2:.*]] = "tfl.cast"(%[[CST_1]]) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %[[VAL_3:.*]] = "tfl.strided_slice"(%arg0, %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x4672xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x519xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.sort
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: sort_to_topk_iota_broadcast
+func.func @sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
+ %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<6xi32>
+ %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32>
+ %2:2 = "mhlo.sort"(%arg0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
+ %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%3) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+ func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32>
+}
+
+// CHECK: %cst = arith.constant dense<6> : tensor<i32>
+// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+
+// -----
+
+// CHECK-LABEL: sort_to_topk_iota_cst_broadcast
+func.func @sort_to_topk_iota_cst_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
+ %0 = mhlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>
+ %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32>
+ %2:2 = "mhlo.sort"(%arg0, %1) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
+ %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%3) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+ func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32>
+}
+
+// CHECK: %cst = arith.constant dense<6> : tensor<i32>
+// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+
+// -----
+
+// CHECK-LABEL: sort_to_topk_const
+func.func @sort_to_topk_const(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
+ %0 = mhlo.constant dense<[[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]> : tensor<3x6xi32>
+ %1:2 = "mhlo.sort"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
+ %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%3) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+ func.return %1#0, %1#1 : tensor<3x6xf32>, tensor<3x6xi32>
+}
+
+// CHECK: %cst = arith.constant dense<6> : tensor<i32>
+// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.dynamic_slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dynamic_slice
+func.func @dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<4x2xf32> {
+ %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+ func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor<i32>
+// CHECK-DAG: %[[CST_IS_3:.*]] = arith.constant dense<3> : tensor<i32>
+// CHECK: %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_3]], %[[MAX_1]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor<i32>
+// CHECK: %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+// CHECK: %[[SLICE_SIZE:.*]] = arith.constant dense<[4, 2]> : tensor<2xi64>
+// CHECK: "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_i64
+func.func @dynamic_slice_i64(%arg0: tensor<7x3xf32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<4x2xf32> {
+ %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
+ func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor<i64>
+// CHECK-DAG: %[[CST_IS_3:.*]] = arith.constant dense<3> : tensor<i64>
+// CHECK: %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK: %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_3]], %[[MAX_1]]) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK: %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor<i64>
+// CHECK: %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK: %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK: %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<i64>, tensor<i64>) -> tensor<2xi64>
+// CHECK: %[[SLICE_SIZE:.*]] = arith.constant dense<[4, 2]> : tensor<2xi64>
+// CHECK: "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_splat_sizes
+func.func @dynamic_slice_splat_sizes(%arg0: tensor<7x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<2x2xf32> {
+ %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<2> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<2x2xf32>
+ func.return %0 : tensor<2x2xf32>
+}
+
+// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor<i32>
+// CHECK-DAG: %[[CST_IS_5:.*]] = arith.constant dense<5> : tensor<i32>
+// CHECK: %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_5]], %[[MAX_1]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor<i32>
+// CHECK: %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK: %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+// CHECK: %[[SLICE_SIZE:.*]] = arith.constant dense<2> : tensor<2xi64>
+// CHECK: "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<2x2xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// rounding
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: round
+func.func @round(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> {
+ %0 = mhlo.constant dense<2.000000e+00> : tensor<8x128xf32>
+ %1 = mhlo.constant dense<5.000000e-01> : tensor<8x128xf32>
+ %2 = mhlo.constant dense<1.000000e+00> : tensor<8x128xf32>
+ %3 = "mhlo.floor"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+ %4 = mhlo.subtract %arg0, %3 : tensor<8x128xf32>
+ %5 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1>
+ %6 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1>
+ %7 = mhlo.multiply %arg0, %1 : tensor<8x128xf32>
+ %8 = "mhlo.floor"(%7) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+ %9 = mhlo.multiply %8, %0 : tensor<8x128xf32>
+ %10 = mhlo.subtract %3, %9 : tensor<8x128xf32>
+ %11 = "mhlo.compare"(%10, %2) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1>
+ %12 = mhlo.and %6, %11 : tensor<8x128xi1>
+ %13 = mhlo.or %5, %12 : tensor<8x128xi1>
+ %14 = mhlo.add %3, %2 : tensor<8x128xf32>
+ %15 = "mhlo.select"(%13, %14, %3) : (tensor<8x128xi1>, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32>
+ func.return %15 : tensor<8x128xf32>
+}
+
+// CHECK: "tfl.round"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_float
+func.func @floor_mod_float(%arg0: tensor<192x8xf32>, %arg1: tensor<192x8xf32>) -> tensor<192x8xf32> {
+ %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32>
+ %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xf32>
+ %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+ %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+ %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1>
+ %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+ %6 = mhlo.and %4, %5 : tensor<192x8xi1>
+ %7 = mhlo.add %1, %arg1 : tensor<192x8xf32>
+ %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+ func.return %8 : tensor<192x8xf32>
+}
+
+// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_int
+func.func @floor_mod_int(%arg0: tensor<192x8xi32>, %arg1: tensor<192x8xi32>) -> tensor<192x8xi32> {
+ %0 = mhlo.constant dense<0> : tensor<192x8xi32>
+ %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xi32>
+ %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+ %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+ %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1>
+ %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+ %6 = mhlo.and %4, %5 : tensor<192x8xi1>
+ %7 = mhlo.add %1, %arg1 : tensor<192x8xi32>
+ %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+ func.return %8 : tensor<192x8xi32>
+}
+
+// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_float_cst
+func.func @floor_mod_float_cst(%arg0: tensor<192x8xf32>) -> tensor<192x8xf32> {
+ %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32>
+ %1 = mhlo.constant dense<2.000000e+00> : tensor<192x8xf32>
+ %2 = mhlo.remainder %arg0, %1 : tensor<192x8xf32>
+ %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+ %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+ %5 = mhlo.and %3, %4 : tensor<192x8xi1>
+ %6 = mhlo.add %2, %1 : tensor<192x8xf32>
+ %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+ func.return %7 : tensor<192x8xf32>
+}
+
+// CHECK: %cst = arith.constant dense<2.000000e+00> : tensor<192x8xf32>
+// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_int_cst
+func.func @floor_mod_int_cst(%arg0: tensor<192x8xi32>) -> tensor<192x8xi32> {
+ %0 = mhlo.constant dense<0> : tensor<192x8xi32>
+ %1 = mhlo.constant dense<2> : tensor<192x8xi32>
+ %2 = mhlo.remainder %arg0, %1 : tensor<192x8xi32>
+ %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+ %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+ %5 = mhlo.and %3, %4 : tensor<192x8xi1>
+ %6 = mhlo.add %2, %1 : tensor<192x8xi32>
+ %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+ func.return %7 : tensor<192x8xi32>
+}
+
+// CHECK: %cst = arith.constant dense<2> : tensor<192x8xi32>
+// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+
+// -----
+
+// CHECK-LABEL: floor_div
+func.func @floor_div(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>) -> tensor<10x10xf32> {
+ %0 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32>
+ %1 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32>
+ %2 = mhlo.remainder %arg0, %arg1 : tensor<10x10xf32>
+ %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+ %4 = "mhlo.sign"(%arg1) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+ %5 = "mhlo.sign"(%2) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+ %6 = "mhlo.compare"(%4, %5) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+ %7 = mhlo.and %3, %6 : tensor<10x10xi1>
+ %8 = mhlo.subtract %arg0, %2 : tensor<10x10xf32>
+ %9 = mhlo.divide %8, %arg1 : tensor<10x10xf32>
+ %10 = mhlo.add %9, %1 : tensor<10x10xf32>
+ %11 = "mhlo.select"(%7, %10, %9) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+ %12 = "mhlo.round_nearest_afz"(%11) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+ %13 = "mhlo.tuple"(%12) : (tensor<10x10xf32>) -> tuple<tensor<10x10xf32>>
+ func.return %12 : tensor<10x10xf32>
+}
+
+// CHECK: tfl.floor_div %arg0, %arg1 : tensor<10x10xf32
+
+// -----
+
+// CHECK-LABEL: floor_div_cst
+func.func @floor_div_cst(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
+ %0 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+ %1 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32>
+ %2 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32>
+ %3 = mhlo.constant dense<5.000000e-01> : tensor<10x10xf32>
+ %4 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32>
+ %5 = mhlo.remainder %arg0, %0 : tensor<10x10xf32>
+ %6 = "mhlo.compare"(%5, %1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+ %7 = "mhlo.sign"(%5) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+ %8 = "mhlo.compare"(%2, %7) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+ %9 = mhlo.and %6, %8 : tensor<10x10xi1>
+ %10 = mhlo.subtract %arg0, %5 : tensor<10x10xf32>
+ %11 = mhlo.multiply %10, %3 : tensor<10x10xf32>
+ %12 = mhlo.add %11, %4 : tensor<10x10xf32>
+ %13 = "mhlo.select"(%9, %12, %11) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+ %14 = "mhlo.round_nearest_afz"(%13) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+ %15 = "mhlo.tuple"(%14) : (tensor<10x10xf32>) -> tuple<tensor<10x10xf32>>
+ func.return %14 : tensor<10x10xf32>
+}
+
+// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32>
+
+// -----
+
+// CHECK-LABEL: floor_div_cst2
+func.func @floor_div_cst2(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
+ %0 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32>
+ %1 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+ %2 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32>
+ %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32>
+ %4 = mhlo.remainder %arg0, %1 : tensor<10x10xf32>
+ %5 = "mhlo.compare"(%4, %2) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+ %6 = "mhlo.sign"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+ %7 = "mhlo.compare"(%0, %6) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+ %8 = mhlo.and %5, %7 : tensor<10x10xi1>
+ %9 = mhlo.subtract %arg0, %4 : tensor<10x10xf32>
+ %10 = mhlo.divide %9, %1 : tensor<10x10xf32>
+ %11 = mhlo.add %10, %3 : tensor<10x10xf32>
+ %12 = "mhlo.select"(%8, %11, %10) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+ %13 = "mhlo.round_nearest_afz"(%12) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+ %14 = "mhlo.tuple"(%13) : (tensor<10x10xf32>) -> tuple<tensor<10x10xf32>>
+ func.return %13 : tensor<10x10xf32>
+}
+
+// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32>
+
+// -----
+
+// CHECK-LABEL: floor_div_broadcast_cst
+func.func @floor_div_broadcast_cst(%arg0: tensor<10x8xf32>) -> tensor<10x8xf32> {
+ %0 = mhlo.constant dense<1.000000e+00> : tensor<10x8xf32>
+ %1 = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00, 1.600000e+01, 3.200000e+01, 6.400000e+01, 1.280000e+02]> : tensor<8xf32>
+ %2 = mhlo.constant dense<0.000000e+00> : tensor<10x8xf32>
+ %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x8xf32>
+ %5 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>) -> tensor<10x8xf32>
+ %6 = mhlo.remainder %arg0, %5 : tensor<10x8xf32>
+ %7 = "mhlo.compare"(%6, %2) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1>
+ %8 = "mhlo.sign"(%6) : (tensor<10x8xf32>) -> tensor<10x8xf32>
+ %9 = "mhlo.compare"(%0, %8) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1>
+ %10 = mhlo.and %7, %9 : tensor<10x8xi1>
+ %11 = mhlo.subtract %arg0, %6 : tensor<10x8xf32>
+ %12 = mhlo.divide %11, %5 : tensor<10x8xf32>
+ %13 = mhlo.add %12, %3 : tensor<10x8xf32>
+ %14 = "mhlo.select"(%10, %13, %12) : (tensor<10x8xi1>, tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xf32>
+ %15 = "mhlo.round_nearest_afz"(%14) : (tensor<10x8xf32>) -> tensor<10x8xf32>
+ %16 = "mhlo.tuple"(%15) : (tensor<10x8xf32>) -> tuple<tensor<10x8xf32>>
+ func.return %15 : tensor<10x8xf32>
+}
+
+// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%1)
+// CHECK: tfl.floor_div %arg0, %[[BCAST]] : tensor<10x8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// unary elementwise
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: convert_i32_f32
+func.func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
+ %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.cast
+
+// -----
+
+// CHECK-LABEL: abs
+func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.abs
+
+// -----
+
+// CHECK-LABEL: abs_dynamic
+func.func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.abs
+
+// -----
+
+// CHECK-LABEL: ceil
+func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.ceil
+
+// -----
+
+// CHECK-LABEL: ceil_dynamic
+func.func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.ceil
+
+// -----
+
+// CHECK-LABEL: complex_abs
+func.func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
+ %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK-NOT: tfl
+
+// -----
+
+func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
+ %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2xf32>
+// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<f32>
+// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor<2xf32>, tensor<f32>) -> tensor<2xi1>
+// CHECK: return %1 : tensor<2xi1>
+
+// -----
+
+func.func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
+ %0 = "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
+ func.return %0 : tensor<?xi1>
+}
+
+// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<?xf32>
+// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<f32>
+// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
+
+// -----
+
+// CHECK-LABEL: cos
+func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.cos
+
+// -----
+
+// CHECK-LABEL: cos_dynamic
+func.func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.cos
+
+// -----
+
+// CHECK-LABEL: logistic
+func.func @logistic(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.logistic"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.logistic
+
+// -----
+
+// CHECK-LABEL: exp
+func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.exp
+
+// -----
+
+// CHECK-LABEL: exp_dynamic
+func.func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.exp
+
+// -----
+
+// CHECK-LABEL: expm1
+func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: %0 = "tfl.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK: %1 = tfl.sub(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
+
+// -----
+
+// CHECK-LABEL: floor
+func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.floor
+
+// -----
+
+// CHECK-LABEL: floor_dynamic
+func.func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.floor
+
+// -----
+
+// CHECK-LABEL: log
+func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.log
+
+// -----
+
+// CHECK-LABEL: log_dynamic
+func.func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.log
+
+// -----
+
+// CHECK-LABEL: log1p
+func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
+// CHECK: %1 = "tfl.log"(%0) : (tensor<2xf32>) -> tensor<2xf32>
+
+// -----
+
+// CHECK-LABEL: log1p_dynamic
+func.func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
+// CHECK: %1 = "tfl.log"(%0) : (tensor<?xf32>) -> tensor<?xf32>
+
+// -----
+
+// CHECK-LABEL: neg
+func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.neg
+
+// -----
+
+// CHECK-LABEL: neg_dynamic
+func.func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.neg
+
+// -----
+
+// CHECK-LABEL: sin
+func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.sin
+
+// -----
+
+// CHECK-LABEL: sin_dynamic
+func.func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.sin
+
+// -----
+
+// CHECK-LABEL: rsqrt
+func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.rsqrt
+
+// -----
+
+// CHECK-LABEL: rsqrt_dynamic
+func.func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.rsqrt
+
+// -----
+
+// CHECK-LABEL: @sqrt
+func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.sqrt
+
+// -----
+
+// CHECK-LABEL: sqrt_dynamic
+func.func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.sqrt
+
+// -----
+
+// CHECK-LABEL: tanh
+func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.tanh
+
+// -----
+
+// CHECK-LABEL: tanh_dynamic
+func.func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.tanh
+
+// -----
+
+// CHECK-LABEL: bitcast
+func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.bitcast
+
+// -----
+
+// CHECK-LABEL: bitcast_dynamic
+func.func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.bitcast
+
+// -----
+
+// CHECK-LABEL: bitcast_same_widths
+func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
+ %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
+ func.return %0 : tensor<2xi32>
+}
+
+// CHECK: tfl.bitcast
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// logical and bitwise ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: not
+func.func @not(%arg0: tensor<5x3x1xi1>) -> tensor<5x3x1xi1> {
+ %0 = "mhlo.not"(%arg0): (tensor<5x3x1xi1>) -> (tensor<5x3x1xi1>)
+ func.return %0 : tensor<5x3x1xi1>
+}
+
+// CHECK: %0 = "tfl.logical_not"(%arg0) : (tensor<5x3x1xi1>) -> tensor<5x3x1xi1>
+
+// -----
+
+// CHECK-LABEL: not_i8
+func.func @not_i8(%arg0: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> {
+ %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi8>) -> (tensor<7x9x11xi8>)
+ func.return %0 : tensor<7x9x11xi8>
+}
+
+// CHECK: %cst = arith.constant dense<-1> : tensor<i8>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi8>, tensor<i8>) -> tensor<7x9x11xi8>
+
+// -----
+
+// CHECK-LABEL: not_i16
+func.func @not_i16(%arg0: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> {
+ %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi16>) -> (tensor<7x9x11xi16>)
+ func.return %0 : tensor<7x9x11xi16>
+}
+
+// CHECK: %cst = arith.constant dense<-1> : tensor<i16>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi16>, tensor<i16>) -> tensor<7x9x11xi16>
+
+// -----
+
+// CHECK-LABEL: not_i32
+func.func @not_i32(%arg0: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> {
+ %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi32>) -> (tensor<7x9x11xi32>)
+ func.return %0 : tensor<7x9x11xi32>
+}
+
+// CHECK: %cst = arith.constant dense<-1> : tensor<i32>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi32>, tensor<i32>) -> tensor<7x9x11xi32>
+
+// -----
+
+// CHECK-LABEL: not_ui8
+func.func @not_ui8(%arg0: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> {
+ %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui8>) -> (tensor<7x9x11xui8>)
+ func.return %0 : tensor<7x9x11xui8>
+}
+
+// CHECK: %cst = arith.constant dense<255> : tensor<ui8>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui8>, tensor<ui8>) -> tensor<7x9x11xui8>
+
+// -----
+
+// CHECK-LABEL: not_ui16
+func.func @not_ui16(%arg0: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> {
+ %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui16>) -> (tensor<7x9x11xui16>)
+ func.return %0 : tensor<7x9x11xui16>
+}
+
+// CHECK: %cst = arith.constant dense<65535> : tensor<ui16>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui16>, tensor<ui16>) -> tensor<7x9x11xui16>
+
+// -----
+
+// CHECK-LABEL: not_ui32
+func.func @not_ui32(%arg0: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> {
+ %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui32>) -> (tensor<7x9x11xui32>)
+ func.return %0 : tensor<7x9x11xui32>
+}
+
+// CHECK: %cst = arith.constant dense<4294967295> : tensor<ui32>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui32>, tensor<ui32>) -> tensor<7x9x11xui32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// binary ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: remainder
+func.func @remainder(%arg0: tensor<10x8xi32>, %arg1: tensor<10x8xi32>) -> tensor<10x8xi32> {
+ %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xi32>
+ func.return %0 : tensor<10x8xi32>
+}
+
+// CHECK: %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<10x8xi32>, tensor<10x8xi32>) -> tensor<10x8xi32>
+
+// -----
+
+// CHECK-LABEL: shift_right_arith
+func.func @shift_right_arith(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+ %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
+ func.return %0 : tensor<4xi32>
+}
+
+// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+
+// -----
+
+// CHECK-LABEL: shift_right_logical
+func.func @shift_right_logical(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+ %0 = mhlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
+ func.return %0 : tensor<4xi32>
+}
+
+// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.compare
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: greater_unsupported_compare_type
+func.func @greater_unsupported_compare_type(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xi1> {
+ %0 = "mhlo.compare"(%arg0, %arg1) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK-NOT: tfl
+// CHECK: mhlo.compare
+
+// -----
+
+// CHECK-LABEL: equal
+func.func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+ %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.equal
+
+// -----
+
+// CHECK-LABEL: notequal
+func.func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+ %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.not_equal
+
+// -----
+
+// CHECK-LABEL: greater
+func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+ %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.greater
+
+// -----
+
+// CHECK-LABEL: greater_equal
+func.func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+ %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.greater_equal
+
+// -----
+
+// CHECK-LABEL: less
+func.func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+ %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.less
+
+// -----
+
+// CHECK-LABEL: less_equal
+func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+ %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.less_equal
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD
index 19d9863..a277c0b 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD
@@ -21,7 +21,6 @@
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
@@ -140,6 +139,7 @@
srcs = ["conv_util.cc"],
hdrs = ["conv_util.h"],
deps = [
+ ":op_util_common",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
@@ -153,6 +153,7 @@
srcs = ["pad.cc"],
hdrs = ["pad.h"],
deps = [
+ ":op_util_common",
":pad_util",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@llvm-project//llvm:Support",
@@ -169,6 +170,7 @@
srcs = ["pad_util.cc"],
hdrs = ["pad_util.h"],
deps = [
+ ":op_util_common",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
@@ -191,3 +193,77 @@
"@local_xla//xla/mlir_hlo",
],
)
+
+cc_library(
+ name = "reduce_window",
+ srcs = ["reduce_window.cc"],
+ hdrs = ["reduce_window.h"],
+ deps = [
+ ":op_util_common",
+ ":reduce_window_util",
+ ":util",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ "@local_xla//xla/mlir_hlo",
+ ],
+)
+
+cc_library(
+ name = "op_util_common",
+ srcs = ["op_util_common.cc"],
+ hdrs = ["op_util_common.h"],
+ deps = [
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ ],
+)
+
+cc_library(
+ name = "reduce_window_util",
+ srcs = ["reduce_window_util.cc"],
+ hdrs = ["reduce_window_util.h"],
+ deps = [
+ ":op_util_common",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ "@local_xla//xla/mlir_hlo",
+ ],
+)
+
+cc_library(
+ name = "slice",
+ srcs = ["slice.cc"],
+ hdrs = ["slice.h"],
+ deps = [
+ ":op_util_common",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ "@local_xla//xla/mlir_hlo",
+ ],
+)
+
+cc_library(
+ name = "sort",
+ srcs = ["sort.cc"],
+ hdrs = ["sort.h"],
+ deps = [
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "//tensorflow/compiler/mlir/lite/stablehlo:hlo_matchers",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ "@local_xla//xla/mlir_hlo",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc
index 33b62b5..87a429d 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc
@@ -44,27 +44,27 @@
return llvm::all_of(shape, [](int64_t d) { return d >= 0; });
}
-bool AreShapesSupported(const ConvData& data) {
+bool AreShapesSupported(const ConvView& data) {
return IsShapeFullyStatic(data.InputShape()) &&
IsShapeFullyStatic(data.KernelShape()) &&
IsShapeFullyStatic(data.OutputShape());
}
-bool IsPaddingSupported(const ConvData& data) {
+bool IsPaddingSupported(const ConvView& data) {
return llvm::all_of(data.Padding(), [](const DimPadding& p) {
return p.Hi() == 0 && p.Lo() == 0;
});
}
-bool IsInputDilationSupported(const ConvData& data) {
+bool IsInputDilationSupported(const ConvView& data) {
return llvm::all_of(data.InputDilations(), [](int64_t v) { return v == 1; });
}
-bool IsBatchGroupSupported(const ConvData& data) {
+bool IsBatchGroupSupported(const ConvView& data) {
return data.BatchGroupCount() == 1;
}
-bool IsWindowReversalSupported(const ConvData& data) {
+bool IsWindowReversalSupported(const ConvView& data) {
return llvm::all_of(data.WindowReversal(), [](bool b) { return !b; });
}
@@ -72,7 +72,7 @@
// Used externally to setup a ConversionTarget with dynamically legal
// mhlo.convolution. Doubles as matching predicate during legalization.
bool IsConvLegal(mhlo::ConvolutionOp op) {
- const ConvData data(op);
+ const ConvView data(op);
const bool supported_conv_type =
IsStandardConv(data) || IsDepthwiseConv(data);
@@ -89,7 +89,7 @@
// Bias is a zero tensor of shape [output_channels].
arith::ConstantOp BuildEmptyBias(OpBuilder& b, Location loc,
- const ConvData& data) {
+ const ConvView& data) {
auto bias_type = RankedTensorType::get(
{data.OutputLayout().SpecialDim2(data.OutputShape())},
data.ElementType());
@@ -109,7 +109,7 @@
mhlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
// Parse mhlo.convolution attrs into cc types.
- const ConvData data(op);
+ const ConvView data(op);
if (IsConvLegal(op) || !IsStandardConv(data) ||
data.InputLayout().Rank() != 4) {
@@ -168,7 +168,7 @@
mhlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
// Parse mhlo.convolution attrs into cc types.
- const ConvData data(op);
+ const ConvView data(op);
if (IsConvLegal(op) || !IsDepthwiseConv(data)) {
return failure();
@@ -236,7 +236,7 @@
mhlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
// Parse mhlo.convolution attrs into cc types.
- const ConvData data(op);
+ const ConvView data(op);
if (IsConvLegal(op) || !IsStandardConv(data) ||
data.InputLayout().Rank() != 5) {
@@ -324,7 +324,7 @@
LogicalResult Conv1DToConv2D::matchAndRewrite(mhlo::ConvolutionOp op,
PatternRewriter& rewriter) const {
- const ConvData view(op);
+ const ConvView view(op);
if (view.InputLayout().Rank() != 3) {
return rewriter.notifyMatchFailure(op, "Not 1D conv.");
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc
index 7d8c85a..e2dff5f 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc
@@ -18,92 +18,16 @@
#include <cstdint>
#include <optional>
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
namespace mlir::odml {
-llvm::SmallVector<int64_t, 4> Layout::GetPermForReLayout(
- const Layout& to_layout) const {
- llvm::SmallVector<int64_t, 4> perm(to_layout.Rank());
- perm[to_layout.SpecialDim1()] = SpecialDim1();
- perm[to_layout.SpecialDim2()] = SpecialDim2();
- for (const auto [to_spatial, from_spatial] :
- llvm::zip(to_layout.Spatials(), Spatials())) {
- perm[to_spatial] = from_spatial;
- }
- return perm;
-}
-
-llvm::SmallVector<int64_t, 4> Layout::PermuteShape(
- const Layout& to_layout, llvm::ArrayRef<int64_t> shape) const {
- llvm::SmallVector<int64_t, 4> new_shape(to_layout.Rank());
- const auto perm = GetPermForReLayout(to_layout);
- for (const auto [ind, val] : llvm::enumerate(perm)) {
- new_shape[ind] = shape[val];
- }
- return new_shape;
-}
-
-bool Layout::HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const {
- return SpecialDim1() == special_dim1 && SpecialDim2() == special_dim2;
-}
-
-bool Layout::AreSpatialsIota() const {
- llvm::ArrayRef<int64_t> spatials = Spatials();
- return llvm::all_of(llvm::enumerate(spatials), [&](const auto& it) {
- return it.index() == 0 || (it.value() == spatials[it.index() - 1] + 1);
- });
-}
-
-llvm::SmallVector<int64_t, 2> ResolveStridesOrDilations(
- const int64_t num_spatials,
- std::optional<mlir::DenseIntElementsAttr> opt_attr) {
- if (!opt_attr.has_value()) {
- return llvm::SmallVector<int64_t, 2>(num_spatials, 1);
- }
- auto attr = opt_attr.value();
- if (attr.isSplat()) {
- return llvm::SmallVector<int64_t, 2>(num_spatials,
- attr.getSplatValue<int64_t>());
- }
- return llvm::SmallVector<int64_t, 2>(attr.getValues<int64_t>());
-}
-
-llvm::SmallVector<DimPadding, 2> ResolvePadding(
- const int64_t num_spatials,
- std::optional<mlir::DenseIntElementsAttr> opt_padding) {
- llvm::SmallVector<DimPadding, 2> res;
- if (!opt_padding.has_value()) {
- for (int i = 0; i < num_spatials; ++i) {
- res.push_back(DimPadding(0, 0));
- }
- return res;
- }
- auto padding = opt_padding.value();
- if (padding.isSplat()) {
- const int64_t val = padding.getSplatValue<int64_t>();
- for (int i = 0; i < num_spatials; ++i) {
- res.push_back(DimPadding(val, val));
- }
- return res;
- }
- int64_t prev;
- for (const auto [ind, val] : llvm::enumerate(padding.getValues<int64_t>())) {
- const int64_t side = ind % 2;
- if (side == 1) {
- res.push_back(DimPadding(prev, val));
- }
- prev = val;
- }
- return res;
-}
-
llvm::SmallVector<bool, 2> ResolveWindowReversal(
const int64_t num_spatials,
std::optional<mlir::DenseElementsAttr> opt_reversals) {
@@ -118,7 +42,7 @@
return llvm::SmallVector<bool, 2>(reversals.getValues<bool>());
}
-ConvData::ConvData(mhlo::ConvolutionOp op)
+ConvView::ConvView(mhlo::ConvolutionOp op)
: input_layout_(
Layout{op.getDimensionNumbers().getInputBatchDimension(),
op.getDimensionNumbers().getInputFeatureDimension(),
@@ -156,7 +80,7 @@
}
Value CreatePadOpFromConvPadding(OpBuilder& b, mhlo::ConvolutionOp op) {
- const ConvData data(op);
+ const ConvView data(op);
const auto rank = data.InputLayout().Rank();
auto input_spatials = data.InputLayout().Spatials();
@@ -185,8 +109,8 @@
b.create<arith::ConstantOp>(op->getLoc(), padding_value_attr);
auto pad_op = b.create<mhlo::PadOp>(padding_value_op->getLoc(), op.getLhs(),
- padding_value_op, hi_padding_attr,
- lo_padding_attr, interior_padding_attr);
+ padding_value_op, lo_padding_attr,
+ hi_padding_attr, interior_padding_attr);
return pad_op;
}
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h
index 63ebbce..d20ad08 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
// Helpers for working with mhlo.convolution attrs in the mlir api as
@@ -26,91 +27,7 @@
namespace mlir::odml {
-// Generic class that wraps the "layout" of a convolution parameter.
-// Both kernel (e.g. [o, 0, 1, i]) and input/output (e.g. [b, 0, 1, f])
-// share the same structure just with different terminology for the
-// batch/feature/input_feature/output_feature dims.
-class Layout {
- public:
- llvm::ArrayRef<int64_t> Spatials() const { return spatials_; }
-
- int64_t NumSpatials() const { return spatials_.size(); }
-
- int64_t Rank() const { return NumSpatials() + 2; }
-
- Layout(int64_t special_dim1, int64_t special_dim2, ArrayRef<int64_t> spatials)
- : special_dim1_(special_dim1),
- special_dim2_(special_dim2),
- spatials_(spatials) {}
-
- // Gets index of first special dim. The batch dim for input and outputs,
- // or the output feature dim for the kernel.
- int64_t SpecialDim1() const { return special_dim1_; }
-
- // Conveniance accesor for getting the dimension size of the first
- // special dimension from a shape.
- int64_t SpecialDim1(llvm::ArrayRef<int64_t> shape) const {
- return shape[special_dim1_];
- }
-
- // Gets index of second special dim. The feature dim for input and outputs,
- // or the input feature dim for the kernel.
- int64_t SpecialDim2() const { return special_dim2_; }
-
- // Convenience accesor for getting the dimension size of the second
- // special dimension from a shape.
- int64_t SpecialDim2(llvm::ArrayRef<int64_t> shape) const {
- return shape[special_dim2_];
- }
-
- // Conveniance method for equality checking special dims.
- bool HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const;
-
- // Determines if the spatial dimensions are all adjacent and in
- // ascending order (HWD).
- bool AreSpatialsIota() const;
-
- // Gets a "permutation array" to be used for transposing a tensor
- // of "this" layout to the given layout. A permutation array is some
- // permutation of [0, 1, i...] for i < rank(layout). Assumes
- // "this" and given layout have the same rank.
- llvm::SmallVector<int64_t, 4> GetPermForReLayout(
- const Layout& to_layout) const;
-
- // Permutes given shape based on the permutaion implied to take this Layout to
- // the given one.
- llvm::SmallVector<int64_t, 4> PermuteShape(const Layout& to_layout,
- ArrayRef<int64_t> shape) const;
-
- bool operator==(const Layout& other) const {
- return SpecialDim1() == other.SpecialDim1() &&
- SpecialDim2() == other.SpecialDim2() &&
- Spatials() == other.Spatials();
- }
-
- bool operator!=(const Layout& other) const { return !(*this == other); }
-
- private:
- int64_t special_dim1_;
- int64_t special_dim2_;
- llvm::SmallVector<int64_t> spatials_;
-};
-
-// Wrapper for the padding attrs along a single dimension.
-class DimPadding {
- public:
- int64_t Hi() const { return hi_; }
-
- int64_t Lo() const { return lo_; }
-
- DimPadding(int64_t hi, int64_t lo) : hi_(hi), lo_(lo) {}
-
- private:
- int64_t hi_;
- int64_t lo_;
-};
-
-class ConvData {
+class ConvView {
public:
// int for each spatial dim. Default 1.
llvm::ArrayRef<int64_t> Strides() const { return strides_; }
@@ -145,7 +62,7 @@
mlir::Type ElementType() const { return element_type_; }
- explicit ConvData(mhlo::ConvolutionOp op);
+ explicit ConvView(mhlo::ConvolutionOp op);
private:
llvm::SmallVector<int64_t, 2> strides_;
@@ -171,11 +88,11 @@
mlir::Type element_type_;
};
-inline bool HasSupportedRank(const ConvData& data) {
+inline bool HasSupportedRank(const ConvView& data) {
return data.InputLayout().Rank() == 4 || data.InputLayout().Rank() == 5;
}
-inline bool HasSupportedOutFeatureDims(const ConvData& data) {
+inline bool HasSupportedOutFeatureDims(const ConvView& data) {
const int64_t kernel_out_features =
data.KernelLayout().SpecialDim2(data.KernelShape());
const int64_t out_features =
@@ -183,7 +100,7 @@
return kernel_out_features == out_features;
}
-inline bool IsNonTrivialConv(const ConvData& data) {
+inline bool IsNonTrivialConv(const ConvView& data) {
return llvm::all_of(data.InputDilations(), [](auto d) { return d == 1; });
}
@@ -191,7 +108,7 @@
// Standard conv predicates
//=-----
-inline bool HasStandardConvInFeatureDims(const ConvData& data) {
+inline bool HasStandardConvInFeatureDims(const ConvView& data) {
// kernel_in_features * feature_groups = input_features by definition.
const int64_t input_features =
data.InputLayout().SpecialDim2(data.InputShape());
@@ -204,7 +121,7 @@
return !trivial_kernel_in_features && (!is_grouped_conv || rank == 4);
}
-inline bool IsStandardConv(const ConvData& data) {
+inline bool IsStandardConv(const ConvView& data) {
return HasSupportedRank(data) && IsNonTrivialConv(data) &&
HasStandardConvInFeatureDims(data) && HasSupportedOutFeatureDims(data);
}
@@ -212,7 +129,7 @@
// Does this convolution map to a standard conv_2d or conv_3d
// (not depthwise or tranpose conv)?
inline bool IsStandardConv(mhlo::ConvolutionOp op) {
- const ConvData data(op);
+ const ConvView data(op);
return IsStandardConv(data);
}
@@ -220,7 +137,7 @@
// Depthwise conv predicates
//=-----
-inline bool IsDepthwiseConv(const ConvData& data) {
+inline bool IsDepthwiseConv(const ConvView& data) {
const bool valid_rank = data.InputLayout().Rank() == 4;
if (!valid_rank || !HasSupportedOutFeatureDims(data) ||
!IsNonTrivialConv(data)) {
@@ -233,7 +150,7 @@
// Does this convolution map to depthwise conv?
inline bool IsDepthwiseConv(mhlo::ConvolutionOp op) {
- const ConvData data(op);
+ const ConvView data(op);
return IsDepthwiseConv(data);
}
@@ -273,7 +190,7 @@
return GetTFLNativeStandardConvKernelLayout(DnumRank(dnums));
}
-inline bool IsTFLNativeLayout(const ConvData& data) {
+inline bool IsTFLNativeLayout(const ConvView& data) {
const int64_t rank = data.KernelLayout().Rank();
const auto native_io_layout = GetTFLNativeInputOrOutputLayout(rank);
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc
new file mode 100644
index 0000000..3d67bbf
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc
@@ -0,0 +1,111 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+
+#include <cstdint>
+#include <optional>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+
+namespace mlir::odml {
+
+llvm::SmallVector<int64_t, 4> Layout::GetPermForReLayout(
+ const Layout& to_layout) const {
+ llvm::SmallVector<int64_t, 4> perm(to_layout.Rank());
+ perm[to_layout.SpecialDim1()] = SpecialDim1();
+ perm[to_layout.SpecialDim2()] = SpecialDim2();
+ for (const auto [to_spatial, from_spatial] :
+ llvm::zip(to_layout.Spatials(), Spatials())) {
+ perm[to_spatial] = from_spatial;
+ }
+ return perm;
+}
+
+llvm::SmallVector<int64_t, 4> Layout::PermuteShape(
+ const Layout& to_layout, llvm::ArrayRef<int64_t> shape) const {
+ llvm::SmallVector<int64_t, 4> new_shape(to_layout.Rank());
+ const auto perm = GetPermForReLayout(to_layout);
+ for (const auto [ind, val] : llvm::enumerate(perm)) {
+ new_shape[ind] = shape[val];
+ }
+ return new_shape;
+}
+
+bool Layout::HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const {
+ return SpecialDim1() == special_dim1 && SpecialDim2() == special_dim2;
+}
+
+bool Layout::AreSpatialsIota() const {
+ llvm::ArrayRef<int64_t> spatials = Spatials();
+ return llvm::all_of(llvm::enumerate(spatials), [&](const auto& it) {
+ return it.index() == 0 || (it.value() == spatials[it.index() - 1] + 1);
+ });
+}
+
+llvm::SmallVector<int64_t, 4> ResolveStridesOrDilations(
+ int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_attr) {
+ if (!opt_attr.has_value()) {
+ return llvm::SmallVector<int64_t, 4>(rank, 1);
+ }
+ auto attr = opt_attr.value();
+ if (attr.isSplat()) {
+ return llvm::SmallVector<int64_t, 4>(rank, attr.getSplatValue<int64_t>());
+ }
+ return llvm::SmallVector<int64_t, 4>(attr.getValues<int64_t>());
+}
+
+llvm::SmallVector<DimPadding, 2> ResolvePadding(
+ int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_padding) {
+ llvm::SmallVector<DimPadding, 4> res;
+ if (!opt_padding.has_value()) {
+ for (int i = 0; i < rank; ++i) {
+ res.push_back(DimPadding(0, 0));
+ }
+ return res;
+ }
+ auto padding = opt_padding.value();
+ if (padding.isSplat()) {
+ const int64_t val = padding.getSplatValue<int64_t>();
+ for (int i = 0; i < rank; ++i) {
+ res.push_back(DimPadding(val, val));
+ }
+ return res;
+ }
+ int64_t prev;
+ for (const auto [ind, val] : llvm::enumerate(padding.getValues<int64_t>())) {
+ const int64_t side = ind % 2;
+ if (side == 1) {
+ res.push_back(DimPadding(prev, val));
+ }
+ prev = val;
+ }
+ return res;
+}
+
+bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k,
+ const DimPadding& pad) {
+ const int64_t pad_diff = pad.Hi() - pad.Lo();
+ if (pad_diff > 1 || pad_diff < 0) {
+ return false;
+ }
+ const int64_t pad_total = pad.Lo() + pad.Hi();
+ const int64_t out = (in + stride - 1) / stride;
+ const int effective_filter = (k - 1) * dilate + 1;
+ return ((out - 1) * stride + effective_filter) == in + pad_total;
+}
+
+} // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h
new file mode 100644
index 0000000..e3f8941
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h
@@ -0,0 +1,139 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_
+
+#include <cstdint>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+
+namespace mlir::odml {
+
+// Class that encodes the "layout" of a tensor. Layouts, generically
+// are some naming of the dimensions of a tensor. In all cases, 2 dimensions
+// are "special" (e.g. batch / feature) and the rest are referred to as "spatial
+// dims". When the special dims are batch and feature, batch is special dim 1
+// and feature is special dim 2. When special dims are input and output features
+// (conv filter), input features is special dim 1 and output features is special
+// dim 2.
+class Layout {
+ public:
+ llvm::ArrayRef<int64_t> Spatials() const { return spatials_; }
+
+ int64_t NumSpatials() const { return spatials_.size(); }
+
+ int64_t Rank() const { return NumSpatials() + 2; }
+
+ Layout(int64_t special_dim1, int64_t special_dim2, ArrayRef<int64_t> spatials)
+ : special_dim1_(special_dim1),
+ special_dim2_(special_dim2),
+ spatials_(spatials) {}
+
+ // TODO: b/351437662 - Consider just using 2 arrays for the case where
+ // there are more than 2 special dims.
+ int64_t SpecialDim1() const { return special_dim1_; }
+
+ // Conveniance accesor for getting the dimension size of the first
+ // special dimension from a shape.
+ int64_t SpecialDim1(llvm::ArrayRef<int64_t> shape) const {
+ return shape[special_dim1_];
+ }
+
+ int64_t SpecialDim2() const { return special_dim2_; }
+
+ // Convenience accesor for getting the dimension size of the second
+ // special dimension from a shape.
+ int64_t SpecialDim2(llvm::ArrayRef<int64_t> shape) const {
+ return shape[special_dim2_];
+ }
+
+ // Conveniance method for equality checking special dims.
+ bool HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const;
+
+ // Determines if the spatial dimensions are all adjacent and in
+ // ascending order.
+ bool AreSpatialsIota() const;
+
+ // Gets a "permutation array" to be used for transposing a tensor
+ // of "this" layout to the given layout. A permutation array is some
+ // permutation of [0, 1, i...] for i < rank(layout). Assumes
+ // "this" and given layout have the same rank.
+ llvm::SmallVector<int64_t, 4> GetPermForReLayout(
+ const Layout& to_layout) const;
+
+ // Permutes given shape based on the permutaion implied to take this Layout to
+ // the given one.
+ llvm::SmallVector<int64_t, 4> PermuteShape(const Layout& to_layout,
+ ArrayRef<int64_t> shape) const;
+
+ bool operator==(const Layout& other) const {
+ return SpecialDim1() == other.SpecialDim1() &&
+ SpecialDim2() == other.SpecialDim2() &&
+ Spatials() == other.Spatials();
+ }
+
+ bool operator!=(const Layout& other) const { return !(*this == other); }
+
+ private:
+ int64_t special_dim1_;
+ int64_t special_dim2_;
+ llvm::SmallVector<int64_t> spatials_;
+};
+
+// Wrapper for the padding attrs along a single dimension.
+class DimPadding {
+ public:
+ int64_t Hi() const { return hi_; }
+
+ int64_t Lo() const { return lo_; }
+
+ bool Trivial() const { return Hi() == 0 && Lo() == 0; }
+
+ DimPadding(int64_t lo, int64_t hi) : lo_(lo), hi_(hi) {}
+
+ private:
+ int64_t lo_;
+ int64_t hi_;
+};
+
+inline llvm::SmallVector<int64_t> UnrollI64Splat(DenseElementsAttr data) {
+ if (!data.isSplat()) {
+ return llvm::SmallVector<int64_t>(data.getValues<int64_t>());
+ }
+ return llvm::SmallVector<int64_t>(data.getType().getNumElements(),
+ data.getSplatValue<int64_t>());
+}
+
+// Resolves optional strides or dilations attributes. If not present,
+// will return trivial 1's vector.
+llvm::SmallVector<int64_t, 4> ResolveStridesOrDilations(
+ int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_attr);
+
+// Resolves optional paddings attributes. If not present, will return
+// trivial [0, 0] paddings on each dim.
+llvm::SmallVector<DimPadding, 2> ResolvePadding(
+ int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_padding);
+
+// Does the padding correspond to "SAME" on given dimension configuration.
+// Assumes given dimension configuration is well formed.
+bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k,
+ const DimPadding& pad);
+
+} // namespace mlir::odml
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc
index ac5e3ff..c25f27a 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc
@@ -28,6 +28,7 @@
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
@@ -49,8 +50,8 @@
}
DenseIntElementsAttr BuildTFLPaddingAttr(OpBuilder& b, mhlo::PadOp op) {
- auto lows = UnrollSplat(op.getEdgePaddingLow());
- auto highs = UnrollSplat(op.getEdgePaddingHigh());
+ auto lows = UnrollI64Splat(op.getEdgePaddingLow());
+ auto highs = UnrollI64Splat(op.getEdgePaddingHigh());
llvm::SmallVector<int64_t> res;
for (auto [l, h] : llvm::zip(lows, highs)) {
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc
index 859cdfe..cb004d3b 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc
@@ -20,6 +20,7 @@
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
namespace mlir::odml {
@@ -28,16 +29,8 @@
return op.getEdgePaddingLow().getType();
}
-llvm::SmallVector<int64_t> UnrollSplat(DenseElementsAttr data) {
- if (!data.isSplat()) {
- return llvm::SmallVector<int64_t>(data.getValues<int64_t>());
- }
- return llvm::SmallVector<int64_t>(data.getType().getNumElements(),
- data.getSplatValue<int64_t>());
-}
-
DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op) {
- auto vals = UnrollSplat(op.getEdgePaddingLow());
+ auto vals = UnrollI64Splat(op.getEdgePaddingLow());
auto starts = llvm::map_range(
vals, [](auto v) -> int64_t { return (v >= 0) ? 0 : -1 * v; });
return DenseIntElementsAttr::get(GetPaddingAttrType(op),
@@ -45,7 +38,7 @@
}
DenseIntElementsAttr SliceEndFromNegPadHighs(mhlo::PadOp op) {
- auto vals = UnrollSplat(op.getEdgePaddingHigh());
+ auto vals = UnrollI64Splat(op.getEdgePaddingHigh());
auto zip = llvm::zip(vals, op.getOperand().getType().getShape());
auto ends = llvm::map_range(zip, [](auto it) -> int64_t {
return (std::get<0>(it) >= 0) ? std::get<1>(it)
@@ -56,7 +49,7 @@
}
DenseIntElementsAttr ReplaceNegsWithZero(DenseElementsAttr data) {
- auto vals = UnrollSplat(data);
+ auto vals = UnrollI64Splat(data);
auto res =
llvm::map_range(vals, [](auto v) -> int64_t { return (v < 0) ? 0 : v; });
return DenseIntElementsAttr::get(data.getType(), llvm::to_vector(res));
@@ -64,8 +57,8 @@
bool AnyNegativePads(mhlo::PadOp op) {
auto is_neg = [](int64_t v) { return v < 0; };
- auto lows_data = UnrollSplat(op.getEdgePaddingLow());
- auto highs_data = UnrollSplat(op.getEdgePaddingHigh());
+ auto lows_data = UnrollI64Splat(op.getEdgePaddingLow());
+ auto highs_data = UnrollI64Splat(op.getEdgePaddingHigh());
return llvm::any_of(lows_data, is_neg) || llvm::any_of(highs_data, is_neg);
}
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h
index aa0428f..5041903 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h
@@ -21,8 +21,6 @@
namespace mlir::odml {
-llvm::SmallVector<int64_t> UnrollSplat(DenseElementsAttr data);
-
// Gets elements corresponding to slice starts from negative padding
// values.
DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op);
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc
new file mode 100644
index 0000000..a00ee33
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc
@@ -0,0 +1,742 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h"
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <tuple>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
+#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/IR/IRMapping.h" // from @llvm-project
+#include "mlir/IR/Matchers.h" // from @llvm-project
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+namespace {
+
+// filters, strides, padding, faf.
+using TFLPoolAttrsT = std::tuple<IntegerAttr, IntegerAttr, IntegerAttr,
+ IntegerAttr, StringAttr, StringAttr>;
+
+bool AreDilationsSupported(const ReduceWindowView& op) {
+ auto is_one = [](int64_t v) { return v == 1; };
+ return llvm::all_of(op.BaseDilations(), is_one) &&
+ llvm::all_of(op.WindowDilations(), is_one);
+}
+
+bool IsRankSupported(const ReduceWindowView& op) { return op.Rank() == 4; }
+
+std::optional<std::tuple<ReduceWindowView, Layout>> GetViewIfAttrsSupported(
+ mhlo::ReduceWindowOp op) {
+ const ReduceWindowView view(op);
+
+ if (!IsRankSupported(view)) {
+ return std::nullopt;
+ }
+
+ if (!AreDilationsSupported(view)) {
+ return std::nullopt;
+ }
+
+ auto opt_layout = view.GuessLayout();
+ if (!opt_layout.has_value()) {
+ return std::nullopt;
+ }
+ auto layout = opt_layout.value();
+
+ const int64_t batch = layout.SpecialDim1();
+ if (!view.Paddings()[batch].Trivial()) {
+ return std::nullopt;
+ }
+
+ const int64_t chan = layout.SpecialDim2();
+ if (!view.Paddings()[chan].Trivial()) {
+ return std::nullopt;
+ }
+
+ return std::tuple(view, layout);
+}
+
+std::optional<bool> IsReduceWindowLegal(mhlo::ReduceWindowOp op) {
+ return std::nullopt;
+}
+
+std::optional<bool> IsDivideLegal(mhlo::DivOp op) { return std::nullopt; }
+
+Layout TFLNativePoolingLayout(int64_t rank) {
+ return Layout(0, rank - 1, llvm::to_vector(llvm::seq<int64_t>(1, rank - 1)));
+}
+
+bool IsCstFloatZero(Value val) {
+ DenseFPElementsAttr initial_value;
+ return matchPattern(val, m_Constant(&initial_value)) &&
+ initial_value.getNumElements() == 1 &&
+ initial_value.getValues<APFloat>()[0].isZero();
+}
+
+bool IsCstIntZero(Value val) {
+ DenseIntElementsAttr initial_value;
+ return matchPattern(val, m_Constant(&initial_value)) &&
+ initial_value.getNumElements() == 1 &&
+ initial_value.getValues<APInt>()[0].isZero();
+}
+
+llvm::SmallVector<int64_t> Permute(llvm::ArrayRef<int64_t> data,
+ llvm::ArrayRef<int64_t> perm) {
+ llvm::SmallVector<int64_t> res(data.size());
+ for (int i = 0; i < data.size(); ++i) {
+ res[i] = data[perm[i]];
+ }
+ return res;
+}
+
+Value TransposeTensor(OpBuilder& b, Value tensor,
+ llvm::SmallVector<int64_t> perm) {
+ const int64_t perm_size = perm.size();
+ auto perm_attr_type = RankedTensorType::get({perm_size}, b.getI64Type());
+ auto perm_attr = DenseIntElementsAttr::get(perm_attr_type, perm);
+ return b.create<mhlo::TransposeOp>(tensor.getLoc(), tensor, perm_attr);
+}
+
+DenseIntElementsAttr BuildDenseI64(OpBuilder& b, ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> data) {
+ return DenseIntElementsAttr::get(RankedTensorType::get(shape, b.getI64Type()),
+ data);
+}
+
+DenseIntElementsAttr BuildDenseI64(OpBuilder& b, ArrayRef<int64_t> data) {
+ const int64_t dim = data.size();
+ return BuildDenseI64(b, {dim}, data);
+}
+
+std::optional<std::tuple<Value, Value>> GetInputAndInitIfValid(
+ mhlo::ReduceWindowOp op) {
+ if (op->getNumResults() != 1) {
+ return std::nullopt;
+ }
+ if (op.getInputs().size() > 1) {
+ return std::nullopt;
+ }
+ if (op.getInitValues().size() > 1) {
+ return std::nullopt;
+ }
+ auto init_val = op.getInitValues().front();
+ if (llvm::dyn_cast<ShapedType>(init_val.getType()).getNumElements() != 1) {
+ return std::nullopt;
+ }
+ return std::tuple(op.getInputs().front(), op.getInitValues().front());
+}
+
+std::optional<std::string> GetTFLPadding(ArrayRef<DimPadding> paddings,
+ ArrayRef<int64_t> window_strides,
+ ArrayRef<int64_t> in_shape,
+ ArrayRef<int64_t> window_dims) {
+ const int64_t rank = paddings.size();
+ std::string tfl_padding = "VALID";
+ for (int i = 1; i < rank - 1; ++i) {
+ const auto& dim_pad = paddings[i];
+ if (dim_pad.Trivial()) {
+ continue;
+ }
+ if (!IsSamePaddingOnDim(in_shape[i], 1, window_strides[i], window_dims[i],
+ dim_pad)) {
+ return std::nullopt;
+ }
+ tfl_padding = "SAME";
+ }
+ return tfl_padding;
+}
+
+TFLPoolAttrsT BuildTFLPoolAttrs(OpBuilder& b, const ReduceWindowView& view,
+ StringRef padding) {
+ const int32_t filter_h = view.WindowDims()[1];
+ auto filter_h_attr = b.getI32IntegerAttr(filter_h);
+
+ const int32_t filter_w = view.WindowDims()[2];
+ auto filter_w_attr = b.getI32IntegerAttr(filter_w);
+
+ const int32_t stride_h = view.WindowStrides()[1];
+ auto stride_h_attr = b.getI32IntegerAttr(stride_h);
+
+ const int32_t stride_w = view.WindowStrides()[2];
+ auto stride_w_attr = b.getI32IntegerAttr(stride_w);
+
+ auto padding_attr = b.getStringAttr(padding);
+ auto faf_attr = b.getStringAttr("NONE");
+
+ return std::tuple(filter_h_attr, filter_w_attr, stride_h_attr, stride_w_attr,
+ padding_attr, faf_attr);
+}
+
+//===------------------------------------------------------------------------===
+// relayout reduce_window to channel last
+//===------------------------------------------------------------------------===
+
+class RelayoutReduceWindow : public OpRewritePattern<mhlo::ReduceWindowOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(mhlo::ReduceWindowOp op,
+ PatternRewriter& rewriter) const final;
+};
+
+LogicalResult RelayoutReduceWindow::matchAndRewrite(
+ mhlo::ReduceWindowOp op, PatternRewriter& rewriter) const {
+ //
+ // check and parse attributes
+ //=-----
+
+ auto opt_view = GetViewIfAttrsSupported(op);
+ if (!opt_view.has_value()) {
+ return rewriter.notifyMatchFailure(
+ op, "Reduce window attributes not supported.");
+ }
+ const auto [view, layout] = opt_view.value();
+
+ //
+ // get inputs and inits if there are only one
+ //=-----
+
+ auto opt_input_and_init = GetInputAndInitIfValid(op);
+ if (!opt_input_and_init.has_value()) {
+ return rewriter.notifyMatchFailure(
+ op, "Reduce window has wrong number of inputs or init values.");
+ }
+ auto [input, init_val] = opt_input_and_init.value();
+
+ //
+ // figure out permutations for layout change
+ //=-----
+
+ const auto target_layout = TFLNativePoolingLayout(view.Rank());
+ if (layout == target_layout) {
+ return rewriter.notifyMatchFailure(
+ op, "Reduce window does not need layout change");
+ }
+
+ llvm::SmallVector<int64_t> perm_for_inputs =
+ layout.GetPermForReLayout(target_layout);
+
+ //
+ // permute layout sensitive attrs
+ //=-----
+
+ // permute paddings
+ auto paddings = view.Paddings();
+ llvm::SmallVector<int64_t> new_paddings(paddings.size() * 2);
+ for (int i = 0; i < new_paddings.size() / 2; ++i) {
+ const auto& dim_pad = paddings[perm_for_inputs[i]];
+ new_paddings[2 * i] = dim_pad.Lo();
+ new_paddings[2 * i + 1] = dim_pad.Hi();
+ }
+ const int64_t new_paddings_size = paddings.size();
+ auto new_paddings_type =
+ RankedTensorType::get({new_paddings_size, 2}, rewriter.getI64Type());
+ auto new_paddings_attr =
+ DenseIntElementsAttr::get(new_paddings_type, new_paddings);
+
+ // permute window dims
+ llvm::SmallVector<int64_t> new_window_dims =
+ Permute(view.WindowDims(), perm_for_inputs);
+ auto new_window_dims_attr = BuildDenseI64(rewriter, new_window_dims);
+
+ // permute window strides
+ llvm::SmallVector<int64_t> new_window_strides =
+ Permute(view.WindowStrides(), perm_for_inputs);
+ auto new_window_strides_attr = BuildDenseI64(rewriter, new_window_strides);
+
+ //
+ // permute params and build new op
+ //=-----
+
+ // figure out permuted result type
+ llvm::SmallVector<int64_t> perm_for_outputs =
+ target_layout.GetPermForReLayout(layout);
+ auto cur_out_type = llvm::dyn_cast<ShapedType>(op.getResult(0).getType());
+ llvm::SmallVector<int64_t> new_rw_out_shape =
+ layout.PermuteShape(target_layout, cur_out_type.getShape());
+ auto new_out_type = cur_out_type.clone(new_rw_out_shape);
+
+ // transpose input and build new reduce_window
+ auto new_input = TransposeTensor(rewriter, input, perm_for_inputs);
+ auto new_rw = rewriter.create<mhlo::ReduceWindowOp>(
+ op.getLoc(), new_out_type, new_input, init_val, new_window_dims_attr,
+ new_window_strides_attr, BuildDenseI64(rewriter, view.BaseDilations()),
+ BuildDenseI64(rewriter, view.WindowDilations()), new_paddings_attr);
+ IRMapping ir_map;
+ op.getBody().cloneInto(&new_rw.getBody(), ir_map);
+
+ // transpose output and update ir
+ auto new_output =
+ TransposeTensor(rewriter, new_rw.getResult(0), perm_for_outputs);
+ rewriter.replaceOp(op, new_output);
+
+ return success();
+}
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> tfl.cum_sum
+//===------------------------------------------------------------------------===
+
+class LegalizeCumSum : public OpConversionPattern<mhlo::ReduceWindowOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeCumSum::matchAndRewrite(
+ mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const {
+ //
+ // check singular params and trivial attrs
+ //=-----
+
+ auto opt_input_init = GetInputAndInitIfValid(op);
+ if (!opt_input_init.has_value()) {
+ return rewriter.notifyMatchFailure(op,
+ "Must have 1 input, init and result.");
+ }
+ auto [input, init] = opt_input_init.value();
+
+ if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(op.getBody()))) {
+ return rewriter.notifyMatchFailure(op, "Requires scalar add in region.");
+ }
+
+ if (!IsCstFloatZero(init) && !IsCstIntZero(init)) {
+ return rewriter.notifyMatchFailure(op, "Requires 0 for init value.");
+ }
+
+ const ReduceWindowView view(op);
+
+ auto trivial = [](int64_t v) { return v == 1; };
+ const bool trivial_window_dilate =
+ llvm::all_of(view.WindowDilations(), trivial);
+ const bool trivial_base_dilate = llvm::all_of(view.BaseDilations(), trivial);
+ const bool trivial_stride = llvm::all_of(view.WindowStrides(), trivial);
+ if (!trivial_window_dilate || !trivial_stride || !trivial_base_dilate) {
+ return rewriter.notifyMatchFailure(
+ op, "Requires trivial strides and dilations attributes.");
+ }
+
+ //
+ // figure out the implicit axis of reduction
+ //=-----
+
+ auto input_type = llvm::cast<ShapedType>(input.getType());
+ if (view.WindowDims().size() != input_type.getRank()) {
+ return rewriter.notifyMatchFailure(op, "Splat window dims not supported.");
+ }
+ int64_t axis = -1;
+ for (auto [ind, val] : llvm::enumerate(view.WindowDims())) {
+ if (val == 1) {
+ continue;
+ }
+
+ if (axis != -1) {
+ return rewriter.notifyMatchFailure(op, "Multiple non 1 dimensions.");
+ }
+
+ if (val != input_type.getShape()[ind]) {
+ return rewriter.notifyMatchFailure(
+ op, "Axis dimension requires size be same as input shape's.");
+ }
+ axis = ind;
+ }
+
+ if (axis == -1) {
+ return rewriter.notifyMatchFailure(op, "Could not identify axis.");
+ }
+
+ const int64_t axis_size = input_type.getShape()[axis];
+
+ //
+ // validate padding is [N-1, 0] on axis and zero elsewhere
+ //=-----
+
+ for (const auto& [ind, dim_pad] : llvm::enumerate(view.Paddings())) {
+ if (dim_pad.Hi() != 0) {
+ return rewriter.notifyMatchFailure(op, "Has non trivial high padding.");
+ }
+
+ if (ind != axis) {
+ if (!dim_pad.Trivial()) {
+ return rewriter.notifyMatchFailure(
+ op, "Has non trivial padding on non axis dim.");
+ }
+ } else {
+ if (dim_pad.Lo() != axis_size - 1) {
+ return rewriter.notifyMatchFailure(
+ op, "Requires low padding on axis dim to be N - 1.");
+ }
+ }
+ }
+
+ //
+ // build axis constant and tfl op
+ //=-----
+
+ auto axis_cst_attr = DenseIntElementsAttr::get(
+ RankedTensorType::get({}, rewriter.getI32Type()),
+ static_cast<int32_t>(axis));
+ auto axis_cst =
+ rewriter.create<arith::ConstantOp>(op->getLoc(), axis_cst_attr);
+
+ auto tfl_exclusive_attr = rewriter.getBoolAttr(false);
+ auto tfl_reverse_attr = rewriter.getBoolAttr(false);
+
+ rewriter.replaceOpWithNewOp<TFL::CumsumOp>(op, op->getResultTypes()[0], input,
+ axis_cst, tfl_exclusive_attr,
+ tfl_reverse_attr);
+
+ return success();
+}
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> tfl.max_pool
+//===------------------------------------------------------------------------===
+
+bool isFloatMinusInfinity(Value value) {
+ DenseFPElementsAttr float_value;
+ if (!matchPattern(value, m_Constant(&float_value))) {
+ return false;
+ }
+ if (float_value.getNumElements() != 1) {
+ return false;
+ }
+ APFloat element = float_value.getValues<APFloat>()[0];
+ return element.isInfinity() && element.isNegative();
+}
+
+class LegalizeMaxPool : public OpConversionPattern<mhlo::ReduceWindowOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeMaxPool::matchAndRewrite(
+ mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const {
+ //
+ // parse and validate lhs reduce window
+ //=-----
+
+ const auto opt_view = GetViewIfAttrsSupported(op);
+ if (!opt_view.has_value()) {
+ return rewriter.notifyMatchFailure(op, "Reduce window is not valid.");
+ }
+ const auto [view, layout] = opt_view.value();
+ if (layout != TFLNativePoolingLayout(layout.Rank())) {
+ return rewriter.notifyMatchFailure(op, "Not tfl standard layout.");
+ }
+
+ // Check that the reduce-window is a max-reduce-window.
+ if (failed(MatchBinaryReduceFunction<mhlo::MaxOp>(op.getBody()))) {
+ return rewriter.notifyMatchFailure(op, "Must be a max pool.");
+ }
+
+ auto type = mlir::dyn_cast<ShapedType>(op.getResult(0).getType());
+ if (!mlir::isa<FloatType>(type.getElementType())) {
+ return rewriter.notifyMatchFailure(op, "Not a floating point pool.");
+ }
+
+ //
+ // validate inputs and init
+ //=-----
+
+ auto opt_inputs_and_init = GetInputAndInitIfValid(op);
+ if (!opt_inputs_and_init.has_value()) {
+ return rewriter.notifyMatchFailure(op, "Too many inputs or inits.");
+ }
+ auto [input, init] = opt_inputs_and_init.value();
+ auto input_type = llvm::dyn_cast<ShapedType>(input.getType());
+
+ if (!isFloatMinusInfinity(init)) {
+ return rewriter.notifyMatchFailure(op, "Init not minus infinity.");
+ }
+
+ //
+ // build tfl
+ //=-----
+
+ auto opt_tfl_padding =
+ GetTFLPadding(view.Paddings(), view.WindowStrides(),
+ input_type.getShape(), view.WindowDims());
+ if (!opt_tfl_padding.has_value()) {
+ return rewriter.notifyMatchFailure(op, "Padding not SAME or VALID.");
+ }
+ const auto& tfl_padding = opt_tfl_padding.value();
+
+ auto [fh, fw, sh, sw, p, faf] =
+ BuildTFLPoolAttrs(rewriter, view, tfl_padding);
+ rewriter.replaceOpWithNewOp<TFL::MaxPool2DOp>(op, type, input, p, sw, sh, fw,
+ fh, faf);
+
+ return success();
+}
+
+//===------------------------------------------------------------------------===
+// mhlo.div(mhlo.reduce_window, cst | mhlo.reduce_window) -> tfl.avg_pool
+//===------------------------------------------------------------------------===
+
+void ReplaceWithAvgPool(mhlo::DivOp op, Value rw_lhs_input,
+ const ReduceWindowView& lhs_view,
+ llvm::StringRef padding, PatternRewriter& rewriter,
+ mhlo::TransposeOp opt_final_tpose) {
+ Type out_type =
+ opt_final_tpose ? opt_final_tpose.getOperand().getType() : op.getType();
+
+ auto [fh, fw, sh, sw, p, faf] =
+ BuildTFLPoolAttrs(rewriter, lhs_view, padding);
+ Value final_op = rewriter.create<TFL::AveragePool2DOp>(
+ op->getLoc(), out_type, rw_lhs_input, fh, fw, p, sh, sw, faf);
+
+ if (opt_final_tpose) {
+ final_op = rewriter
+ .create<mhlo::TransposeOp>(final_op.getLoc(), final_op,
+ opt_final_tpose.getPermutation())
+ .getResult();
+ }
+
+ rewriter.replaceOp(op, final_op);
+}
+
+// Walks up the op and ignore all precedding ops of type Tys.
+// Returns the first producer op whose type is not in Tys.
+template <typename... Tys>
+Value RecursivelyWalkUp(Value op) {
+ while (llvm::isa_and_nonnull<Tys...>(op.getDefiningOp())) {
+ Operation* producer = op.getDefiningOp();
+ op = producer->getOperand(/*idx=*/0);
+ }
+
+ return op;
+}
+
+class LegalizeAvgPool : public OpConversionPattern<mhlo::DivOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mhlo::DivOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeAvgPool::matchAndRewrite(
+ mhlo::DivOp div_op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const {
+ //
+ // parse and validate lhs reduce window
+ //=-----
+
+ auto div_lhs = div_op.getLhs();
+ // If div's input is transposed, save it to chain on the new pool op.
+ mhlo::TransposeOp opt_final_tpose;
+ if (auto div_lhs_op = div_lhs.getDefiningOp()) {
+ opt_final_tpose = llvm::dyn_cast_or_null<mhlo::TransposeOp>(div_lhs_op);
+ }
+
+ auto rw_lhs_val = RecursivelyWalkUp<mhlo::TransposeOp>(div_lhs);
+ auto rw_lhs =
+ llvm::dyn_cast_or_null<mhlo::ReduceWindowOp>(rw_lhs_val.getDefiningOp());
+ if (!rw_lhs) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Could not match lhs of div on reduce window.");
+ }
+
+ const auto opt_rw_lhs_view = GetViewIfAttrsSupported(rw_lhs);
+ if (!opt_rw_lhs_view.has_value()) {
+ return rewriter.notifyMatchFailure(div_op, "Lhs rw is not valid.");
+ }
+ const auto [rw_lhs_view, rw_lhs_layout] = opt_rw_lhs_view.value();
+ if (rw_lhs_layout != TFLNativePoolingLayout(rw_lhs_layout.Rank())) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Lhs reduce window not tfl standard layout.");
+ }
+
+ // Check that the reduce-window is a sum-reduce-window.
+ if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_lhs.getBody()))) {
+ return rewriter.notifyMatchFailure(div_op,
+ "Failed to match rw lhs binary func.");
+ }
+
+ //
+ // validate inputs and init val
+ //=-----
+
+ auto opt_rw_lhs_input_and_init = GetInputAndInitIfValid(rw_lhs);
+ if (!opt_rw_lhs_input_and_init.has_value()) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Lhs reduce window has wrong number of inputs or init values.");
+ }
+ auto [rw_lhs_input, rw_lhs_init_val] = opt_rw_lhs_input_and_init.value();
+ auto rw_lhs_input_type = llvm::dyn_cast<ShapedType>(rw_lhs_input.getType());
+
+ auto rw_lhs_type =
+ mlir::dyn_cast<RankedTensorType>(rw_lhs.getResult(0).getType());
+ if (!mlir::isa<FloatType>(rw_lhs_type.getElementType())) {
+ return rewriter.notifyMatchFailure(div_op,
+ "Reduce window lhs most be float type.");
+ }
+
+ // If the init value isn't zero then it can't be an average pool.
+ if (!IsCstFloatZero(rw_lhs_init_val)) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Reduce window lhs init value is not zero.");
+ }
+
+ //
+ // case 1: rhs is splat const with val == window_size
+ //=-----
+
+ auto opt_tfl_padding =
+ GetTFLPadding(rw_lhs_view.Paddings(), rw_lhs_view.WindowStrides(),
+ rw_lhs_input_type.getShape(), rw_lhs_view.WindowDims());
+ if (!opt_tfl_padding.has_value()) {
+ return rewriter.notifyMatchFailure(div_op,
+ "Padding must be VALID or SAME.");
+ }
+ const auto& tfl_padding = opt_tfl_padding.value();
+
+ {
+ DenseFPElementsAttr divisor;
+ auto div_rhs = RecursivelyWalkUp<mhlo::BroadcastInDimOp, mhlo::TransposeOp>(
+ div_op.getRhs());
+ if (matchPattern(div_rhs, m_Constant(&divisor))) {
+ if (!divisor.isSplat()) {
+ return failure();
+ }
+
+ if (!divisor.getSplatValue<APFloat>().isExactlyValue(
+ rw_lhs_view.WindowSize())) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Rhs splat const is not equal to window size.");
+ }
+
+ if (tfl_padding != "VALID") {
+ return rewriter.notifyMatchFailure(div_op,
+ "Matching on rhs splat const where "
+ "rw lhs has non-trivial padding.");
+ }
+
+ ReplaceWithAvgPool(div_op, rw_lhs_input, rw_lhs_view, tfl_padding,
+ rewriter, opt_final_tpose);
+ return success();
+ }
+ }
+
+ //
+ // case 2: rhs is another reduce window over 1's with same config as lhs
+ //=-----
+
+ {
+ Value divisor = RecursivelyWalkUp<mhlo::BroadcastInDimOp, mhlo::ReshapeOp,
+ mhlo::TransposeOp>(div_op.getRhs());
+ auto rw_rhs =
+ dyn_cast_or_null<mhlo::ReduceWindowOp>(divisor.getDefiningOp());
+ if (!rw_rhs) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Rhs of div op is not a reduce window.");
+ }
+
+ const auto opt_rw_rhs_view = GetViewIfAttrsSupported(rw_rhs);
+ if (!opt_rw_rhs_view.has_value()) {
+ return rewriter.notifyMatchFailure(div_op, "Rhs rw is not valid.");
+ }
+ const auto [rw_rhs_view, rw_rhs_layout] = opt_rw_rhs_view.value();
+ if (rw_rhs_layout != TFLNativePoolingLayout(rw_rhs_layout.Rank())) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Rhs reduce window not tfl standard layout.");
+ }
+
+ // Check that RHS is a sum-reduce-window.
+ if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.getBody()))) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Rhs rw body function is not an add op.");
+ }
+
+ auto opt_rw_rhs_input_and_init = GetInputAndInitIfValid(rw_rhs);
+ if (!opt_rw_rhs_input_and_init.has_value()) {
+ return rewriter.notifyMatchFailure(
+ div_op,
+ "Rhs reduce window has wrong number of inputs or init values.");
+ }
+ auto [rw_rhs_input, rw_rhs_init_val] = opt_rw_rhs_input_and_init.value();
+
+ if (!IsCstFloatZero(rw_rhs_init_val)) {
+ return rewriter.notifyMatchFailure(div_op,
+ "Rhs rw init vals is not zero.");
+ }
+
+ rw_rhs_input = RecursivelyWalkUp<mhlo::BroadcastInDimOp, mhlo::TransposeOp>(
+ rw_rhs_input);
+ DenseFPElementsAttr rhs_input_data;
+ if (!matchPattern(rw_rhs_input, m_Constant(&rhs_input_data)) ||
+ !rhs_input_data.isSplat() ||
+ !rhs_input_data.getSplatValue<APFloat>().isExactlyValue(1.0)) {
+ return rewriter.notifyMatchFailure(div_op,
+ "Rw rhs input is not splat of 1.0.");
+ }
+
+ // Check that the two reduce window have the same window configuration.
+ if (rw_lhs.getWindowDimensions() != rw_rhs.getWindowDimensions() ||
+ rw_lhs.getWindowStrides() != rw_rhs.getWindowStrides() ||
+ rw_lhs.getPadding() != rw_rhs.getPadding()) {
+ return rewriter.notifyMatchFailure(
+ div_op, "Lhs rw and Rhs rw do not have the same config.");
+ }
+
+ ReplaceWithAvgPool(div_op, rw_lhs_input, rw_lhs_view, tfl_padding, rewriter,
+ opt_final_tpose);
+ return success();
+ }
+
+ return failure();
+}
+
+} // namespace
+
+void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns,
+ ConversionTarget& target) {
+ patterns.add<LegalizeAvgPool, LegalizeMaxPool, LegalizeCumSum>(ctx);
+ target.addDynamicallyLegalOp<mhlo::ReduceWindowOp>(IsReduceWindowLegal);
+ target.addDynamicallyLegalOp<mhlo::DivOp>(IsDivideLegal);
+}
+
+void PopulatePrepareReduceWindowPatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns) {
+ patterns.add<RelayoutReduceWindow>(ctx);
+}
+
+} // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h
new file mode 100644
index 0000000..ccc9c27f
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h
@@ -0,0 +1,45 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_
+
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+
+namespace mlir::odml {
+
+// Patterns to legalize mhlo.reduce_window to TFL.
+//
+// Maps the following representations of AvgPool in MHLO into a tfl.avg_pool
+// operation when they cleanly map to 2D or 3D average pool with VALID or SAME
+// padding:
+// * div(reduce_sum_window(x), constant(sizeof(window)))
+// * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
+//
+// Emits: tfl.average_pool2d
+void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns,
+ ConversionTarget& target);
+
+// Patterns to prepare mhlo.reduce_window for legalization.
+// Transposes reduce_windows to be NHWC.
+//
+// Emits: tfl.transpose
+void PopulatePrepareReduceWindowPatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns);
+
+} // namespace mlir::odml
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc
new file mode 100644
index 0000000..67e4db6
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc
@@ -0,0 +1,72 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h"
+
+#include <cstdint>
+#include <optional>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+
+ReduceWindowView::ReduceWindowView(mhlo::ReduceWindowOp op) {
+ rank_ = op.getWindowDimensions().size();
+ window_dims_ =
+ SmallVector<int64_t, 4>(op.getWindowDimensions().getValues<int64_t>());
+ window_strides_ = ResolveStridesOrDilations(rank_, op.getWindowStrides());
+ window_dilations_ = ResolveStridesOrDilations(rank_, op.getWindowDilations());
+ base_dilations_ = ResolveStridesOrDilations(rank_, op.getBaseDilations());
+ paddings_ = ResolvePadding(rank_, op.getPadding());
+ window_size_ = 1;
+ for (auto d : window_dims_) {
+ window_size_ *= d;
+ }
+}
+
+std::optional<Layout> ReduceWindowView::GuessLayout() const {
+ auto zip_dims_strides = llvm::zip(WindowDims(), WindowStrides());
+ auto simple_window_dims =
+ llvm::to_vector(llvm::map_range(zip_dims_strides, [](auto it) {
+ return std::get<0>(it) == 1 && std::get<1>(it) == 1;
+ }));
+
+ if (llvm::count(simple_window_dims, 1) < 2) {
+ return std::nullopt;
+ }
+
+ const bool is_channel_last =
+ simple_window_dims[0] && simple_window_dims[Rank() - 1];
+ if (is_channel_last) {
+ return Layout(0, Rank() - 1,
+ llvm::to_vector(llvm::seq<int64_t>(1, Rank() - 1)));
+ }
+
+ const bool is_channel_first = simple_window_dims[0] && simple_window_dims[1];
+ if (is_channel_first) {
+ return Layout(0, 1, llvm::to_vector(llvm::seq<int64_t>(2, Rank())));
+ }
+
+ // In theory, we can support any layout with at least 2 1's in
+ // `simple_window_dims` by permuting layouts such that the 1's are
+ // the first and last position. Unclear if such a case ever comes up.
+ return std::nullopt;
+}
+
+} // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h
new file mode 100644
index 0000000..512389b
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h
@@ -0,0 +1,62 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_
+
+#include <optional>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/IR/Types.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+// Helpers for working with mhlo.reduce_window attrs in the mlir api as
+// native cc types.
+
+namespace mlir::odml {
+
+class ReduceWindowView {
+ public:
+ explicit ReduceWindowView(mhlo::ReduceWindowOp op);
+
+ llvm::ArrayRef<int64_t> WindowDims() const { return window_dims_; }
+ int64_t WindowSize() const { return window_size_; }
+ llvm::ArrayRef<int64_t> WindowStrides() const { return window_strides_; }
+ llvm::ArrayRef<DimPadding> Paddings() const { return paddings_; }
+ llvm::ArrayRef<int64_t> WindowDilations() const { return window_dilations_; }
+ llvm::ArrayRef<int64_t> BaseDilations() const { return base_dilations_; }
+ int64_t Rank() const { return rank_; }
+
+ std::optional<Layout> GuessLayout() const;
+
+ private:
+ int64_t rank_;
+
+ llvm::SmallVector<int64_t, 4> window_dims_;
+ llvm::SmallVector<int64_t, 4> window_strides_;
+ llvm::SmallVector<int64_t, 4> window_dilations_;
+
+ llvm::SmallVector<DimPadding, 4> paddings_;
+
+ llvm::SmallVector<int64_t, 4> base_dilations_;
+
+ int64_t window_size_;
+};
+
+} // namespace mlir::odml
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc
new file mode 100644
index 0000000..76c843e
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc
@@ -0,0 +1,214 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h"
+
+#include <cstdint>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
+#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+namespace {
+
+//===----------------------------------------------------------------------===//
+// mhlo.slice
+//===----------------------------------------------------------------------===//
+
+// Cast the value to i32.
+Value BuildTFLCastOp(OpBuilder& b, Value value) {
+ return b.create<TFL::CastOp>(
+ value.getLoc(),
+ RankedTensorType::get(llvm::cast<ShapedType>(value.getType()).getShape(),
+ b.getI32Type()),
+ value);
+}
+
+class LegalizeSliceOp : public OpConversionPattern<mhlo::SliceOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::SliceOp slice_op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const final {
+ auto begin = rewriter.create<arith::ConstantOp>(slice_op.getLoc(),
+ slice_op.getStartIndices());
+ auto end = rewriter.create<arith::ConstantOp>(slice_op.getLoc(),
+ slice_op.getLimitIndices());
+ auto strides = rewriter.create<arith::ConstantOp>(slice_op.getLoc(),
+ slice_op.getStrides());
+ auto zero = rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
+ auto no_offset = rewriter.getBoolAttr(false);
+
+ rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
+ slice_op, slice_op.getType(), slice_op.getOperand(),
+ BuildTFLCastOp(rewriter, begin), BuildTFLCastOp(rewriter, end),
+ BuildTFLCastOp(rewriter, strides), zero, zero, zero, zero, zero,
+ no_offset);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// mhlo.dynamic_slice
+//===----------------------------------------------------------------------===//
+
+class CastSliceIndicesToSignless
+ : public OpRewritePattern<mhlo::DynamicSliceOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mhlo::DynamicSliceOp op,
+ PatternRewriter& rewriter) const final;
+};
+
+LogicalResult CastSliceIndicesToSignless::matchAndRewrite(
+ mhlo::DynamicSliceOp op, PatternRewriter& rewriter) const {
+ // All start inds have the same element type.
+ auto start_type =
+ llvm::cast<ShapedType>(op.getStartIndices().front().getType());
+ auto start_e_type = start_type.getElementType();
+
+ if (start_e_type.isSignlessIntOrFloat()) {
+ return rewriter.notifyMatchFailure(op, "Already signless.");
+ }
+ auto new_start_e_type =
+ rewriter.getIntegerType(start_e_type.getIntOrFloatBitWidth());
+
+ llvm::SmallVector<Value> casted_start_inds;
+ for (auto start_ind_opr : op.getStartIndices()) {
+ auto casted_start_ind_opr = rewriter.create<mhlo::ConvertOp>(
+ start_ind_opr.getLoc(), start_ind_opr, new_start_e_type);
+ casted_start_inds.push_back(casted_start_ind_opr.getResult());
+ }
+
+ rewriter.replaceOpWithNewOp<mhlo::DynamicSliceOp>(
+ op, op.getOperand(), casted_start_inds, op.getSliceSizes());
+
+ return success();
+}
+
+bool IsDynamicSliceLegal(mhlo::DynamicSliceOp op) {
+ return !llvm::cast<ShapedType>(op.getOperand().getType()).hasStaticShape();
+}
+
+class LegalizeDynamicSliceOp
+ : public OpConversionPattern<mhlo::DynamicSliceOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::DynamicSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeDynamicSliceOp::matchAndRewrite(
+ mhlo::DynamicSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const {
+ auto start_type =
+ llvm::cast<ShapedType>(op.getStartIndices().front().getType());
+ auto start_e_type = start_type.getElementType();
+ if (!start_e_type.isSignlessIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "Must be signless integer for start indices.");
+ }
+
+ auto input_type = llvm::cast<ShapedType>(op.getOperand().getType());
+ if (!input_type.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(op, "Input must be statically shaped.");
+ }
+
+ //
+ // clamp start indices between zero and shape(operand) - slice_sizes
+ //=-----
+
+ Value clamp_left_cst = rewriter.create<arith::ConstantOp>(
+ op->getLoc(), rewriter.getZeroAttr(start_type));
+
+ llvm::SmallVector<Value> new_start_indices;
+ const auto stride_sizes = UnrollI64Splat(op.getSliceSizes());
+
+ for (auto [dim_size, start_ind_opr, stride_size] :
+ llvm::zip(input_type.getShape(), op.getStartIndices(), stride_sizes)) {
+ const int64_t clamp_right_val = dim_size - stride_size;
+ auto clamp_right_cst = rewriter.create<arith::ConstantOp>(
+ op->getLoc(),
+ DenseElementsAttr::get(start_type, rewriter.getIntegerAttr(
+ start_e_type, clamp_right_val)));
+
+ Value new_start_ind = rewriter.create<TFL::MaximumOp>(
+ op->getLoc(), start_type, clamp_left_cst, start_ind_opr);
+ new_start_ind = rewriter.create<TFL::MinimumOp>(
+ op->getLoc(), start_type, clamp_right_cst, new_start_ind);
+
+ new_start_indices.push_back(new_start_ind);
+ }
+
+ //
+ // pack variadic scalar start indices into one tensor
+ //=-----
+
+ const int64_t packed_start_indices_size = new_start_indices.size();
+ auto packed_start_indices_type =
+ RankedTensorType::get({packed_start_indices_size}, start_e_type);
+
+ auto values_count_attr =
+ rewriter.getI32IntegerAttr(packed_start_indices_size);
+ auto pack_axis_attr = rewriter.getI32IntegerAttr(0);
+
+ auto packed_start_inds = rewriter.create<TFL::PackOp>(
+ op->getLoc(), packed_start_indices_type, new_start_indices,
+ values_count_attr, pack_axis_attr);
+
+ //
+ // build tfl
+ //=-----
+
+ auto slice_sizes_cst =
+ rewriter.create<arith::ConstantOp>(op->getLoc(), op.getSliceSizes());
+
+ rewriter.replaceOpWithNewOp<TFL::SliceOp>(op, op.getType(), op.getOperand(),
+ packed_start_inds, slice_sizes_cst);
+
+ return success();
+}
+
+} // namespace
+
+void PopulateLegalizeSlicePatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns,
+ ConversionTarget& target) {
+ patterns.add<LegalizeSliceOp, LegalizeDynamicSliceOp>(ctx);
+
+ target.addIllegalOp<mhlo::SliceOp>();
+ target.addDynamicallyLegalOp<mhlo::DynamicSliceOp>(IsDynamicSliceLegal);
+}
+
+void PopulatePrepareSlicePatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns) {
+ patterns.add<CastSliceIndicesToSignless>(ctx);
+}
+
+} // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h
new file mode 100644
index 0000000..024cbb4
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h
@@ -0,0 +1,34 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_
+
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+
+// Patterns to legalize mhlo.slice to TFL.
+void PopulateLegalizeSlicePatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns,
+ ConversionTarget& target);
+
+void PopulatePrepareSlicePatterns(MLIRContext* ctx,
+ RewritePatternSet& patterns);
+
+} // namespace mlir::odml
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc
new file mode 100644
index 0000000..43477cc
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc
@@ -0,0 +1,146 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h"
+
+#include <cstdint>
+
+#include "llvm/ADT/ilist.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
+#include "mlir/IR/Block.h" // from @llvm-project
+#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/IR/Region.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+namespace {
+
+using OpListType = llvm::iplist<Operation>;
+
+template <typename ReturnOpType>
+bool MatchTopKComparator(Region& comparator) {
+ if (!comparator.hasOneBlock()) return false;
+ Block& comparator_blk = comparator.front();
+
+ OpListType& operations = comparator_blk.getOperations();
+ if (operations.size() != 2) return false;
+
+ auto compare_op =
+ llvm::dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
+ auto return_op = llvm::dyn_cast_or_null<ReturnOpType>(&operations.back());
+ if (!compare_op || !return_op) return false;
+
+ if (compare_op.getComparisonDirection() != mhlo::ComparisonDirection::GT) {
+ return false;
+ }
+
+ if (compare_op.getOperands()[0] != comparator_blk.getArgument(0) ||
+ compare_op.getOperands()[1] != comparator_blk.getArgument(1)) {
+ return false;
+ }
+
+ return return_op.getOperands().front() == compare_op.getResult();
+}
+
+bool IsSortOpNotTopK(mhlo::SortOp op) {
+ if (op->getNumOperands() != 2) {
+ return true;
+ }
+
+ auto keys_opr = op.getInputs().front();
+ auto keys_type = llvm::cast<ShapedType>(keys_opr.getType());
+
+ if (!keys_type.hasStaticShape() ||
+ !keys_type.getElementType().isIntOrFloat()) {
+ return true;
+ }
+
+ auto indices_opr = op.getInputs().back();
+ auto indices_type = llvm::cast<ShapedType>(indices_opr.getType());
+
+ if (!indices_type.hasStaticShape() ||
+ !indices_type.getElementType().isInteger(32)) {
+ return true;
+ }
+
+ const int64_t sort_dim = op.getDimension();
+ const auto k = indices_type.getDimSize(sort_dim);
+ const auto rank = keys_type.getRank();
+
+ if (sort_dim != rank - 1 || k < 1) {
+ return true;
+ }
+
+ OpBuilder b(op->getContext());
+ if (!MatchIota(b.getI64TensorAttr({sort_dim}), indices_opr)) {
+ return true;
+ }
+
+ if (!MatchTopKComparator<mhlo::ReturnOp>(op.getComparator())) {
+ return true;
+ }
+
+ return false;
+}
+
+class LegalizeSortOp : public OpConversionPattern<mhlo::SortOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::SortOp sort_op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeSortOp::matchAndRewrite(
+ mhlo::SortOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const {
+ if (IsSortOpNotTopK(op)) {
+ return failure();
+ }
+
+ auto keys = op.getInputs().front();
+ auto indices = op.getInputs().back();
+ auto indices_type = llvm::cast<ShapedType>(indices.getType());
+
+ const int32_t k = indices_type.getShape().back();
+ auto k_cst_attr = DenseIntElementsAttr::get(
+ RankedTensorType::get({}, rewriter.getI32Type()), k);
+ auto k_cst = rewriter.create<arith::ConstantOp>(op->getLoc(), k_cst_attr);
+
+ rewriter.replaceOpWithNewOp<TFL::TopKV2Op>(op, keys.getType(),
+ indices.getType(), keys, k_cst);
+
+ return success();
+}
+
+} // namespace
+
+void PopulateSortPatterns(MLIRContext* ctx, RewritePatternSet& patterns,
+ ConversionTarget& target) {
+ patterns.add<LegalizeSortOp>(ctx);
+ target.addDynamicallyLegalOp<mhlo::SortOp>(IsSortOpNotTopK);
+}
+
+} // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h
new file mode 100644
index 0000000..9bbb1f3
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h
@@ -0,0 +1,28 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_
+
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+
+namespace mlir::odml {
+
+void PopulateSortPatterns(MLIRContext* ctx, RewritePatternSet& patterns,
+ ConversionTarget& target);
+
+} // namespace mlir::odml
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc
index 5a7b976..9ff5a6e 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc
@@ -17,7 +17,6 @@
#include <utility>
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep
-#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@@ -29,6 +28,8 @@
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h" // IWYU pragma: keep
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h" // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep
namespace mlir {
@@ -54,6 +55,8 @@
populateWithGenerated(patterns);
PopulatePrepareConvPatterns(context, patterns);
+ PopulatePrepareReduceWindowPatterns(context, patterns);
+ PopulatePrepareSlicePatterns(context, patterns);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
signalPassFailure();
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc
index 8e22b34..e8a2bc8 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc
@@ -148,6 +148,7 @@
case flexbuffers::FBT_VECTOR_INT: {
const auto& vector = value.AsTypedVector();
std::vector<int64_t> vec;
+ vec.reserve(vector.size());
for (size_t i = 0; i < vector.size(); i++) {
vec.push_back(vector[i].AsInt64());
}
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc
index 7aff6c1..5395af5 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc
@@ -14,12 +14,16 @@
==============================================================================*/
// The kept headers are provided for the included file `passes.h.inc`.
+#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
@@ -37,6 +41,9 @@
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" // IWYU pragma: keep
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep
@@ -47,9 +54,137 @@
namespace odml {
namespace {
+// Returns the shape of the given value in a Constant Op.
+arith::ConstantOp ShapeToConst(PatternRewriter& rewriter, Value value) {
+ ArrayRef<int64_t> shape = mlir::cast<ShapedType>(value.getType()).getShape();
+ auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
+ rewriter.getIntegerType(64));
+ auto attr = DenseElementsAttr::get(attr_type, shape);
+ return rewriter.create<arith::ConstantOp>(value.getLoc(), attr_type, attr);
+}
+
+bool IsSign(APInt a, APInt sign) {
+ if (a.isZero()) return a == sign;
+ if (a.isNegative()) return sign == -1;
+ return sign == 1;
+}
+
+bool IsSign(APFloat a, APFloat sign) {
+ if (a.isNaN() || a.isZero()) return a == sign;
+ if (a.isNegative()) return sign.isExactlyValue(-1.0);
+ return sign.isExactlyValue(1.0);
+}
+
+bool IsDenseSplatIntAttr(ElementsAttr float_or_int) {
+ return mlir::isa<SplatElementsAttr>(float_or_int) &&
+ mlir::isa<DenseIntElementsAttr>(float_or_int);
+}
+
+bool IsDenseSplatFloatAttr(ElementsAttr float_or_int) {
+ return mlir::isa<SplatElementsAttr>(float_or_int) &&
+ mlir::isa<DenseFPElementsAttr>(float_or_int);
+}
+
+bool ValueEquals(ElementsAttr float_or_int, double rhs) {
+ if (IsDenseSplatFloatAttr(float_or_int)) {
+ return mlir::cast<SplatElementsAttr>(float_or_int)
+ .getSplatValue<APFloat>()
+ .isExactlyValue(rhs);
+ } else if (IsDenseSplatIntAttr(float_or_int)) {
+ return mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APInt>() ==
+ static_cast<int>(rhs);
+ }
+ return false;
+}
+
+// Returns whether the splat constant is the sign of the int or float Tensor.
+bool TensorIsSign(PatternRewriter& rewriter, ElementsAttr float_or_int,
+ ElementsAttr sgn_cst) {
+ auto sgn_splat = llvm::dyn_cast<SplatElementsAttr>(sgn_cst);
+ if (!sgn_splat) return false;
+
+ auto splat = dyn_cast<SplatElementsAttr>(float_or_int);
+ if (auto float_spl = llvm::dyn_cast_if_present<FloatAttr>(splat),
+ sgn_cst_spl = llvm::dyn_cast_if_present<FloatAttr>(sgn_splat);
+ float_spl && sgn_cst_spl) {
+ return IsSign(float_spl.getValue(), sgn_cst_spl.getValue());
+ }
+ if (auto int_spl = llvm::dyn_cast_if_present<IntegerAttr>(splat),
+ sgn_cst_spl = llvm::dyn_cast_if_present<IntegerAttr>(sgn_splat);
+ int_spl && sgn_cst_spl) {
+ return IsSign(int_spl.getValue(), sgn_cst_spl.getValue());
+ }
+ if (mlir::isa<DenseFPElementsAttr>(float_or_int)) {
+ auto sgn_splat_value = sgn_splat.getSplatValue<APFloat>();
+ return llvm::all_of(float_or_int.getValues<APFloat>(), [&](APFloat value) {
+ return IsSign(value, sgn_splat_value);
+ });
+ }
+ if (mlir::isa<DenseIntElementsAttr>(float_or_int)) {
+ auto sgn_splat_value = sgn_splat.getSplatValue<APInt>();
+ return llvm::all_of(float_or_int.getValues<APInt>(), [&](APInt value) {
+ return IsSign(value, sgn_splat_value);
+ });
+ }
+ return false;
+}
+
+bool SameTypeOrDefaultCompare(mhlo::ComparisonTypeAttr comparison_type_attr,
+ ElementsAttr cst) {
+ if (!comparison_type_attr) return true;
+ auto comparison_type_attr_value = comparison_type_attr.getValue();
+ if (comparison_type_attr_value == mhlo::ComparisonType::FLOAT &&
+ IsDenseSplatFloatAttr(cst)) {
+ return true;
+ }
+ if ((comparison_type_attr_value == mhlo::ComparisonType::SIGNED ||
+ comparison_type_attr_value == mhlo::ComparisonType::UNSIGNED) &&
+ IsDenseSplatIntAttr(cst)) {
+ return true;
+ }
+ return false;
+}
+
+bool ValueIsReciprocal(ElementsAttr float_or_int, ElementsAttr rhs) {
+ if (IsDenseSplatFloatAttr(float_or_int) &&
+ IsDenseSplatFloatAttr(float_or_int)) {
+ return (mlir::cast<SplatElementsAttr>(float_or_int)
+ .getSplatValue<APFloat>() *
+ mlir::cast<SplatElementsAttr>(rhs).getSplatValue<APFloat>())
+ .isExactlyValue(1.0);
+ } else if (IsDenseSplatIntAttr(float_or_int) &&
+ IsDenseSplatIntAttr(float_or_int)) {
+ return (mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APInt>() *
+ mlir::cast<SplatElementsAttr>(rhs).getSplatValue<APInt>()) == 1;
+ }
+ return false;
+}
+
+bool ValueGreaterThanZero(ElementsAttr float_or_int) {
+ if (IsDenseSplatIntAttr(float_or_int)) {
+ auto value =
+ mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APInt>();
+ return !value.isNegative() && !value.isZero();
+ } else if (IsDenseSplatFloatAttr(float_or_int)) {
+ auto value =
+ mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APFloat>();
+ return !value.isNaN() && !value.isNegative() && !value.isZero();
+ }
+ return false;
+}
+
#define GEN_PASS_DEF_LEGALIZEHLOTOTFLITEPASS
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc"
+bool SupportedComparisonType(mhlo::ComparisonTypeAttr comp_type) {
+ if (!comp_type) return true;
+ auto c_ty = comp_type.getValue();
+ return c_ty == mhlo::ComparisonType::FLOAT ||
+ c_ty == mhlo::ComparisonType::SIGNED ||
+ c_ty == mhlo::ComparisonType::UNSIGNED ||
+ c_ty == mhlo::ComparisonType::NOTYPE;
+}
+
class LegalizeHloToTfLitePass
: public impl::LegalizeHloToTfLitePassBase<LegalizeHloToTfLitePass> {
public:
@@ -62,10 +197,43 @@
return !op.getType().getElementType().isF32();
}
+bool IsNotOpLegal(mhlo::NotOp op) {
+ return op.getType().getElementType().isInteger(64);
+}
+
+// Mark possible target ops from rounding patterns as having "unknown"
+// legality. This is required to schedule patterns on these ops even
+// though MhloDialect is explicitly marked legal (which cannot be changed
+// easily).
+void AddRoundingOpsAsUnknown(ConversionTarget& target) {
+ target.addDynamicallyLegalOp<
+ mhlo::FloorOp, mhlo::SubtractOp, mhlo::AndOp, mhlo::SelectOp, mhlo::RemOp,
+ mhlo::AddOp, mhlo::SignOp, mhlo::MulOp, mhlo::DivOp, mhlo::OrOp,
+ mhlo::BroadcastInDimOp, mhlo::ConstantOp, mhlo::RoundOp, mhlo::TupleOp>(
+ [](Operation* op) { return std::nullopt; });
+}
+
+bool IsCompareLegal(mhlo::CompareOp op) {
+ return !SupportedComparisonType(op.getCompareTypeAttr());
+}
+
+void SetUnaryOpLegal(ConversionTarget& target) {
+ auto is_legal = [](Operation* op) {
+ return !llvm::cast<ShapedType>(op->getOperand(0).getType())
+ .getElementType()
+ .isIntOrFloat();
+ };
+ target.addDynamicallyLegalOp<
+ mhlo::AbsOp, mhlo::BitcastConvertOp, mhlo::CeilOp, mhlo::IsFiniteOp,
+ mhlo::CosineOp, mhlo::ExpOp, mhlo::Expm1Op, mhlo::FloorOp, mhlo::ImagOp,
+ mhlo::LogOp, mhlo::NegOp, mhlo::RealOp, mhlo::Log1pOp, mhlo::RsqrtOp,
+ mhlo::SineOp, mhlo::LogisticOp, mhlo::SignOp, mhlo::SqrtOp, mhlo::TanhOp,
+ mhlo::ConvertOp>(is_legal);
+}
+
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_tflite_legalize_hlo.inc"
void LegalizeHloToTfLitePass::runOnOperation() {
MLIRContext* context = &getContext();
-
RewritePatternSet patterns(context);
patterns.add<odml::ConvertCustomCallOp, odml::LowerDotGeneralOp>(context);
populateWithGenerated(patterns);
@@ -73,14 +241,25 @@
ConversionTarget target(*context);
target.addLegalDialect<TFL::TensorFlowLiteDialect, mhlo::MhloDialect>();
target.addLegalOp<func::CallOp, func::ConstantOp, arith::ConstantOp>();
+
target.addDynamicallyLegalOp<mhlo::CustomCallOp>(IsCustomCallLegal);
target.addDynamicallyLegalOp<mhlo::CbrtOp>(IsCbrtLegal);
- target.addIllegalOp<mhlo::DotGeneralOp, mhlo::DotOp, mhlo::TransposeOp>();
+ target.addIllegalOp<mhlo::DotGeneralOp, mhlo::DotOp, mhlo::TransposeOp,
+ mhlo::ShiftRightArithmeticOp, mhlo::ShiftRightLogicalOp,
+ mhlo::RemOp, mhlo::ReshapeOp, mhlo::DynamicReshapeOp>();
+ target.addDynamicallyLegalOp<mhlo::NotOp>(IsNotOpLegal);
+ target.addDynamicallyLegalOp<mhlo::CompareOp>(IsCompareLegal);
+
+ AddRoundingOpsAsUnknown(target);
+ SetUnaryOpLegal(target);
PopulatePadPatterns(context, patterns, target);
PopulateReducePatterns(context, patterns, target);
+ PopulateLegalizeReduceWindowPatterns(context, patterns, target);
PopulateGatherPatterns(context, patterns, target);
PopulateLegalizeConvPatterns(context, patterns, target);
+ PopulateLegalizeSlicePatterns(context, patterns, target);
+ PopulateSortPatterns(context, patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td
index 671115d..7d41cbe 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td
@@ -17,10 +17,15 @@
include "mlir/Dialect/Func/IR/FuncOps.td"
include "mhlo/IR/hlo_ops.td"
include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
+
+def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">;
+
def CreateTFLCastToInt32Op : NativeCodeCall<
"CreateCastToInt32($0, $_loc, $_builder)">;
@@ -28,6 +33,95 @@
(TFL_TransposeOp $arg,
(CreateTFLCastToInt32Op (TFL_ConstOp $perm)))>;
+
+def : Pat<(MHLO_ShiftRightArithmeticOp $l, $r), (TFL_RightShiftOp $l, $r)>;
+def : Pat<(MHLO_ShiftRightLogicalOp $l, $r), (TFL_RightShiftOp $l, $r)>;
+def : Pat<(MHLO_RemOp $l, $r), (TFL_FloorModOp $l, $r)>;
+
+def LegalizeReshape : Pat<(MHLO_ReshapeOp:$output $input),
+ (TFL_ReshapeOp $input,
+ (CreateTFLCastToInt32Op (ShapeToConst $output)))>;
+
+def LegalizeDynamicReshape : Pat<(MHLO_DynamicReshapeOp $input, $shape),
+ (TFL_ReshapeOp $input, (CreateTFLCastToInt32Op $shape))>;
+
+//===----------------------------------------------------------------------===//
+// logical and bitwise ops
+//===----------------------------------------------------------------------===//
+
+class GetRankedScalarAttr<string prefix, int width, string signed, string value> :
+ NativeCodeCall<"DenseElementsAttr::get<" # prefix # "int" # width # "_t>("
+ "RankedTensorType::get({}, $_builder.getIntegerType("
+ # width # signed # "))," # value # ")">;
+
+def : Pat<(MHLO_NotOp I1Tensor:$input), (TFL_LogicalNotOp $input)>;
+
+// TFL does not support bitwise negation. not(x) is equivalant to xor(x, y) if
+// y has a 1 in every bit position (xor(1, 1) = 0 and xor(0, 1) = 1).
+
+// Signed: The 2s complement of -1 has a 1 in every bit position.
+def : Pat<(MHLO_NotOp I8Tensor:$input),
+ (TFL_BitwiseXorOp $input,
+ (Arith_ConstantOp
+ (GetRankedScalarAttr<"", 8, "", "-1">)))>;
+
+def : Pat<(MHLO_NotOp I16Tensor:$input),
+ (TFL_BitwiseXorOp $input,
+ (Arith_ConstantOp
+ (GetRankedScalarAttr<"", 16, "", "-1">)))>;
+
+def : Pat<(MHLO_NotOp I32Tensor:$input),
+ (TFL_BitwiseXorOp $input,
+ (Arith_ConstantOp
+ (GetRankedScalarAttr<"", 32, "", "-1">)))>;
+
+
+// Unsigned: 0xFFF... has a 1 in every bit position.
+def : Pat<(MHLO_NotOp TensorOf<[UI8]>:$input),
+ (TFL_BitwiseXorOp $input,
+ (Arith_ConstantOp
+ (GetRankedScalarAttr<"u", 8, ", false", "0xFFU">)))>;
+
+def : Pat<(MHLO_NotOp TensorOf<[UI16]>:$input),
+ (TFL_BitwiseXorOp $input,
+ (Arith_ConstantOp
+ (GetRankedScalarAttr<"u", 16, ", false", "0xFFFFU">)))>;
+
+def : Pat<(MHLO_NotOp TensorOf<[UI32]>:$input),
+ (TFL_BitwiseXorOp $input,
+ (Arith_ConstantOp
+ (GetRankedScalarAttr<"u", 32, ", false", "0xFFFFFFFFUL">)))>;
+
+//===----------------------------------------------------------------------===//
+// comparison ops
+//===----------------------------------------------------------------------===//
+
+// Check implicit bool cast of `$_self` to ensure Attribute is non-null before
+// casting.
+def HasSupportedComparisonType : AttrConstraint<
+ CPred<"!$_self || SupportedComparisonType($_self.cast<mhlo::ComparisonTypeAttr>())">>;
+
+class MHLO_ComparisonDirectionValue<string enumStr> :
+ ConstantAttr<MHLO_ComparisonDirectionAttr,
+ "::mlir::mhlo::ComparisonDirection::" # enumStr>;
+
+foreach p = [
+ [TFL_EqualOp, MHLO_ComparisonDirectionValue<"EQ">],
+ [TFL_NotEqualOp, MHLO_ComparisonDirectionValue<"NE">],
+ [TFL_GreaterEqualOp, MHLO_ComparisonDirectionValue<"GE">],
+ [TFL_LessEqualOp, MHLO_ComparisonDirectionValue<"LE">],
+ [TFL_GreaterOp, MHLO_ComparisonDirectionValue<"GT">],
+ [TFL_LessOp, MHLO_ComparisonDirectionValue<"LT">]]
+in {
+ def : Pat<
+ (MHLO_CompareOp $l, $r, p[1], HasSupportedComparisonType),
+ (p[0] $l, $r)>;
+}
+
+//===----------------------------------------------------------------------===//
+// unary element-wise op
+//===----------------------------------------------------------------------===//
+
def LowerCbrt : Pat<(MHLO_CbrtOp $opr),
(TFL_PowOp $opr,
(TFL_DivOp
@@ -35,3 +129,429 @@
(Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_None)),
[(F32Tensor $opr)]>;
+
+
+foreach pair = [
+ [MHLO_AbsOp, TFL_AbsOp],
+ [MHLO_BitcastConvertOp, TFL_BitcastOp],
+ [MHLO_CeilOp, TFL_CeilOp],
+ [MHLO_CosineOp, TFL_CosOp],
+ [MHLO_ExpOp, TFL_ExpOp],
+ [MHLO_FloorOp, TFL_FloorOp],
+ [MHLO_ImagOp, TFL_ImagOp],
+ [MHLO_LogOp, TFL_LogOp],
+ [MHLO_LogisticOp, TFL_LogisticOp],
+ [MHLO_NegOp, TFL_NegOp],
+ [MHLO_RealOp, TFL_RealOp],
+ [MHLO_RsqrtOp, TFL_RsqrtOp],
+ [MHLO_SineOp, TFL_SinOp],
+ [MHLO_SignOp, TFL_SignOp],
+ [MHLO_SqrtOp, TFL_SqrtOp],
+ [MHLO_TanhOp, TFL_TanhOp]
+] in {
+ def : Pat<
+ (pair[0] $input),
+ (pair[1] $input)>;
+}
+
+def : Pat<
+ (MHLO_ConvertOp $input),
+ (TFL_CastOp $input)>;
+
+def : Pat<
+ (MHLO_Expm1Op F32Tensor:$x),
+ (TFL_SubOp
+ (TFL_ExpOp $x),
+ (Arith_ConstantOp
+ ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
+ TFL_AF_None)>;
+
+def : Pat<
+ (MHLO_IsFiniteOp F32Tensor:$x),
+ (TFL_EqualOp
+ (TFL_SubOp $x, $x, TFL_AF_None),
+ (Arith_ConstantOp
+ ConstantAttr<RankedF32ElementsAttr<[]>, "0.0f">))>;
+
+def : Pat<
+ (MHLO_Log1pOp F32Tensor:$x),
+ (TFL_LogOp
+ (TFL_AddOp
+ $x,
+ (Arith_ConstantOp
+ ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
+ TFL_AF_None))>;
+
+//===----------------------------------------------------------------------===//
+// rounding
+//===----------------------------------------------------------------------===//
+
+class ValueEquals<string val> :
+ Constraint<CPred<"ValueEquals($0, " # val # ")">>;
+
+def SameValue :
+ Constraint<CPred<"$0 == $1">>;
+
+def FloatOrDefaultCompare :
+ Constraint<CPred<"!$0 || $0.getValue() == ::mlir::mhlo::ComparisonType::FLOAT">>;
+
+def SameTypeOrDefaultCompare :
+ Constraint<CPred<"SameTypeOrDefaultCompare($0, $1)">>;
+
+def ValueIsReciprocal :
+ Constraint<CPred<"ValueIsReciprocal($0, $1)">>;
+
+def TensorIsSign :
+ Constraint<CPred<"TensorIsSign($_builder, $0, $1)">>;
+
+def ValueGreaterThanZero :
+ Constraint<CPred<"ValueGreaterThanZero($0)">>;
+
+
+// Converts a dag of HLOs representing banker rounding (round x.5 to nearest
+// even) to tfl.round. This only supports float types because mhlo.floor only
+// supports float types. tf.round with integer input type will become an
+// identity op, so we will never face an mhlo.floor with an integer input type.
+// The pattern matched executes the following computation:
+// frac = x - floor(x)
+// to_even = (floor(x) - 2 * floor(0.5 * x)) == 1
+// if frac > 0.5 || (frac == 0.5 && to_even)
+// return floor(x) + 1
+// else
+// return floor(x)
+def Round : Pat<(MHLO_SelectOp
+ (MHLO_OrOp
+ (MHLO_CompareOp (MHLO_SubtractOp:$frac
+ $input,
+ (MHLO_FloorOp:$floor $input)),
+ (MHLO_ConstantOp $half),
+ MHLO_ComparisonDirectionValue<"GT">,
+ $compare_type0),
+ (MHLO_AndOp
+ (MHLO_CompareOp
+ $frac1,
+ (MHLO_ConstantOp $half1),
+ MHLO_ComparisonDirectionValue<"EQ">,
+ $compare_type1),
+ (MHLO_CompareOp
+ (MHLO_SubtractOp
+ $floor1,
+ (MHLO_MulOp
+ (MHLO_FloorOp (MHLO_MulOp $input, (MHLO_ConstantOp $half2))),
+ (MHLO_ConstantOp $two))),
+ (MHLO_ConstantOp $one1),
+ MHLO_ComparisonDirectionValue<"EQ">,
+ $compare_type2))),
+ (MHLO_AddOp $floor2, (MHLO_ConstantOp $one)),
+ $floor3),
+ (TFL_RoundOp $input),
+ [(ValueEquals<"1.0"> $one),
+ (ValueEquals<"1.0"> $one1),
+ (ValueEquals<"2.0"> $two),
+ (ValueEquals<"0.5"> $half),
+ (ValueEquals<"0.5"> $half1),
+ (ValueEquals<"0.5"> $half2),
+ (SameValue $floor, $floor1),
+ (SameValue $floor, $floor2),
+ (SameValue $floor, $floor3),
+ (SameValue $frac, $frac1),
+ (FloatOrDefaultCompare $compare_type0),
+ (FloatOrDefaultCompare $compare_type1),
+ (FloatOrDefaultCompare $compare_type2)]>;
+
+// Converts a dag of HLOs representing floor_mod to tfl.floor_mod.
+// The pattern matched executes the following computation:
+//
+// rem = remainder(arg0, arg1)
+// for i in 0 to len(arg1):
+// if ((rem[i] < 0) != (arg0[i] < 0) && arg0[i] != 0)
+// rem[i] += arg1[i]
+// return rem
+def : Pat<(MHLO_SelectOp
+ (MHLO_AndOp
+ (MHLO_CompareOp
+ (MHLO_CompareOp:$rltz
+ (MHLO_RemOp:$rem $arg, $arg1),
+ (MHLO_ConstantOp $cst),
+ MHLO_ComparisonDirectionValue<"LT">,
+ $compare_type),
+ (MHLO_CompareOp:$arg1ltz $arg1, (MHLO_ConstantOp $cst1), MHLO_ComparisonDirectionValue<"LT">, $compare_type1),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type2),
+ (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)),
+ (MHLO_AddOp $rem2, $arg1),
+ $rem3),
+ (TFL_FloorModOp $arg, $arg1),
+ [(ValueEquals<"0.0"> $cst),
+ (ValueEquals<"0.0"> $cst1),
+ (ValueEquals<"0.0"> $cst2),
+ (SameValue $rem, $rem1),
+ (SameValue $rem, $rem2),
+ (SameValue $rem, $rem3),
+ (SameTypeOrDefaultCompare $compare_type, $cst),
+ (SameTypeOrDefaultCompare $compare_type1, $cst1)]>;
+
+// Converts a dag of HLOs representing floor_mod with a constant to
+// tfl.floor_mod. The pattern matched executes the following computation:
+//
+// cst = value that is > 0
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg1):
+// if (rem[i] < 0 && rem[i] != 0)
+// rem[i] += cst
+// return rem
+def : Pat<(MHLO_SelectOp
+ (MHLO_AndOp
+ (MHLO_CompareOp:$rltz
+ (MHLO_RemOp:$rem $arg, (MHLO_ConstantOp $cst)),
+ (MHLO_ConstantOp $cst1),
+ MHLO_ComparisonDirectionValue<"LT">,
+ $compare_type),
+ (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)),
+ (MHLO_AddOp $rem2, (MHLO_ConstantOp $cst3)),
+ $rem3),
+ (TFL_FloorModOp $arg, (Arith_ConstantOp $cst3)),
+ [(ValueGreaterThanZero $cst),
+ (ValueEquals<"0.0"> $cst1),
+ (ValueEquals<"0.0"> $cst2),
+ (SameValue $cst, $cst3),
+ (SameValue $rem, $rem1),
+ (SameValue $rem, $rem2),
+ (SameValue $rem, $rem3),
+ (SameTypeOrDefaultCompare $compare_type, $cst1),
+ (SameTypeOrDefaultCompare $compare_type3, $cst2)]>;
+
+// Converts a dag of HLOs representing floor_div to tfl.floor_div.
+// The pattern matched executes the following computation:
+//
+// rem = remainder(arg0, arg1)
+// for i in 0 to len(arg1):
+// rem[i] = arg0[i] - rem[i] / arg1[i]
+// if (rem[i] != 0 && sign(arg1[i]) != sign(rem[i]))
+// rem[i] -= 1.0
+// return round_nearest_afz(rem)
+// As a dag this looks like the following:
+// round
+// |
+// -------- select
+// | | \
+// && + div
+// / | / \
+// != != div -1
+// / | / | / |
+// rem 0.0 sn sn1 - $1
+// / | | | / |
+// $0 $1 $1 rem $0 rem
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+ (MHLO_SelectOp
+ (MHLO_AndOp
+ (MHLO_CompareOp
+ (MHLO_RemOp:$rem $arg0, $arg1),
+ (MHLO_ConstantOp $cst),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type),
+ (MHLO_CompareOp
+ (MHLO_SignOp $arg1),
+ (MHLO_SignOp $rem1),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type1)),
+ (MHLO_AddOp
+ (MHLO_DivOp:$div
+ (MHLO_SubtractOp $arg0, $rem2),
+ $arg1b),
+ (MHLO_ConstantOp $cst_neg1)),
+ $div1)),
+ (TFL_FloorDivOp $arg0, $arg1),
+ [(ValueEquals<"0.0"> $cst),
+ (ValueEquals<"-1.0"> $cst_neg1),
+ (SameValue $div, $div1),
+ (SameValue $rem, $rem1),
+ (SameValue $rem, $rem2),
+ (FloatOrDefaultCompare $compare_type, $cst),
+ (FloatOrDefaultCompare $compare_type1, $cst)]>;
+
+// Converts a dag of HLOs representing floor_div with a splat constant to
+// tfl.floor_div. The pattern matched executes the following computation:
+// This particular pattern matches multiplication with the reciprocal of the
+// constant instead of dividing by the constant.
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg0):
+// rem[i] = (arg0[i] - rem[i]) * 1 / cst
+// if (rem[i] != 0 && sign(cst) != sign(rem[i]))
+// rem[i] += -1.0
+// return round_nearest_afz(rem)
+// As a dag this looks like the following:
+// round
+// |
+// -------- select
+// | | \
+// && + mul
+// / | / \
+// != != mul -1
+// / | / | / |
+// rem 0.0 cs1 sn1 - cs2
+// / | | / |
+// $0 cst rem $0 rem
+// cs1 == sign(cst)
+// cs2 = 1 / cst i.e. the reciprocal
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+ (MHLO_SelectOp
+ (MHLO_AndOp
+ (MHLO_CompareOp
+ (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)),
+ (MHLO_ConstantOp $cst_zero),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type),
+ (MHLO_CompareOp
+ (MHLO_ConstantOp $cst_sgn),
+ (MHLO_SignOp $rem1),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type1)),
+ (MHLO_AddOp
+ (MHLO_MulOp:$mul
+ (MHLO_SubtractOp $arg0, $rem2),
+ (MHLO_ConstantOp $cst_recip)),
+ (MHLO_ConstantOp $cst_neg1)),
+ $mul1)),
+ (TFL_FloorDivOp $arg0, $cst),
+ [(ValueEquals<"0.0"> $cst_zero),
+ (ValueEquals<"-1.0"> $cst_neg1),
+ (TensorIsSign $cstv, $cst_sgn),
+ (ValueIsReciprocal $cstv, $cst_recip),
+ (SameValue $mul, $mul1),
+ (SameValue $rem, $rem1),
+ (SameValue $rem, $rem2),
+ (FloatOrDefaultCompare $compare_type, $cst_zero),
+ (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>;
+
+// Converts a dag of HLOs representing floor_div with a splat constant to
+// tfl.floor_div. The pattern matched executes the following computation:
+// This particular pattern matches division with the constant.
+// .
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg0):
+// rem[i] = (arg0[i] - rem[i]) / cst
+// if (rem[i] != 0 && sign(cst) != sign(rem[i]))
+// rem[i] -= 1.0
+// return round_nearest_afz(rem)
+// As a dag this looks like the following:
+// round
+// |
+// -------- select
+// | | \
+// && + div
+// / | / \
+// != != div -1
+// / | / | / |
+// rem 0.0 cs1 sn1 - cs2
+// / | | / |
+// $0 cst rem $0 rem
+// cs1 == sign(cst)
+// cs2 = 1 / cst i.e. the reciprocal
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+ (MHLO_SelectOp
+ (MHLO_AndOp
+ (MHLO_CompareOp
+ (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)),
+ (MHLO_ConstantOp $cst_zero),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type),
+ (MHLO_CompareOp
+ (MHLO_ConstantOp $cst_sgn),
+ (MHLO_SignOp $rem1),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type1)),
+ (MHLO_AddOp
+ (MHLO_DivOp:$div
+ (MHLO_SubtractOp $arg0, $rem2),
+ (MHLO_ConstantOp $cstv1)),
+ (MHLO_ConstantOp $cst_neg1)),
+ $div1)),
+ (TFL_FloorDivOp $arg0, $cst),
+ [(ValueEquals<"0.0"> $cst_zero),
+ (ValueEquals<"-1.0"> $cst_neg1),
+ (TensorIsSign $cstv, $cst_sgn),
+ (SameValue $div, $div1),
+ (SameValue $rem, $rem1),
+ (SameValue $rem, $rem2),
+ (SameValue $cstv1, $cstv),
+ (FloatOrDefaultCompare $compare_type, $cst_zero),
+ (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>;
+
+// Converts a dag of HLOs representing floor_div with a broadcasted vector
+// constant to tfl.floor_div. The pattern matched executes the following
+// computation:
+// scs = sign(cst)
+// bcst = broadcast(cst)
+// rem = remainder(arg0, bcst)
+// for i in 0 to len(arg0):
+// rem[i] = arg0[i] - rem[i] * / bcst
+// if (rem[i] != 0 && scs != sign(rem[i]))
+// rem[i] -= 1.0
+// return round_nearest_afz(rem)
+// Where scs is a splat constant folded sign on the unbroadcasted tensor.
+//
+// As a dag this looks like the following:
+// round
+// |
+// -------- select
+// | | \
+// && + div
+// / | / \
+// != != div -1
+// / | / | / |
+// rem 0.0 scs sn1 - bcst
+// / | | / |
+// $0 bcst rem $0 rem
+// |
+// cst
+// scs == sign(cst) == sign(bcst)
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+ (MHLO_SelectOp
+ (MHLO_AndOp
+ (MHLO_CompareOp
+ (MHLO_RemOp:$rem $arg0,
+ (MHLO_BroadcastInDimOp:$bcst
+ (MHLO_ConstantOp $cstv),
+ $broadcast_dimension)),
+ (MHLO_ConstantOp $cst_zero),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type),
+ (MHLO_CompareOp
+ (MHLO_ConstantOp $cst_sgn),
+ (MHLO_SignOp $rem1),
+ MHLO_ComparisonDirectionValue<"NE">,
+ $compare_type1)),
+ (MHLO_AddOp
+ (MHLO_DivOp:$div
+ (MHLO_SubtractOp $arg0, $rem2),
+ $bcst1),
+ (MHLO_ConstantOp $cst_neg1)),
+ $div1)),
+ (TFL_FloorDivOp $arg0, $bcst),
+ [(ValueEquals<"0.0"> $cst_zero),
+ (ValueEquals<"-1.0"> $cst_neg1),
+ (TensorIsSign $cstv, $cst_sgn),
+ (SameValue $bcst, $bcst1),
+ (SameValue $div, $div1),
+ (SameValue $rem, $rem1),
+ (SameValue $rem, $rem2),
+ (FloatOrDefaultCompare $compare_type, $cst_zero),
+ (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>;
+
diff --git a/tensorflow/compiler/mlir/lite/stateful_error_reporter.h b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h
new file mode 100644
index 0000000..fbb82d3
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h
@@ -0,0 +1,36 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_
+
+// LINT.IfChange
+#include <string>
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+namespace tflite_migration {
+
+// Similar to tflite::ErrorReporter, except that it allows callers to get the
+// last error message.
+class StatefulErrorReporter : public tflite::ErrorReporter {
+ public:
+ // Returns last error message. Returns empty string if no error is reported.
+ virtual std::string message() = 0;
+};
+
+} // namespace tflite_migration
+// LINT.ThenChange(//tensorflow/lite/stateful_error_reporter.h)
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD
index 6cc9a62..4e7fa53 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD
@@ -56,9 +56,9 @@
"importer_test_min_max.cc",
],
deps = [
+ "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/lite/schema:schema_utils",
- "//tensorflow/lite:framework",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
],
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc
index 30890fb..6231088 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc
@@ -16,17 +16,15 @@
#include <iostream>
#include <memory>
#include <optional>
-#include <system_error>
-#include "absl/strings/string_view.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/raw_ostream.h"
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
-#include "tensorflow/lite/model.h"
using llvm::cl::opt;
@@ -52,7 +50,7 @@
namespace {
std::optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
llvm::StringRef buffer) {
- auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
+ auto model_ptr = TFL::FlatBufferModelAbslError::VerifyAndBuildFromBuffer(
buffer.data(), buffer.size());
if (nullptr == model_ptr) {
return std::nullopt;
diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
index 853606c..02c4ac4 100644
--- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
@@ -107,7 +107,6 @@
using mlir::ModuleOp;
using mlir::Operation;
using mlir::OwningOpRef;
-using ::stablehlo::quantization::QuantizationConfig;
using ::tensorflow::quantization::PyFunctionLibrary;
bool IsControlFlowV1Op(Operation* op) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index d3830d3..2b5b753 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -81,25 +81,6 @@
namespace {
-// TODO(b/355062942): This a temporary solution to unblock LLVM intergration.
-// https://github.com/llvm/llvm-project/commit/bbd4af5da2b741672a8e6f625eb12ea5c2d6220f
-// changed the behavior of `applySignatureConversion`. Before, an op adaptor
-// would have the new block arguments directly as operands. Now, there is an
-// `UnrealizedConversionCastOp` inserts from the new type to the old type.
-// The new behaviour is correct, but passes in this file depended on the old
-// bahavior and worked by coincidence.
-llvm::SmallVector<Value, 4> GetOperandsAndSkipUnrealizedConversionCasts(
- ValueRange operands) {
- llvm::SmallVector<Value, 4> result;
- for (Value operand : operands) {
- if (auto cast = operand.getDefiningOp<UnrealizedConversionCastOp>()) {
- operand = cast.getInputs().front();
- }
- result.push_back(operand);
- }
- return result;
-}
-
/// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public impl::LowerStaticTensorListPassBase<LowerStaticTensorListPass> {
@@ -371,9 +352,7 @@
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
-
+ auto operands = adaptor.getOperands();
Value input = operands[0];
Value index = operands[1];
Value item = operands[2];
@@ -433,8 +412,7 @@
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+ auto operands = adaptor.getOperands();
Value input = operands[0];
Value index = operands[1];
Value item = operands[2];
@@ -721,8 +699,7 @@
LogicalResult matchAndRewrite(
TF::TensorListPushBackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+ auto operands = adaptor.getOperands();
Value input_handle = operands[0];
Value item = operands[1];
@@ -764,8 +741,7 @@
LogicalResult matchAndRewrite(
TF::TensorListResizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+ auto operands = adaptor.getOperands();
Value input_handle = operands[0];
Value size = operands[1];
@@ -929,9 +905,7 @@
LogicalResult matchAndRewrite(
TF::TensorListGetItemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
-
+ auto operands = adaptor.getOperands();
Value input = operands[0];
Value index = operands[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
@@ -948,8 +922,7 @@
TF::TensorListLengthOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value input_handle =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands())[0];
+ Value input_handle = adaptor.getOperands()[0];
BoolAttr true_attr = rewriter.getBoolAttr(true);
auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
@@ -970,8 +943,7 @@
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+ auto operands = adaptor.getOperands();
Value input = operands[0];
Value element_shape = operands[1];
@@ -1021,8 +993,7 @@
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+ auto operands = adaptor.getOperands();
Value input = operands[0];
Value element_shape = operands[1];
@@ -1084,8 +1055,7 @@
LogicalResult matchAndRewrite(
TF::IdentityOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Value input =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands())[0];
+ Value input = adaptor.getOperands()[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), input,
op->getAttrs());
return success();
@@ -1098,9 +1068,7 @@
LogicalResult matchAndRewrite(
func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
-
+ auto operands = adaptor.getOperands();
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, ValueRange{}, operands,
op->getAttrs());
return success();
@@ -1113,8 +1081,7 @@
LogicalResult matchAndRewrite(
TF::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto operands =
- GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+ auto operands = adaptor.getOperands();
rewriter.replaceOpWithNewOp<TF::YieldOp>(op, operands);
return success();
}
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc
index 6f4c2dd..1bbf673 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc
@@ -129,7 +129,7 @@
void AddStablehloQuantToIntPasses(OpPassManager& pm) {
pm.addNestedPass<func::FuncOp>(
- mlir::stablehlo::createStablehloLegalizeQuantToIntPass());
+ mlir::stablehlo::createStablehloLegalizeQuantToMathPass());
// StableHLO -> MHLO legalization.
pm.addPass(mhlo::createStablehloLegalizeToHloPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc
index 318ccbd..e71af87 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc
@@ -27,9 +27,9 @@
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
#include "tensorflow/compiler/mlir/register_common_dialects.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
namespace mlir::quant::stablehlo {
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc
index 8cbd48d..2e9d7a8 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc
@@ -25,12 +25,12 @@
#include "xla/shape.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
-#include "tsl/lib/core/status_test_util.h"
namespace mlir::quant::stablehlo {
namespace {
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc
index e96ab83..d8878ff 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc
@@ -31,7 +31,7 @@
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
pm.addPass(mhlo::createHloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
- mlir::stablehlo::createStablehloLegalizeQuantToIntPass());
+ mlir::stablehlo::createStablehloLegalizeQuantToMathPass());
pm.addPass(mhlo::createStablehloLegalizeToHloPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(CreateVerifyQuantLegalizationPass());
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD
index b465fe1..80167b4 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD
@@ -51,11 +51,11 @@
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@stablehlo//:stablehlo_ops",
],
)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc
index c3034f4..be49ddb 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc
@@ -32,7 +32,7 @@
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/path.h"
#include "tsl/platform/test.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 7b90a43..cbde974 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -1137,7 +1137,7 @@
DefaultValuedOptionalAttr<I64ArrayAttr, "{}">:$low_priority_allowed_batch_sizes,
DefaultValuedOptionalAttr<I64Attr, "0">:$low_priority_max_enqueued_batches,
DefaultValuedOptionalAttr<TF_AnyStrAttrOf<["low_priority_padding_with_max_batch_size", "low_priority_padding_with_next_allowed_batch_size", "priority_isolation"]>, "\"low_priority_padding_with_max_batch_size\"">:$mixed_priority_policy,
- DefaultValuedOptionalAttr<TF_AnyStrAttrOf<["PAD_UP"]>, "\"PAD_UP\"">:$batch_padding_policy,
+ DefaultValuedOptionalAttr<TF_AnyStrAttrOf<["PAD_UP", "BATCH_DOWN", "MINIMIZE_TPU_COST_PER_REQUEST"]>, "\"PAD_UP\"">:$batch_padding_policy,
DefaultValuedOptionalAttr<BoolAttr, "false">:$enable_large_batch_splitting
);
@@ -11293,6 +11293,7 @@
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_inter_op_parallelism,
DefaultValuedOptionalAttr<StrAttr, "\"default\"">:$deterministic,
DefaultValuedOptionalAttr<BoolAttr, "false">:$preserve_cardinality,
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$use_unbounded_threadpool,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$metadata
);
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
index 4e82022..a458a20 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
@@ -424,7 +424,7 @@
}
// CHECK: func private @f_callee(%[[ARG0:.*]]: tensor<0xf32>) -> tensor<0xf32>
- // CHECK-SAME: tf._input_shapes = [#tf_type.shape<00>]
+ // CHECK-SAME: tf._input_shapes = [#tf_type.shape<0>]
func.func private @f_callee(%arg0: tensor<0xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<0xf32> attributes {tf._input_shapes = [#tf_type.shape<0>, #tf_type.shape<>]} {
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor<0xf32>
%1 = "tf.AddV2"(%arg0, %0) : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
index 1789b1d..4a31291 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
@@ -289,7 +289,7 @@
%1 = "tf.AddV2"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
func.return %1 : tensor<f32>
}
- // CHECK: func.func private @f_callee(%arg0: tensor<f32>) -> tensor<f32> attributes {tf._input_shapes = [#tf_type.shape<00>]} {
+ // CHECK: func.func private @f_callee(%arg0: tensor<f32>) -> tensor<f32> attributes {tf._input_shapes = [#tf_type.shape<0>]} {
// CHECK: %cst = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %0 = "tf.AddV2"(%arg0, %cst) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %0 : tensor<f32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
index a7ee2b1..f400424 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
@@ -829,6 +829,7 @@
"@local_xla//xla/service:shape_inference",
"@local_xla//xla/translate/hlo_to_mhlo:hlo_utils",
"@local_xla//xla/translate/mhlo_to_hlo:type_to_shape",
+ "@local_xla//xla/tsl/util:env_var",
],
)
@@ -1094,3 +1095,14 @@
"@local_tsl//tsl/platform:path",
],
)
+
+tf_cc_test(
+ name = "shape_inference_test",
+ srcs = ["shape_inference_test.cc"],
+ deps = [
+ ":shape_inference_pass",
+ "@com_google_googletest//:gtest_main",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:env",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
index 54fb7cf..77f9361 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
@@ -36,13 +36,13 @@
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "xla/tsl/framework/device_type.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/util/debug_data_dumper.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace tfrt_compiler {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 4bcd7e0..d9110a2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -87,6 +87,7 @@
#include "xla/shape.h"
#include "xla/translate/hlo_to_mhlo/hlo_utils.h"
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"
+#include "xla/tsl/util/env_var.h"
#include "xla/window_util.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -110,6 +111,20 @@
namespace TF {
namespace {
+MLIRContext::Threading GetMlirContextThreading() {
+ bool enable_single_thread_mlir_context = []() {
+ bool result = false;
+ if (auto status = tsl::ReadBoolFromEnvVar(kMLIRContextSingleThreadVar,
+ /*default_val=*/false, &result);
+ status.ok()) {
+ return result;
+ }
+ return false;
+ }();
+ return enable_single_thread_mlir_context ? MLIRContext::Threading::DISABLED
+ : MLIRContext::Threading::ENABLED;
+}
+
// Compute a refined type between two types `lhs` and `rhs`, the result type
// is always more refined (i.e. has more static information) than `lhs`
// This method will actually merge the information contained in the
@@ -443,6 +458,11 @@
}
} // namespace
+// Create a MLIRContext based on the threading setup in the env var.
+std::unique_ptr<MLIRContext> MakeMLIRContextWithThreading() {
+ return std::make_unique<MLIRContext>(GetMlirContextThreading());
+}
+
// Returns whether type can be further refined.
bool CanBeRefined(Type type) {
auto shape_type = mlir::dyn_cast<ShapedType>(type);
@@ -1024,7 +1044,7 @@
// each `XlaCallModule` op. Uses its own MLIRContext since the loader needs to
// load additional dialects, which is not allowed for the main context since
// shape inference may be called from a pass.
- MLIRContext xla_call_module_context_;
+ std::unique_ptr<MLIRContext> xla_call_module_context_;
DenseMap<XlaCallModuleOp, std::unique_ptr<tensorflow::XlaCallModuleLoader>>
xla_call_module_loaders_;
};
@@ -1036,6 +1056,7 @@
symbol_users_(symbol_table_, module),
graph_version_(graph_version),
propagate_caller_callee_constants_(propagate_caller_callee_constants) {
+ xla_call_module_context_ = MakeMLIRContextWithThreading();
for (const auto& op_type : ops_to_skip) {
ops_to_skip_.insert(op_type);
}
@@ -1242,10 +1263,10 @@
mlir::DialectRegistry registry;
registry.insert<mlir::func::FuncDialect>();
mlir::func::registerAllExtensions(registry);
- xla_call_module_context_.appendDialectRegistry(registry);
+ xla_call_module_context_->appendDialectRegistry(registry);
auto l = tensorflow::XlaCallModuleLoader::Create(
- &xla_call_module_context_, op.getVersion(), op.getModule().str(),
+ xla_call_module_context_.get(), op.getVersion(), op.getModule().str(),
std::move(disabled_checks), std::move(platforms),
/*num_invocation_args=*/op.getArgs().size(),
op.getHasTokenInputOutput());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
index 46c1bc9..9075754 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
@@ -17,6 +17,7 @@
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SHAPE_INFERENCE_H_
#include <cstdint>
+#include <memory>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@@ -31,6 +32,9 @@
namespace mlir {
namespace TF {
+inline constexpr char kMLIRContextSingleThreadVar[] =
+ "TF_USE_SINGLE_THREAD_MLIR_CONTEXT";
+
// Returns whether type can be further refined.
bool CanBeRefined(Type type);
@@ -71,6 +75,9 @@
int64_t max_iterations = 10,
ArrayRef<TypeID> ops_to_skip = {});
+// Create a MLIRContext based on the threading setup in the env var.
+std::unique_ptr<MLIRContext> MakeMLIRContextWithThreading();
+
} // namespace TF
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc
new file mode 100644
index 0000000..416807c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc
@@ -0,0 +1,39 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+
+namespace mlir {
+namespace TF {
+namespace {
+
+TEST(ShapeInferenceTest, CreateMultiThreadedMLIRContext) {
+ std::unique_ptr<MLIRContext> ctx = MakeMLIRContextWithThreading();
+ EXPECT_TRUE(ctx->isMultithreadingEnabled());
+}
+
+TEST(ShapeInferenceTest, CreateSingleThreadedMLIRContext) {
+ setenv(kMLIRContextSingleThreadVar, "true", 1);
+ std::unique_ptr<MLIRContext> ctx = MakeMLIRContextWithThreading();
+ EXPECT_FALSE(ctx->isMultithreadingEnabled());
+}
+
+} // namespace
+} // namespace TF
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc
index 9262f87..3720a09 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc
@@ -22,8 +22,8 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/errors.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
namespace mlir::TF {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
index 20f771d..4763746 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
@@ -97,12 +97,12 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla:shape_util",
"@local_xla//xla/client:xla_builder",
"@local_xla//xla/client:xla_computation",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -179,7 +179,6 @@
"//tensorflow/core/tpu/kernels:tpu_compile_op_support",
"//tensorflow/core/tpu/kernels/xla:host_compute_ops",
"@com_google_googletest//:gtest",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/lib/monitoring:test_utils",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
@@ -187,6 +186,7 @@
"@local_xla//xla/client:client_library",
"@local_xla//xla/stream_executor:platform_manager",
"@local_xla//xla/translate/mhlo_to_hlo:type_to_shape",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -247,8 +247,8 @@
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -299,7 +299,7 @@
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
index 916b568..bce19ea 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
@@ -28,9 +28,9 @@
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/register_common_dialects.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc
index 71640e8..57769d2 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc
@@ -39,6 +39,7 @@
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/shape_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
@@ -49,7 +50,6 @@
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc
index 06208be..c3598e6 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc
@@ -29,11 +29,11 @@
#include "xla/shape.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/monitoring/test_utils.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
index cad1edf..1da6d58c 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
@@ -27,9 +27,9 @@
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/register_common_dialects.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
index 74219b2..49b12fa 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
@@ -168,8 +168,8 @@
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -227,7 +227,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
index 233b112..20da430 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
@@ -33,9 +33,9 @@
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc
index 3f58643..5830676 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc
@@ -29,6 +29,7 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "xla/client/client_library.h"
#include "xla/stream_executor/platform_manager.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/lib/monitoring/test_utils.h"
#include "tensorflow/core/platform/env.h"
@@ -37,7 +38,6 @@
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/util/debug_data_dumper.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/monitoring/test_utils.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc
index 4e5199b..22d36ea 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc
@@ -32,9 +32,9 @@
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/register_common_dialects.h"
#include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
#include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace tf2xla {
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD
index ea2bf17..e635213 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD
@@ -244,8 +244,8 @@
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -296,7 +296,7 @@
"//tensorflow/core:testlib",
"//tensorflow/core/platform:enable_tf2_utils",
# "//tensorflow/core/platform:resource_loader",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc
index 840d4c9..3365a85 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc
@@ -29,10 +29,10 @@
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/register_common_dialects.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/file_statistics.h"
#include "tsl/platform/status.h"
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc
index 6cbc67d..78d027a 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc
@@ -30,6 +30,7 @@
#include "tensorflow/cc/ops/tpu_functional_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
@@ -38,7 +39,6 @@
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
index 6ef95e7..8be5db9 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
@@ -129,8 +129,8 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Pass",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:statusor",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -390,7 +390,6 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
@@ -399,6 +398,7 @@
"@local_xla//xla/client:xla_builder",
"@local_xla//xla/client:xla_computation",
"@local_xla//xla/mlir_hlo",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -507,8 +507,8 @@
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
index 9021095..a17ef43 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
@@ -34,7 +34,7 @@
#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc
index aecf9db..c9d0680 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc
@@ -42,9 +42,9 @@
#include "xla/client/xla_computation.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/shape_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
index dbb7773..730d096 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
@@ -11,6 +11,7 @@
//
// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>)
// CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 1 "
+// CHECK-SAME: device_assignment = []
// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
// CHECK: return
@@ -38,7 +39,8 @@
// CHECK: return
//
// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>)
-// CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true "
+// CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 use_spmd_for_xla_partitioning: true "
+// CHECK-SAME: device_assignment = [0, 0, 0, 0, 0, 0, 0, 1]
// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
// CHECK: return
@@ -70,6 +72,7 @@
// CHECK: return
//
// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32>
+// CHECK-SAME: device_assignment = [0, 0, 0, 0, 0, 0, 0, 1]
// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
// CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1)
// CHECK: return
@@ -102,6 +105,7 @@
// CHECK: return
//
// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32>
+// CHECK-SAME: device_assignment = []
// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
// CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1)
// CHECK: return
diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir
new file mode 100644
index 0000000..02afa96
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir
@@ -0,0 +1,8 @@
+// RUN: tf-tfrt-opt %s -tf-device-cleanup | FileCheck %s
+
+// CHECK-LABEL: func @ops_with_device
+func.func @ops_with_device() {
+ %0 = "tf.VarHandleOp"() {container = "", shared_name = "var", device = "/device/..."} : () -> tensor<!tf_type.resource<tensor<1xf32>>>
+ // CHECK-NOT: device = "/device/..."
+ func.return
+}
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD
index 80969fe..2ec0fdd 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD
@@ -69,6 +69,7 @@
"lower_to_ifrt_restore_variable.cc",
"rewrite_cluster_to_ifrt_call.cc",
"sink_variable_as_named_array.cc",
+ "tf_device_cleanup.cc",
"tf_identity_propagation.cc",
"tf_ifrt_passes.cc",
"tf_restore_merging.cc",
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
index 6da9eda..3b190d3 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
@@ -31,6 +31,7 @@
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/test_util.h"
#include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
@@ -38,7 +39,6 @@
#include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool.h"
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td
index 7cdc557..9c37c58 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td
@@ -129,3 +129,14 @@
let constructor = "CreateTfIdentityPropagationPass()";
}
+def TfDeviceCleanupPass : Pass<"tf-device-cleanup", "mlir::func::FuncOp"> {
+ let summary = "Cleans up device attributes from all ops";
+
+ let description = [{
+ This pass removes `device` attributes from all TF ops. Some Serving
+ doesn't rely on `device` attributes from SavedModel.
+ }];
+
+ let constructor = "CreateTfDeviceCleanupPass()";
+}
+
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
index aba37ac..2fc2c17 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
@@ -27,6 +27,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@@ -128,39 +129,9 @@
<< " is missing";
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
- std::optional<xla::DeviceAssignmentProto> xla_device_assignment;
- auto topology_attr = cluster_func->getAttrOfType<mlir::StringAttr>(
- tensorflow::kTopologyAttr);
- // Get device assignment.
- auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
- tensorflow::kDeviceAssignmentAttr);
- if (topology_attr && device_assignment_attr && !topology_attr.empty() &&
- !device_assignment_attr.empty()) {
- auto device_coordinates =
- tensorflow::GetDeviceCoordinates(device_assignment_attr);
- if (!device_coordinates.ok())
- return cluster_func.emitError()
- << "error in parsing tpu device coordinates: "
- << device_coordinates.status().message();
-
- absl::StatusOr<xla::DeviceAssignmentProto>
- xla_device_assignment_from_device_assignment_attr =
- tensorflow::GetXlaDeviceAssignmentProto(
- topology_attr.getValue(), num_replicas, num_cores_per_replica,
- *device_coordinates);
- if (!xla_device_assignment_from_device_assignment_attr.ok()) {
- return cluster_func.emitError()
- << "error in getting xla device assignment: "
- << xla_device_assignment_from_device_assignment_attr.status()
- .message();
- }
- xla_device_assignment =
- *xla_device_assignment_from_device_assignment_attr;
- }
-
return mlir::TFTPU::SetMetadataProtoFromClusterFuncOp(
cluster_func, num_replicas, num_cores_per_replica,
- std::move(xla_device_assignment), metadata);
+ /*xla_device_assignment=*/std::nullopt, metadata);
}
void Rewrite(mlir::SymbolTable &symbol_table,
@@ -194,10 +165,16 @@
auto metadata_attr =
ifrt_program->getAttrOfType<mlir::StringAttr>(kMetadataTextAttrName);
- if (!metadata_attr) {
+ auto device_assignment_attr =
+ ifrt_program->getAttrOfType<mlir::ArrayAttr>(kDeviceAssignmentAttr);
+ if (!metadata_attr || !device_assignment_attr) {
return signalPassFailure();
}
+
+ // For better debuggability, attach attributes such as
+ // tpu_compile_metadata and device_assignment to IfrtCallOp.
ifrt_call_op->setAttr(kMetadataTextAttrName, metadata_attr);
+ ifrt_call_op->setAttr(kDeviceAssignmentAttr, device_assignment_attr);
// TODO(b/304839793): populate variable names after adding a variable
// hoisting pass.
@@ -230,6 +207,13 @@
cloned_ifrt_program->setAttr(kMetadataTextAttrName,
builder.getStringAttr(serialized_metadata));
+ auto device_assignment_attr =
+ cluster_func->getAttrOfType<mlir::ArrayAttr>(kDeviceAssignmentAttr);
+ if (!device_assignment_attr) {
+ device_assignment_attr = builder.getArrayAttr({});
+ }
+ cloned_ifrt_program->setAttr(kDeviceAssignmentAttr, device_assignment_attr);
+
cloned_ifrt_program.setName(ifrt_program_name);
int64_t program_id = NewProgramId();
@@ -250,10 +234,11 @@
// hoisting pass.
ifrt_call_op.setVariableArgIndicesAttr(builder.getI32ArrayAttr({}));
ifrt_call_op.setProgramId(program_id);
- // Additionally attach tpu_compile_metadata to IfrtCallOp. Some subsequent
- // pass such as SinkVariableAsNamedArrayPass relies on this attribute.
+ // For better debuggability, attach attributes such as tpu_compile_metadata
+ // and device_assignment to IfrtCallOp.
ifrt_call_op->setAttr(kMetadataTextAttrName,
builder.getStringAttr(serialized_metadata));
+ ifrt_call_op->setAttr(kDeviceAssignmentAttr, device_assignment_attr);
cluster_func->replaceAllUsesWith(ifrt_call_op.getResults());
cluster_func->erase();
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc
index b201370..1de61ab 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc
@@ -37,9 +37,9 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc
new file mode 100644
index 0000000..b40c94e
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+namespace {
+
+#define GEN_PASS_DEF_TFDEVICECLEANUPPASS
+#define GEN_PASS_DECL_TFDEVICECLEANUPPASS
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep
+
+class TfDeviceCleanupPass
+ : public impl::TfDeviceCleanupPassBase<TfDeviceCleanupPass> {
+ public:
+ void runOnOperation() override {
+ mlir::func::FuncOp func = getOperation();
+ func.walk([](mlir::Operation* op) {
+ if (llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) {
+ op->removeAttr("device");
+ }
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateTfDeviceCleanupPass() {
+ return std::make_unique<TfDeviceCleanupPass>();
+}
+
+} // namespace ifrt_serving
+} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
index 2802cb5..6d49f9a 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
@@ -81,6 +81,10 @@
pm.addPass(CreateRewriteClusterToIfrtCallPass());
+ // After device program is extracted, we can clean up device attributes from
+ // all ops.
+ pm.addNestedPass<mlir::func::FuncOp>(CreateTfDeviceCleanupPass());
+
// Sink VarHandle with ReadVariableOp: subsequent SinkVariableAsNamedArrayPass
// rely on the co-existence of VarHandle and ReadVariable in the same
// function.
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
index 93713fb..92d9b06 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
@@ -57,6 +57,10 @@
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerToIfrtRestoreVariablePass();
+// Creates a pass that cleans up device attributes from all ops.
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateTfDeviceCleanupPass();
+
#define GEN_PASS_REGISTRATION
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD
index cb517d1..83b70c2 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD
@@ -20,11 +20,15 @@
deps = [
"//tensorflow/core/tfrt/mlrt/bytecode",
"//tensorflow/core/tfrt/mlrt/bytecode:executable",
+ "//tensorflow/core/tfrt/mlrt/bytecode:function",
+ "//tensorflow/core/tfrt/mlrt/bytecode:kernel",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
@@ -38,10 +42,15 @@
data = glob(["testdata/**"]),
deps = [
":mlir_to_bytecode",
+ "//tensorflow/core/tfrt/mlrt/bytecode",
"//tensorflow/core/tfrt/mlrt/bytecode:executable",
"//tensorflow/core/tfrt/mlrt/interpreter:attribute_span",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:resource_loader",
@@ -57,10 +66,15 @@
hdrs = ["test_utils.h"],
deps = [
# copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/stubs:tfrt_native_lowering_impl",
+ "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/tfrt/graph_executor:sync_resource_state",
"//tensorflow/core/tfrt/mlrt/attribute",
"//tensorflow/core/tfrt/mlrt/bytecode",
"//tensorflow/core/tfrt/mlrt/bytecode:kernel",
@@ -70,7 +84,9 @@
"//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub",
"//tensorflow/core/tfrt/utils:tensor_util",
"@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
"@tf_runtime//:hostcontext",
+ "@tf_runtime//:support",
"@tf_runtime//:tensor",
],
)
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc
index d3b19eb..52b1826 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc
@@ -25,14 +25,26 @@
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
+#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/Region.h" // from @llvm-project
+#include "mlir/IR/Types.h" // from @llvm-project
+#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/function.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h"
namespace mlrt {
namespace {
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h
index 7f5416d..9508656 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h
@@ -22,9 +22,16 @@
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
namespace mlrt {
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc
index 9f02f1d..d7d3065d 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc
@@ -19,9 +19,20 @@
#include <vector>
#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
+#include "mlir/IR/DialectRegistry.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h"
#include "tsl/platform/resource_loader.h"
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc
index b5a3cb9..e4f9e6f 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc
@@ -22,10 +22,19 @@
#include <utility>
#include <vector>
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/context.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
namespace mlrt {
namespace testing {
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h
index d569f32..6140c71 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h
@@ -21,10 +21,15 @@
#include <utility>
#include <vector>
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h"
#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h"
@@ -34,10 +39,13 @@
#include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h"
#include "tensorflow/core/tfrt/utils/tensor_util.h"
#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
#include "tfrt/host_context/execution_context.h" // from @tf_runtime
#include "tfrt/host_context/host_allocator.h" // from @tf_runtime
#include "tfrt/host_context/host_context.h" // from @tf_runtime
+#include "tfrt/support/string_util.h" // from @tf_runtime
+#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
#include "tfrt/tensor/dense_tensor_utils.h" // from @tf_runtime
namespace mlrt {
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index fc31b8f..0972f67 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1,6 +1,7 @@
load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test")
load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
+load("//tensorflow/compiler/tests:build_combined_defs.bzl", "tf_xla_combined_py_test")
load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", "tf_xla_py_strict_test")
load(
"//tensorflow/core/platform:build_config_root.bzl",
@@ -88,15 +89,20 @@
],
)
-tf_xla_py_strict_test(
- name = "adadelta_test",
+tf_xla_combined_py_test(
+ name = "ops_test_mlir_false",
size = "medium",
- srcs = ["adadelta_test.py"],
enable_mlir_bridge = False,
+ package = "tensorflow.compiler.tests",
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
+ test_files = [
+ # go/keep-sorted start
+ "adadelta_test.py",
+ # go/keep-sorted end
+ ],
deps = [
":xla_test",
"//tensorflow/python/framework:constant_op",
diff --git a/tensorflow/compiler/tests/build_combined_defs.bzl b/tensorflow/compiler/tests/build_combined_defs.bzl
new file mode 100644
index 0000000..0463fe1
--- /dev/null
+++ b/tensorflow/compiler/tests/build_combined_defs.bzl
@@ -0,0 +1,46 @@
+"""Build rule for combining Tensorflow/XLA tests."""
+
+load("//tensorflow:strict.default.bzl", "py_strict_test")
+load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
+
+def tf_xla_combined_py_test(name = "", package = None, test_files = [], **kwargs):
+ """Generates combined tf_xla_py_test targets, one per XLA backend.
+
+ All tests found in the list test_files are combined into one new test which is then passed on to
+ tf_xla_py_test which creates a new target per XLA backend.
+
+ Args:
+ name: Name of the target.
+ package: The package that all tests in test_files belong to.
+ test_files: The test files to be combined and tested.
+ **kwargs: keyword arguments passed onto the tf_xla_py_test rule.
+ """
+
+ test_file = name + ".py"
+
+ # run the generator to create the combined test file containing all the tests in test_files
+ # redirecting the output of the generator to test_file.
+ native.genrule(
+ name = name + "_gen",
+ testonly = 1,
+ srcs = test_files,
+ outs = [test_file],
+ cmd = """
+mkdir -p $(@D) && cat > $@ << EOF
+from tensorflow.python.platform import test
+%s
+
+if __name__ == "__main__":
+ test.main()
+EOF
+ """ % "\n".join(["from %s.%s import *" % (package, test[:-3]) for test in test_files]),
+ tools = [],
+ tags = ["generated_python_test=%s.%s" % (package, name)],
+ )
+
+ tf_xla_py_test(
+ name = name,
+ test_rule = py_strict_test,
+ srcs = [test_file] + test_files,
+ **kwargs
+ )
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 1a9215b..ffab429 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -68,6 +68,7 @@
#include "absl/types/span.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/device.h"
@@ -97,7 +98,6 @@
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
diff --git a/tensorflow/compiler/tests/unary_ops_composition_test.cc b/tensorflow/compiler/tests/unary_ops_composition_test.cc
index 40a5300..58878b8 100644
--- a/tensorflow/compiler/tests/unary_ops_composition_test.cc
+++ b/tensorflow/compiler/tests/unary_ops_composition_test.cc
@@ -21,6 +21,7 @@
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/device_factory.h"
@@ -35,7 +36,6 @@
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/port.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index c1ac5b1..75e1eb4 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -668,6 +668,10 @@
hdrs = [
"xla_resource.h",
],
+ visibility = [
+ ":internal",
+ "//learning/deepmind/tensorflow/tpufunc:__pkg__",
+ ],
deps = [
":common",
":sharding_util",
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 9ba3ded..b65ac38 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -211,7 +211,6 @@
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:fingerprint",
"@local_tsl//tsl/platform:statusor",
@@ -226,6 +225,7 @@
"@local_xla//xla/stream_executor/gpu:gpu_executor_header",
"@local_xla//xla/stream_executor/gpu:gpu_stream_header",
"@local_xla//xla/stream_executor/gpu:gpu_types_header",
+ "@local_xla//xla/tsl/lib/strings:proto_serialization",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc
index 4971fd0..6697b06 100644
--- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc
+++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc
@@ -56,6 +56,7 @@
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/util.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/process_state.h"
@@ -73,7 +74,6 @@
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/fingerprint.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index 39dd10d..31a6b68 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -138,6 +138,7 @@
std::vector<bool> dims_are_dynamic;
const auto& dims = shape.dims();
dims_are_dynamic.reserve(dims);
+ output_dim_sizes.reserve(dims);
for (int64_t i = 0; i < dims; ++i) {
output_dim_sizes.push_back(
xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {}));
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 26d3cff..09d6898 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -405,6 +405,7 @@
std::vector<xla::XlaOp> dynamic_dims;
const xla::Shape& shape = list_shape.tuple_shapes(i);
auto sub_element = xla::GetTupleElement(list, i);
+ dynamic_dims.reserve(shape.dimensions_size());
for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) {
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc
index 73c6c34..1845b9b 100644
--- a/tensorflow/compiler/tf2xla/kernels/where_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc
@@ -275,6 +275,7 @@
// and then scatter iotas[out_idxs] into the output.
std::vector<XlaOp> iotas_to_concat;
auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions());
+ iotas_to_concat.reserve(iota_shape.rank());
for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) {
iotas_to_concat.push_back(
xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}));
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
index 787d676..48f0622 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
@@ -336,11 +336,6 @@
return nullptr;
}
- absl::StatusOr<std::unique_ptr<se::StreamExecutor>> GetUncachedExecutor(
- const se::StreamExecutorConfig& config) override {
- return std::unique_ptr<se::StreamExecutor>(nullptr);
- }
-
private:
string name_;
};
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index fc32bda..0877e8a 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -3427,10 +3427,10 @@
"//tensorflow/core/framework:function_testlib",
"//tensorflow/core/framework:optimized_function_graph_proto_cc",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:test",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -3477,9 +3477,9 @@
"//tensorflow/core/kernels:function_ops",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:test",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -3497,7 +3497,7 @@
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/common_runtime/arg_ret_placement_test.cc b/tensorflow/core/common_runtime/arg_ret_placement_test.cc
index 8aea657..11b8bdb 100644
--- a/tensorflow/core/common_runtime/arg_ret_placement_test.cc
+++ b/tensorflow/core/common_runtime/arg_ret_placement_test.cc
@@ -20,6 +20,7 @@
#include <gtest/gtest.h>
#include "tensorflow/cc/framework/scope.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/full_type.pb.h"
#include "tensorflow/core/framework/function.h"
@@ -29,7 +30,6 @@
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc
index a16b90c..6e78f32 100644
--- a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc
+++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc
@@ -19,6 +19,7 @@
#include <string>
#include "tensorflow/cc/framework/scope.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/graph_def_builder_util.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/config/flag_defs.h"
@@ -29,7 +30,6 @@
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index a4712a5..675ffc6 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -206,7 +206,7 @@
const Tensor* input, Tensor* output,
int dev_to_dev_stream_index, StatusCallback done,
bool sync_dst_compute) {
- profiler::ScopedAnnotation annotation(
+ tsl::profiler::ScopedAnnotation annotation(
[&] { return absl::StrCat("#edge_name=", edge_name, "#"); });
VLOG(4) << "Copy " << edge_name;
diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
index d2b51ee..1f8c8cd 100644
--- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
+++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
@@ -136,49 +136,6 @@
std::string key_prefix_;
};
-// Remove LocalDeviceState objects from
-// info->local_device_states that have unique hardware IDs
-// (i.e. ignore duplicate virtual devices) and return them in a map.
-static std::map<int, std::unique_ptr<xla::LocalDeviceState>>
-GetUniqueDeviceStates(PjRtGpuClientCreationInfo* info) {
- // Only consider each hardware device once. In test environments, one
- // physical GPU (e.g. hardware_id 0) might be shared as virtual GPUs (e.g.
- // local_id 0 and 1) by multiple workers (multiple processes on the same
- // computer). If there is a need to not ignore these for an actual case, a
- // possible solution is to add a flag to only enable the use of
- // hardware_id_to_local_id for tests.
-
- auto input_states = std::move(info->local_device_states);
-
- absl::flat_hash_map<int, int> hardware_id_to_local_id;
- for (const auto& id_state : input_states) {
- int local_id = id_state.second->local_device_id().value();
- int hardware_id = id_state.second->local_hardware_id().value();
- if (hardware_id_to_local_id.contains(hardware_id)) {
- if (hardware_id_to_local_id[hardware_id] > local_id) {
- // Use the device with the smallest local_id, ignore others.
- hardware_id_to_local_id[hardware_id] = local_id;
- }
- } else {
- hardware_id_to_local_id[hardware_id] = local_id;
- }
- }
- std::map<int, std::unique_ptr<xla::LocalDeviceState>> local_device_states;
- for (auto& id_state : input_states) {
- int local_id = id_state.second->local_device_id().value();
- int hardware_id = id_state.second->local_hardware_id().value();
- if (hardware_id_to_local_id[hardware_id] != local_id) {
- VLOG(1) << "For hardware_id=" << hardware_id
- << ", ignoring redundant local_id=" << local_id
- << ". local_id=" << hardware_id_to_local_id[hardware_id]
- << " will be used instead.";
- continue;
- }
- local_device_states.emplace(id_state.first, std::move(id_state.second));
- }
- return local_device_states;
-}
-
// Coordinate creation of a PjRt GPU client with distributed devices when there
// are multiple threads (which typically occurs in test environments that use
// multiple threads to simulate multiple workers).
@@ -319,10 +276,9 @@
auto kv_store =
std::make_shared<XlaKeyValueStore>(coordination_service_agent);
- std::map<int, std::unique_ptr<xla::LocalDeviceState>>
- unique_local_device_states;
+ std::map<int, std::unique_ptr<xla::LocalDeviceState>> local_device_states;
if (use_creation_info) {
- unique_local_device_states = GetUniqueDeviceStates(info);
+ local_device_states = std::move(info->local_device_states);
}
if (use_creation_info) {
// Tell any other threads are waiting to call BuildDistributedDevices to
@@ -330,7 +286,7 @@
creation_state->SetReady();
}
auto device_topology_pair = BuildDistributedDevices(
- platform_name, std::move(unique_local_device_states), node_id, num_nodes,
+ platform_name, std::move(local_device_states), node_id, num_nodes,
gpu_run_options.get(), kv_store, /*enable_mock_nccl=*/false);
if (!device_topology_pair.ok()) {
if (use_creation_info) {
diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc
index 1758e31..4c50c62 100644
--- a/tensorflow/core/common_runtime/eager/context_test.cc
+++ b/tensorflow/core/common_runtime/eager/context_test.cc
@@ -26,6 +26,7 @@
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
@@ -39,7 +40,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/eager/eager_executor_test.cc b/tensorflow/core/common_runtime/eager/eager_executor_test.cc
index e933a04..1650dbf 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor_test.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor_test.cc
@@ -17,10 +17,10 @@
#include <memory>
#include <utility>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/core/common_runtime/eager/placement_utils_test.cc b/tensorflow/core/common_runtime/eager/placement_utils_test.cc
index 8803d74..6220cc9 100644
--- a/tensorflow/core/common_runtime/eager/placement_utils_test.cc
+++ b/tensorflow/core/common_runtime/eager/placement_utils_test.cc
@@ -21,11 +21,11 @@
#include <gtest/gtest.h>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#define DEVICE_CPU0 "/job:localhost/replica:0/task:0/device:CPU:0"
#define DEVICE_CPU0_TASK1 "/job:localhost/replica:0/task:1/device:CPU:0"
diff --git a/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc b/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc
index 1038ba9..efa597f 100644
--- a/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc
+++ b/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc
@@ -19,12 +19,12 @@
#include <string>
#include <vector>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index 3f2665e..6aa62de 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -25,13 +25,13 @@
#include "xla/stream_executor/gpu/gpu_init.h"
#include "xla/tests/test_macros.h"
#include "xla/tsl/framework/device_id.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#ifdef TF_GPU_USE_PJRT
#include "xla/pjrt/pjrt_client.h"
diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc
index 66109ae..78c2713 100644
--- a/tensorflow/core/common_runtime/graph_constructor.cc
+++ b/tensorflow/core/common_runtime/graph_constructor.cc
@@ -112,7 +112,7 @@
: allow_internal_ops(false),
expect_device_spec(false),
propagate_device_spec(in.propagate_device_spec),
- prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
+ prefix(in.prefix.empty() || absl::EndsWith(in.prefix, "/")
? in.prefix
: in.prefix + "/"),
uniquify_names(in.uniquify_names),
diff --git a/tensorflow/core/common_runtime/int32_fulltype_test.cc b/tensorflow/core/common_runtime/int32_fulltype_test.cc
index e6ead59..5d2c0e0 100644
--- a/tensorflow/core/common_runtime/int32_fulltype_test.cc
+++ b/tensorflow/core/common_runtime/int32_fulltype_test.cc
@@ -18,6 +18,7 @@
#include <string>
#include <unordered_map>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/graph_def_builder_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/full_type.pb.h"
@@ -30,7 +31,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc
index b57145c..31c1e40 100644
--- a/tensorflow/core/common_runtime/lower_while_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_while_op_test.cc
@@ -25,6 +25,7 @@
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
@@ -39,7 +40,6 @@
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD
index 74f843c..4a36e8f 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD
+++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD
@@ -329,7 +329,6 @@
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/time",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:env_impl",
"@local_tsl//tsl/platform:errors",
@@ -341,6 +340,7 @@
"@local_xla//xla/tsl/distributed_runtime:call_options",
"@local_xla//xla/tsl/distributed_runtime/coordination:coordination_client",
"@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD
index 7feb974..7862391 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD
+++ b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD
@@ -169,8 +169,8 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc
index b5a37df..02ea581 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc
+++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc
@@ -26,6 +26,7 @@
#include "absl/strings/str_cat.h"
#include "absl/synchronization/notification.h"
#include "xla/tsl/framework/allocator.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.h"
#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.h"
@@ -45,7 +46,6 @@
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc
index f61be2e..5d62f8c 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc
+++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc
@@ -27,8 +27,8 @@
#include "xla/tsl/distributed_runtime/call_options.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_client.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/test.h"
#include "tsl/protobuf/coordination_config.pb.h"
diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc
index b2cf3d1..52925d9 100644
--- a/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc
+++ b/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc
@@ -21,6 +21,7 @@
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_set.h"
@@ -29,7 +30,6 @@
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/status.h"
diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info.h b/tensorflow/core/common_runtime/optimized_function_graph_info.h
index dd05b02..b15790d 100644
--- a/tensorflow/core/common_runtime/optimized_function_graph_info.h
+++ b/tensorflow/core/common_runtime/optimized_function_graph_info.h
@@ -71,10 +71,10 @@
OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo& info) = delete;
OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo& info) =
delete;
- OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) noexcept =
- default;
- OptimizedFunctionGraphInfo& operator=(
- OptimizedFunctionGraphInfo&& info) noexcept = default;
+ OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) =
+ default; // NOLINT
+ OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo&& info) =
+ default; // NOLINT
// Converts from the struct to OptimizedFunctionGraph proto.
static OptimizedFunctionGraph ToProto(const OptimizedFunctionGraphInfo& info);
diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc b/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc
index 800da5f..cab15e6 100644
--- a/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc
+++ b/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc
@@ -23,11 +23,11 @@
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "third_party/protobuf/text_format.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/optimized_function_graph.pb.h"
#include "tensorflow/core/graph/node_builder.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/common_runtime/replicate_constants_pass_test.cc b/tensorflow/core/common_runtime/replicate_constants_pass_test.cc
index 346ba17..bf335df 100644
--- a/tensorflow/core/common_runtime/replicate_constants_pass_test.cc
+++ b/tensorflow/core/common_runtime/replicate_constants_pass_test.cc
@@ -22,6 +22,7 @@
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/config/flag_defs.h"
#include "tensorflow/core/config/flags.h"
@@ -29,7 +30,6 @@
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/test.h"
diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h
index a773fbb..23e9989 100644
--- a/tensorflow/core/config/flag_defs.h
+++ b/tensorflow/core/config/flag_defs.h
@@ -64,6 +64,9 @@
// TODO(b/341325107): Make this behavior the default and remove the flag.
TF_DECLARE_FLAG(enable_function_pruning_before_inlining, false,
"If true, functions will be pruned before inlining.")
+ TF_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs, false,
+ "If true, TF2XLA encapsulation will be skipped for non-TPU "
+ "graphs.")
// LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc)
};
diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc
index 096d48c..060ede3 100644
--- a/tensorflow/core/config/flags_api_wrapper.cc
+++ b/tensorflow/core/config/flags_api_wrapper.cc
@@ -55,5 +55,6 @@
TF_PY_DECLARE_FLAG(enable_colocation_key_propagation_in_while_op_lowering);
TF_PY_DECLARE_FLAG(enable_tf2min_ici_weight)
TF_PY_DECLARE_FLAG(enable_function_pruning_before_inlining)
+ TF_PY_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs)
// LINT.ThenChange(//tensorflow/core/config/flag_defs.h)
};
diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD
index 1ec7a6f..748dfc1 100644
--- a/tensorflow/core/data/BUILD
+++ b/tensorflow/core/data/BUILD
@@ -218,7 +218,9 @@
"//tensorflow/core:framework",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc
index 6ddf987..2206cb0 100644
--- a/tensorflow/core/data/captured_function.cc
+++ b/tensorflow/core/data/captured_function.cc
@@ -943,8 +943,10 @@
}
void InstantiatedCapturedFunction::RunAsync(
- IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done,
+ std::function<void(std::function<void()>)> runner,
+ CancellationManager* parent_cancellation_manager,
+ CollectiveExecutor* collective_executor, std::vector<Tensor>&& args,
+ std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done,
const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
@@ -952,7 +954,7 @@
// potentially do a non-trivial amount of (e.g. copying) work, and we may
// want to run that concurrently with the next invocation.
Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
- (*ctx->runner())(
+ runner(
std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
std::move(done)));
return;
@@ -971,18 +973,18 @@
resource_mgr->Cleanup(name).IgnoreError();
});
f_opts.step_container = step_container;
- f_opts.runner = ctx->runner();
+ f_opts.runner = &runner;
f_opts.create_rendezvous = ShouldCreateRendezvous();
auto cancellation_manager =
- std::make_unique<CancellationManager>(ctx->cancellation_manager());
+ std::make_unique<CancellationManager>(parent_cancellation_manager);
f_opts.cancellation_manager = cancellation_manager.get();
- f_opts.collective_executor = ctx->collective_executor();
+ f_opts.collective_executor = collective_executor;
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
- if (node || ctx->stats_aggregator()) {
+ if (node) {
stats_collector = std::make_shared<SimpleStepStatsCollector>();
}
- const bool collect_usage = node && ctx->model();
+ const bool collect_usage = node != nullptr;
f_opts.stats_collector = stats_collector.get();
// Transfer ownership of the cancellation manager to `callback`.
@@ -992,7 +994,6 @@
[this, rets, step_container, raw_cancellation_manager, frame, node,
collect_usage](
const FunctionLibraryRuntime::DoneCallback& done,
- IteratorContext* ctx,
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
// Begin unbound arguments.
Status s) {
@@ -1003,18 +1004,6 @@
}
delete frame;
if (node) {
- // TODO(b/129085499) Utilize the `node_name` which would be unique
- // than the prefix for the function execution time statistics.
- // prefix_with_func_name would then be node_name + func_name.
- if (ctx->stats_aggregator()) {
- string prefix_with_func_name =
- strings::StrCat(node->name(), stats_utils::kDelimiter,
- captured_func_->func().name());
- ctx->stats_aggregator()->AddToHistogram(
- stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
- {static_cast<float>(stats_collector->processing_time())},
- node->num_elements());
- }
node->add_processing_time(stats_collector->processing_time());
}
if (collect_usage) {
@@ -1025,7 +1014,7 @@
node->record_stop(EnvTime::NowNanos());
}
},
- std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
+ std::move(done), std::move(stats_collector), std::placeholders::_1);
tsl::profiler::TraceMe activity(
[&] {
diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h
index e415c54..854d9fc 100644
--- a/tensorflow/core/data/captured_function.h
+++ b/tensorflow/core/data/captured_function.h
@@ -288,6 +288,18 @@
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,
+ const std::shared_ptr<model::Node>& node) const {
+ RunAsync(*(ctx->runner()), ctx->cancellation_manager(),
+ ctx->collective_executor(), std::move(args), rets, done, node);
+ }
+
+ // A version of `RunAsync` that does not take an `IteratorContext` but a
+ // runner, a cancellation manager, and a collective executor.
+ void RunAsync(std::function<void(std::function<void()>)> runner,
+ CancellationManager* parent_cancellation_manager,
+ CollectiveExecutor* collective_executor,
+ std::vector<Tensor>&& args, std::vector<Tensor>* rets,
+ FunctionLibraryRuntime::DoneCallback done,
const std::shared_ptr<model::Node>& node) const;
std::string func_name() const { return captured_func_->func().name(); }
diff --git a/tensorflow/core/data/dataset_test_base.cc b/tensorflow/core/data/dataset_test_base.cc
index 7e295e3..e770b4f 100644
--- a/tensorflow/core/data/dataset_test_base.cc
+++ b/tensorflow/core/data/dataset_test_base.cc
@@ -348,7 +348,7 @@
Status DatasetOpsTestBase::CreateDatasetContext(
OpKernel* const dateset_kernel,
- gtl::InlinedVector<TensorValue, 4>* const inputs,
+ absl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
std::unique_ptr<OpKernelContext>* dataset_context) {
Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
@@ -515,13 +515,13 @@
}
Status DatasetOpsTestBase::CreateOpKernelContext(
- OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
+ OpKernel* kernel, absl::InlinedVector<TensorValue, 4>* inputs,
std::unique_ptr<OpKernelContext>* context) {
return CreateOpKernelContext(kernel, inputs, ¶ms_, context);
}
Status DatasetOpsTestBase::CreateOpKernelContext(
- OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
+ OpKernel* kernel, absl::InlinedVector<TensorValue, 4>* inputs,
std::unique_ptr<OpKernelContext::Params>* context_params,
std::unique_ptr<OpKernelContext>* context) {
auto params = std::make_unique<OpKernelContext::Params>();
@@ -565,7 +565,7 @@
}
Status DatasetOpsTestBase::CheckOpKernelInput(
- const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
+ const OpKernel& kernel, const absl::InlinedVector<TensorValue, 4>& inputs) {
if (kernel.num_inputs() != inputs.size()) {
return errors::InvalidArgument("The number of input elements should be ",
kernel.num_inputs(),
@@ -575,7 +575,7 @@
}
Status DatasetOpsTestBase::AddDatasetInput(
- gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
+ absl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
DataType dtype, const TensorShape& shape) {
if (input_types.size() < inputs->size()) {
return errors::InvalidArgument("Adding more inputs than types: ",
@@ -862,7 +862,7 @@
input_datasets.push_back(t.get());
created_tensors->push_back(std::move(t));
}
- gtl::InlinedVector<TensorValue, 4> inputs;
+ absl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(input_datasets.size());
for (auto input_dataset : input_datasets) {
inputs.emplace_back(TensorValue(input_dataset));
@@ -985,7 +985,7 @@
TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
auto input_tensors = dataset_params.GetInputTensors();
- gtl::InlinedVector<TensorValue, 4> inputs;
+ absl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(input_datasets.size() + input_tensors.size());
for (auto input_dataset : input_datasets) {
inputs.emplace_back(TensorValue(input_dataset));
@@ -1165,7 +1165,7 @@
const std::vector<Tensor>& input_components) {
std::vector<PartialTensorShape> shapes;
for (const auto& component : input_components) {
- gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
+ absl::InlinedVector<int64_t, 4> partial_dim_sizes;
for (int i = 1; i < component.dims(); ++i) {
partial_dim_sizes.push_back(component.dim_size(i));
}
diff --git a/tensorflow/core/data/dataset_test_base.h b/tensorflow/core/data/dataset_test_base.h
index ec9805b..e727823 100644
--- a/tensorflow/core/data/dataset_test_base.h
+++ b/tensorflow/core/data/dataset_test_base.h
@@ -766,7 +766,7 @@
// Creates a new op kernel context.
Status CreateDatasetContext(
- OpKernel* dateset_kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
+ OpKernel* dateset_kernel, absl::InlinedVector<TensorValue, 4>* inputs,
std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
std::unique_ptr<OpKernelContext>* dataset_context);
@@ -798,16 +798,16 @@
// Checks that the size of `inputs` matches the requirement of the op kernel.
Status CheckOpKernelInput(const OpKernel& kernel,
- const gtl::InlinedVector<TensorValue, 4>& inputs);
+ const absl::InlinedVector<TensorValue, 4>& inputs);
// Creates a new context for running the dataset operation.
Status CreateOpKernelContext(OpKernel* kernel,
- gtl::InlinedVector<TensorValue, 4>* inputs,
+ absl::InlinedVector<TensorValue, 4>* inputs,
std::unique_ptr<OpKernelContext>* context);
// Creates a new context for running the dataset operation.
Status CreateOpKernelContext(OpKernel* kernel,
- gtl::InlinedVector<TensorValue, 4>* inputs,
+ absl::InlinedVector<TensorValue, 4>* inputs,
std::unique_ptr<OpKernelContext::Params>* params,
std::unique_ptr<OpKernelContext>* context);
@@ -856,7 +856,7 @@
// Adds an empty tensor with the specified dtype and shape to the input
// vector.
- Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
+ Status AddDatasetInput(absl::InlinedVector<TensorValue, 4>* inputs,
DataTypeVector input_types, DataType dtype,
const TensorShape& shape);
diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc
index cc7ed17..1934599 100644
--- a/tensorflow/core/data/dataset_utils.cc
+++ b/tensorflow/core/data/dataset_utils.cc
@@ -1018,7 +1018,7 @@
AllTasks);
REGISTER_DATASET_EXPERIMENT("no_compression", RandomJobSamplePercentage<0>,
AllTasks);
-REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<50>,
+REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<0>,
AllTasks);
REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>,
AllTasks);
diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc
index e581f6e..2e107eb 100644
--- a/tensorflow/core/data/dataset_utils_test.cc
+++ b/tensorflow/core/data/dataset_utils_test.cc
@@ -359,11 +359,10 @@
auto opt_ins = test_case.opt_ins;
auto opt_outs = test_case.opt_outs;
if (!opt_ins.empty()) {
- setenv("TF_DATA_EXPERIMENT_OPT_IN", str_util::Join(opt_ins, ",").c_str(),
- 1);
+ setenv("TF_DATA_EXPERIMENT_OPT_IN", absl::StrJoin(opt_ins, ",").c_str(), 1);
}
if (!opt_outs.empty()) {
- setenv("TF_DATA_EXPERIMENT_OPT_OUT", str_util::Join(opt_outs, ",").c_str(),
+ setenv("TF_DATA_EXPERIMENT_OPT_OUT", absl::StrJoin(opt_outs, ",").c_str(),
1);
}
const std::string job_name = "job";
@@ -376,14 +375,14 @@
for (const auto& experiment : test_case.expected_in) {
EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
<< "experiment=" << experiment << " opt_ins={"
- << str_util::Join(opt_ins, ",") << "} opt_outs={"
- << str_util::Join(opt_outs, ",") << "}";
+ << absl::StrJoin(opt_ins, ",") << "} opt_outs={"
+ << absl::StrJoin(opt_outs, ",") << "}";
}
for (const auto& experiment : test_case.expected_out) {
EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
<< "experiment=" << experiment << " opt_ins={"
- << str_util::Join(opt_ins, ",") << "} opt_outs={"
- << str_util::Join(opt_outs, ",") << "}";
+ << absl::StrJoin(opt_ins, ",") << "} opt_outs={"
+ << absl::StrJoin(opt_outs, ",") << "}";
}
if (!opt_ins.empty()) {
diff --git a/tensorflow/core/data/global_shuffle_utils.cc b/tensorflow/core/data/global_shuffle_utils.cc
index 132a35f..dc42563 100644
--- a/tensorflow/core/data/global_shuffle_utils.cc
+++ b/tensorflow/core/data/global_shuffle_utils.cc
@@ -16,10 +16,13 @@
#include <cstdint>
#include <optional>
+#include <string>
#include <vector>
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor.h"
@@ -29,6 +32,13 @@
namespace tensorflow {
namespace data {
+namespace {
+
+constexpr absl::string_view kGlobalShuffleIteratorNextIndex =
+ "global_shuffle_iterator_next_index";
+
+}
+
IteratorContextWithIndexMapper::IteratorContextWithIndexMapper(
IteratorContext* ctx, const IteratorBase* iterator)
: ctx_(ctx) {
@@ -60,10 +70,22 @@
}
absl::MutexLock l(&mu_);
- TF_ASSIGN_OR_RETURN(int64_t output_index,
- ctx->index_mapper()(element_count_++));
+ absl::StatusOr<int64_t> shuffled_index =
+ absl::NotFoundError("Default not found");
+
+ while (absl::IsNotFound(shuffled_index.status())) {
+ shuffled_index = ctx->index_mapper()(element_count_++);
+ }
+
+ if (absl::IsOutOfRange(shuffled_index.status())) {
+ *end_of_sequence = true;
+ return absl::OkStatus();
+ }
+
+ TF_RETURN_IF_ERROR(shuffled_index.status());
+
absl::Status status =
- dataset_->Get(AnyContext(ctx), output_index, out_tensors);
+ dataset_->Get(AnyContext(ctx), shuffled_index.value(), out_tensors);
if (absl::IsOutOfRange(status)) {
*end_of_sequence = true;
return absl::OkStatus();
@@ -73,7 +95,18 @@
return absl::OkStatus();
}
-absl::Status GlobalShuffleIterator::Restore(IteratorContext* ctx) {
+absl::Status GlobalShuffleIterator::Save(
+ const std::string& parent_iterator_prefix, SerializationContext* ctx,
+ IteratorStateWriter* writer) {
+ absl::MutexLock l(&mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ parent_iterator_prefix, kGlobalShuffleIteratorNextIndex, element_count_));
+ return absl::OkStatus();
+}
+
+absl::Status GlobalShuffleIterator::Restore(
+ const std::string& parent_iterator_prefix, IteratorContext* ctx,
+ IteratorStateReader* reader) {
if (!ctx->restored_element_count().has_value()) {
return absl::FailedPreconditionError(absl::StrCat(
"Trying to restore random element count for dataset ",
@@ -81,7 +114,9 @@
}
absl::MutexLock l(&mu_);
- element_count_ = *(ctx->restored_element_count());
+ TF_RETURN_IF_ERROR(reader->ReadScalar(parent_iterator_prefix,
+ kGlobalShuffleIteratorNextIndex,
+ &element_count_));
return absl::OkStatus();
}
diff --git a/tensorflow/core/data/global_shuffle_utils.h b/tensorflow/core/data/global_shuffle_utils.h
index 91b4fa0..c7513a0 100644
--- a/tensorflow/core/data/global_shuffle_utils.h
+++ b/tensorflow/core/data/global_shuffle_utils.h
@@ -75,9 +75,13 @@
absl::Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence);
+ absl::Status Save(const std::string& parent_iterator_prefix,
+ SerializationContext* ctx, IteratorStateWriter* writer);
+
// Restores the element count.
// REQUIRES: ctx->restored_element_count() != nullopt.
- absl::Status Restore(IteratorContext* ctx);
+ absl::Status Restore(const std::string& parent_iterator_prefix,
+ IteratorContext* ctx, IteratorStateReader* reader);
private:
const DatasetBase* const dataset_;
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index fe265ba..3bfbac2 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -680,8 +680,6 @@
# copybara:uncomment copts = ["-Wthread-safety-analysis"],
deps = [
":credentials_factory",
- "//tensorflow/core:framework",
- "//tensorflow/core/data:dataset_utils",
],
)
@@ -758,8 +756,8 @@
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:statusor",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/data/service/auto_scaler_test.cc b/tensorflow/core/data/service/auto_scaler_test.cc
index c04ea49..299715d 100644
--- a/tensorflow/core/data/service/auto_scaler_test.cc
+++ b/tensorflow/core/data/service/auto_scaler_test.cc
@@ -18,9 +18,9 @@
#include <optional>
#include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/lib/monitoring/cell_reader.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
namespace tensorflow {
diff --git a/tensorflow/core/data/service/client/BUILD b/tensorflow/core/data/service/client/BUILD
index 16bd0ef..34b0b7a 100644
--- a/tensorflow/core/data/service/client/BUILD
+++ b/tensorflow/core/data/service/client/BUILD
@@ -120,10 +120,10 @@
"//tensorflow/core/data/service:dispatcher_client",
"//tensorflow/core/data/service:test_cluster",
"//tensorflow/core/data/service:test_util",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/protobuf:protos_all_cc",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(),
)
diff --git a/tensorflow/core/data/service/client/data_service_client_test.cc b/tensorflow/core/data/service/client/data_service_client_test.cc
index 8ec654b..09b1ede 100644
--- a/tensorflow/core/data/service/client/data_service_client_test.cc
+++ b/tensorflow/core/data/service/client/data_service_client_test.cc
@@ -21,6 +21,7 @@
#include <vector>
#include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/client/common.h"
#include "tensorflow/core/data/service/common.h"
#include "tensorflow/core/data/service/test_cluster.h"
@@ -33,7 +34,6 @@
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/data_service.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/data/service/client/utils_test.cc b/tensorflow/core/data/service/client/utils_test.cc
index c3d9451..8729bff 100644
--- a/tensorflow/core/data/service/client/utils_test.cc
+++ b/tensorflow/core/data/service/client/utils_test.cc
@@ -17,13 +17,13 @@
#include <optional>
#include <string>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/dispatcher_client.h"
#include "tensorflow/core/data/service/test_cluster.h"
#include "tensorflow/core/data/service/test_util.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/data_service.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc
index b1341be..e561ecb 100644
--- a/tensorflow/core/data/service/dispatcher_state_test.cc
+++ b/tensorflow/core/data/service/dispatcher_state_test.cc
@@ -21,13 +21,13 @@
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/journal.pb.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/data_service.pb.h"
#include "tensorflow/core/protobuf/service_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
namespace tensorflow {
diff --git a/tensorflow/core/data/service/py_utils.cc b/tensorflow/core/data/service/py_utils.cc
index d14e1c9..be5308d 100644
--- a/tensorflow/core/data/service/py_utils.cc
+++ b/tensorflow/core/data/service/py_utils.cc
@@ -17,9 +17,7 @@
#include <string>
-#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/data/service/credentials_factory.h"
-#include "tensorflow/core/framework/metrics.h"
namespace tensorflow {
namespace data {
@@ -39,17 +37,5 @@
return "grpc";
}
-bool DisableCompressionAtRegistrationTime() {
-#if defined(PLATFORM_GOOGLE)
- if (!GetExperiments().contains("no_compression_v2")) {
- return false;
- }
- metrics::RecordTFDataServiceCompressionAction(
- "disabled_at_registration_time");
- return true;
-#endif // PLATFORM_GOOGLE
- return false;
-}
-
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/data/service/py_utils.h b/tensorflow/core/data/service/py_utils.h
index 010c155..b0ea892 100644
--- a/tensorflow/core/data/service/py_utils.h
+++ b/tensorflow/core/data/service/py_utils.h
@@ -27,10 +27,6 @@
// Returns the default protocol to use for tf.data service control flow.
std::string DefaultProtocol();
-// Returns `true` if tf.data service compression is to be disabled at
-// registration time.
-bool DisableCompressionAtRegistrationTime();
-
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD
index 40b5cba..35d28b0 100644
--- a/tensorflow/core/data/service/snapshot/BUILD
+++ b/tensorflow/core/data/service/snapshot/BUILD
@@ -150,13 +150,13 @@
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc
index 71e9bfc..8974964 100644
--- a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc
+++ b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc
@@ -20,6 +20,7 @@
#include "absl/status/status.h"
#include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/dispatcher_client.h"
#include "tensorflow/core/data/service/snapshot/path_utils.h"
#include "tensorflow/core/data/service/snapshot/test_utils.h"
@@ -28,7 +29,6 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/protobuf/snapshot.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/compression.h"
#include "tsl/platform/env.h"
#include "tsl/platform/path.h"
diff --git a/tensorflow/core/data/service/snapshot/file_utils_test.cc b/tensorflow/core/data/service/snapshot/file_utils_test.cc
index 9bf1e52..9582cab 100644
--- a/tensorflow/core/data/service/snapshot/file_utils_test.cc
+++ b/tensorflow/core/data/service/snapshot/file_utils_test.cc
@@ -18,13 +18,13 @@
#include <string>
#include <vector>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/dataset_test_base.h"
#include "tensorflow/core/data/service/test_util.h"
#include "tensorflow/core/data/snapshot_utils.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/compression.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc
index d55e7d1..43944c6 100644
--- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc
+++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc
@@ -30,10 +30,10 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/byte_size.h"
#include "tensorflow/core/data/snapshot_utils.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/compression.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc
index 0f0c9f9..1e019a1 100644
--- a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc
+++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc
@@ -31,13 +31,13 @@
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/split_provider.h"
#include "tensorflow/core/data/service/test_util.h"
#include "tensorflow/core/data/snapshot_utils.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/compression.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc
index 4dc6f23..e40fd0a 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc
@@ -26,13 +26,13 @@
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/serialization_utils.h"
#include "tensorflow/core/data/service/snapshot/file_utils.h"
#include "tensorflow/core/data/service/snapshot/path_utils.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/path.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc
index 9e11653..65b3c59 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc
@@ -17,13 +17,13 @@
#include <memory>
#include <string>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/snapshot/path_utils.h"
#include "tensorflow/core/data/service/test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc
index 96c7c3e..5a6f820 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc
@@ -49,8 +49,9 @@
constexpr char kRepetitionIndex[] = "repetition_index";
absl::StatusOr<int64_t> GetRepetitionIndex(const std::string& split_file) {
- tsl::StringPiece repetition_dir_path = tsl::io::Dirname(split_file);
- tsl::StringPiece repetition_dir_name = tsl::io::Basename(repetition_dir_path);
+ absl::string_view repetition_dir_path = tsl::io::Dirname(split_file);
+ absl::string_view repetition_dir_name =
+ tsl::io::Basename(repetition_dir_path);
return ParseRepetitionDirectoryName(repetition_dir_name);
}
} // namespace
diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc
index 2d6a9fd..b9b9f3d 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc
@@ -23,6 +23,7 @@
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/serialization_utils.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/dispatcher_client.h"
@@ -34,7 +35,6 @@
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/protobuf/snapshot.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/compression.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc
index 85d401b..071c7e1 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc
@@ -21,6 +21,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/byte_size.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/snapshot/path_utils.h"
@@ -28,7 +29,6 @@
#include "tensorflow/core/data/service/snapshot/test_utils.h"
#include "tensorflow/core/data/service/task_runner.h"
#include "tensorflow/core/data/service/test_util.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/compression.h"
#include "tsl/platform/env.h"
#include "tsl/platform/random.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc
index 84243d8..f918341 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc
@@ -25,6 +25,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/byte_size.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/snapshot/file_utils.h"
@@ -35,7 +36,6 @@
#include "tensorflow/core/data/snapshot_utils.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/compression.h"
#include "tsl/lib/monitoring/cell_reader.h"
#include "tsl/platform/env.h"
diff --git a/tensorflow/core/data/service/split_provider_test.cc b/tensorflow/core/data/service/split_provider_test.cc
index 08adc90..d311db2 100644
--- a/tensorflow/core/data/service/split_provider_test.cc
+++ b/tensorflow/core/data/service/split_provider_test.cc
@@ -21,10 +21,10 @@
#include <vector>
#include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/test_util.h"
#include "tensorflow/core/framework/dataset.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
diff --git a/tensorflow/core/data/standalone_save_restore_test.cc b/tensorflow/core/data/standalone_save_restore_test.cc
index fd16369..9798021 100644
--- a/tensorflow/core/data/standalone_save_restore_test.cc
+++ b/tensorflow/core/data/standalone_save_restore_test.cc
@@ -17,11 +17,11 @@
#include <utility>
#include <vector>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/test_util.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/data/standalone_test.cc b/tensorflow/core/data/standalone_test.cc
index 54f438b..fac2a9e 100644
--- a/tensorflow/core/data/standalone_test.cc
+++ b/tensorflow/core/data/standalone_test.cc
@@ -19,10 +19,10 @@
#include <optional>
#include <vector>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 6c82115..fb01f41 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -182,11 +182,8 @@
":debug_grpc_testlib",
":debug_io_utils",
":debug_node_key",
- ":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
- "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -194,7 +191,6 @@
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- "//tensorflow/core/platform/default/build_config:platformlib",
],
)
@@ -260,7 +256,6 @@
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
index 3eaf365..87aea15 100644
--- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
@@ -13,6 +13,8 @@
limitations under the License.
==============================================================================*/
+#include <memory>
+
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/debug/debug_grpc_testlib.h"
#include "tensorflow/core/debug/debug_io_utils.h"
@@ -47,10 +49,10 @@
int64_t server_start_delay_micros) {
server_data->port = testing::PickUnusedPortOrDie();
server_data->url = strings::StrCat("grpc://localhost:", server_data->port);
- server_data->server.reset(new test::TestEventListenerImpl());
+ server_data->server = std::make_unique<test::TestEventListenerImpl>();
- server_data->thread_pool.reset(
- new thread::ThreadPool(Env::Default(), "test_server", 1));
+ server_data->thread_pool =
+ std::make_unique<thread::ThreadPool>(Env::Default(), "test_server", 1);
server_data->thread_pool->Schedule(
[server_data, server_start_delay_micros]() {
Env::Default()->SleepForMicroseconds(server_start_delay_micros);
diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h
index 18009a3..2a57df8 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.h
+++ b/tensorflow/core/debug/debug_grpc_testlib.h
@@ -39,7 +39,7 @@
::grpc::Status SendEvents(
::grpc::ServerContext* context,
::grpc::ServerReaderWriter< ::tensorflow::EventReply,
- ::tensorflow::Event>* stream);
+ ::tensorflow::Event>* stream) override;
// Clear debug data (e.g., Tensors) received so far.
void ClearReceivedDebugData();
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index dad4360..74d5758 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -13,11 +13,12 @@
limitations under the License.
==============================================================================*/
-#include <cstdlib>
-#include <unordered_set>
-
#include "tensorflow/core/debug/debug_io_utils.h"
+#include <cstdlib>
+#include <memory>
+#include <unordered_set>
+
#include "tensorflow/core/debug/debug_callback_registry.h"
#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
@@ -40,7 +41,7 @@
void Initialize() {
env_ = Env::Default();
- tensor_a_.reset(new Tensor(DT_FLOAT, TensorShape({2, 2})));
+ tensor_a_ = std::make_unique<Tensor>(DT_FLOAT, TensorShape({2, 2}));
tensor_a_->flat<float>()(0) = 5.0;
tensor_a_->flat<float>()(1) = 3.0;
tensor_a_->flat<float>()(2) = -1.0;
diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h
index 9de9bdc..4114d68 100644
--- a/tensorflow/core/debug/debugger_state_impl.h
+++ b/tensorflow/core/debug/debugger_state_impl.h
@@ -26,7 +26,7 @@
class DebuggerState : public DebuggerStateInterface {
public:
DebuggerState(const DebugOptions& debug_options);
- virtual ~DebuggerState();
+ ~DebuggerState() override;
// Publish metadata about the debugged Session::Run() call.
//
@@ -47,7 +47,7 @@
public:
DebugGraphDecorator(const DebugOptions& debug_options)
: debug_options_(debug_options) {}
- virtual ~DebugGraphDecorator() {}
+ ~DebugGraphDecorator() override {}
Status DecorateGraph(Graph* graph, Device* device) override;
Status PublishGraph(const Graph& graph, const string& device_name) override;
diff --git a/tensorflow/core/distributed_runtime/integration_test/BUILD b/tensorflow/core/distributed_runtime/integration_test/BUILD
index 4927d6f..7408bcb 100644
--- a/tensorflow/core/distributed_runtime/integration_test/BUILD
+++ b/tensorflow/core/distributed_runtime/integration_test/BUILD
@@ -52,7 +52,7 @@
"//tensorflow/core/platform:blocking_counter",
"//tensorflow/core/platform:env",
"@com_google_absl//absl/time",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -168,6 +168,6 @@
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/platform:env",
"@com_google_absl//absl/time",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc
index 4e0bd6f..356f0a0 100644
--- a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc
+++ b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc
@@ -23,6 +23,7 @@
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/blocking_counter.h"
@@ -31,7 +32,6 @@
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/protobuf/coordination_config.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc
index ba48750..cffe93d 100644
--- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc
+++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc
@@ -185,8 +185,7 @@
if (worker_id == 0) {
TFE_TensorHandle* in = TestMatrixTensorHandle(ctx);
const std::string& op_name =
- tensorflow::str_util::StrContains(send_device, "GPU") ? "Send"
- : "_HostSend";
+ absl::StrContains(send_device, "GPU") ? "Send" : "_HostSend";
TFE_Op* sendop = SendOp(ctx, in, op_name, send_device, recv_device,
send_device_incarnation);
TFE_TensorHandle* retvals[1];
@@ -197,8 +196,7 @@
TFE_DeleteTensorHandle(in);
} else {
const std::string& op_name =
- tensorflow::str_util::StrContains(send_device, "GPU") ? "Recv"
- : "_HostRecv";
+ absl::StrContains(send_device, "GPU") ? "Recv" : "_HostRecv";
TFE_Op* recvop = RecvOp(ctx, op_name, send_device, recv_device,
send_device_incarnation);
TFE_TensorHandle* retvals[1];
diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc
index dde7f65..3d9ff3c 100644
--- a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc
+++ b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc
@@ -24,6 +24,7 @@
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/strcat.h"
@@ -31,7 +32,6 @@
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/protobuf/coordination_config.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 7777bc6..2fcbc72 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1868,14 +1868,17 @@
DebuggerStateRegistry::CreateState(debug_options, debugger_state));
std::vector<string> input_names;
+ input_names.reserve(req.num_feeds());
for (size_t i = 0; i < req.num_feeds(); ++i) {
input_names.push_back(req.feed_name(i));
}
std::vector<string> output_names;
+ output_names.reserve(req.num_fetches());
for (size_t i = 0; i < req.num_fetches(); ++i) {
output_names.push_back(req.fetch_name(i));
}
std::vector<string> target_names;
+ target_names.reserve(req.num_targets());
for (size_t i = 0; i < req.num_targets(); ++i) {
target_names.push_back(req.target_name(i));
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 56eb09e..c1026dc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -17,6 +17,7 @@
#include <string>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -33,7 +34,6 @@
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/port.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index a0b39c7..526085c 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -1948,7 +1948,7 @@
deps = [
"//tensorflow/core:framework",
"//tensorflow/security/fuzzing/cc/core/framework:tensor_shape_domains",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -1962,7 +1962,7 @@
"//tensorflow/security/fuzzing/cc/core/framework:datatype_domains",
"//tensorflow/security/fuzzing/cc/core/framework:tensor_domains",
"//tensorflow/security/fuzzing/cc/core/framework:tensor_shape_domains",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 9ebcd90..ca9e5a6 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -91,6 +91,14 @@
// Maps the index of dataset elements to a globally shuffled index. See the
// comment for IteratorContext::Params::index_mapper for more details.
+// Notes:
+// * `absl::OutOfRangeError` indicates the input index argument exceeds
+// the cardinality of the dataset.
+// * `absl::NotFoundError` indicates we should skip this element.
+// This happens in the case we mix multiple datasets into one. For example,
+// `dataset1.concatenate(dataset2)`.
+// See go/tf-data-random-access-iterator and
+// go/tf-data-random-access-iterator-for-concatenate for more info.
using IndexMapperFn = std::function<absl::StatusOr<size_t>(size_t)>;
constexpr char kTFDataFunction[] = "_tf_data_function";
@@ -905,6 +913,10 @@
IndexMapperFn index_mapper() const { return params_.index_mapper; }
+ void set_restored_element_count(size_t element_count) {
+ params_.restored_element_count.emplace(element_count);
+ }
+
std::optional<int64_t> restored_element_count() const {
return params_.restored_element_count;
}
diff --git a/tensorflow/core/framework/dataset_test.cc b/tensorflow/core/framework/dataset_test.cc
index a632551..66213ea 100644
--- a/tensorflow/core/framework/dataset_test.cc
+++ b/tensorflow/core/framework/dataset_test.cc
@@ -21,11 +21,11 @@
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/framework/partial_tensor_shape_test.cc b/tensorflow/core/framework/partial_tensor_shape_test.cc
index e20a585..581989c 100644
--- a/tensorflow/core/framework/partial_tensor_shape_test.cc
+++ b/tensorflow/core/framework/partial_tensor_shape_test.cc
@@ -17,13 +17,13 @@
#include <limits>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index d743669..71d856e 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -1288,6 +1288,10 @@
bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+ CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+ CHECK_LT(idx, output_handle_shapes_and_types_.size())
+ << "Got idx: " << idx << " but only "
+ << output_handle_shapes_and_types_.size() << " inputs.";
if (output_handle_shapes_and_types_[idx] == nullptr) {
output_handle_shapes_and_types_[idx].reset(
new std::vector<ShapeAndType>(shapes_and_types));
@@ -1299,6 +1303,10 @@
bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+ CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+ CHECK_LT(idx, input_handle_shapes_and_types_.size())
+ << "Got idx: " << idx << " but only "
+ << input_handle_shapes_and_types_.size() << " inputs.";
if (input_handle_shapes_and_types_[idx] == nullptr) {
input_handle_shapes_and_types_[idx].reset(
new std::vector<ShapeAndType>(shapes_and_types));
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index f00dac8..6ed932e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -683,6 +683,10 @@
void set_input_handle_shapes_and_types(
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+ CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+ CHECK_LT(idx, input_handle_shapes_and_types_.size())
+ << "Got idx: " << idx << " but only "
+ << input_handle_shapes_and_types_.size() << " inputs.";
input_handle_shapes_and_types_[idx] =
absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
}
@@ -690,17 +694,29 @@
// Returns the output handle shapes and types, for the resource tensor output
// at index <idx>. Returns NULL if the shape and types were never set.
const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
+ CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+ CHECK_LT(idx, output_handle_shapes_and_types_.size())
+ << "Got idx: " << idx << " but only "
+ << output_handle_shapes_and_types_.size() << " outputs.";
return output_handle_shapes_and_types_[idx].get();
}
// Returns the inputs handle shapes and types, for the resource tensor input
// at index <idx>. Returns NULL if the shape and types were not available.
const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
+ CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+ CHECK_LT(idx, input_handle_shapes_and_types_.size())
+ << "Got idx: " << idx << " but only "
+ << input_handle_shapes_and_types_.size() << " inputs.";
return input_handle_shapes_and_types_[idx].get();
}
void set_output_handle_shapes_and_types(
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+ CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+ CHECK_LT(idx, output_handle_shapes_and_types_.size())
+ << "Got idx: " << idx << " but only "
+ << output_handle_shapes_and_types_.size() << " inputs.";
output_handle_shapes_and_types_[idx] =
absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
}
diff --git a/tensorflow/core/framework/tensor_fuzz.cc b/tensorflow/core/framework/tensor_fuzz.cc
index 5665185..49f91b0 100644
--- a/tensorflow/core/framework/tensor_fuzz.cc
+++ b/tensorflow/core/framework/tensor_fuzz.cc
@@ -13,13 +13,13 @@
limitations under the License.
==============================================================================*/
#include "fuzztest/fuzztest.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/security/fuzzing/cc/core/framework/datatype_domains.h"
#include "tensorflow/security/fuzzing/cc/core/framework/tensor_domains.h"
#include "tensorflow/security/fuzzing/cc/core/framework/tensor_shape_domains.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow::fuzzing {
namespace {
diff --git a/tensorflow/core/framework/tensor_shape_fuzz.cc b/tensorflow/core/framework/tensor_shape_fuzz.cc
index 7a0351a..d14284e 100644
--- a/tensorflow/core/framework/tensor_shape_fuzz.cc
+++ b/tensorflow/core/framework/tensor_shape_fuzz.cc
@@ -18,9 +18,9 @@
#include <vector>
#include "fuzztest/fuzztest.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/security/fuzzing/cc/core/framework/tensor_shape_domains.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace fuzzing {
diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc
index e55cefa..c13a16f 100644
--- a/tensorflow/core/framework/tensor_shape_test.cc
+++ b/tensorflow/core/framework/tensor_shape_test.cc
@@ -18,6 +18,7 @@
#include <cstdint>
#include <limits>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
@@ -28,7 +29,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
class TensorShapeTestHelper {
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e0981fe..2bf4de1 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -444,7 +444,7 @@
}
bool IsQueue(const NodeDef& node) {
- return str_util::EndsWith(node.op(), "QueueV2");
+ return absl::EndsWith(node.op(), "QueueV2");
}
bool IsRandomShuffle(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index bfed969..e23e1c0 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -1022,7 +1022,7 @@
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
index f9d4f06..a212e25 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -178,7 +178,8 @@
NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name,
StringPiece num_parallel_calls_node_name,
StringPiece function_name,
- StringPiece deterministic) {
+ StringPiece deterministic,
+ bool use_unbounded_threadpool) {
return test::function::NDef(
name, "ParallelMapDatasetV2",
{string(input_node_name), string(num_parallel_calls_node_name)},
@@ -188,6 +189,7 @@
{"output_shapes", absl::Span<const TensorShape>{}},
{"output_types", absl::Span<const DataType>{}},
{"deterministic", string(deterministic)},
+ {"use_unbounded_threadpool", use_unbounded_threadpool},
});
}
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
index 7341329..c5823d1 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -89,7 +89,8 @@
NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name,
StringPiece num_parallel_calls_node_name,
StringPiece function_name,
- StringPiece deterministic);
+ StringPiece deterministic,
+ bool use_unbounded_threadpool);
// Creates a test NodeDef for ParseExampleDataset.
NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name,
diff --git a/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc
index 54bf5fe..5cb93fa 100644
--- a/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc
@@ -18,6 +18,7 @@
#include <string>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/function_testlib.h"
@@ -27,7 +28,6 @@
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace grappler {
@@ -101,7 +101,7 @@
{{"value", 1}, {"dtype", DT_INT32}}),
graph_tests_utils::MakeParallelMapV2Node(
"map_1", "io_1", "num_parallel_calls_1", "noop_1",
- /*deterministic=*/"default"),
+ /*deterministic=*/"default", /*use_unbounded_threadpool=*/false),
NDef("files_2", "Const", {},
{{"value", "file1file2"}, {"dtype", DT_STRING}}),
@@ -114,7 +114,7 @@
{{"value", 1}, {"dtype", DT_INT32}}),
graph_tests_utils::MakeParallelMapV2Node(
"map_2", "io_2", "num_parallel_calls_2", "noop_2",
- /*deterministic=*/"default"),
+ /*deterministic=*/"default", /*use_unbounded_threadpool=*/false),
NDef("zip", "ZipDataset", {"map_1", "map_2"}, {}),
NDef("Sink", "Identity", {"zip"}, {})},
diff --git a/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc b/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc
index 1b76fee..1ff66f3 100644
--- a/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc
@@ -85,7 +85,7 @@
} else {
orig_map_node_def = graph_tests_utils::MakeParallelMapV2Node(
"map", "range", "num_parallel_calls", "MyFunction",
- deterministic ? "true" : "false");
+ deterministic ? "true" : "false", /*use_unbounded_threadpool=*/false);
}
orig_map_node_def.add_input("^start");
AttrValue* attr_val = &(*orig_map_node_def.mutable_attr())["Targuments"];
@@ -321,7 +321,8 @@
{{"value", 1}, {"dtype", DT_INT32}}),
graph_tests_utils::MakeParallelMapV2Node(
"map", "range", "num_parallel_calls", func_name,
- deterministic ? "true" : "false")},
+ deterministic ? "true" : "false",
+ /*use_unbounded_threadpool=*/false)},
// FunctionLib
{test::function::XTimesTwo(), OuterXTimesTwo()});
@@ -387,7 +388,8 @@
{{"value", Tensor(int64_t{1})}, {"dtype", DT_INT64}}),
graph_tests_utils::MakeParallelMapV2Node(
"map", "range", "num_parallel_calls", func_name,
- deterministic ? "true" : "false"),
+ deterministic ? "true" : "false",
+ /*use_unbounded_threadpool=*/false),
graph_tests_utils::MakePrefetchNode("prefetch", "map", "buffer_size")},
// FunctionLib
{test::function::RandomUniform(), OuterRandomUniform()});
@@ -485,7 +487,7 @@
NodeDef map_node_def = graph_tests_utils::MakeParallelMapV2Node(
"map", "range", "num_parallel_calls", func_name,
- deterministic ? "true" : "false");
+ deterministic ? "true" : "false", /*use_unbounded_threadpool=*/false);
map_node_def.add_input("^start");
// Rewrite occurs due to parallelism in map function
@@ -587,7 +589,8 @@
{{"value", Tensor(int64_t{1})}, {"dtype", DT_INT64}}),
graph_tests_utils::MakeParallelMapV2Node(
"map", "range", "num_parallel_calls", func_name,
- deterministic ? "true" : "false"),
+ deterministic ? "true" : "false",
+ /*use_unbounded_threadpool=*/false),
graph_tests_utils::MakePrefetchNode("prefetch", "map", "buffer_size")},
// FunctionLib
{test::function::ReadResourceVariable(), OuterReadResourceVariable()});
diff --git a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
index 207a9cd..bf15420 100644
--- a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
@@ -21,7 +21,6 @@
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -147,7 +146,7 @@
{{"value", 1}, {"dtype", DT_INT32}}),
graph_tests_utils::MakeParallelMapV2Node(
"map", "range", "num_parallel_calls", "XTimesTwo",
- /*deterministic=*/"default")},
+ /*deterministic=*/"default", /*use_unbounded_threadpool=*/false)},
// FunctionLib
{
test::function::XTimesTwo(),
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 69943e8..091e94d 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -124,6 +124,11 @@
if (node2->op() != "MapDataset" && !IsParallelMap(*node2)) {
continue;
}
+ // Do not fuse ParallelMap node that uses the unbounded thread pool.
+ if (node2->attr().find("use_unbounded_threadpool") != node2->attr().end() &&
+ node2->attr().at("use_unbounded_threadpool").b()) {
+ continue;
+ }
// Use a more descriptive variable name now that we know the node type.
NodeDef* map_node = node2;
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
index 74947cb..077123e 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -402,6 +402,71 @@
EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
}
+TEST(MapAndBatchFusionTest, NoChange_UnboundedThreadpoolParallelMap) {
+ GrapplerItem item;
+ MutableGraphView graph(&item.graph);
+ NodeDef *start_node = graph_utils::AddScalarConstNode<int64_t>(0, &graph);
+ NodeDef *stop_node = graph_utils::AddScalarConstNode<int64_t>(10, &graph);
+ NodeDef *step_node = graph_utils::AddScalarConstNode<int64_t>(1, &graph);
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ std::vector<std::pair<string, AttrValue>> range_attrs;
+ NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
+ range_attrs, &graph);
+ NodeDef *captured_input_node =
+ graph_utils::AddScalarConstNode<StringPiece>("hello", &graph);
+ NodeDef *num_parallel_calls_node =
+ graph_utils::AddScalarConstNode<int>(2, &graph);
+
+ NodeDef *map_node;
+ {
+ std::vector<string> map_inputs(3);
+ map_inputs[0] = range_node->name();
+ map_inputs[1] = captured_input_node->name();
+ map_inputs[2] = num_parallel_calls_node->name();
+ std::vector<std::pair<string, AttrValue>> map_attrs(3);
+ AttrValue f_attr;
+ SetAttrValue("f", &f_attr);
+ map_attrs[0] = std::make_pair("f", f_attr);
+ AttrValue args_attr;
+ SetAttrValue("Targuments", &args_attr);
+ map_attrs[1] = std::make_pair("Targuments", args_attr);
+ AttrValue use_unbounded_threadpool_attr;
+ SetAttrValue(true, &use_unbounded_threadpool_attr);
+ map_attrs[2] = std::make_pair("use_unbounded_threadpool",
+ use_unbounded_threadpool_attr);
+ map_node = graph_utils::AddNode("", "ParallelMapDataset", map_inputs,
+ map_attrs, &graph);
+ }
+
+ NodeDef *batch_size_node =
+ graph_utils::AddScalarConstNode<int64_t>(5, &graph);
+ NodeDef *batch_node;
+ {
+ std::vector<string> batch_inputs(2);
+ batch_inputs[0] = map_node->name();
+ batch_inputs[1] = batch_size_node->name();
+ std::vector<std::pair<string, AttrValue>> batch_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ batch_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ batch_attrs[1] = std::make_pair("output_types", types_attr);
+ batch_node = graph_utils::AddNode("", "BatchDataset", batch_inputs,
+ batch_attrs, &graph);
+ }
+
+ MapAndBatchFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index d2bf6a3..78e9eba 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -216,10 +216,22 @@
for (const NodeDef& node : sorted_old_graph.node()) {
const NodeDef* map_node = get_map_node(node);
if (!map_node) continue;
+ // Do not fuse ParallelMap node that uses the unbounded thread pool.
+ if (map_node->attr().find("use_unbounded_threadpool") !=
+ map_node->attr().end() &&
+ map_node->attr().at("use_unbounded_threadpool").b()) {
+ continue;
+ }
const NodeDef* parent_map_node =
get_map_node(*graph_utils::GetInputNode(*map_node, graph));
if (!parent_map_node) continue;
+ // Do not fuse ParallelMap node that uses the unbounded thread pool.
+ if (parent_map_node->attr().find("use_unbounded_threadpool") !=
+ parent_map_node->attr().end() &&
+ parent_map_node->attr().at("use_unbounded_threadpool").b()) {
+ continue;
+ }
// TODO(b/148614504): Support fusing different types of map operations.
if (parent_map_node->op() != map_node->op()) continue;
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index c81191e..a773d9b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -20,6 +20,7 @@
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -28,7 +29,6 @@
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
namespace tensorflow {
@@ -88,9 +88,11 @@
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
num_parallel_calls_node,
MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
- "XTimesTwo", "default"),
+ "XTimesTwo", "default",
+ /*use_unbounded_threadpool=*/false),
MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
- "XTimesTwo", "default")},
+ "XTimesTwo", "default",
+ /*use_unbounded_threadpool=*/false)},
// FunctionLib
{
test::function::XTimesTwo(),
@@ -171,9 +173,11 @@
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
num_parallel_calls_node,
MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
- "XTimesTwo", "default"),
+ "XTimesTwo", "default",
+ /*use_unbounded_threadpool=*/false),
MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
- "XTimesTwo", "default")},
+ "XTimesTwo", "default",
+ /*use_unbounded_threadpool=*/false)},
// FunctionLib
{
test::function::XTimesTwo(),
@@ -187,6 +191,36 @@
EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
}
+TEST(MapFusionTest, NoChange_UnboundedThreadpoolParallelMap) {
+ using test::function::NDef;
+ GrapplerItem item;
+ NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper(
+ "num_parallel_calls", DT_INT64,
+ [](TensorProto* proto) { proto->add_int64_val(-1); });
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ num_parallel_calls_node,
+ MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
+ "XTimesTwo", "default",
+ /*use_unbounded_threadpool=*/true),
+ MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
+ "XTimesTwo", "default",
+ /*use_unbounded_threadpool=*/false)},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(OptimizeWithMapFusion(item, &output, true));
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
TEST(MapFusionTest, FusedNodesAndFunctionsAreNamedAfterOldNodesAndFunctions) {
using test::function::NDef;
NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper(
@@ -209,10 +243,11 @@
num_parallel_calls_node,
MakeParallelMapV2Node(parent_map_node_name, "range",
num_parallel_calls_node.name(),
- parent_function_name, "default"),
+ parent_function_name, "default",
+ /*use_unbounded_threadpool=*/false),
MakeParallelMapV2Node(map_node_name, parent_map_node_name,
num_parallel_calls_node.name(), function_name,
- "default")},
+ "default", /*use_unbounded_threadpool=*/false)},
// FunctionLib
{parent_fn, fn});
};
diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc
index 25d86a1..2060b0e 100644
--- a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc
@@ -17,6 +17,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -25,7 +26,6 @@
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/platform/status_matchers.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/protobuf/error_codes.pb.h"
namespace tensorflow {
@@ -74,7 +74,8 @@
/*input_node_name=*/"RangeDataset/_3",
/*num_parallel_calls_node_name=*/"Const/_4",
/*function_name=*/"__inference_Dataset_map_lambda_10",
- /*deterministic=*/"default"),
+ /*deterministic=*/"default",
+ /*use_unbounded_threadpool=*/false),
NDef("dataset", // name
"_Retval", // op
diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc
index 076357a..eba3fce 100644
--- a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc
@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -32,7 +33,6 @@
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/platform/status.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace grappler {
diff --git a/tensorflow/core/grappler/utils/pattern_utils_test.cc b/tensorflow/core/grappler/utils/pattern_utils_test.cc
index 6b8f689..22fe41b 100644
--- a/tensorflow/core/grappler/utils/pattern_utils_test.cc
+++ b/tensorflow/core/grappler/utils/pattern_utils_test.cc
@@ -184,7 +184,7 @@
bool all_indices_matched = true;
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
it++) {
- auto label = str_util::StripPrefix(it->first, "my_");
+ auto label = absl::StripPrefix(it->first, "my_");
int matched_node_idx = it->second;
int expected_node_idx = graph_view.GetNode(label)->node_index();
if (matched_node_idx != expected_node_idx) {
@@ -268,7 +268,7 @@
bool all_indices_matched = true;
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
it++) {
- auto label = str_util::StripPrefix(it->first, "my_");
+ auto label = absl::StripPrefix(it->first, "my_");
int matched_node_idx = it->second;
int expected_node_idx = graph_view.GetNode(label)->node_index();
if (matched_node_idx != expected_node_idx) {
@@ -387,7 +387,7 @@
bool all_indices_matched = true;
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
it++) {
- auto label = str_util::StripPrefix(it->first, "my_");
+ auto label = absl::StripPrefix(it->first, "my_");
int matched_node_idx = it->second;
int expected_node_idx = graph_view.GetNode(label)->node_index();
if (matched_node_idx != expected_node_idx) {
@@ -561,7 +561,7 @@
bool all_indices_matched = true;
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
it++) {
- auto label = str_util::StripPrefix(it->first, "my_");
+ auto label = absl::StripPrefix(it->first, "my_");
int matched_node_idx = it->second;
int expected_node_idx = graph_view.GetNode(label)->node_index();
if (matched_node_idx != expected_node_idx) {
diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc
index 481c9f0..9805e17 100644
--- a/tensorflow/core/ir/types/dialect.cc
+++ b/tensorflow/core/ir/types/dialect.cc
@@ -32,6 +32,7 @@
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
@@ -366,17 +367,10 @@
os << "<";
if (hasRank()) {
auto print_dim = [&](int64_t dim) {
- if (dim != ShapedType::kDynamic) {
- if (dim == 0) {
- // In order to avoid the parseInteger below from confusing a dimension
- // list with '0x' as hex integer, we use 00 for a 0 sized dimension.
- os << "00";
- } else {
- os << dim;
- }
- } else {
+ if (dim != ShapedType::kDynamic)
+ os << dim;
+ else
os << "?";
- }
};
llvm::interleave(getShape(), os, print_dim, "x");
} else {
@@ -405,7 +399,7 @@
llvm::SMLoc loc = parser.getCurrentLocation();
if (succeeded(parser.parseOptionalQuestion())) {
shape.back() = ShapedType::kDynamic;
- } else if (failed(parser.parseInteger(shape.back()))) {
+ } else if (failed(parser.parseDecimalInteger(shape.back()))) {
parser.emitError(loc)
<< "expected an integer or `?` when parsing a tf.shape attribute";
return failure();
diff --git a/tensorflow/core/ir/types/dialect_test.cc b/tensorflow/core/ir/types/dialect_test.cc
index 84a301a..4fb014d 100644
--- a/tensorflow/core/ir/types/dialect_test.cc
+++ b/tensorflow/core/ir/types/dialect_test.cc
@@ -62,7 +62,7 @@
TEST(TFTypesDialect, ParsesDimensionListWithZero) {
// Test that a dimension list with zero can be parsed.
const char *const code = R"mlir(
- "test.op"() {shape = #tf_type.shape<00x128>} : () -> ()
+ "test.op"() {shape = #tf_type.shape<0x128>} : () -> ()
)mlir";
MLIRContext context;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9dd9d5e..213c83a 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -701,6 +701,7 @@
"//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
"//tensorflow/core/kernels/batching_util:batch_resource_base",
"//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
+ "//tensorflow/core/kernels/batching_util:batch_scheduler_utils",
"//tensorflow/core/kernels/batching_util:bounded_executor",
"//tensorflow/core/kernels/batching_util:concat_split_util",
"//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
@@ -1647,7 +1648,7 @@
tf_cc_test(
name = "batch_kernels_test",
- size = "medium",
+ size = "small",
srcs = ["batch_kernels_test.cc"],
features = ["-layering_check"],
deps = [
@@ -4672,7 +4673,7 @@
"spacetobatch_functor.h",
"spacetobatch_functor_gpu.cu.cc",
],
- visibility = [":friends"],
+ visibility = ["//visibility:private"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -4715,7 +4716,7 @@
"spacetodepth_op.h",
"spacetodepth_op_gpu.cu.cc",
],
- visibility = [":friends"],
+ visibility = ["//visibility:private"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 8e1e97d..bd93c1e 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -38,6 +38,7 @@
#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
#include "tensorflow/core/kernels/batching_util/bounded_executor.h"
#include "tensorflow/core/kernels/batching_util/concat_split_util.h"
#include "tensorflow/core/kernels/batching_util/periodic_function.h"
@@ -177,7 +178,8 @@
/*mixed_priority_batching_policy=*/
serving::MixedPriorityBatchingPolicy::
kLowPriorityPaddingWithMaxBatchSize,
- enable_large_batch_splitting, resource);
+ enable_large_batch_splitting,
+ /*batch_padding_policy=*/"PAD_UP", resource);
}
static Status Create(
@@ -190,7 +192,7 @@
int32_t low_priority_max_enqueued_batches,
const std::vector<int32>& low_priority_allowed_batch_sizes,
serving::MixedPriorityBatchingPolicy mixed_priority_batching_policy,
- bool enable_large_batch_splitting,
+ bool enable_large_batch_splitting, absl::string_view batch_padding_policy,
std::unique_ptr<BatchResource>* resource) {
BatcherT::Options batcher_options;
batcher_options.num_batch_threads = num_batch_threads;
@@ -203,8 +205,8 @@
num_batch_threads, max_execution_batch_size, batch_timeout_micros,
max_enqueued_batches, allowed_batch_sizes,
enable_large_batch_splitting,
- /*disable_padding=*/false, low_priority_max_batch_size,
- low_priority_batch_timeout_micros,
+ /*disable_padding=*/false, batch_padding_policy,
+ low_priority_max_batch_size, low_priority_batch_timeout_micros,
low_priority_max_enqueued_batches, low_priority_allowed_batch_sizes,
mixed_priority_batching_policy),
allowed_batch_sizes));
@@ -439,7 +441,7 @@
low_priority_batch_timeout_micros_,
low_priority_max_enqueued_batches_, low_priority_allowed_batch_sizes_,
mixed_priority_batching_policy, enable_large_batch_splitting_,
- &new_resource));
+ batch_padding_policy_, &new_resource));
if (session_metadata) {
new_resource->set_session_metadata(*session_metadata);
}
diff --git a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc
index e3601cf..7e66f9b 100644
--- a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc
+++ b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc
@@ -13,8 +13,6 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/kernels/batch_kernels.h"
-
#include <cstdint>
#include <memory>
#include <utility>
@@ -22,6 +20,7 @@
#include <gtest/gtest.h>
#include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/device_factory.h"
#include "tensorflow/core/framework/function.h"
@@ -30,6 +29,7 @@
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/batch_kernel_test_util.h"
+#include "tensorflow/core/kernels/batch_kernels.h"
#include "tensorflow/core/kernels/batching_util/warmup.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/env.h"
@@ -37,7 +37,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/version.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/refcount.h"
diff --git a/tensorflow/core/kernels/batch_kernels_env_test.cc b/tensorflow/core/kernels/batch_kernels_env_test.cc
index 508c0e8..5a8bfec 100644
--- a/tensorflow/core/kernels/batch_kernels_env_test.cc
+++ b/tensorflow/core/kernels/batch_kernels_env_test.cc
@@ -14,11 +14,11 @@
==============================================================================*/
#include <gmock/gmock.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/kernels/batch_kernel_test_util.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc
index 62666c0..9aaeb5a 100644
--- a/tensorflow/core/kernels/batch_kernels_test.cc
+++ b/tensorflow/core/kernels/batch_kernels_test.cc
@@ -17,12 +17,15 @@
#include <cstdint>
#include <memory>
+#include <string>
#include <utility>
#include <vector>
#include <gtest/gtest.h>
#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/device_factory.h"
#include "tensorflow/core/framework/function.h"
@@ -39,12 +42,13 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/version.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/criticality.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/refcount.h"
#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace {
@@ -84,19 +88,13 @@
return absl::OkStatus();
}});
}
-};
-class BatchFunctionTestState : public SharedBatchFunctionTestState {
- public:
- // Init test fixture with a batch kernel instance. The caller guarantees that
- // the device pointer is valid throughout the life of this class.
- absl::Status Init(Device *device, bool enable_low_priority_queue,
- absl::string_view mixed_priority_policy,
- int64_t expected_batch_size) {
- // Override the per-test/per-op device with a given device so that it can
- // be shared between ops.
- device_ = device;
-
+ protected:
+ // Create common batch function op for testing.
+ absl::StatusOr<NodeDefBuilder> CreateBatchFunctionBuilder(
+ const std::vector<int> &allowed_batch_sizes, int max_batch_size,
+ absl::string_view padding_policy,
+ const TensorShape &expected_output_shape) {
NameAttrList f;
f.set_name("ShapeEnforcingFunction");
FunctionDef func = FunctionDefHelper::Create(
@@ -112,8 +110,7 @@
{{{"o"},
"EnsureShape",
{"x"},
- {{"T", DataType::DT_INT64},
- {"shape", TensorShape({expected_batch_size, 2})}}}},
+ {{"T", DataType::DT_INT64}, {"shape", expected_output_shape}}}},
// ret_def
{{"o", "o:output"}});
TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func));
@@ -121,13 +118,40 @@
std::vector<NodeDefBuilder::NodeOut> inputs(
{NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})});
- TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction")
- .Attr("max_batch_size", 8)
- .Attr("num_batch_threads", 8)
- .Attr("allowed_batch_sizes", {4, 8})
- .Attr("batch_timeout_micros", 1000000)
- .Attr("max_enqueued_batches", 10)
- .Attr("enable_large_batch_splitting", true)
+ return NodeDefBuilder(absl::StrCat("BatchTPUInput", padding_policy),
+ "BatchFunction")
+ .Attr("max_batch_size", max_batch_size)
+ .Attr("num_batch_threads", 8)
+ .Attr("allowed_batch_sizes", allowed_batch_sizes)
+ .Attr("batch_timeout_micros", 1000000)
+ .Attr("max_enqueued_batches", 10)
+ .Attr("enable_large_batch_splitting", true)
+ .Attr("batch_padding_policy", padding_policy)
+ .Attr("Tin", {DataType::DT_INT64})
+ .Input(inputs)
+ .Attr("Tcaptured", std::vector<DataType>{})
+ .Input(std::vector<NodeDefBuilder::NodeOut>{})
+ .Attr("Tout", std::vector<DataType>{DT_INT64})
+ .Attr("f", f);
+ }
+};
+
+class BatchFunctionTestState : public SharedBatchFunctionTestState {
+ public:
+ // Init test fixture with a batch kernel instance. The caller guarantees that
+ // the device pointer is valid throughout the life of this class.
+ absl::Status Init(Device *device, bool enable_low_priority_queue,
+ absl::string_view mixed_priority_policy,
+ int64_t expected_batch_size) {
+ // Override the per-test/per-op device with a given device so that it can
+ // be shared between ops.
+ device_ = device;
+
+ const TensorShape expected_output_shape({expected_batch_size, 2});
+ TF_ASSIGN_OR_RETURN(
+ NodeDefBuilder builder,
+ CreateBatchFunctionBuilder({4, 8}, 8, "PAD_UP", expected_output_shape));
+ TF_RETURN_IF_ERROR(builder
.Attr("low_priority_max_batch_size",
enable_low_priority_queue ? 8 : 0)
.Attr("low_priority_batch_timeout_micros",
@@ -139,14 +163,8 @@
.Attr("low_priority_max_enqueued_batches",
enable_low_priority_queue ? 2 : 0)
.Attr("mixed_priority_policy", mixed_priority_policy)
- .Attr("batch_padding_policy", "PAD_UP")
- .Attr("Tin", {DataType::DT_INT64})
- .Input(inputs)
- .Attr("Tcaptured", std::vector<DataType>{})
- .Input(std::vector<NodeDefBuilder::NodeOut>{})
- .Attr("Tout", std::vector<DataType>{DT_INT64})
- .Attr("f", f)
.Finalize(node_def()));
+
return OpsTestBase::InitOp();
}
@@ -576,48 +594,13 @@
// be shared between ops.
device_ = cpu_device;
- NameAttrList f;
- f.set_name("BatchFunctionKernelParallelWarmupTestStateFunc");
- FunctionDef func = FunctionDefHelper::Create(
- // function_name
- f.name(),
- // in_def
- {"x:int64"},
- // out_def
- {"o:int64"},
- // attr_def
- {},
- // node_def
- {{{"o"},
- "EnsureShape",
- {"x"},
- {{"T", DataType::DT_INT64}, {"shape", TensorShape({2})}}}},
- // ret_def
- {{"o", "o:output"}});
- TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func));
- SharedBatchFunctionTestState::CreateFunctionLibraryRuntime();
+ const TensorShape expected_output_shape({2});
+ TF_ASSIGN_OR_RETURN(
+ NodeDefBuilder builder,
+ CreateBatchFunctionBuilder({2, 4, 8}, enable_splitting ? 16 : 8,
+ "PAD_UP", expected_output_shape));
+ TF_RETURN_IF_ERROR(builder.Finalize(node_def()));
- std::vector<NodeDefBuilder::NodeOut> inputs(
- {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})});
- TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction")
- .Attr("max_batch_size", enable_splitting ? 16 : 8)
- .Attr("num_batch_threads", 8)
- .Attr("allowed_batch_sizes", {2, 4, 8})
- .Attr("batch_timeout_micros", 1000000)
- .Attr("max_enqueued_batches", 10)
- .Attr("enable_large_batch_splitting", true)
- .Attr("low_priority_max_batch_size", 64)
- .Attr("low_priority_batch_timeout_micros", 8000)
- .Attr("low_priority_allowed_batch_sizes", {32, 64})
- .Attr("low_priority_max_enqueued_batches", 1000)
- .Attr("batch_padding_policy", "PAD_UP")
- .Attr("Tin", {DataType::DT_INT64})
- .Input(inputs)
- .Attr("Tcaptured", std::vector<DataType>{})
- .Input(std::vector<NodeDefBuilder::NodeOut>{})
- .Attr("Tout", std::vector<DataType>{DT_INT64})
- .Attr("f", f)
- .Finalize(node_def()));
return OpsTestBase::InitOp();
}
@@ -688,5 +671,80 @@
BatchFunctionKernelParallelWarmupTest,
::testing::Bool());
+class BatchFunctionKernelPaddingTestState
+ : public SharedBatchFunctionTestState {
+ public:
+ // Init test fixture with a batch kernel instance.
+ absl::Status Init(absl::string_view padding_policy, int expected_batch_size) {
+ static auto *const cpu_device = []() {
+ auto device =
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
+ return device.release();
+ }();
+
+ // Override the per-test/per-op device with a global device so that it can
+ // be shared between ops.
+ device_ = cpu_device;
+
+ const TensorShape expected_output_shape({expected_batch_size, 2});
+ TF_RETURN_IF_ERROR(CreateBatchFunctionBuilder({4, 8}, 8, padding_policy,
+ expected_output_shape)
+ ->Finalize(node_def()));
+
+ return OpsTestBase::InitOp();
+ }
+
+ void TestBody() override {}
+};
+
+class BatchFunctionKernelPaddingTest
+ : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(BatchFunctionKernelPaddingTest, PadUp) {
+ SessionMetadata session_metadata;
+ session_metadata.set_name("test_model");
+ session_metadata.set_version(123);
+
+ // Send 5 requests in parallel and check that the given batch padding
+ // policy behaves as expected.
+ int64_t num_requests = 5;
+ int64_t expected_batch_size = 0;
+ std::string padding_policy = GetParam();
+ if (padding_policy == "PAD_UP") {
+ expected_batch_size = 8;
+ } else if (padding_policy == "BATCH_DOWN") {
+ expected_batch_size = 4;
+ } else if (padding_policy == "MINIMIZE_TPU_COST_PER_REQUEST") {
+ expected_batch_size = 8;
+ } else {
+ FAIL() << "Unsupported padding policy: " << padding_policy;
+ }
+
+ {
+ tsl::BlockingCounter blocking_counter(num_requests);
+ for (int i = 0; i < num_requests; ++i) {
+ Env::Default()->SchedClosure([&]() {
+ BatchFunctionKernelPaddingTestState test_state;
+ test_state.set_session_metadata(session_metadata);
+ TF_CHECK_OK(test_state.Init(padding_policy, expected_batch_size));
+ test_state.AddInputFromList<int64_t>(TensorShape({1, 2}), {123, 456});
+ TF_EXPECT_OK(test_state.RunOpKernel());
+
+ test::ExpectTensorEqual<int64_t>(
+ *test_state.GetOutput(0),
+ test::AsTensor<int64_t>({123, 456}, TensorShape({1, 2})));
+ blocking_counter.DecrementCount();
+ });
+ }
+
+ blocking_counter.Wait();
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelPaddingTestSuite,
+ BatchFunctionKernelPaddingTest,
+ ::testing::Values("PAD_UP", "BATCH_DOWN",
+ "MINIMIZE_TPU_COST_PER_REQUEST"));
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 6cf3ef0..4173632 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -121,10 +121,10 @@
"//tensorflow/core/lib/core:notification",
"//tensorflow/core/lib/core:status",
"//tensorflow/core/platform:thread_annotations",
- "//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:criticality",
+ "@local_tsl//tsl/profiler/lib:traceme",
],
)
@@ -134,12 +134,12 @@
hdrs = ["batch_scheduler.h"],
deps = [
"//tensorflow/core:lib",
- "//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:criticality",
+ "@local_tsl//tsl/profiler/lib:traceme",
],
)
@@ -148,8 +148,12 @@
srcs = ["batch_scheduler_utils.cc"],
hdrs = ["batch_scheduler_utils.h"],
deps = [
+ ":batch_scheduler_hdrs",
+ ":batch_stats",
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/time",
],
)
@@ -183,7 +187,10 @@
name = "batch_scheduler_utils_test",
srcs = ["batch_scheduler_utils_test.cc"],
deps = [
+ ":batch_scheduler_hdrs",
":batch_scheduler_utils",
+ ":batch_stats",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
],
)
@@ -195,6 +202,7 @@
":batch_input_task",
":batch_scheduler_hdrs",
":batch_scheduler_utils",
+ ":batch_stats",
":periodic_function_dynamic",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
@@ -209,7 +217,6 @@
"//tensorflow/core/profiler/lib:context_types_hdrs",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/lib:traceme_encode",
- "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
@@ -225,13 +232,13 @@
":batch_input_task",
":batch_scheduler",
":batch_scheduler_utils",
+ ":batch_stats",
":periodic_function_dynamic",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/lib:context_types_hdrs",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/lib:traceme_encode",
- "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
@@ -246,6 +253,7 @@
srcs = ["shared_batch_scheduler_test.cc"],
deps = [
":batch_scheduler",
+ ":batch_scheduler_utils",
":fake_clock_env",
":shared_batch_scheduler",
"//tensorflow/core:lib",
@@ -481,18 +489,30 @@
srcs = ["batch_resource_base_test.cc"],
deps = [
":batch_resource_base",
+ ":batch_scheduler_hdrs",
+ ":batch_scheduler_utils",
":batch_stats",
+ ":shared_batch_scheduler",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:portable_gif_internal",
"//tensorflow/core/common_runtime:cost_constants",
"//tensorflow/core/common_runtime:cost_measurement",
"//tensorflow/core/common_runtime:cost_measurement_registry",
"//tensorflow/core/common_runtime:no_op_cost_measurement",
"//tensorflow/core/common_runtime:request_cost",
"//tensorflow/core/framework:types_proto_cc",
+ "//tensorflow/core/kernels:batch_kernels",
+ "//tensorflow/core/lib/monitoring:cell_reader",
+ "//tensorflow/core/platform:notification",
+ "@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:criticality",
+ "@local_tsl//tsl/platform:status",
],
)
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
index 630ff22..81cbe41 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
@@ -233,6 +233,16 @@
cell->GetCell(model_name, op_name)->Set(max_batch_size);
}
+void RecordBatchParamPaddingPolicy(const string& batch_padding_policy,
+ const string& model_name,
+ const string& op_name) {
+ static auto* cell = monitoring::Gauge<string, 2>::New(
+ "/tensorflow/serving/batching/configured_batch_padding_policy",
+ "The value of BatchFunction.batch_padding_policy attribute.",
+ "model_name", "op_name");
+ cell->GetCell(model_name, op_name)->Set(batch_padding_policy);
+}
+
void RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,
const string& model_name,
const string& op_name) {
@@ -406,6 +416,9 @@
RecordBatchParamMaxEnqueuedBatches(
batcher_queue_options_.max_enqueued_batches, GetModelName(context),
context->op_kernel().name());
+ RecordBatchParamPaddingPolicy(
+ this->batcher_queue_options_.batch_padding_policy,
+ GetModelName(context), context->op_kernel().name());
} else if (adaptive_batcher_) {
RecordBatchParamBatchTimeoutMicros(
adaptive_batcher_queue_options_.batch_timeout_micros,
@@ -472,8 +485,10 @@
}
BatcherQueueT* batcher_queue;
- TF_RETURN_IF_ERROR(
- LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
+ TF_RETURN_IF_ERROR(LookupOrCreateBatcherQueue(
+ /* queue_name= */ batcher_queue_name,
+ /* model_name= */ GetModelName(context),
+ /* op_name= */ context->op_kernel().name(), /* queue= */ &batcher_queue));
if (!session_metadata().name().empty()) {
absl::MutexLock lock(&outstanding_batch_mu_);
@@ -500,7 +515,9 @@
return GetBatcherQueueOptions(
num_batch_threads, max_batch_size, batch_timeout_micros,
max_enqueued_batches, allowed_batch_sizes, enable_large_batch_splitting,
- disable_padding, /*low_priority_max_batch_size=*/0,
+ disable_padding,
+ /*batch_padding_policy=*/kPadUpPolicy,
+ /*low_priority_max_batch_size=*/0,
/*low_priority_batch_timeout_micros=*/0,
/*low_priority_max_enqueued_batches=*/0,
/*low_priority_allowed_batch_sizes=*/{},
@@ -514,7 +531,7 @@
int32_t batch_timeout_micros, int32_t max_enqueued_batches,
const std::vector<int32>& allowed_batch_sizes,
bool enable_large_batch_splitting, bool disable_padding,
- int32_t low_priority_max_batch_size,
+ absl::string_view batch_padding_policy, int32_t low_priority_max_batch_size,
int32_t low_priority_batch_timeout_micros,
int32_t low_priority_max_enqueued_batches,
const std::vector<int32>& low_priority_allowed_batch_sizes,
@@ -523,6 +540,8 @@
batcher_queue_options.input_batch_size_limit = max_batch_size;
batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
+ batcher_queue_options.batch_padding_policy =
+ std::string(batch_padding_policy);
if (low_priority_max_batch_size > 0) {
batcher_queue_options.enable_priority_queue = true;
}
@@ -1172,9 +1191,9 @@
}
}
-// Looks up the batcher queue for 'queue_name'. If it didn't previously exist,
-// creates it.
Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
+ const string& model_name,
+ const string& op_name,
BatcherQueueT** queue) {
mutex_lock l(batcher_queues_mu_);
@@ -1186,8 +1205,12 @@
std::unique_ptr<BatcherQueueT> new_queue;
if (batcher_) {
+ BatcherT::QueueOptions batcher_queue_options = batcher_queue_options_;
+ batcher_queue_options.model_batch_stats = &GlobalBatchStatsRegistry().model(
+ /* model_name= */ model_name, /* op_name= */ op_name);
+
TF_RETURN_IF_ERROR(batcher_->AddQueue(
- batcher_queue_options_,
+ batcher_queue_options,
absl::bind_front(&BatchResourceBase::ProcessBatchCallBack, this),
&new_queue));
} else if (adaptive_batcher_) {
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h
index e8b3926..c50b29f 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.h
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h
@@ -25,6 +25,7 @@
#include <vector>
#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
#include "absl/synchronization/blocking_counter.h"
#include "tensorflow/core/common_runtime/cost_measurement_registry.h"
#include "tensorflow/core/common_runtime/request_cost.h"
@@ -34,6 +35,7 @@
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
#include "tensorflow/core/platform/context.h"
@@ -52,6 +54,7 @@
int32_t batch_timeout_micros;
int32_t max_enqueued_batches;
std::vector<int32_t> allowed_batch_sizes;
+ std::string batch_padding_policy{kPadUpPolicy};
int32_t low_priority_max_batch_size;
int32_t low_priority_batch_timeout_micros;
int32_t low_priority_max_enqueued_batches;
@@ -213,6 +216,7 @@
int32_t batch_timeout_micros, int32_t max_enqueued_batches,
const std::vector<int32>& allowed_batch_sizes,
bool enable_large_batch_splitting, bool disable_padding,
+ absl::string_view batch_padding_policy,
int32_t low_priority_max_batch_size,
int32_t low_priority_batch_timeout_micros,
int32_t low_priority_max_enqueued_batches,
@@ -332,9 +336,14 @@
static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch,
int output_index);
- // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
+ // Looks up the batcher queue for 'queue_name'. If it didn't previously exist,
// creates it.
+ //
+ // The model_name and op_name are the names of the current model and
+ // operation, respectively.
Status LookupOrCreateBatcherQueue(const string& queue_name,
+ const string& model_name,
+ const string& op_name,
BatcherQueueT** queue);
SessionMetadata session_metadata_;
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
index 3900711..fa4fad9 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
@@ -16,22 +16,40 @@
#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
#include <cstdint>
+#include <functional>
#include <memory>
+#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
+#include "absl/types/span.h"
#include "tensorflow/core/common_runtime/cost_constants.h"
#include "tensorflow/core/common_runtime/cost_measurement.h"
#include "tensorflow/core/common_runtime/cost_measurement_registry.h"
#include "tensorflow/core/common_runtime/request_cost.h"
+#include "tensorflow/core/framework/device.h"
+#include "tensorflow/core/framework/device_factory.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
#include "tensorflow/core/kernels/batching_util/batch_stats.h"
+#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
+#include "tensorflow/core/lib/monitoring/cell_reader.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
#include "tsl/platform/criticality.h"
+#include "tsl/platform/status.h"
namespace tensorflow {
namespace serving {
@@ -400,6 +418,210 @@
original_cumulative_processed_size + 4);
}
+class BatchResourceBaseTest : public ::testing::Test {
+ protected:
+ // Like BatchResourceBase but overrides abstract methods, one of which
+ // notifies the exposed process_func_batch_called() notification.
+ class MyBatchResource : public BatchResourceBase {
+ public:
+ using BatchResourceBase::BatchResourceBase;
+
+ std::string DebugString() const override { return ""; }
+
+ void ProcessFuncBatchImpl(
+ const BatchResourceBase::BatchTask& /* last_task */,
+ absl::Span<const Tensor> /* inputs */,
+ std::vector<Tensor>* /* combined_outputs */,
+ std::function<void(const absl::Status&)> /* done */) const override {
+ process_func_batch_called_.Notify();
+ }
+
+ Notification& process_func_batch_called() {
+ return process_func_batch_called_;
+ }
+
+ private:
+ mutable Notification process_func_batch_called_;
+ };
+
+ BatchResourceBaseTest() {
+ // The whole point of this test fixture is to create a usable batch function
+ // context, context_.
+
+ // Create device_.
+ device_ = DeviceFactory::NewDevice("CPU", SessionOptions{},
+ "/job:a/replica:0/task:0");
+
+ // Create batch_kernel_node_def.
+ NodeDefBuilder batch_function_builder("my_batch_node", "BatchFunction");
+ batch_function_builder.Attr("max_batch_size", 128);
+ batch_function_builder.Attr("num_batch_threads", 8);
+ batch_function_builder.Attr("allowed_batch_sizes", {2, 4, 8});
+ batch_function_builder.Attr("batch_timeout_micros", 100);
+ batch_function_builder.Attr("max_enqueued_batches", 100);
+ batch_function_builder.Attr("enable_large_batch_splitting", true);
+ std::vector<DataType> input_dtypes = {DataType::DT_INT64,
+ DataType::DT_INT64};
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ inputs.push_back(NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64}));
+ inputs.push_back(NodeDefBuilder::NodeOut({"n2", 1, DataType::DT_INT64}));
+ batch_function_builder.Attr("Tin", input_dtypes);
+ batch_function_builder.Input(inputs);
+ batch_function_builder.Attr("Tcaptured", {DataType::DT_INT64});
+ batch_function_builder.Input(std::vector<NodeDefBuilder::NodeOut>{
+ NodeDefBuilder::NodeOut({"n3", 1, DataType::DT_INT64})});
+ batch_function_builder.Attr("Tout", {DataType::DT_INT64});
+ NameAttrList f;
+ f.set_name("func_to_batch");
+ batch_function_builder.Attr("f", f);
+ NodeDef batch_kernel_node_def;
+ TF_CHECK_OK(batch_function_builder.Finalize(&batch_kernel_node_def));
+
+ // Create batch_kernel_.
+ absl::Status op_kernel_creation_status;
+ batch_kernel_ =
+ CreateOpKernel(DEVICE_CPU, device_.get(), device_->GetAllocator({}),
+ batch_kernel_node_def, TF_GRAPH_DEF_VERSION,
+ &op_kernel_creation_status);
+ TF_CHECK_OK(op_kernel_creation_status);
+ CHECK(batch_kernel_ != nullptr);
+
+ // Create input tensors.
+ input_tensor_ = Tensor(DataType::DT_INT64, TensorShape({5, 2, 1}));
+ input_tensor_values_ = {
+ TensorValue(&input_tensor_),
+ TensorValue(&input_tensor_),
+ TensorValue(&input_tensor_),
+ };
+
+ // Fill-in session_metadata_.
+ session_metadata_.set_name("my_model_name");
+
+ // Fill-in params_.
+ params_.device = device_.get();
+ params_.op_kernel = batch_kernel_.get();
+ params_.inputs = input_tensor_values_;
+ params_.session_metadata = &session_metadata_;
+
+ // Create context_.
+ context_ = std::make_unique<OpKernelContext>(¶ms_);
+ }
+
+ std::unique_ptr<Device> device_;
+
+ std::unique_ptr<OpKernel> batch_kernel_;
+
+ Tensor input_tensor_;
+ std::vector<TensorValue> input_tensor_values_;
+
+ SessionMetadata session_metadata_;
+
+ OpKernelContext::Params params_;
+
+ std::unique_ptr<OpKernelContext> context_;
+};
+
+TEST_F(BatchResourceBaseTest, PassesCorrectModelBatchStatsToSbs) {
+ using BatchTask = BatchResourceBase::BatchTask;
+ using SharedBatchScheduler = SharedBatchScheduler<BatchTask>;
+
+ // Like SharedBatchScheduler but exposes the last QueueOptions passed to
+ // AddQueue as queue_options().
+ class MySharedBatchScheduler : public SharedBatchScheduler {
+ public:
+ MySharedBatchScheduler() : SharedBatchScheduler::SharedBatchScheduler({}) {}
+
+ absl::Status AddQueue(
+ const QueueOptions& options,
+ ProcessBatchCallback process_batch_callback,
+ std::unique_ptr<BatchScheduler<BatchTask>>* queue) override {
+ queue_options_ = options;
+ return SharedBatchScheduler::AddQueue(options, process_batch_callback,
+ queue);
+ }
+
+ const QueueOptions& queue_options() const { return queue_options_; }
+
+ private:
+ QueueOptions queue_options_;
+ };
+
+ auto batcher = std::make_shared<MySharedBatchScheduler>();
+
+ MyBatchResource* my_batch_resource = new MyBatchResource(
+ /* has_process_batch_function */ true,
+ /* batcher= */ batcher,
+ /* batcher_queue_options */ {},
+ /* allowed_batch_sizes */ {});
+
+ TF_CHECK_OK(my_batch_resource->RegisterInput(
+ /* guid= */
+ 0,
+ /* context= */ context_.get(),
+ /* batcher_queue_name= */ "batcher_queue_name",
+ /* create_batch_task_fn= */
+ []() -> absl::StatusOr<std::unique_ptr<BatchResourceBase::BatchTask>> {
+ return std::make_unique<BatchResourceBase::BatchTask>();
+ },
+ /* done_callback= */ [] {}, /* forced_warmup_batch_size= */ 0));
+
+ EXPECT_EQ(batcher->queue_options().model_batch_stats,
+ &GlobalBatchStatsRegistry().model(/* model_name= */ "my_model_name",
+ /* op_name= */ "my_batch_node"));
+
+ // Wait for the batch timeout to expire and the scheduler to dump the only
+ // scheduled task back to the batch resource. If we don't do this, the
+ // scheduler will do this itself on destruction, when the resource has already
+ // been destroyed.
+ my_batch_resource->process_func_batch_called().WaitForNotificationWithTimeout(
+ absl::Seconds(1));
+
+ // This is how we have to destroy the BatchResource.
+ my_batch_resource->Unref();
+}
+
+TEST_F(BatchResourceBaseTest, ConfiguredBatchPaddingPolicyMetric) {
+ tensorflow::monitoring::testing::CellReader<std::string> metric(
+ "/tensorflow/serving/batching/configured_batch_padding_policy");
+
+ std::shared_ptr<SharedBatchScheduler<BatchResourceBase::BatchTask>> batcher;
+ TF_CHECK_OK(
+ SharedBatchScheduler<BatchResourceBase::BatchTask>::Create({}, &batcher));
+
+ MyBatchResource* my_batch_resource = new MyBatchResource(
+ /* has_process_batch_function */ true,
+ /* batcher= */ batcher,
+ /* batcher_queue_options */
+ MyBatchResource::BatcherT::QueueOptions{
+ .batch_padding_policy{kMinimizeTpuCostPerRequestPolicy},
+ },
+ /* allowed_batch_sizes */ {});
+
+ TF_CHECK_OK(my_batch_resource->RegisterInput(
+ /* guid= */
+ 0, /* context= */ context_.get(),
+ /* batcher_queue_name= */ "batcher_queue_name",
+ /* create_batch_task_fn= */
+ []() -> absl::StatusOr<std::unique_ptr<BatchResourceBase::BatchTask>> {
+ return std::make_unique<BatchResourceBase::BatchTask>();
+ },
+ /* done_callback= */ [] {}, /* forced_warmup_batch_size= */ 0));
+
+ EXPECT_EQ(metric.Read(/* model_name= */ "my_model_name",
+ /* op_name= */ "my_batch_node"),
+ kMinimizeTpuCostPerRequestPolicy);
+
+ // Wait for the batch timeout to expire and the scheduler to dump the only
+ // scheduled task back to the batch resource. If we don't do this, the
+ // scheduler will do this itself on destruction, when the resource has already
+ // been destroyed.
+ my_batch_resource->process_func_batch_called().WaitForNotificationWithTimeout(
+ absl::Seconds(1));
+
+ // This is how we have to destroy the BatchResource.
+ my_batch_resource->Unref();
+}
+
} // namespace
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler.h b/tensorflow/core/kernels/batching_util/batch_scheduler.h
index c70972c..1c50c55 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler.h
@@ -32,7 +32,7 @@
#include <atomic>
#include <cstddef>
#include <deque>
-#include <functional>
+#include <iterator>
#include <memory>
#include <optional>
#include <utility>
@@ -43,12 +43,11 @@
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/profiler/lib/traceme.h"
#include "tsl/platform/criticality.h"
+#include "tsl/profiler/lib/traceme.h"
namespace tensorflow {
namespace serving {
@@ -252,7 +251,7 @@
// accept new tasks; a closed one cannot. A batch is monotonic: initially it is
// open and tasks can be added to it; then it is closed and its set of tasks
// remains fixed for the remainder of its life. A closed batch cannot be re-
-// opened. Tasks can never be removed from a batch.
+// opened.
//
// Type parameter TaskType must be a subclass of BatchTask.
template <typename TaskType>
@@ -304,6 +303,15 @@
// Returns the TraceMe context id of this batch.
uint64 traceme_context_id() const;
+ // Attempts to trim this batch to a new, smaller size (not to be confused with
+ // the number of tasks in the batch). On success, the trimmed tasks go into
+ // 'out_trimmed_tasks' in the same order the tasks were in this batch.
+ //
+ // The method might not succeed if it needs to split a large task to hit the
+ // correct size.
+ void TryTrimToNewSize(
+ int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks);
+
private:
mutable mutex mu_;
@@ -505,6 +513,45 @@
return traceme_context_id_;
}
+template <typename TaskType>
+void Batch<TaskType>::TryTrimToNewSize(
+ int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
+ mutex_lock l(mu_);
+ DCHECK_GT(new_size, 0);
+ DCHECK_LT(new_size, size_);
+ DCHECK(out_trimmed_tasks.empty());
+
+ // Index of the first task to trim away. It is possible that it is the index
+ // of a task of size larger than 1 that will have to be split in order to get
+ // to the target new_size.
+ int32 first_task_to_move = 0;
+ // The sum of sizes of tasks i, where i < first_task_to_move.
+ int32 size_of_previous_tasks = 0;
+ while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <=
+ new_size) {
+ size_of_previous_tasks += tasks_[first_task_to_move]->size();
+ first_task_to_move++;
+ // The loop must always stop before this check is tripped because new_size
+ // must never be larger than the size of the batch.
+ DCHECK_LT(first_task_to_move, tasks_.size());
+ }
+
+ // Check whether task 'first_task_to_move' will have to be split.
+ if (size_of_previous_tasks < new_size) {
+ // TODO: b/325954758 - Consider supporting splitting large tasks and then
+ // drop 'Try' from the method name.
+ return;
+ }
+ DCHECK_EQ(size_of_previous_tasks, new_size);
+
+ // Actually trim.
+ out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move);
+ std::move(tasks_.begin() + first_task_to_move, tasks_.end(),
+ std::back_inserter(out_trimmed_tasks));
+ tasks_.resize(first_task_to_move);
+ size_ = new_size;
+}
+
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
index e159c437..2f9c903 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
@@ -21,12 +21,13 @@
#include <optional>
#include <string>
#include <tuple>
+#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tsl/platform/criticality.h"
@@ -37,6 +38,7 @@
using ::testing::ElementsAre;
using ::testing::Eq;
+using ::testing::Pointer;
using ::testing::Property;
TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) {
@@ -386,6 +388,53 @@
EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty()); // third call
}
+TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) {
+ Batch<FakeTask> batch;
+
+ auto task0 = new FakeTask(3);
+ batch.AddTask(std::unique_ptr<FakeTask>(task0));
+
+ auto task1 = new FakeTask(5);
+ batch.AddTask(std::unique_ptr<FakeTask>(task1));
+
+ auto task2 = new FakeTask(7);
+ batch.AddTask(std::unique_ptr<FakeTask>(task2));
+
+ auto task3 = new FakeTask(9);
+ batch.AddTask(std::unique_ptr<FakeTask>(task3));
+
+ std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
+ batch.TryTrimToNewSize(/* new_size= */ 8,
+ /* out_trimmed_tasks= */ trimmed_tasks);
+
+ EXPECT_EQ(batch.size(), 8);
+ EXPECT_EQ(batch.num_tasks(), 2);
+
+ EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3)));
+
+ batch.Close(); // Batch::~Batch blocks until the batch is closed.
+}
+
+TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) {
+ Batch<FakeTask> batch;
+
+ auto task0 = new FakeTask(3);
+ batch.AddTask(std::unique_ptr<FakeTask>(task0));
+
+ auto task1 = new FakeTask(5);
+ batch.AddTask(std::unique_ptr<FakeTask>(task1));
+
+ std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
+ batch.TryTrimToNewSize(/* new_size= */ 4,
+ /* out_trimmed_tasks= */ trimmed_tasks);
+
+ EXPECT_EQ(batch.size(), 8);
+ EXPECT_EQ(batch.num_tasks(), 2);
+ EXPECT_TRUE(trimmed_tasks.empty());
+
+ batch.Close(); // Batch::~Batch blocks until the batch is closed.
+}
+
} // namespace
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
index 7e4382a..9a6deb1 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
@@ -16,8 +16,15 @@
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_
+#include <memory>
+#include <optional>
#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -36,6 +43,114 @@
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding);
+// Constants containing possible values for the batch_padding_policy argument
+// of MaybeBatchDown. This argument specifies the policy that a batch scheduler
+// is using when deciding what to do when, say, 18 requests need to be batched,
+// but only 16 and 32 batch sizes are allowed. The following options are
+// available.
+//
+// - PAD_UP: pad to size 32.
+// - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the
+// batch buffer.
+// - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses
+// to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per
+// real request. In this case, it would compare (batch_16_cost / 16) and
+// (batch_32_cost / 18).
+//
+inline constexpr absl::string_view kBatchDownPolicy = "BATCH_DOWN";
+inline constexpr absl::string_view kPadUpPolicy = "PAD_UP";
+inline constexpr absl::string_view kMinimizeTpuCostPerRequestPolicy =
+ "MINIMIZE_TPU_COST_PER_REQUEST";
+
+// Trims the batch to the next allowed batch size when possible and when
+// configured by batch_padding_policy.
+//
+// When trimming, this function puts the trimmed tasks go into the
+// out_trimmed_tasks vector in the same order as they were in the batch.
+template <typename TaskType>
+void MaybeBatchDown(Batch<TaskType>& batch,
+ const std::vector<int32>& allowed_batch_sizes,
+ bool disable_padding,
+ absl::string_view batch_padding_policy,
+ ModelBatchStats* model_batch_stats,
+ std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
+ if (batch_padding_policy == kPadUpPolicy) {
+ // This is the default behavior of batch resource when it is given a batch
+ // size that doesn't match any of the allowed batch sizes.
+ return;
+ }
+ bool minimize_tpu_cost_per_request;
+ if (batch_padding_policy == kBatchDownPolicy) {
+ minimize_tpu_cost_per_request = false;
+ } else if (batch_padding_policy == kMinimizeTpuCostPerRequestPolicy) {
+ if (model_batch_stats == nullptr) {
+ LOG_FIRST_N(ERROR, 1)
+ << kMinimizeTpuCostPerRequestPolicy
+ << " batch padding policy has been chosen "
+ "but no ModelBatchStats passed to the batch scheduler; will "
+ "fall back on the "
+ << kPadUpPolicy << " policy.";
+ return;
+ }
+ minimize_tpu_cost_per_request = true;
+ } else {
+ LOG_FIRST_N(ERROR, 1) << "Unsupported batch_padding_policy: "
+ << batch_padding_policy << ", falling back on the "
+ << kPadUpPolicy << " policy.";
+ return;
+ }
+
+ int32 batch_size = batch.size();
+
+ int32 pad_up_size =
+ GetNextAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
+ if (pad_up_size == batch_size) {
+ return; // Good, no padding is necessary.
+ }
+
+ int32 batch_down_size =
+ GetPrevAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
+ if (batch_down_size == batch_size) {
+ return; // Can't batch down (e.g. no smaller batch size available).
+ }
+
+ if (minimize_tpu_cost_per_request) {
+ // TODO: b/325954758 - Consider logging a warning here or elsewhere if
+ // a larger batch doesn't cost meaningfully cheaper than a smaller batch.
+ // TODO: b/325954758 - Consider logging a warning here or elsewhere if a
+ // smaller batch costs unreasonably cheaper than a larger one (assuming
+ // a batch cost model = constant_cost + batch_size * per_element_cost).
+ // TODO: b/325954758 - Consider occasionally picking either batch size so
+ // that we learn fresh costs of each batch size. For this code, it is not a
+ // large priority though because if we are in between two allowed batch
+ // sizes (say, 16 and 32), chances are that will occasionally organically
+ // get batches of exact sizes 16 and 32 (and then we pick those
+ // unconditionally). But if we explicitly occasionally explored other batch
+ // sizes, we wouldn't have to rely on this "chances are". For other
+ // applications of batch costs, we might also want to occasionally explore
+ // all allowed batch sizes and not just 16 and 32 from this example.
+ std::optional<absl::Duration> down_batch_cost =
+ model_batch_stats->batch_size(batch_down_size).tpu_cost().mean();
+ std::optional<absl::Duration> up_batch_cost =
+ model_batch_stats->batch_size(pad_up_size).tpu_cost().mean();
+ if (!down_batch_cost.has_value() || !up_batch_cost.has_value()) {
+ // We have no data about batch costs, let's just do nothing.
+ return;
+ }
+
+ auto batch_down_cost_per_request = *down_batch_cost / batch_down_size;
+ auto pad_up_cost_per_request = *up_batch_cost / batch_size;
+
+ if (pad_up_cost_per_request < batch_down_cost_per_request) {
+ // Abort batching down because it's cheaper to pad up.
+ return;
+ }
+ }
+
+ // Batch down.
+ batch.TryTrimToNewSize(batch_down_size, out_trimmed_tasks);
+}
+
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc
index 2bff515..e45cb46 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc
@@ -15,7 +15,14 @@
#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
+#include <cstddef>
+#include <memory>
+#include <vector>
+
#include <gtest/gtest.h>
+#include "absl/time/time.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
namespace tensorflow {
namespace serving {
@@ -66,6 +73,208 @@
EXPECT_EQ(GetPrevAllowedBatchSize(10, {2, 4, 8}, false), 8);
}
+class FakeTask : public BatchTask {
+ public:
+ explicit FakeTask(size_t size) : size_(size) {}
+
+ size_t size() const override { return size_; }
+
+ private:
+ const size_t size_;
+};
+
+TEST(MaybeBatchDownTest, PadUp) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kPadUpPolicy,
+ /* model_batch_stats= */ nullptr,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ // The batch must stay unchanged (for the batch resource to then pad it to the
+ // next allowed batch size, thus ending up in a pad-up behavior.)
+ EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, BatchDown) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kBatchDownPolicy,
+ /* model_batch_stats= */ nullptr,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ // The scheduler should trim the batch to a smaller allowed size that requires
+ // no padding.
+ EXPECT_EQ(batch.size(), 2);
+ // The trimmed part.
+ EXPECT_EQ(out_trimmed_tasks.size(), 1);
+}
+
+TEST(MaybeBatchDownTest, BatchDownDoesNotSplitTasks) {
+ // Add tasks for size 3, but the second task is large and will have to be
+ // split if doing batch-down.
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(2));
+ batch.Close();
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kBatchDownPolicy,
+ /* model_batch_stats= */ nullptr,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ // The batch must stay unchanged due the fact that the current implementation
+ // doesn's support splitting large tasks.
+ EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenTheBatchSizeIsAlreadyAllowed) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kBatchDownPolicy,
+ /* model_batch_stats= */ nullptr,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ // The batch should stay unchanged because it's already of an allowed size.
+ EXPECT_EQ(batch.size(), 4);
+}
+
+TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenNoSmallerAllowedSize) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {4, 8},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kBatchDownPolicy,
+ /* model_batch_stats= */ nullptr,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ // Can't batch down because there is no smaller allowed size.
+ EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestPicksBatchDown) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ ModelBatchStats model_batch_stats;
+ model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2));
+ model_batch_stats.batch_size(4).tpu_cost().Register(absl::Seconds(3.1));
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+ /* model_batch_stats= */ &model_batch_stats,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ EXPECT_EQ(batch.size(), 2);
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestPicksPadUp) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ ModelBatchStats model_batch_stats;
+ model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2));
+ model_batch_stats.batch_size(4).tpu_cost().Register(absl::Seconds(2.9));
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+ /* model_batch_stats= */ &model_batch_stats,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestIsOkWithMissingCosts) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ ModelBatchStats model_batch_stats;
+ model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2));
+ // Not adding costs for batch 4.
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+ /* model_batch_stats= */ &model_batch_stats,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ // No expectations as we do not expect a particular behavior. We just care
+ // that we don't crash.
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestDoesPadUpWhenNoModelStats) {
+ Batch<FakeTask> batch;
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.AddTask(std::make_unique<FakeTask>(1));
+ batch.Close();
+
+ std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+ MaybeBatchDown(
+ /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+ /* disable_padding= */ false,
+ /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+ /* model_batch_stats= */ nullptr,
+ /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+ EXPECT_EQ(batch.size(), 3);
+}
+
} // namespace
} // namespace serving
diff --git a/tensorflow/core/kernels/batching_util/batch_stats.h b/tensorflow/core/kernels/batching_util/batch_stats.h
index b430663..87c36fc 100644
--- a/tensorflow/core/kernels/batching_util/batch_stats.h
+++ b/tensorflow/core/kernels/batching_util/batch_stats.h
@@ -66,6 +66,10 @@
namespace tensorflow::serving {
+// Default values for when there is no recorded statistic in ModelBatchStats.
+constexpr int64_t kNumBatchThreadsUnknown = -1;
+constexpr int64_t kBatchTimeoutMicrosUnknown = -1;
+
// Tracks the average cost of registered samples.
//
// Thread-safe.
@@ -167,6 +171,23 @@
return result;
}
+ void SetNumBatchThreads(int64_t num_batch_threads) {
+ num_batch_threads_.store(num_batch_threads, std::memory_order_relaxed);
+ }
+
+ int64_t num_batch_threads() const {
+ return num_batch_threads_.load(std::memory_order_relaxed);
+ }
+
+ void SetBatchTimeoutMicros(int64_t batch_timeout_micros) {
+ batch_timeout_micros_.store(batch_timeout_micros,
+ std::memory_order_relaxed);
+ }
+
+ int64_t batch_timeout_micros() const {
+ return batch_timeout_micros_.load(std::memory_order_relaxed);
+ }
+
private:
mutable mutex mu_;
@@ -184,6 +205,13 @@
// Can be used to generate an internal load metric per model. See
// RegisterQuerySize for more details.
std::atomic<int64_t> cumulative_processed_size_ = 0;
+
+ // The number of batch threads assigned to this model.
+ std::atomic<int64_t> num_batch_threads_ = kNumBatchThreadsUnknown;
+
+ // The timeout in microseconds for this model (after which the current batch
+ // is sent to be processed by the TPU).
+ std::atomic<int64_t> batch_timeout_micros_ = kBatchTimeoutMicrosUnknown;
};
// Tracks batch statistics for all models.
diff --git a/tensorflow/core/kernels/batching_util/batch_stats_test.cc b/tensorflow/core/kernels/batching_util/batch_stats_test.cc
index 223cde6..5f5168c 100644
--- a/tensorflow/core/kernels/batching_util/batch_stats_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_stats_test.cc
@@ -126,6 +126,28 @@
ASSERT_THAT(stats.BatchSizes(), UnorderedElementsAre(1, 2, 4));
}
+TEST(BatchStatsTest, BatchTimeoutIsCorrect) {
+ ModelBatchStats stats;
+
+ // Originally the batch timeout is -1 if unassigned.
+ ASSERT_EQ(stats.batch_timeout_micros(), -1);
+
+ // Assign a batch timeout of 100 microseconds.
+ stats.SetBatchTimeoutMicros(100);
+ ASSERT_EQ(stats.batch_timeout_micros(), 100);
+}
+
+TEST(BatchStatsTest, NumBatchThreadsIsCorrect) {
+ ModelBatchStats stats;
+
+ // Originally the number of batch threads is -1 if unassigned.
+ ASSERT_EQ(stats.num_batch_threads(), -1);
+
+ // Assign a number of per-model batch threads.
+ stats.SetNumBatchThreads(16);
+ ASSERT_EQ(stats.num_batch_threads(), 16);
+}
+
} // namespace
} // namespace tensorflow::serving
diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
index 93ac0c9..acea649 100644
--- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
@@ -29,13 +29,13 @@
#include <variant>
#include <vector>
-#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/time/clock.h"
#include "tensorflow/core/kernels/batching_util/batch_input_task.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
#include "tensorflow/core/kernels/batching_util/periodic_function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -43,6 +43,7 @@
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -149,7 +150,7 @@
const Options& options,
std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler);
- ~SharedBatchScheduler();
+ virtual ~SharedBatchScheduler();
// Adds a queue to which tasks may be submitted. The returned queue implements
// the BatchScheduler API. Each queue has its own set of scheduling options,
@@ -240,6 +241,18 @@
// If true, the padding will not be appended.
bool disable_padding = false;
+ // The padding policy to use.
+ //
+ // See the documentation for kPadUpPolicy for details.
+ string batch_padding_policy = string(kPadUpPolicy);
+
+ // A pointer to a ModelBatchStats instance for this model. To be used for
+ // cost-based padding policy selection.
+ //
+ // If null, some other padding policy will be used if a cost-based one is
+ // requested.
+ ModelBatchStats* model_batch_stats = nullptr;
+
// If true, queue implementation would split high priority and low priority
// inputs into two sub queues.
bool enable_priority_queue = false;
@@ -270,13 +283,15 @@
MixedPriorityBatchingPolicy mixed_priority_batching_policy =
MixedPriorityBatchingPolicy::kLowPriorityPaddingWithMaxBatchSize;
};
- Status AddQueue(const QueueOptions& options,
- ProcessBatchCallback process_batch_callback,
- std::unique_ptr<BatchScheduler<TaskType>>* queue);
+ // This method is marked virtual for testing purposes only.
+ virtual Status AddQueue(const QueueOptions& options,
+ ProcessBatchCallback process_batch_callback,
+ std::unique_ptr<BatchScheduler<TaskType>>* queue);
- private:
+ protected:
explicit SharedBatchScheduler(const Options& options);
+ private:
void GetNextWorkItem_Locked(internal::Queue<TaskType>** queue_for_batch_out,
BatchUniquePtr* batch_to_process_out)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
@@ -588,6 +603,9 @@
// The time at which the first task was added to the open (back-most) batch
// in 'high_priority_batches_'. Valid iff that batch contains at least one
// task.
+ //
+ // Note that when using a batch padding policy other than PAD_UP, this field
+ // might contain an approximate value (see ScheduleBatchWithEagerSplit).
uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
// Whether this queue contains a batch that is eligible to be scheduled.
@@ -920,7 +938,7 @@
template <typename TaskType>
Status Queue<TaskType>::ScheduleWithLazySplit(std::unique_ptr<TaskType>* task) {
- profiler::TraceMe trace_me([task] {
+ tsl::profiler::TraceMe trace_me([task] {
return profiler::TraceMeEncode(
"ScheduleWithLazySplit",
{{"batching_input_task_size", (*task)->size()}});
@@ -1055,7 +1073,7 @@
Status Queue<TaskType>::ScheduleWithoutOrEagerSplit(
std::unique_ptr<TaskType>* task) {
const bool large_batch_splitting = options_.enable_large_batch_splitting;
- profiler::TraceMe trace_me([task, large_batch_splitting] {
+ tsl::profiler::TraceMe trace_me([task, large_batch_splitting] {
return profiler::TraceMeEncode(
large_batch_splitting ? "ScheduleWithEagerSplit"
: "ScheduleWithoutSplit",
@@ -1223,7 +1241,37 @@
std::deque<std::unique_ptr<Batch<TaskType>>>& batches = GetBatches();
// Consider closing the open batch at this time, to schedule it.
if (batches.size() == 1 && IsOpenBatchSchedulable()) {
+ // Support BatchPaddingPolicy::kBatchDown and
+ // BatchPaddingPolicy::kMinimizeTpuCostPerRequest. We do this before
+ // starting a new batch because starting a new batch will close the old
+ // batch, making it read-only.
+ std::vector<std::unique_ptr<TaskType>> trimmed_tasks;
+ MaybeBatchDown(
+ /* batch= */ *batches[0],
+ /* allowed_batch_sizes= */ options_.allowed_batch_sizes,
+ /* disable_padding= */ options_.disable_padding,
+ /* batch_padding_policy= */ options_.batch_padding_policy,
+ /* model_batch_stats= */ options_.model_batch_stats,
+ /* out_trimmed_tasks= */ trimmed_tasks);
+
StartNewBatch();
+
+ // Move the trimmed tasks, if any, into the new batch.
+ Batch<TaskType>& new_batch = *batches[1];
+ for (std::unique_ptr<TaskType>& task : trimmed_tasks) {
+ new_batch.AddTask(std::move(task));
+ }
+ if (!new_batch.empty()) {
+ // TODO - b/325954758: Reconsider the starting time of a trimmed batch.
+ //
+ // Ideally, we'd set open_batch_start_time_micros_ to time we received
+ // the first task, but we don't have this information here, so we're
+ // using NOW as the timestamp. An alternative solution that doesn't
+ // require adding time to each task would be to assume that requests
+ // arrived at a steady rate and therefore use a point between the old
+ // value of open_batch_start_time_micros_ and NOW.
+ open_batch_start_time_micros_ = env_->NowMicros();
+ }
}
if (batches.size() >= 2) {
diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc
index 1c4073a..2a5afae 100644
--- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc
+++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc
@@ -27,11 +27,14 @@
#include "absl/container/fixed_array.h"
#include "absl/status/status.h"
#include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/status.h"
@@ -39,7 +42,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/criticality.h"
namespace tensorflow {
@@ -1052,6 +1054,79 @@
}
}
+TEST_P(SharedBatchSchedulerTest, BatchPaddingPolicyBatchDown) {
+ if (enable_lazy_split()) {
+ GTEST_SKIP()
+ << "BatchPaddingPolicy::kBatchDown is not supported for lazy split.";
+ }
+
+ // Set up a fake clock, which only advances when we explicitly tell it to.
+ test_util::FakeClockEnv env(Env::Default());
+ Notification start_teardown, stop_teardown;
+ std::unique_ptr<Thread> teardown_thread =
+ CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+
+ {
+ Notification first_batch_processed;
+ Notification second_batch_processed;
+ auto callback = [&](std::unique_ptr<Batch<FakeTask>> batch) {
+ if (!first_batch_processed.HasBeenNotified()) {
+ // This is the main expectation of the test.
+ //
+ // The scheduler should have trimmed the batch to a smaller allowed
+ // size which requires no padding.
+ EXPECT_EQ(batch->size(), 2);
+
+ first_batch_processed.Notify();
+ return;
+ }
+
+ if (!second_batch_processed.HasBeenNotified()) {
+ // Leftovers after the first batch.
+ EXPECT_EQ(batch->size(), 1);
+
+ second_batch_processed.Notify();
+ return;
+ }
+
+ ADD_FAILURE() << "Batch callback must not be invoked more than expected";
+ };
+
+ auto scheduler = CreateSharedBatchScheduler(1, &env);
+
+ QueueOptions options =
+ CreateQueueOptions(/* max_execution_batch_size= */ 10,
+ /* input_batch_size_limit= */ 10,
+ /* batch_timeout_micros= */ 10,
+ /* max_enqueued_batches= */ 10);
+
+ // The most interesting option for this test.
+ options.allowed_batch_sizes = {1, 2, 4, 8};
+ options.batch_padding_policy = kBatchDownPolicy;
+
+ auto queue = CreateQueue(scheduler, options, callback);
+
+ // Schedule some tasks and ensure the scheduler calls the callback after a
+ // batch timeout has expired.
+ TF_ASSERT_OK(ScheduleTask(1, queue.get()));
+ TF_ASSERT_OK(ScheduleTask(1, queue.get()));
+ TF_ASSERT_OK(ScheduleTask(1, queue.get()));
+ env.AdvanceByMicroseconds(options.batch_timeout_micros);
+ first_batch_processed.WaitForNotification();
+
+ // Ensure the scheduler correctly updates the starting time of the new
+ // batch.
+ env.AdvanceByMicroseconds(options.batch_timeout_micros - 1);
+ EXPECT_FALSE(second_batch_processed.WaitForNotificationWithTimeout(
+ absl::Milliseconds(10)));
+ env.AdvanceByMicroseconds(1);
+ second_batch_processed.WaitForNotification();
+
+ start_teardown.Notify();
+ }
+ stop_teardown.Notify();
+}
+
// TODO(b/161857471):
// Add test coverage when input-split and no-split returns differently.
INSTANTIATE_TEST_SUITE_P(
diff --git a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc
index 250f843..cb39718 100644
--- a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc
+++ b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc
@@ -19,6 +19,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/path.h"
@@ -26,7 +27,6 @@
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace checkpoint {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index a5df473..c447763 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -121,6 +121,10 @@
"//tensorflow/core:lib_internal",
"//tensorflow/core/data:name_utils",
"//tensorflow/core/data:split_utils",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@local_tsl//tsl/platform:mutex",
],
)
@@ -851,6 +855,7 @@
"//tensorflow/core/data:dataset_utils",
"//tensorflow/core/data:name_utils",
"//tensorflow/core/data:stats_utils",
+ "//tensorflow/core/data:unbounded_thread_pool",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/lib:traceme_encode",
"@com_google_absl//absl/base",
@@ -1480,7 +1485,7 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 2387067..cd8acd1 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -288,17 +288,24 @@
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
+ int64_t input_empty;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
+
if (ctx->restored_element_count().has_value()) {
IteratorContext::Params params(ctx);
params.restored_element_count =
*ctx->restored_element_count() * dataset()->batch_size_;
IteratorContext ctx_copy(params);
- return RestoreInput(&ctx_copy, reader, input_impl_);
+ if (!static_cast<bool>(input_empty)) {
+ TF_RETURN_IF_ERROR(RestoreInput(&ctx_copy, reader, input_impl_));
+ ctx->MergeCheckpoint(ctx_copy.checkpoint());
+ } else {
+ input_impl_.reset();
+ }
+ return absl::OkStatus();
}
- int64_t input_empty;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
if (!static_cast<bool>(input_empty)) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
} else {
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index de8f50d..b77af19 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -880,13 +880,14 @@
TF_RETURN_IF_ERROR(
WriteElementsToCheckpoint(writer, prefix(), cache_->data()));
}
+ TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
return SaveInput(ctx, writer, iterator_);
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
if (ctx->restored_element_count().has_value()) {
- return global_shuffle_iterator_.Restore(ctx);
+ return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
}
mutex_lock l(mu_);
iterator_.reset();
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index 8f380a2..77ed1ce 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -14,13 +14,21 @@
==============================================================================*/
#include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
-#include <string>
+#include <algorithm>
+#include <cstddef>
#include <utility>
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/split_utils.h"
-#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/mutex.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/thread_annotations.h"
namespace tensorflow {
namespace data {
@@ -36,6 +44,30 @@
constexpr char kIndex[] = "i";
constexpr char kInputImplUninitialized[] = "input_impl_uninitialized";
+constexpr char kElementCount[] = "element_count";
+
+namespace {
+
+// Gets the next shuffled index by iterating through the `index_mapper` until
+// 1. It is not a `NotFoundError` or
+// 2. It is an `OutOfRangeError` or
+// 3. It is an error other than `NotFoundError` or `OutOfRangeError`
+absl::StatusOr<size_t> GetNextShuffledIndex(const IndexMapperFn& index_mapper,
+ size_t& element_count) {
+ absl::StatusOr<size_t> shuffled_index = absl::NotFoundError("default");
+
+ while (absl::IsNotFound(shuffled_index.status())) {
+ shuffled_index = index_mapper(element_count++);
+ if (absl::IsOutOfRange(shuffled_index.status())) {
+ return shuffled_index.status();
+ }
+ if (!absl::IsNotFound(shuffled_index.status()) && !shuffled_index.ok()) {
+ return shuffled_index.status();
+ }
+ }
+ return shuffled_index;
+}
+} // namespace
class ConcatenateDatasetOp::Dataset : public DatasetBase {
public:
@@ -58,6 +90,12 @@
&output_tensorshape));
output_shapes_.push_back(output_tensorshape);
}
+ if (input_ != nullptr && !input_->RandomIndexingCompatible().ok()) {
+ random_indexing_compatible_ = input->RandomIndexingCompatible();
+ } else if (to_concatenate_ != nullptr &&
+ !to_concatenate_->RandomIndexingCompatible().ok()) {
+ random_indexing_compatible_ = to_concatenate_->RandomIndexingCompatible();
+ }
}
~Dataset() override {
input_->Unref();
@@ -126,6 +164,10 @@
return absl::OkStatus();
}
+ absl::Status RandomIndexingCompatible() const override {
+ return random_indexing_compatible_;
+ }
+
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
@@ -149,11 +191,15 @@
bool SymbolicCheckpointCompatible() const override { return true; }
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ input_impls_.resize(2);
+
TF_ASSIGN_OR_RETURN(input_contexts_,
CreateInputIteratorContexts(ctx, dataset()));
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
&input_contexts_[0], this, strings::StrCat(prefix(), "[0]"),
- &input_impl_));
+ &input_impls_[0]));
+
ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
return absl::OkStatus();
}
@@ -162,25 +208,115 @@
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- if (!input_impl_) {
+ if (!input_impls_[0] && !input_impls_[1]) {
*end_of_sequence = true;
return absl::OkStatus();
}
- while (i_ < 2) {
- TF_RETURN_IF_ERROR(input_impl_->GetNext(&input_contexts_[i_],
- out_tensors, end_of_sequence));
+ // Global shuffling
+ if (ctx->index_mapper()) {
+ if (input_impls_[1] == nullptr) {
+ // Creates the second iterator immediately in the case of
+ // global random shuffling.
+ TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+ &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+ &input_impls_[1]));
+ ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+ }
+
+ if (input_contexts_[0].index_mapper() == nullptr) {
+ IndexMapperFn left_index_mapper =
+ [index_mapper = ctx->index_mapper(),
+ left_cardinality = dataset()->input_cardinality_,
+ right_cardinality = dataset()->to_concatenate_cardinality_](
+ size_t to_idx) -> absl::StatusOr<size_t> {
+ TF_ASSIGN_OR_RETURN(size_t from_idx, index_mapper(to_idx));
+
+ if (from_idx >= left_cardinality + right_cardinality) {
+ return absl::OutOfRangeError("Running out of elements.");
+ }
+ if (from_idx >= left_cardinality) {
+ // This has to return a status so that upstream global shuffle
+ // iterator will not treat it as an end of sequence.
+ return absl::NotFoundError("Skipping this element.");
+ }
+ return from_idx;
+ };
+
+ IndexMapperFn right_index_mapper =
+ [index_mapper = ctx->index_mapper(),
+ left_cardinality = dataset()->input_cardinality_,
+ right_cardinality = dataset()->to_concatenate_cardinality_](
+ size_t to_idx) -> absl::StatusOr<size_t> {
+ TF_ASSIGN_OR_RETURN(size_t from_idx, index_mapper(to_idx));
+
+ if (from_idx >= left_cardinality + right_cardinality) {
+ return absl::OutOfRangeError("Running out of elements.");
+ }
+ if (from_idx < left_cardinality) {
+ // This has to return a status so that upstream global shuffle
+ // iterator will not treat it as an end of sequence.
+ return absl::NotFoundError("Skipping this element.");
+ }
+ return from_idx - left_cardinality;
+ };
+
+ input_contexts_[0].SetIndexMapper(left_index_mapper);
+ input_contexts_[1].SetIndexMapper(right_index_mapper);
+ }
+
+ // Materializes the shuffled index because we need this information
+ // to determine which iterator we need to call later.
+
+ absl::StatusOr<size_t> shuffled_index =
+ GetNextShuffledIndex(ctx->index_mapper(), element_count_);
+
+ if (absl::IsOutOfRange(shuffled_index.status())) {
+ *end_of_sequence = true;
+ return absl::OkStatus();
+ }
+
+ TF_RETURN_IF_ERROR(shuffled_index.status());
+
+ // Routes the shuffled index to the correct input iterator.
+ bool temp_end_of_sequence = false;
+ absl::Status status = absl::OkStatus();
+ if (shuffled_index.value() < dataset()->input_cardinality_) {
+ status = input_impls_[0]->GetNext(&input_contexts_[0], out_tensors,
+ &temp_end_of_sequence);
+ ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
+ } else {
+ status = input_impls_[1]->GetNext(&input_contexts_[1], out_tensors,
+ &temp_end_of_sequence);
+ ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+ }
+ TF_RETURN_IF_ERROR(status);
+
+ if (temp_end_of_sequence) {
+ *end_of_sequence = temp_end_of_sequence;
+ return absl::OkStatus();
+ }
+ return absl::OkStatus();
+ }
+
+ for (; i_ < 2; ++i_) {
+ TF_RETURN_IF_ERROR(input_impls_[i_]->GetNext(
+ &input_contexts_[i_], out_tensors, end_of_sequence));
ctx->MergeCheckpoint(input_contexts_[i_].checkpoint());
if (!*end_of_sequence) {
return absl::OkStatus();
}
- if (++i_ < 2) {
+ if (i_ == 0) {
+ // Creates the second iterator only when the first iterator
+ // is exhausted to save memory usage.
TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
- &input_contexts_[i_], this, strings::StrCat(prefix(), "[1]"),
- &input_impl_));
+ &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+ &input_impls_[1]));
+ ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
}
}
*end_of_sequence = true;
- input_impl_.reset();
+ input_impls_[0].reset();
+ input_impls_[1].reset();
return absl::OkStatus();
}
@@ -196,10 +332,18 @@
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kIndex, i_));
TF_RETURN_IF_ERROR(
- writer->WriteScalar(prefix(), kInputImplUninitialized,
- static_cast<int64_t>(!input_impl_)));
- if (input_impl_) {
- TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
+ writer->WriteScalar(prefix(), kElementCount, element_count_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 0),
+ static_cast<int64_t>(!input_impls_[0])));
+ if (input_impls_[0]) {
+ TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impls_[0]));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 1),
+ static_cast<int64_t>(!input_impls_[1])));
+ if (input_impls_[1]) {
+ TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impls_[1]));
}
return absl::OkStatus();
}
@@ -207,33 +351,96 @@
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kIndex, &i_));
- int64_t input_uninitialized;
- TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kInputImplUninitialized,
- &input_uninitialized));
- if (static_cast<bool>(input_uninitialized)) {
- input_impl_.reset();
+
+ int64_t input_uninitialized[2];
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 0),
+ &input_uninitialized[0]));
+ if (static_cast<bool>(input_uninitialized[0])) {
+ input_impls_[0].reset();
+ }
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 1),
+ &input_uninitialized[1]));
+ if (static_cast<bool>(input_uninitialized[1])) {
+ input_impls_[1].reset();
+ }
+
+ if (ctx->restored_element_count()) {
+ if (input_impls_.size() != 2) {
+ return absl::FailedPreconditionError(
+ "`Initialize` should be called before restoring from the "
+ "checkpoint.");
+ }
+ {
+ int64_t tmp_element_count;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(prefix(), kElementCount, &tmp_element_count));
+ if (tmp_element_count < 0) {
+ return absl::FailedPreconditionError(absl::StrFormat(
+ "element_count should be >= 0. Got %d", tmp_element_count));
+ }
+ element_count_ = static_cast<size_t>(tmp_element_count);
+ }
+
+ if (!static_cast<bool>(input_uninitialized[0])) {
+ if (!input_impls_[0]) {
+ return absl::FailedPreconditionError(
+ "Something went wrong internally. The first iterator should "
+ "exist because of `Initialize`.");
+ }
+ input_contexts_[0].set_restored_element_count(
+ *ctx->restored_element_count());
+ TF_RETURN_IF_ERROR(
+ RestoreInput(&input_contexts_[0], reader, input_impls_[0]));
+ ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
+ }
+
+ if (!static_cast<bool>(input_uninitialized[1])) {
+ TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+ &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+ &input_impls_[1]));
+
+ input_contexts_[1].set_restored_element_count(
+ *ctx->restored_element_count());
+
+ TF_RETURN_IF_ERROR(
+ RestoreInput(&input_contexts_[1], reader, input_impls_[1]));
+ ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+ }
return absl::OkStatus();
}
+
+ TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kIndex, &i_));
+
if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
return errors::InvalidArgument("i_ must be in range [0, 2].");
- if (i_ == 1) {
+
+ if (!static_cast<bool>(input_uninitialized[0])) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impls_[0]));
+ }
+ if (!static_cast<bool>(input_uninitialized[1])) {
TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
- ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
- } else if (i_ == 2) {
- input_impl_.reset();
+ &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+ &input_impls_[1]));
+ ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+
+ TF_RETURN_IF_ERROR(
+ RestoreInput(&input_contexts_[1], reader, input_impls_[1]));
+ ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
}
- if (input_impl_) {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- }
+
return absl::OkStatus();
}
private:
mutex mu_;
int64_t i_ TF_GUARDED_BY(mu_);
- std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
- std::vector<IteratorContext> input_contexts_;
+ std::vector<std::unique_ptr<IteratorBase>> input_impls_ TF_GUARDED_BY(mu_);
+ std::vector<IteratorContext> input_contexts_ TF_GUARDED_BY(mu_);
+ // Indicates `ctx->index_mapper()(element_count_)` is the next
+ // shuffled index.
+ size_t element_count_ TF_GUARDED_BY(mu_) = 0;
};
Status MostSpecificCompatibleShape(const PartialTensorShape& ts1,
@@ -257,6 +464,7 @@
const int64_t input_cardinality_;
const int64_t to_concatenate_cardinality_;
std::vector<PartialTensorShape> output_shapes_;
+ absl::Status random_indexing_compatible_ = absl::OkStatus();
};
ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc
index f15bd39..936ad1f 100644
--- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc
@@ -173,17 +173,8 @@
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- if (ctx->restored_element_count().has_value()) {
- num_elements_ = *(ctx->restored_element_count());
- // If the dataset has reached the end of sequence, the restored element
- // count could be cardinality + 1.
- if (num_elements_ > dataset()->Cardinality()) {
- num_elements_ = dataset()->Cardinality();
- }
- } else {
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("num_elements"), &num_elements_));
- }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("num_elements"), &num_elements_));
return RestoreInput(ctx, reader, input_impl_);
}
diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
index 8bfc9ad..91b596c 100644
--- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
@@ -251,7 +251,7 @@
const auto& t_flat = t.flat<T>();
// TODO(mrry): Replace with a memcpy or something more
// efficient. (Maybe an Eigen assign op?)
- gtl::InlinedVector<int64_t, 4> strides(row_ndims);
+ absl::InlinedVector<int64_t, 4UL> strides(row_ndims);
if (!strides.empty()) {
strides[row_ndims - 1] = 1;
for (int64_t row_dim = strides.size() - 2; row_dim >= 0;
diff --git a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc
index c937f61..20d65e8 100644
--- a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc
@@ -235,15 +235,8 @@
TF_ASSIGN_OR_RETURN(element_position,
parent_index_mapper(element_position));
}
- // This could happen if the source dataset generates more elements than
- // needed by the intermediate transformations. For example, when shuffling
- // `range(10).batch(3, drop_remainder=True)`, the last element of `range`
- // has index 9, which maps to the 4th batched element. However, since
- // `batch` drops remainders, the cardinality is 3. In this case, the
- // element position exceeds the max index. The caller is responsible to
- // handle this case properly.
if (element_position > max_index) {
- return element_position;
+ return absl::OutOfRangeError("Out of range");
}
if (max_index == 0) {
return 0;
@@ -265,6 +258,7 @@
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed, seed_));
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed2, seed2_));
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed3, seed3_));
+ TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
return absl::OkStatus();
}
diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc
index c93ccce8..2852b44 100644
--- a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc
@@ -186,14 +186,16 @@
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
- return split_provider_->Save(
- [this](const std::string& key) { return full_name(key); }, writer);
+ TF_RETURN_IF_ERROR(split_provider_->Save(
+ [this](const std::string& key) { return full_name(key); }, writer));
+ TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
+ return absl::OkStatus();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
if (ctx->restored_element_count().has_value()) {
- return global_shuffle_iterator_.Restore(ctx);
+ return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
}
return split_provider_->Restore(
[this](const std::string& key) { return full_name(key); }, reader);
diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc
index 86f4b00..44e25cd 100644
--- a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc
@@ -96,7 +96,7 @@
const std::vector<std::vector<Tensor>>& input_elements) {
std::vector<PartialTensorShape> output_shapes;
for (const auto& tensor : input_elements.front()) {
- gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
+ absl::InlinedVector<int64_t, 4UL> partial_dim_sizes;
partial_dim_sizes.reserve(tensor.dims());
for (int i = 0; i < tensor.dims(); ++i) {
partial_dim_sizes.push_back(tensor.dim_size(i));
diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
index 7021a23..cd9bb2a 100644
--- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
@@ -164,7 +164,7 @@
DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<std::pair<size_t, Node*>> inputs;
- std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
+ std::vector<std::pair<size_t, absl::Span<Node* const>>> list_inputs;
int input_index = 0;
Node* input_node;
diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
index 3cc3320..5682d19 100644
--- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
@@ -72,7 +72,7 @@
if (batch_size_ < 0 && shape.dim_size(0) >= 0) {
batch_size_ = shape.dim_size(0);
}
- gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
+ absl::InlinedVector<int64_t, 4UL> partial_dim_sizes;
for (int i = 1; i < shape.dims(); ++i) {
partial_dim_sizes.push_back(shape.dim_size(i));
}
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index bdd95f0..cd03a09 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -141,7 +141,12 @@
}
absl::Status RandomIndexingCompatible() const override {
- return random_indexing_compatible_;
+ return absl::UnimplementedError(
+ "Please consider applying maps on each dataset, concatenating them "
+ "into "
+ "one dataset and apply global shuffle dataset op onto the "
+ "dataset to achieve the same result as flat map with global "
+ "shuffling.");
}
protected:
@@ -358,7 +363,10 @@
return absl::OkStatus();
}
- // TODO(b/325112575): Refactor and reuse this code from weighted flat map.
+ // TODO: b/355241367 - This implementation is incorrect because IndexMapper
+ // should be stateless otherwise it would not be compatible with batch
+ // dataset op.
+ // See go/tf-data-random-access-iterator-for-concatenate for more info.
IndexMapperFn GetFlatMapIndexMapper(IndexMapperFn parent_index_mapper,
size_t input_dataset_index)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
diff --git a/tensorflow/core/kernels/data/map_defun_op_test.cc b/tensorflow/core/kernels/data/map_defun_op_test.cc
index d48650c..aaf292e 100644
--- a/tensorflow/core/kernels/data/map_defun_op_test.cc
+++ b/tensorflow/core/kernels/data/map_defun_op_test.cc
@@ -104,9 +104,10 @@
}
// Creates a new `MapDefun` op kernel context.
- Status CreateMapDefunContext(OpKernel* const op_kernel,
- gtl::InlinedVector<TensorValue, 4>* const inputs,
- std::unique_ptr<OpKernelContext>* context) {
+ Status CreateMapDefunContext(
+ OpKernel* const op_kernel,
+ absl::InlinedVector<TensorValue, 4UL>* const inputs,
+ std::unique_ptr<OpKernelContext>* context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return absl::OkStatus();
@@ -243,7 +244,7 @@
TestCase test_case = GetParam();
TF_ASSERT_OK(InitializeRuntime(test_case.map_defun_op_params));
auto input_tensors = test_case.map_defun_op_params.GetInputTensors();
- gtl::InlinedVector<TensorValue, 4> input_values;
+ absl::InlinedVector<TensorValue, 4UL> input_values;
for (auto& input : input_tensors) {
input_values.push_back(TensorValue(&input));
}
@@ -272,7 +273,7 @@
for (auto& test_case : test_cases) {
TF_ASSERT_OK(InitializeRuntime(test_case.map_defun_op_params));
auto input_tensors = test_case.map_defun_op_params.GetInputTensors();
- gtl::InlinedVector<TensorValue, 4> input_values;
+ absl::InlinedVector<TensorValue, 4UL> input_values;
for (auto& input : input_tensors) {
input_values.push_back(TensorValue(&input));
}
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index fcb0c11..72b230c 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -17,6 +17,7 @@
#include <cstddef>
#include <deque>
#include <functional>
+#include <limits>
#include <memory>
#include <optional>
#include <string>
@@ -31,6 +32,7 @@
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/stats_utils.h"
+#include "tensorflow/core/data/unbounded_thread_pool.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/model.h"
@@ -92,17 +94,19 @@
const std::vector<PartialTensorShape>& output_shapes,
DeterminismPolicy deterministic,
std::unique_ptr<CapturedFunction> captured_func,
- bool preserve_cardinality, int op_version)
+ bool preserve_cardinality, bool use_unbounded_threadpool,
+ int op_version)
: Dataset(DatasetContext(ctx), input, num_parallel_calls, output_types,
output_shapes, deterministic, std::move(captured_func),
- preserve_cardinality, op_version) {}
+ preserve_cardinality, use_unbounded_threadpool, op_version) {}
Dataset(DatasetContext dataset_context, const DatasetBase* input,
int64_t num_parallel_calls, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
DeterminismPolicy deterministic,
std::unique_ptr<CapturedFunction> captured_func,
- bool preserve_cardinality, int op_version)
+ bool preserve_cardinality, bool use_unbounded_threadpool,
+ int op_version)
: DatasetBase(std::move(dataset_context)),
input_(input),
num_parallel_calls_(num_parallel_calls),
@@ -110,6 +114,7 @@
output_shapes_(output_shapes),
deterministic_(deterministic),
preserve_cardinality_(preserve_cardinality),
+ use_unbounded_threadpool_(use_unbounded_threadpool),
captured_func_(std::move(captured_func)),
op_version_(op_version) {
input_->Ref();
@@ -235,6 +240,12 @@
b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
attrs.emplace_back(kPreserveCardinality, preserve_cardinality_attr);
+ // Attr: use_unbounded_threadpool
+ AttrValue use_unbounded_threadpool_attr;
+ b->BuildAttrValue(use_unbounded_threadpool_,
+ &use_unbounded_threadpool_attr);
+ attrs.emplace_back(kUseUnboundedThreadpool, use_unbounded_threadpool_attr);
+
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{std::make_pair(0, input_graph_node),
@@ -256,6 +267,7 @@
deterministic_(params.dataset->deterministic_.IsDeterministic() ||
params.dataset->deterministic_.IsDefault()),
preserve_cardinality_(params.dataset->preserve_cardinality_),
+ use_unbounded_threadpool_(params.dataset->use_unbounded_threadpool_),
autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {}
~Iterator() override {
@@ -271,7 +283,10 @@
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
interleave_depth_ = ctx->interleave_depth();
-
+ if (use_unbounded_threadpool_) {
+ unbounded_thread_pool_ = std::make_unique<UnboundedThreadPool>(
+ ctx->env(), "tf_data_map_unbounded_thread_pool");
+ }
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx);
}
@@ -323,11 +338,17 @@
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
std::shared_ptr<model::Parameter> parameter;
+ // If unbounded threadpool is used, sets the max of `num_parallel_calls`
+ // to be infinite and lets Autotune find the right value that is under
+ // the ram budget.
+ double max_parallelism_value = use_unbounded_threadpool_
+ ? std::numeric_limits<double>::max()
+ : ctx->runner_threadpool_size();
if (num_parallel_calls_ &&
dataset()->num_parallel_calls_ == model::kAutotune) {
parameter = model::MakeParameter(
"parallelism", num_parallel_calls_, /*min=*/1,
- /*max=*/ctx->runner_threadpool_size(),
+ /*max=*/max_parallelism_value,
// This is to ensure before this op has seen its first element,
// `MaximumBufferedBytes()` can use the correct `parameter->value`
// to estimate the maximum buffer bytes.
@@ -335,7 +356,7 @@
} else {
parameter =
model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
- /*max=*/ctx->runner_threadpool_size());
+ /*max=*/max_parallelism_value);
}
std::optional<int64_t> estimated_element_size =
dataset()->GetEstimatedElementSize();
@@ -394,10 +415,6 @@
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- if (ctx->restored_element_count().has_value()) {
- return RestoreInput(ctx, reader, input_impl_);
- }
-
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
DCHECK(invocation_results_.empty());
@@ -456,6 +473,9 @@
std::make_pair("autotune", autotune_ ? "true" : "false"));
result.push_back(
std::make_pair("deterministic", deterministic_ ? "true" : "false"));
+ result.push_back(
+ std::make_pair("use_unbounded_threadpool",
+ use_unbounded_threadpool_ ? "true" : "false"));
result.push_back(std::make_pair(
"parallelism",
parallelism == -1
@@ -543,7 +563,15 @@
// Apply the map function on `input_element`, storing the result in
// `result->return_values`, and invoking `done` when finished.
- if (dataset()->captured_func_->use_inter_op_parallelism()) {
+ if (use_unbounded_threadpool_) {
+ auto runner_fn = [this](std::function<void()> fn) {
+ this->unbounded_thread_pool_->Schedule(fn);
+ };
+ instantiated_captured_func_->RunAsync(
+ runner_fn, ctx->cancellation_manager(), ctx->collective_executor(),
+ std::move(input_element), &result->return_values, done,
+ model_node());
+ } else if (dataset()->captured_func_->use_inter_op_parallelism()) {
instantiated_captured_func_->RunAsync(
ctx.get(), std::move(input_element), &result->return_values,
std::move(done), model_node());
@@ -751,6 +779,7 @@
const std::shared_ptr<model::SharedState> num_parallel_calls_;
const bool deterministic_;
const bool preserve_cardinality_;
+ const bool use_unbounded_threadpool_;
const bool autotune_;
// Counts the number of outstanding calls.
int64_t num_calls_ TF_GUARDED_BY(*mu_) = 0;
@@ -767,6 +796,7 @@
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
+ std::unique_ptr<UnboundedThreadPool> unbounded_thread_pool_;
// Method for deregistering the cancellation callback.
std::function<void()> deregister_fn_;
@@ -784,6 +814,7 @@
const std::vector<PartialTensorShape> output_shapes_;
const DeterminismPolicy deterministic_;
const bool preserve_cardinality_;
+ const bool use_unbounded_threadpool_;
const std::unique_ptr<CapturedFunction> captured_func_;
const int op_version_;
// This is used for random access provided by Get().
@@ -812,12 +843,15 @@
} else {
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
}
+ use_unbounded_threadpool_ = false;
}
if (op_version_ == 2) {
std::string deterministic;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
OP_REQUIRES_OK(
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr(kUseUnboundedThreadpool, &use_unbounded_threadpool_));
}
OP_REQUIRES_OK(ctx,
ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
@@ -849,10 +883,10 @@
metrics::RecordTFDataAutotune(kDatasetType);
}
- *output =
- new Dataset(ctx, input, num_parallel_calls, output_types_, output_shapes_,
- deterministic_, std::move(captured_func),
- preserve_cardinality_, op_version_);
+ *output = new Dataset(ctx, input, num_parallel_calls, output_types_,
+ output_shapes_, deterministic_,
+ std::move(captured_func), preserve_cardinality_,
+ use_unbounded_threadpool_, op_version_);
}
std::unique_ptr<DatasetBase> MakeDataServiceUncompressDataset(
@@ -867,7 +901,8 @@
/*num_parallel_calls=*/model::kAutotune, output_types, output_shapes,
DeterminismPolicy(DeterminismPolicy::Type::kDefault),
std::move(captured_function),
- /*preserve_cardinality=*/true, /*op_version=*/2);
+ /*preserve_cardinality=*/true,
+ /*use_unbounded_threadpool=*/false, /*op_version=*/2);
}
namespace {
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.h b/tensorflow/core/kernels/data/parallel_map_dataset_op.h
index 4e1e564..efdf633 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.h
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.h
@@ -38,6 +38,8 @@
static constexpr const char* const kSloppy = "sloppy";
static constexpr const char* const kPreserveCardinality =
"preserve_cardinality";
+ static constexpr const char* const kUseUnboundedThreadpool =
+ "use_unbounded_threadpool";
explicit ParallelMapDatasetOp(OpKernelConstruction* ctx);
@@ -54,6 +56,7 @@
bool sloppy_;
bool preserve_cardinality_;
DeterminismPolicy deterministic_;
+ bool use_unbounded_threadpool_;
friend std::unique_ptr<DatasetBase> MakeDataServiceUncompressDataset(
DatasetBase* input, std::unique_ptr<CapturedFunction> captured_function,
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
index 357e279..cedc0e8 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
@@ -12,10 +12,10 @@
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/dataset_test_base.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 0a7779c..c6238a0 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -297,11 +297,6 @@
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- if (ctx->restored_element_count().has_value()) {
- tsl::mutex_lock l(input_mu_);
- return RestoreInput(ctx, reader, input_impl_);
- }
-
mutex_lock input_l(input_mu_);
mutex_lock l(*mu_);
DCHECK(!prefetch_thread_);
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index bf71d67..5834494 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -332,13 +332,14 @@
TF_RETURN_IF_ERROR(
writer->WriteScalar(prefix(), kNext, counter_->Peek()));
}
+ TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
return absl::OkStatus();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
if (ctx->restored_element_count().has_value()) {
- return global_shuffle_iterator_.Restore(ctx);
+ return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
}
if (reader->Contains(prefix(), kHasSplitProvider)) {
TF_RETURN_IF_ERROR(split_provider_->Restore(
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 820fc9b..555bdba 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -319,7 +319,7 @@
if (element_position >= input_cardinality) {
// The input element position is out-of-range. The caller is
// responsible for handle this case (e.g.: returning end_of_sequence).
- return element_position;
+ return absl::OutOfRangeError("Finite repeat is out of range");
}
// First, maps the input indices from
@@ -356,28 +356,37 @@
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
+ int64_t input_empty;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIteration, &i_));
+
if (ctx->restored_element_count().has_value()) {
CardinalityOptions options;
options.set_compute_level(
CardinalityOptions::CARDINALITY_COMPUTE_MODERATE);
const int64_t input_cardinality =
dataset()->input_->Cardinality(std::move(options));
- i_ = *ctx->restored_element_count() / input_cardinality;
// For upstream iterators, the restored element count should be the
// element count within the current repetition.
IteratorContext::Params params(ctx);
params.restored_element_count =
- *ctx->restored_element_count() % input_cardinality;
+ *ctx->restored_element_count() % (input_cardinality);
params.index_mapper = GetIndexMapper(ctx->index_mapper());
IteratorContext ctx_with_restored_element_count(params);
- return RestoreInput(&ctx_with_restored_element_count, reader,
- input_impl_);
+ if (!input_empty) {
+ // Needs to re-`MakeIterator` because `i_` might have changed.
+ TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
+ ctx, this, nested_prefix(prefix(), i_), &input_impl_));
+ TF_RETURN_IF_ERROR(RestoreInput(&ctx_with_restored_element_count,
+ reader, input_impl_));
+ ctx->MergeCheckpoint(ctx_with_restored_element_count.checkpoint());
+ } else {
+ input_impl_.reset();
+ }
+ return absl::OkStatus();
}
- TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIteration, &i_));
- int64_t input_empty;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
if (static_cast<bool>(!input_empty)) {
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
ctx, this, nested_prefix(prefix(), i_), &input_impl_));
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index 695263e..950be9b 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -280,7 +280,7 @@
"is not currently supported."));
previous_batch_index = next_batch_index;
}
- gtl::InlinedVector<int64_t, 8> std_order(dense_shape->NumElements(), 0);
+ absl::InlinedVector<int64_t, 8UL> std_order(dense_shape->NumElements(), 0);
TensorShape shape;
OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape(
dense_shape->vec<int64_t>(), &shape));
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index a08890e..d910271 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -187,12 +187,6 @@
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- if (ctx->restored_element_count().has_value()) {
- mutex_lock l(mu_);
- i_ = *ctx->restored_element_count();
- return RestoreInput(ctx, reader, input_impl_);
- }
-
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIndex, &i_));
int64_t input_empty;
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index 02736e5..fe2d564 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -184,13 +184,14 @@
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kProduced,
static_cast<int64_t>(produced_)));
+ TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
return absl::OkStatus();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
if (ctx->restored_element_count().has_value()) {
- return global_shuffle_iterator_.Restore(ctx);
+ return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
}
mutex_lock l(mu_);
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index dad1e8c..3e2374b 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -55,7 +55,7 @@
replicate_on_split_(replicate_on_split) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
- gtl::InlinedVector<int64_t, 4> element_dim_sizes;
+ absl::InlinedVector<int64_t, 4UL> element_dim_sizes;
// Handle scalar here. Check that everyone matches here? Or fail
// at runtime?
for (int i = 1; i < t.dims(); ++i) {
@@ -206,14 +206,16 @@
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
- return split_provider_->Save(
- [this](const std::string& key) { return full_name(key); }, writer);
+ TF_RETURN_IF_ERROR(split_provider_->Save(
+ [this](const std::string& key) { return full_name(key); }, writer));
+ TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
+ return absl::OkStatus();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
if (ctx->restored_element_count().has_value()) {
- return global_shuffle_iterator_.Restore(ctx);
+ return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
}
return split_provider_->Restore(
[this](const std::string& key) { return full_name(key); }, reader);
diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
index 3c5ed02..99a29eb 100644
--- a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
@@ -20,6 +20,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/dataset_test_base.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/framework/dataset.h"
@@ -34,7 +35,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index a5c0e49..c9d2a5a 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -177,6 +177,7 @@
std::vector<std::vector<Tensor>> elements;
for (size_t i = 0; i < num_elements; ++i) {
std::vector<Tensor> element;
+ element.reserve(element_size);
for (size_t j = 0; j < element_size; ++j) {
element.push_back(std::move(inputs[i * element_size + j]));
}
diff --git a/tensorflow/core/kernels/data/window_dataset_op_test.cc b/tensorflow/core/kernels/data/window_dataset_op_test.cc
index 3252f52..6dfc803 100644
--- a/tensorflow/core/kernels/data/window_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op_test.cc
@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
#include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/dataset_test_base.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/serialization_utils.h"
@@ -27,7 +28,6 @@
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index e23c181..f2891dd 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -272,6 +272,9 @@
mutex_lock l(mu_);
// Note: When restoring, `SaveInternal` would not be called
// if there is a global_shuffle_dataset_op.cc above this op.
+ int64_t inputs_empty;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(prefix(), kInputImplsEmpty, &inputs_empty));
if (ctx->restored_element_count()) {
if (input_impls_.size() != dataset()->inputs_.size()) {
return absl::FailedPreconditionError(
@@ -283,14 +286,19 @@
"ctx->index_mapper() should be provided along with "
"ctx->restored_element_count() when restoring.");
}
- for (const auto& input_impl : input_impls_) {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl));
+ if (static_cast<bool>(inputs_empty)) {
+ input_impls_.clear();
+ } else {
+ for (int i = 0; i < input_impls_.size(); ++i) {
+ input_contexts_[i].set_restored_element_count(
+ ctx->restored_element_count().value());
+ TF_RETURN_IF_ERROR(
+ RestoreInput(&input_contexts_[i], reader, input_impls_[i]));
+ ctx->MergeCheckpoint(input_contexts_[i].checkpoint());
+ }
}
return absl::OkStatus();
}
- int64_t inputs_empty;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(prefix(), kInputImplsEmpty, &inputs_empty));
if (static_cast<bool>(inputs_empty)) {
input_impls_.clear();
} else {
diff --git a/tensorflow/core/kernels/data/zip_dataset_op_test.cc b/tensorflow/core/kernels/data/zip_dataset_op_test.cc
index c864287..9d8eed6 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op_test.cc
@@ -17,6 +17,7 @@
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/data/dataset_test_base.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/framework/tensor.h"
@@ -25,7 +26,6 @@
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index d7c0c76..15ff88c 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -190,7 +190,7 @@
LOG(ERROR) << "Debug node of watch key "
<< debug_watch_key_->debug_node_name
<< " failed to publish debug tensor data to all URLs "
- << str_util::Join(debug_urls_, ", ")
+ << absl::StrJoin(debug_urls_, ", ")
<< ", due to: " << status.message();
}
return status;
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 7c865b6..0871890 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -255,7 +255,7 @@
args.push_back(ctx->input(i));
}
std::vector<Tensor>* rets = new std::vector<Tensor>;
- profiler::TraceMe trace_me("SymbolicGradientOp");
+ tsl::profiler::TraceMe trace_me("SymbolicGradientOp");
lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
if (!status.ok()) {
ctx->SetStatus(status);
@@ -319,12 +319,12 @@
handle = cached_entry->second;
} else {
VLOG(1) << "Instantiating " << func_name << " on " << target_device;
- profiler::TraceMe activity(
+ tsl::profiler::TraceMe activity(
[&] {
return strings::StrCat("RemoteCall: Instantiate: ", func_name,
" on ", target_device);
},
- profiler::TraceMeLevel::kInfo);
+ tsl::profiler::TraceMeLevel::kInfo);
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
const auto* config = (ctx->function_library())
? ctx->function_library()->config_proto()
@@ -398,24 +398,24 @@
auto* rets = new std::vector<Tensor>;
VLOG(1) << "Running " << func_name << " on " << target_device
<< " with handle: " << handle;
- profiler::TraceMe trace_me(
+ tsl::profiler::TraceMe trace_me(
[&] {
return profiler::TraceMeEncode(
"RemoteCallOp",
{{"func_name", func_name}, {"device", target_device}});
},
- profiler::TraceMeLevel::kInfo);
+ tsl::profiler::TraceMeLevel::kInfo);
lib->Run(
opts, handle, args, rets,
[rets, done = std::move(done), func_name, ctx, cancel_mgr,
target_device = std::move(function_target.first)](const Status& status) {
- profiler::TraceMe activity(
+ tsl::profiler::TraceMe activity(
[&] {
return profiler::TraceMeEncode(
"RemoteCallOpDone",
{{"func_name", func_name}, {"device", target_device}});
},
- profiler::TraceMeLevel::kInfo);
+ tsl::profiler::TraceMeLevel::kInfo);
if (!status.ok()) {
ctx->SetStatus(status);
} else {
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index 79c393f..7cf465b 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -199,7 +199,7 @@
void Start() {
FHandle handle = cond_ ? then_handle_ : else_handle_;
rets_.clear();
- profiler::TraceMe trace_me("IfOp");
+ tsl::profiler::TraceMe trace_me("IfOp");
lib_->Run(
// Evaluate one of the branch.
opts_, handle, args_, &rets_,
@@ -378,7 +378,7 @@
branch = branch_handles_.size() - 1;
}
rets_.clear();
- profiler::TraceMe trace_me("CaseOp");
+ tsl::profiler::TraceMe trace_me("CaseOp");
lib_->Run(
// Evaluate one of the branch.
opts_, branch_handles_[branch], args_, &rets_,
@@ -633,7 +633,7 @@
std::unique_ptr<BodyFuncCallFrame> body_frame_;
void EvalCond() {
- profiler::TraceMe trace_me("WhileOp-EvalCond");
+ tsl::profiler::TraceMe trace_me("WhileOp-EvalCond");
lib_->Run(
// Evaluate the condition.
opts_, cond_handle_, args_, &rets_,
@@ -669,7 +669,7 @@
}
rets_.clear();
rets_.resize(args_.size());
- profiler::TraceMe trace_me("WhileOp-StartBody");
+ tsl::profiler::TraceMe trace_me("WhileOp-StartBody");
lib_->Run(
// Evaluate the body.
opts_, body_handle_, body_frame_.get(),
@@ -724,7 +724,7 @@
do {
// Evaluate the cond function on the current loop variables.
{
- profiler::TraceMe trace_me("WhileOp-EvalCond");
+ tsl::profiler::TraceMe trace_me("WhileOp-EvalCond");
TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
}
if (cond_rets.size() != 1) {
@@ -745,7 +745,7 @@
// Evaluate the body function on the current loop variables, to get an
// updated vector of loop variables.
{
- profiler::TraceMe trace_me("WhileOp-StartBody");
+ tsl::profiler::TraceMe trace_me("WhileOp-StartBody");
body_rets.resize(num_loop_vars);
BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types);
TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame));
@@ -982,7 +982,7 @@
args_[1 + i] = std::move(rets_[i]);
}
rets_.clear();
- profiler::TraceMe trace_me("ForOp");
+ tsl::profiler::TraceMe trace_me("ForOp");
lib_->Run(opts_, body_handle_, args_, &rets_, [this](const Status& s) {
if (s.ok()) {
*iter_ += delta_;
diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc
index 3212068..2758fbb 100644
--- a/tensorflow/core/kernels/gather_nd_op_test.cc
+++ b/tensorflow/core/kernels/gather_nd_op_test.cc
@@ -18,6 +18,7 @@
#include <vector>
#include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
@@ -37,7 +38,6 @@
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc
index 77f1725..209dbbd 100644
--- a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc
+++ b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc
@@ -92,7 +92,13 @@
std::array<float, 4>* weights,
std::array<int64_t, 4>* indices) {
const int64_t in_loc = scale * out_loc;
- const float delta = scale * out_loc - in_loc;
+ // Ensure that the following calculation is kept in a float to match the
+ // rounding done in the optimised case. Merging it with the following line
+ // keeps an intermediate value at higher precision and that leads to a
+ // divergence in the result. So keep the following two lines separate to
+ // ensure that the calculation is rounded as expected.
+ const float in_loc_float = scale * out_loc;
+ const float delta = in_loc_float - in_loc;
const int64_t offset = lrintf(delta * kTableSize);
const float* coeffs_tab = GetCoeffsTable();
*weights = {{coeffs_tab[offset * 2 + 1], coeffs_tab[offset * 2],
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 7619201..d07b4b9 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -257,7 +257,7 @@
std::vector<Tensor>* rets = new std::vector<Tensor>;
const string& func_name = func_->name();
- profiler::TraceMe trace_me("PartitionedCallOp");
+ tsl::profiler::TraceMe trace_me("PartitionedCallOp");
lib->Run(run_opts, handle, inputs, rets,
[rets, done = std::move(done), ctx, func_name,
step_container](const Status& status) {
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
index 2e88088..02fa44f 100644
--- a/tensorflow/core/kernels/scatter_nd_op_test.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -20,6 +20,7 @@
#include "absl/status/status.h"
#include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
@@ -38,7 +39,6 @@
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc
index 7f76cb4..d15cc7f 100644
--- a/tensorflow/core/kernels/sendrecv_ops.cc
+++ b/tensorflow/core/kernels/sendrecv_ops.cc
@@ -120,8 +120,8 @@
auto dst_it = attr.find("_dst");
const string& src = src_it != attr.end() ? src_it->second.s() : "";
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
- string op = profiler::TraceMeOp(name_view(), type_string_view());
- return profiler::TraceMeEncode(
+ string op = tsl::profiler::TraceMeOp(name_view(), type_string_view());
+ return tsl::profiler::TraceMeEncode(
std::move(op),
{{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}});
}
@@ -166,8 +166,8 @@
auto dst_it = attr.find("_dst");
const string& src = src_it != attr.end() ? src_it->second.s() : "";
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
- string op = profiler::TraceMeOp(name_view(), type_string_view());
- return profiler::TraceMeEncode(
+ string op = tsl::profiler::TraceMeOp(name_view(), type_string_view());
+ return tsl::profiler::TraceMeEncode(
std::move(op),
{{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}});
}
diff --git a/tensorflow/core/kernels/spectrogram_op_test.cc b/tensorflow/core/kernels/spectrogram_op_test.cc
index 791fdda..650024b 100644
--- a/tensorflow/core/kernels/spectrogram_op_test.cc
+++ b/tensorflow/core/kernels/spectrogram_op_test.cc
@@ -25,13 +25,13 @@
#include "tensorflow/cc/ops/audio_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/kernels/stochastic_cast_op_test.cc b/tensorflow/core/kernels/stochastic_cast_op_test.cc
index b0a5835..10d9eae1 100644
--- a/tensorflow/core/kernels/stochastic_cast_op_test.cc
+++ b/tensorflow/core/kernels/stochastic_cast_op_test.cc
@@ -21,6 +21,7 @@
#include <gtest/gtest.h>
#include "Eigen/Core" // from @eigen_archive
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -30,7 +31,6 @@
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/lib/random/philox_random.h"
namespace Eigen {
diff --git a/tensorflow/core/kernels/uniform_quant_ops/BUILD b/tensorflow/core/kernels/uniform_quant_ops/BUILD
index 5c158fe..507a857 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/BUILD
+++ b/tensorflow/core/kernels/uniform_quant_ops/BUILD
@@ -191,7 +191,7 @@
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:test",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -205,6 +205,6 @@
"//tensorflow/core/platform:test",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc b/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc
index 4d2e869..a331b8b 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc
+++ b/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc
@@ -17,10 +17,10 @@
#include <limits>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc
index f3d2f2c..c4f0ea5 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc
+++ b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc
@@ -19,9 +19,9 @@
#include <gtest/gtest.h>
#include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc
index 7ffc629..73121de 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc
+++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc
@@ -17,6 +17,7 @@
#include <vector>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/numeric_types.h"
@@ -27,7 +28,6 @@
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h
index ef333ef..3c604ee 100644
--- a/tensorflow/core/lib/core/status_test_util.h
+++ b/tensorflow/core/lib/core/status_test_util.h
@@ -16,7 +16,7 @@
#ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
#define TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#endif // TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc
index 39ea38e..0a0042f 100644
--- a/tensorflow/core/lib/db/sqlite_test.cc
+++ b/tensorflow/core/lib/db/sqlite_test.cc
@@ -17,11 +17,11 @@
#include <array>
#include <climits>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h
index 9481ba7..818ec69 100644
--- a/tensorflow/core/lib/gtl/edit_distance.h
+++ b/tensorflow/core/lib/gtl/edit_distance.h
@@ -59,7 +59,7 @@
if (s == t) return 0;
// Create work vector
- gtl::InlinedVector<int64_t, 32> scratch_holder(t_size);
+ absl::InlinedVector<int64_t, 32UL> scratch_holder(t_size);
int64_t* scratch = scratch_holder.data();
diff --git a/tensorflow/core/lib/strings/BUILD b/tensorflow/core/lib/strings/BUILD
index 72eb0a6..d8f4e6d 100644
--- a/tensorflow/core/lib/strings/BUILD
+++ b/tensorflow/core/lib/strings/BUILD
@@ -51,7 +51,7 @@
name = "proto_serialization",
hdrs = ["proto_serialization.h"],
deps = [
- "@local_tsl//tsl/lib/strings:proto_serialization",
+ "@local_xla//xla/tsl/lib/strings:proto_serialization",
],
)
@@ -116,7 +116,7 @@
"ordered_code.cc",
"ordered_code.h",
"proto_serialization.h",
- "@local_tsl//tsl/lib/strings:mobile_srcs_only_runtime",
+ "@local_xla//xla/tsl/lib/strings:mobile_srcs_only_runtime",
],
visibility = ["//tensorflow/core:__pkg__"],
)
@@ -133,7 +133,7 @@
"str_util.h",
"strcat.h",
"stringprintf.h",
- "@local_tsl//tsl/lib/strings:legacy_lib_strings_all_headers",
+ "@local_xla//xla/tsl/lib/strings:legacy_lib_strings_all_headers",
],
visibility = ["//tensorflow/core:__pkg__"],
)
@@ -165,7 +165,7 @@
"str_util.h",
"strcat.h",
"stringprintf.h",
- "@local_tsl//tsl/lib/strings:legacy_lib_string_headers",
+ "@local_xla//xla/tsl/lib/strings:legacy_lib_string_headers",
],
visibility = ["//tensorflow/core:__pkg__"],
)
@@ -178,7 +178,7 @@
"proto_serialization.h",
"proto_text_util.h",
"scanner.h",
- "@local_tsl//tsl/lib/strings:legacy_lib_internal_public_string_headers",
+ "@local_xla//xla/tsl/lib/strings:legacy_lib_internal_public_string_headers",
],
visibility = ["//tensorflow/core:__pkg__"],
)
diff --git a/tensorflow/core/lib/strings/proto_serialization.h b/tensorflow/core/lib/strings/proto_serialization.h
index 0c01708..e0c253f 100644
--- a/tensorflow/core/lib/strings/proto_serialization.h
+++ b/tensorflow/core/lib/strings/proto_serialization.h
@@ -15,7 +15,7 @@
#ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_
#define TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
namespace tensorflow {
// NOLINTBEGIN(misc-unused-using-decls)
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index ebaade2..b05c412 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -157,7 +157,6 @@
offset_i.push_back(strings::StrCat("offset:offset:", i));
dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
}
- DataTypeVector dtype_list(N, T);
// ConcatGrad(dim, x, dy):
// for i in range(N):
diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc
index 99d4551..6d21ee4 100644
--- a/tensorflow/core/ops/batch_ops.cc
+++ b/tensorflow/core/ops/batch_ops.cc
@@ -76,9 +76,17 @@
// allowed. The following options are available.
//
// - PAD_UP: pad to size 32.
+ // - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the
+ // batch buffer.
+ // - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses
+ // to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per
+ // real request. In this case, it would compare (batch_16_cost / 16) and
+ // (batch_32_cost / 18).
+ //
+ // WARNING: Not all batch schedulers might support this attribute.
.Attr(
"batch_padding_policy: "
- "{'PAD_UP'} = 'PAD_UP'")
+ "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'")
.Attr("Tin: list(type)")
.Attr("Tcaptured: list(type) >= 0")
.Attr("Tout: list(type)")
diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt
index 8fecdf6..d743b8e 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt
@@ -802,3 +802,152 @@
}
is_distributed_communication: true
}
+op {
+ name: "BatchFunction"
+ input_arg {
+ name: "in_tensors"
+ type_list_attr: "Tin"
+ }
+ input_arg {
+ name: "captured_tensors"
+ type_list_attr: "Tcaptured"
+ }
+ output_arg {
+ name: "out_tensors"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "num_batch_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_batch_size"
+ type: "int"
+ }
+ attr {
+ name: "batch_timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "max_enqueued_batches"
+ type: "int"
+ default_value {
+ i: 10
+ }
+ }
+ attr {
+ name: "allowed_batch_sizes"
+ type: "list(int)"
+ default_value {
+ list {
+ }
+ }
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "batching_queue"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "low_priority_max_batch_size"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "low_priority_batch_timeout_micros"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "low_priority_allowed_batch_sizes"
+ type: "list(int)"
+ default_value {
+ list {
+ }
+ }
+ }
+ attr {
+ name: "low_priority_max_enqueued_batches"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "mixed_priority_policy"
+ type: "string"
+ default_value {
+ s: "low_priority_padding_with_max_batch_size"
+ }
+ allowed_values {
+ list {
+ s: "low_priority_padding_with_max_batch_size"
+ s: "low_priority_padding_with_next_allowed_batch_size"
+ s: "priority_isolation"
+ }
+ }
+ }
+ attr {
+ name: "batch_padding_policy"
+ type: "string"
+ default_value {
+ s: "PAD_UP"
+ }
+ allowed_values {
+ list {
+ s: "PAD_UP"
+ s: "BATCH_DOWN"
+ s: "MINIMIZE_TPU_COST_PER_REQUEST"
+ }
+ }
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "enable_large_batch_splitting"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_distributed_communication: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt
index 55e73b7..1f3016b 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt
@@ -290,3 +290,98 @@
}
}
}
+op {
+ name: "ParallelMapDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ experimental_full_type {
+ type_id: TFT_DATASET
+ args {
+ type_id: TFT_FOR_EACH
+ args {
+ type_id: TFT_PRODUCT
+ }
+ args {
+ type_id: TFT_TENSOR
+ args {
+ type_id: TFT_VAR
+ s: "output_types"
+ }
+ }
+ args {
+ type_id: TFT_VAR
+ s: "output_types"
+ }
+ }
+ }
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "deterministic"
+ type: "string"
+ default_value {
+ s: "default"
+ }
+ }
+ attr {
+ name: "preserve_cardinality"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "use_unbounded_threadpool"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "metadata"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+}
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index d347d0c..7e81212 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -206,6 +206,7 @@
// "true", "false", or "default".
.Attr("deterministic: string = 'default'")
.Attr("preserve_cardinality: bool = false")
+ .Attr("use_unbounded_threadpool: bool = false")
.Attr("metadata: string = ''")
.SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
"output_types"))
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d05df09..dcf9e2f 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4112,6 +4112,8 @@
allowed_values {
list {
s: "PAD_UP"
+ s: "BATCH_DOWN"
+ s: "MINIMIZE_TPU_COST_PER_REQUEST"
}
}
}
@@ -33011,6 +33013,13 @@
}
}
attr {
+ name: "use_unbounded_threadpool"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
name: "metadata"
type: "string"
default_value {
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 2f179f7..1dc0e20 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -317,7 +317,7 @@
":test_main",
"//tensorflow/core:protos_all_cc",
"@eigen_archive//:eigen3",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 178a68c..2059e3b 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -19,6 +19,7 @@
#include <memory>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/cord.h"
@@ -29,7 +30,6 @@
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tsl {
diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h
index 916b530..66de83f 100644
--- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h
+++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h
@@ -541,7 +541,7 @@
separator.Add();
output->Append(R"({"args":{"name":)", JsonEscape(device.name()),
R"(},"name":"process_name","ph":"M","pid":)", device_id,
- "}");
+ R"(,"thread_count":)", device.resources_size(), "}");
}
separator.Add();
output->Append(R"({"args":{"sort_index":)", device_id,
diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc
index 3c06bd3..6c80add 100644
--- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc
@@ -24,6 +24,7 @@
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/statusor.h"
@@ -31,7 +32,6 @@
#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
index 5609fd7..8b22847 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
@@ -227,8 +227,6 @@
XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace);
using OpMetricBySymbol =
absl::flat_hash_map</*symbol_id=*/uint64_t, OpMetrics>;
- absl::flat_hash_map</*program_id=*/uint64_t, OpMetricBySymbol> flat_op_metric;
-
XEventsOpMetricsDbBuilder builder;
plane.ForEachLine([&](const XLineVisitor& line) {
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index 70d5efc..d6efbd1 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -54,6 +54,7 @@
using tsl::profiler::kPythonTracerPlaneName; // NOLINT
using tsl::profiler::kRoctracerApiPlaneName; // NOLINT
using tsl::profiler::kSourceLineName; // NOLINT
+using tsl::profiler::kSparseCorePlaneRegex; // NOLINT
using tsl::profiler::kStepLineName; // NOLINT
using tsl::profiler::kTensorFlowNameScopeLineName; // NOLINT
using tsl::profiler::kTensorFlowOpLineName; // NOLINT
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index cd5d889..fbf446c 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 1940 // Updated: 2024/7/31
+#define TF_GRAPH_DEF_VERSION 1952 // Updated: 2024/8/12
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc
index bdb6f9e..5b88167 100644
--- a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc
+++ b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc
@@ -16,10 +16,10 @@
#include <vector>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime
#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
#include "tfrt/support/forward_decls.h" // from @tf_runtime
diff --git a/tensorflow/core/runtime_fallback/runtime/BUILD b/tensorflow/core/runtime_fallback/runtime/BUILD
index ece9e5d..45f433d 100644
--- a/tensorflow/core/runtime_fallback/runtime/BUILD
+++ b/tensorflow/core/runtime_fallback/runtime/BUILD
@@ -195,6 +195,7 @@
"//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
"//tensorflow/core/kernels/batching_util:batch_resource_base",
"//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
+ "//tensorflow/core/kernels/batching_util:batch_stats",
"//tensorflow/core/kernels/batching_util:bounded_executor",
"//tensorflow/core/kernels/batching_util:warmup",
"//tensorflow/core/lib/core:refcount",
@@ -205,7 +206,6 @@
"//tensorflow/core/tfrt/fallback:op_kernel_runner",
"//tensorflow/core/tfrt/utils:error_util",
"//tensorflow/core/tfrt/utils:fallback_tensor",
- "@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status",
diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h
index dae6eb3..86772a2 100644
--- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h
+++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h
@@ -20,7 +20,6 @@
#include <string>
#include <vector>
-#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -30,10 +29,12 @@
#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
#include "tensorflow/core/kernels/batching_util/warmup.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/random.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
@@ -215,6 +216,7 @@
batch_resource_options.batch_timeout_micros = batch_timeout_micros_;
batch_resource_options.max_enqueued_batches = max_enqueued_batches_;
batch_resource_options.allowed_batch_sizes = allowed_batch_sizes_;
+ batch_resource_options.batch_padding_policy = batch_padding_policy_;
batch_resource_options.low_priority_max_batch_size =
low_priority_max_batch_size_;
batch_resource_options.low_priority_batch_timeout_micros =
@@ -224,6 +226,13 @@
batch_resource_options.low_priority_allowed_batch_sizes =
low_priority_allowed_batch_sizes_;
+ serving::ModelBatchStats& model_batch_stats =
+ serving::GlobalBatchStatsRegistry().model(
+ /* model_name= */ std::string(GetModelName(c)),
+ /* op_name= */ c->op_kernel().name());
+ model_batch_stats.SetBatchTimeoutMicros(batch_timeout_micros_);
+ model_batch_stats.SetNumBatchThreads(num_batch_threads_);
+
std::unique_ptr<BatchResourceType> new_resource;
auto status = BatchResourceType::Create(
c, batch_resource_options, batch_function_,
diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc
index e953c70..38b8fc3 100644
--- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc
+++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc
@@ -137,7 +137,8 @@
options.num_batch_threads, options.max_batch_size,
options.batch_timeout_micros, options.max_enqueued_batches,
options.allowed_batch_sizes, enable_large_batch_splitting,
- disable_padding, options.low_priority_max_batch_size,
+ disable_padding, options.batch_padding_policy,
+ options.low_priority_max_batch_size,
options.low_priority_batch_timeout_micros,
options.low_priority_max_enqueued_batches,
options.low_priority_allowed_batch_sizes,
@@ -437,7 +438,7 @@
// BatchFunction in core/ops/batch_ops.cc.
.Attr(
"batch_padding_policy: "
- "{'PAD_UP'} = 'PAD_UP'")
+ "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'")
.Attr("Tin: list(type)")
.Attr("Tcaptured: list(type) >= 0")
.Attr("Tout: list(type)")
diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD
index ba79007..eb9724f 100644
--- a/tensorflow/core/tfrt/common/BUILD
+++ b/tensorflow/core/tfrt/common/BUILD
@@ -188,13 +188,13 @@
":pjrt_state",
":pjrt_util",
"//tensorflow/core:framework",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
"@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
"@local_xla//xla/pjrt/cpu:cpu_client",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -214,13 +214,13 @@
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/core:framework",
"@com_google_absl//absl/strings",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:test_main",
"@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
"@local_xla//xla/pjrt:pjrt_client",
"@local_xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@local_xla//xla/service:gpu_plugin",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/tfrt/common/pjrt_state_test.cc b/tensorflow/core/tfrt/common/pjrt_state_test.cc
index fddd72e..03dcdb7 100644
--- a/tensorflow/core/tfrt/common/pjrt_state_test.cc
+++ b/tensorflow/core/tfrt/common/pjrt_state_test.cc
@@ -21,10 +21,10 @@
#include <gtest/gtest.h>
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/core/tfrt/common/pjrt_util_test.cc b/tensorflow/core/tfrt/common/pjrt_util_test.cc
index 1361b72..48f7743 100644
--- a/tensorflow/core/tfrt/common/pjrt_util_test.cc
+++ b/tensorflow/core/tfrt/common/pjrt_util_test.cc
@@ -18,9 +18,9 @@
#include <utility>
#include "xla/pjrt/cpu/cpu_client.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/tfrt/common/pjrt_state.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
diff --git a/tensorflow/core/tfrt/fallback/fallback_state_test.cc b/tensorflow/core/tfrt/fallback/fallback_state_test.cc
index d7d5531..2111171 100644
--- a/tensorflow/core/tfrt/fallback/fallback_state_test.cc
+++ b/tensorflow/core/tfrt/fallback/fallback_state_test.cc
@@ -19,10 +19,10 @@
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/tfrt/gpu/kernel/BUILD b/tensorflow/core/tfrt/gpu/kernel/BUILD
index bd4f861..fef0e58 100644
--- a/tensorflow/core/tfrt/gpu/kernel/BUILD
+++ b/tensorflow/core/tfrt/gpu/kernel/BUILD
@@ -13,20 +13,16 @@
deps = [
":gpu_runner",
"//tensorflow/core:framework",
- "//tensorflow/core/common_runtime:copy_tensor",
"//tensorflow/core/framework:tensor",
"//tensorflow/core/platform:status",
"//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
- "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_utils",
"//tensorflow/core/runtime_fallback/kernel:tensor_util",
"//tensorflow/core/tfrt/utils:fallback_tensor",
"//tensorflow/core/tfrt/utils:gpu_variables_table",
- "//tensorflow/core/tfrt/utils:tensor_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
- "@tf_runtime//:core_runtime",
"@tf_runtime//:hostcontext",
"@tf_runtime//:support",
"@tf_runtime//:tensor_alwayslink",
@@ -47,9 +43,11 @@
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
+ "//tensorflow/core/framework:attr_value_proto_cc",
+ "//tensorflow/core/framework:function_proto_cc",
+ "//tensorflow/core/framework:types_proto_cc",
"//tensorflow/core/platform:notification",
"//tensorflow/core/platform:status",
- "//tensorflow/core/platform:statusor",
"//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
"//tensorflow/core/tfrt/common:global_state",
"//tensorflow/core/tfrt/utils:fallback_tensor",
@@ -59,6 +57,7 @@
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
@@ -122,6 +121,7 @@
"//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector",
"//tensorflow/core/platform:status",
"//tensorflow/core/tfrt/runtime",
+ "@com_google_absl//absl/status",
"@local_xla//xla/tsl/framework:serving_device_selector_policies",
"@tf_runtime//:hostcontext",
],
diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc
index 3143b8b..d4047d4 100644
--- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc
+++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc
@@ -41,15 +41,18 @@
#include "xla/tsl/framework/device_id_manager.h"
#include "xla/tsl/framework/serving_device_selector.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
#include "tensorflow/core/tfrt/common/global_state.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
@@ -231,7 +234,7 @@
int device_idx, const llvm::SmallVector<tfrt_stub::FallbackTensor>& args,
tfrt::ArrayRef<int64_t> resource_indices, Device* cpu_device,
absl::flat_hash_map<int, Device*> gpu_devices,
- tfrt::gpu::GpuVariablesTable& vars_table,
+ tfrt::gpu::GpuVariablesTable& vars_table, bool variables_are_shared,
const tfrt::ExecutionContext& exec_ctx) {
llvm::SmallVector<tfrt::AsyncValueRef<tfrt_stub::FallbackTensor>> results;
@@ -244,35 +247,51 @@
TF_ASSIGN_OR_RETURN(const std::vector<tsl::TfDeviceId> devices_on_platform,
tsl::DeviceIdManager::GetTfDevicesOnPlatform(
device_type, platform_device_id));
- const int platform_idx = platform_device_id.value();
absl::flat_hash_set<int64_t> resource_indices_set(resource_indices.begin(),
resource_indices.end());
+ // If variables are shared, there is only one copy of variables for all
+ // logical devices on the same physical GPU device; otherwise, each logical
+ // device has its own copy of variables.
+ const int cache_copy_idx =
+ variables_are_shared ? platform_device_id.value() : device_idx;
+
for (int i = 0, resource_idx = 0; i < args.size(); ++i) {
if (resource_indices_set.contains(i)) {
// Transfer resources.
+ VLOG(2) << "Transfer resource arg[" << i << "].";
tfrt::AsyncValueRef<tfrt_stub::FallbackTensor> device_tensor;
auto cached_device_variable =
- vars_table.GetDeviceVariable(args[i], platform_idx);
+ vars_table.GetDeviceVariable(args[i], cache_copy_idx);
if (cached_device_variable) {
- VLOG(2) << "Cache hit for resource arg[" << i << "]";
+ VLOG(2) << "Cache hit for resource arg[" << i << "].";
device_tensor = cached_device_variable.CopyRef();
} else {
- VLOG(2) << "Cache miss for resource arg[" << i << "]";
- // Distribute variables on virtual devices on the same GPU.
- const int idx = resource_idx % devices_on_platform.size();
- const int gpu_device_idx = devices_on_platform[idx].value();
+ VLOG(2) << "Cache miss for resource arg[" << i << "].";
+
+ int gpu_device_idx;
+ if (variables_are_shared) {
+ // Distribute variables on logical devices on the same GPU.
+ const int idx = resource_idx % devices_on_platform.size();
+ gpu_device_idx = devices_on_platform[idx].value();
+ } else {
+ gpu_device_idx = device_idx;
+ }
+
+ VLOG(2) << "Transfer the resource arg[" << i << "] to device "
+ << gpu_device_idx << ".";
device_tensor = TransferTensorToDevice(exec_ctx, args[i],
gpu_devices.at(gpu_device_idx));
- vars_table.AddOrUpdateDeviceVariable(args[i], platform_idx,
+ vars_table.AddOrUpdateDeviceVariable(args[i], cache_copy_idx,
std::move(device_tensor));
device_tensor =
- vars_table.GetDeviceVariable(args[i], platform_idx).CopyRef();
+ vars_table.GetDeviceVariable(args[i], cache_copy_idx).CopyRef();
}
results.push_back(device_tensor);
++resource_idx;
} else {
// Transfer inputs.
+ VLOG(2) << "Transfer input arg[" << i << "].";
tfrt::AsyncValueRef<tfrt_stub::FallbackTensor> device_tensor =
TransferTensorToDevice(exec_ctx, args[i], gpu_devices.at(device_idx));
results.push_back(device_tensor);
@@ -356,6 +375,7 @@
tsl::DeviceReservation device_reservation =
serving_device_selector_->ReserveDevice(absl::StrCat(fingerprint));
const int device_idx = device_reservation.device_index();
+ VLOG(1) << "GpuRunner selected device " << device_idx << ".";
// Compile the program.
const XlaCompiler::CompilationResult* compilation_result;
@@ -368,10 +388,10 @@
TF_ASSIGN_OR_RETURN(
llvm::SmallVector<tfrt::AsyncValueRef<tfrt_stub::FallbackTensor>>
transferred_args,
- TransferVariablesAndInputs(device_idx, *run_inputs.args,
- run_inputs.resource_indices,
- run_inputs.cpu_device, *run_inputs.gpu_devices,
- vars_table_, *run_inputs.exec_ctx));
+ TransferVariablesAndInputs(
+ device_idx, *run_inputs.args, run_inputs.resource_indices,
+ run_inputs.cpu_device, *run_inputs.gpu_devices, vars_table_,
+ /*variables_are_shared=*/false, *run_inputs.exec_ctx));
llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 4>
transferred_args_to_wait;
diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h
index fc61eff..d292fed 100644
--- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h
+++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h
@@ -18,7 +18,10 @@
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "llvm/ADT/SmallVector.h"
#include "xla/tsl/framework/serving_device_selector.h"
+#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/status.h"
@@ -27,6 +30,7 @@
#include "tensorflow/core/tfrt/utils/gpu_variables_table.h"
#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
#include "tfrt/host_context/execution_context.h" // from @tf_runtime
+#include "tfrt/support/forward_decls.h" // from @tf_runtime
namespace tensorflow {
namespace gpu {
diff --git a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc
index 8cc6a62..43cb013 100644
--- a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc
+++ b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc
@@ -21,19 +21,15 @@
#include "absl/strings/str_cat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include "tensorflow/core/common_runtime/copy_tensor.h"
-#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
-#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h"
#include "tensorflow/core/runtime_fallback/kernel/tensor_util.h"
#include "tensorflow/core/tfrt/gpu/kernel/gpu_runner.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tensorflow/core/tfrt/utils/gpu_variables_table.h"
-#include "tensorflow/core/tfrt/utils/tensor_util.h"
-#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
#include "tfrt/host_context/attribute_utils.h" // from @tf_runtime
#include "tfrt/host_context/execution_context.h" // from @tf_runtime
diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc
index 94e52ad..48f3160 100644
--- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc
+++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc
@@ -17,6 +17,7 @@
#include <memory>
#include <utility>
+#include "absl/status/status.h"
#include "xla/tsl/framework/serving_device_selector_policies.h"
#include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h"
#include "tensorflow/core/platform/status.h"
diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h
index bb99022..452ccdd 100644
--- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h
+++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h
@@ -15,6 +15,7 @@
#ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_
#define TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_
#include "xla/tsl/framework/serving_device_selector_policies.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
namespace tensorflow {
diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD
index 3a5d945..61d869f 100644
--- a/tensorflow/core/tfrt/graph_executor/BUILD
+++ b/tensorflow/core/tfrt/graph_executor/BUILD
@@ -246,9 +246,9 @@
":test_config_proto_cc",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/tfrt/graph_executor/config_test.cc b/tensorflow/core/tfrt/graph_executor/config_test.cc
index bc3d186..fc1b54f 100644
--- a/tensorflow/core/tfrt/graph_executor/config_test.cc
+++ b/tensorflow/core/tfrt/graph_executor/config_test.cc
@@ -17,9 +17,9 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/tfrt/graph_executor/config.pb.h"
#include "tensorflow/core/tfrt/graph_executor/test_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
index 0a0f073..c0e07b3 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
@@ -31,6 +31,7 @@
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -50,7 +51,6 @@
#include "tensorflow/core/tfrt/mlrt/interpreter/value.h"
#include "tensorflow/core/tfrt/mlrt/kernel/kernel.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tfrt/cpp_tests/test_util.h" // from @tf_runtime
#include "tfrt/host_context/resource_context.h" // from @tf_runtime
diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD
index dcc4c9d..94ca5ab 100644
--- a/tensorflow/core/tfrt/ifrt/BUILD
+++ b/tensorflow/core/tfrt/ifrt/BUILD
@@ -114,6 +114,7 @@
hdrs = ["ifrt_serving_executable.h"],
deps = [
":ifrt_config_proto_cc",
+ ":ifrt_device_utils",
":ifrt_loaded_variable_registry",
":ifrt_loaded_variable_utils",
":ifrt_restore_tensor_registry",
@@ -143,7 +144,10 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
@@ -216,12 +220,22 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@local_tsl//tsl/platform:statusor",
+ "@local_xla//xla/hlo/ir:hlo",
"@local_xla//xla/python/ifrt",
"@local_xla//xla/tsl/concurrency:ref_count",
],
)
cc_library(
+ name = "ifrt_model_restore_context",
+ hdrs = ["ifrt_model_restore_context.h"],
+ deps = [
+ ":checkpoint_loader",
+ "@com_google_absl//absl/strings:string_view",
+ ],
+)
+
+cc_library(
name = "ifrt_model_context",
srcs = ["ifrt_model_context.cc"],
hdrs = ["ifrt_model_context.h"],
@@ -406,10 +420,10 @@
"//tensorflow/core/framework:types_proto_cc",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/python/ifrt",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -429,16 +443,17 @@
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_xla//xla:xla_data_proto_cc",
+ "@local_xla//xla/hlo/ir:hlo",
"@local_xla//xla/python/ifrt",
"@local_xla//xla/python/ifrt:test_util",
"@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib",
"@local_xla//xla/tsl/concurrency:ref_count",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@tf_runtime//:hostcontext",
],
)
@@ -593,3 +608,37 @@
"@tf_runtime//backends/cpu:tf_ops_alwayslink",
],
)
+
+cc_library(
+ name = "checkpoint_loader",
+ srcs = ["checkpoint_loader.cc"],
+ hdrs = ["checkpoint_loader.h"],
+ deps = [
+ ":ifrt_loaded_variable_utils",
+ ":ifrt_restore_tensor_registry",
+ "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/common_runtime:function",
+ "//tensorflow/core/framework:attr_value_proto_cc",
+ "//tensorflow/core/framework:node_def_util",
+ "//tensorflow/core/framework:tensor",
+ "//tensorflow/core/tfrt/fallback:op_kernel_runner",
+ "//tensorflow/core/tfrt/mlrt/bytecode",
+ "//tensorflow/core/tfrt/mlrt/kernel:context",
+ "//tensorflow/core/tfrt/mlrt/kernel:kernel_runner_utils",
+ "//tensorflow/core/tfrt/mlrt/kernel:shard_restore_util",
+ "//tensorflow/core/tfrt/utils:fallback_tensor",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:tstring",
+ "@local_xla//xla/python/ifrt",
+ "@tf_runtime//:hostcontext",
+ ],
+)
diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc
new file mode 100644
index 0000000..7085efd
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc
@@ -0,0 +1,358 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/OwningOpRef.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
+#include "xla/python/ifrt/future.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_handle.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h"
+#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/tstring.h"
+#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+namespace {
+
+static constexpr int kNumRestoreClusters = 4;
+
+// A shard of variables to be restored.
+struct RestoreVariableShard {
+ tensorflow::Tensor prefix;
+ tensorflow::Tensor tensor_names;
+ tensorflow::Tensor shape_and_slices;
+ std::vector<tensorflow::tfrt_stub::FallbackTensor> var_handles;
+ tensorflow::AttrValue dtypes_attr_value;
+ std::vector<tensorflow::DataType> restored_dtypes;
+ std::vector<bool> truncate_in_cast;
+};
+
+struct AsyncState {
+ explicit AsyncState(
+ const std::vector<tensorflow::TensorValue>& input_tf_tensor_values,
+ const OpKernelContext::Params& params, int num_outputs,
+ const tensorflow::DeviceMgr& device_manager,
+ const tensorflow::ProcessFunctionLibraryRuntime&
+ process_function_library_runtime)
+ : run_state(input_tf_tensor_values, params),
+ context(&run_state.params, num_outputs),
+ device_manager(device_manager),
+ process_function_library_runtime(process_function_library_runtime) {}
+
+ tfrt_stub::OpKernelRunState run_state;
+ OpKernelContext context;
+ const tensorflow::DeviceMgr& device_manager;
+ const tensorflow::ProcessFunctionLibraryRuntime&
+ process_function_library_runtime;
+
+ std::vector<xla::ifrt::Promise<tensorflow::Tensor>> results;
+};
+
+// Returns a casted tensor if successful.
+absl::StatusOr<tensorflow::Tensor> Cast(
+ tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype,
+ tensorflow::DataType cast_dtype, bool truncate_in_cast,
+ const tensorflow::DeviceMgr& device_manager,
+ const tensorflow::ProcessFunctionLibraryRuntime&
+ process_function_library_runtime,
+ OpKernelContext::Params& params) {
+ auto runner =
+ tfrt_stub::OpKernelRunner::Create(
+ /*op_name=*/
+ "Cast", /*node_name=*/"Cast", params.device->name(),
+ /*num_args=*/1,
+ [&](tensorflow::AttrValueMap* attr_value_map) {
+ tensorflow::AttrValue restored_dtype_attr_value;
+ restored_dtype_attr_value.set_type(restored_dtype);
+ attr_value_map->insert({"SrcT", restored_dtype_attr_value});
+
+ tensorflow::AttrValue cast_dtype_attr_value;
+ cast_dtype_attr_value.set_type(cast_dtype);
+ attr_value_map->insert({"DstT", cast_dtype_attr_value});
+
+ tensorflow::AttrValue truncate_attr_value;
+ truncate_attr_value.set_b(truncate_in_cast);
+ attr_value_map->insert({"Truncate", truncate_attr_value});
+ return absl::OkStatus();
+ },
+ device_manager, process_function_library_runtime)
+ .value();
+
+ std::vector<tensorflow::TensorValue> input_tf_tensor_values;
+ input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor));
+
+ tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params);
+ // Use persistent device instead of the per request device.
+
+ OpKernelContext op_kernel_context(¶ms, /*num_outputs=*/1);
+
+ runner.Run(&op_kernel_context);
+
+ if (!op_kernel_context.status().ok()) {
+ return op_kernel_context.status();
+ }
+ DCHECK_EQ(op_kernel_context.num_outputs(), 1);
+ return *(op_kernel_context.mutable_output(0));
+}
+
+absl::Status RunShard(RestoreVariableShard shard,
+ IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry,
+ tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue,
+ tf_mlrt::Context& context) {
+ if (!ifrt_restore_tensor_registry) {
+ return absl::InternalError("ifrt_restore_tensor_registry must not be null");
+ }
+ if (!checkpoint_loader_work_queue) {
+ return absl::InternalError("checkpoint_loader_work_queue must not be null");
+ }
+ const int num_outputs = shard.var_handles.size();
+ DCHECK_EQ(num_outputs, shard.tensor_names.NumElements());
+ auto& fallback_request_state = context.fallback_request_state();
+
+ // Use `tf.RestoreV2` to restore tensor. This will also populate
+ // tensorflow::ResourceManager.
+ // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the
+ // variable is only used by device/IFRT.
+ // TODO(b/319045348): consider directly calling restore function such as that
+ // in /tensorflow/core/kernels/save_restore_v2_ops.cc
+ auto runner =
+ tfrt_stub::OpKernelRunner::Create(
+ /*op_name=*/
+ "RestoreV2", /*node_name=*/"RestoreV2",
+ context.params().device->name(),
+ /*num_args=*/3,
+ [&](tensorflow::AttrValueMap* attr_value_map) {
+ attr_value_map->insert({"dtypes", shard.dtypes_attr_value});
+ return absl::OkStatus();
+ },
+ fallback_request_state.device_manager(),
+ fallback_request_state.process_function_library_runtime())
+ .value();
+
+ // Prepare the input tensors.
+ std::vector<tensorflow::TensorValue> input_tf_tensor_values;
+ static constexpr int kNumInputArgs = 3;
+ input_tf_tensor_values.resize(kNumInputArgs);
+ // We need to keep these tensor alive
+ input_tf_tensor_values[0].tensor = &shard.prefix;
+ input_tf_tensor_values[1].tensor = &shard.tensor_names;
+ input_tf_tensor_values[2].tensor = &shard.shape_and_slices;
+
+ auto& params = context.params();
+ tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params);
+ // Use persistent device instead of the per request device.
+ params.device = context.fallback_request_state().device_manager().HostCPU();
+
+ auto async_state = std::make_unique<AsyncState>(
+ input_tf_tensor_values, params, num_outputs,
+ fallback_request_state.device_manager(),
+ fallback_request_state.process_function_library_runtime());
+
+ for (int i = 0; i < num_outputs; ++i) {
+ auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
+ auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
+ const ResourceHandle& var_handle =
+ shard.var_handles[i].tensor().scalar<tensorflow::ResourceHandle>()();
+
+ TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape,
+ ifrt_serving::GetDtypeAndShape(var_handle));
+
+ std::string runtime_name =
+ ifrt_serving::GetRuntimeNameFromVarHandle(var_handle);
+
+ ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo
+ restored_tensor_info = {false, std::move(dtype_and_shape),
+ std::move(future)};
+ if (auto status = ifrt_restore_tensor_registry->TryRegister(
+ runtime_name, restored_tensor_info);
+ !status.ok()) {
+ // Propagate errors so that if already-registered futures are being waited
+ // on, they can be unblocked.
+ for (auto& result : async_state->results) {
+ std::move(result).Set(status);
+ };
+ return status;
+ }
+ async_state->results.push_back(std::move(promise));
+ }
+
+ // Use dedicated work queue for restore operation.
+ checkpoint_loader_work_queue->AddTask([runner = std::move(runner),
+ async_state = std::move(async_state),
+ shard = std::move(shard)]() {
+ // Keep input tensor alive in `shard`.
+ auto* op_kernel_context_ptr = &async_state->context;
+ runner.Run(op_kernel_context_ptr);
+
+ auto& op_kernel_context = async_state->context;
+ if (!op_kernel_context.status().ok()) {
+ for (auto& result : async_state->results) {
+ std::move(result).Set(op_kernel_context.status());
+ }
+ return;
+ }
+ DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs());
+ DCHECK_EQ(shard.truncate_in_cast.size(), op_kernel_context.num_outputs());
+
+ // TODO(b/343964091): consider to run multiple casts in parallel.
+ for (int i = 0; i < op_kernel_context.num_outputs(); ++i) {
+ DCHECK(op_kernel_context.mutable_output(i));
+
+ if (op_kernel_context.mutable_output(i)->dtype() !=
+ shard.restored_dtypes[i]) {
+ std::move(async_state->results[i])
+ .Set(absl::InvalidArgumentError(absl::StrCat(
+ "The restored tensor has a different dtype than the "
+ "variable handle: ",
+ op_kernel_context.mutable_output(i)->dtype(), " vs. ",
+ shard.restored_dtypes[i])));
+ return;
+ }
+ const ResourceHandle& var_handle =
+ shard.var_handles[i].tensor().scalar<tensorflow::ResourceHandle>()();
+
+ if (shard.restored_dtypes[i] == var_handle.dtypes_and_shapes()[0].dtype) {
+ std::move(async_state->results[i])
+ .Set(*std::move(op_kernel_context.mutable_output(i)));
+ } else {
+ absl::StatusOr<tensorflow::Tensor> cast_output =
+ Cast(*op_kernel_context.mutable_output(i), shard.restored_dtypes[i],
+ var_handle.dtypes_and_shapes()[0].dtype,
+ shard.truncate_in_cast[i], async_state->device_manager,
+ async_state->process_function_library_runtime,
+ async_state->run_state.params);
+ if (!cast_output.ok()) {
+ std::move(async_state->results[i]).Set(cast_output.status());
+ } else {
+ std::move(async_state->results[i]).Set(*std::move(cast_output));
+ }
+ }
+ }
+ });
+ return absl::OkStatus();
+}
+
+int64_t GetSizeFromVarHandle(const ResourceHandle& handle) {
+ int size = 0;
+ for (auto& dtype_and_shape : handle.dtypes_and_shapes()) {
+ size += DataTypeSize(dtype_and_shape.dtype) *
+ dtype_and_shape.shape.num_elements();
+ }
+ return size;
+}
+
+} // namespace
+
+absl::Status CheckpointLoader::PrepareRestore(
+ mlir::OwningOpRef<mlir::ModuleOp> module) {
+ VLOG(1) << "Skip CheckpointLoader::PrepareRestore";
+ return absl::OkStatus();
+}
+
+absl::Status CheckpointLoader::Load(
+ const tensorflow::tfrt_stub::FallbackTensor& prefix,
+ const std::vector<tensorflow::tfrt_stub::FallbackTensor>& var_handles,
+ const tensorflow::tfrt_stub::FallbackTensor& tensor_names,
+ const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices,
+ const mlrt::bc::Vector<tensorflow::DataType>& restored_dtypes,
+ const mlrt::bc::Vector<bool>& truncate_in_cast, tf_mlrt::Context& context) {
+ std::vector<int64_t> variable_sizes;
+ variable_sizes.reserve(var_handles.size());
+ for (auto& handle : var_handles) {
+ variable_sizes.push_back(GetSizeFromVarHandle(
+ handle.tensor().scalar<tensorflow::ResourceHandle>()()));
+ }
+
+ std::vector<std::vector<int>> sharded_indices = tf_mlrt::ShardVariables(
+ kNumRestoreClusters, absl::MakeSpan(variable_sizes));
+
+ // Converts the names and slices back to the tensor.
+ auto vector_to_tensor = [](const std::vector<tsl::tstring>& vec) {
+ tensorflow::Tensor tensor(tensorflow::DT_STRING,
+ TensorShape({static_cast<int>(vec.size())}));
+ for (int i = 0; i < vec.size(); ++i) {
+ tensor.flat<tsl::tstring>()(i) = vec[i];
+ }
+ return tensor;
+ };
+
+ const auto& tensor_names_flat = tensor_names.tensor().flat<tsl::tstring>();
+ const auto& shape_and_slices_flat =
+ shape_and_slices.tensor().flat<tsl::tstring>();
+
+ std::vector<RestoreVariableShard> shards;
+ shards.reserve(sharded_indices.size());
+ for (auto& sharded_index : sharded_indices) {
+ RestoreVariableShard shard;
+ shard.var_handles.reserve(sharded_index.size());
+ shard.truncate_in_cast.reserve(sharded_index.size());
+ shard.restored_dtypes.reserve(sharded_index.size());
+ std::vector<tsl::tstring> tensor_names;
+ std::vector<tsl::tstring> shape_and_slices;
+ shape_and_slices.reserve(sharded_index.size());
+ tensor_names.reserve(sharded_index.size());
+ for (int index : sharded_index) {
+ tensor_names.push_back(tensor_names_flat(index));
+ shape_and_slices.push_back(shape_and_slices_flat(index));
+ shard.dtypes_attr_value.mutable_list()->add_type(restored_dtypes[index]);
+ shard.var_handles.push_back(var_handles[index]);
+ shard.restored_dtypes.push_back(restored_dtypes[index]);
+ shard.truncate_in_cast.push_back(truncate_in_cast[index]);
+ }
+ shard.prefix = prefix.tensor();
+ shard.tensor_names = vector_to_tensor(tensor_names);
+ shard.shape_and_slices = vector_to_tensor(shape_and_slices);
+ shards.push_back(std::move(shard));
+ }
+ for (const auto& shard : shards) {
+ TF_RETURN_IF_ERROR(RunShard(shard, ifrt_restore_tensor_registry_,
+ checkpoint_loader_work_queue_, context));
+ }
+ return absl::OkStatus();
+}
+
+} // namespace ifrt_serving
+} // namespace tensorflow
diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.h b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h
new file mode 100644
index 0000000..cd835d4
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h
@@ -0,0 +1,66 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_
+#define TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_
+
+#include <vector>
+
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/OwningOpRef.h" // from @llvm-project
+#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
+#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
+#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+// TODO(b/352551302) Move the unit test in ifrt_ops_kernel for restore to test
+// this class's APIs.
+// Implement the `CheckpointLoaderInterface` by using RestoreV2.
+class CheckpointLoader {
+ public:
+ explicit CheckpointLoader(
+ IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry,
+ tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue)
+ : ifrt_restore_tensor_registry_(ifrt_restore_tensor_registry),
+ checkpoint_loader_work_queue_(checkpoint_loader_work_queue) {}
+ virtual ~CheckpointLoader() = default;
+
+ // Called before `Load` to do some preparation work.
+ virtual absl::Status PrepareRestore(mlir::OwningOpRef<mlir::ModuleOp> module);
+
+ // Load the checkpoint. This API is designed to be compatible with the
+ // `tf_mlrt.ifrt_restore_variable` kernel.
+ virtual absl::Status Load(
+ const tensorflow::tfrt_stub::FallbackTensor& prefix,
+ const std::vector<tensorflow::tfrt_stub::FallbackTensor>& var_handles,
+ const tensorflow::tfrt_stub::FallbackTensor& tensor_names,
+ const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices,
+ const mlrt::bc::Vector<tensorflow::DataType>& restored_dtypes,
+ const mlrt::bc::Vector<bool>& truncate_in_cast,
+ tf_mlrt::Context& context);
+
+ IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry_;
+ tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue_;
+};
+
+} // namespace ifrt_serving
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h
index e799c57..2c6a566 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h
@@ -17,6 +17,7 @@
#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_
#include <string>
+#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
@@ -25,6 +26,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
+#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/future.h"
#include "xla/tsl/concurrency/ref_count.h"
@@ -38,19 +40,29 @@
// The key is per variable tensor per device assignment. For single -device
// program, variables can be loaded on multiple devices with core selection.
// For SPMD program, we currently assume all devices will be used, so we use
- // set to make it compatible with SPMD.
+ // vector to make it compatible with SPMD.
struct Key {
- // We use a set to make it compatible with SPMD.
- absl::flat_hash_set<int> device_ids;
+ // We use a vector to make it compatible with SPMD because the order of the
+ // devices used for sharding must match the order of the devices used for
+ // xla compilation.
+ std::vector<int> device_ids;
std::string input_name;
+ xla::HloSharding hlo_sharding;
template <typename H>
friend H AbslHashValue(H h, const Key& key) {
- h = H::combine(std::move(h), key.input_name, key.device_ids);
+ h = H::combine(std::move(h), key.input_name, key.device_ids,
+ key.hlo_sharding);
return h;
}
friend bool operator==(const Key& x, const Key& y) {
- return x.input_name == y.input_name && x.device_ids == y.device_ids;
+ return x.input_name == y.input_name && x.device_ids == y.device_ids &&
+ x.hlo_sharding == y.hlo_sharding;
+ }
+
+ std::string ToString() const {
+ return absl::StrCat(input_name, ":", absl::StrJoin(device_ids, ","), ":",
+ hlo_sharding.ToString());
}
};
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc
index 7a17a7e..0af5363 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc
@@ -53,14 +53,10 @@
std::shared_ptr<xla::ifrt::Client> ifrt_client,
const tsl::thread::ThreadPool& thread_pool,
const tensorflow::Tensor& variable,
- const VariableDeviceShardingConfigProto& sharding_config) {
- std::vector<int> device_ids{sharding_config.device_ids().begin(),
- sharding_config.device_ids().end()};
- TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
- xla::HloSharding::FromProto(sharding_config.sharding()));
+ const VariableDeviceShardingConfig& sharding_config) {
return tensorflow::ifrt_serving::MakeArrayFromTensor(
- *ifrt_client, variable, sharding_config.device_ids(), hlo_sharding,
- thread_pool);
+ *ifrt_client, variable, sharding_config.device_ids,
+ sharding_config.hlo_sharding, thread_pool);
}
} // namespace
@@ -97,12 +93,11 @@
const ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry,
ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry,
tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
- const VariableDeviceShardingConfigProto& sharding_config) {
- absl::flat_hash_set<int> device_ids{sharding_config.device_ids().begin(),
- sharding_config.device_ids().end()};
+ const VariableDeviceShardingConfig& sharding_config) {
IfrtLoadedVariableRegistry::Key loaded_variable_key{
- .device_ids = std::move(device_ids),
+ .device_ids = sharding_config.device_ids,
.input_name = std::string(runtime_name),
+ .hlo_sharding = sharding_config.hlo_sharding,
};
if (ifrt_loaded_variable_registry.GetLoadedVariable(loaded_variable_key)
.ok()) {
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h
index 4d07d1a..6fea3a5 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h
@@ -18,11 +18,13 @@
#include <memory>
#include <string>
+#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
+#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/python/ifrt/client.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
@@ -37,6 +39,12 @@
// An index to indicate a non per-core executable bundle cache.
inline constexpr int kNoCoreSelectedIndex = -1;
+// TODO(b/352551302) Delete VariableDeviceShardingConfigProto.
+struct VariableDeviceShardingConfig {
+ std::vector<int> device_ids;
+ xla::HloSharding hlo_sharding;
+};
+
absl::StatusOr<ifrt_serving::DtypeAndShape> GetDtypeAndShape(
const ResourceHandle& resource_handle);
@@ -57,7 +65,7 @@
const ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry,
ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry,
tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
- const VariableDeviceShardingConfigProto& sharding_config);
+ const VariableDeviceShardingConfig& sharding_config);
} // namespace ifrt_serving
} // namespace tensorflow
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc
index 4777d0a3c..fe8e988 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc
@@ -23,12 +23,14 @@
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/future.h"
#include "xla/python/ifrt/test_util.h"
#include "xla/tsl/concurrency/ref_count.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/tensor.h"
@@ -39,7 +41,6 @@
#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
@@ -77,8 +78,10 @@
auto restore_work_queue = tfrt::CreateMultiThreadedWorkQueue(
/*num_threads=*/4, /*num_blocking_threads=*/4);
- VariableDeviceShardingConfigProto sharding_config;
- sharding_config.add_device_ids(0);
+ VariableDeviceShardingConfig sharding_config = {
+ .device_ids = {0},
+ .hlo_sharding = xla::HloSharding::Replicate(),
+ };
auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
@@ -120,8 +123,10 @@
auto restore_work_queue = tfrt::CreateMultiThreadedWorkQueue(
/*num_threads=*/4, /*num_blocking_threads=*/4);
- VariableDeviceShardingConfigProto sharding_config;
- sharding_config.add_device_ids(0);
+ VariableDeviceShardingConfig sharding_config{
+ .device_ids = {0},
+ .hlo_sharding = xla::HloSharding::Replicate(),
+ };
auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
@@ -140,6 +145,7 @@
IfrtLoadedVariableRegistry::Key key{
.device_ids = {0},
.input_name = "var_x",
+ .hlo_sharding = sharding_config.hlo_sharding,
};
TF_ASSERT_OK_AND_ASSIGN(auto v,
loaded_variable_registry.GetLoadedVariable(key));
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h
new file mode 100644
index 0000000..da9528e
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h
@@ -0,0 +1,50 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_
+#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_
+
+#include <memory>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+inline constexpr absl::string_view kIfrtModelRestoreContextName =
+ "IfrtModelRestoreContext";
+
+// A resource context that holds the `CheckpointLoader` for a model. We need a
+// different context than `IfrtModelContext` because `IfrtModelContext` is too
+// large to be a dependency of other libraries.
+class IfrtModelRestoreContext {
+ public:
+ explicit IfrtModelRestoreContext(
+ std::unique_ptr<CheckpointLoader> checkpoint_loader)
+ : checkpoint_loader_(std::move(checkpoint_loader)) {}
+
+ CheckpointLoader* checkpoint_loader() const {
+ return checkpoint_loader_.get();
+ }
+
+ private:
+ std::unique_ptr<CheckpointLoader> checkpoint_loader_;
+};
+
+} // namespace ifrt_serving
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc
index de0a27a..3251962 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc
@@ -21,11 +21,11 @@
#include "absl/status/status.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
#include "xla/python/ifrt/future.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
index b5e0770..a6d35bf 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
@@ -34,8 +34,12 @@
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h"
@@ -63,6 +67,7 @@
#include "xla/tsl/framework/serving_device_selector.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -70,6 +75,7 @@
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_device_utils.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
@@ -119,39 +125,61 @@
return dtypes_and_shapes;
}
-absl::StatusOr<xla::DeviceAssignment> GetXlaDeviceAssignment(
- const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) {
- if (!compile_metadata.has_device_assignment()) {
- return absl::InternalError("No device assignment found.");
+// Returns the device assignment from the given IFRT devices list.
+absl::StatusOr<xla::DeviceAssignment> GetRuntimeXlaDeviceAssignment(
+ const xla::ifrt::DeviceList& devices, int num_replicas,
+ int num_cores_per_replica) {
+ const int num_devices = num_replicas * num_cores_per_replica;
+ if (devices.size() != num_devices) {
+ return absl::InternalError(
+ absl::StrCat("Device assignment has ", devices.size(),
+ " devices, but expected ", num_devices));
}
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<xla::DeviceAssignment> da,
- xla::DeviceAssignment::Deserialize(compile_metadata.device_assignment()));
- return *da;
-}
-
-absl::StatusOr<std::vector<xla::ifrt::Device*>> GetAssignedDevices(
- const xla::ifrt::Client& ifrt_client,
- const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) {
- TF_ASSIGN_OR_RETURN(auto device_assignment,
- GetXlaDeviceAssignment(compile_metadata));
- const int num_devices =
- device_assignment.replica_count() * device_assignment.computation_count();
- std::vector<xla::ifrt::Device*> devices;
- devices.reserve(num_devices);
- for (int replica_idx = 0; replica_idx < device_assignment.replica_count();
- replica_idx++) {
- for (int computation_idx = 0;
- computation_idx < device_assignment.computation_count();
- computation_idx++) {
- auto device_id = device_assignment(replica_idx, computation_idx);
- TF_ASSIGN_OR_RETURN(
- xla::ifrt::Device * device,
- ifrt_client.LookupDevice(xla::ifrt::DeviceId(device_id)));
- devices.push_back(device);
+ xla::DeviceAssignment da(num_replicas, num_cores_per_replica);
+ int device_index = 0;
+ for (int replica_idx = 0; replica_idx < num_replicas; replica_idx++) {
+ for (int core_idx = 0; core_idx < num_cores_per_replica;
+ core_idx++, device_index++) {
+ da(replica_idx, core_idx) = devices[device_index]->Id().value();
+ VLOG(3) << "Added IFRT device id: " << da(replica_idx, core_idx);
}
}
- return devices;
+ return da;
+}
+
+static constexpr absl::string_view kDeviceAssignmentAttr = "device_assignment";
+static constexpr absl::string_view kEntryFuncName = "main";
+
+absl::StatusOr<std::vector<xla::ifrt::Device*>> GetAssignedDevices(
+ mlir::ModuleOp module, const xla::ifrt::Client& ifrt_client,
+ int num_replicas, int num_cores_per_replica) {
+ auto op = module.lookupSymbol<mlir::func::FuncOp>(kEntryFuncName);
+ if (!op) {
+ return absl::InternalError("Could not find entry function in MLIR Module.");
+ }
+
+ auto device_assignment_attr =
+ op->getAttrOfType<mlir::ArrayAttr>(kDeviceAssignmentAttr);
+ std::optional<std::vector<int>> device_assignment_attr_val;
+
+ if (device_assignment_attr && !device_assignment_attr.getValue().empty()) {
+ std::vector<int> coords;
+ coords.reserve(num_replicas * num_cores_per_replica);
+ for (auto coord_attr : device_assignment_attr.getValue()) {
+ auto coord_attr_val = mlir::dyn_cast<mlir::IntegerAttr>(coord_attr);
+ if (!coord_attr_val) {
+ return absl::InternalError(
+ llvm::formatv("Device assignment attribute is not an integer: {0}",
+ device_assignment_attr)
+ .str());
+ }
+ coords.push_back(coord_attr_val.getInt());
+ }
+ device_assignment_attr_val = std::move(coords);
+ }
+ return GetAssignedIfrtDevices(ifrt_client, num_replicas,
+ num_cores_per_replica,
+ device_assignment_attr_val);
}
} // namespace
@@ -173,12 +201,21 @@
tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata,
GetCompileMetadata(*module, *client));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<xla::ifrt::Device*> assigned_devices,
+ GetAssignedDevices(*module, *client,
+ original_compile_metadata.num_replicas(),
+ original_compile_metadata.num_cores_per_replica()));
+
auto executable = absl::WrapUnique(new IfrtServingExecutable(
program_id, model_name, signature_name, std::move(module),
std::move(client), thread_pool, ifrt_loaded_variable_registry,
ifrt_restore, checkpoint_loader_queue, device_mgr,
std::move(shape_representation_fn), ifrt_serving_core_selector,
- std::move(original_compile_metadata), compilation_environement_proto));
+ std::move(original_compile_metadata),
+ xla::ifrt::DeviceList(xla::ifrt::DeviceList::Devices(
+ assigned_devices.begin(), assigned_devices.end())),
+ compilation_environement_proto));
return executable;
}
@@ -367,14 +404,17 @@
xla_compile_options.executable_build_options.set_num_partitions(
num_partitions);
- xla_compile_options.executable_build_options.set_use_spmd_partitioning(true);
+ xla_compile_options.executable_build_options.set_use_spmd_partitioning(
+ original_compile_metadata_.use_spmd_for_xla_partitioning());
xla_compile_options.parameter_is_tupled_arguments = false;
// Use portable execution for single device + core selection.
if (UsePortableExecution(compile_metadata)) {
xla_compile_options.compile_portable_executable = true;
} else {
- TF_ASSIGN_OR_RETURN(xla::DeviceAssignment da,
- GetXlaDeviceAssignment(tf2hlo_result.compile_metadata));
+ TF_ASSIGN_OR_RETURN(
+ xla::DeviceAssignment da,
+ GetRuntimeXlaDeviceAssignment(assigned_device_list_, num_replicas,
+ num_partitions));
VLOG(2) << "Device assignment :" << da.ToString();
xla_compile_options.executable_build_options.set_device_assignment(da);
}
@@ -516,7 +556,7 @@
// `device_reservation` should be alive before the end of the execution.
tsl::DeviceReservation device_reservation(kNoCoreSelectedIndex, nullptr);
- std::vector<xla ::ifrt::Device*> devices;
+ xla::ifrt::DeviceList device_list;
if (UsePortableExecution(compile_metadata)) {
device_reservation =
ifrt_serving_core_selector_->ReserveDevice(program_id_);
@@ -526,19 +566,16 @@
TF_ASSIGN_OR_RETURN(xla::ifrt::Device * device,
ifrt_client_->LookupDevice(xla::ifrt::DeviceId(
device_reservation.device_index())));
- devices.push_back(device);
+ device_list =
+ xla::ifrt::DeviceList(xla::ifrt::DeviceList::Devices({device}));
} else {
- TF_ASSIGN_OR_RETURN(devices,
- GetAssignedDevices(*ifrt_client_, compile_metadata));
+ device_list = assigned_device_list_;
}
TF_ASSIGN_OR_RETURN(SharedCachedExecutableBundle executable_bundle,
LookUpOrCreateExecutable(
compile_metadata, absl::MakeSpan(dtypes_and_shapes))
.Await());
- xla::ifrt::DeviceList device_list(
- xla::ifrt::DeviceList::Devices(devices.begin(), devices.end()));
-
if (executable_bundle->compile_metadata.args().size() !=
dtypes_and_shapes.size()) {
return absl::InternalError(absl::StrCat(
@@ -548,7 +585,7 @@
// Asynchronously load the restored variable tensors to Ifrt array.
TF_RETURN_IF_ERROR(AsyncLoadIfrtArray(inputs, variable_arg_indices,
- *executable_bundle, devices));
+ *executable_bundle, device_list));
std::vector<tsl::RCReference<xla::ifrt::Array>> args;
args.reserve(inputs.size());
@@ -556,13 +593,19 @@
for (int i = 0; i < inputs.size(); i++) {
if (variable_index < variable_arg_indices.size() &&
i == variable_arg_indices[variable_index]) {
- absl::flat_hash_set<int> device_ids;
- for (const auto& device : devices) {
- device_ids.insert(device->Id().value());
+ std::vector<int> device_ids;
+ device_ids.reserve(device_list.size());
+ for (const auto& device : device_list) {
+ device_ids.push_back(device->Id().value());
}
+ TF_ASSIGN_OR_RETURN(
+ xla::HloSharding hlo_sharding,
+ xla::HloSharding::FromProto(
+ executable_bundle->compile_metadata.args()[i].sharding()));
IfrtLoadedVariableRegistry::Key key{
.device_ids = std::move(device_ids),
.input_name = inputs[i].scalar<tsl::tstring>()(),
+ .hlo_sharding = std::move(hlo_sharding),
};
TF_ASSIGN_OR_RETURN(
auto loaded_variable,
@@ -640,7 +683,7 @@
absl::Span<const tensorflow::Tensor> inputs,
absl::Span<const int> variable_arg_indices,
const CachedExecutableBundle& executable_bundle,
- const std::vector<xla::ifrt::Device*>& devices) {
+ const xla::ifrt::DeviceList& devices) {
for (const int i : variable_arg_indices) {
if (inputs[i].dtype() != tensorflow::DT_STRING ||
!tensorflow::TensorShapeUtils::IsScalar(inputs[i].shape())) {
@@ -652,11 +695,15 @@
}
std::string runtime_name = inputs[i].scalar<tsl::tstring>()();
// TODO(b/339521818): Add test cases for OpSharding on variables.
- VariableDeviceShardingConfigProto sharding_config;
- *sharding_config.mutable_sharding() =
- executable_bundle.compile_metadata.args()[i].sharding();
+ TF_ASSIGN_OR_RETURN(
+ xla::HloSharding hlo_sharding,
+ xla::HloSharding::FromProto(
+ executable_bundle.compile_metadata.args()[i].sharding()));
+ VariableDeviceShardingConfig sharding_config{
+ .hlo_sharding = std::move(hlo_sharding),
+ };
for (const auto& device : devices) {
- sharding_config.add_device_ids(device->Id().value());
+ sharding_config.device_ids.push_back(device->Id().value());
}
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
index f3983b6..9dfc225 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
+++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
@@ -19,7 +19,6 @@
#include <algorithm>
#include <cstdint>
#include <memory>
-#include <optional>
#include <string>
#include <utility>
#include <vector>
@@ -144,12 +143,14 @@
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
IfrtServingCoreSelector* ifrt_serving_core_selector,
tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata,
+ xla::ifrt::DeviceList assigned_device_list,
tsl::protobuf::Message* compilation_environment_proto)
: program_id_(program_id),
model_name_(std::string(model_name)),
signature_name_(std::string(signature_name)),
module_(std::move(module)),
original_compile_metadata_(std::move(original_compile_metadata)),
+ assigned_device_list_(std::move(assigned_device_list)),
ifrt_client_(std::move(client)),
thread_pool_(*thread_pool),
ifrt_loaded_variable_registry_(*ifrt_loaded_variable_registry),
@@ -168,9 +169,10 @@
mlir::OwningOpRef<mlir::ModuleOp> module_ ABSL_GUARDED_BY(mutex_);
// The original compile metadata. We need to keep it around to be able to
- // test portable execution condition even if the Module itsel is already
+ // test portable execution condition even if the Module itself is already
// released.
tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata_;
+ const xla::ifrt::DeviceList assigned_device_list_;
std::shared_ptr<xla::ifrt::Client> ifrt_client_;
tsl::thread::ThreadPool& thread_pool_;
@@ -196,7 +198,7 @@
absl::Span<const tensorflow::Tensor> inputs,
absl::Span<const int> variable_arg_indices,
const CachedExecutableBundle& executable_bundle,
- const std::vector<xla::ifrt::Device*>& devices);
+ const xla::ifrt::DeviceList& devices);
absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> ConvertTensorToArray(
const tensorflow::Tensor& tensor,
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc
index c0bdef1..28d3efb 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc
@@ -33,6 +33,7 @@
#include "xla/python/ifrt/test_util.h"
#include "xla/tsl/framework/serving_device_selector.h"
#include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_matcher.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -41,7 +42,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/tstring.h"
diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc
index 7ec27bb..77043fc 100644
--- a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc
+++ b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc
@@ -34,13 +34,13 @@
#include "xla/python/ifrt/test_util.h"
#include "xla/python/pjrt_ifrt/xla_sharding.h"
#include "xla/tsl/concurrency/ref_count.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_matcher.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/ml_dtypes.h"
#include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/tfrt/kernels/BUILD b/tensorflow/core/tfrt/kernels/BUILD
index 9716f4b..817b848 100644
--- a/tensorflow/core/tfrt/kernels/BUILD
+++ b/tensorflow/core/tfrt/kernels/BUILD
@@ -62,7 +62,6 @@
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/pjrt/cpu:cpu_client",
@@ -72,6 +71,7 @@
"@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib",
"@local_xla//xla/tsl/framework:serving_device_selector",
"@local_xla//xla/tsl/framework/test_util:mock_serving_device_selector",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc
index 3ae4d09..cd29511 100644
--- a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc
+++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc
@@ -32,6 +32,7 @@
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "xla/tsl/framework/serving_device_selector.h"
#include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
@@ -45,7 +46,6 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/core/tfrt/mlrt/interpreter/BUILD b/tensorflow/core/tfrt/mlrt/interpreter/BUILD
index 552959b..0b1eee7 100644
--- a/tensorflow/core/tfrt/mlrt/interpreter/BUILD
+++ b/tensorflow/core/tfrt/mlrt/interpreter/BUILD
@@ -189,9 +189,9 @@
"@com_google_absl//absl/types:span",
"@com_google_benchmark//:benchmark",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:test_benchmark",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@tf_runtime//:hostcontext",
],
)
diff --git a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc
index 0434f3c..97982e7 100644
--- a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc
+++ b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc
@@ -29,6 +29,7 @@
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "benchmark/benchmark.h" // from @com_google_benchmark
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/async_handle.h"
@@ -39,7 +40,6 @@
#include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/register_span.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/value.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/test_benchmark.h"
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD
index 9da1e45..5e377f8 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/BUILD
+++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD
@@ -10,6 +10,7 @@
# copybara:uncomment "//learning/brain/tfrt:__subpackages__",
# copybara:uncomment "//learning/serving/servables/tfrt:__subpackages__",
"//tensorflow/core/tfrt/graph_executor:__subpackages__",
+ "//tensorflow/core/tfrt/ifrt:__subpackages__",
"//tensorflow/core/tfrt/saved_model:__subpackages__",
"//tensorflow/core/tfrt/tfrt_session:__subpackages__",
],
@@ -67,19 +68,16 @@
deps = [
":context",
":kernel",
- ":kernel_runner_utils",
- ":shard_restore_util",
- "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
- "//tensorflow/core/common_runtime:function",
"//tensorflow/core/framework:attr_value_proto_cc",
"//tensorflow/core/framework:types_proto_cc",
"//tensorflow/core/platform:protobuf",
- "//tensorflow/core/tfrt/fallback:op_kernel_runner",
+ "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
"//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc",
"//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_utils",
"//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+ "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
"//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry",
"//tensorflow/core/tfrt/mlrt/bytecode",
"//tensorflow/core/tfrt/mlrt/interpreter:context",
@@ -89,13 +87,10 @@
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:tstring",
"@local_xla//xla:xla_data_proto_cc",
"@local_xla//xla/python/ifrt",
- "@tf_runtime//:hostcontext",
],
alwayslink = 1,
)
@@ -178,8 +173,8 @@
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status_matchers",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@tf_runtime//:hostcontext",
"@tf_runtime//:ref_count",
],
@@ -210,9 +205,11 @@
"//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
"//tensorflow/core/tfrt/fallback:fallback_state",
"//tensorflow/core/tfrt/fallback:op_kernel_runner",
+ "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
"//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc",
"//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_registry",
"//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+ "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
"//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry",
"//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector",
"//tensorflow/core/tfrt/mlrt/bytecode",
@@ -230,7 +227,7 @@
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:refcount",
"@local_tsl//tsl/platform:status",
@@ -242,6 +239,7 @@
"@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib",
"@local_xla//xla/tsl/framework:serving_device_selector",
"@local_xla//xla/tsl/framework/test_util:mock_serving_device_selector",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@tf_runtime//:hostcontext",
],
)
diff --git a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
index 9d30f31..5a58488 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
@@ -234,7 +234,9 @@
options.num_batch_threads, options.max_batch_size,
options.batch_timeout_micros, options.max_enqueued_batches,
options.allowed_batch_sizes, enable_large_batch_splitting,
- disable_padding, options.low_priority_max_batch_size,
+ disable_padding,
+ /* batch_padding_policy= */ options.batch_padding_policy,
+ options.low_priority_max_batch_size,
options.low_priority_batch_timeout_micros,
options.low_priority_max_enqueued_batches,
options.low_priority_allowed_batch_sizes,
diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc
index ca9dd22..e5c7dbd 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc
@@ -25,39 +25,31 @@
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
#include "xla/python/ifrt/future.h"
#include "xla/xla_data.pb.h"
-#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
-#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep
-#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/context.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/future.h"
#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
#include "tensorflow/core/tfrt/mlrt/kernel/kernel.h"
-#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h"
-#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
#include "tsl/platform/tstring.h"
-#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
using tensorflow::ifrt_serving::IfrtModelContext;
@@ -65,14 +57,6 @@
namespace tf_mlrt {
namespace {
-int64_t GetSizeFromVarHandle(const ResourceHandle& handle) {
- int size = 0;
- for (auto& dtype_and_shape : handle.dtypes_and_shapes()) {
- size += DataTypeSize(dtype_and_shape.dtype) *
- dtype_and_shape.shape.num_elements();
- }
- return size;
-}
struct MlrtIfrtRestoreVariableKernel : mlrt::KernelFrame {
using KernelFrame::KernelFrame;
@@ -119,20 +103,8 @@
// dynamically decide it based on the size of the variables.
static constexpr int kNumRestoreClusters = 4;
- // A shard of variables to be restored.
- struct RestoreVariableShard {
- tensorflow::Tensor prefix;
- tensorflow::Tensor tensor_names;
- tensorflow::Tensor shape_and_slices;
- std::vector<tensorflow::tfrt_stub::FallbackTensor> var_handles;
- tensorflow::AttrValue dtypes_attr_value;
- std::vector<tensorflow::DataType> restored_dtypes;
- std::vector<bool> truncate_in_cast;
- };
-
absl::Status InvokeHelper();
- absl::Status RunShard(RestoreVariableShard shard);
absl::Status ValidateInput();
};
@@ -144,218 +116,6 @@
}
}
-// Returns a casted tensor if successful.
-absl::StatusOr<tensorflow::Tensor> Cast(
- tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype,
- tensorflow::DataType cast_dtype, bool truncate_in_cast,
- const tensorflow::DeviceMgr& device_manager,
- const tensorflow::ProcessFunctionLibraryRuntime&
- process_function_library_runtime,
- OpKernelContext::Params& params) {
- auto runner =
- tfrt_stub::OpKernelRunner::Create(
- /*op_name=*/
- "Cast", /*node_name=*/"Cast", params.device->name(),
- /*num_args=*/1,
- [&](tensorflow::AttrValueMap* attr_value_map) {
- tensorflow::AttrValue restored_dtype_attr_value;
- restored_dtype_attr_value.set_type(restored_dtype);
- attr_value_map->insert({"SrcT", restored_dtype_attr_value});
-
- tensorflow::AttrValue cast_dtype_attr_value;
- cast_dtype_attr_value.set_type(cast_dtype);
- attr_value_map->insert({"DstT", cast_dtype_attr_value});
-
- tensorflow::AttrValue truncate_attr_value;
- truncate_attr_value.set_b(truncate_in_cast);
- attr_value_map->insert({"Truncate", truncate_attr_value});
- return absl::OkStatus();
- },
- device_manager, process_function_library_runtime)
- .value();
-
- std::vector<tensorflow::TensorValue> input_tf_tensor_values;
- input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor));
-
- SetUpParams(runner, input_tf_tensor_values, params);
- // Use persistent device instead of the per request device.
-
- OpKernelContext op_kernel_context(¶ms, /*num_outputs=*/1);
-
- runner.Run(&op_kernel_context);
-
- if (!op_kernel_context.status().ok()) {
- return op_kernel_context.status();
- }
- DCHECK_EQ(op_kernel_context.num_outputs(), 1);
- return *(op_kernel_context.mutable_output(0));
-}
-
-absl::Status MlrtIfrtRestoreVariableKernel::RunShard(
- RestoreVariableShard shard) {
- std::optional<IfrtModelContext*> ifrt_model_context =
- context().resource_context().GetResource<IfrtModelContext>(
- "IfrtModelContext");
- if (!ifrt_model_context.has_value()) {
- return absl::FailedPreconditionError(
- "RestoreVariableOp: failed to fetch IfrtModelContext");
- }
- const int num_outputs = shard.var_handles.size();
- DCHECK_EQ(num_outputs, shard.tensor_names.NumElements());
- auto& fallback_request_state = context().fallback_request_state();
-
- // Use `tf.RestoreV2` to restore tensor. This will also populate
- // tensorflow::ResourceManager.
- // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the
- // variable is only used by device/IFRT.
- // TODO(b/319045348): consider directly calling restore function such as that
- // in /tensorflow/core/kernels/save_restore_v2_ops.cc
- auto runner =
- tfrt_stub::OpKernelRunner::Create(
- /*op_name=*/
- "RestoreV2", /*node_name=*/"RestoreV2",
- context().params().device->name(),
- /*num_args=*/3,
- [&](tensorflow::AttrValueMap* attr_value_map) {
- attr_value_map->insert({"dtypes", shard.dtypes_attr_value});
- return absl::OkStatus();
- },
- fallback_request_state.device_manager(),
- fallback_request_state.process_function_library_runtime())
- .value();
-
- // Prepare the input tensors.
- std::vector<tensorflow::TensorValue> input_tf_tensor_values;
- static constexpr int kNumInputArgs = 3;
- input_tf_tensor_values.resize(kNumInputArgs);
- // We need to keep these tensor alive
- input_tf_tensor_values[0].tensor = &shard.prefix;
- input_tf_tensor_values[1].tensor = &shard.tensor_names;
- input_tf_tensor_values[2].tensor = &shard.shape_and_slices;
-
- auto& params = context().params();
- SetUpParams(runner, input_tf_tensor_values, params);
- // Use persistent device instead of the per request device.
- params.device = context().fallback_request_state().device_manager().HostCPU();
-
- struct AsyncState {
- explicit AsyncState(
- const std::vector<tensorflow::TensorValue>& input_tf_tensor_values,
- const OpKernelContext::Params& params, int num_outputs,
- const tensorflow::DeviceMgr& device_manager,
- const tensorflow::ProcessFunctionLibraryRuntime&
- process_function_library_runtime)
- : run_state(input_tf_tensor_values, params),
- context(&run_state.params, num_outputs),
- device_manager(device_manager),
- process_function_library_runtime(process_function_library_runtime) {}
-
- tfrt_stub::OpKernelRunState run_state;
- OpKernelContext context;
- const tensorflow::DeviceMgr& device_manager;
- const tensorflow::ProcessFunctionLibraryRuntime&
- process_function_library_runtime;
-
- std::vector<xla::ifrt::Promise<tensorflow::Tensor>> results;
- };
- auto async_state = std::make_unique<AsyncState>(
- input_tf_tensor_values, params, num_outputs,
- fallback_request_state.device_manager(),
- fallback_request_state.process_function_library_runtime());
-
- ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry =
- (*ifrt_model_context)->GetRestoreTensorRegistry();
- for (int i = 0; i < num_outputs; ++i) {
- auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
- auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
- const ResourceHandle& var_handle =
- shard.var_handles[i].tensor().scalar<tensorflow::ResourceHandle>()();
-
- TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape,
- ifrt_serving::GetDtypeAndShape(var_handle));
-
- std::string runtime_name =
- ifrt_serving::GetRuntimeNameFromVarHandle(var_handle);
-
- ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo
- restored_tensor_info = {false, std::move(dtype_and_shape),
- std::move(future)};
- if (auto status = ifrt_restore_tensor_registry.TryRegister(
- runtime_name, restored_tensor_info);
- !status.ok()) {
- // Propagate errors so that if already-registered futures are being waited
- // on, they can be unblocked.
- for (auto& result : async_state->results) {
- std::move(result).Set(status);
- };
- return status;
- }
- async_state->results.push_back(std::move(promise));
- }
-
- // Use dedicated work queue for restore operation.
- DCHECK((*ifrt_model_context)->checkpoint_loader_queue() != nullptr);
- (*ifrt_model_context)
- ->checkpoint_loader_queue()
- ->AddTask([runner = std::move(runner),
- async_state = std::move(async_state),
- shard = std::move(shard)]() {
- // Keep input tensor alive in `shard`.
- auto* op_kernel_context_ptr = &async_state->context;
- runner.Run(op_kernel_context_ptr);
-
- auto& op_kernel_context = async_state->context;
- if (!op_kernel_context.status().ok()) {
- for (auto& result : async_state->results) {
- std::move(result).Set(op_kernel_context.status());
- }
- return;
- }
- DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs());
- DCHECK_EQ(shard.truncate_in_cast.size(),
- op_kernel_context.num_outputs());
-
- // TODO(b/343964091): consider to run multiple casts in parallel.
- for (int i = 0; i < op_kernel_context.num_outputs(); ++i) {
- DCHECK(op_kernel_context.mutable_output(i));
-
- if (op_kernel_context.mutable_output(i)->dtype() !=
- shard.restored_dtypes[i]) {
- std::move(async_state->results[i])
- .Set(absl::InvalidArgumentError(absl::StrCat(
- "The restored tensor has a different dtype than the "
- "variable handle: ",
- op_kernel_context.mutable_output(i)->dtype(), " vs. ",
- shard.restored_dtypes[i])));
- return;
- }
- const ResourceHandle& var_handle =
- shard.var_handles[i]
- .tensor()
- .scalar<tensorflow::ResourceHandle>()();
-
- if (shard.restored_dtypes[i] ==
- var_handle.dtypes_and_shapes()[0].dtype) {
- std::move(async_state->results[i])
- .Set(*std::move(op_kernel_context.mutable_output(i)));
- } else {
- absl::StatusOr<tensorflow::Tensor> cast_output = Cast(
- *op_kernel_context.mutable_output(i), shard.restored_dtypes[i],
- var_handle.dtypes_and_shapes()[0].dtype,
- shard.truncate_in_cast[i], async_state->device_manager,
- async_state->process_function_library_runtime,
- async_state->run_state.params);
- if (!cast_output.ok()) {
- std::move(async_state->results[i]).Set(cast_output.status());
- } else {
- std::move(async_state->results[i]).Set(*std::move(cast_output));
- }
- }
- }
- });
- return absl::OkStatus();
-}
-
absl::Status MlrtIfrtRestoreVariableKernel::ValidateInput() {
if (prefix().tensor().NumElements() != 1) {
return absl::InvalidArgumentError(
@@ -398,65 +158,26 @@
}
absl::Status MlrtIfrtRestoreVariableKernel::InvokeHelper() {
- TF_RETURN_IF_ERROR(ValidateInput());
-
- std::vector<int64_t> variable_sizes;
- variable_sizes.reserve(var_handles().size());
- for (auto& handle : var_handles()) {
- variable_sizes.push_back(GetSizeFromVarHandle(
- handle.tensor().scalar<tensorflow::ResourceHandle>()()));
+ std::optional<ifrt_serving::IfrtModelRestoreContext*> model_restore_context =
+ context()
+ .resource_context()
+ .GetResource<ifrt_serving::IfrtModelRestoreContext>(
+ ifrt_serving::kIfrtModelRestoreContextName);
+ if (!model_restore_context.has_value()) {
+ return absl::InternalError(
+ "Did not find IfrtModelRestoreContext resource.");
}
-
- std::vector<std::vector<int>> sharded_indices =
- ShardVariables(kNumRestoreClusters, absl::MakeSpan(variable_sizes));
-
- // Converts the names and slices back to the tensor.
- auto vector_to_tensor = [](const std::vector<tsl::tstring>& vec) {
- tensorflow::Tensor tensor(tensorflow::DT_STRING,
- TensorShape({static_cast<int>(vec.size())}));
- for (int i = 0; i < vec.size(); ++i) {
- tensor.flat<tsl::tstring>()(i) = vec[i];
- }
- return tensor;
- };
-
- const auto& tensor_names_flat = tensor_names().tensor().flat<tsl::tstring>();
- const auto& shape_and_slices_flat =
- shape_and_slices().tensor().flat<tsl::tstring>();
-
- std::vector<RestoreVariableShard> shards;
- shards.reserve(sharded_indices.size());
- for (auto& sharded_index : sharded_indices) {
- RestoreVariableShard shard;
- shard.var_handles.reserve(sharded_index.size());
- shard.truncate_in_cast.reserve(sharded_index.size());
- shard.restored_dtypes.reserve(sharded_index.size());
-
- std::vector<tsl::tstring> tensor_names;
- std::vector<tsl::tstring> shape_and_slices;
- shape_and_slices.reserve(sharded_index.size());
- tensor_names.reserve(sharded_index.size());
- for (int index : sharded_index) {
- tensor_names.push_back(tensor_names_flat(index));
- shape_and_slices.push_back(shape_and_slices_flat(index));
- shard.dtypes_attr_value.mutable_list()->add_type(
- restored_dtypes()[index]);
-
- shard.var_handles.push_back(var_handles()[index]);
- shard.restored_dtypes.push_back(restored_dtypes()[index]);
- shard.truncate_in_cast.push_back(truncate_in_cast()[index]);
- }
-
- shard.prefix = prefix().tensor();
- shard.tensor_names = vector_to_tensor(tensor_names);
- shard.shape_and_slices = vector_to_tensor(shape_and_slices);
- shards.push_back(std::move(shard));
+ if (*model_restore_context == nullptr) {
+ return absl::InternalError("IfrtModelRestoreContext must not be null.");
}
-
- for (const auto& shard : shards) {
- TF_RETURN_IF_ERROR(RunShard(shard));
+ ifrt_serving::CheckpointLoader* checkpoint_loader =
+ (*model_restore_context)->checkpoint_loader();
+ if (!checkpoint_loader) {
+ return absl::InternalError("CheckpointLoader must not be null.");
}
- return absl::OkStatus();
+ return checkpoint_loader->Load(prefix(), var_handles(), tensor_names(),
+ shape_and_slices(), restored_dtypes(),
+ truncate_in_cast(), context());
}
class MlrtIfrtLoadVariableKernel : public mlrt::KernelFrame {
diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc
index 83b5876..07fb83b 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc
@@ -33,6 +33,7 @@
#include "xla/python/ifrt/future.h"
#include "xla/python/ifrt/test_util.h"
#include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_matcher.h"
@@ -43,8 +44,10 @@
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
@@ -57,7 +60,6 @@
#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
#include "tensorflow/core/tfrt/mlrt/kernel/kernel.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/refcount.h"
#include "tsl/platform/status.h"
@@ -403,6 +405,13 @@
.value();
ifrt_model_context_->set_checkpoint_loader_queue(restore_work_queue_.get());
+ resource_context_
+ .CreateResource<tensorflow::ifrt_serving::IfrtModelRestoreContext>(
+ ifrt_serving::kIfrtModelRestoreContextName,
+ std::make_unique<tensorflow::ifrt_serving::CheckpointLoader>(
+ &ifrt_model_context_->GetRestoreTensorRegistry(),
+ ifrt_model_context_->checkpoint_loader_queue()));
+
serving_device_selector_ =
std::make_unique<tsl::test_util::MockServingDeviceSelector>();
ifrt_core_selector_ =
diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc
index f996695..a9aa89a 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc
@@ -29,6 +29,7 @@
#include "absl/strings/substitute.h"
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/tfrt/fallback/device_with_custom_allocator.h"
@@ -39,7 +40,6 @@
#include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h"
#include "tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h"
#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
#include "tfrt/concurrency/ref_count.h" // from @tf_runtime
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
diff --git a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc
index cd3f49f..16293c2 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc
@@ -66,7 +66,7 @@
};
std::priority_queue<RestoreVariableCluster,
std::vector<RestoreVariableCluster>, decltype(cmp)>
- min_heap;
+ min_heap(cmp);
for (int i = 0; i < num_shards; ++i) {
min_heap.push(RestoreVariableCluster());
}
diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD
index 9e85c14..5261546 100644
--- a/tensorflow/core/tfrt/saved_model/BUILD
+++ b/tensorflow/core/tfrt/saved_model/BUILD
@@ -135,6 +135,8 @@
"//tensorflow/core/tfrt/graph_executor",
"//tensorflow/core/tfrt/graph_executor:export_mlir",
"//tensorflow/core/tfrt/graph_executor:graph_execution_options",
+ "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
+ "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
"//tensorflow/core/tfrt/mlrt/bytecode",
"//tensorflow/core/tfrt/mlrt/bytecode:executable",
"//tensorflow/core/tfrt/mlrt/interpreter:context",
diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc
index 62ad655..84fbbff 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model.cc
@@ -16,13 +16,11 @@
#include <algorithm>
#include <cstddef>
-#include <cstdint>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
-#include <unordered_set>
#include <utility>
#include <vector>
@@ -70,6 +68,8 @@
#include "tensorflow/core/tfrt/graph_executor/export_mlir.h"
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
#include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/context.h"
@@ -134,6 +134,34 @@
"/tensorflow/tfrt/saved_model/input_spec_validation_failure",
"Record the models that failed input spec validation.", "model_name");
+absl::Status PrepareRestore(mlir::MLIRContext* context,
+ ModelRuntimeContext* model_runtime_context,
+ const tensorflow::MetaGraphDef& meta_graph_def,
+ FallbackState& fallback_state,
+ const std::string& saved_model_dir,
+ const SavedModel::Options& options,
+ ifrt_serving::CheckpointLoader* checkpoint_loader) {
+ // Import the global MLIR with `import_user_signatures` as true so that we can
+ // analysis the global MLIR to retrieve data needed for restore.
+ mlir::OwningOpRef<mlir::ModuleOp> mlir_module_restore_analysis;
+ ASSIGN_OR_RETURN_IN_IMPORT(
+ mlir_module_restore_analysis,
+ ImportSavedModel(
+ context, meta_graph_def, fallback_state, saved_model_dir,
+ /*import_user_signatures=*/true,
+ options.graph_execution_options.run_placer_grappler_on_functions));
+
+ if (!checkpoint_loader) {
+ return absl::InternalError("Missing checkpoint loader.");
+ }
+
+ TF_RETURN_IF_ERROR(checkpoint_loader->PrepareRestore(
+ std::move(mlir_module_restore_analysis)));
+
+ LOG(INFO) << "Complete set restore metadata.";
+ return absl::OkStatus();
+}
+
tensorflow::Status RunBytecodeInitializers(
const GraphExecutionOptions& options,
const InitializersAndSignatures& initializers_and_signatures,
@@ -596,6 +624,25 @@
model_context.set_callable_options(nullptr);
}
+ if (options.graph_execution_options.use_ifrt) {
+ std::optional<ifrt_serving::IfrtModelRestoreContext*>
+ model_restore_context =
+ model_context.resource_context()
+ .GetResource<ifrt_serving::IfrtModelRestoreContext>(
+ ifrt_serving::kIfrtModelRestoreContextName);
+ if (!model_restore_context.has_value()) {
+ return absl::InternalError(
+ "Did not find IfrtModelRestoreContext resource.");
+ }
+ if (*model_restore_context == nullptr) {
+ return absl::InternalError("IfrtModelRestoreContexts must not be null.");
+ }
+ TF_RETURN_IF_ERROR(
+ PrepareRestore(&context, &model_context, meta_graph_def,
+ *fallback_state, std::string(saved_model_dir), options,
+ (*model_restore_context)->checkpoint_loader()));
+ }
+
GetDefaultInputValue(meta_graph_def.signature_def(), model_context,
initializers_and_signatures.signature_map);
diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD
index 3dfc07d..c026800 100644
--- a/tensorflow/core/tfrt/saved_model/tests/BUILD
+++ b/tensorflow/core/tfrt/saved_model/tests/BUILD
@@ -649,7 +649,9 @@
"//tensorflow/core/platform:resource_loader",
"//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink",
"//tensorflow/core/tfrt:ifrt_program_ops_op_lib",
+ "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
"//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+ "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
"//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector",
"//tensorflow/core/tfrt/mlrt/kernel:ifrt_ops_kernel",
"//tensorflow/core/tfrt/runtime",
diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc
index 8eaf225..4f4caf0 100644
--- a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc
+++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc
@@ -26,14 +26,16 @@
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/test_util.h"
#include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/resource_loader.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool.h"
@@ -77,10 +79,17 @@
"IfrtModelContext", client, &core_selector, &GetThreadPool(),
/*compilation_environment_proto=*/nullptr);
- (*model_context.resource_context()
- .GetResource<tensorflow::ifrt_serving::IfrtModelContext>(
- "IfrtModelContext"))
- ->set_checkpoint_loader_queue(work_queue.get());
+ tensorflow::ifrt_serving::IfrtModelContext* ifrt_model_context =
+ (*model_context.resource_context()
+ .GetResource<tensorflow::ifrt_serving::IfrtModelContext>(
+ "IfrtModelContext"));
+ ifrt_model_context->set_checkpoint_loader_queue(work_queue.get());
+ model_context.resource_context()
+ .CreateResource<tensorflow::ifrt_serving::IfrtModelRestoreContext>(
+ ifrt_serving::kIfrtModelRestoreContextName,
+ std::make_unique<tensorflow::ifrt_serving::CheckpointLoader>(
+ &ifrt_model_context->GetRestoreTensorRegistry(),
+ ifrt_model_context->checkpoint_loader_queue()));
return absl::OkStatus();
});
diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc
index 183b418..605e441 100644
--- a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc
+++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc
@@ -34,6 +34,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
@@ -45,7 +46,6 @@
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_util.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
diff --git a/tensorflow/core/tfrt/saved_model/utils/BUILD b/tensorflow/core/tfrt/saved_model/utils/BUILD
index b76aa6f..008fea2 100644
--- a/tensorflow/core/tfrt/saved_model/utils/BUILD
+++ b/tensorflow/core/tfrt/saved_model/utils/BUILD
@@ -60,8 +60,8 @@
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@tf_runtime//:bef",
],
)
diff --git a/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
index deaf171..2fc9c28 100644
--- a/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
+++ b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
@@ -26,6 +26,7 @@
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h"
#include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
@@ -33,7 +34,6 @@
#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_util.h"
#include "tensorflow/core/tfrt/utils/utils.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tfrt/bef/bef_buffer.h" // from @tf_runtime
diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc
index 2b63268..b63fc76 100644
--- a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc
+++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc
@@ -29,6 +29,7 @@
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/saved_model/reader.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -44,7 +45,6 @@
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
#include "tensorflow/core/tfrt/utils/thread_pool.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/protobuf.h"
namespace tensorflow {
diff --git a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc
index 330adde..f5af931 100644
--- a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc
+++ b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc
@@ -34,6 +34,7 @@
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/saved_model/reader.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
@@ -45,7 +46,6 @@
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/path.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
index 498d07a..c18230e 100644
--- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
+++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
@@ -250,9 +250,9 @@
DumpGraphDefToFile("before_pruning", graph_def);
}
- TF_ASSIGN_OR_RETURN(
- result.graph,
- CreatePrunedGraph(graph_def, build_graph_options.callable_options));
+ TF_ASSIGN_OR_RETURN(result.graph,
+ CreatePrunedGraph(std::move(graph_def),
+ build_graph_options.callable_options));
DCHECK(result.graph);
if (VLOG_IS_ON(1)) {
diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD
index 273c822..73fbacd 100644
--- a/tensorflow/core/tpu/graph_rewrite/BUILD
+++ b/tensorflow/core/tpu/graph_rewrite/BUILD
@@ -1,6 +1,10 @@
# Contains graph rewrites for TPU runtimes and optimizations.
load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+load(
"//tensorflow/core/platform:build_config_root.bzl",
"if_static",
)
@@ -119,6 +123,7 @@
"//tensorflow/core:session_options",
"//tensorflow/core/common_runtime:function_body",
"//tensorflow/core/common_runtime:function_utils",
+ "//tensorflow/core/config:flag_defs",
"//tensorflow/core/tpu:tpu_compile_interface",
"//tensorflow/core/tpu:tpu_defs",
"@com_google_absl//absl/container:flat_hash_map",
@@ -131,6 +136,7 @@
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:status",
"@local_xla//xla:status_macros",
+ "@local_xla//xla/tsl/util:env_var",
] + if_static(
[
"//tensorflow/core/common_runtime:function",
@@ -140,6 +146,26 @@
),
)
+tf_cc_test(
+ name = "encapsulate_tpu_computations_pass_test",
+ srcs = ["encapsulate_tpu_computations_pass_test.cc"],
+ deps = [
+ ":encapsulate_tpu_computations_pass",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/common_runtime:optimization_registry",
+ "//tensorflow/core/config:flag_defs",
+ ],
+)
+
cc_library(
name = "distributed_tpu_rewrite_pass_internal",
srcs = ["distributed_tpu_rewrite_pass_internal.cc"],
diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
index 62a4c45..9370cd6 100644
--- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
@@ -48,6 +48,7 @@
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/config/flag_defs.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
@@ -2481,10 +2482,32 @@
return absl::OkStatus();
}
+// TODO(b/355263902): Encapsulation fails for some non-TPU graphs that are
+// missing full variable shape information. Remove this path once the
+// underlying issue is fixed.
+bool ShouldSkipEncapsulationForNonTPUGraph() {
+ return flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.value();
+}
+
} // namespace
/*static*/ Status EncapsulateTPUComputationsPass::Encapsulate(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ // If the graph does not contain any TPU computations, there is nothing to do.
+ if (ShouldSkipEncapsulationForNonTPUGraph()) {
+ bool found_tpu_replicate = false;
+ for (const Node* n : (*graph)->nodes()) {
+ if (n->attrs().Find(kTPUReplicateAttr) != nullptr) {
+ found_tpu_replicate = true;
+ break;
+ }
+ }
+ if (!found_tpu_replicate) {
+ VLOG(1) << "No TPU replicate found, skipping encapsulation";
+ return absl::OkStatus();
+ }
+ }
+
// Check for undeclared outputs before Encapsulation, so we can give a better
// error message.
// TODO(phawkins): merge this with the encapsulation code to avoid the extra
diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc
new file mode 100644
index 0000000..a21cdae
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
+
+#include <memory>
+
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/config/flag_defs.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+std::unique_ptr<Graph> CreateGraph() {
+ // c = a + b
+ auto g = std::make_unique<Graph>(OpRegistry::Global());
+ auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto in1 = test::graph::Arg(g.get(), 1, DT_FLOAT);
+ auto tmp = test::graph::Add(g.get(), in0, in1);
+ auto ret = test::graph::Retval(g.get(), 0, tmp);
+ g->AddControlEdge(in1, ret);
+ FixupSourceAndSinkEdges(g.get());
+ return g;
+}
+
+TEST(EncapsulateTPUComputationsPassTest, NonTPUGraph) {
+ auto g = CreateGraph();
+ GraphOptimizationPassOptions options;
+ options.graph = &g;
+ options.flib_def = g->mutable_flib_def();
+
+ EncapsulateTPUComputationsPass pass;
+ TF_ASSERT_OK(pass.Run(options));
+
+ int nodes_meeting_expectations = 0;
+
+ for (const auto* node : g->nodes()) {
+ if (!IsSource(node) && !IsSink(node)) {
+ ASSERT_TRUE(node->attrs().Find("_xla_inferred_shapes"));
+ ++nodes_meeting_expectations;
+ }
+ }
+ EXPECT_EQ(nodes_meeting_expectations, 4);
+}
+
+TEST(EncapsulateTPUComputationsPassTest, SkipEncapsulationForNonTPUGraph) {
+ flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(true);
+ auto g = CreateGraph();
+ GraphOptimizationPassOptions options;
+ options.graph = &g;
+ options.flib_def = g->mutable_flib_def();
+
+ EncapsulateTPUComputationsPass pass;
+ TF_ASSERT_OK(pass.Run(options));
+
+ int nodes_meeting_expectations = 0;
+
+ for (const auto* node : g->nodes()) {
+ if (!IsSource(node) && !IsSink(node)) {
+ ASSERT_FALSE(node->attrs().Find("_xla_inferred_shapes"));
+ ++nodes_meeting_expectations;
+ }
+ }
+ EXPECT_EQ(nodes_meeting_expectations, 4);
+
+ flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(false);
+}
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc
index e62e9d7..084802a 100644
--- a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc
+++ b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc
@@ -25,6 +25,7 @@
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
@@ -39,7 +40,6 @@
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/core/tpu/kernels/sharding_utils_test.cc b/tensorflow/core/tpu/kernels/sharding_utils_test.cc
index cd583df..552a637 100644
--- a/tensorflow/core/tpu/kernels/sharding_utils_test.cc
+++ b/tensorflow/core/tpu/kernels/sharding_utils_test.cc
@@ -26,11 +26,11 @@
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/status.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.cc b/tensorflow/core/tpu/kernels/sparse_core_layout.cc
index 2f4a945b..bc4c416 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout.cc
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout.cc
@@ -72,10 +72,10 @@
activation_mem_bytes_limit_(GetXlaSparseCoreStackingMemLimit()),
variable_shard_bytes_limit_(GetXlaSparseCoreStackingTableShardLimit()) {}
-absl::Status SparseCoreLayoutStacker::AddTable(tsl::StringPiece table_name,
+absl::Status SparseCoreLayoutStacker::AddTable(absl::string_view table_name,
int64_t table_height,
int64_t table_width,
- tsl::StringPiece group,
+ absl::string_view group,
int64_t output_samples) {
if (stacks_by_group_.empty()) { // First call?
VLOG(1) << "Stacking parameters: stacking_enabled_ = " << stacking_enabled_
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.h b/tensorflow/core/tpu/kernels/sparse_core_layout.h
index c1d22f3..9f4697c 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout.h
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout.h
@@ -84,8 +84,8 @@
//
// Be sure you call AddTable in a deterministic order; the details of the
// stacking will depend on the order you call AddTable.
- absl::Status AddTable(tsl::StringPiece table_name, int64_t table_height,
- int64_t table_width, tsl::StringPiece group,
+ absl::Status AddTable(absl::string_view table_name, int64_t table_height,
+ int64_t table_width, absl::string_view group,
int64_t output_samples);
// Get the information about each table out.
diff --git a/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc b/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc
index ca2b1d5..d854eab 100644
--- a/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc
+++ b/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc
@@ -21,11 +21,11 @@
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/protobuf.h" // IWYU pragma: keep
#include "tsl/platform/test.h"
diff --git a/tensorflow/core/tpu/tpu_embedding_errors_test.cc b/tensorflow/core/tpu/tpu_embedding_errors_test.cc
index 3dbb182..f0a8d86 100644
--- a/tensorflow/core/tpu/tpu_embedding_errors_test.cc
+++ b/tensorflow/core/tpu/tpu_embedding_errors_test.cc
@@ -21,8 +21,8 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/errors.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD
index 8b89487..990edbe 100644
--- a/tensorflow/core/util/autotune_maps/BUILD
+++ b/tensorflow/core/util/autotune_maps/BUILD
@@ -52,8 +52,8 @@
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/protobuf:dnn_proto_cc",
+ "@local_xla//xla/tsl/lib/strings:proto_serialization",
],
)
@@ -118,7 +118,7 @@
"conv_parameters.h",
],
cuda_deps = [
- "@local_tsl//tsl/lib/strings:proto_serialization",
+ "@local_xla//xla/tsl/lib/strings:proto_serialization",
],
deps = [
":conv_parameters_proto_cc",
@@ -182,12 +182,12 @@
"//tensorflow/core:framework",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:str_util",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/protobuf:dnn_proto_cc",
"@local_xla//xla:status_macros",
"@local_xla//xla/stream_executor:dnn",
"@local_xla//xla/stream_executor:platform_manager",
"@local_xla//xla/stream_executor/gpu:gpu_init",
+ "@local_xla//xla/tsl/lib/strings:proto_serialization",
],
)
diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.cc b/tensorflow/core/util/autotune_maps/autotune_serialize.cc
index 63470c0..c601502 100644
--- a/tensorflow/core/util/autotune_maps/autotune_serialize.cc
+++ b/tensorflow/core/util/autotune_maps/autotune_serialize.cc
@@ -25,13 +25,13 @@
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/gpu/gpu_init.h"
#include "xla/stream_executor/platform_manager.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/util/activation_mode.h"
#include "tensorflow/core/util/autotune_maps/autotune_map.pb.h"
#include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
#include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/protobuf/dnn.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc
index baa68aa..0bd1122 100644
--- a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc
+++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc
@@ -19,9 +19,9 @@
#include "absl/log/check.h"
#include "absl/status/status.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "tensorflow/core/util/autotune_maps/autotune_map.pb.h"
#include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/protobuf/dnn.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/util/autotune_maps/conv_parameters.cc b/tensorflow/core/util/autotune_maps/conv_parameters.cc
index 6343693..a620e39 100644
--- a/tensorflow/core/util/autotune_maps/conv_parameters.cc
+++ b/tensorflow/core/util/autotune_maps/conv_parameters.cc
@@ -19,9 +19,9 @@
#include <vector>
#include "absl/strings/str_format.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/hash.h"
#include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
namespace tensorflow {
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
index 5c8a5db..61d1fb5 100644
--- a/tensorflow/core/util/bcast.h
+++ b/tensorflow/core/util/bcast.h
@@ -199,7 +199,6 @@
prev_is_one[i] = false;
current_is_one[i] = false;
}
- Vec output;
bool output_dim_set = false;
int64_t output_dim = -1;
bool none_is_one = true;
diff --git a/tensorflow/core/util/dump_graph_test.cc b/tensorflow/core/util/dump_graph_test.cc
index d24eccf..935ca41 100644
--- a/tensorflow/core/util/dump_graph_test.cc
+++ b/tensorflow/core/util/dump_graph_test.cc
@@ -16,6 +16,7 @@
#include "tensorflow/core/util/dump_graph.h"
#include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -23,7 +24,6 @@
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
namespace tensorflow {
diff --git a/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc b/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc
index fe27a6c..ad28eeb 100644
--- a/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc
+++ b/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc
@@ -16,9 +16,9 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/core/util/strided_slice_op_test.cc b/tensorflow/core/util/strided_slice_op_test.cc
index cbe0976..6eb961c 100644
--- a/tensorflow/core/util/strided_slice_op_test.cc
+++ b/tensorflow/core/util/strided_slice_op_test.cc
@@ -22,12 +22,12 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc
index 6e9c20d..05f5d0f 100644
--- a/tensorflow/core/util/util.cc
+++ b/tensorflow/core/util/util.cc
@@ -151,10 +151,12 @@
} else if (dt == DT_HALF) {
// Float16 is not supported in oneDNN v2.x
#ifdef ENABLE_ONEDNN_V3
- result = (TestCPUFeature(port::CPUFeature::AVX512BW) &&
- (TestCPUFeature(port::CPUFeature::AVX512_FP16) ||
- TestCPUFeature(port::CPUFeature::AMX_FP16) ||
- TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT)));
+ // Some CPUs that don't support AVX-512 use AVX-NE-CONVERT to cast to and
+ // from FP32
+ result = ((TestCPUFeature(port::CPUFeature::AVX512BW) &&
+ (TestCPUFeature(port::CPUFeature::AVX512_FP16) ||
+ TestCPUFeature(port::CPUFeature::AMX_FP16))) ||
+ TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT));
if (result) VLOG(2) << "CPU supports " << DataType_Name(dt);
#endif // ENABLE_ONEDNN_V3
} else {
diff --git a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc
index 802d46f..2e24e5d 100644
--- a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc
+++ b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc
@@ -165,7 +165,7 @@
mesh_name = "";
}
const std::vector<int>& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name];
- VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", ");
+ VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", ");
xla::DeviceAssignmentProto device_assignment;
device_assignment.set_replica_count(1);
@@ -223,7 +223,7 @@
mesh_name = "";
}
const std::vector<int>& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name];
- VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", ");
+ VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", ");
xla::DeviceAssignmentProto device_assignment;
device_assignment.set_replica_count(num_replicas);
diff --git a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc
index ff29775..475e08c 100644
--- a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc
+++ b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc
@@ -22,11 +22,11 @@
#include "absl/strings/str_cat.h"
#include "benchmark/benchmark.h" // from @com_google_benchmark
#include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/dtensor/cc/dstatus.h"
#include "tensorflow/dtensor/cc/tensor_layout.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
namespace tensorflow {
diff --git a/tensorflow/examples/speech_commands/accuracy_utils_test.cc b/tensorflow/examples/speech_commands/accuracy_utils_test.cc
index cf4f5ba..7edd1b4 100644
--- a/tensorflow/examples/speech_commands/accuracy_utils_test.cc
+++ b/tensorflow/examples/speech_commands/accuracy_utils_test.cc
@@ -15,11 +15,11 @@
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/examples/speech_commands/recognize_commands_test.cc b/tensorflow/examples/speech_commands/recognize_commands_test.cc
index 1730d06..1f13e24 100644
--- a/tensorflow/examples/speech_commands/recognize_commands_test.cc
+++ b/tensorflow/examples/speech_commands/recognize_commands_test.cc
@@ -15,12 +15,12 @@
#include "tensorflow/examples/speech_commands/recognize_commands.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index 4a1d5bd..eb41fb0 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -830,7 +830,6 @@
deps = [
":minimal_logging",
"//tensorflow/lite/core/api:error_reporter",
- "//tensorflow/lite/core/c:common",
],
)
diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt
index eb93434..6c288ad 100644
--- a/tensorflow/lite/CMakeLists.txt
+++ b/tensorflow/lite/CMakeLists.txt
@@ -373,7 +373,9 @@
list(APPEND TFLITE_DELEGATES_GPU_SRCS
${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc
${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc
+ ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.h
${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.cc
+ ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.h
${TFLITE_SOURCE_DIR}/delegates/gpu/tflite_profile.cc
${TFLITE_SOURCE_DIR}/experimental/acceleration/compatibility/android_info.cc
${TFLITE_DELEGATES_GPU_CL_SRCS}
@@ -681,6 +683,10 @@
${TF_SOURCE_DIR}/compiler/mlir/lite/utils/string_utils.h
${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.h
${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.cc
+ ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.h
+ ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.cc
+ ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.h
+ ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.cc
${TFLITE_SOURCE_DIR}/schema/schema_generated.h
)
add_library(tensorflow-lite
diff --git a/tensorflow/lite/allocation.cc b/tensorflow/lite/allocation.cc
index b187ef0..bbc41fa 100644
--- a/tensorflow/lite/allocation.cc
+++ b/tensorflow/lite/allocation.cc
@@ -21,6 +21,8 @@
#include <cstdint>
#include <cstdio>
+#include <cstdlib>
+#include <cstring>
#include <memory>
#include "tensorflow/lite/core/api/error_reporter.h"
@@ -100,11 +102,37 @@
}
#endif // __arm__
+// `android_local_test` doesn't support zipalign b/356640509 so we need this
+// workaround to keep our clients working.
+// TODO: b/356413060 - Remove the workaround once b/356640509 is fixed.
+#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+ if ((reinterpret_cast<uintptr_t>(ptr) & 0x3) != 0) {
+ aligned_ptr_ = ::aligned_alloc(4, num_bytes);
+ if (aligned_ptr_ == nullptr) {
+ TF_LITE_REPORT_ERROR(error_reporter, "Failed to allocate aligned buffer");
+ buffer_ = nullptr;
+ buffer_size_bytes_ = 0;
+ return;
+ }
+ memcpy(aligned_ptr_, ptr, num_bytes);
+ buffer_ = aligned_ptr_;
+ } else {
+ buffer_ = ptr;
+ }
+#else // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
buffer_ = ptr;
+#endif // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+
buffer_size_bytes_ = num_bytes;
}
-MemoryAllocation::~MemoryAllocation() {}
+MemoryAllocation::~MemoryAllocation() {
+#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+ if (aligned_ptr_) {
+ free(aligned_ptr_);
+ }
+#endif
+}
const void* MemoryAllocation::base() const { return buffer_; }
diff --git a/tensorflow/lite/allocation.h b/tensorflow/lite/allocation.h
index 6840646..f007b3c 100644
--- a/tensorflow/lite/allocation.h
+++ b/tensorflow/lite/allocation.h
@@ -145,6 +145,9 @@
private:
const void* buffer_;
+#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+ void* aligned_ptr_ = nullptr;
+#endif
size_t buffer_size_bytes_ = 0;
};
diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD
index d18efae..4309e28 100644
--- a/tensorflow/lite/core/BUILD
+++ b/tensorflow/lite/core/BUILD
@@ -43,13 +43,14 @@
],
compatible_with = get_compatible_with_portable(),
copts = tflite_copts() + tflite_copts_warnings(),
- visibility = [
- "//tensorflow/lite:__subpackages__",
- ],
+ visibility = ["//tensorflow/lite:__subpackages__"],
deps = [
":cc_api_stable",
":signature_runner",
+ "//tensorflow/compiler/mlir/lite/core:macros",
+ "//tensorflow/compiler/mlir/lite/core:model_builder_base",
"//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
+ "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/lite:allocation",
"//tensorflow/lite:array",
"//tensorflow/lite:external_cpu_backend_context",
@@ -120,6 +121,7 @@
":cc_api_stable",
":model_builder",
":signature_runner",
+ "//tensorflow/compiler/mlir/lite/core:model_builder_base",
"//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
"//tensorflow/lite:allocation",
"//tensorflow/lite:array",
@@ -135,7 +137,6 @@
"//tensorflow/lite:util",
"//tensorflow/lite/c:common_internal",
"//tensorflow/lite/core/api",
- "//tensorflow/lite/core/api:verifier",
"//tensorflow/lite/core/async:async_signature_runner",
"//tensorflow/lite/core/c:common",
"//tensorflow/lite/experimental/resource",
@@ -174,7 +175,10 @@
":model_builder",
":signature_runner",
":subgraph",
+ "//tensorflow/compiler/mlir/lite/core:model_builder_base",
"//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
+ "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
+ "//tensorflow/compiler/mlir/lite/schema:schema_utils",
"//tensorflow/lite:allocation",
"//tensorflow/lite:array",
"//tensorflow/lite:external_cpu_backend_context",
@@ -253,6 +257,7 @@
deps = [
":cc_api_stable",
":signature_runner",
+ "//tensorflow/compiler/mlir/lite/core:model_builder_base",
"//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
"//tensorflow/lite:allocation",
"//tensorflow/lite:array",
@@ -268,7 +273,6 @@
"//tensorflow/lite:util",
"//tensorflow/lite/c:common_internal",
"//tensorflow/lite/core/api",
- "//tensorflow/lite/core/api:verifier",
"//tensorflow/lite/core/async:async_signature_runner",
"//tensorflow/lite/core/c:c_api_types",
"//tensorflow/lite/core/c:common",
@@ -286,22 +290,14 @@
cc_library(
name = "model_builder",
- srcs = ["model_builder.cc"],
hdrs = ["model_builder.h"],
compatible_with = get_compatible_with_portable(),
copts = tflite_copts_warnings(),
visibility = internal_visibility_allowlist(),
deps = [
- ":macros",
- "//tensorflow/compiler/mlir/lite/core:macros",
- "//tensorflow/lite:allocation",
+ "//tensorflow/compiler/mlir/lite/core:model_builder_base",
"//tensorflow/lite:stderr_reporter",
- "//tensorflow/lite:string",
"//tensorflow/lite/core/api:error_reporter",
- "//tensorflow/lite/core/api:verifier",
- "//tensorflow/lite/schema:schema_fbs",
- "@com_google_absl//absl/strings",
- "@flatbuffers",
],
alwayslink = 1,
)
@@ -336,6 +332,7 @@
deps = [
":framework",
":signature_runner",
+ "//tensorflow/compiler/mlir/lite/core:model_builder_base",
"//tensorflow/lite:model_builder",
"//tensorflow/lite/core/kernels:builtin_ops",
"//tensorflow/lite/testing:util",
@@ -366,6 +363,7 @@
],
tags = [
"no_windows", # TODO(b/194459105): the test is flaky.
+ "noasan",
"tflite_not_portable",
"tflite_smoke_test",
],
diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD
index 08ac033..ac1bb24 100644
--- a/tensorflow/lite/core/api/BUILD
+++ b/tensorflow/lite/core/api/BUILD
@@ -11,7 +11,6 @@
filegroup(
name = "tflite_internal_cc_3p_api_deps_src",
srcs = [
- ":error_reporter.cc",
":error_reporter.h",
":op_resolver.cc",
":op_resolver.h",
@@ -77,23 +76,33 @@
cc_library(
name = "error_reporter",
- srcs = ["error_reporter.cc"],
- hdrs = ["error_reporter.h"],
+ hdrs = [
+ "error_reporter.h",
+ "//tensorflow/compiler/mlir/lite/core/api:error_reporter.h",
+ ],
compatible_with = get_compatible_with_portable(),
copts = tflite_copts(),
visibility = [
"//visibility:public",
],
- deps = [],
+ deps = [
+ "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+ ],
)
cc_library(
name = "verifier",
- hdrs = ["verifier.h"],
+ hdrs = [
+ "verifier.h",
+ "//tensorflow/compiler/mlir/lite/core/api:verifier.h",
+ ],
compatible_with = get_compatible_with_portable(),
copts = tflite_copts(),
visibility = ["//visibility:public"],
- deps = [":error_reporter"],
+ deps = [
+ "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+ "//tensorflow/compiler/mlir/lite/core/api:verifier",
+ ],
)
cc_library(
@@ -109,16 +118,6 @@
)
cc_test(
- name = "error_reporter_test",
- size = "small",
- srcs = ["error_reporter_test.cc"],
- deps = [
- ":api",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_test(
name = "op_resolver_test",
size = "small",
srcs = ["op_resolver_test.cc"],
diff --git a/tensorflow/lite/core/api/error_reporter.cc b/tensorflow/lite/core/api/error_reporter.cc
deleted file mode 100644
index 7070eaa..0000000
--- a/tensorflow/lite/core/api/error_reporter.cc
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/lite/core/api/error_reporter.h"
-#include <cstdarg>
-
-namespace tflite {
-
-int ErrorReporter::Report(const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
-// TODO(aselle): Make the name of ReportError on context the same, so
-// we can use the ensure functions w/o a context and w/ a reporter.
-int ErrorReporter::ReportError(void*, const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
-} // namespace tflite
diff --git a/tensorflow/lite/core/api/error_reporter.h b/tensorflow/lite/core/api/error_reporter.h
index 1e0ef7d..f910604 100644
--- a/tensorflow/lite/core/api/error_reporter.h
+++ b/tensorflow/lite/core/api/error_reporter.h
@@ -15,58 +15,6 @@
#ifndef TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
#define TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
-#include <cstdarg>
-
-namespace tflite {
-
-/// A functor that reports error to supporting system. Invoked similar to
-/// printf.
-///
-/// Usage:
-/// ErrorReporter foo;
-/// foo.Report("test %d", 5);
-/// or
-/// va_list args;
-/// foo.Report("test %d", args); // where args is va_list
-///
-/// Subclass ErrorReporter to provide another reporting destination.
-/// For example, if you have a GUI program, you might redirect to a buffer
-/// that drives a GUI error log box.
-class ErrorReporter {
- public:
- virtual ~ErrorReporter() = default;
- /// Converts `args` to character equivalents according to `format` string,
- /// constructs the error string and report it.
- /// Returns number of characters written or zero on success, and negative
- /// number on error.
- virtual int Report(const char* format, va_list args) = 0;
-
- /// Converts arguments to character equivalents according to `format` string,
- /// constructs the error string and report it.
- /// Returns number of characters written or zero on success, and negative
- /// number on error.
- int Report(const char* format, ...);
-
- /// Equivalent to `Report` above. The additional `void*` parameter is unused.
- /// This method is for compatibility with macros that takes `TfLiteContext`,
- /// like TF_LITE_ENSURE and related macros.
- int ReportError(void*, const char* format, ...);
-};
-
-} // namespace tflite
-
-// You should not make bare calls to the error reporter, instead use the
-// TF_LITE_REPORT_ERROR macro, since this allows message strings to be
-// stripped when the binary size has to be optimized. If you are looking to
-// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and
-// every call will be stubbed out, taking no memory.
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
-#define TF_LITE_REPORT_ERROR(reporter, ...) \
- do { \
- static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \
- } while (false)
-#else // TF_LITE_STRIP_ERROR_STRINGS
-#define TF_LITE_REPORT_ERROR(reporter, ...)
-#endif // TF_LITE_STRIP_ERROR_STRINGS
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" // IWYU pragma: export
#endif // TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/lite/core/api/error_reporter_test.cc b/tensorflow/lite/core/api/error_reporter_test.cc
deleted file mode 100644
index 03d6da73..0000000
--- a/tensorflow/lite/core/api/error_reporter_test.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/core/api/error_reporter.h"
-
-#include <cstdio>
-
-#include <gtest/gtest.h>
-
-namespace tflite {
-
-class MockErrorReporter : public ErrorReporter {
- public:
- MockErrorReporter() { buffer_[0] = 0; }
- int Report(const char* format, va_list args) override {
- vsnprintf(buffer_, kBufferSize, format, args);
- return 0;
- }
- char* GetBuffer() { return buffer_; }
-
- private:
- static constexpr int kBufferSize = 256;
- char buffer_[kBufferSize];
-};
-
-TEST(ErrorReporter, TestReport) {
- MockErrorReporter mock_reporter;
- ErrorReporter* reporter = &mock_reporter;
- reporter->Report("Error: %d", 23);
- EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
-}
-
-TEST(ErrorReporter, TestReportMacro) {
- MockErrorReporter mock_reporter;
- // Only define the reporter if it's used, to avoid warnings.
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
- ErrorReporter* reporter = &mock_reporter;
-#endif // TFLITE_STRIP_ERROR_STRINGS
-
- TF_LITE_REPORT_ERROR(reporter, "Error: %d", 23);
-
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
- EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
-#else // TF_LITE_STRIP_ERROR_STRINGS
- EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), ""));
-#endif // TF_LITE_STRIP_ERROR_STRINGS
-}
-
-} // namespace tflite
diff --git a/tensorflow/lite/core/api/verifier.h b/tensorflow/lite/core/api/verifier.h
index 8128ff3..dcb1d02 100644
--- a/tensorflow/lite/core/api/verifier.h
+++ b/tensorflow/lite/core/api/verifier.h
@@ -18,22 +18,6 @@
#ifndef TENSORFLOW_LITE_CORE_API_VERIFIER_H_
#define TENSORFLOW_LITE_CORE_API_VERIFIER_H_
-#include "tensorflow/lite/core/api/error_reporter.h"
-
-namespace tflite {
-
-/// Abstract interface that verifies whether a given model is legit.
-/// It facilitates the use-case to verify and build a model without loading it
-/// twice.
-/// (See also "tensorflow/lite/tools/verifier.h".)
-class TfLiteVerifier {
- public:
- /// Returns true if the model is legit.
- virtual bool Verify(const char* data, int length,
- ErrorReporter* reporter) = 0;
- virtual ~TfLiteVerifier() {}
-};
-
-} // namespace tflite
+#include "tensorflow/compiler/mlir/lite/core/api/verifier.h" // IWYU pragma: export
#endif // TENSORFLOW_LITE_CORE_API_VERIFIER_H_
diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h
index 96f19f1..648b862 100644
--- a/tensorflow/lite/core/c/common.h
+++ b/tensorflow/lite/core/c/common.h
@@ -100,7 +100,9 @@
TfLiteStatus (*Refresh)(struct TfLiteContext* context);
} TfLiteExternalContext;
+// LINT.IfChange(optional_tensor)
#define kTfLiteOptionalTensor (-1)
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/flatbuffer_export.cc:optional_tensor)
/// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
/// indices
diff --git a/tensorflow/lite/core/model_builder.cc b/tensorflow/lite/core/model_builder.cc
deleted file mode 100644
index afa8513..0000000
--- a/tensorflow/lite/core/model_builder.cc
+++ /dev/null
@@ -1,474 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/lite/core/model_builder.h"
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include <algorithm>
-#include <cstring>
-#include <map>
-#include <memory>
-#include <string>
-#include <utility>
-
-#include "absl/strings/str_cat.h"
-#include "flatbuffers/base.h" // from @flatbuffers
-#include "flatbuffers/buffer.h" // from @flatbuffers
-#include "flatbuffers/vector.h" // from @flatbuffers
-#include "flatbuffers/verifier.h" // from @flatbuffers
-#include "tensorflow/compiler/mlir/lite/core/macros.h"
-#include "tensorflow/lite/allocation.h"
-#include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/core/api/verifier.h"
-#include "tensorflow/lite/schema/schema_generated.h"
-#include "tensorflow/lite/stderr_reporter.h"
-#include "tensorflow/lite/string_type.h"
-
-namespace tflite {
-
-namespace {
-
-// Ensure that ErrorReporter is non-null.
-ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
- return e ? e : DefaultErrorReporter();
-}
-
-} // namespace
-
-#ifndef TFLITE_MCU
-// Loads a model from `filename`. If `mmap_file` is true then use mmap,
-// otherwise make a copy of the model in a buffer.
-std::unique_ptr<Allocation> GetAllocationFromFile(
- const char* filename, ErrorReporter* error_reporter) {
- std::unique_ptr<Allocation> allocation;
- if (MMAPAllocation::IsSupported()) {
- allocation = std::make_unique<MMAPAllocation>(filename, error_reporter);
- } else {
- allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter);
- }
- return allocation;
-}
-
-// Loads a model from `fd`. If `mmap_file` is true then use mmap,
-// otherwise make a copy of the model in a buffer.
-std::unique_ptr<Allocation> GetAllocationFromFile(
- int fd, ErrorReporter* error_reporter) {
- std::unique_ptr<Allocation> allocation;
- if (MMAPAllocation::IsSupported()) {
- allocation = std::make_unique<MMAPAllocation>(fd, error_reporter);
- } else {
- allocation = std::make_unique<FileCopyAllocation>(
- absl::StrCat("/proc/self/fd/", fd).c_str(), error_reporter);
- }
- return allocation;
-}
-
-namespace impl {
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
- const char* filename, ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
- std::unique_ptr<FlatBufferModel> model = BuildFromAllocation(
- GetAllocationFromFile(filename, error_reporter), error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
- return model;
-#else
- return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
- const char* filename, TfLiteVerifier* extra_verifier,
- ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
- std::unique_ptr<FlatBufferModel> model = VerifyAndBuildFromAllocation(
- GetAllocationFromFile(filename, error_reporter), extra_verifier,
- error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
- return model;
-#else
- return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFileDescriptor(
- int fd, ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
- std::unique_ptr<FlatBufferModel> model = BuildFromAllocation(
- GetAllocationFromFile(fd, error_reporter), error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
- return model;
-#else
- return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-std::unique_ptr<FlatBufferModel>
-FlatBufferModel::VerifyAndBuildFromFileDescriptor(
- int fd, TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
- std::unique_ptr<FlatBufferModel> model =
- VerifyAndBuildFromAllocation(GetAllocationFromFile(fd, error_reporter),
- extra_verifier, error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
- return model;
-#else
- return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-} // namespace impl
-
-#endif
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
- const char* caller_owned_buffer, size_t buffer_size,
- ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
- std::unique_ptr<Allocation> allocation(
- new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
- return BuildFromAllocation(std::move(allocation), error_reporter);
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
- const char* caller_owned_buffer, size_t buffer_size,
- TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
- std::unique_ptr<Allocation> allocation(
- new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
- return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
- error_reporter);
-}
-
-#if FLATBUFFERS_LITTLEENDIAN == 0
-
-void FlatBufferModel::ByteSwapSerializedModel(std::string* serialized_model,
- bool from_big_endian) {
- const uint8_t* buffer =
- reinterpret_cast<const uint8_t*>(serialized_model->c_str());
- const tflite::Model* input_model = tflite::GetModel(buffer);
- ByteSwapTFLiteModel(input_model, from_big_endian);
-}
-
-void FlatBufferModel::ByteSwapBuffer(int8_t tensor_type, size_t buffer_size,
- uint8_t* buffer, bool from_big_endian) {
- switch (tensor_type) {
- case tflite::TensorType_STRING: {
- auto bp = reinterpret_cast<int32_t*>(buffer);
- int num_of_strings =
- from_big_endian ? bp[0] : flatbuffers::EndianSwap(bp[0]);
- for (int i = 0; i < num_of_strings + 2; i++)
- bp[i] = flatbuffers::EndianSwap(bp[i]);
- break;
- }
- // 16-bit types
- case tflite::TensorType_FLOAT16:
- case tflite::TensorType_INT16:
- case tflite::TensorType_UINT16: {
- auto bp = reinterpret_cast<uint16_t*>(buffer);
- for (int i = 0; i < buffer_size / 2; i++)
- bp[i] = flatbuffers::EndianSwap(bp[i]);
- break;
- }
- // 32-bit types
- case tflite::TensorType_FLOAT32:
- case tflite::TensorType_INT32:
- case tflite::TensorType_UINT32:
- case tflite::TensorType_COMPLEX64: {
- auto bp = reinterpret_cast<uint32_t*>(buffer);
- for (int i = 0; i < buffer_size / 4; i++)
- bp[i] = flatbuffers::EndianSwap(bp[i]);
- break;
- }
- // 64-bit types
- case tflite::TensorType_INT64:
- case tflite::TensorType_FLOAT64:
- case tflite::TensorType_UINT64:
- case tflite::TensorType_COMPLEX128: {
- auto bp = reinterpret_cast<uint64_t*>(buffer);
- for (int i = 0; i < buffer_size / 8; i++)
- bp[i] = flatbuffers::EndianSwap(bp[i]);
- break;
- }
- default:
- break;
- }
-}
-
-void FlatBufferModel::ByteSwapTFLiteModel(const tflite::Model* tfl_model,
- bool from_big_endian) {
- bool buffer_swapped[tfl_model->buffers()->size()] = {};
- for (size_t subgraph_idx = 0; subgraph_idx < tfl_model->subgraphs()->size();
- subgraph_idx++) {
- const tflite::SubGraph* subgraph =
- tfl_model->subgraphs()->Get(subgraph_idx);
- for (size_t ts_idx = 0; ts_idx < subgraph->tensors()->size(); ts_idx++) {
- const tflite::Tensor* tensor = subgraph->tensors()->Get(ts_idx);
- if (tensor->buffer() > 0 &&
- tensor->buffer() < tfl_model->buffers()->size() &&
- !buffer_swapped[tensor->buffer()]) {
- const tflite::Buffer* buffer_ =
- (*tfl_model->buffers())[tensor->buffer()];
- if (!buffer_ || !buffer_->data()) continue;
- auto* buffer = buffer_->data();
- uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
- ByteSwapBuffer(tensor->type(), buffer->size(), buff_, from_big_endian);
- buffer_swapped[tensor->buffer()] = true;
- }
- }
- }
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::ByteConvertModel(
- std::unique_ptr<FlatBufferModel> model, ErrorReporter* error_reporter,
- bool from_big_endian) {
- if (model == nullptr) return model;
- auto tfl_model = model->GetModel();
- if (tfl_model->subgraphs()->size() == 0) return model;
- if (tfl_model->subgraphs()->Get(0)->tensors()->size() == 0) return model;
- if (tfl_model->buffers()->size() < 2) return model;
- return ByteSwapFlatBufferModel(std::move(model), error_reporter,
- from_big_endian);
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::ByteSwapFlatBufferModel(
- std::unique_ptr<FlatBufferModel> model, ErrorReporter* error_reporter,
- bool from_big_endian) {
- FlatBufferModel* modelp = model.release();
- auto tflite_model = modelp->GetModel();
- auto copied_model = std::make_unique<tflite::ModelT>();
- tflite_model->UnPackTo(copied_model.get(), nullptr);
- ByteSwapTFLiteModelT(copied_model.get(), from_big_endian);
- std::unique_ptr<flatbuffers::FlatBufferBuilder> builder(
- new flatbuffers::FlatBufferBuilder());
- auto packed_model = tflite::Model::Pack(*builder, copied_model.get());
- tflite::FinishModelBuffer(*builder, packed_model);
- flatbuffers::FlatBufferBuilder* builder_ = builder.release();
- return BuildFromBuffer(
- reinterpret_cast<const char*>(builder_->GetBufferPointer()),
- builder_->GetSize(), error_reporter);
-}
-
-void FlatBufferModel::ByteSwapTFLiteModelT(tflite::ModelT* tfl_modelt,
- bool from_big_endian) {
- size_t bytes_per_elem = 0;
- bool buffer_swapped[tfl_modelt->buffers.size()] = {};
- for (size_t subgraph_idx = 0; subgraph_idx < tfl_modelt->subgraphs.size();
- subgraph_idx++) {
- tflite::SubGraphT* subgraph = tfl_modelt->subgraphs.at(subgraph_idx).get();
- for (size_t ts_idx = 0; ts_idx < subgraph->tensors.size(); ts_idx++) {
- tflite::TensorT* tensor = subgraph->tensors[ts_idx].get();
- if (tensor->buffer > 0 && tensor->buffer < tfl_modelt->buffers.size() &&
- !buffer_swapped[tensor->buffer]) {
- const auto* buffer = &(tfl_modelt->buffers[tensor->buffer].get()->data);
- if (buffer && buffer->data()) {
- uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
- ByteSwapBuffer(tensor->type, buffer->size(), buff_, from_big_endian);
- buffer_swapped[tensor->buffer] = true;
- }
- }
- }
- }
-}
-
-#endif
-
-void FlatBufferModel::ValidateModelBuffers(ErrorReporter* error_reporter) {
- auto buffers = model_->buffers();
- if (buffers && buffers->size() > 0) {
- auto first_buffer = buffers->Get(0);
- if (first_buffer && first_buffer->size() != 0) {
- // Note the 0th entry of this array must be an empty buffer (sentinel).
- // This is a convention so that tensors without a buffer can provide 0
- // as their buffer.
- TF_LITE_REPORT_ERROR(
- error_reporter,
- "The 0th entry of the model buffer must be an empty buffer.");
- }
- }
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
- std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
- std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
- std::move(allocation), ValidateErrorReporter(error_reporter)));
- if (!model->initialized()) {
- model.reset();
- } else {
- model->ValidateModelBuffers(error_reporter);
- }
- return model;
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
- std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
- ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
- if (!allocation || !allocation->valid()) {
- TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
- return nullptr;
- }
-
- {
- // Flatbuffers can only be smaller than 2GB. The file format appends some
- // data after the actual flabuffer. We truncate the allocation size to 2GB
- // so that the verifier doesn't early exit on us.
- size_t allocation_size =
- std::min(allocation->bytes(),
- static_cast<size_t>(FLATBUFFERS_MAX_BUFFER_SIZE - 1));
- flatbuffers::Verifier base_verifier(
- reinterpret_cast<const uint8_t*>(allocation->base()), allocation_size);
- if (!VerifyModelBuffer(base_verifier)) {
- TF_LITE_REPORT_ERROR(error_reporter,
- "The model is not a valid Flatbuffer buffer");
- return nullptr;
- }
-
- if (extra_verifier &&
- !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
- allocation_size, error_reporter)) {
- // The verifier will have already logged an appropriate error message.
- return nullptr;
- }
- }
-
- return BuildFromAllocation(std::move(allocation), error_reporter);
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
- const tflite::Model* caller_owned_model_spec,
- ErrorReporter* error_reporter) {
- error_reporter = ValidateErrorReporter(error_reporter);
-
- if (CheckBufferOutsideModel(caller_owned_model_spec)) {
- TF_LITE_REPORT_ERROR(error_reporter,
- "The model contains weights not accessible from "
- "tflite::Model *, please use other api");
- return nullptr;
- }
-
- std::unique_ptr<FlatBufferModel> model(
- new FlatBufferModel(caller_owned_model_spec, error_reporter));
- if (!model->initialized()) {
- model.reset();
- } else {
- model->ValidateModelBuffers(error_reporter);
- }
- return model;
-}
-
-bool FlatBufferModel::CheckBufferOutsideModel(const tflite::Model* model) {
- if (!model || !model->metadata()) return false;
-
- for (int i = 0; i < model->metadata()->size(); ++i) {
- auto metadata = model->metadata()->Get(i);
- if (metadata->name()->str() == tflite_metadata_buffer_location) {
- return true;
- }
- }
- return false;
-}
-
-string FlatBufferModel::GetMinimumRuntime() const {
- if (!model_ || !model_->metadata()) return "";
-
- for (int i = 0; i < model_->metadata()->size(); ++i) {
- auto metadata = model_->metadata()->Get(i);
- if (metadata->name()->str() == tflite_metadata_min_runtime_version) {
- auto buf = metadata->buffer();
- auto* buffer = (*model_->buffers())[buf];
- auto* array = buffer->data();
- // Get the real length of the runtime string, since there might be
- // trailing
- // '\0's in the buffer.
- for (int len = 0; len < array->size(); ++len) {
- if (array->data()[len] == '\0') {
- return string(reinterpret_cast<const char*>(array->data()), len);
- }
- }
- // If there is no '\0' in the buffer, this indicates that the flatbuffer
- // is malformed.
- TF_LITE_REPORT_ERROR(
- error_reporter_,
- "Min_runtime_version in model metadata is malformed");
- break;
- }
- }
- return "";
-}
-
-std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata() const {
- return ReadAllMetadata(model_);
-}
-
-std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata(
- const tflite::Model* model) {
- std::map<std::string, std::string> keys_values;
- if (!model || !model->metadata() || !model->buffers()) return keys_values;
-
- for (int i = 0; i < model->metadata()->size(); ++i) {
- auto metadata = model->metadata()->Get(i);
- auto buf = metadata->buffer();
- if (buf >= model->buffers()->size()) continue;
- const tflite::Buffer* buffer = (*model->buffers())[buf];
- if (!buffer || !buffer->data()) continue;
- const flatbuffers::Vector<uint8_t>* array = buffer->data();
- if (!array) continue;
- std::string val =
- string(reinterpret_cast<const char*>(array->data()), array->size());
- // Skip if key or value of metadata is empty.
- if (!metadata->name() || val.empty()) continue;
- keys_values[metadata->name()->str()] = val;
- }
- return keys_values;
-}
-
-bool FlatBufferModel::CheckModelIdentifier() const {
- if (allocation_->bytes() < 7) {
- TF_LITE_REPORT_ERROR(
- error_reporter_,
- "Model provided must have at least 7 bytes to hold identifier.\n");
- return false;
- }
- if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
- const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
- TF_LITE_REPORT_ERROR(
- error_reporter_,
- "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
- ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
- return false;
- }
- return true;
-}
-
-FlatBufferModel::FlatBufferModel(const Model* model,
- ErrorReporter* error_reporter)
- : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
-
-FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
- ErrorReporter* error_reporter)
- : error_reporter_(ValidateErrorReporter(error_reporter)),
- allocation_(std::move(allocation)) {
- if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
- return;
- }
-
- model_ = ::tflite::GetModel(allocation_->base());
-}
-
-FlatBufferModel::~FlatBufferModel() {}
-
-} // namespace tflite
diff --git a/tensorflow/lite/core/model_builder.h b/tensorflow/lite/core/model_builder.h
index 3337d93..9a36c4b 100644
--- a/tensorflow/lite/core/model_builder.h
+++ b/tensorflow/lite/core/model_builder.h
@@ -22,260 +22,30 @@
/// but should instead include "third_party/tensorflow/lite/model_builder.h".
/// Only the TensorFlow Lite implementation itself should include this
/// file directly.
+
#ifndef TENSORFLOW_LITE_CORE_MODEL_BUILDER_H_
#define TENSORFLOW_LITE_CORE_MODEL_BUILDER_H_
#include <stddef.h>
-#include <map>
-#include <memory>
-#include <string>
-
-#include "tensorflow/lite/allocation.h"
+#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h"
#include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/core/api/verifier.h"
-#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/stderr_reporter.h"
namespace tflite {
-/// An RAII object that represents a read-only tflite model, copied from disk,
-/// or mmapped. This uses flatbuffers as the serialization format.
-///
-/// NOTE: The current API requires that a FlatBufferModel instance be kept alive
-/// by the client as long as it is in use by any dependent Interpreter
-/// instances. As the FlatBufferModel instance is effectively immutable after
-/// creation, the client may safely use a single model with multiple dependent
-/// Interpreter instances, even across multiple threads (though note that each
-/// Interpreter instance is *not* thread-safe).
-///
-/// <pre><code>
-/// using namespace tflite;
-/// StderrReporter error_reporter;
-/// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
-/// &error_reporter);
-/// MyOpResolver resolver; // You need to subclass OpResolver to provide
-/// // implementations.
-/// InterpreterBuilder builder(*model, resolver);
-/// std::unique_ptr<Interpreter> interpreter;
-/// if(builder(&interpreter) == kTfLiteOk) {
-/// .. run model inference with interpreter
-/// }
-/// </code></pre>
-///
-/// OpResolver must be defined to provide your kernel implementations to the
-/// interpreter. This is environment specific and may consist of just the
-/// builtin ops, or some custom operators you defined to extend tflite.
namespace impl {
-class FlatBufferModel {
+class FlatBufferModel : public FlatBufferModelBase<FlatBufferModel> {
public:
- /// Builds a model based on a file.
- /// Caller retains ownership of `error_reporter` and must ensure its lifetime
- /// is longer than the FlatBufferModel instance.
- /// Returns a nullptr in case of failure.
- static std::unique_ptr<FlatBufferModel> BuildFromFile(
- const char* filename,
- ErrorReporter* error_reporter = DefaultErrorReporter());
+ // Use stderr_reporter as the default error reporter.
+ static ErrorReporter* GetDefaultErrorReporter() {
+ return DefaultErrorReporter();
+ }
- /// Verifies whether the content of the file is legit, then builds a model
- /// based on the file.
- /// The extra_verifier argument is an additional optional verifier for the
- /// file contents. By default, we always check with tflite::VerifyModelBuffer.
- /// If extra_verifier is supplied, the file contents is also checked against
- /// the extra_verifier after the check against tflite::VerifyModelBuilder.
- /// Caller retains ownership of `error_reporter` and must ensure its lifetime
- /// is longer than the FlatBufferModel instance.
- /// Returns a nullptr in case of failure.
- static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
- const char* filename, TfLiteVerifier* extra_verifier = nullptr,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Builds a model based on a file descriptor.
- /// Caller retains ownership of `error_reporter` and must ensure its lifetime
- /// is longer than the FlatBufferModel instance. Caller retains ownership of
- /// `fd` and must ensure it is closed after BuildFromFile returns.
- /// Returns a nullptr in case of failure.
- static std::unique_ptr<FlatBufferModel> BuildFromFileDescriptor(
- int fd,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Verifies whether the content of the file descriptor is legit, then builds
- /// a model based on the file.
- /// The extra_verifier argument is an additional optional verifier for the
- /// file contents. By default, we always check with tflite::VerifyModelBuffer.
- /// If extra_verifier is supplied, the file contents is also checked against
- /// the extra_verifier after the check against tflite::VerifyModelBuilder.
- /// Caller retains ownership of `error_reporter` and must ensure its lifetime
- /// is longer than the FlatBufferModel instance.
- /// Returns a nullptr in case of failure.
- static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFileDescriptor(
- int fd, TfLiteVerifier* extra_verifier = nullptr,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Builds a model based on a pre-loaded flatbuffer.
- /// Caller retains ownership of the buffer and should keep it alive until
- /// the returned object is destroyed. Caller also retains ownership of
- /// `error_reporter` and must ensure its lifetime is longer than the
- /// FlatBufferModel instance.
- /// Returns a nullptr in case of failure.
- /// NOTE: this does NOT validate the buffer so it should NOT be called on
- /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
- static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
- const char* caller_owned_buffer, size_t buffer_size,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Verifies whether the content of the buffer is legit, then builds a model
- /// based on the pre-loaded flatbuffer.
- /// The extra_verifier argument is an additional optional verifier for the
- /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
- /// extra_verifier is supplied, the buffer is checked against the
- /// extra_verifier after the check against tflite::VerifyModelBuilder. The
- /// caller retains ownership of the buffer and should keep it alive until the
- /// returned object is destroyed. Caller retains ownership of `error_reporter`
- /// and must ensure its lifetime is longer than the FlatBufferModel instance.
- /// Returns a nullptr in case of failure.
- static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer(
- const char* caller_owned_buffer, size_t buffer_size,
- TfLiteVerifier* extra_verifier = nullptr,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Builds a model directly from an allocation.
- /// Ownership of the allocation is passed to the model, but the caller
- /// retains ownership of `error_reporter` and must ensure its lifetime is
- /// longer than the FlatBufferModel instance.
- /// Returns a nullptr in case of failure (e.g., the allocation is invalid).
- static std::unique_ptr<FlatBufferModel> BuildFromAllocation(
- std::unique_ptr<Allocation> allocation,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Verifies whether the content of the allocation is legit, then builds a
- /// model based on the provided allocation.
- /// The extra_verifier argument is an additional optional verifier for the
- /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
- /// extra_verifier is supplied, the buffer is checked against the
- /// extra_verifier after the check against tflite::VerifyModelBuilder.
- /// Ownership of the allocation is passed to the model, but the caller
- /// retains ownership of `error_reporter` and must ensure its lifetime is
- /// longer than the FlatBufferModel instance.
- /// Returns a nullptr in case of failure.
- static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation(
- std::unique_ptr<Allocation> allocation,
- TfLiteVerifier* extra_verifier = nullptr,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Builds a model directly from a flatbuffer pointer
- /// Caller retains ownership of the buffer and should keep it alive until the
- /// returned object is destroyed. Caller retains ownership of `error_reporter`
- /// and must ensure its lifetime is longer than the FlatBufferModel instance.
- /// Returns a nullptr in case of failure.
- static std::unique_ptr<FlatBufferModel> BuildFromModel(
- const tflite::Model* caller_owned_model_spec,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
-#if FLATBUFFERS_LITTLEENDIAN == 0
- /// Byte swap a constant buffer in place.
- static void ByteSwapBuffer(int8_t tensor_type, size_t buffer_size,
- uint8_t* buffer, bool from_big_endian = true);
-
- /// Byte swap the buffers field of a TFLite Model instance in place.
- static void ByteSwapTFLiteModel(const tflite::Model* tfl_model,
- bool from_big_endian = true);
-
- /// Byte swap the buffers field of a TFLite ModelT instance in place.
- static void ByteSwapTFLiteModelT(tflite::ModelT* tfl_modelt,
- bool from_big_endian = true);
-
- /// Convert the TFLite buffers field between LE and BE format in a
- /// FlatBufferModel which is not empty and return the converted instance.
- static std::unique_ptr<FlatBufferModel> ByteConvertModel(
- std::unique_ptr<FlatBufferModel> model,
- ErrorReporter* error_reporter = DefaultErrorReporter(),
- bool from_big_endian = false);
-
- /// Byte Swap the TFLite buffers field in a FlatBufferModel and return the
- /// swapped instance.
- static std::unique_ptr<FlatBufferModel> ByteSwapFlatBufferModel(
- std::unique_ptr<FlatBufferModel> model,
- ErrorReporter* error_reporter = DefaultErrorReporter(),
- bool from_big_endian = false);
-
- /// Byte Swap the serialized String of a TFLite model in place.
- static void ByteSwapSerializedModel(std::string* serialized_model,
- bool from_big_endian = true);
-#endif
-
- // Releases memory or unmaps mmaped memory.
- ~FlatBufferModel();
-
- // Copying or assignment is disallowed to simplify ownership semantics.
- FlatBufferModel(const FlatBufferModel&) = delete;
- FlatBufferModel& operator=(const FlatBufferModel&) = delete;
-
- bool initialized() const { return model_ != nullptr; }
- const tflite::Model* operator->() const { return model_; }
- const tflite::Model* GetModel() const { return model_; }
- ErrorReporter* error_reporter() const { return error_reporter_; }
- const Allocation* allocation() const { return allocation_.get(); }
-
- // Returns the minimum runtime version from the flatbuffer. This runtime
- // version encodes the minimum required interpreter version to run the
- // flatbuffer model. If the minimum version can't be determined, an empty
- // string will be returned.
- // Note that the returned minimum version is a lower-bound but not a strict
- // lower-bound; ops in the graph may not have an associated runtime version,
- // in which case the actual required runtime might be greater than the
- // reported minimum.
- std::string GetMinimumRuntime() const;
-
- // Return model metadata as a mapping of name & buffer strings.
- // See Metadata table in TFLite schema.
- std::map<std::string, std::string> ReadAllMetadata() const;
-
- // Return model metadata as a mapping of name & buffer strings.
- // See Metadata table in TFLite schema.
- static std::map<std::string, std::string> ReadAllMetadata(
- const ::tflite::Model* model);
-
- // If the buffer is stored as part of the Flatbuffer or outside
- // return false if the buffers are part of the Flatbuffer
- static bool CheckBufferOutsideModel(const tflite::Model* model);
-
- // Validates if the FlatBufferModel's buffer is well-formed. Specifically, it
- // checks if the 0th entry of the model buffers is an empty buffer (sentinel).
- // This is a convention so that tensors without a buffer can provide 0
- // as their buffer.
- // NOTE: The function doesn't explicitly fail for backward compatibility
- // reasons; it just provides a warning in case of failures.
- void ValidateModelBuffers(ErrorReporter* error_reporter);
-
- /// Returns true if the model identifier is correct (otherwise false and
- /// reports an error).
- bool CheckModelIdentifier() const;
-
- private:
- /// Loads a model from a given allocation. FlatBufferModel will take over the
- /// ownership of `allocation`, and delete it in destructor. The ownership of
- /// `error_reporter`remains with the caller and must have lifetime at least
- /// as much as FlatBufferModel. This is to allow multiple models to use the
- /// same ErrorReporter instance.
- explicit FlatBufferModel(
- std::unique_ptr<Allocation> allocation,
- ErrorReporter* error_reporter = DefaultErrorReporter());
-
- /// Loads a model from Model flatbuffer. The `model` has to remain alive and
- /// unchanged until the end of this flatbuffermodel's lifetime.
- FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
-
- /// Flatbuffer traverser pointer. (Model* is a pointer that is within the
- /// allocated memory of the data allocated by allocation's internals.
- const tflite::Model* model_ = nullptr;
- /// The error reporter to use for model errors and subsequent errors when
- /// the interpreter is created
- ErrorReporter* error_reporter_;
- /// The allocator used for holding memory of the model. Note that this will
- /// be null if the client provides a tflite::Model directly.
- std::unique_ptr<Allocation> allocation_;
+ // Inherit all constructors from FlatBufferModelBase since inherited factory
+ // methods refer to them.
+ using FlatBufferModelBase<FlatBufferModel>::FlatBufferModelBase;
};
} // namespace impl
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index 73f192d..b84cb9a 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -299,6 +299,7 @@
":cl_kernel",
":program_cache",
":tensor",
+ "//tensorflow/lite/delegates/gpu/common/task:compiler_options",
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
index 1cc1738..8fd9493 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
@@ -17,6 +17,8 @@
#include <string>
+#include "tensorflow/lite/delegates/gpu/common/task/compiler_options.h"
+
namespace tflite {
namespace gpu {
namespace cl {
@@ -165,6 +167,10 @@
creation_context.context, &operation_->args_,
&operation_->code_));
operation_->args_.ReleaseCPURepresentation();
+ if (creation_context.device->info_.opencl_info.IsCLVK()) {
+ operation_->compiler_options_.push_back(
+ CompilerOptions::kClFastRelaxedMath);
+ }
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
operation_->code_, "main_function", operation_->compiler_options_,
*creation_context.context, *creation_context.device, &kernel_,
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
index 7a40a60..5bd407d 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
@@ -75,6 +75,7 @@
// Ordered records are to be sorted by size of corresponding tensor.
std::vector<TensorUsageWithIndex<size_t>> ordered_records;
+ ordered_records.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
ordered_records.emplace_back(&usage_records[i], i);
}
diff --git a/tensorflow/lite/delegates/gpu/common/model.cc b/tensorflow/lite/delegates/gpu/common/model.cc
index dc68e70..a7a174f 100644
--- a/tensorflow/lite/delegates/gpu/common/model.cc
+++ b/tensorflow/lite/delegates/gpu/common/model.cc
@@ -333,10 +333,16 @@
model->nodes_.clear();
model->execution_plan_.clear();
model->values_.clear();
+ model->known_graph_outputs_.clear();
for (auto& value_def : values_) {
model->values_.push_back({});
if (value_def.value) {
model->values_.back().value = std::make_unique<Value>(*value_def.value);
+ if (std::find(known_graph_outputs_.begin(), known_graph_outputs_.end(),
+ value_def.value.get()) != known_graph_outputs_.end()) {
+ model->known_graph_outputs_.push_back(
+ model->values_.back().value.get());
+ }
}
}
// Add all nodes first.
diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
index f80a52d..8554124 100644
--- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
+++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
@@ -56,7 +56,7 @@
LoadDelegateFromSharedLibrary(
"tensorflow/lite/delegates/utils/experimental/"
"sample_stable_delegate/"
- "libtensorflowlite_sample_stable_delegate_for_test.so");
+ "libtensorflowlite_sample_stable_delegate.so");
ASSERT_NE(stable_delegate_handle, nullptr);
EXPECT_STREQ(stable_delegate_handle->delegate_abi_version,
TFL_STABLE_DELEGATE_ABI_VERSION);
diff --git a/tensorflow/lite/examples/label_image/CMakeLists.txt b/tensorflow/lite/examples/label_image/CMakeLists.txt
index 9874801..2fcb09c 100644
--- a/tensorflow/lite/examples/label_image/CMakeLists.txt
+++ b/tensorflow/lite/examples/label_image/CMakeLists.txt
@@ -61,6 +61,11 @@
${TFLITE_SOURCE_DIR}/tools/delegates/external_delegate_provider.cc)
endif()
+include_directories(label_image
+ PUBLIC
+ ${CMAKE_BINARY_DIR}
+)
+
add_executable(label_image
${TFLITE_LABEL_IMAGE_SRCS}
)
@@ -78,4 +83,6 @@
)
target_link_libraries(label_image
tensorflow-lite
+ profiling_info_proto
+ protobuf
)
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index 8ce5e0c..b6c518d 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -22,6 +22,7 @@
visibility = jni_utils_visibility_allowlist(),
deps = [
"//tensorflow/lite:error_reporter",
+ "//tensorflow/lite/core/c:common",
"//tensorflow/lite/java/jni",
],
)
diff --git a/tensorflow/lite/java/src/main/native/jni_utils.h b/tensorflow/lite/java/src/main/native/jni_utils.h
index 1602d77..1796a38 100644
--- a/tensorflow/lite/java/src/main/native/jni_utils.h
+++ b/tensorflow/lite/java/src/main/native/jni_utils.h
@@ -21,6 +21,7 @@
#include <vector>
+#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc
index 4190fd7..d927010 100644
--- a/tensorflow/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/lite/kernels/embedding_lookup.cc
@@ -104,13 +104,13 @@
// Propagate empty tensor if input is empty
return kTfLiteOk;
}
- const int row_bytes = value->bytes / row_size;
+ const int64_t row_bytes = value->bytes / row_size;
char* output_raw = GetTensorData<char>(output);
const char* value_raw = GetTensorData<char>(value);
const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
- int idx = lookup_data[i];
+ int64_t idx = lookup_data[i];
if (idx >= row_size || idx < 0) {
TF_LITE_KERNEL_LOG(context,
"Embedding Lookup: index out of bounds. "
diff --git a/tensorflow/lite/kernels/embedding_lookup_test.cc b/tensorflow/lite/kernels/embedding_lookup_test.cc
index d13ddd4..679975f 100644
--- a/tensorflow/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/lite/kernels/embedding_lookup_test.cc
@@ -92,6 +92,19 @@
}
}
}
+
+ template <typename T>
+ void Set2DWeightMatrix(const std::function<T(int, int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(weight_);
+ int64_t rows = tensor->dims->data[0];
+ int64_t columns = tensor->dims->data[1];
+ T* data = GetTensorData<T>(tensor);
+ for (int64_t i = 0; i < rows; i++) {
+ for (int64_t j = 0; j < columns; j++) {
+ data[i * columns + j] = function(i, j);
+ }
+ }
+ }
};
class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
@@ -144,6 +157,27 @@
})));
}
+#if !defined(GOOGLE_UNSUPPORTED_OS_LOONIX) && defined(__LP64__)
+TEST(EmbeddingLookupOpTest, LargeTableTest) {
+ EmbeddingLookupOpModel m({1}, {256000, 9216});
+ // Choose a value specifically designed to overflow int32.max
+ m.SetInput({235248});
+ m.Set2DWeightMatrix<float>(
+ [](int i, int j) -> float { return j + i / 100.; });
+
+ // This will cause a lookup at index 235248 in a buffer where every row
+ // has 9216 entries * 4 bytes per entry, which will overflow unless
+ // the Op is using a 64-bit offset for address calculation.
+ ASSERT_EQ(m.Invoke(), kTfLiteOk);
+ std::vector<float> exp(9216);
+
+ for (int s = 0; s < exp.size(); s++) {
+ exp[s] = static_cast<float>(s) + 2352.48f;
+ }
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear(exp)));
+}
+#endif
+
TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestUint8) {
HybridEmbeddingLookupOpModel m({3}, {3, 8}, TensorType_UINT8);
m.SetInput({1, 0, 2});
diff --git a/tensorflow/lite/kernels/internal/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/portable_tensor_utils.h
index d37fe6e..ed59fd0 100644
--- a/tensorflow/lite/kernels/internal/portable_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/portable_tensor_utils.h
@@ -317,7 +317,7 @@
void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output);
-// Same as above but the internal calcualtion is float.
+// Same as above but the internal calculation is float.
void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output);
@@ -333,7 +333,7 @@
void ApplyTanh(int32_t intger_bits, const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
-// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
+// Apply Tanh to a quantized vector. The internal calculation is in float.
// - Input has 2^(integer_bits) as scale.
// - Output has Q0.15 as scale.
void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
diff --git a/tensorflow/lite/kernels/internal/quantization_util_test.cc b/tensorflow/lite/kernels/internal/quantization_util_test.cc
index aec0b2b..aa9c274 100644
--- a/tensorflow/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/lite/kernels/internal/quantization_util_test.cc
@@ -160,13 +160,13 @@
// 255 | 30.0
// 128 | 10.0
TEST(QuantizationUtilTest, ChooseQuantizationParams) {
- QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 30.0);
+ QuantizationParams qp = ChooseQuantizationParams<uint8_t>(-10.0, 30.0);
EXPECT_NEAR(qp.scale, 0.156863, 1e-5);
EXPECT_EQ(qp.zero_point, 64);
}
TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) {
- QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 30.0);
+ QuantizationParams qp = ChooseQuantizationParams<uint8_t>(0.0, 30.0);
EXPECT_NEAR(qp.scale, 0.117647, 1e-5);
EXPECT_EQ(qp.zero_point, 0);
}
@@ -174,23 +174,23 @@
#if GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) {
// Assumption is that zero is within the range.
- EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), "");
+ EXPECT_DEATH(ChooseQuantizationParams<uint8_t>(10.0, 30.0), "");
}
TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) {
// Assumption is that zero is within the range.
- EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), "");
+ EXPECT_DEATH(ChooseQuantizationParams<uint8_t>(30.0, 30.0), "");
}
#endif // GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) {
- QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0);
+ QuantizationParams qp = ChooseQuantizationParams<uint8_t>(0.0, 0.0);
EXPECT_NEAR(qp.scale, 0.0, 1e-5);
EXPECT_EQ(qp.zero_point, 0);
}
TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
- QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 0.0);
+ QuantizationParams qp = ChooseQuantizationParams<uint8_t>(-10.0, 0.0);
EXPECT_NEAR(qp.scale, 0.039216, 1e-5);
EXPECT_EQ(qp.zero_point, 255);
}
@@ -330,7 +330,7 @@
#if GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
- EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
+ EXPECT_DEATH(ChooseQuantizationParams<uint8_t>(10.0, -30.0), "");
}
TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) {
@@ -533,12 +533,12 @@
const std::vector<double> weights = {-4, -2, -1, -0.5, -0.25, -0.125, 0,
0.125, 0.25, 0.5, 1, 2, 4};
const int size = weights.size();
- std::vector<int32> effective_scale_significand(size);
+ std::vector<int32_t> effective_scale_significand(size);
std::vector<int> effective_scale_shift(size);
QuantizeMultiplierArray(weights.data(), size,
effective_scale_significand.data(),
effective_scale_shift.data());
- const std::vector<int32> expected_effective_scale_significand = {
+ const std::vector<int32_t> expected_effective_scale_significand = {
-1073741824, // float scale = -4
-1073741824, // float scale = -2
-1073741824, // float scale = -1
diff --git a/tensorflow/lite/kernels/parse_example/parse_example.cc b/tensorflow/lite/kernels/parse_example/parse_example.cc
index acec033..ec87aab 100644
--- a/tensorflow/lite/kernels/parse_example/parse_example.cc
+++ b/tensorflow/lite/kernels/parse_example/parse_example.cc
@@ -111,7 +111,7 @@
bool ParseExample(StringRef serialized, Example* example) {
DCHECK(example != nullptr);
tf::protobuf::io::CodedInputStream stream(
- reinterpret_cast<const uint8*>(serialized.str), serialized.len);
+ reinterpret_cast<const uint8_t*>(serialized.str), serialized.len);
tensorflow::example::EnableAliasing(&stream);
return ParseExample(&stream, example);
}
diff --git a/tensorflow/lite/kernels/shim/BUILD b/tensorflow/lite/kernels/shim/BUILD
index 3244635..9aa10f1 100644
--- a/tensorflow/lite/kernels/shim/BUILD
+++ b/tensorflow/lite/kernels/shim/BUILD
@@ -163,6 +163,7 @@
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:macros",
] + if_mobile([
"//tensorflow/core:portable_tensorflow_lib_lite",
]) + if_not_mobile([
diff --git a/tensorflow/lite/kernels/shim/test_op/BUILD b/tensorflow/lite/kernels/shim/test_op/BUILD
index e5703e9..af4cd02 100644
--- a/tensorflow/lite/kernels/shim/test_op/BUILD
+++ b/tensorflow/lite/kernels/shim/test_op/BUILD
@@ -48,7 +48,7 @@
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/platform:tstring",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -132,7 +132,7 @@
"//tensorflow/core/framework:tensor_testutil",
"//tensorflow/core/kernels:ops_testutil",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc b/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc
index db537b72..a37483f 100644
--- a/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc
+++ b/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc
@@ -15,6 +15,7 @@
#include <cstdint>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -22,7 +23,6 @@
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/tstring.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tflite {
namespace shim {
diff --git a/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc b/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc
index c457bcc..8e661d8 100644
--- a/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc
+++ b/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc
@@ -15,13 +15,13 @@
#include <cstdint>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tflite {
namespace shim {
diff --git a/tensorflow/lite/kernels/shim/tf_op_shim.cc b/tensorflow/lite/kernels/shim/tf_op_shim.cc
index 7d12bc8..d71cfa7 100644
--- a/tensorflow/lite/kernels/shim/tf_op_shim.cc
+++ b/tensorflow/lite/kernels/shim/tf_op_shim.cc
@@ -21,10 +21,17 @@
#include <vector>
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
+#include "tensorflow/lite/kernels/shim/op_kernel.h"
+#include "tensorflow/lite/kernels/shim/shape.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow/lite/kernels/shim/tensor_view.h"
#include "tensorflow/lite/kernels/shim/tf_tensor_view.h"
diff --git a/tensorflow/lite/kernels/shim/tf_op_shim.h b/tensorflow/lite/kernels/shim/tf_op_shim.h
index 834a394..8f6442b 100644
--- a/tensorflow/lite/kernels/shim/tf_op_shim.h
+++ b/tensorflow/lite/kernels/shim/tf_op_shim.h
@@ -21,13 +21,15 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/registration/registration.h"
#include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/kernels/shim/op_kernel.h"
#include "tensorflow/lite/kernels/shim/shape.h"
+#include "tsl/platform/macros.h"
// This file contains the TF adapter. That is, it takes a `OpKernelShim`
// class and provides a TF kernel out of it.
@@ -51,9 +53,9 @@
public:
explicit TfInvokeContext(::tensorflow::OpKernelContext* context);
// Read an input tensor
- ConstTensorViewOr GetInput(const int idx) const;
+ ConstTensorViewOr GetInput(int idx) const;
// Get a mutable output tensor
- TensorViewOr GetOutput(const int idx, const Shape& shape) const;
+ TensorViewOr GetOutput(int idx, const Shape& shape) const;
// Number of input tensors
int NumInputs() const;
// Number of output tensors
@@ -70,11 +72,11 @@
explicit TfShapeInferenceContext(
::tensorflow::shape_inference::InferenceContext* context);
// Read an input tensor shape
- ShapeOr GetInputShape(const int idx) const;
+ ShapeOr GetInputShape(int idx) const;
// Set an output tensor shape
- absl::Status SetOutputShape(const int idx, const Shape& shape);
+ absl::Status SetOutputShape(int idx, const Shape& shape);
// Read an input tensor during shape inference
- ConstTensorViewOr GetInputTensor(const int idx) const;
+ ConstTensorViewOr GetInputTensor(int idx) const;
// Read a given attribute
absl::StatusOr<AttrValue> GetAttr(const std::string& attr_name) const;
// Number of input tensors
diff --git a/tensorflow/lite/profiling/profile_summary_formatter_test.cc b/tensorflow/lite/profiling/profile_summary_formatter_test.cc
index 48b0697..a6ba380 100644
--- a/tensorflow/lite/profiling/profile_summary_formatter_test.cc
+++ b/tensorflow/lite/profiling/profile_summary_formatter_test.cc
@@ -14,13 +14,14 @@
==============================================================================*/
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
+#include <cstddef>
#include <fstream>
#include <ios>
#include <map>
#include <memory>
#include <string>
+#include <tuple>
-#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/match.h"
#include "tensorflow/core/util/stat_summarizer_options.h"
@@ -32,6 +33,127 @@
namespace {
+// LINT.IfChange(OpProfilingStatComparator)
+bool AreOpProfilingStatEqual(const OpProfilingStat& op_profiling_stat_1,
+ const OpProfilingStat& op_profiling_stat_2) {
+ auto proto_to_tuple = [](const OpProfilingStat& op_profiling_stat) {
+ return std::make_tuple(op_profiling_stat.first(), op_profiling_stat.last(),
+ op_profiling_stat.avg(), op_profiling_stat.stddev(),
+ op_profiling_stat.variance(),
+ op_profiling_stat.min(), op_profiling_stat.max(),
+ op_profiling_stat.sum(), op_profiling_stat.count());
+ };
+ return proto_to_tuple(op_profiling_stat_1) ==
+ proto_to_tuple(op_profiling_stat_2);
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:OpProfilingStat)
+
+// LINT.IfChange(OpProfileDataComparator)
+bool AreOpProfileDataEqual(const OpProfileData& op_profile_data_1,
+ const OpProfileData& op_profile_data_2) {
+ auto proto_to_tuple = [](const OpProfileData& op_profile_data) {
+ return std::make_tuple(op_profile_data.node_type(),
+ op_profile_data.times_called(),
+ op_profile_data.name(), op_profile_data.run_order());
+ };
+
+ return (proto_to_tuple(op_profile_data_1) ==
+ proto_to_tuple(op_profile_data_2)) &&
+ AreOpProfilingStatEqual(op_profile_data_1.inference_microseconds(),
+ op_profile_data_2.inference_microseconds()) &&
+ (AreOpProfilingStatEqual(op_profile_data_1.mem_kb(),
+ op_profile_data_2.mem_kb()));
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:OpProfileData)
+
+// LINT.IfChange(SubGraphProfilingDataComparator)
+bool AreSubGraphProfilingDataEqual(
+ const SubGraphProfilingData& subgraph_profiling_data_1,
+ const SubGraphProfilingData& subgraph_profiling_data_2) {
+ auto proto_to_tuple =
+ [](const SubGraphProfilingData& subgraph_profiling_data) {
+ return std::make_tuple(
+ subgraph_profiling_data.subgraph_name(),
+ subgraph_profiling_data.per_op_profiles().size());
+ };
+
+ if (proto_to_tuple(subgraph_profiling_data_1) ==
+ proto_to_tuple(subgraph_profiling_data_2)) {
+ for (size_t i = 0; i < subgraph_profiling_data_1.per_op_profiles().size();
+ ++i) {
+ auto op_profile_data_1 = subgraph_profiling_data_1.per_op_profiles(i);
+ auto op_profile_data_2 = subgraph_profiling_data_2.per_op_profiles(i);
+ if (!AreOpProfileDataEqual(op_profile_data_1, op_profile_data_2)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:SubGraphProfilingData)
+
+// LINT.IfChange(DelegateProfilingDataComparator)
+bool AreDelegateProfilingDataEqual(
+ const DelegateProfilingData& delegate_profiling_data_1,
+ const DelegateProfilingData& delegate_profiling_data_2) {
+ auto proto_to_tuple =
+ [](const DelegateProfilingData& delegate_profiling_data) {
+ return std::make_tuple(
+ delegate_profiling_data.delegate_name(),
+ delegate_profiling_data.per_op_profiles().size());
+ };
+
+ if (proto_to_tuple(delegate_profiling_data_1) ==
+ proto_to_tuple(delegate_profiling_data_2)) {
+ for (size_t i = 0; i < delegate_profiling_data_1.per_op_profiles().size();
+ ++i) {
+ auto op_profile_data_1 = delegate_profiling_data_1.per_op_profiles(i);
+ auto op_profile_data_2 = delegate_profiling_data_2.per_op_profiles(i);
+ if (!AreOpProfileDataEqual(op_profile_data_1, op_profile_data_2)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:DelegateProfilingData)
+
+// LINT.IfChange(ModelProfilingDataComparator)
+bool AreModelProfilingDataEqual(
+ const ModelProfilingData& model_profiling_data_1,
+ const ModelProfilingData& model_profiling_data_2) {
+ if (model_profiling_data_1.subgraph_profiles().size() !=
+ model_profiling_data_2.subgraph_profiles().size()) {
+ return false;
+ }
+ for (size_t i = 0; i < model_profiling_data_1.subgraph_profiles().size();
+ ++i) {
+ auto subgraph_profile_1 = model_profiling_data_1.subgraph_profiles(i);
+ auto subgraph_profile_2 = model_profiling_data_2.subgraph_profiles(i);
+ if (!AreSubGraphProfilingDataEqual(subgraph_profile_1,
+ subgraph_profile_2)) {
+ return false;
+ }
+ }
+ if (model_profiling_data_1.delegate_profiles().size() !=
+ model_profiling_data_2.delegate_profiles().size()) {
+ return false;
+ }
+ for (size_t i = 0; i < model_profiling_data_1.delegate_profiles().size();
+ ++i) {
+ auto delegate_profile_1 = model_profiling_data_1.delegate_profiles(i);
+ auto delegate_profile_2 = model_profiling_data_2.delegate_profiles(i);
+ if (!AreDelegateProfilingDataEqual(delegate_profile_1,
+ delegate_profile_2)) {
+ return false;
+ }
+ }
+ return true;
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:ModelProfilingData)
+
TEST(SummaryWriterTest, SummaryOptionStdOut) {
ProfileSummaryDefaultFormatter writer;
tensorflow::StatSummarizerOptions options = writer.GetStatSummarizerOptions();
@@ -182,8 +304,9 @@
op_profile_data_1.set_name(kernel_name_1);
op_profile_data_1.set_run_order(1);
op_profile_data_1.set_times_called(2);
- EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(0),
- testing::EqualsProto(op_profile_data_1));
+ EXPECT_TRUE(AreOpProfileDataEqual(
+ model_profiling_data.subgraph_profiles(0).per_op_profiles(0),
+ op_profile_data_1));
OpProfileData op_profile_data_2;
op_profile_data_2.set_node_type(op_name_2);
@@ -212,8 +335,9 @@
op_profile_data_2.set_name(kernel_name_2);
op_profile_data_2.set_run_order(2);
- EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(1),
- testing::EqualsProto(op_profile_data_2));
+ EXPECT_TRUE(AreOpProfileDataEqual(
+ model_profiling_data.subgraph_profiles(0).per_op_profiles(1),
+ op_profile_data_2));
ASSERT_EQ(model_profiling_data.subgraph_profiles(1).subgraph_name(),
"Subgraph 1");
@@ -246,8 +370,9 @@
op_profile_data_3.set_times_called(1);
op_profile_data_3.set_name(kernel_name_3);
op_profile_data_3.set_run_order(3);
- EXPECT_THAT(model_profiling_data.subgraph_profiles(1).per_op_profiles(0),
- testing::EqualsProto(op_profile_data_3));
+ EXPECT_TRUE(AreOpProfileDataEqual(
+ model_profiling_data.subgraph_profiles(1).per_op_profiles(0),
+ op_profile_data_3));
}
TEST(SummaryWriterTest, MultiSubgraphHandleOutputForProto) {
@@ -351,10 +476,10 @@
file.close();
ASSERT_TRUE(benchmark_profiling_data.model_name().empty());
- EXPECT_THAT(benchmark_profiling_data.init_profile(),
- testing::EqualsProto(model_profiling_data_init));
- EXPECT_THAT(benchmark_profiling_data.runtime_profile(),
- testing::EqualsProto(model_profiling_data_run));
+ EXPECT_TRUE(AreModelProfilingDataEqual(
+ benchmark_profiling_data.init_profile(), model_profiling_data_init));
+ EXPECT_TRUE(AreModelProfilingDataEqual(
+ benchmark_profiling_data.runtime_profile(), model_profiling_data_run));
}
TEST(SummaryWriterTest, MultiSubgraphShortSummary) {
diff --git a/tensorflow/lite/profiling/proto/profiling_info.proto b/tensorflow/lite/profiling/proto/profiling_info.proto
index 8116524..5d33571 100644
--- a/tensorflow/lite/profiling/proto/profiling_info.proto
+++ b/tensorflow/lite/profiling/proto/profiling_info.proto
@@ -25,22 +25,27 @@
optional ModelProfilingData runtime_profile = 3;
}
+// LINT.IfChange(ModelProfilingData)
message ModelProfilingData {
repeated SubGraphProfilingData subgraph_profiles = 1;
repeated DelegateProfilingData delegate_profiles = 2;
}
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:ModelProfilingDataComparator)
+// LINT.IfChange(SubGraphProfilingData)
message SubGraphProfilingData {
optional string subgraph_name = 1;
optional int32 subgraph_index = 2;
repeated OpProfileData per_op_profiles = 3;
}
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:SubGraphProfilingDataComparator)
message DelegateProfilingData {
optional string delegate_name = 1;
repeated OpProfileData per_op_profiles = 2;
}
+// LINT.IfChange(OpProfilingStat)
message OpProfilingStat {
optional int64 first = 1;
optional int64 last = 2;
@@ -52,7 +57,9 @@
optional int64 sum = 8;
optional int64 count = 9;
}
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:OpProfilingStatComparator)
+// LINT.IfChange(OpProfileData)
message OpProfileData {
optional string node_type = 1;
optional OpProfilingStat inference_microseconds = 2;
@@ -61,3 +68,4 @@
optional string name = 5;
optional int64 run_order = 6;
}
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:OpProfileDataComparator)
diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD
index 7bf0f18..e064789 100644
--- a/tensorflow/lite/schema/BUILD
+++ b/tensorflow/lite/schema/BUILD
@@ -48,10 +48,10 @@
"upgrade_schema.py",
],
data = [
- "schema_v0.fbs",
- "schema_v1.fbs",
- "schema_v2.fbs",
- "schema_v3.fbs",
+ "//tensorflow/compiler/mlir/lite/schema:schema_v0.fbs",
+ "//tensorflow/compiler/mlir/lite/schema:schema_v1.fbs",
+ "//tensorflow/compiler/mlir/lite/schema:schema_v2.fbs",
+ "//tensorflow/compiler/mlir/lite/schema:schema_v3.fbs",
"@flatbuffers//:flatc",
],
srcs_version = "PY3",
@@ -103,13 +103,6 @@
exports_files([
"conversion_metadata.fbs",
- "schema.fbs",
- "schema_v0.fbs",
- "schema_v1.fbs",
- "schema_v2.fbs",
- "schema_v3.fbs",
- "schema_v3a.fbs",
- "schema_v3b.fbs",
])
flatbuffer_cc_library(
diff --git a/tensorflow/lite/schema/schema_v3b.fbs b/tensorflow/lite/schema/schema_v3b.fbs
deleted file mode 100644
index 9177860..0000000
--- a/tensorflow/lite/schema/schema_v3b.fbs
+++ /dev/null
@@ -1,1242 +0,0 @@
-// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Revision History
-// Version 0: Initial version.
-// Version 1: Add subgraphs to schema.
-// Version 2: Rename operators to conform to NN API.
-// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
-// Version 3a: Add new builtin op code field. Has backward compatibility with
-// version 3.
-// Version 3b: Rename fields in SignatureDef. Has backward compatibility with
-// version 3 and 3a.
-
-namespace tflite;
-
-// This corresponds to the version.
-file_identifier "TFL3";
-// File extension of any written files.
-file_extension "tflite";
-
-// IMPORTANT: All new members of tables, enums and unions must be added at the
-// end to ensure backwards compatibility.
-
-// The type of data stored in a tensor.
-enum TensorType : byte {
- FLOAT32 = 0,
- FLOAT16 = 1,
- INT32 = 2,
- UINT8 = 3,
- INT64 = 4,
- STRING = 5,
- BOOL = 6,
- INT16 = 7,
- COMPLEX64 = 8,
- INT8 = 9,
- FLOAT64 = 10,
- COMPLEX128 = 11,
- UINT64 = 12,
- // Experimental: Resource and variant types are experimental, that are subject
- // to change. Do not implement custom kernels using resource & variant types
- // now.
- RESOURCE = 13,
- VARIANT = 14,
- UINT32 = 15,
-}
-
-// Custom quantization parameters for experimenting with new quantization
-// techniques.
-table CustomQuantization {
- custom:[ubyte] (force_align: 16);
-}
-
-// Represents a specific quantization technique's parameters.
-union QuantizationDetails {
- CustomQuantization,
-}
-
-// Parameters for converting a quantized tensor back to float.
-table QuantizationParameters {
- // These four parameters are the asymmetric linear quantization parameters.
- // Given a quantized value q, the corresponding float value f should be:
- // f = scale * (q - zero_point)
- // For other quantization types, the QuantizationDetails below is used.
- min:[float]; // For importing back into tensorflow.
- max:[float]; // For importing back into tensorflow.
- scale:[float]; // For dequantizing the tensor's values.
- zero_point:[long];
-
- // If this is not none, the other quantization parameters (i.e. min, max,
- // scale, zero_point fields above) are ignored and the value of the
- // QuantizationDetails union should be used.
- details:QuantizationDetails;
-
- // Specifies the dimension of the Tensor's shape that the scales and
- // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
- // with quantization params:
- // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1
- // will be quantized across the second dimension of t.
- // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
- // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
- // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
- quantized_dimension:int;
-}
-
-// Sparse tensors.
-// We use a modification of the TACO format.
-// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf
-//
-// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1),
-// potentially with a k-dimensional block (0 <= k <= n) with dims
-// (dn, ..., dn+k-1), the format needs to specify:
-// 1. In what order to traverse these dimensions. For example, to store a 2-D
-// matrix in row major order, the traversal order would be (d0, d1),
-// whereas to store it in column major order, the traversal order would be
-// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order
-// could be (d0, d1, d2, d3).
-// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original
-// tensor dimension in (d0, ..., dn-1).
-// 3. In the traversal order defined above, the format (dense vs. sparse) and
-// index metadata for each dimension. For a dense dimension, this is just
-// the size of that dimension. For a sparse dimension, it's the same as
-// the compressed index defined in the Compressed Sparse Row (CSR) format.
-// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html)
-
-// The storage type for a dimension. Currently we support:
-// 1. DENSE: each coordinate in this dimension is stored implicitly.
-// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The
-// compression technique is the same what CSR uses.
-// More types like a sparse dimension with a different compression technique
-// could be added to the list in the future.
-enum DimensionType : byte {
- DENSE = 0,
- SPARSE_CSR = 1,
-}
-
-table Int32Vector {
- values:[int];
-}
-
-table Uint16Vector {
- values:[ushort] (force_align: 4);
-}
-
-table Uint8Vector {
- values:[ubyte] (force_align: 4);
-}
-
-// Variable-typed buffer to store the index metadata for a sparse dimension.
-// The widest type is Int32 instead of UInt32 because tensor's shape is a int32
-// vector. We don't want the per-dimensional index to overflow that range.
-union SparseIndexVector {
- Int32Vector,
- Uint16Vector,
- Uint8Vector
-}
-
-table DimensionMetadata {
- // Whether a dimension is dense or sparse.
- format:DimensionType;
- // Index metadata used for a dimension.
- // - If format is DimensionType.DENSE then we use the dense_size field to
- // store the size of that dimension. Each index in that dimension is
- // stored implicitly.
- // - If format is DimensionType.SPARSE_CSR then we use array_segments and
- // array_indices to encode that dimension. array_segments represents how
- // to segment the indices array, each segment corresponds to one element
- // in the previous dimension. array_indices represents the index of the
- // non-zero elements within this dimension (as those in the CSR matrix
- // format, where the first array is row pointers and the second array is
- // column indices).
- dense_size:int;
- array_segments:SparseIndexVector;
- array_indices:SparseIndexVector;
-}
-
-// Parameters to encode a sparse TfLite tensor.
-table SparsityParameters {
- // The traversal order of the dimensions defined in the `shape` field of the
- // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1,
- // ..., dn-1),
- // - if not block sparse, the traversal_order is just a permutation of (d0,
- // ..., dn-1). For example, a 2-D matrix stored in row-major order would
- // have traversal_order = (d0, d1).
- // - if block sparse with a k-dimensional block (0 <= k <= n), the
- // traversal_order has n + k elements. The first n elements are still a
- // permutation of (d0, ..., dn-1). The lask k elements are a permutation
- // of (dn, ..., dn+k-1), defining how to traverse a block internally. For
- // example, a 2-D matrix with 2-D blocks, both stored in row-major order
- // would have traversal_order = (d0, d1, d2, d3).
- traversal_order:[int];
- // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n),
- // stores how a block dimension in (dn, ..., dn+k-1) maps to the original
- // tensor dimension in (d0, ..., dn).
- // It's stored in the order of (dn, ..., dn+k-1).
- // If not block-sparse, this field is NULL.
- block_map:[int];
- // In the traversal order defined above, the metadata needed for
- // each dimension to locate the non-zero values in the original dense tensor.
- // The size of the dim_metadata array = the size of the traversal_order array
- // = n + k.
- dim_metadata:[DimensionMetadata];
-}
-
-table Tensor {
- // The tensor shape. The meaning of each entry is operator-specific but
- // builtin ops use: [batch size, height, width, number of channels] (That's
- // Tensorflow's NHWC).
- shape:[int];
- type:TensorType;
- // An index that refers to the buffers table at the root of the model. Or,
- // if there is no data buffer associated (i.e. intermediate results), then
- // this is 0 (which refers to an always existent empty buffer).
- //
- // The data_buffer itself is an opaque container, with the assumption that the
- // target device is little-endian. In addition, all builtin operators assume
- // the memory is ordered such that if `shape` is [4, 3, 2], then index
- // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k].
- buffer:uint;
- name:string; // For debugging and importing back into tensorflow.
- quantization:QuantizationParameters; // Optional.
-
- is_variable:bool = false;
-
- // Parameters to encode a sparse tensor. See the example in
- // tensorflow/lite/testdata/sparse_tensor.json.
- sparsity:SparsityParameters; // Optional.
-
- // Encodes `shape` with unknown dimensions. Unknown dimensions are
- // represented with -1.
- shape_signature:[int]; // Optional.
-}
-
-// A list of builtin operators. Builtin operators are slightly faster than custom
-// ones, but not by much. Moreover, while custom operators accept an opaque
-// object containing configuration parameters, builtins have a predetermined
-// set of acceptable options.
-// LINT.IfChange
-enum BuiltinOperator : int32 {
- ADD = 0,
- AVERAGE_POOL_2D = 1,
- CONCATENATION = 2,
- CONV_2D = 3,
- DEPTHWISE_CONV_2D = 4,
- DEPTH_TO_SPACE = 5,
- DEQUANTIZE = 6,
- EMBEDDING_LOOKUP = 7,
- FLOOR = 8,
- FULLY_CONNECTED = 9,
- HASHTABLE_LOOKUP = 10,
- L2_NORMALIZATION = 11,
- L2_POOL_2D = 12,
- LOCAL_RESPONSE_NORMALIZATION = 13,
- LOGISTIC = 14,
- LSH_PROJECTION = 15,
- LSTM = 16,
- MAX_POOL_2D = 17,
- MUL = 18,
- RELU = 19,
- // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed
- // since different model developers use RELU1 in different ways. Never
- // create another op called RELU1.
- RELU_N1_TO_1 = 20,
- RELU6 = 21,
- RESHAPE = 22,
- RESIZE_BILINEAR = 23,
- RNN = 24,
- SOFTMAX = 25,
- SPACE_TO_DEPTH = 26,
- SVDF = 27,
- TANH = 28,
- CONCAT_EMBEDDINGS = 29,
- SKIP_GRAM = 30,
- CALL = 31,
- CUSTOM = 32,
- EMBEDDING_LOOKUP_SPARSE = 33,
- PAD = 34,
- UNIDIRECTIONAL_SEQUENCE_RNN = 35,
- GATHER = 36,
- BATCH_TO_SPACE_ND = 37,
- SPACE_TO_BATCH_ND = 38,
- TRANSPOSE = 39,
- MEAN = 40,
- SUB = 41,
- DIV = 42,
- SQUEEZE = 43,
- UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
- STRIDED_SLICE = 45,
- BIDIRECTIONAL_SEQUENCE_RNN = 46,
- EXP = 47,
- TOPK_V2 = 48,
- SPLIT = 49,
- LOG_SOFTMAX = 50,
- // DELEGATE is a special op type for the operations which are delegated to
- // other backends.
- // WARNING: Experimental interface, subject to change
- DELEGATE = 51,
- BIDIRECTIONAL_SEQUENCE_LSTM = 52,
- CAST = 53,
- PRELU = 54,
- MAXIMUM = 55,
- ARG_MAX = 56,
- MINIMUM = 57,
- LESS = 58,
- NEG = 59,
- PADV2 = 60,
- GREATER = 61,
- GREATER_EQUAL = 62,
- LESS_EQUAL = 63,
- SELECT = 64,
- SLICE = 65,
- SIN = 66,
- TRANSPOSE_CONV = 67,
- SPARSE_TO_DENSE = 68,
- TILE = 69,
- EXPAND_DIMS = 70,
- EQUAL = 71,
- NOT_EQUAL = 72,
- LOG = 73,
- SUM = 74,
- SQRT = 75,
- RSQRT = 76,
- SHAPE = 77,
- POW = 78,
- ARG_MIN = 79,
- FAKE_QUANT = 80,
- REDUCE_PROD = 81,
- REDUCE_MAX = 82,
- PACK = 83,
- LOGICAL_OR = 84,
- ONE_HOT = 85,
- LOGICAL_AND = 86,
- LOGICAL_NOT = 87,
- UNPACK = 88,
- REDUCE_MIN = 89,
- FLOOR_DIV = 90,
- REDUCE_ANY = 91,
- SQUARE = 92,
- ZEROS_LIKE = 93,
- FILL = 94,
- FLOOR_MOD = 95,
- RANGE = 96,
- RESIZE_NEAREST_NEIGHBOR = 97,
- LEAKY_RELU = 98,
- SQUARED_DIFFERENCE = 99,
- MIRROR_PAD = 100,
- ABS = 101,
- SPLIT_V = 102,
- UNIQUE = 103,
- CEIL = 104,
- REVERSE_V2 = 105,
- ADD_N = 106,
- GATHER_ND = 107,
- COS = 108,
- WHERE = 109,
- RANK = 110,
- ELU = 111,
- REVERSE_SEQUENCE = 112,
- MATRIX_DIAG = 113,
- QUANTIZE = 114,
- MATRIX_SET_DIAG = 115,
- ROUND = 116,
- HARD_SWISH = 117,
- IF = 118,
- WHILE = 119,
- NON_MAX_SUPPRESSION_V4 = 120,
- NON_MAX_SUPPRESSION_V5 = 121,
- SCATTER_ND = 122,
- SELECT_V2 = 123,
- DENSIFY = 124,
- SEGMENT_SUM = 125,
- BATCH_MATMUL = 126,
- PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
- CUMSUM = 128,
- CALL_ONCE = 129,
- BROADCAST_TO = 130,
- RFFT2D = 131,
- CONV_3D = 132,
- IMAG=133,
- REAL=134,
- COMPLEX_ABS=135,
- HASHTABLE = 136,
- HASHTABLE_FIND = 137,
- HASHTABLE_IMPORT = 138,
- HASHTABLE_SIZE = 139,
- REDUCE_ALL = 140,
- CONV_3D_TRANSPOSE = 141,
- VAR_HANDLE = 142,
- READ_VARIABLE = 143,
- ASSIGN_VARIABLE = 144,
-}
-// LINT.ThenChange(nnapi_linter/linter.proto)
-
-// Options for the builtin operators.
-union BuiltinOptions {
- Conv2DOptions,
- DepthwiseConv2DOptions,
- ConcatEmbeddingsOptions,
- LSHProjectionOptions,
- Pool2DOptions,
- SVDFOptions,
- RNNOptions,
- FullyConnectedOptions,
- SoftmaxOptions,
- ConcatenationOptions,
- AddOptions,
- L2NormOptions,
- LocalResponseNormalizationOptions,
- LSTMOptions,
- ResizeBilinearOptions,
- CallOptions,
- ReshapeOptions,
- SkipGramOptions,
- SpaceToDepthOptions,
- EmbeddingLookupSparseOptions,
- MulOptions,
- PadOptions,
- GatherOptions,
- BatchToSpaceNDOptions,
- SpaceToBatchNDOptions,
- TransposeOptions,
- ReducerOptions,
- SubOptions,
- DivOptions,
- SqueezeOptions,
- SequenceRNNOptions,
- StridedSliceOptions,
- ExpOptions,
- TopKV2Options,
- SplitOptions,
- LogSoftmaxOptions,
- CastOptions,
- DequantizeOptions,
- MaximumMinimumOptions,
- ArgMaxOptions,
- LessOptions,
- NegOptions,
- PadV2Options,
- GreaterOptions,
- GreaterEqualOptions,
- LessEqualOptions,
- SelectOptions,
- SliceOptions,
- TransposeConvOptions,
- SparseToDenseOptions,
- TileOptions,
- ExpandDimsOptions,
- EqualOptions,
- NotEqualOptions,
- ShapeOptions,
- PowOptions,
- ArgMinOptions,
- FakeQuantOptions,
- PackOptions,
- LogicalOrOptions,
- OneHotOptions,
- LogicalAndOptions,
- LogicalNotOptions,
- UnpackOptions,
- FloorDivOptions,
- SquareOptions,
- ZerosLikeOptions,
- FillOptions,
- BidirectionalSequenceLSTMOptions,
- BidirectionalSequenceRNNOptions,
- UnidirectionalSequenceLSTMOptions,
- FloorModOptions,
- RangeOptions,
- ResizeNearestNeighborOptions,
- LeakyReluOptions,
- SquaredDifferenceOptions,
- MirrorPadOptions,
- AbsOptions,
- SplitVOptions,
- UniqueOptions,
- ReverseV2Options,
- AddNOptions,
- GatherNdOptions,
- CosOptions,
- WhereOptions,
- RankOptions,
- ReverseSequenceOptions,
- MatrixDiagOptions,
- QuantizeOptions,
- MatrixSetDiagOptions,
- HardSwishOptions,
- IfOptions,
- WhileOptions,
- DepthToSpaceOptions,
- NonMaxSuppressionV4Options,
- NonMaxSuppressionV5Options,
- ScatterNdOptions,
- SelectV2Options,
- DensifyOptions,
- SegmentSumOptions,
- BatchMatMulOptions,
- CumsumOptions,
- CallOnceOptions,
- BroadcastToOptions,
- Rfft2dOptions,
- Conv3DOptions,
- HashtableOptions,
- HashtableFindOptions,
- HashtableImportOptions,
- HashtableSizeOptions,
- VarHandleOptions,
- ReadVariableOptions,
- AssignVariableOptions,
-}
-
-enum Padding : byte { SAME, VALID }
-
-enum ActivationFunctionType : byte {
- NONE = 0,
- RELU = 1,
- RELU_N1_TO_1 = 2,
- RELU6 = 3,
- TANH = 4,
- SIGN_BIT = 5,
-}
-
-table Conv2DOptions {
- padding:Padding;
- stride_w:int;
- stride_h:int;
- fused_activation_function:ActivationFunctionType;
- dilation_w_factor:int = 1;
- dilation_h_factor:int = 1;
-}
-
-// Options for both Conv3D and Conv3DTranspose.
-table Conv3DOptions {
- padding:Padding;
- stride_d:int;
- stride_w:int;
- stride_h:int;
- fused_activation_function:ActivationFunctionType;
- dilation_d_factor:int = 1;
- dilation_w_factor:int = 1;
- dilation_h_factor:int = 1;
-}
-
-table Pool2DOptions {
- padding:Padding;
- stride_w:int;
- stride_h:int;
- filter_width:int;
- filter_height:int;
- fused_activation_function:ActivationFunctionType;
-}
-
-table DepthwiseConv2DOptions {
- // Parameters for DepthwiseConv version 1 or above.
- padding:Padding;
- stride_w:int;
- stride_h:int;
- // `depth_multiplier` is redundant. It's used by CPU kernels in
- // TensorFlow 2.0 or below, but ignored in versions above.
- // See comments in lite/c/builtin_op_data.h for more details.
- depth_multiplier:int;
- fused_activation_function:ActivationFunctionType;
- // Parameters for DepthwiseConv version 2 or above.
- dilation_w_factor:int = 1;
- dilation_h_factor:int = 1;
-}
-
-table ConcatEmbeddingsOptions {
- num_channels:int;
- num_columns_per_channel:[int];
- embedding_dim_per_channel:[int]; // This could be inferred from parameters.
-}
-
-enum LSHProjectionType: byte {
- UNKNOWN = 0,
- SPARSE = 1,
- DENSE = 2,
-}
-
-table LSHProjectionOptions {
- type: LSHProjectionType;
-}
-
-table SVDFOptions {
- rank:int;
- fused_activation_function:ActivationFunctionType;
- // For weights-only quantization, use asymmetric quantization for non
- // constant inputs at evaluation time.
- asymmetric_quantize_inputs:bool;
-}
-
-// An implementation of TensorFlow RNNCell.
-table RNNOptions {
- fused_activation_function:ActivationFunctionType;
- asymmetric_quantize_inputs:bool;
-}
-
-// An implementation of TensorFlow dynamic_rnn with RNNCell.
-table SequenceRNNOptions {
- time_major:bool;
- fused_activation_function:ActivationFunctionType;
- asymmetric_quantize_inputs:bool;
-}
-
-// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
-table BidirectionalSequenceRNNOptions {
- time_major:bool;
- fused_activation_function:ActivationFunctionType;
- merge_outputs: bool;
- asymmetric_quantize_inputs:bool;
-}
-
-enum FullyConnectedOptionsWeightsFormat: byte {
- DEFAULT = 0,
- SHUFFLED4x16INT8 = 1,
-}
-
-// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
-table FullyConnectedOptions {
- // Parameters for FullyConnected version 1 or above.
- fused_activation_function:ActivationFunctionType;
-
- // Parameters for FullyConnected version 2 or above.
- weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT;
-
- // Parameters for FullyConnected version 5 or above.
- // If set to true, then the number of dimension is preserved. Furthermore,
- // all but the last dimension of the input and output shapes will be equal.
- keep_num_dims: bool;
-
- // Parameters for FullyConnected version 7 or above.
- // If set to true, then weights-only op will use asymmetric quantization for
- // inputs.
- asymmetric_quantize_inputs: bool;
-}
-
-table SoftmaxOptions {
- beta: float;
-}
-
-// An implementation of TensorFlow concat.
-table ConcatenationOptions {
- axis:int;
- fused_activation_function:ActivationFunctionType;
-}
-
-table AddOptions {
- fused_activation_function:ActivationFunctionType;
- // Parameters supported by version 3.
- pot_scale_int16:bool = true;
-}
-
-table MulOptions {
- fused_activation_function:ActivationFunctionType;
-}
-
-table L2NormOptions {
- // This field is currently ignored in the L2 Norm Op.
- fused_activation_function:ActivationFunctionType;
-}
-
-table LocalResponseNormalizationOptions {
- radius:int;
- bias:float;
- alpha:float;
- beta:float;
-}
-
-enum LSTMKernelType : byte {
- // Full LSTM kernel which supports peephole and projection.
- FULL = 0,
- // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
- BASIC = 1,
-}
-
-// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
-table LSTMOptions {
- // Parameters for LSTM version 1 or above.
- fused_activation_function:ActivationFunctionType;
- cell_clip: float; // Optional, 0.0 means no clipping
- proj_clip: float; // Optional, 0.0 means no clipping
-
- // Parameters for LSTM version 2 or above.
- // Basic kernel is only supported in version 2 or above.
- kernel_type: LSTMKernelType = FULL;
-
- // Parameters for LSTM version 4 or above.
- asymmetric_quantize_inputs: bool;
-}
-
-// An implementation of TensorFlow dynamic_rnn with LSTMCell.
-table UnidirectionalSequenceLSTMOptions {
- fused_activation_function:ActivationFunctionType;
- cell_clip: float; // Optional, 0.0 means no clipping
- proj_clip: float; // Optional, 0.0 means no clipping
-
- // If true then first dimension is sequence, otherwise batch.
- time_major:bool;
-
- // Parameter for Unidirectional Sequence LSTM version 4.
- asymmetric_quantize_inputs:bool;
-}
-
-table BidirectionalSequenceLSTMOptions {
- // Parameters supported by version 1:
- fused_activation_function:ActivationFunctionType;
- cell_clip: float; // Optional, 0.0 means no clipping
- proj_clip: float; // Optional, 0.0 means no clipping
-
- // If true, store the outputs of both directions into the first output.
- merge_outputs: bool;
-
- // Parameters supported by version 2:
- // If true then first dimension is sequence, otherwise batch.
- // Version 1 implementations assumed time_major to be true, so this default
- // value should never change.
- time_major: bool = true;
-
- // Parameters for version 3 or above.
- asymmetric_quantize_inputs:bool;
-}
-
-table ResizeBilinearOptions {
- new_height: int (deprecated);
- new_width: int (deprecated);
- align_corners: bool;
- half_pixel_centers: bool;
-}
-
-table ResizeNearestNeighborOptions {
- align_corners: bool;
- half_pixel_centers: bool;
-}
-
-// A call operation options
-table CallOptions {
- // The subgraph index that needs to be called.
- subgraph:uint;
-}
-
-table PadOptions {
-}
-
-table PadV2Options {
-}
-
-table ReshapeOptions {
- new_shape:[int];
-}
-
-table SpaceToBatchNDOptions {
-}
-
-table BatchToSpaceNDOptions {
-}
-
-table SkipGramOptions {
- ngram_size: int;
- max_skip_size: int;
- include_all_ngrams: bool;
-}
-
-table SpaceToDepthOptions {
- block_size: int;
-}
-
-table DepthToSpaceOptions {
- block_size: int;
-}
-
-table SubOptions {
- fused_activation_function:ActivationFunctionType;
- // Parameters supported by version 5
- pot_scale_int16:bool = true;
-}
-
-table DivOptions {
- fused_activation_function:ActivationFunctionType;
-}
-
-table TopKV2Options {
-}
-
-enum CombinerType : byte {
- SUM = 0,
- MEAN = 1,
- SQRTN = 2,
-}
-
-table EmbeddingLookupSparseOptions {
- combiner:CombinerType;
-}
-
-table GatherOptions {
- axis: int;
- // Parameters for Gather version 5 or above.
- batch_dims: int = 0;
-}
-
-table TransposeOptions {
-}
-
-table ExpOptions {
-}
-
-table CosOptions {
-}
-
-table ReducerOptions {
- keep_dims: bool;
-}
-
-table SqueezeOptions {
- squeeze_dims:[int];
-}
-
-table SplitOptions {
- num_splits: int;
-}
-
-table SplitVOptions {
- num_splits: int;
-}
-
-table StridedSliceOptions {
- begin_mask: int;
- end_mask: int;
- ellipsis_mask: int;
- new_axis_mask: int;
- shrink_axis_mask: int;
-}
-
-table LogSoftmaxOptions {
-}
-
-table CastOptions {
- in_data_type: TensorType;
- out_data_type: TensorType;
-}
-
-table DequantizeOptions {
-}
-
-table MaximumMinimumOptions {
-}
-
-table TileOptions {
-}
-
-table ArgMaxOptions {
- output_type : TensorType;
-}
-
-table ArgMinOptions {
- output_type : TensorType;
-}
-
-table GreaterOptions {
-}
-
-table GreaterEqualOptions {
-}
-
-table LessOptions {
-}
-
-table LessEqualOptions {
-}
-
-table NegOptions {
-}
-
-table SelectOptions {
-}
-
-table SliceOptions {
-}
-
-table TransposeConvOptions {
- padding:Padding;
- stride_w:int;
- stride_h:int;
-}
-
-table ExpandDimsOptions {
-}
-
-table SparseToDenseOptions {
- validate_indices:bool;
-}
-
-table EqualOptions {
-}
-
-table NotEqualOptions {
-}
-
-table ShapeOptions {
- // Optional output type of the operation (int32 or int64). Defaults to int32.
- out_type : TensorType;
-}
-
-table RankOptions {
-}
-
-table PowOptions {
-}
-
-table FakeQuantOptions {
- // Parameters supported by version 1:
- min:float;
- max:float;
- num_bits:int;
-
- // Parameters supported by version 2:
- narrow_range:bool;
-}
-
-table PackOptions {
- values_count:int;
- axis:int;
-}
-
-table LogicalOrOptions {
-}
-
-table OneHotOptions {
- axis:int;
-}
-
-table AbsOptions {
-}
-
-
-table HardSwishOptions {
-}
-
-table LogicalAndOptions {
-}
-
-table LogicalNotOptions {
-}
-
-table UnpackOptions {
- num:int;
- axis:int;
-}
-
-table FloorDivOptions {
-}
-
-table SquareOptions {
-}
-
-table ZerosLikeOptions {
-}
-
-table FillOptions {
-}
-
-table FloorModOptions {
-}
-
-table RangeOptions {
-}
-
-table LeakyReluOptions {
- alpha:float;
-}
-
-table SquaredDifferenceOptions {
-}
-
-enum MirrorPadMode : byte {
- // Doesn't include borders.
- REFLECT = 0,
- // Includes borders.
- SYMMETRIC = 1,
-}
-
-table MirrorPadOptions {
- mode:MirrorPadMode;
-}
-
-table UniqueOptions {
- idx_out_type:TensorType = INT32;
-}
-
-table ReverseV2Options {
-}
-
-table AddNOptions {
-}
-
-table GatherNdOptions {
-}
-
-table WhereOptions {
-}
-
-table ReverseSequenceOptions {
- seq_dim:int;
- batch_dim:int = 0;
-}
-
-table MatrixDiagOptions {
-}
-
-table QuantizeOptions {
-}
-
-table MatrixSetDiagOptions {
-}
-
-table IfOptions {
- then_subgraph_index:int;
- else_subgraph_index:int;
-}
-
-table CallOnceOptions {
- init_subgraph_index:int;
-}
-
-table WhileOptions {
- cond_subgraph_index:int;
- body_subgraph_index:int;
-}
-
-table NonMaxSuppressionV4Options {
-}
-
-table NonMaxSuppressionV5Options {
-}
-
-table ScatterNdOptions {
-}
-
-table SelectV2Options {
-}
-
-table DensifyOptions {
-}
-
-table SegmentSumOptions {
-}
-
-table BatchMatMulOptions {
- adj_x:bool;
- adj_y:bool;
- // Parameters for BatchMatMul version 4 or above.
- // If set to true, then weights-only op will use asymmetric quantization for
- // inputs.
- asymmetric_quantize_inputs: bool;
-}
-
-table CumsumOptions {
- exclusive:bool;
- reverse:bool;
-}
-
-table BroadcastToOptions {
-}
-
-table Rfft2dOptions {
-}
-
-table HashtableOptions {
- // The identity of hash tables. This identity will be used across different
- // subgraphs in the same interpreter instance.
- table_id:int;
- key_dtype:TensorType;
- value_dtype:TensorType;
-}
-
-table HashtableFindOptions {
-}
-
-table HashtableImportOptions {
-}
-
-table HashtableSizeOptions {
-}
-
-table VarHandleOptions {
- container:string;
- shared_name:string;
-}
-
-table ReadVariableOptions {
-}
-
-table AssignVariableOptions {
-}
-
-// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
-// builtin, or a string if the operator is custom.
-table OperatorCode {
- // This field is for backward compatibility. This field will be used when
- // the value of the extended builtin_code field has less than
- // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
- deprecated_builtin_code:byte;
- custom_code:string;
-
- // The version of the operator. The version need to be bumped whenever new
- // parameters are introduced into an op.
- version:int = 1;
-
- // This field is introduced for resolving op builtin code shortage problem
- // (the original BuiltinOperator enum field was represented as a byte).
- // This field will be used when the value of the extended builtin_code field
- // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
- builtin_code:BuiltinOperator;
-}
-
-enum CustomOptionsFormat : byte {
- FLEXBUFFERS = 0,
-}
-
-// An operator takes tensors as inputs and outputs. The type of operation being
-// performed is determined by an index into the list of valid OperatorCodes,
-// while the specifics of each operations is configured using builtin_options
-// or custom_options.
-table Operator {
- // Index into the operator_codes array. Using an integer here avoids
- // complicate map lookups.
- opcode_index:uint;
-
- // Optional input are indicated by -1.
- inputs:[int];
- outputs:[int];
-
- builtin_options:BuiltinOptions;
- custom_options:[ubyte];
- custom_options_format:CustomOptionsFormat;
-
- // A list of booleans indicating the input tensors which are being mutated by
- // this operator.(e.g. used by RNN and LSTM).
- // For example, if the "inputs" array refers to 5 tensors and the second and
- // fifth are mutable variables, then this list will contain
- // [false, true, false, false, true].
- //
- // If the list is empty, no variable is mutated in this operator.
- // The list either has the same length as `inputs`, or is empty.
- mutating_variable_inputs:[bool];
-
- // A list of indices to the subgraph's "tensors" that are internal to an Op.
- // Internal tensors are those that do not flow in or out of the operation,
- // but instead are part of internal computation. As such, the operation's
- // implementation may manage its memory more efficiently. They are needed
- // however (i.e. not just an implementation detail) since they are part of the
- // computation, which may require relevant metadata such as quantization
- // parameters.
- intermediates:[int];
-}
-
-// The root type, defining a subgraph, which typically represents an entire
-// model.
-table SubGraph {
- // A list of all tensors used in this subgraph.
- tensors:[Tensor];
-
- // Indices of the tensors that are inputs into this subgraph. Note this is
- // the list of non-static tensors that feed into the subgraph for inference.
- inputs:[int];
-
- // Indices of the tensors that are outputs out of this subgraph. Note this is
- // the list of output tensors that are considered the product of the
- // subgraph's inference.
- outputs:[int];
-
- // All operators, in execution order.
- operators:[Operator];
-
- // Name of this subgraph (used for debugging).
- name:string;
-}
-
-// Table of raw data buffers (used for constant tensors). Referenced by tensors
-// by index. The generous alignment accommodates mmap-friendly data structures.
-table Buffer {
- data:[ubyte] (force_align: 16);
-}
-
-table Metadata {
- // A human readable string to uniquely identify a Metadata.
- name:string;
- // An index to the buffers table.
- buffer:uint;
-}
-
-// Map from an alias name of tensor to tensor index in the graph.
-// This is used in Signature def.
-table TensorMap {
- // Represents the alias to use for this tensor.
- name:string;
-
- // The actual tensor index in the primary graph, that 'name' corresponds to.
- tensor_index:uint;
-}
-
-// This corresponds to SignatureDef in Tensorflow SavedModel.
-// The SignatureDef will be part of the SavedModel provided for conversion.
-table SignatureDef {
- // Named inputs for this signature.
- inputs:[TensorMap];
-
- // Named outputs for this signature.
- outputs:[TensorMap];
-
- // Key value which was in the Tensorflow SavedModel SignatureDef map.
- signature_key:string;
-
- // Model tag, deprecated.
- deprecated_tag:string (deprecated);
-
- // Index of subgraphs that corresponds to the exported method.
- subgraph_index:uint;
-}
-
-table Model {
- // Version of the schema.
- version:uint;
-
- // A list of all operator codes used in this model. This is
- // kept in order because operators carry an index into this
- // vector.
- operator_codes:[OperatorCode];
-
- // All the subgraphs of the model. The 0th is assumed to be the main
- // model.
- subgraphs:[SubGraph];
-
- // A description of the model.
- description:string;
-
- // Buffers of the model.
- // Note the 0th entry of this array must be an empty buffer (sentinel).
- // This is a convention so that tensors without a buffer can provide 0 as
- // their buffer.
- buffers:[Buffer];
-
- // Metadata about the model. Indirects into the existings buffers list.
- // Deprecated, prefer to use metadata field.
- metadata_buffer:[int];
-
- // Metadata about the model.
- metadata:[Metadata];
-
- // Optional SignatureDefs for the model.
- signature_defs:[SignatureDef];
-}
-
-root_type Model;
diff --git a/tensorflow/lite/stateful_error_reporter.h b/tensorflow/lite/stateful_error_reporter.h
index cf66934..10dc096 100644
--- a/tensorflow/lite/stateful_error_reporter.h
+++ b/tensorflow/lite/stateful_error_reporter.h
@@ -15,9 +15,10 @@
#ifndef TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
#define TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
+// LINT.IfChange
#include <string>
-#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
namespace tflite {
@@ -30,5 +31,6 @@
};
} // namespace tflite
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/stateful_error_reporter.h)
#endif // TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
diff --git a/tensorflow/lite/stderr_reporter.h b/tensorflow/lite/stderr_reporter.h
index 2eacb9e..fdac5d4 100644
--- a/tensorflow/lite/stderr_reporter.h
+++ b/tensorflow/lite/stderr_reporter.h
@@ -18,7 +18,6 @@
#include <cstdarg>
#include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/core/c/common.h"
namespace tflite {
diff --git a/tensorflow/lite/testing/op_tests/is_finite.py b/tensorflow/lite/testing/op_tests/is_finite.py
index 2425fa9..493ea05 100644
--- a/tensorflow/lite/testing/op_tests/is_finite.py
+++ b/tensorflow/lite/testing/op_tests/is_finite.py
@@ -52,7 +52,7 @@
input_values[random_index(input_values.shape)] = np.inf
input_values[random_index(input_values.shape)] = -np.inf
- input_values[random_index(input_values.shape)] = np.NAN
+ input_values[random_index(input_values.shape)] = np.nan
input_values[random_index(input_values.shape)] = tf.float32.max
input_values[random_index(input_values.shape)] = tf.float32.min
diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD
index 4f370bb..8c9ccf7 100644
--- a/tensorflow/lite/toco/BUILD
+++ b/tensorflow/lite/toco/BUILD
@@ -302,7 +302,6 @@
"//tensorflow/lite/kernels/internal:strided_slice_logic",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
- "@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
index 89da8a6..d6932b7 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
@@ -17,11 +17,14 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
index cc519e4..6d2b5ca 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
@@ -12,8 +12,9 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/status/status.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
index 66d7f64..84e84aa 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
@@ -17,9 +17,9 @@
#include <unordered_map>
#include <vector>
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index 8f56bfa..b7763e1 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -17,10 +17,11 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc
index 49c380b..60dcf00 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc
@@ -17,11 +17,12 @@
#include <unordered_map>
#include <vector>
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
index c3bfbf5..c98d64d 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
@@ -17,11 +17,12 @@
#include <unordered_map>
#include <vector>
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
index 547e0d8..c60ddff 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
@@ -12,7 +12,9 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/status/status.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
index f493d4e..c945615 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
@@ -17,11 +17,12 @@
#include <unordered_map>
#include <vector>
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
index 4781f4e..71a7d92 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
@@ -14,10 +14,12 @@
==============================================================================*/
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
index 183cb53..8a33ad5 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -15,7 +15,9 @@
#include <string>
#include <vector>
+#include "absl/status/status.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc
index f69afe4..380cdf2 100644
--- a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -17,11 +17,12 @@
#include <unordered_map>
#include <vector>
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/dequantize.cc b/tensorflow/lite/toco/graph_transformations/dequantize.cc
index 1aa5069..5dd4d2e 100644
--- a/tensorflow/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/lite/toco/graph_transformations/dequantize.cc
@@ -17,11 +17,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc
index 0a7af2f..cdd748a 100644
--- a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc
+++ b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -17,11 +17,12 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc
index a076814..d3cfae0 100644
--- a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc
+++ b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -12,10 +12,12 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc
index 22d6d94..f8d639c 100644
--- a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc
+++ b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -17,7 +17,9 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
index b496f51..ed3a89a 100644
--- a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+++ b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
@@ -17,10 +17,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc
index 7d34270..64b91cc 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -17,11 +17,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
index 926d41c..3afa9c4 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -18,11 +18,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index 130cb0b..fa0baf9 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -17,7 +17,9 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/runtime/types.h"
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
index df1d6da..ba57090 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
@@ -17,10 +17,11 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
index bee6665..125e559 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
@@ -21,9 +21,12 @@
#include <utility>
#include <vector>
-#include "tensorflow/lite/toco/toco_port.h"
-#include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/lite/toco/format_port.h"
+#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/model_flags.pb.h"
+#include "tensorflow/lite/toco/tooling_util.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
index 9f93ee1..c7e2c9de 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
@@ -21,6 +21,8 @@
#include <unordered_set>
#include <vector>
+#include "absl/log/check.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/toco_port.h"
diff --git a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc
index f5a8d16..2da6fbe 100644
--- a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc
+++ b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc
@@ -20,6 +20,8 @@
#include <string>
#include <vector>
+#include "absl/log/check.h"
+#include "absl/status/status.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
index 53c12b4..6f142a4 100644
--- a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -17,7 +17,9 @@
#include <string>
#include <vector>
+#include "absl/status/status.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc
index 026f51a..985e588 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -15,10 +15,12 @@
#include <string>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc
index 4a6dea0..437147f 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc
@@ -18,7 +18,8 @@
#include <unordered_map>
#include <vector>
-#include "tensorflow/core/platform/logging.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/identify_util.h"
#include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc
index b66f0b0..e8a5d20 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -18,10 +18,12 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc
index 91bda7e..a980995a 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -17,10 +17,12 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc
index 18e74ae..df0aa9f 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc
@@ -16,8 +16,12 @@
#include <string>
#include <vector>
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 4b2c497..24299d5 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -18,8 +18,9 @@
#include <utility>
#include <vector>
-#include "absl/memory/memory.h"
-#include "absl/strings/string_view.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index 3de0a71..aea6d93 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -18,8 +18,8 @@
#include <utility>
#include <vector>
-#include "absl/memory/memory.h"
-#include "absl/strings/string_view.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc
index 580b680..1d1d67b 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc
@@ -17,8 +17,9 @@
#include <string>
#include <vector>
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc
index 31edcb4..0f28cb1 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc
@@ -17,10 +17,12 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
// This transformation rule tries to identify the PRelu structure generated by
// Keras, and convert it to a single op.
diff --git a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc
index dad425c..6f2e224 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc
@@ -17,7 +17,8 @@
#include <unordered_map>
#include <vector>
-#include "tensorflow/core/platform/logging.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/identify_util.h"
#include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_util.cc b/tensorflow/lite/toco/graph_transformations/identify_util.cc
index e860511..6ed8e33 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_util.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_util.cc
@@ -18,6 +18,7 @@
#include <string>
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/identify_util.h b/tensorflow/lite/toco/graph_transformations/identify_util.h
index 1a79231..6c59b0b 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_util.h
+++ b/tensorflow/lite/toco/graph_transformations/identify_util.h
@@ -17,6 +17,7 @@
#include <string>
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/lstm_utils.cc b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc
index 7a979b7..676aa75 100644
--- a/tensorflow/lite/toco/graph_transformations/lstm_utils.cc
+++ b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc
@@ -16,6 +16,9 @@
#include <string>
+#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/tooling_util.h"
+
namespace toco {
void CreateOptionalArray(Model* model, std::string* input_array_buffer,
diff --git a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index 290dc7f..0726b32 100644
--- a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -17,12 +17,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
index b07815e..a292b97 100644
--- a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
+++ b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -18,12 +18,14 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
#include "tensorflow/lite/toco/model.h"
-#include "tensorflow/lite/toco/runtime/types.h"
+#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
index 8538413..588a034 100644
--- a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
@@ -15,6 +15,9 @@
#include <algorithm>
#include <string>
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
index 79d8229..fffdde0 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
@@ -17,12 +17,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
#include "tensorflow/lite/toco/model.h"
-#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc
index af801c3..ef0a520 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -17,9 +17,11 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc
index 0f9197c..54b76fb 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc
@@ -17,11 +17,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 10968a9..62d8715 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -17,11 +17,13 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index ab6f407..5136bc0 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -21,11 +21,14 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc
index 87f95a9..9e5e580 100644
--- a/tensorflow/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/lite/toco/graph_transformations/quantize.cc
@@ -21,11 +21,15 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
+#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
index 5e867ea..bf9334f 100644
--- a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
+++ b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
@@ -18,10 +18,11 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
-#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc
index 438c7a6..fc15e8e 100644
--- a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc
+++ b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -17,11 +17,12 @@
#include <unordered_map>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
index fdc4d27..79e6b68 100644
--- a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
+++ b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
@@ -15,7 +15,8 @@
#include <string>
#include <vector>
-#include "tensorflow/core/platform/logging.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc
index 88402f0..45de603 100644
--- a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc
+++ b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -16,10 +16,12 @@
#include <string>
#include <vector>
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc
index 2af2324..860fa05 100644
--- a/tensorflow/lite/tools/optimize/modify_model_interface.cc
+++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc
@@ -248,7 +248,7 @@
TfLiteStatus RemoveInputTensor(ModelT* model,
const std::vector<TensorOpTensor>& inputs,
- int32 original_number_tensors) {
+ int32_t original_number_tensors) {
// Consistency check to make sure that erase start from the end.
int last_op_index = std::numeric_limits<int32_t>::max();
int last_tensor_index = std::numeric_limits<int32_t>::max();
@@ -274,7 +274,7 @@
TfLiteStatus RemoveOutputTensor(ModelT* model,
const std::vector<TensorOpTensor>& outputs,
- int32 original_number_tensors) {
+ int32_t original_number_tensors) {
// Consistency check to make sure that erase start from the end.
int last_op_index = std::numeric_limits<int32_t>::max();
int last_tensor_index = std::numeric_limits<int32_t>::max();
@@ -298,7 +298,6 @@
return kTfLiteOk;
}
-
int GetOriginalNumberOfTensors(const TensorType& input_type,
const TensorType& output_type, ModelT* model,
ErrorReporter* error_reporter) {
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index ab15889..a1c155f 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -321,6 +321,7 @@
tf_staging/third_party/repo.bzl:
tf_staging/third_party/six.BUILD:
tf_staging/third_party/snappy.BUILD:
+tf_staging/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD:
tf_staging/third_party/sqlite.BUILD:
tf_staging/third_party/stablehlo/BUILD:
tf_staging/third_party/systemlibs/BUILD.tpl:
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 035daec..8aa243a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -932,7 +932,6 @@
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
"//tensorflow/dtensor/cc:dtensor_device_cc", # DTensor
"//tensorflow/dtensor/cc:tensor_layout", # DTensor
- "//tensorflow/lite/kernels/shim:tf_op_shim", # tf_text
"//tensorflow/lite/toco/python:toco_python_api", # toco
"//tensorflow/python/client:tf_session_helper", # tf_session
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
diff --git a/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb b/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb
index 8b7b3e9..44c2ea6 100644
--- a/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb
+++ b/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb
@@ -164,7 +164,7 @@
"source": [
"### Helpful static analysis passes\n",
"\n",
- "The `static_analysis` module contains various helper passes for dataflow analyis.\n",
+ "The `static_analysis` module contains various helper passes for dataflow analysis.\n",
"\n",
"All these passes annotate the AST. These annotations can be extracted using [anno.getanno](https://github.com/tensorflow/tensorflow/blob/40802bcdb5c8a4379da2145441f51051402bd29b/tensorflow/python/autograph/pyct/anno.py#L111). Most of them rely on the `qual_names` annotations, which just simplify the way more complex identifiers like `a.b.c` are accessed.\n",
"\n",
@@ -253,7 +253,7 @@
"\n",
"\n",
"def f(a):\n",
- " if a \u003e 0:\n",
+ " if a > 0:\n",
" return a\n",
" b = -a\n",
"\n",
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index 08f7342..ffacbe4 100644
--- a/tensorflow/python/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -172,7 +172,7 @@
hasattr(root_node.decorator_list[0], 'lineno')):
# Typical case: functions. The line number of the first decorator
# is more accurate than the line number of the function itself in
- # 3.8+. In earier versions they coincide.
+ # 3.8+. In earlier versions they coincide.
self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno
else:
# Fall back to the line number of the root node.
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 8af0f7a..5d6a872 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -570,7 +570,7 @@
node.decorator_list = self.visit_block(node.decorator_list)
if node.returns:
node.returns = self._process_annotation(node.returns)
- # Argument annotartions (includeing defaults) affect the defining context.
+ # Argument annotartions (including defaults) affect the defining context.
node = self._visit_arg_annotations(node)
function_name = qual_names.QN(node.name)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
index cdeddaa..ad373a5 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -249,7 +249,7 @@
inner_fn_body = fn_body[1].body[1].body
def_of_a_in_foo = inner_fn_body[0].value
- # Even though `a` is visible in the inner functio above, the late binding
+ # Even though `a` is visible in the inner function above, the late binding
# makes it impossible to assume that the same value will be visible at
# call time.
self.assertHasDefs(def_of_a_in_foo, 0)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
index 5b59a5a..d5ab1f5 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
@@ -63,9 +63,10 @@
ns: namespace
types_ns: types namespace
name: symbol name
+
Returns:
Tuple (type, static_value). The first element is the type to use for
- inferrence. The second is the static value to use. Return None to treat it
+ inference. The second is the static value to use. Return None to treat it
as unknown.
"""
raise NotImplementedError('subclasses must implement')
@@ -383,7 +384,7 @@
for t in f_types:
if isinstance(t, Callable):
- # Note: these are undocummented - may be version-specific!
+ # Note: these are undocumented - may be version-specific!
# Callable[[x], y]: __args__ are (x, y)
args = t.__args__
if args:
diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py
index c190097..005135e 100644
--- a/tensorflow/python/autograph/pyct/transformer.py
+++ b/tensorflow/python/autograph/pyct/transformer.py
@@ -314,13 +314,13 @@
in nodes
after_visit: optional callable that takes in an AST node and returns a
tuple (new_node, new_destination). It is called after visiting each item
- in nodes. Is used in the same was as the
- visit_* methods: new_node will replace the node; if not None,
- new_destination must be a list, and subsequent nodes will be placed
- in this list instead of the list returned by visit_block.
+ in nodes. Is used in the same was as the visit_* methods: new_node will
+ replace the node; if not None, new_destination must be a list, and
+ subsequent nodes will be placed in this list instead of the list
+ returned by visit_block.
Returns:
- A list of AST node objects containing the transformed items fron nodes,
+ A list of AST node objects containing the transformed items from nodes,
except those nodes that have been relocated using after_visit.
"""
if nodes is None:
diff --git a/tensorflow/python/autograph/pyct/transpiler.py b/tensorflow/python/autograph/pyct/transpiler.py
index 013ccc5..f7b9150 100644
--- a/tensorflow/python/autograph/pyct/transpiler.py
+++ b/tensorflow/python/autograph/pyct/transpiler.py
@@ -238,7 +238,7 @@
result = <<transform node>>
return result
- transformer = MyTransfomer()
+ transformer = MyTransformer()
result = transformer.transform(f, ...)
# result is the output
@@ -381,7 +381,7 @@
node = <<transform node, usually using ast.NodeTransformer classes>>
return node
- transformer = MyTransfomer()
+ transformer = MyTransformer()
new_f, module, source_map = transformer.transform_function(f, ...)
# new_f is a function with signature identical to f
@@ -430,7 +430,7 @@
return cached_factory
def transform_function(self, fn, user_context):
- """Transforms a function. See GenericTranspiler.trasnform_function.
+ """Transforms a function. See GenericTranspiler.transform_function.
This overload wraps the parent's `transform_function`, adding caching and
facilities to instantiate the output as a Python object. It also
@@ -441,6 +441,7 @@
fn: A function or lambda.
user_context: An opaque object (may be None) that is forwarded to
transform_ast, through the ctx.user attribute.
+
Returns:
A tuple:
* A function or lambda with the same signature and closure as `fn`
diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc
index b2d3492..00baa13 100644
--- a/tensorflow/python/client/tf_session_wrapper.cc
+++ b/tensorflow/python/client/tf_session_wrapper.cc
@@ -334,7 +334,7 @@
tf_handle(const tf_handle<T>& other) { Reset(other.obj_); }
- tf_handle<T>& operator=(tf_handle<T>&& other) {
+ tf_handle<T>& operator=(tf_handle<T>&& other) noexcept {
if (this == &other) {
return *this;
}
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 77c31ab..a4777a7 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -29,7 +29,7 @@
# This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 7, 31)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 8, 12)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
diff --git a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py
index 857c6f7..b7c062b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py
@@ -88,13 +88,6 @@
test_base.default_test_combinations(),
combinations.combine(
num_elements=10,
- asserted_cardinality=1,
- expected_error=errors.FailedPreconditionError,
- expected_error_message=(
- "Input dataset was expected to contain 1 element but "
- "contained at least 2 elements.")) +
- combinations.combine(
- num_elements=10,
asserted_cardinality=100,
expected_error=errors.FailedPreconditionError,
expected_error_message=(
diff --git a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi
index e88ec56..29126c1 100644
--- a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi
+++ b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi
@@ -14,4 +14,3 @@
# ==============================================================================
def TF_DATA_DefaultProtocol() -> str: ...
-def TF_DATA_DisableCompressionAtRegistrationTime() -> bool: ...
diff --git a/tensorflow/python/data/experimental/service/utils_wrapper.cc b/tensorflow/python/data/experimental/service/utils_wrapper.cc
index f949829..c725ff3 100644
--- a/tensorflow/python/data/experimental/service/utils_wrapper.cc
+++ b/tensorflow/python/data/experimental/service/utils_wrapper.cc
@@ -23,8 +23,4 @@
PYBIND11_MODULE(_pywrap_utils_exp, m) {
m.def("TF_DATA_DefaultProtocol",
[]() -> std::string { return tensorflow::data::DefaultProtocol(); });
-
- m.def("TF_DATA_DisableCompressionAtRegistrationTime", []() -> bool {
- return tensorflow::data::DisableCompressionAtRegistrationTime();
- });
};
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 7c20b1f..88de9cd 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -213,9 +213,11 @@
name = "concatenate_test",
size = "medium",
srcs = ["concatenate_test.py"],
+ shard_count = 20,
deps = [
":checkpoint_test_base",
":test_base",
+ "//tensorflow/python/data/experimental/ops:global_shuffle_op",
"//tensorflow/python/data/experimental/ops:random_access",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:options",
diff --git a/tensorflow/python/data/kernel_tests/concatenate_test.py b/tensorflow/python/data/kernel_tests/concatenate_test.py
index be08594..51c7ca4 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_test.py
@@ -13,8 +13,10 @@
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.concatenate()."""
+from typing import Callable, Tuple
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.experimental.ops import global_shuffle_op
from tensorflow.python.data.experimental.ops import random_access
from tensorflow.python.data.kernel_tests import checkpoint_test_base
from tensorflow.python.data.kernel_tests import test_base
@@ -248,5 +250,301 @@
self.evaluate(random_access.at(concatenated, index=5))
+class GlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
+ """Tests for global shuffling of tf.data datasets."""
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testShuffledOutput(self):
+ dataset1 = dataset_ops.Dataset.range(0, 5)
+ dataset2 = dataset_ops.Dataset.range(5, 17)
+
+ dataset = dataset1.concatenate(dataset2)
+
+ dataset = global_shuffle_op._global_shuffle(dataset)
+
+ output = self.getDatasetOutput(dataset, requires_initialization=True)
+ self.assertCountEqual(output, range(0, 17))
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testShuffledWithBatchOutput(self):
+ """Testing with `.batch()` ensures the global shuffle map is stateless."""
+ dataset1 = dataset_ops.Dataset.range(0, 4)
+ dataset2 = dataset_ops.Dataset.range(4, 10)
+
+ dataset = dataset1.concatenate(dataset2)
+ dataset = dataset.batch(3, drop_remainder=True)
+
+ dataset = global_shuffle_op._global_shuffle(dataset)
+
+ got = self.getDatasetOutput(dataset, requires_initialization=True)
+ expected = [
+ np.array([0, 1, 2], dtype=np.int32),
+ np.array([3, 4, 5], dtype=np.int32),
+ np.array([6, 7, 8], dtype=np.int32),
+ ]
+
+ self.assertIsInstance(got, list)
+ # Converts to tuples for lexicographically sort
+ got.sort(key=tuple)
+
+ self.assertLen(got, len(expected))
+
+ for element_got, element_expected in zip(got, expected):
+ self.assertAllEqual(element_got, element_expected)
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testNestedConcatenateShuffledOutput(self):
+ dataset1 = dataset_ops.Dataset.range(0, 3)
+ dataset2 = dataset_ops.Dataset.range(3, 6)
+ dataset3 = dataset_ops.Dataset.range(6, 9)
+
+ dataset = dataset1.concatenate(dataset2)
+ dataset = dataset.concatenate(dataset3)
+
+ dataset = global_shuffle_op._global_shuffle(dataset)
+
+ output = self.getDatasetOutput(dataset, requires_initialization=True)
+ self.assertCountEqual(output, range(0, 9))
+
+
+class ConcatenateGlobalShuffleCheckpointTest(
+ checkpoint_test_base.CheckpointTestBase, parameterized.TestCase
+):
+
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ checkpoint_test_base.default_test_combinations(),
+ combinations.combine(
+ dataset_ranges=[(10, 8), (9, 5), (4, 7), (5, 8)],
+ reshuffle_each_iteration=[True, False],
+ symbolic_checkpoint=[True, False],
+ ),
+ )
+ )
+ def testConcatenate(
+ self,
+ verify_fn: Callable[..., None],
+ dataset_ranges: Tuple[int, int],
+ reshuffle_each_iteration: bool,
+ symbolic_checkpoint: bool,
+ ):
+
+ def _build_dataset():
+ first_dataset = dataset_ops.Dataset.range(dataset_ranges[0])
+ second_dataset = dataset_ops.Dataset.range(
+ dataset_ranges[0], dataset_ranges[0] + dataset_ranges[1]
+ )
+ dataset = first_dataset.concatenate(second_dataset)
+ dataset = global_shuffle_op._global_shuffle(
+ dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+ )
+
+ options = options_lib.Options()
+ options.experimental_optimization.apply_default_optimizations = False
+ options.experimental_symbolic_checkpoint = symbolic_checkpoint
+ return dataset.with_options(options)
+
+ verify_fn(
+ self,
+ _build_dataset,
+ num_outputs=sum(dataset_ranges),
+ assert_items_equal=reshuffle_each_iteration,
+ )
+
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ checkpoint_test_base.default_test_combinations(),
+ combinations.combine(
+ dataset_ranges=[(10, 8, 11), (9, 5, 3)],
+ reshuffle_each_iteration=[True, False],
+ symbolic_checkpoint=[True, False],
+ ),
+ )
+ )
+ def testNestedConcatenate(
+ self,
+ verify_fn: Callable[..., None],
+ dataset_ranges: Tuple[int, int],
+ reshuffle_each_iteration: bool,
+ symbolic_checkpoint: bool,
+ ):
+
+ def _build_dataset():
+ first_dataset = dataset_ops.Dataset.range(dataset_ranges[0])
+ second_dataset = dataset_ops.Dataset.range(
+ dataset_ranges[0], dataset_ranges[0] + dataset_ranges[1]
+ )
+ third_dataset = dataset_ops.Dataset.range(
+ sum(dataset_ranges[:2]), sum(dataset_ranges[:3])
+ )
+
+ dataset = first_dataset.concatenate(second_dataset)
+ dataset = dataset.concatenate(third_dataset)
+
+ dataset = global_shuffle_op._global_shuffle(
+ dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+ )
+
+ options = options_lib.Options()
+ options.experimental_optimization.apply_default_optimizations = False
+ options.experimental_symbolic_checkpoint = symbolic_checkpoint
+ return dataset.with_options(options)
+
+ verify_fn(
+ self,
+ _build_dataset,
+ num_outputs=sum(dataset_ranges),
+ assert_items_equal=reshuffle_each_iteration,
+ )
+
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ checkpoint_test_base.default_test_combinations(),
+ combinations.combine(
+ dataset_ranges=[(3, 4, 6, 5)],
+ reshuffle_each_iteration=[True, False],
+ symbolic_checkpoint=[True, False],
+ ),
+ )
+ )
+ def testFourNestedConcatenate(
+ self,
+ verify_fn: Callable[..., None],
+ dataset_ranges: Tuple[int, int],
+ reshuffle_each_iteration: bool,
+ symbolic_checkpoint: bool,
+ ):
+ def _build_dataset():
+ first_dataset = dataset_ops.Dataset.range(dataset_ranges[0])
+ second_dataset = dataset_ops.Dataset.range(
+ dataset_ranges[0], sum(dataset_ranges[:2])
+ )
+ third_dataset = dataset_ops.Dataset.range(
+ sum(dataset_ranges[:2]), sum(dataset_ranges[:3])
+ )
+ fourth_dataset = dataset_ops.Dataset.range(
+ sum(dataset_ranges[:3]), sum(dataset_ranges)
+ )
+
+ left = first_dataset.concatenate(second_dataset)
+ right = third_dataset.concatenate(fourth_dataset)
+
+ dataset = left.concatenate(right)
+ dataset = global_shuffle_op._global_shuffle(
+ dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+ )
+
+ options = options_lib.Options()
+ options.experimental_optimization.apply_default_optimizations = False
+ options.experimental_symbolic_checkpoint = symbolic_checkpoint
+ return dataset.with_options(options)
+
+ verify_fn(
+ self,
+ _build_dataset,
+ num_outputs=sum(dataset_ranges),
+ assert_items_equal=reshuffle_each_iteration,
+ )
+
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ checkpoint_test_base.default_test_combinations(),
+ combinations.combine(
+ dataset_ranges=[(1, 2, 3, 4, 5, 6)],
+ reshuffle_each_iteration=[True, False],
+ symbolic_checkpoint=[True, False],
+ ),
+ )
+ )
+ def testDeepConcatenate(
+ self,
+ verify_fn: Callable[..., None],
+ dataset_ranges: Tuple[int, ...],
+ reshuffle_each_iteration: bool,
+ symbolic_checkpoint: bool,
+ ):
+ def _build_dataset():
+ prefix_sums = [0] * (len(dataset_ranges) + 1)
+ for i, value in enumerate(dataset_ranges):
+ prefix_sums[i + 1] = prefix_sums[i] + value
+
+ dataset = dataset_ops.Dataset.range(prefix_sums[0], prefix_sums[1])
+ for i in range(1, len(dataset_ranges)):
+ to_concat = dataset_ops.Dataset.range(
+ prefix_sums[i], prefix_sums[i + 1]
+ )
+ dataset = dataset.concatenate(to_concat)
+
+ dataset = global_shuffle_op._global_shuffle(
+ dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+ )
+
+ options = options_lib.Options()
+ options.experimental_optimization.apply_default_optimizations = False
+ options.experimental_symbolic_checkpoint = symbolic_checkpoint
+ return dataset.with_options(options)
+
+ verify_fn(
+ self,
+ _build_dataset,
+ num_outputs=sum(dataset_ranges),
+ assert_items_equal=reshuffle_each_iteration,
+ )
+
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ checkpoint_test_base.default_test_combinations(),
+ combinations.combine(
+ dataset_ranges=[(1, 2, 3, 4, 5, 6)],
+ reshuffle_each_iteration=[True, False],
+ symbolic_checkpoint=[True, False],
+ ),
+ )
+ )
+ def testDeepConcatenateWithBatchAndPrefetch(
+ self,
+ verify_fn: Callable[..., None],
+ dataset_ranges: Tuple[int, ...],
+ reshuffle_each_iteration: bool,
+ symbolic_checkpoint: bool,
+ ):
+ def _build_dataset():
+ prefix_sums = [0] * (len(dataset_ranges) + 1)
+ for i, value in enumerate(dataset_ranges):
+ prefix_sums[i + 1] = prefix_sums[i] + value
+
+ dataset = dataset_ops.Dataset.range(prefix_sums[0], prefix_sums[1])
+ for i in range(1, len(dataset_ranges)):
+ to_concat = dataset_ops.Dataset.range(
+ prefix_sums[i], prefix_sums[i + 1]
+ )
+ dataset = dataset.concatenate(to_concat)
+
+ dataset = dataset.batch(2, drop_remainder=True)
+ dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
+
+ dataset = global_shuffle_op._global_shuffle(
+ dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+ )
+ dataset = dataset.unbatch()
+
+ options = options_lib.Options()
+ options.experimental_optimization.apply_default_optimizations = False
+ options.experimental_symbolic_checkpoint = symbolic_checkpoint
+ return dataset.with_options(options)
+
+ verify_fn(
+ self,
+ _build_dataset,
+ num_outputs=(sum(dataset_ranges) // 2) * 2,
+ assert_items_equal=reshuffle_each_iteration,
+ )
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py
index 00299dd..c2c5f26 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_test.py
@@ -15,6 +15,7 @@
"""Tests for `tf.data.Dataset.flat_map()`."""
import random
from typing import Callable, Optional
+import unittest
from absl.testing import parameterized
import numpy as np
@@ -466,6 +467,10 @@
verify_fn(self, build_dataset, num_outputs=3 * 4 - num_skips)
+@unittest.skip(
+ "TODO: b/355241367 - `flat_map_dataset_op.cc` still needs to be fixed."
+ " Please use concatenate dataset op plus global shuffling instead."
+)
class FlatMapGlobalShuffleTest(
test_base.DatasetTestBase, parameterized.TestCase):
@@ -511,6 +516,10 @@
self.getDatasetOutput(dataset, requires_initialization=True)
+@unittest.skip(
+ "TODO: b/355241367 - `flat_map_dataset_op.cc` still needs to be fixed."
+ " Please use concatenate dataset op plus global shuffling instead."
+)
class FlatMapGlobalShuffleCheckpointTest(
checkpoint_test_base.CheckpointTestBase, parameterized.TestCase
):
diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py
index 2d00e6b..2e58aeb 100644
--- a/tensorflow/python/data/kernel_tests/map_test.py
+++ b/tensorflow/python/data/kernel_tests/map_test.py
@@ -262,13 +262,15 @@
self.assertAllEqual(component[i]**2, result_component)
def _parallel_map_dataset_factory(self, components, apply_map, count,
- num_parallel_calls, buffer_size):
+ num_parallel_calls, buffer_size,
+ use_unbounded_threadpool=False):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
dataset = dataset_ops.Dataset.from_tensor_slices(components)
- dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls)
+ dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls,
+ use_unbounded_threadpool=use_unbounded_threadpool)
dataset = dataset.prefetch(buffer_size).repeat(count)
self.assertEqual(
@@ -284,8 +286,10 @@
combinations.combine(num_parallel_calls=2, buffer_size=2) +
combinations.combine(num_parallel_calls=2, buffer_size=4) +
combinations.combine(num_parallel_calls=8, buffer_size=8) +
- combinations.combine(num_parallel_calls=8, buffer_size=16)))
- def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size):
+ combinations.combine(num_parallel_calls=8, buffer_size=16),
+ combinations.combine(use_unbounded_threadpool=[None, True, False])))
+ def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size,
+ use_unbounded_threadpool):
"""Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
@@ -296,7 +300,8 @@
# Test single-threaded access to the iterator.
get_next = self.getNext(
self._parallel_map_dataset_factory(components, apply_map, 14,
- num_parallel_calls, buffer_size))
+ num_parallel_calls, buffer_size,
+ use_unbounded_threadpool))
for _ in range(14):
for i in range(7):
result = self.evaluate(get_next())
@@ -1537,6 +1542,20 @@
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
+ combinations.combine(
+ use_unbounded_threadpool=[True, False])))
+ def testAutotuneUseUnboundedThreadpool(self, use_unbounded_threadpool):
+ dataset = dataset_ops.Dataset.range(100)
+ dataset = dataset.map(
+ lambda x: x * 2,
+ num_parallel_calls=dataset_ops.AUTOTUNE,
+ use_unbounded_threadpool=use_unbounded_threadpool,
+ deterministic=True,
+ name="map")
+ self.assertDatasetProduces(dataset, [x * 2 for x in range(100)])
+
+ @combinations.generate(
+ combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 1])))
def testName(self, num_parallel_calls):
dataset = dataset_ops.Dataset.from_tensors(21).map(
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index b2580e3..c8d01f2 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -2157,6 +2157,7 @@
num_parallel_calls=None,
deterministic=None,
synchronous=None,
+ use_unbounded_threadpool=False,
name=None,
) -> "DatasetV2":
"""Maps `map_func` across the elements of this dataset.
@@ -2313,6 +2314,11 @@
saving memory, since even setting `num_parallel_calls=1` will cause one
batch to be buffered, while with `synchronous=True` the map
transformation doesn't buffer anything.
+ use_unbounded_threadpool: (Optional.) By default, map functions run in a
+ limited threadpool based on the number of cores on the machine. This
+ efficient for CPU-heavy processing, but if the map function performs IO
+ it is better to use an unbounded threadpool by setting it to `True`. It
+ is `False` by default.
name: (Optional.) A name for the tf.data operation.
Returns:
@@ -2329,6 +2335,7 @@
num_parallel_calls=num_parallel_calls,
deterministic=deterministic,
synchronous=synchronous,
+ use_unbounded_threadpool=use_unbounded_threadpool,
name=name,
)
# pylint: enable=g-import-not-at-top,protected-access
@@ -4092,6 +4099,7 @@
num_parallel_calls=None,
deterministic=None,
synchronous=None,
+ use_unbounded_threadpool=False,
name=None,
):
# Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
@@ -4105,12 +4113,17 @@
num_parallel_calls=num_parallel_calls,
deterministic=deterministic,
synchronous=synchronous,
+ use_unbounded_threadpool=use_unbounded_threadpool,
)
# pylint: enable=g-import-not-at-top,protected-access
@deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
def map_with_legacy_function(
- self, map_func, num_parallel_calls=None, deterministic=None
+ self,
+ map_func,
+ num_parallel_calls=None,
+ deterministic=None,
+ use_unbounded_threadpool=False,
) -> "DatasetV1Adapter":
"""Maps `map_func` across the elements of this dataset.
@@ -4133,6 +4146,11 @@
elements out of order to trade determinism for performance. If not
specified, the `tf.data.Options.deterministic` option (`True` by
default) controls the behavior.
+ use_unbounded_threadpool: (Optional.) By default, map functions run in a
+ limited threadpool based on the number of cores on the machine. This
+ efficient for CPU-heavy processing, but if the map function performs IO
+ it is better to use an unbounded threadpool by setting it to `True`. It
+ is `False` by default.
Returns:
Dataset: A `Dataset`.
diff --git a/tensorflow/python/data/ops/map_op.py b/tensorflow/python/data/ops/map_op.py
index 0a056ab..f301dee 100644
--- a/tensorflow/python/data/ops/map_op.py
+++ b/tensorflow/python/data/ops/map_op.py
@@ -30,6 +30,7 @@
num_parallel_calls=None,
deterministic=None,
synchronous=None,
+ use_unbounded_threadpool=None,
name=None,
):
"""See `Dataset.map()` for details."""
@@ -59,6 +60,7 @@
num_parallel_calls=num_parallel_calls,
deterministic=deterministic,
preserve_cardinality=True,
+ use_unbounded_threadpool=use_unbounded_threadpool,
name=name)
@@ -68,6 +70,7 @@
num_parallel_calls=None,
deterministic=None,
synchronous=None,
+ use_unbounded_threadpool=None, # pylint: disable=unused-argument
):
"""See `Dataset.map()` for details."""
if num_parallel_calls is None or debug_mode.DEBUG_MODE:
@@ -92,7 +95,8 @@
map_func,
num_parallel_calls,
deterministic,
- preserve_cardinality=False))
+ preserve_cardinality=False,
+ use_unbounded_threadpool=False))
def _map_v1_with_legacy_function( # pylint: disable=unused-private-name
@@ -130,7 +134,8 @@
num_parallel_calls,
deterministic,
preserve_cardinality=False,
- use_legacy_function=True))
+ use_legacy_function=True,
+ use_unbounded_threadpool=False))
class _MapDataset(dataset_ops.UnaryDataset):
@@ -189,6 +194,7 @@
use_inter_op_parallelism=True,
preserve_cardinality=False,
use_legacy_function=False,
+ use_unbounded_threadpool=False,
name=None):
"""See `Dataset.map()` for details."""
self._input_dataset = input_dataset
@@ -207,6 +213,7 @@
self._preserve_cardinality = preserve_cardinality
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+ self._use_unbounded_threadpool = use_unbounded_threadpool
self._name = name
variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
@@ -216,6 +223,7 @@
deterministic=self._deterministic,
use_inter_op_parallelism=self._use_inter_op_parallelism,
preserve_cardinality=self._preserve_cardinality,
+ use_unbounded_threadpool=self._use_unbounded_threadpool,
**self._common_args)
super().__init__(input_dataset, variant_tensor)
diff --git a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
index cd3d2ee..1f334a0 100644
--- a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
+++ b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
@@ -19,15 +19,9 @@
#include "Python.h"
#include "pybind11/pybind11.h" // from @pybind11
#include "pybind11/stl.h" // from @pybind11
-#include "tensorflow/c/c_api.h"
-#include "tensorflow/c/c_api_experimental.h"
-#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
-#include "tensorflow/c/safe_ptr.h"
-#include "tensorflow/python/lib/core/py_exception_registry.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
-#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
namespace py = pybind11;
diff --git a/tensorflow/python/flags_pybind.pyi b/tensorflow/python/flags_pybind.pyi
index b34ed2f..7c450b6 100644
--- a/tensorflow/python/flags_pybind.pyi
+++ b/tensorflow/python/flags_pybind.pyi
@@ -24,6 +24,7 @@
enable_function_pruning_before_inlining: Flag
enable_nested_function_shape_inference: Flag
enable_quantized_dtypes_training: Flag
+ enable_skip_encapsulation_for_non_tpu_graphs: Flag
enable_tf2min_ici_weight: Flag
graph_building_optimization: Flag
more_stack_traces: Flag
diff --git a/tensorflow/python/framework/offset_counter_helper_test.cc b/tensorflow/python/framework/offset_counter_helper_test.cc
index dcf6f7c..ef616a3 100644
--- a/tensorflow/python/framework/offset_counter_helper_test.cc
+++ b/tensorflow/python/framework/offset_counter_helper_test.cc
@@ -18,10 +18,10 @@
#include <string>
#include "absl/strings/str_format.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/python/framework/op_reg_offset.pb.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 4cd43ae..1c81b35 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -716,16 +716,22 @@
// These objects are efficiently handled by Numpy. We transform them into
// Numpy arrays and handle them in the Numpy case below. Note that Tensors
// implement the __array__ function, and will be handled in this shortcut.
- Safe_PyObjectPtr array =
- make_safe(PyArray_FromArrayAttr(obj, nullptr, nullptr));
- if (array == nullptr) {
- return nullptr;
+ // We used to call PyArray_FromArrayAttr here, but NumPy 2.0 changed its
+ // semantics such that it errors if a copy of the array is required.
+ // (Ideally no copy would be needed here, but that would be a larger change.)
+ Safe_PyObjectPtr array;
+ if (PyObject_HasAttrString(obj, "__array__")) {
+ array = make_safe(PyObject_CallMethod(obj, "__array__", nullptr));
+ if (array == nullptr) {
+ return nullptr;
+ }
+ if (!PyArray_Check(array.get())) {
+ PyErr_SetString(PyExc_ValueError,
+ "Value returned by __array__ is not a NumPy array");
+ return nullptr;
+ }
}
- if (array.get() == Py_NotImplemented) {
- // The Py_NotImplemented returned from PyArray_FromArrayAttr is not
- // Py_INCREF'ed, so we don't want the Safe_PyObjectPtr to Py_DECREF it.
- array.release();
-
+ if (!array) {
// Try __array_interface__ objects (such as PIL Image).
array = make_safe(PyArray_FromInterface(obj));
if (array == nullptr) {
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 0cb65da..76dd388 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -396,7 +396,7 @@
Given a tensor `input`, this operation inserts a dimension of length 1 at the
dimension index `axis` of `input`'s shape. The dimension index follows Python
- indexing rules: It's zero-based, a negative index it is counted backward
+ indexing rules: It's zero-based, and a negative index is counted backward
from the end.
This operation is useful to:
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index ab508d1..d553667 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2010,8 +2010,12 @@
# infer dtype if not explicitly provided
if dtype is None:
dtype_hierarchy = [
- dtypes.int32, dtypes.int64, dtypes.float16, dtypes.bfloat16,
- dtypes.float32, dtypes.float64
+ dtypes.int32,
+ dtypes.int64,
+ dtypes.float16,
+ dtypes.bfloat16,
+ dtypes.float32,
+ dtypes.float64,
]
assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
inferred_dtype = max([arg.dtype for arg in [start, limit, delta]],
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 17a2776..b43fcc3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -129,11 +129,11 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "map_with_legacy_function"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 88a2e4f..e4b1c19 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -131,11 +131,11 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "map_with_legacy_function"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index b71ecb1..999d0c0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -131,11 +131,11 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "map_with_legacy_function"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index e4f1c85..123c3b7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -131,11 +131,11 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "map_with_legacy_function"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
index 266a8f1..2b8fdc3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -131,11 +131,11 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "map_with_legacy_function"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
index aff5ee5f..77b396a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -131,11 +131,11 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "map_with_legacy_function"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
index 6986dab..48c8540 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -131,11 +131,11 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "map_with_legacy_function"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 8c5d4bf..c55d7db 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -3014,7 +3014,7 @@
}
member_method {
name: "ParallelMapDatasetV2"
- argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'\', \'None\'], "
+ argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'use_unbounded_threadpool\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'False\', \'\', \'None\'], "
}
member_method {
name: "ParameterizedTruncatedNormal"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 1852f76..00f5d54 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -100,7 +100,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 5f1f368..9230f45 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -102,7 +102,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 85eb696..e89d811 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -101,7 +101,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 42a293d..c936ff4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -102,7 +102,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
index d376170..f31da9c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -102,7 +102,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
index 190a21f..8faf543 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -103,7 +103,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
index e19f932..490b7f7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -102,7 +102,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt
index d6d6439..c2194ab 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt
@@ -103,7 +103,7 @@
}
member_method {
name: "map"
- argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 8c5d4bf..c55d7db 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -3014,7 +3014,7 @@
}
member_method {
name: "ParallelMapDatasetV2"
- argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'\', \'None\'], "
+ argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'use_unbounded_threadpool\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'False\', \'\', \'None\'], "
}
member_method {
name: "ParameterizedTruncatedNormal"
diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc
index c6e4284..1b1b443 100644
--- a/tensorflow/tools/benchmark/benchmark_model_test.cc
+++ b/tensorflow/tools/benchmark/benchmark_model_test.cc
@@ -18,6 +18,7 @@
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -29,7 +30,6 @@
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/stat_summarizer.h"
-#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh
index 72a228f..992aa6d 100644
--- a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh
+++ b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh
@@ -16,5 +16,4 @@
set -x
ARM_SKIP_TESTS="-//tensorflow/lite/... \
--//tensorflow/core/kernels/image:resize_bicubic_op_test \
"
diff --git a/tensorflow/tools/graph_transforms/backports_test.cc b/tensorflow/tools/graph_transforms/backports_test.cc
index 80a954e..155ec29 100644
--- a/tensorflow/tools/graph_transforms/backports_test.cc
+++ b/tensorflow/tools/graph_transforms/backports_test.cc
@@ -192,7 +192,7 @@
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
ASSERT_EQ(1, node_lookup.count("v3_node"));
- EXPECT_TRUE(str_util::EndsWith(node_lookup.at("v3_node")->op(), "V2"));
+ EXPECT_TRUE(absl::EndsWith(node_lookup.at("v3_node")->op(), "V2"));
}
}
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index dcdc3c2..3d388cd 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -210,10 +210,10 @@
for (const NodeDef& node : graph_def.node()) {
const StringPiece name(node.name());
const int occurrence_count = folded_node_map.count(node.name());
- if (str_util::EndsWith(name, "expect_removed")) {
+ if (absl::EndsWith(name, "expect_removed")) {
EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
}
- if (str_util::EndsWith(name, "expect_remains")) {
+ if (absl::EndsWith(name, "expect_remains")) {
EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
}
}
diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD
index 9672bf1..da86a3a 100644
--- a/tensorflow/tools/proto_splitter/cc/BUILD
+++ b/tensorflow/tools/proto_splitter/cc/BUILD
@@ -97,9 +97,9 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:status_matchers",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@riegeli//riegeli/base:initializer",
"@riegeli//riegeli/bytes:cord_reader",
"@riegeli//riegeli/bytes:fd_reader",
@@ -163,11 +163,11 @@
"//tensorflow/tools/proto_splitter/testdata:test_message_proto_cc",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
] + if_oss([
"//tensorflow/tools/proto_splitter:protos_impl",
]),
diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc
index 55aac6a..d62acc5 100644
--- a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc
@@ -30,6 +30,7 @@
#include "riegeli/bytes/fd_reader.h" // from @riegeli
#include "riegeli/bytes/string_reader.h" // from @riegeli
#include "riegeli/records/record_reader.h" // from @riegeli
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
@@ -39,7 +40,6 @@
#include "tensorflow/tools/proto_splitter/cc/util.h"
#include "tensorflow/tools/proto_splitter/chunk.pb.h"
#include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc
index 1d98a3a..1fb19f5 100644
--- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc
@@ -23,6 +23,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/cord.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -34,7 +35,6 @@
#include "tensorflow/tools/proto_splitter/cc/test_util.h"
#include "tensorflow/tools/proto_splitter/cc/util.h"
#include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc
index b03bcc1..1712421 100644
--- a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc
@@ -22,6 +22,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -33,7 +34,6 @@
#include "tensorflow/tools/proto_splitter/cc/max_size.h"
#include "tensorflow/tools/proto_splitter/cc/util.h"
#include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/tools/proto_splitter/cc/util_test.cc b/tensorflow/tools/proto_splitter/cc/util_test.cc
index e318f7c..7880519 100644
--- a/tensorflow/tools/proto_splitter/cc/util_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/util_test.cc
@@ -21,12 +21,12 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/proto_splitter/cc/test_util.h"
#include "tensorflow/tools/proto_splitter/chunk.pb.h"
#include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/tools/proto_splitter/merge_test.cc b/tensorflow/tools/proto_splitter/merge_test.cc
index 5f78f3f..06d40d5 100644
--- a/tensorflow/tools/proto_splitter/merge_test.cc
+++ b/tensorflow/tools/proto_splitter/merge_test.cc
@@ -23,6 +23,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/test.h"
@@ -31,7 +32,6 @@
#include "tensorflow/tools/proto_splitter/cc/util.h"
#include "tensorflow/tools/proto_splitter/chunk.pb.h"
#include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/statusor.h"
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile
index 7659bd6..b6d1aca 100644
--- a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile
@@ -1,5 +1,5 @@
################################################################################
-FROM ubuntu:22.04@sha256:19478ce7fc2ffbce89df29fea5725a8d12e57de52eb9ea570890dc5852aac1ac as builder
+FROM ubuntu:22.04@sha256:340d9b015b194dc6e2a13938944e0d016e57b9679963fdeb9ce021daac430221 as builder
################################################################################
# Install devtoolset build dependencies
diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl
index 06def11..2959fe8 100644
--- a/tensorflow/workspace2.bzl
+++ b/tensorflow/workspace2.bzl
@@ -57,7 +57,6 @@
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("//third_party/tensorrt:workspace.bzl", tensorrt = "repo")
load("//third_party/triton:workspace.bzl", triton = "repo")
-load("//third_party/uv:workspace.bzl", uv = "repo")
load("//third_party/vulkan_headers:workspace.bzl", vulkan_headers = "repo")
def _initialize_third_party():
@@ -93,7 +92,6 @@
vulkan_headers()
tensorrt()
triton()
- uv()
# copybara: tsl vendor
diff --git a/third_party/absl/nvidia_jetson.patch b/third_party/absl/nvidia_jetson.patch
new file mode 100644
index 0000000..5328c3a
--- /dev/null
+++ b/third_party/absl/nvidia_jetson.patch
@@ -0,0 +1,35 @@
+From 372124e6af36a540e74a2ec31d79d7297a831f98 Mon Sep 17 00:00:00 2001
+From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= <frederic.bastien@gmail.com>
+Date: Thu, 1 Aug 2024 12:38:52 -0700
+Subject: [PATCH] PR #1732: Fix build on NVIDIA Jetson board. Fix #1665
+
+Imported from GitHub PR https://github.com/abseil/abseil-cpp/pull/1732
+
+Fix build on NVIDIA Jetson board. Fix #1665
+
+This patch is already used by the spark project.
+I'm fixing this as this break the build of Tensorflow and JAX on Jetson board.
+Merge 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff into 6b8ebb35c0414ef5a2b6fd4a0f59057e41beaff9
+
+Merging this change closes #1732
+
+COPYBARA_INTEGRATE_REVIEW=https://github.com/abseil/abseil-cpp/pull/1732 from nouiz:fix_neon_on_jetson 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff
+PiperOrigin-RevId: 658501520
+Change-Id: If502ede4efc8c877fb3fed227eca6dc7622dd181
+---
+ absl/base/config.h | 2 +-
+ 1 file changed, 1 insertion(+), 1 deletion(-)
+
+diff --git a/absl/base/config.h b/absl/base/config.h
+index 97c9a22a109..ab1e9860a91 100644
+--- a/absl/base/config.h
++++ b/absl/base/config.h
+@@ -926,7 +926,7 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' ||
+ // https://llvm.org/docs/CompileCudaWithLLVM.html#detecting-clang-vs-nvcc-from-code
+ #ifdef ABSL_INTERNAL_HAVE_ARM_NEON
+ #error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set
+-#elif defined(__ARM_NEON) && !defined(__CUDA_ARCH__)
++#elif defined(__ARM_NEON) && !(defined(__NVCC__) && defined(__CUDACC__))
+ #define ABSL_INTERNAL_HAVE_ARM_NEON 1
+ #endif
+
diff --git a/third_party/absl/workspace.bzl b/third_party/absl/workspace.bzl
index 06f7516..9565a82 100644
--- a/third_party/absl/workspace.bzl
+++ b/third_party/absl/workspace.bzl
@@ -44,4 +44,5 @@
system_link_files = SYS_LINKS,
strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
+ patch_file = ["//third_party/absl:nvidia_jetson.patch"],
)
diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl
index 0b85e59..44cdbe3 100644
--- a/third_party/gpus/cuda/BUILD.tpl
+++ b/third_party/gpus/cuda/BUILD.tpl
@@ -249,3 +249,9 @@
# to make bazel query happy.
name = "nvptxcompiler",
)
+
+cc_library(
+ # This is not yet fully supported, but we need the rule
+ # to make bazel query happy.
+ name = "nvjitlink",
+)
\ No newline at end of file
diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl
index c185ca7..ff9b53b 100644
--- a/third_party/gpus/rocm_configure.bzl
+++ b/third_party/gpus/rocm_configure.bzl
@@ -205,6 +205,8 @@
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include")
+ if int(rocm_config.rocm_version_number) >= 60200:
+ inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include")
# Support hcc based off clang 10.0.0 (for ROCm 3.3)
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index 6429d9b..54a3c65 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@
def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d"
- LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd"
+ LLVM_COMMIT = "4c5ef6690040383956461828457ac27f7f912edb"
+ LLVM_SHA256 = "a30da7822f5307bc0aca8c497ffdd6369e3877186e87501e2ac1f3ec5ed1c0b7"
tf_http_archive(
name = name,
diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD
index 4b3ad84..8c73096 100644
--- a/third_party/mkl_dnn/mkldnn_v1.BUILD
+++ b/third_party/mkl_dnn/mkldnn_v1.BUILD
@@ -12,7 +12,7 @@
"#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
"#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
- "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
+ "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH",
"#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE",
"#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
@@ -109,6 +109,7 @@
"-UUSE_CBLAS",
"-DDNNL_ENABLE_MAX_CPU_ISA",
"-DDNNL_ENABLE_ITT_TASKS",
+ "-DDNNL_ENABLE_GRAPH_DUMP",
] + tf_openmp_copts()
_INCLUDES_LIST = [
@@ -119,6 +120,7 @@
"src/cpu",
"src/cpu/gemm",
"src/cpu/x64/xbyak",
+ "src/graph",
]
_TEXTUAL_HDRS_LIST = glob([
@@ -129,6 +131,15 @@
"src/cpu/**/*.hpp",
"src/cpu/jit_utils/**/*.hpp",
"src/cpu/x64/xbyak/*.h",
+ "src/graph/interface/*.hpp",
+ "src/graph/backend/*.hpp",
+ "src/graph/backend/dnnl/*.hpp",
+ "src/graph/backend/fake/*.hpp",
+ "src/graph/backend/dnnl/passes/*.hpp",
+ "src/graph/backend/dnnl/patterns/*.hpp",
+ "src/graph/backend/dnnl/kernels/*.hpp",
+ "src/graph/utils/*.hpp",
+ "src/graph/utils/pm/*.hpp",
]) + [
":dnnl_config_h",
":dnnl_version_h",
@@ -160,6 +171,16 @@
"src/cpu/**/*.cpp",
"src/common/ittnotify/*.c",
"src/cpu/jit_utils/**/*.cpp",
+ "src/cpu/x64/**/*.cpp",
+ "src/graph/interface/*.cpp",
+ "src/graph/backend/*.cpp",
+ "src/graph/backend/dnnl/*.cpp",
+ "src/graph/backend/fake/*.cpp",
+ "src/graph/backend/dnnl/passes/*.cpp",
+ "src/graph/backend/dnnl/patterns/*.cpp",
+ "src/graph/backend/dnnl/kernels/*.cpp",
+ "src/graph/utils/*.cpp",
+ "src/graph/utils/pm/*.cpp",
],
exclude = [
"src/cpu/aarch64/**",
diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch
index 4d99610..1711c1b 100644
--- a/third_party/shardy/temporary.patch
+++ b/third_party/shardy/temporary.patch
@@ -1,15 +1,15 @@
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
-index 9345d8d..6429d9b 100644
+index 4b3f0db..54a3c65 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
def repo(name):
"""Imports LLVM."""
-- LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
-- LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"
-+ LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d"
-+ LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd"
+- LLVM_COMMIT = "16dadecc05fa4986d4522c2c3a09a7628feb0fd4"
+- LLVM_SHA256 = "e7c5195e30f75c6027f90b8196ded71a41a38c586931cfb33d63295d9eed95fd"
++ LLVM_COMMIT = "4c5ef6690040383956461828457ac27f7f912edb"
++ LLVM_SHA256 = "a30da7822f5307bc0aca8c497ffdd6369e3877186e87501e2ac1f3ec5ed1c0b7"
tf_http_archive(
name = name,
diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl
index 200ac3f..a40d827 100644
--- a/third_party/shardy/workspace.bzl
+++ b/third_party/shardy/workspace.bzl
@@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
- SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
- SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"
+ SHARDY_COMMIT = "76731821434117cb6d736bdd1b32b7ee4ffbcb4b"
+ SHARDY_SHA256 = "944bdbdc9e97ca95b15ac81bfee151664c06b4e6373661d25091e064a604fd2f"
tf_http_archive(
name = "shardy",
diff --git a/third_party/spirv_llvm_translator/BUILD b/third_party/spirv_llvm_translator/BUILD
new file mode 100644
index 0000000..8d626dc
--- /dev/null
+++ b/third_party/spirv_llvm_translator/BUILD
@@ -0,0 +1,7 @@
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+# spirv_llvm_translator license placeholder
diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
new file mode 100644
index 0000000..557e2e8
--- /dev/null
+++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
@@ -0,0 +1,34 @@
+cc_library(
+ name = "spirv_llvm_translator",
+ srcs = glob([
+ "lib/SPIRV/libSPIRV/*.cpp",
+ "lib/SPIRV/libSPIRV/*.hpp",
+ "lib/SPIRV/libSPIRV/*.h",
+ "lib/SPIRV/Mangler/*.cpp",
+ "lib/SPIRV/Mangler/*.h",
+ "lib/SPIRV/*.cpp",
+ "lib/SPIRV/*.hpp",
+ "lib/SPIRV/*.h",
+ ]),
+ hdrs = glob(["include/*"]),
+ includes = [
+ "include/",
+ "lib/SPIRV/",
+ "lib/SPIRV/Mangler/",
+ "lib/SPIRV/libSPIRV/",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@llvm-project//llvm:Analysis",
+ "@llvm-project//llvm:BitWriter",
+ "@llvm-project//llvm:CodeGen",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Demangle",
+ "@llvm-project//llvm:IRReader",
+ "@llvm-project//llvm:Linker",
+ "@llvm-project//llvm:Passes",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:TransformUtils",
+ "@spirv_headers//:spirv_cpp_headers",
+ ],
+)
diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
new file mode 100644
index 0000000..fc843b1
--- /dev/null
+++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
@@ -0,0 +1,25 @@
+diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h
+index a828add8..924e13b4 100644
+
+Spir backend uses different addrspace representations link with nvptx backend link.
+We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding
+changing addrspace based on device backend everywhere)
+
+--- a/lib/SPIRV/SPIRVInternal.h
++++ b/lib/SPIRV/SPIRVInternal.h
+@@ -179,11 +179,12 @@ typedef SPIRVMap<Op, Op, IntBoolOpMapId> IntBoolOpMap;
+ "-v512:512:512-v1024:1024:1024"
+
+ enum SPIRAddressSpace {
+- SPIRAS_Private,
++ SPIRAS_Generic,
+ SPIRAS_Global,
+- SPIRAS_Constant,
++ SPIRAS_Internal,
+ SPIRAS_Local,
+- SPIRAS_Generic,
++ SPIRAS_Constant,
++ SPIRAS_Private,
+ SPIRAS_GlobalDevice,
+ SPIRAS_GlobalHost,
+ SPIRAS_Input,
\ No newline at end of file
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
index 6a1b6b0..8b13789 100755
--- a/third_party/stablehlo/temporary.patch
+++ b/third_party/stablehlo/temporary.patch
@@ -1,19 +1 @@
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float32 {
- func.func public @main() {
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float64 {
- func.func public @main() {
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index f9c14a6..f1b02d2 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -4,8 +4,8 @@
def repo():
# LINT.IfChange
- STABLEHLO_COMMIT = "8555db77763fadbd6be83df0a5532828bc419cba"
- STABLEHLO_SHA256 = "666a88d94e0f1b36e9e5b25411521b878320c61983214859b4e419f36acbf332"
+ STABLEHLO_COMMIT = "24d1807a9a3e0df81103a0be9be7ad28ee34c85a"
+ STABLEHLO_SHA256 = "fc8b165379e6b34a7ab64ebe9334605efec9c01b9106cc78c40ed1163f35a529"
# LINT.ThenChange(Google-internal path)
tf_http_archive(
diff --git a/third_party/triton/BUILD b/third_party/triton/BUILD
index 3c41380..3b74065 100644
--- a/third_party/triton/BUILD
+++ b/third_party/triton/BUILD
@@ -1 +1,13 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+# copybara:uncomment_begin
+# package(default_applicable_licenses = ["//tensorflow:license"])
+#
+# filegroup(
+# name = "patch_files",
+# srcs = glob([
+# "xla_extensions/**",
+# "llvm_integration/**",
+# "temporary/**",
+# ]),
+# visibility = ["//third_party/triton:__subpackages__"],
+# )
+# copybara:uncomment_end
diff --git a/third_party/triton/llvm_integration/BUILD b/third_party/triton/llvm_integration/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/triton/llvm_integration/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/triton/llvm_integration/cl657620552.patch b/third_party/triton/llvm_integration/cl657620552.patch
deleted file mode 100644
index 4a1f47d..0000000
--- a/third_party/triton/llvm_integration/cl657620552.patch
+++ /dev/null
@@ -1,18 +0,0 @@
-# Do not upstream this patch. This has been already upstreamed in
-# https://github.com/triton-lang/triton/commit/de46a0ede6efe7e93c2a9ebef639e36c6177c511
-# Next integration will include it and this patch should be removed then.
-
-diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc
---- a/third_party/amd/python/triton_amd.cc
-+++ b/third_party/amd/python/triton_amd.cc
-@@ -193,9 +193,7 @@ void init_triton_amd(py::module &&m) {
- target->createMCAsmBackend(*sti, *mri, mcOptions));
- mcStreamer.reset(target->createMCObjectStreamer(
- triple, ctx, std::move(mab), mab->createObjectWriter(svos),
-- std::move(ce), *sti, mcOptions.MCRelaxAll,
-- mcOptions.MCIncrementalLinkerCompatible,
-- /*DWARFMustBeAtTheEnd=*/false));
-+ std::move(ce), *sti));
-
- std::unique_ptr<llvm::MCAsmParser> parser(
- createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl
index 5348e66..656b9c8 100644
--- a/third_party/triton/llvm_integration/series.bzl
+++ b/third_party/triton/llvm_integration/series.bzl
@@ -8,6 +8,5 @@
"""
llvm_patch_list = [
- "//third_party/triton/llvm_integration:cl657620552.patch",
# Add new patches just above this line
]
diff --git a/third_party/triton/temporary/BUILD b/third_party/triton/temporary/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/triton/temporary/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/triton/temporary/cuda11-temporary.patch b/third_party/triton/temporary/cuda11-temporary.patch
deleted file mode 100644
index a92166e..0000000
--- a/third_party/triton/temporary/cuda11-temporary.patch
+++ /dev/null
@@ -1,35 +0,0 @@
-# This temporary patch has already been included to the public list of Triton
-# patches. It is only here temporarily to be included in the openxla version,
-# but it will be removed during the next triton integration.
-
---- a/third_party/nvidia/backend/driver.c
-+++ b/third_party/nvidia/backend/driver.c
-@@ -154,6 +154,8 @@ static PyObject *loadBinary(PyObject *se
- typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
- int *numClusters, CUfunction func, const CUlaunchConfig *config);
-
-+#if CUDA_VERSION < 12000
-+#else
- typedef CUresult (*cuTensorMapEncodeTiled_t)(
- CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
- cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
-@@ -161,6 +161,7 @@ typedef CUresult (*cuTensorMapEncodeTile
- const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
- CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
- CUtensorMapFloatOOBfill oobFill);
-+#endif
-
- #define defineGetFunctionHandle(name, symbolName) \
- static symbolName##_t name() { \
-@@ -187,8 +187,11 @@ typedef CUresult (*cuTensorMapEncodeTile
- defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
- cuOccupancyMaxActiveClusters);
-
-+#if CUDA_VERSION < 12000
-+#else
- defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
- cuTensorMapEncodeTiled);
-+#endif
-
- static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
- int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl
index f55c41e..4fa5526 100644
--- a/third_party/triton/temporary/series.bzl
+++ b/third_party/triton/temporary/series.bzl
@@ -14,7 +14,5 @@
"""
temporary_patch_list = [
- "//third_party/triton/temporary:cuda11-temporary.patch",
- "//third_party/triton/temporary:undo_tesla_gpu.patch",
# Add new patches just above this line
]
diff --git a/third_party/triton/temporary/undo_tesla_gpu.patch b/third_party/triton/temporary/undo_tesla_gpu.patch
deleted file mode 100644
index 6c2d1d1..0000000
--- a/third_party/triton/temporary/undo_tesla_gpu.patch
+++ /dev/null
@@ -1,13 +0,0 @@
-This can be removed on the next integrate as it already exists in upstream.
-diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-@@ -21,7 +21,7 @@ namespace {
- static int getMMAVersionSafe(int computeCapability, DotOp op) {
- // List supported mma version in order of preference.
- SmallVector<int> versionsSupported;
-- if (computeCapability < 80) {
-+ if (computeCapability < 75) {
- versionsSupported = {1};
- } else if (computeCapability < 90) {
- versionsSupported = {2};
diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl
index f321b4e..012f5b4 100644
--- a/third_party/triton/workspace.bzl
+++ b/third_party/triton/workspace.bzl
@@ -1,15 +1,15 @@
"""Provides the repository macro to import Triton."""
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-load("//third_party/triton/llvm_integration:series.bzl", "llvm_patch_list")
-load("//third_party/triton/temporary:series.bzl", "temporary_patch_list")
-load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_list")
+load("//third_party/triton:llvm_integration/series.bzl", "llvm_patch_list")
+load("//third_party/triton:temporary/series.bzl", "temporary_patch_list")
+load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_list")
def repo():
"""Imports Triton."""
- TRITON_COMMIT = "cl657175856"
- TRITON_SHA256 = "316f421a7d7ead2b7e5adc2e8bb68ce1a8f7809db73dbed8abd54c35bd0c1576"
+ TRITON_COMMIT = "cl659604537"
+ TRITON_SHA256 = "4f4699ca0df9d48649efcd838d331aa6bee4ec7fce905cdeed71715f6ca7e033"
tf_http_archive(
name = "triton",
sha256 = TRITON_SHA256,
diff --git a/third_party/triton/xla_extensions/BUILD b/third_party/triton/xla_extensions/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/triton/xla_extensions/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/triton/xla_extensions/series.bzl b/third_party/triton/xla_extensions/series.bzl
index 19ba85b..be33c18 100644
--- a/third_party/triton/xla_extensions/series.bzl
+++ b/third_party/triton/xla_extensions/series.bzl
@@ -7,7 +7,7 @@
"""
extensions_files_patch_list = [
- "//third_party/triton/xla_extensions:sparse_dot.patch", # Sparsity internal patch
- "//third_party/triton/xla_extensions:sparsity_layout.patch", # Sparsity internal patch
+ "//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch
+ "//third_party/triton:xla_extensions/sparsity_layout.patch", # Sparsity internal patch
# Add new patches just above this line
]
diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch
index a1c011d..d9fa4cc 100644
--- a/third_party/triton/xla_extensions/sparse_dot.patch
+++ b/third_party/triton/xla_extensions/sparse_dot.patch
@@ -273,9 +273,9 @@
+ return op->hasTrait<OpTrait::DotLike>() || isa<ttg::SparseDotOp>(op);
+}
+
- // Replace the ForOp's yield with a new one with the given operands appended.
- static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
- // Fix up the yield op.
+ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
+ Value insertIdx, Value extractIdx,
+ tt::CoarseSchedule &schedule,
@@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
diff --git a/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/triton/xla_extensions/sparsity_layout.patch
index b64ddbd..4daf4f2 100644
--- a/third_party/triton/xla_extensions/sparsity_layout.patch
+++ b/third_party/triton/xla_extensions/sparsity_layout.patch
@@ -2,19 +2,20 @@
index 34fb89954..a0172e107 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
-@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> std::optional<Value> {
-- llvm_unreachable("Argument rematerialization should not happen in Triton "
-- "-> TritonGPU conversion");
-+ // TODO(b/354860562): reenable or remove.
-+ // llvm_unreachable("Argument rematerialization should not happen in Triton "
-+ // "-> TritonGPU conversion");
++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
++ // remaining arguments that have been converted to a new type.
++ // We use this to rewrite triton_gpu.sparse_dot in a separate pass after
++ // 'convert-triton-to-tritongpu'.
++ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
++ inputs);
+ llvm_unreachable("Argument rematerialization should not happen in Triton "
+ "-> TritonGPU conversion");
return std::nullopt;
- });
-
-@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
@@ -31,7 +32,7 @@
index df3d3b042..e38c184f6 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
-@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert
+@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
diff --git a/third_party/uv/uv.BUILD b/third_party/uv/uv.BUILD
index 75a2df3..b04383a 100644
--- a/third_party/uv/uv.BUILD
+++ b/third_party/uv/uv.BUILD
@@ -11,7 +11,48 @@
cc_library(
name = "uv",
- srcs = glob(["src/*.c"]),
+ srcs = [
+ "src/fs-poll.c",
+ "src/idna.c",
+ "src/inet.c",
+ "src/random.c",
+ "src/strscpy.c",
+ "src/threadpool.c",
+ "src/timer.c",
+ "src/uv-common.c",
+ "src/uv-data-getter-setters.c",
+ "src/version.c",
+ ] + [
+ "src/unix/async.c",
+ "src/unix/core.c",
+ "src/unix/dl.c",
+ "src/unix/fs.c",
+ "src/unix/getaddrinfo.c",
+ "src/unix/getnameinfo.c",
+ "src/unix/loop.c",
+ "src/unix/loop-watcher.c",
+ "src/unix/pipe.c",
+ "src/unix/poll.c",
+ "src/unix/process.c",
+ "src/unix/random-devurandom.c",
+ "src/unix/signal.c",
+ "src/unix/stream.c",
+ "src/unix/tcp.c",
+ "src/unix/thread.c",
+ "src/unix/tty.c",
+ "src/unix/udp.c",
+ ] + select({
+ "@platforms//os:osx": [
+ "src/unix/bsd-ifaddrs.c",
+ "src/unix/darwin.c",
+ "src/unix/darwin-proctitle.c",
+ "src/unix/fsevents.c",
+ "src/unix/kqueue.c",
+ "src/unix/proctitle.c",
+ "src/unix/random-getentropy.c",
+ ],
+ }),
+ # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt.
hdrs = [
"include/uv.h",
],
diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc
index 11783a8..edb4659 100644
--- a/third_party/xla/.bazelrc
+++ b/third_party/xla/.bazelrc
@@ -573,6 +573,9 @@
build:rbe_win_clang --linkopt=/FORCE:MULTIPLE
build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE
+# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits.
+build:rbe_windows_x86_cpu --config=rbe_win_clang
+
# END TF REMOTE BUILD EXECUTION OPTIONS
# TFLite build configs for generic embedded Linux
@@ -787,17 +790,19 @@
# PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over
# the whole TF code base. These are usually run continuously or upon presubmit.
-# CPU PYCPP:
+# LINUX CPU PYCPP:
test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# CUDA PYCPP:
+
+# LINUX CUDA PYCPP:
test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# ARM64 PYCPP
+
+# LINUX ARM64 PYCPP
# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on
# Linux x86 so that we can use RBE. Since tests still need to run on the single
# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit.
@@ -830,6 +835,13 @@
# CROSS-COMPILE MACOS X86 PYCPP
build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test
build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test
+# WINDOWS X86-64 CPU PYCPP
+test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600"
+test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only
+test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/...
+
# END TF TEST SUITE OPTIONS
# START CROSS-COMPILE CONFIGS
diff --git a/third_party/xla/.github/workflows/bazel_query.yml b/third_party/xla/.github/workflows/bazel_query.yml
new file mode 100644
index 0000000..253218a
--- /dev/null
+++ b/third_party/xla/.github/workflows/bazel_query.yml
@@ -0,0 +1,38 @@
+# Copyright 2024 The OpenXLA Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+name: Bazel Query
+permissions:
+ contents: read
+on:
+ pull_request:
+
+env:
+ # Have `go install` place binaries in $PATH
+ GOBIN: "/usr/local/bin"
+
+jobs:
+ bazel-query:
+ runs-on: ubuntu-22.04
+ defaults:
+ run:
+ shell: bash
+ timeout-minutes: 2
+ steps:
+ - name: "Checking out repository"
+ uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
+ - name: "Install bazelisk"
+ run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0
+ - name: "Run bazel query //xla/..."
+ run: bazelisk query //xla/...
diff --git a/third_party/xla/.github/workflows/buildifier.yml b/third_party/xla/.github/workflows/buildifier.yml
index 5514067..797b884 100644
--- a/third_party/xla/.github/workflows/buildifier.yml
+++ b/third_party/xla/.github/workflows/buildifier.yml
@@ -13,7 +13,8 @@
# limitations under the License.
# ============================================================================
name: Buildifier
-permissions: read-all
+permissions:
+ contents: read
on:
pull_request:
diff --git a/third_party/xla/.github/workflows/check_contents.yml b/third_party/xla/.github/workflows/check_contents.yml
index fd38adf..1756b36 100644
--- a/third_party/xla/.github/workflows/check_contents.yml
+++ b/third_party/xla/.github/workflows/check_contents.yml
@@ -19,7 +19,8 @@
# files once XLA moves out of Tensorflow internally.
# TODO(ddunleavy): Update this after METADATA files are consolidated.
name: Check Contents
-permissions: read-all
+permissions:
+ contents: read
on:
pull_request:
diff --git a/third_party/xla/.github/workflows/clang_format.yml b/third_party/xla/.github/workflows/clang_format.yml
index 2701311..e22b67e 100644
--- a/third_party/xla/.github/workflows/clang_format.yml
+++ b/third_party/xla/.github/workflows/clang_format.yml
@@ -14,7 +14,8 @@
# ============================================================================
name: Clang Format
-permissions: read-all
+permissions:
+ contents: read
on:
pull_request:
diff --git a/third_party/xla/docs/custom_call.md b/third_party/xla/docs/custom_call.md
index eb97ad7..2471df6 100644
--- a/third_party/xla/docs/custom_call.md
+++ b/third_party/xla/docs/custom_call.md
@@ -1,4 +1,4 @@
-# XLA custom calls
+# XLA Custom Calls
This document describes how to write and use XLA custom calls using XLA FFI
library. Custom call is a mechanism to describe an external "operation" in the
@@ -23,6 +23,269 @@
> and custom call target references or to use C-style namespacing directly in
> the function name.
+## JAX + XLA Custom Calls
+
+See [JAX documentation](https://jax.readthedocs.io/en/latest/ffi.html) for
+end to end examples of integrating custom calls and XLA FFI with JAX.
+
+## XLA FFI Binding
+
+XLA FFI binding is a compile-time specification of the custom call signature:
+custom call arguments, attributes and their types, and additional parameters
+passed via the execution context (i.e., gpu stream for GPU backend). XLA FFI
+finding can be bound to any C++ callable (function pointer, lambda, etc.) with
+compatible `operator()` signature. Constructed handler decodes XLA FFI call
+frame (defined by the stable C API), type check all parameters, and forward
+decoded results to the user-defined callback.
+
+XLA FFI binding heavily relies on template metaprogramming to be be able to
+compile constructed handler to the most efficient machine code. Run time
+overheads are in order of a couple of nanoseconds for each custom call
+parameter.
+
+XLA FFI customization points implemented as template specializations, and
+users can define how to decode their custom types, i.e., it is possible
+to define custom decoding for user-defined `enum class` types.
+
+### Returning Errors From Custom Calls
+
+Custom call implementations must return `xla::ffi::Error` value to signal
+success or error to XLA runtime. It is similar to `absl::Status`, and has
+the same set of error codes. We do not use `absl::Status` because it does
+not have a stable ABI and it would be unsafe to pass it between dynamically
+loaded custom call library, and XLA itself.
+
+```c++
+// Handler that always returns an error.
+auto always_error = Ffi::Bind().To(
+ []() { return Error(ErrorCode::kInternal, "Oops!"); });
+
+// Handler that always returns a success.
+auto always_success = Ffi::Bind().To(
+ []() { return Error::Success(); });
+
+```
+
+### Buffer Arguments And Results
+
+XLA uses destination passing style for results: custom calls (or any other XLA
+operations for that matter) do not allocate memory for results, and instead
+write into destinations passed by XLA runtime. XLA uses static buffer
+assignment, and allocates buffers for all values based on their live ranges at
+compile time.
+
+Results passed to FFI handlers wrapped into a `Result<T>` template, that
+has a pointer-like semantics: `operator->` gives access to the underlying
+parameter.
+
+`AnyBuffer` arguments and results gives access to custom call buffer parameters
+of any data type. This is useful when custom call has a generic implementation
+that works for multiple data types, and custom call implementation does run time
+dispatching based on data type. `AnyBuffer` gives access to the buffer data
+type, dimensions, and a pointer to the buffer itself.
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+ call_target_name = "foo",
+ api_version = 4 : i32
+} : (tensor<2x2xf32>) -> tensor<2x2xf32>
+```
+
+
+```c++
+// Buffers of any rank and data type.
+auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
+ [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
+ void* arg_data = arg.untyped_data();
+ void* res_data = res->untyped_data();
+ return Error::Success();
+ });
+```
+
+### Constrained Buffer Arguments And Results
+
+`Buffer` allows to add constraints on the buffer data type and rank, and they
+will be automatically checked by the handler and return an error to XLA runtime,
+if run time arguments do not match the FFI handler signature.
+
+```c++
+// Buffers of any rank and F32 data type.
+auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
+ [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
+ float* arg_data = arg.typed_data();
+ float* res_data = res->typed_data();
+ return Error::Success();
+ });
+```
+
+```c++
+// Buffers of rank 2 and F32 data type.
+auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
+ [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
+ float* arg_data = arg.typed_data();
+ float* res_data = res->typed_data();
+ return Error::Success();
+ });
+```
+
+### Variadic Arguments And Results
+
+If the number of arguments and result can be different in different instances of
+a custom call, they can be decoded at run time using `RemainingArgs` and
+`RemainingRets`.
+
+```
+auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
+ [](RemainingArgs args, RemainingRets results) -> Error {
+ ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
+ ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
+
+ if (!arg.has_value()) {
+ return Error(ErrorCode::kInternal, arg.error());
+ }
+
+ if (!res.has_value()) {
+ return Error(ErrorCode::kInternal, res.error());
+ }
+
+ return Error::Success();
+ });
+```
+
+Variadic arguments and results can be declared after regular arguments and
+results, however binding regular arguments and results after variadic one is
+illegal.
+
+```c++
+auto handler =
+ Ffi::Bind()
+ .Arg<AnyBuffer>()
+ .RemainingArgs()
+ .Ret<AnyBuffer>()
+ .RemainingRets()
+ .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
+ RemainingRets results) -> Error { return Error::Success(); });
+```
+
+### Attributes
+
+XLA FFI supports automatic decoding of `mlir::DictionaryAttr` passed as a
+`custom_call` `backend_config` into FFI handler arguments.
+
+Note: See [stablehlo RFC](https://github.com/openxla/stablehlo/blob/main/rfcs/20240312-standardize-customcallop.md)
+for details, and `stablehlo.custom_call` operation specification.
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+ call_target_name = "foo",
+ backend_config= {
+ i32 = 42 : i32,
+ str = "string"
+ },
+ api_version = 4 : i32
+} : (tensor<f32>) -> tensor<f32>
+```
+
+In this example custom call has a single buffer argument and two attributes, and
+XLA FFI can automatically decode them and pass to the user-defined callable.
+
+```c++
+auto handler = Ffi::Bind()
+ .Arg<BufferR0<F32>>()
+ .Attr<int32_t>("i32")
+ .Attr<std::string_view>("str")
+ .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
+ return Error::Success();
+ });
+```
+
+### User-Defined Enum Attributes
+
+XLA FFI can automatically decode integral MLIR attributes into user-defined
+enums. Enum class must have the same underlying integral type, and decoding
+has to be explicitly registered with XLA FFI.
+
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+ call_target_name = "foo",
+ backend_config= {
+ command = 0 : i32
+ },
+ api_version = 4 : i32
+} : (tensor<f32>) -> tensor<f32>
+```
+
+```c++
+enum class Command : int32_t {
+ kAdd = 0,
+ kMul = 1,
+};
+
+XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
+
+auto handler = Ffi::Bind().Attr<Command>("command").To(
+ [](Command command) -> Error { return Error::Success(); });
+```
+
+### Binding All Custom Call Attributes
+
+It is possible to get access to all custom call attributes as a dictionary
+and lazily decode only the attributes that are needed at run time.
+
+```c++
+auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
+ ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
+ return Error::Success();
+});
+```
+
+### User-defined Struct Attributes
+
+XLA FFI can decode dictionary attributes into user-defined structs.
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+ call_target_name = "foo",
+ backend_config= {
+ range = { lo = 0 : i64, hi = 42 : i64 }
+ },
+ api_version = 4 : i32
+} : (tensor<f32>) -> tensor<f32>
+```
+
+In example above `range` is an `mlir::DictionaryAttr` attribute, and instead
+of accessing dictionary fields by name, it can be automatically decoded as
+a C++ struct. Decoding has to be explicitly registered with a
+`XLA_FFI_REGISTER_STRUCT_ATTR_DECODING` macro (behind the scene it defines
+a template specialization in `::xla::ffi` namespace, thus macro must be added to
+the global namespace).
+
+```c++
+struct Range {
+ int64_t lo;
+ int64_t hi;
+};
+
+XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
+ StructMember<int64_t>("i64"));
+
+auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
+ return Error::Success();
+});
+```
+
+Custom attributes can be loaded from a dictionary, just like any other
+attribute. In example below, all custom call attributes decoded as a
+`Dictionary`, and a `range` can be accessed by name.
+
+```c++
+auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
+ ErrorOr<Range> range = attrs.get<Range>("range");
+ return Error::Success();
+});
+```
+
## Create a custom call on CPU
You can create an HLO instruction that represents a custom call via XLA's client
diff --git a/third_party/xla/docs/determinism.md b/third_party/xla/docs/determinism.md
index d8cd934..09a1e4f 100644
--- a/third_party/xla/docs/determinism.md
+++ b/third_party/xla/docs/determinism.md
@@ -8,6 +8,10 @@
measurements different kernels can be picked as the fastest ones in different
compilation runs.
+`--xla_gpu_require_complete_aot_autotune_results` can be used to ensure that no
+autotuning happens on repeated compilations - they either reuse compatible
+results of previous runs or fail.
+
## Execution
Programs compiled by XLA can be non-deterministic on operations like scatter,
diff --git a/third_party/xla/docs/operation_semantics.md b/third_party/xla/docs/operation_semantics.md
index 55ed575..5584997 100644
--- a/third_party/xla/docs/operation_semantics.md
+++ b/third_party/xla/docs/operation_semantics.md
@@ -1214,8 +1214,8 @@
Where `Op` is one of `Add` (addition), `Sub`(subtraction), `Mul`
(multiplication), `Div` (division), `Pow` (power), `Rem` (remainder), `Max`
-(maximum), `Min` (minimum), `LogicalAnd` (logical AND), `LogicalOr` (logical
-OR), `LogicalXor` (logical XOR), `ShiftLeft` (Left Shift),
+(maximum), `Min` (minimum), `And` (logical AND), `Or` (logical
+OR), `Xor` (logical XOR), `ShiftLeft` (Left Shift),
`ShiftRightArithmetic` (arithmetic Right Shift), `ShiftRightLogical` (logical
Right Shift), `Atan2` (2-argument arctangent), or `Complex` (combines real and
imaginary parts into a complex number)
@@ -1305,12 +1305,22 @@
<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
+<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`.
+
<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`.
+<b>`Clz(operand)`</b> Element-wise count leading zeros.
+
<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`.
+<b>`Erf(operand)`</b> Element-wise error function `x -> erf(x)` where
+
+$$\text{erf}(x) = \frac{2}{\sqrt{\pi}}\int_0^x e^{-t^2} \, dt$$.
+
<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`.
+<b>`Expm1(operand)`</b> Element-wise natural exponential minus one `x -> e^x - 1`.
+
<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
<b>`Imag(operand)`</b> Element-wise imaginary part of a complex (or real)
@@ -1323,19 +1333,25 @@
<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
-<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
+<b>`Log1p(operand)`</b> Element-wise shifted natural logarithm `x -> ln(1+x)`.
<b>`Logistic(operand)`</b> Element-wise logistic function computation `x ->
logistic(x)`.
+<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
+
+<b>`Not(operand)`</b> Element-wise logical not `x -> !(x)`.
+
<b>`PopulationCount(operand)`</b> Computes the number of bits set in each
element of `operand`.
-<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
-
<b>`Real(operand)`</b> Element-wise real part of a complex (or real) shape.
`x -> real(x)`. If the operand is a floating point type, returns the same value.
+<b>`Round(operand)`</b> Element-wise rounding, ties away from zero.
+
+<b>`RoundNearestEven(operand)`</b> Element-wise rounding, ties to nearest even.
+
<b>`Rsqrt(operand)`</b> Element-wise reciprocal of square root operation
`x -> 1.0 / sqrt(x)`.
@@ -1345,16 +1361,14 @@
using the comparison operator of the element type of `operand`.
+<b>`Sin(operand)`</b> Element-wise sine `x -> sin(x)`.
+
<b>`Sqrt(operand)`</b> Element-wise square root operation `x -> sqrt(x)`.
-<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`.
+<b>`Tan(operand)`</b> Element-wise tangent `x -> tan(x)`.
<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
-<b>`Round(operand)`</b> Element-wise rounding, ties away from zero.
-
-<b>`RoundNearestEven(operand)`</b> Element-wise rounding, ties to nearest even.
-
Arguments | Type | Semantics
--------- | ------- | ---------------------------
`operand` | `XlaOp` | The operand to the function
diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files
index baafd35..5759a24 100644
--- a/third_party/xla/opensource_only.files
+++ b/third_party/xla/opensource_only.files
@@ -34,6 +34,7 @@
third_party/py/python_repo.bzl:
third_party/python_runtime/BUILD:
third_party/repo.bzl:
+third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD:
third_party/stablehlo/BUILD:
tools/toolchains/BUILD:
tools/toolchains/clang6/BUILD:
diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch
index 4d99610..1711c1b 100644
--- a/third_party/xla/third_party/shardy/temporary.patch
+++ b/third_party/xla/third_party/shardy/temporary.patch
@@ -1,15 +1,15 @@
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
-index 9345d8d..6429d9b 100644
+index 4b3f0db..54a3c65 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
def repo(name):
"""Imports LLVM."""
-- LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
-- LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"
-+ LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d"
-+ LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd"
+- LLVM_COMMIT = "16dadecc05fa4986d4522c2c3a09a7628feb0fd4"
+- LLVM_SHA256 = "e7c5195e30f75c6027f90b8196ded71a41a38c586931cfb33d63295d9eed95fd"
++ LLVM_COMMIT = "4c5ef6690040383956461828457ac27f7f912edb"
++ LLVM_SHA256 = "a30da7822f5307bc0aca8c497ffdd6369e3877186e87501e2ac1f3ec5ed1c0b7"
tf_http_archive(
name = name,
diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl
index 200ac3f..a40d827 100644
--- a/third_party/xla/third_party/shardy/workspace.bzl
+++ b/third_party/xla/third_party/shardy/workspace.bzl
@@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
- SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
- SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"
+ SHARDY_COMMIT = "76731821434117cb6d736bdd1b32b7ee4ffbcb4b"
+ SHARDY_SHA256 = "944bdbdc9e97ca95b15ac81bfee151664c06b4e6373661d25091e064a604fd2f"
tf_http_archive(
name = "shardy",
diff --git a/third_party/xla/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/spirv_llvm_translator/BUILD
new file mode 100644
index 0000000..8d626dc
--- /dev/null
+++ b/third_party/xla/third_party/spirv_llvm_translator/BUILD
@@ -0,0 +1,7 @@
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+# spirv_llvm_translator license placeholder
diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
new file mode 100644
index 0000000..557e2e8
--- /dev/null
+++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
@@ -0,0 +1,34 @@
+cc_library(
+ name = "spirv_llvm_translator",
+ srcs = glob([
+ "lib/SPIRV/libSPIRV/*.cpp",
+ "lib/SPIRV/libSPIRV/*.hpp",
+ "lib/SPIRV/libSPIRV/*.h",
+ "lib/SPIRV/Mangler/*.cpp",
+ "lib/SPIRV/Mangler/*.h",
+ "lib/SPIRV/*.cpp",
+ "lib/SPIRV/*.hpp",
+ "lib/SPIRV/*.h",
+ ]),
+ hdrs = glob(["include/*"]),
+ includes = [
+ "include/",
+ "lib/SPIRV/",
+ "lib/SPIRV/Mangler/",
+ "lib/SPIRV/libSPIRV/",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@llvm-project//llvm:Analysis",
+ "@llvm-project//llvm:BitWriter",
+ "@llvm-project//llvm:CodeGen",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Demangle",
+ "@llvm-project//llvm:IRReader",
+ "@llvm-project//llvm:Linker",
+ "@llvm-project//llvm:Passes",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:TransformUtils",
+ "@spirv_headers//:spirv_cpp_headers",
+ ],
+)
diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
new file mode 100644
index 0000000..fc843b1
--- /dev/null
+++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
@@ -0,0 +1,25 @@
+diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h
+index a828add8..924e13b4 100644
+
+Spir backend uses different addrspace representations link with nvptx backend link.
+We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding
+changing addrspace based on device backend everywhere)
+
+--- a/lib/SPIRV/SPIRVInternal.h
++++ b/lib/SPIRV/SPIRVInternal.h
+@@ -179,11 +179,12 @@ typedef SPIRVMap<Op, Op, IntBoolOpMapId> IntBoolOpMap;
+ "-v512:512:512-v1024:1024:1024"
+
+ enum SPIRAddressSpace {
+- SPIRAS_Private,
++ SPIRAS_Generic,
+ SPIRAS_Global,
+- SPIRAS_Constant,
++ SPIRAS_Internal,
+ SPIRAS_Local,
+- SPIRAS_Generic,
++ SPIRAS_Constant,
++ SPIRAS_Private,
+ SPIRAS_GlobalDevice,
+ SPIRAS_GlobalHost,
+ SPIRAS_Input,
\ No newline at end of file
diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch
index 6a1b6b0..8b13789 100755
--- a/third_party/xla/third_party/stablehlo/temporary.patch
+++ b/third_party/xla/third_party/stablehlo/temporary.patch
@@ -1,19 +1 @@
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float32 {
- func.func public @main() {
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float64 {
- func.func public @main() {
diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl
index f9c14a6..f1b02d2 100644
--- a/third_party/xla/third_party/stablehlo/workspace.bzl
+++ b/third_party/xla/third_party/stablehlo/workspace.bzl
@@ -4,8 +4,8 @@
def repo():
# LINT.IfChange
- STABLEHLO_COMMIT = "8555db77763fadbd6be83df0a5532828bc419cba"
- STABLEHLO_SHA256 = "666a88d94e0f1b36e9e5b25411521b878320c61983214859b4e419f36acbf332"
+ STABLEHLO_COMMIT = "24d1807a9a3e0df81103a0be9be7ad28ee34c85a"
+ STABLEHLO_SHA256 = "fc8b165379e6b34a7ab64ebe9334605efec9c01b9106cc78c40ed1163f35a529"
# LINT.ThenChange(Google-internal path)
tf_http_archive(
diff --git a/third_party/xla/third_party/triton/BUILD b/third_party/xla/third_party/triton/BUILD
index 3c41380..3b74065 100644
--- a/third_party/xla/third_party/triton/BUILD
+++ b/third_party/xla/third_party/triton/BUILD
@@ -1 +1,13 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+# copybara:uncomment_begin
+# package(default_applicable_licenses = ["//tensorflow:license"])
+#
+# filegroup(
+# name = "patch_files",
+# srcs = glob([
+# "xla_extensions/**",
+# "llvm_integration/**",
+# "temporary/**",
+# ]),
+# visibility = ["//third_party/triton:__subpackages__"],
+# )
+# copybara:uncomment_end
diff --git a/third_party/xla/third_party/triton/llvm_integration/BUILD b/third_party/xla/third_party/triton/llvm_integration/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/triton/llvm_integration/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch b/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch
deleted file mode 100644
index 4a1f47d..0000000
--- a/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch
+++ /dev/null
@@ -1,18 +0,0 @@
-# Do not upstream this patch. This has been already upstreamed in
-# https://github.com/triton-lang/triton/commit/de46a0ede6efe7e93c2a9ebef639e36c6177c511
-# Next integration will include it and this patch should be removed then.
-
-diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc
---- a/third_party/amd/python/triton_amd.cc
-+++ b/third_party/amd/python/triton_amd.cc
-@@ -193,9 +193,7 @@ void init_triton_amd(py::module &&m) {
- target->createMCAsmBackend(*sti, *mri, mcOptions));
- mcStreamer.reset(target->createMCObjectStreamer(
- triple, ctx, std::move(mab), mab->createObjectWriter(svos),
-- std::move(ce), *sti, mcOptions.MCRelaxAll,
-- mcOptions.MCIncrementalLinkerCompatible,
-- /*DWARFMustBeAtTheEnd=*/false));
-+ std::move(ce), *sti));
-
- std::unique_ptr<llvm::MCAsmParser> parser(
- createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl
index 5348e66..656b9c8 100644
--- a/third_party/xla/third_party/triton/llvm_integration/series.bzl
+++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl
@@ -8,6 +8,5 @@
"""
llvm_patch_list = [
- "//third_party/triton/llvm_integration:cl657620552.patch",
# Add new patches just above this line
]
diff --git a/third_party/xla/third_party/triton/temporary/BUILD b/third_party/xla/third_party/triton/temporary/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/triton/temporary/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch b/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch
deleted file mode 100644
index a92166e..0000000
--- a/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch
+++ /dev/null
@@ -1,35 +0,0 @@
-# This temporary patch has already been included to the public list of Triton
-# patches. It is only here temporarily to be included in the openxla version,
-# but it will be removed during the next triton integration.
-
---- a/third_party/nvidia/backend/driver.c
-+++ b/third_party/nvidia/backend/driver.c
-@@ -154,6 +154,8 @@ static PyObject *loadBinary(PyObject *se
- typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
- int *numClusters, CUfunction func, const CUlaunchConfig *config);
-
-+#if CUDA_VERSION < 12000
-+#else
- typedef CUresult (*cuTensorMapEncodeTiled_t)(
- CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
- cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
-@@ -161,6 +161,7 @@ typedef CUresult (*cuTensorMapEncodeTile
- const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
- CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
- CUtensorMapFloatOOBfill oobFill);
-+#endif
-
- #define defineGetFunctionHandle(name, symbolName) \
- static symbolName##_t name() { \
-@@ -187,8 +187,11 @@ typedef CUresult (*cuTensorMapEncodeTile
- defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
- cuOccupancyMaxActiveClusters);
-
-+#if CUDA_VERSION < 12000
-+#else
- defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
- cuTensorMapEncodeTiled);
-+#endif
-
- static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
- int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl
index f55c41e..4fa5526 100644
--- a/third_party/xla/third_party/triton/temporary/series.bzl
+++ b/third_party/xla/third_party/triton/temporary/series.bzl
@@ -14,7 +14,5 @@
"""
temporary_patch_list = [
- "//third_party/triton/temporary:cuda11-temporary.patch",
- "//third_party/triton/temporary:undo_tesla_gpu.patch",
# Add new patches just above this line
]
diff --git a/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch b/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch
deleted file mode 100644
index 6c2d1d1..0000000
--- a/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch
+++ /dev/null
@@ -1,13 +0,0 @@
-This can be removed on the next integrate as it already exists in upstream.
-diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-@@ -21,7 +21,7 @@ namespace {
- static int getMMAVersionSafe(int computeCapability, DotOp op) {
- // List supported mma version in order of preference.
- SmallVector<int> versionsSupported;
-- if (computeCapability < 80) {
-+ if (computeCapability < 75) {
- versionsSupported = {1};
- } else if (computeCapability < 90) {
- versionsSupported = {2};
diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl
index f321b4e..012f5b4 100644
--- a/third_party/xla/third_party/triton/workspace.bzl
+++ b/third_party/xla/third_party/triton/workspace.bzl
@@ -1,15 +1,15 @@
"""Provides the repository macro to import Triton."""
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-load("//third_party/triton/llvm_integration:series.bzl", "llvm_patch_list")
-load("//third_party/triton/temporary:series.bzl", "temporary_patch_list")
-load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_list")
+load("//third_party/triton:llvm_integration/series.bzl", "llvm_patch_list")
+load("//third_party/triton:temporary/series.bzl", "temporary_patch_list")
+load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_list")
def repo():
"""Imports Triton."""
- TRITON_COMMIT = "cl657175856"
- TRITON_SHA256 = "316f421a7d7ead2b7e5adc2e8bb68ce1a8f7809db73dbed8abd54c35bd0c1576"
+ TRITON_COMMIT = "cl659604537"
+ TRITON_SHA256 = "4f4699ca0df9d48649efcd838d331aa6bee4ec7fce905cdeed71715f6ca7e033"
tf_http_archive(
name = "triton",
sha256 = TRITON_SHA256,
diff --git a/third_party/xla/third_party/triton/xla_extensions/BUILD b/third_party/xla/third_party/triton/xla_extensions/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/triton/xla_extensions/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/triton/xla_extensions/series.bzl b/third_party/xla/third_party/triton/xla_extensions/series.bzl
index 19ba85b..be33c18 100644
--- a/third_party/xla/third_party/triton/xla_extensions/series.bzl
+++ b/third_party/xla/third_party/triton/xla_extensions/series.bzl
@@ -7,7 +7,7 @@
"""
extensions_files_patch_list = [
- "//third_party/triton/xla_extensions:sparse_dot.patch", # Sparsity internal patch
- "//third_party/triton/xla_extensions:sparsity_layout.patch", # Sparsity internal patch
+ "//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch
+ "//third_party/triton:xla_extensions/sparsity_layout.patch", # Sparsity internal patch
# Add new patches just above this line
]
diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch
index a1c011d..d9fa4cc 100644
--- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch
+++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch
@@ -273,9 +273,9 @@
+ return op->hasTrait<OpTrait::DotLike>() || isa<ttg::SparseDotOp>(op);
+}
+
- // Replace the ForOp's yield with a new one with the given operands appended.
- static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
- // Fix up the yield op.
+ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
+ Value insertIdx, Value extractIdx,
+ tt::CoarseSchedule &schedule,
@@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
diff --git a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch
index b64ddbd..4daf4f2 100644
--- a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch
+++ b/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch
@@ -2,19 +2,20 @@
index 34fb89954..a0172e107 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
-@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> std::optional<Value> {
-- llvm_unreachable("Argument rematerialization should not happen in Triton "
-- "-> TritonGPU conversion");
-+ // TODO(b/354860562): reenable or remove.
-+ // llvm_unreachable("Argument rematerialization should not happen in Triton "
-+ // "-> TritonGPU conversion");
++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
++ // remaining arguments that have been converted to a new type.
++ // We use this to rewrite triton_gpu.sparse_dot in a separate pass after
++ // 'convert-triton-to-tritongpu'.
++ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
++ inputs);
+ llvm_unreachable("Argument rematerialization should not happen in Triton "
+ "-> TritonGPU conversion");
return std::nullopt;
- });
-
-@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
@@ -31,7 +32,7 @@
index df3d3b042..e38c184f6 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
-@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert
+@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc
index 11783a8..edb4659 100644
--- a/third_party/xla/third_party/tsl/.bazelrc
+++ b/third_party/xla/third_party/tsl/.bazelrc
@@ -573,6 +573,9 @@
build:rbe_win_clang --linkopt=/FORCE:MULTIPLE
build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE
+# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits.
+build:rbe_windows_x86_cpu --config=rbe_win_clang
+
# END TF REMOTE BUILD EXECUTION OPTIONS
# TFLite build configs for generic embedded Linux
@@ -787,17 +790,19 @@
# PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over
# the whole TF code base. These are usually run continuously or upon presubmit.
-# CPU PYCPP:
+# LINUX CPU PYCPP:
test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# CUDA PYCPP:
+
+# LINUX CUDA PYCPP:
test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# ARM64 PYCPP
+
+# LINUX ARM64 PYCPP
# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on
# Linux x86 so that we can use RBE. Since tests still need to run on the single
# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit.
@@ -830,6 +835,13 @@
# CROSS-COMPILE MACOS X86 PYCPP
build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test
build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test
+# WINDOWS X86-64 CPU PYCPP
+test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600"
+test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only
+test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/...
+
# END TF TEST SUITE OPTIONS
# START CROSS-COMPILE CONFIGS
diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files
index 300ae95..1d52b26 100644
--- a/third_party/xla/third_party/tsl/opensource_only.files
+++ b/third_party/xla/third_party/tsl/opensource_only.files
@@ -93,6 +93,7 @@
third_party/repo.bzl:
third_party/six.BUILD:
third_party/snappy.BUILD:
+third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD:
third_party/systemlibs/BUILD.tpl:
third_party/systemlibs/BUILD:
third_party/systemlibs/absl_py.BUILD:
diff --git a/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch b/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch
new file mode 100644
index 0000000..5328c3a
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch
@@ -0,0 +1,35 @@
+From 372124e6af36a540e74a2ec31d79d7297a831f98 Mon Sep 17 00:00:00 2001
+From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= <frederic.bastien@gmail.com>
+Date: Thu, 1 Aug 2024 12:38:52 -0700
+Subject: [PATCH] PR #1732: Fix build on NVIDIA Jetson board. Fix #1665
+
+Imported from GitHub PR https://github.com/abseil/abseil-cpp/pull/1732
+
+Fix build on NVIDIA Jetson board. Fix #1665
+
+This patch is already used by the spark project.
+I'm fixing this as this break the build of Tensorflow and JAX on Jetson board.
+Merge 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff into 6b8ebb35c0414ef5a2b6fd4a0f59057e41beaff9
+
+Merging this change closes #1732
+
+COPYBARA_INTEGRATE_REVIEW=https://github.com/abseil/abseil-cpp/pull/1732 from nouiz:fix_neon_on_jetson 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff
+PiperOrigin-RevId: 658501520
+Change-Id: If502ede4efc8c877fb3fed227eca6dc7622dd181
+---
+ absl/base/config.h | 2 +-
+ 1 file changed, 1 insertion(+), 1 deletion(-)
+
+diff --git a/absl/base/config.h b/absl/base/config.h
+index 97c9a22a109..ab1e9860a91 100644
+--- a/absl/base/config.h
++++ b/absl/base/config.h
+@@ -926,7 +926,7 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' ||
+ // https://llvm.org/docs/CompileCudaWithLLVM.html#detecting-clang-vs-nvcc-from-code
+ #ifdef ABSL_INTERNAL_HAVE_ARM_NEON
+ #error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set
+-#elif defined(__ARM_NEON) && !defined(__CUDA_ARCH__)
++#elif defined(__ARM_NEON) && !(defined(__NVCC__) && defined(__CUDACC__))
+ #define ABSL_INTERNAL_HAVE_ARM_NEON 1
+ #endif
+
diff --git a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
index 06f7516..9565a82 100644
--- a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
+++ b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
@@ -44,4 +44,5 @@
system_link_files = SYS_LINKS,
strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
+ patch_file = ["//third_party/absl:nvidia_jetson.patch"],
)
diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl
index 0b85e59..44cdbe3 100644
--- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl
+++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl
@@ -249,3 +249,9 @@
# to make bazel query happy.
name = "nvptxcompiler",
)
+
+cc_library(
+ # This is not yet fully supported, but we need the rule
+ # to make bazel query happy.
+ name = "nvjitlink",
+)
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl
index c185ca7..ff9b53b 100644
--- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl
+++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl
@@ -205,6 +205,8 @@
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include")
+ if int(rocm_config.rocm_version_number) >= 60200:
+ inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include")
# Support hcc based off clang 10.0.0 (for ROCm 3.3)
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD
index 4b3ad84..8c73096 100644
--- a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD
@@ -12,7 +12,7 @@
"#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
"#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
- "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
+ "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH",
"#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE",
"#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
@@ -109,6 +109,7 @@
"-UUSE_CBLAS",
"-DDNNL_ENABLE_MAX_CPU_ISA",
"-DDNNL_ENABLE_ITT_TASKS",
+ "-DDNNL_ENABLE_GRAPH_DUMP",
] + tf_openmp_copts()
_INCLUDES_LIST = [
@@ -119,6 +120,7 @@
"src/cpu",
"src/cpu/gemm",
"src/cpu/x64/xbyak",
+ "src/graph",
]
_TEXTUAL_HDRS_LIST = glob([
@@ -129,6 +131,15 @@
"src/cpu/**/*.hpp",
"src/cpu/jit_utils/**/*.hpp",
"src/cpu/x64/xbyak/*.h",
+ "src/graph/interface/*.hpp",
+ "src/graph/backend/*.hpp",
+ "src/graph/backend/dnnl/*.hpp",
+ "src/graph/backend/fake/*.hpp",
+ "src/graph/backend/dnnl/passes/*.hpp",
+ "src/graph/backend/dnnl/patterns/*.hpp",
+ "src/graph/backend/dnnl/kernels/*.hpp",
+ "src/graph/utils/*.hpp",
+ "src/graph/utils/pm/*.hpp",
]) + [
":dnnl_config_h",
":dnnl_version_h",
@@ -160,6 +171,16 @@
"src/cpu/**/*.cpp",
"src/common/ittnotify/*.c",
"src/cpu/jit_utils/**/*.cpp",
+ "src/cpu/x64/**/*.cpp",
+ "src/graph/interface/*.cpp",
+ "src/graph/backend/*.cpp",
+ "src/graph/backend/dnnl/*.cpp",
+ "src/graph/backend/fake/*.cpp",
+ "src/graph/backend/dnnl/passes/*.cpp",
+ "src/graph/backend/dnnl/patterns/*.cpp",
+ "src/graph/backend/dnnl/kernels/*.cpp",
+ "src/graph/utils/*.cpp",
+ "src/graph/utils/pm/*.cpp",
],
exclude = [
"src/cpu/aarch64/**",
diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD
new file mode 100644
index 0000000..8d626dc
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD
@@ -0,0 +1,7 @@
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+# spirv_llvm_translator license placeholder
diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
new file mode 100644
index 0000000..557e2e8
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
@@ -0,0 +1,34 @@
+cc_library(
+ name = "spirv_llvm_translator",
+ srcs = glob([
+ "lib/SPIRV/libSPIRV/*.cpp",
+ "lib/SPIRV/libSPIRV/*.hpp",
+ "lib/SPIRV/libSPIRV/*.h",
+ "lib/SPIRV/Mangler/*.cpp",
+ "lib/SPIRV/Mangler/*.h",
+ "lib/SPIRV/*.cpp",
+ "lib/SPIRV/*.hpp",
+ "lib/SPIRV/*.h",
+ ]),
+ hdrs = glob(["include/*"]),
+ includes = [
+ "include/",
+ "lib/SPIRV/",
+ "lib/SPIRV/Mangler/",
+ "lib/SPIRV/libSPIRV/",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@llvm-project//llvm:Analysis",
+ "@llvm-project//llvm:BitWriter",
+ "@llvm-project//llvm:CodeGen",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Demangle",
+ "@llvm-project//llvm:IRReader",
+ "@llvm-project//llvm:Linker",
+ "@llvm-project//llvm:Passes",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:TransformUtils",
+ "@spirv_headers//:spirv_cpp_headers",
+ ],
+)
diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
new file mode 100644
index 0000000..fc843b1
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
@@ -0,0 +1,25 @@
+diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h
+index a828add8..924e13b4 100644
+
+Spir backend uses different addrspace representations link with nvptx backend link.
+We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding
+changing addrspace based on device backend everywhere)
+
+--- a/lib/SPIRV/SPIRVInternal.h
++++ b/lib/SPIRV/SPIRVInternal.h
+@@ -179,11 +179,12 @@ typedef SPIRVMap<Op, Op, IntBoolOpMapId> IntBoolOpMap;
+ "-v512:512:512-v1024:1024:1024"
+
+ enum SPIRAddressSpace {
+- SPIRAS_Private,
++ SPIRAS_Generic,
+ SPIRAS_Global,
+- SPIRAS_Constant,
++ SPIRAS_Internal,
+ SPIRAS_Local,
+- SPIRAS_Generic,
++ SPIRAS_Constant,
++ SPIRAS_Private,
+ SPIRAS_GlobalDevice,
+ SPIRAS_GlobalHost,
+ SPIRAS_Input,
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h
index f04fd0c..8d5cf79 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h
@@ -80,7 +80,7 @@
// Move constructor leaves src in a valid but unspecified state (same as
// std::unordered_map).
- FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {}
+ FlatMap(FlatMap&& src) noexcept : rep_(std::move(src.rep_)) {}
template <typename InputIter>
FlatMap(InputIter first, InputIter last, size_t N = 1,
@@ -100,14 +100,14 @@
// Move-assignment operator leaves src in a valid but unspecified state (same
// as std::unordered_map).
- FlatMap& operator=(FlatMap&& src) {
+ FlatMap& operator=(FlatMap&& src) noexcept {
rep_.MoveFrom(std::move(src.rep_));
return *this;
}
~FlatMap() {}
- void swap(FlatMap& x) { rep_.swap(x.rep_); }
+ void swap(FlatMap& x) noexcept { rep_.swap(x.rep_); }
void clear_no_resize() { rep_.clear_no_resize(); }
void clear() { rep_.clear(); }
void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h
index dfc6584..d6c77e7 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h
@@ -58,10 +58,11 @@
CopyEntries(src.array_, src.end_, CopyEntry());
}
- FlatRep(FlatRep&& src)
- // Copy rather than move src.hash_ and src.equal_. This is necessary to
- // leave src in a valid state -- otherwise e.g. if hash_ is an
- // std::function, moving it would null it out.
+ FlatRep(
+ FlatRep&& src) noexcept // Copy rather than move src.hash_ and
+ // src.equal_. This is necessary to leave src in
+ // a valid state -- otherwise e.g. if hash_ is an
+ // std::function, moving it would null it out.
: hash_(src.hash_), equal_(src.equal_) {
// TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap
// as it could be. The fundamental problem is that we need to leave src in
@@ -118,7 +119,7 @@
MaybeResize();
}
- void swap(FlatRep& x) {
+ void swap(FlatRep& x) noexcept {
using std::swap;
swap(array_, x.array_);
swap(end_, x.end_);
diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h
index ec8e9ad..b317822 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h
@@ -63,7 +63,7 @@
// Move constructor leaves src in a valid but unspecified state (same as
// std::unordered_set).
- FlatSet(FlatSet&& src) : rep_(std::move(src.rep_)) {}
+ FlatSet(FlatSet&& src) noexcept : rep_(std::move(src.rep_)) {}
template <typename InputIter>
FlatSet(InputIter first, InputIter last, size_t N = 1,
@@ -83,14 +83,14 @@
// Move-assignment operator leaves src in a valid but unspecified state (same
// as std::unordered_set).
- FlatSet& operator=(FlatSet&& src) {
+ FlatSet& operator=(FlatSet&& src) noexcept {
rep_.MoveFrom(std::move(src.rep_));
return *this;
}
~FlatSet() {}
- void swap(FlatSet& x) { rep_.swap(x.rep_); }
+ void swap(FlatSet& x) noexcept { rep_.swap(x.rep_); }
void clear_no_resize() { rep_.clear_no_resize(); }
void clear() { rep_.clear(); }
void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD
index c103dcf..055931f 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD
@@ -263,12 +263,12 @@
srcs = ["buffered_file_test.cc"],
deps = [
":buffered_file",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env",
"//tsl/platform:env_impl",
"//tsl/platform:test",
"//tsl/platform:test_benchmark",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -443,12 +443,12 @@
deps = [
":buffered_inputstream",
":random_inputstream",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env",
"//tsl/platform:env_impl",
"//tsl/platform:test",
"//tsl/platform:test_benchmark",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -471,7 +471,6 @@
srcs = ["inputbuffer_test.cc"],
deps = [
":inputbuffer",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:coding",
"//tsl/platform:env",
"//tsl/platform:env_impl",
@@ -482,6 +481,7 @@
"//tsl/platform:strcat",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -491,10 +491,10 @@
srcs = ["inputstream_interface_test.cc"],
deps = [
":inputstream_interface",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:errors",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -504,11 +504,11 @@
srcs = ["random_inputstream_test.cc"],
deps = [
":random_inputstream",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env",
"//tsl/platform:env_impl",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -519,7 +519,6 @@
deps = [
":record_reader",
":record_writer",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env",
"//tsl/platform:env_impl",
"//tsl/platform:errors",
@@ -528,6 +527,7 @@
"//tsl/platform:strcat",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
"@zlib",
],
)
@@ -539,7 +539,6 @@
deps = [
":record_reader",
":record_writer",
- "//tsl/lib/core:status_test_util",
"//tsl/lib/hash:crc32c",
"//tsl/lib/random:philox",
"//tsl/platform:coding",
@@ -549,6 +548,7 @@
"//tsl/platform:str_util",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -580,12 +580,12 @@
":zlib_compression_options",
":zlib_inputstream",
":zlib_outputbuffer",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env",
"//tsl/platform:env_impl",
"//tsl/platform:errors",
"//tsl/platform:strcat",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc
index 6fae0b6..f9fa67d 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc
@@ -18,7 +18,7 @@
#include <memory>
#include <utility>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/test.h"
#include "tsl/platform/test_benchmark.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc
index ab1f58e..83e5776 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc
@@ -15,7 +15,7 @@
#include "tsl/lib/io/buffered_inputstream.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/random_inputstream.h"
#include "tsl/platform/env.h"
#include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc
index e384604..d23f06a 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc
@@ -17,7 +17,7 @@
#include <vector>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/coding.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc
index 23d4fb0..c9c34db 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc
@@ -15,7 +15,7 @@
#include "tsl/lib/io/inputstream_interface.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc
index 0b47ef2..dfa4ec8 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc
@@ -15,7 +15,7 @@
#include "tsl/lib/io/random_inputstream.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc
index 67df7831..45934c9 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc
@@ -23,7 +23,7 @@
#include <memory>
#include <vector>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc
index 42adf76..51c2be6 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/lib/hash/crc32c.h"
#include "tsl/lib/io/record_reader.h"
#include "tsl/lib/io/record_writer.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD
index 3f42c5f..0adc5e2 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD
@@ -90,12 +90,12 @@
":snappy_inputbuffer",
":snappy_inputstream",
":snappy_outputbuffer",
- "//tsl/lib/core:status_test_util",
"//tsl/lib/io:inputbuffer",
"//tsl/lib/io:random_inputstream",
"//tsl/platform:env",
"//tsl/platform:env_impl",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc
index d04d8d1..7844b89 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc
@@ -170,7 +170,7 @@
bytes_to_read -= avail_in_;
read_location += avail_in_;
}
- StringPiece data;
+ absl::string_view data;
// Try to read enough data to fill up input_buffer_.
absl::Status s = file_->Read(file_pos_, bytes_to_read, &data, read_location);
if (data.data() != read_location) {
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc
index 6d19c60..e851f58 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc
@@ -40,7 +40,7 @@
}
}
-absl::Status SnappyOutputBuffer::Append(StringPiece data) {
+absl::Status SnappyOutputBuffer::Append(absl::string_view data) {
return Write(data);
}
@@ -58,7 +58,7 @@
return Flush();
}
-absl::Status SnappyOutputBuffer::Name(StringPiece* result) const {
+absl::Status SnappyOutputBuffer::Name(absl::string_view* result) const {
return file_->Name(result);
}
@@ -71,7 +71,7 @@
return file_->Tell(position);
}
-absl::Status SnappyOutputBuffer::Write(StringPiece data) {
+absl::Status SnappyOutputBuffer::Write(absl::string_view data) {
//
// The deflated output is accumulated in output_buffer_ and gets written to
// file as and when needed.
@@ -121,7 +121,7 @@
return input_buffer_capacity_ - avail_in_;
}
-void SnappyOutputBuffer::AddToInputBuffer(StringPiece data) {
+void SnappyOutputBuffer::AddToInputBuffer(absl::string_view data) {
size_t bytes_to_write = data.size();
DCHECK_LE(bytes_to_write, AvailableInputSpace());
@@ -182,7 +182,7 @@
absl::Status SnappyOutputBuffer::FlushOutputBufferToFile() {
size_t bytes_to_write = output_buffer_capacity_ - avail_out_;
if (bytes_to_write > 0) {
- absl::Status s = file_->Append(StringPiece(
+ absl::Status s = file_->Append(absl::string_view(
reinterpret_cast<char*>(output_buffer_.get()), bytes_to_write));
if (s.ok()) {
next_out_ = output_buffer_.get();
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h
index 4c4d664..a3bd447 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h
@@ -64,7 +64,7 @@
//
// The input data is buffered internally and will be written to disk at a
// later time. To immediately write contents to file call `Flush()`.
- absl::Status Append(StringPiece data) override;
+ absl::Status Append(absl::string_view data) override;
#if defined(TF_CORD_SUPPORT)
absl::Status Append(const absl::Cord& cord) override;
@@ -81,7 +81,7 @@
absl::Status Close() override;
// Returns the name of the underlying file.
- absl::Status Name(StringPiece* result) const override;
+ absl::Status Name(absl::string_view* result) const override;
// Deflates any cached input, writes all output to file and syncs it.
absl::Status Sync() override;
@@ -98,7 +98,7 @@
// to file when the buffer is full.
//
// To immediately write contents to file call `Flush()`.
- absl::Status Write(StringPiece data);
+ absl::Status Write(absl::string_view data);
// Compresses any cached input and writes all output to file. This must be
// called before the destructor to avoid any data loss.
@@ -107,7 +107,7 @@
private:
// Appends `data` to `input_buffer_`.
// Throws if `data.size()` > AvailableInputSpace().
- void AddToInputBuffer(StringPiece data);
+ void AddToInputBuffer(absl::string_view data);
// Appends `data` to `output_buffer_`. Flushes buffer contents to file when
// buffer gets full.
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc
index 33f42bd..78eecf3 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc
@@ -15,7 +15,7 @@
#include <memory>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/inputbuffer.h"
#include "tsl/lib/io/random_inputstream.h"
#include "tsl/lib/io/snappy/snappy_inputbuffer.h"
@@ -77,7 +77,7 @@
compress_output_buf_size);
for (int i = 0; i < num_writes; i++) {
- TF_RETURN_IF_ERROR(out.Write(StringPiece(data)));
+ TF_RETURN_IF_ERROR(out.Write(absl::string_view(data)));
if (with_flush) {
TF_RETURN_IF_ERROR(out.Flush());
}
@@ -96,7 +96,7 @@
std::unique_ptr<RandomAccessFile> file_reader;
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader));
- StringPiece data;
+ absl::string_view data;
size_t file_pos = 0;
size_t bytes_to_read = 256;
char* scratch = new char[bytes_to_read];
@@ -106,14 +106,14 @@
while ((file_reader->Read(file_pos, bytes_to_read, &data, scratch)).ok()) {
file_pos += data.size();
TF_CHECK_OK(
- corrupt_file_writer->Append(StringPiece(buffer, buffer_size)));
+ corrupt_file_writer->Append(absl::string_view(buffer, buffer_size)));
memcpy(buffer, data.data(), data.size());
buffer_size = data.size();
}
// Drop the last byte. File is now corrupt.
- TF_CHECK_OK(
- corrupt_file_writer->Append(StringPiece(buffer, buffer_size - 1)));
+ TF_CHECK_OK(corrupt_file_writer->Append(
+ absl::string_view(buffer, buffer_size - 1)));
TF_CHECK_OK(corrupt_file_writer->Flush());
TF_CHECK_OK(corrupt_file_writer->Close());
delete[] scratch;
@@ -216,7 +216,7 @@
TF_CHECK_OK(env->NewWritableFile(fname, &file_writer));
io::SnappyOutputBuffer out(file_writer.get(), compress_input_buf_size,
compress_output_buf_size);
- TF_CHECK_OK(out.Write(StringPiece(data)));
+ TF_CHECK_OK(out.Write(absl::string_view(data)));
TF_CHECK_OK(out.Flush());
TF_CHECK_OK(file_writer->Flush());
TF_CHECK_OK(file_writer->Close());
@@ -296,7 +296,7 @@
static bool SnappyCompressionSupported() {
string out;
- StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
+ absl::string_view in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
return port::Snappy_Compress(in.data(), in.size(), &out);
}
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc
index 0aa65e8..c2ff61d 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/lib/io/random_inputstream.h"
#include "tsl/lib/io/zlib_compression_options.h"
#include "tsl/lib/io/zlib_inputstream.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD b/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD
deleted file mode 100644
index 699965e..0000000
--- a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD
+++ /dev/null
@@ -1,57 +0,0 @@
-load(
- "@local_tsl//tsl/platform:rules_cc.bzl",
- "cc_library",
-)
-load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility")
-load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup")
-
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
-
-cc_library(
- name = "proto_serialization",
- srcs = ["proto_serialization.cc"],
- hdrs = ["proto_serialization.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tsl/lib/gtl:inlined_vector",
- "//tsl/platform:hash",
- "//tsl/platform:logging",
- "//tsl/platform:macros",
- "//tsl/platform:protobuf",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- ],
-)
-
-filegroup(
- name = "mobile_srcs_only_runtime",
- srcs = [
- "proto_serialization.cc",
- "proto_serialization.h",
- ],
- visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
-
-filegroup(
- name = "legacy_lib_strings_all_headers",
- srcs = [
- "proto_serialization.h",
- ],
- visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
-
-filegroup(
- name = "legacy_lib_string_headers",
- srcs = [
- "proto_serialization.h",
- ],
- visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
-
-filegroup(
- name = "legacy_lib_internal_public_string_headers",
- srcs = [
- "proto_serialization.h",
- ],
- visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc b/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc
deleted file mode 100644
index 139849e..0000000
--- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc
+++ /dev/null
@@ -1,101 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tsl/lib/strings/proto_serialization.h"
-
-#include <cstring>
-#include <memory>
-
-#include "absl/memory/memory.h"
-#include "absl/strings/string_view.h"
-#include "tsl/lib/gtl/inlined_vector.h"
-#include "tsl/platform/hash.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/macros.h"
-
-namespace tsl {
-namespace {
-
-// Helper for deterministic serialization.
-class DeterministicSerializer {
- public:
- explicit DeterministicSerializer(const protobuf::MessageLite& msg)
- : DeterministicSerializer(msg, msg.ByteSizeLong()) {}
-
- DeterministicSerializer(const protobuf::MessageLite& msg, size_t size)
- : size_(size) {
- char* ptr = space_;
- if (size_ > sizeof(space_)) {
- ptr = new char[size_];
- alloc_.reset(ptr);
- }
- bool ok = SerializeToBufferDeterministic(msg, ptr, size_);
- DCHECK(ok);
- }
-
- size_t size() const { return size_; }
- const char* data() const { return alloc_ == nullptr ? space_ : alloc_.get(); }
-
- private:
- // Avoid InlinedVector since it causes 2x slowdown in the compilation
- // of graphs containing large tensors in debug mode.
- static constexpr int kInlinedBufferSize = 256;
- const size_t size_;
- std::unique_ptr<char[]> alloc_;
- char space_[kInlinedBufferSize];
-};
-} // namespace
-
-bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
- string* result) {
- const size_t size = msg.ByteSizeLong();
- DCHECK_LE(size, static_cast<size_t>(INT_MAX));
- *result = string(size, '\0');
- return SerializeToBufferDeterministic(msg, const_cast<char*>(result->data()),
- result->size());
-}
-
-bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
- char* buffer, size_t size) {
- DCHECK(msg.ByteSizeLong() == size && size <= static_cast<size_t>(INT_MAX));
- protobuf::io::ArrayOutputStream array_stream(buffer, size);
- protobuf::io::CodedOutputStream output_stream(&array_stream);
- output_stream.SetSerializationDeterministic(true);
- msg.SerializeWithCachedSizes(&output_stream);
- return !output_stream.HadError() &&
- size == static_cast<size_t>(output_stream.ByteCount());
-}
-
-bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
- const protobuf::MessageLite& y) {
- const size_t size = x.ByteSizeLong();
- if (size != y.ByteSizeLong()) return false;
- if (size == 0) return true;
- DeterministicSerializer x_serialized(x, size);
- DeterministicSerializer y_serialized(y, size);
- return memcmp(x_serialized.data(), y_serialized.data(), size) == 0;
-}
-
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
- uint64 seed) {
- DeterministicSerializer serialized(proto);
- return Hash64(serialized.data(), serialized.size(), seed);
-}
-
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto) {
- DeterministicSerializer serialized(proto);
- return Hash64(serialized.data(), serialized.size());
-}
-
-} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h b/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h
deleted file mode 100644
index 96a5c55..0000000
--- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
-#define TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
-
-#include "tsl/platform/protobuf.h"
-
-namespace tsl {
-
-// Wrapper around protocol buffer serialization that requests deterministic
-// serialization, in particular for Map fields, which serialize in a random
-// order by default. Returns true on success.
-// Serialization is guaranteed to be deterministic for a given binary only.
-// See the following for more details:
-// https://github.com/google/protobuf/blob/a1bb147e96b6f74db6cdf3c3fcb00492472dbbfa/src/google/protobuf/io/coded_stream.h#L834
-bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
- string* result);
-
-// As above, but takes a pre-allocated buffer wrapped by result.
-// PRECONDITION: size == msg.ByteSizeLong() && size <= INT_MAX.
-bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
- char* buffer, size_t size);
-
-// Returns true if serializing x and y using
-// SerializeToBufferDeterministic() yields identical strings.
-bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
- const protobuf::MessageLite& y);
-
-// Computes Hash64 of the output of SerializeToBufferDeterministic().
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto);
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
- uint64 seed);
-
-} // namespace tsl
-
-#endif // TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD
index 1f4c491..0f61c77 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD
@@ -1465,7 +1465,7 @@
":subprocess",
":test",
":test_main",
- "//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -1768,7 +1768,7 @@
":str_util",
":test",
":test_main",
- "//tsl/lib/core:status_test_util",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -1784,7 +1784,7 @@
":str_util",
":test",
":test_main",
- "//tsl/lib/core:status_test_util",
"@com_google_absl//absl/time",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD
index bff2db4..e92fd04 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD
@@ -231,7 +231,6 @@
copts = tsl_copts(),
deps = [
":curl_http_request",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:errors",
"//tsl/platform:macros",
"//tsl/platform:protobuf",
@@ -240,6 +239,7 @@
"//tsl/platform:test",
"//tsl/platform:types",
"@curl",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -359,10 +359,10 @@
deps = [
":expiring_lru_cache",
":now_seconds_env",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env_impl",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -373,13 +373,13 @@
deps = [
":now_seconds_env",
":ram_file_block_cache",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:blocking_counter",
"//tsl/platform:env",
"//tsl/platform:env_impl",
"//tsl/platform:notification",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -390,7 +390,6 @@
deps = [
":gcs_file_system",
":http_request_fake",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env_impl",
"//tsl/platform:errors",
"//tsl/platform:str_util",
@@ -399,6 +398,7 @@
"//tsl/platform:test_main",
"//tsl/profiler/backends/cpu:traceme_recorder_impl",
"//tsl/profiler/utils:time_utils_impl",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -423,11 +423,11 @@
linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]),
deps = [
":gcs_throttle",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env_impl",
"//tsl/platform:str_util",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -437,13 +437,13 @@
srcs = ["curl_http_request_test.cc"],
deps = [
":curl_http_request",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env_impl",
"//tsl/platform:path",
"//tsl/platform:platform_port",
"//tsl/platform:test",
"//tsl/platform:test_main",
"@com_google_absl//absl/status",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -459,7 +459,6 @@
deps = [
":http_request_fake",
":oauth_client",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:base64",
"//tsl/platform:env",
"//tsl/platform:env_impl",
@@ -468,6 +467,7 @@
"//tsl/platform:test",
"//tsl/platform:test_main",
"@boringssl//:crypto",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -484,11 +484,11 @@
":google_auth_provider",
":http_request_fake",
":oauth_client",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:env_impl",
"//tsl/platform:path",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
@@ -525,8 +525,8 @@
srcs = ["time_util_test.cc"],
deps = [
":time_util",
- "//tsl/lib/core:status_test_util",
"//tsl/platform:test",
"//tsl/platform:test_main",
+ "@local_xla//xla/tsl/lib/core:status_test_util",
],
)
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h
index 969b8bc..4b1b292 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h
@@ -31,9 +31,9 @@
/// \brief Returns the short-term authentication bearer token.
///
/// Safe for concurrent use by multiple threads.
- virtual Status GetToken(string* t) = 0;
+ virtual absl::Status GetToken(string* t) = 0;
- static Status GetToken(AuthProvider* provider, string* token) {
+ static absl::Status GetToken(AuthProvider* provider, string* token) {
if (!provider) {
return errors::Internal("Auth provider is required.");
}
@@ -44,9 +44,9 @@
/// No-op auth provider, which will only work for public objects.
class EmptyAuthProvider : public AuthProvider {
public:
- Status GetToken(string* token) override {
+ absl::Status GetToken(string* token) override {
*token = "";
- return OkStatus();
+ return absl::OkStatus();
}
};
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc
index 7be3af7..7a41c8f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc
@@ -40,7 +40,7 @@
: http_request_factory_(std::move(http_request_factory)),
retry_config_(config) {}
-Status ComputeEngineMetadataClient::GetMetadata(
+absl::Status ComputeEngineMetadataClient::GetMetadata(
const string& path, std::vector<char>* response_buffer) {
const auto get_metadata_from_gce = [path, response_buffer, this]() {
string metadata_url;
@@ -56,7 +56,7 @@
request->AddHeader("Metadata-Flavor", "Google");
request->SetResultBuffer(response_buffer);
TF_RETURN_IF_ERROR(request->Send());
- return OkStatus();
+ return absl::OkStatus();
};
return RetryingUtils::CallWithRetries(get_metadata_from_gce, retry_config_);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h
index fac94cd..1337d33 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h
@@ -51,8 +51,8 @@
/// To get the zone of an instance:
/// compute_engine_metadata_client.GetMetadata(
/// "instance/zone", response_buffer);
- virtual Status GetMetadata(const string& path,
- std::vector<char>* response_buffer);
+ virtual absl::Status GetMetadata(const string& path,
+ std::vector<char>* response_buffer);
private:
std::shared_ptr<HttpRequest::Factory> http_request_factory_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc
index 7720784..19f2755 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc
@@ -28,15 +28,15 @@
std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client)
: google_metadata_client_(std::move(google_metadata_client)) {}
-Status ComputeEngineZoneProvider::GetZone(string* zone) {
+absl::Status ComputeEngineZoneProvider::GetZone(string* zone) {
if (!cached_zone.empty()) {
*zone = cached_zone;
- return OkStatus();
+ return absl::OkStatus();
}
std::vector<char> response_buffer;
TF_RETURN_IF_ERROR(google_metadata_client_->GetMetadata(kGceMetadataZonePath,
&response_buffer));
- StringPiece location(&response_buffer[0], response_buffer.size());
+ absl::string_view location(&response_buffer[0], response_buffer.size());
std::vector<string> elems = str_util::Split(location, "/");
if (elems.size() == 4) {
@@ -47,7 +47,7 @@
<< string(location);
}
- return OkStatus();
+ return absl::OkStatus();
}
ComputeEngineZoneProvider::~ComputeEngineZoneProvider() {}
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h
index a37b43c..99ed41f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h
@@ -27,7 +27,7 @@
std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client);
virtual ~ComputeEngineZoneProvider();
- Status GetZone(string* zone) override;
+ absl::Status GetZone(string* zone) override;
private:
std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc
index c41f967..44eeab7 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc
@@ -230,8 +230,8 @@
libcurl_->curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE"));
}
-Status CurlHttpRequest::SetPutFromFile(const string& body_filepath,
- size_t offset) {
+absl::Status CurlHttpRequest::SetPutFromFile(const string& body_filepath,
+ size_t offset) {
CheckNotSent();
CheckMethodNotSet();
is_method_set_ = true;
@@ -257,7 +257,7 @@
reinterpret_cast<void*>(put_body_)));
// Using the default CURLOPT_READFUNCTION, which is doing an fread() on the
// FILE * userdata set with CURLOPT_READDATA.
- return OkStatus();
+ return absl::OkStatus();
}
void CurlHttpRequest::SetPutEmptyBody() {
@@ -286,7 +286,7 @@
reinterpret_cast<void*>(this)));
CHECK_CURL_OK(libcurl_->curl_easy_setopt(curl_, CURLOPT_READFUNCTION,
&CurlHttpRequest::ReadCallback));
- post_body_buffer_ = StringPiece(buffer, size);
+ post_body_buffer_ = absl::string_view(buffer, size);
}
void CurlHttpRequest::SetPostEmptyBody() {
@@ -397,8 +397,8 @@
size_t nmemb, void* this_object) {
CHECK(ptr);
auto that = reinterpret_cast<CurlHttpRequest*>(this_object);
- StringPiece header(reinterpret_cast<const char*>(ptr), size * nmemb);
- StringPiece name, value;
+ absl::string_view header(reinterpret_cast<const char*>(ptr), size * nmemb);
+ absl::string_view name, value;
// The supplied header has the form "<name>: <value>", parse it.
if (strings::Scanner(header)
.ScanEscapedUntil(':')
@@ -412,7 +412,7 @@
return size * nmemb;
}
-Status CurlHttpRequest::Send() {
+absl::Status CurlHttpRequest::Send() {
CheckNotSent();
CHECK(is_uri_set_) << "URI has not been set.";
@@ -457,7 +457,7 @@
auto get_error_message = [this]() -> string {
string error_message = strings::StrCat(
"Error executing an HTTP request: HTTP response code ", response_code_);
- StringPiece body = GetResponse();
+ absl::string_view body = GetResponse();
if (!body.empty()) {
return strings::StrCat(
error_message, " with body '",
@@ -466,7 +466,7 @@
return error_message;
};
- Status result;
+ absl::Status result;
switch (response_code_) {
// The group of response codes indicating that the request achieved
// the expected goal.
@@ -474,7 +474,7 @@
case 201: // Created
case 204: // No Content
case 206: // Partial Content
- result = OkStatus();
+ result = absl::OkStatus();
break;
case 416: // Requested Range Not Satisfiable
@@ -485,7 +485,7 @@
if (IsDirectResponse()) {
direct_response_.bytes_transferred_ = 0;
}
- result = OkStatus();
+ result = absl::OkStatus();
break;
// INVALID_ARGUMENT indicates a problem with how the request is constructed.
@@ -556,13 +556,14 @@
CHECK(!is_sent_) << "The request has already been sent.";
}
-StringPiece CurlHttpRequest::GetResponse() const {
- StringPiece response;
+absl::string_view CurlHttpRequest::GetResponse() const {
+ absl::string_view response;
if (IsDirectResponse()) {
- response = StringPiece(direct_response_.buffer_,
- direct_response_.bytes_transferred_);
+ response = absl::string_view(direct_response_.buffer_,
+ direct_response_.bytes_transferred_);
} else {
- response = StringPiece(response_buffer_->data(), response_buffer_->size());
+ response =
+ absl::string_view(response_buffer_->data(), response_buffer_->size());
}
return response;
}
@@ -627,10 +628,10 @@
return 0;
}
-Status CurlHttpRequest::CURLcodeToStatus(CURLcode code,
- const char* error_buffer) {
+absl::Status CurlHttpRequest::CURLcodeToStatus(CURLcode code,
+ const char* error_buffer) {
if (code == CURLE_OK) {
- return OkStatus();
+ return absl::OkStatus();
}
string error_message = strings::StrCat(
"Error executing an HTTP request: libcurl code ", code, " meaning '",
@@ -648,7 +649,7 @@
// a response body (e.g. GCS sends one with an error message) but we
// pretend as though they don't, so actually ignore this error.
if (get_response_result == CURLE_OK && response_code == 416) {
- return OkStatus();
+ return absl::OkStatus();
}
return errors::FailedPrecondition(
strings::StrCat(error_message, overflow_message));
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h
index b5c7285..4c64758 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h
@@ -86,7 +86,8 @@
///
/// The request body will be taken from the specified file starting from
/// the given offset.
- Status SetPutFromFile(const string& body_filepath, size_t offset) override;
+ absl::Status SetPutFromFile(const string& body_filepath,
+ size_t offset) override;
/// Makes the request a PUT request with an empty body.
void SetPutEmptyBody() override;
@@ -140,7 +141,7 @@
///
/// If the result buffer was defined, the response will be written there.
/// The object is not designed to be re-used after Send() is executed.
- Status Send() override;
+ absl::Status Send() override;
// Url encodes str and returns a new string.
string EscapeString(const string& str) override;
@@ -167,18 +168,18 @@
curl_off_t ulnow);
void CheckMethodNotSet() const;
void CheckNotSent() const;
- StringPiece GetResponse() const;
+ absl::string_view GetResponse() const;
/// Helper to convert the given CURLcode and error buffer, representing the
/// result of performing a transfer, into a Status with an error message.
- Status CURLcodeToStatus(CURLcode code, const char* error_buffer);
+ absl::Status CURLcodeToStatus(CURLcode code, const char* error_buffer);
LibCurl* libcurl_;
Env* env_;
FILE* put_body_ = nullptr;
- StringPiece post_body_buffer_;
+ absl::string_view post_body_buffer_;
size_t post_body_read_ = 0;
std::vector<char>* response_buffer_ = nullptr;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc
index 36d7108..31cde67 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc
@@ -19,7 +19,7 @@
#include <string>
#include "absl/status/status.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/mem.h"
#include "tsl/platform/path.h"
#include "tsl/platform/test.h"
@@ -151,8 +151,8 @@
posted_content_ = "";
do {
bytes_read = read_callback_(buffer, 1, sizeof(buffer), read_data_);
- posted_content_ =
- strings::StrCat(posted_content_, StringPiece(buffer, bytes_read));
+ posted_content_ = strings::StrCat(
+ posted_content_, absl::string_view(buffer, bytes_read));
} while (bytes_read > 0);
}
if (write_data_ || write_callback_) {
@@ -366,7 +366,7 @@
http_request.SetUri("http://www.testuri.com");
http_request.SetResultBufferDirect(scratch.data(), scratch.size());
- const Status& status = http_request.Send();
+ const absl::Status& status = http_request.Send();
EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
EXPECT_EQ(
"Error executing an HTTP request: libcurl code 23 meaning "
@@ -770,7 +770,7 @@
void RecordResponse(const HttpRequest* request, const string& uri,
HttpRequest::RequestMethod method,
- const Status& result) override {
+ const absl::Status& result) override {
has_recorded_response_ = true;
record_response_request_ = request;
record_response_uri_ = uri;
@@ -787,7 +787,7 @@
string record_response_uri_ = "http://www.testuri.com";
HttpRequest::RequestMethod record_response_method_ =
HttpRequest::RequestMethod::kGet;
- Status record_response_result_;
+ absl::Status record_response_result_;
bool has_recorded_request_ = false;
bool has_recorded_response_ = false;
@@ -864,7 +864,7 @@
http_request.AddAuthBearerHeader("fake-bearer");
http_request.SetRange(100, 199);
http_request.SetResultBuffer(&scratch);
- Status s = http_request.Send();
+ absl::Status s = http_request.Send();
// Check interaction with stats.
ASSERT_TRUE(stats.has_recorded_request_);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h
index 1def81b..d3a15bc 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h
@@ -71,13 +71,13 @@
return LookupLocked(key, value);
}
- typedef std::function<Status(const string&, T*)> ComputeFunc;
+ typedef std::function<absl::Status(const string&, T*)> ComputeFunc;
/// Look up the entry with key `key` and copy it to `value` if found. If not
/// found, call `compute_func`. If `compute_func` returns successfully, store
/// a copy of the output parameter in the cache, and another copy in `value`.
- Status LookupOrCompute(const string& key, T* value,
- const ComputeFunc& compute_func) {
+ absl::Status LookupOrCompute(const string& key, T* value,
+ const ComputeFunc& compute_func) {
if (max_age_ == 0) {
return compute_func(key, value);
}
@@ -88,9 +88,9 @@
// key if this proves to be a significant performance bottleneck.
mutex_lock lock(mu_);
if (LookupLocked(key, value)) {
- return OkStatus();
+ return absl::OkStatus();
}
- Status s = compute_func(key, value);
+ absl::Status s = compute_func(key, value);
if (s.ok()) {
InsertLocked(key, *value);
}
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc
index ce3e0fc..7225dca 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc
@@ -17,7 +17,7 @@
#include <memory>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/cloud/now_seconds_env.h"
#include "tsl/platform/test.h"
@@ -97,7 +97,7 @@
[&num_compute_calls](const string& key, int* value) {
*value = num_compute_calls;
num_compute_calls++;
- return OkStatus();
+ return absl::OkStatus();
};
ExpiringLRUCache<int> cache1(0, 4);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h
index e336a42..5992754 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h
@@ -70,9 +70,9 @@
/// cache is constructed. The returned Status should be OK as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
- typedef std::function<Status(const string& filename, size_t offset,
- size_t buffer_size, char* buffer,
- size_t* bytes_transferred)>
+ typedef std::function<absl::Status(const string& filename, size_t offset,
+ size_t buffer_size, char* buffer,
+ size_t* bytes_transferred)>
BlockFetcher;
virtual ~FileBlockCache() {}
@@ -91,8 +91,8 @@
/// placed in `out`.
/// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
/// in `out`).
- virtual Status Read(const string& filename, size_t offset, size_t n,
- char* buffer, size_t* bytes_transferred) = 0;
+ virtual absl::Status Read(const string& filename, size_t offset, size_t n,
+ char* buffer, size_t* bytes_transferred) = 0;
// Validate the given file signature with the existing file signature in the
// cache. Returns true if the signature doesn't change or the file did not
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc
index 4819b49..594703f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc
@@ -41,7 +41,8 @@
const std::vector<string>& kCachedDomainNames =
*new std::vector<string>{"www.googleapis.com", "storage.googleapis.com"};
-inline void print_getaddrinfo_error(const string& name, Status return_status) {
+inline void print_getaddrinfo_error(const string& name,
+ absl::Status return_status) {
// Status doesn't map well to EAI type errors.
LOG(ERROR) << "Error resolving " << name << ": " << return_status;
}
@@ -104,13 +105,13 @@
/* max_delay_time_us = */ 50 * 1000 * 5000,
/* max_retries = */ 5);
- const Status getaddrinfo_status = RetryingUtils::CallWithRetries(
+ const absl::Status getaddrinfo_status = RetryingUtils::CallWithRetries(
[&name, &hints, &result]() {
int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result);
absl::Status return_status;
switch (return_code) {
case 0:
- return_status = OkStatus();
+ return_status = absl::OkStatus();
break;
#ifndef _WIN32
case EAI_ADDRFAMILY:
@@ -175,7 +176,7 @@
#endif
}
- return Status(return_status);
+ return absl::Status(return_status);
},
retryConfig);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc
index 069dcb5..a5ce088 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc
@@ -40,8 +40,9 @@
void SetRequestStats(HttpRequest::RequestStats* stats) override {}
void SetDeleteRequest() override {}
- Status SetPutFromFile(const string& body_filepath, size_t offset) override {
- return OkStatus();
+ absl::Status SetPutFromFile(const string& body_filepath,
+ size_t offset) override {
+ return absl::OkStatus();
}
void SetPutEmptyBody() override {}
void SetPostFromBuffer(const char* buffer, size_t size) override {}
@@ -52,7 +53,7 @@
string GetResponseHeader(const string& name) const override { return ""; }
uint64 GetResponseCode() const override { return 0; }
- Status Send() override { return OkStatus(); }
+ absl::Status Send() override { return absl::OkStatus(); }
string EscapeString(const string& str) override { return ""; }
void SetTimeouts(uint32 connection, uint32 inactivity,
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc
index ea65028..c1cc244 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc
@@ -163,9 +163,9 @@
// objects.
constexpr char kComposeAppend[] = "compose";
-Status GetTmpFilename(string* filename) {
+absl::Status GetTmpFilename(string* filename) {
*filename = io::GetTempFilename("");
- return OkStatus();
+ return absl::OkStatus();
}
/// Appends a trailing slash if the name doesn't already have one.
@@ -199,7 +199,7 @@
std::set<string> result;
result.insert(paths.begin(), paths.end());
for (const string& path : paths) {
- StringPiece subpath = io::Dirname(path);
+ absl::string_view subpath = io::Dirname(path);
// If `path` starts with `/`, `subpath` will be `/` and then we get into an
// infinite loop. Same behavior happens if there is a `//` pattern in
// `path`, so we check for that and leave the loop quicker.
@@ -211,32 +211,32 @@
return result;
}
-Status ParseJson(StringPiece json, Json::Value* result) {
+absl::Status ParseJson(absl::string_view json, Json::Value* result) {
Json::Reader reader;
if (!reader.parse(json.data(), json.data() + json.size(), *result)) {
return errors::Internal("Couldn't parse JSON response from GCS.");
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status ParseJson(const std::vector<char>& json, Json::Value* result) {
- return ParseJson(StringPiece{json.data(), json.size()}, result);
+absl::Status ParseJson(const std::vector<char>& json, Json::Value* result) {
+ return ParseJson(absl::string_view{json.data(), json.size()}, result);
}
/// Reads a JSON value with the given name from a parent JSON value.
-Status GetValue(const Json::Value& parent, const char* name,
- Json::Value* result) {
+absl::Status GetValue(const Json::Value& parent, const char* name,
+ Json::Value* result) {
*result = parent.get(name, Json::Value::null);
if (result->isNull()) {
return errors::Internal("The field '", name,
"' was expected in the JSON response.");
}
- return OkStatus();
+ return absl::OkStatus();
}
/// Reads a string JSON value with the given name from a parent JSON value.
-Status GetStringValue(const Json::Value& parent, const char* name,
- string* result) {
+absl::Status GetStringValue(const Json::Value& parent, const char* name,
+ string* result) {
Json::Value result_value;
TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
if (!result_value.isString()) {
@@ -245,21 +245,21 @@
"' in the JSON response was expected to be a string.");
}
*result = result_value.asString();
- return OkStatus();
+ return absl::OkStatus();
}
/// Reads a long JSON value with the given name from a parent JSON value.
-Status GetInt64Value(const Json::Value& parent, const char* name,
- int64_t* result) {
+absl::Status GetInt64Value(const Json::Value& parent, const char* name,
+ int64_t* result) {
Json::Value result_value;
TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
if (result_value.isNumeric()) {
*result = result_value.asInt64();
- return OkStatus();
+ return absl::OkStatus();
}
if (result_value.isString() &&
strings::safe_strto64(result_value.asCString(), result)) {
- return OkStatus();
+ return absl::OkStatus();
}
return errors::Internal(
"The field '", name,
@@ -267,7 +267,8 @@
}
/// Reads a boolean JSON value with the given name from a parent JSON value.
-Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) {
+absl::Status GetBoolValue(const Json::Value& parent, const char* name,
+ bool* result) {
Json::Value result_value;
TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
if (!result_value.isBool()) {
@@ -276,7 +277,7 @@
"' in the JSON response was expected to be a boolean.");
}
*result = result_value.asBool();
- return OkStatus();
+ return absl::OkStatus();
}
/// Get GCS Retry Config by applying user overrides through env if any.
@@ -314,21 +315,21 @@
/// A GCS-based implementation of a random access file with an LRU block cache.
class GcsRandomAccessFile : public RandomAccessFile {
public:
- using ReadFn =
- std::function<Status(const string& filename, uint64 offset, size_t n,
- StringPiece* result, char* scratch)>;
+ using ReadFn = std::function<absl::Status(
+ const string& filename, uint64 offset, size_t n,
+ absl::string_view* result, char* scratch)>;
GcsRandomAccessFile(const string& filename, ReadFn read_fn)
: filename_(filename), read_fn_(std::move(read_fn)) {}
- Status Name(StringPiece* result) const override {
+ absl::Status Name(absl::string_view* result) const override {
*result = filename_;
- return OkStatus();
+ return absl::OkStatus();
}
/// The implementation of reads with an LRU block cache. Thread safe.
- Status Read(uint64 offset, size_t n, StringPiece* result,
- char* scratch) const override {
+ absl::Status Read(uint64 offset, size_t n, absl::string_view* result,
+ char* scratch) const override {
return read_fn_(filename_, offset, n, result, scratch);
}
@@ -342,9 +343,9 @@
/// A GCS-based implementation of a random access file with a read buffer.
class BufferedGcsRandomAccessFile : public RandomAccessFile {
public:
- using ReadFn =
- std::function<Status(const string& filename, uint64 offset, size_t n,
- StringPiece* result, char* scratch)>;
+ using ReadFn = std::function<absl::Status(
+ const string& filename, uint64 offset, size_t n,
+ absl::string_view* result, char* scratch)>;
// Initialize the reader. Provided read_fn should be thread safe.
BufferedGcsRandomAccessFile(const string& filename, uint64 buffer_size,
@@ -355,16 +356,16 @@
buffer_start_(0),
buffer_end_is_past_eof_(false) {}
- Status Name(StringPiece* result) const override {
+ absl::Status Name(absl::string_view* result) const override {
*result = filename_;
- return OkStatus();
+ return absl::OkStatus();
}
/// The implementation of reads with an read buffer. Thread safe.
/// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result`
/// because of EOF.
- Status Read(uint64 offset, size_t n, StringPiece* result,
- char* scratch) const override {
+ absl::Status Read(uint64 offset, size_t n, absl::string_view* result,
+ char* scratch) const override {
if (n > buffer_size_) {
return read_fn_(filename_, offset, n, result, scratch);
}
@@ -375,12 +376,12 @@
if (offset < buffer_end && offset >= buffer_start_) {
copy_size = std::min(n, static_cast<size_t>(buffer_end - offset));
memcpy(scratch, buffer_.data() + (offset - buffer_start_), copy_size);
- *result = StringPiece(scratch, copy_size);
+ *result = absl::string_view(scratch, copy_size);
}
bool consumed_buffer_to_eof =
offset + copy_size >= buffer_end && buffer_end_is_past_eof_;
if (copy_size < n && !consumed_buffer_to_eof) {
- Status status = FillBuffer(offset + copy_size);
+ absl::Status status = FillBuffer(offset + copy_size);
if (!status.ok() && !absl::IsOutOfRange(status)) {
// Empty the buffer to avoid caching bad reads.
buffer_.resize(0);
@@ -389,7 +390,7 @@
size_t remaining_copy = std::min(n - copy_size, buffer_.size());
memcpy(scratch + copy_size, buffer_.data(), remaining_copy);
copy_size += remaining_copy;
- *result = StringPiece(scratch, copy_size);
+ *result = absl::string_view(scratch, copy_size);
}
if (copy_size < n) {
// Forget the end-of-file flag to allow for clients that poll on the
@@ -399,17 +400,17 @@
" bytes from ", offset, ".");
}
}
- return OkStatus();
+ return absl::OkStatus();
}
private:
- Status FillBuffer(uint64 start) const
+ absl::Status FillBuffer(uint64 start) const
TF_EXCLUSIVE_LOCKS_REQUIRED(buffer_mutex_) {
buffer_start_ = start;
buffer_.resize(buffer_size_);
- StringPiece str_piece;
- Status status = read_fn_(filename_, buffer_start_, buffer_size_, &str_piece,
- &(buffer_[0]));
+ absl::string_view str_piece;
+ absl::Status status = read_fn_(filename_, buffer_start_, buffer_size_,
+ &str_piece, &(buffer_[0]));
buffer_end_is_past_eof_ = absl::IsOutOfRange(status);
buffer_.resize(str_piece.size());
return status;
@@ -437,28 +438,28 @@
};
// Function object declaration with params needed to create upload sessions.
-typedef std::function<Status(
+typedef std::function<absl::Status(
uint64 start_offset, const std::string& object_to_upload,
const std::string& bucket, uint64 file_size, const std::string& gcs_path,
UploadSessionHandle* session_handle)>
SessionCreator;
// Function object declaration with params needed to upload objects.
-typedef std::function<Status(const std::string& session_uri,
- uint64 start_offset, uint64 already_uploaded,
- const std::string& tmp_content_filename,
- uint64 file_size, const std::string& file_path)>
+typedef std::function<absl::Status(
+ const std::string& session_uri, uint64 start_offset,
+ uint64 already_uploaded, const std::string& tmp_content_filename,
+ uint64 file_size, const std::string& file_path)>
ObjectUploader;
// Function object declaration with params needed to poll upload status.
-typedef std::function<Status(const string& session_uri, uint64 file_size,
- const std::string& gcs_path, bool* completed,
- uint64* uploaded)>
+typedef std::function<absl::Status(const string& session_uri, uint64 file_size,
+ const std::string& gcs_path, bool* completed,
+ uint64* uploaded)>
StatusPoller;
// Function object declaration with params needed to poll upload status.
-typedef std::function<Status(const string& fname, const string& bucket,
- const string& object, int64_t* generation)>
+typedef std::function<absl::Status(const string& fname, const string& bucket,
+ const string& object, int64_t* generation)>
GenerationGetter;
/// \brief GCS-based implementation of a writeable file.
@@ -534,7 +535,7 @@
std::remove(tmp_content_filename_.c_str());
}
- Status Append(StringPiece data) override {
+ absl::Status Append(absl::string_view data) override {
TF_RETURN_IF_ERROR(CheckWritable());
VLOG(3) << "Append: " << GetGcsPath() << " size " << data.length();
sync_needed_ = true;
@@ -543,37 +544,38 @@
return errors::Internal(
"Could not append to the internal temporary file.");
}
- return OkStatus();
+ return absl::OkStatus();
}
- Status Close() override {
+ absl::Status Close() override {
VLOG(3) << "Close:" << GetGcsPath();
if (outfile_.is_open()) {
- Status sync_status = Sync();
+ absl::Status sync_status = Sync();
if (sync_status.ok()) {
outfile_.close();
}
return sync_status;
}
- return OkStatus();
+ return absl::OkStatus();
}
- Status Flush() override {
+ absl::Status Flush() override {
VLOG(3) << "Flush:" << GetGcsPath();
return Sync();
}
- Status Name(StringPiece* result) const override {
- return errors::Unimplemented("GCSWritableFile does not support Name()");
+ absl::Status Name(absl::string_view* result) const override {
+ *result = object_;
+ return absl::OkStatus();
}
- Status Sync() override {
+ absl::Status Sync() override {
VLOG(3) << "Sync started:" << GetGcsPath();
TF_RETURN_IF_ERROR(CheckWritable());
if (!sync_needed_) {
- return OkStatus();
+ return absl::OkStatus();
}
- Status status = SyncImpl();
+ absl::Status status = SyncImpl();
VLOG(3) << "Sync finished " << GetGcsPath();
if (status.ok()) {
sync_needed_ = false;
@@ -581,12 +583,12 @@
return status;
}
- Status Tell(int64_t* position) override {
+ absl::Status Tell(int64_t* position) override {
*position = outfile_.tellp();
if (*position == -1) {
return errors::Internal("tellp on the internal temporary file failed");
}
- return OkStatus();
+ return absl::OkStatus();
}
private:
@@ -596,7 +598,7 @@
/// In case of a failure, it resumes failed uploads as recommended by the GCS
/// resumable API documentation. When the whole upload needs to be
/// restarted, Sync() returns UNAVAILABLE and relies on RetryingFileSystem.
- Status SyncImpl() {
+ absl::Status SyncImpl() {
outfile_.flush();
if (!outfile_.good()) {
return errors::Internal(
@@ -620,7 +622,7 @@
&session_handle));
uint64 already_uploaded = 0;
bool first_attempt = true;
- const Status upload_status = RetryingUtils::CallWithRetries(
+ const absl::Status upload_status = RetryingUtils::CallWithRetries(
[&first_attempt, &already_uploaded, &session_handle, &start_offset,
this]() {
if (session_handle.resumable && !first_attempt) {
@@ -637,7 +639,7 @@
// It's unclear why UploadToSession didn't return OK in the
// previous attempt, but GCS reports that the file is fully
// uploaded, so succeed.
- return OkStatus();
+ return absl::OkStatus();
}
}
first_attempt = false;
@@ -661,28 +663,28 @@
return upload_status;
}
- Status CheckWritable() const {
+ absl::Status CheckWritable() const {
if (!outfile_.is_open()) {
return errors::FailedPrecondition(
"The internal temporary file is not writable.");
}
- return OkStatus();
+ return absl::OkStatus();
}
- Status GetCurrentFileSize(uint64* size) {
+ absl::Status GetCurrentFileSize(uint64* size) {
const auto tellp = outfile_.tellp();
if (tellp == static_cast<std::streampos>(-1)) {
return errors::Internal(
"Could not get the size of the internal temporary file.");
}
*size = tellp;
- return OkStatus();
+ return absl::OkStatus();
}
/// Initiates a new resumable upload session.
- Status CreateNewUploadSession(uint64 start_offset,
- std::string object_to_upload,
- UploadSessionHandle* session_handle) {
+ absl::Status CreateNewUploadSession(uint64 start_offset,
+ std::string object_to_upload,
+ UploadSessionHandle* session_handle) {
uint64 file_size;
TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
return session_creator_(start_offset, object_to_upload, bucket_, file_size,
@@ -691,7 +693,7 @@
/// Appends the data of append_object to the original object and deletes
/// append_object.
- Status AppendObject(string append_object) {
+ absl::Status AppendObject(string append_object) {
const string append_object_path = GetGcsPathWithObject(append_object);
VLOG(3) << "AppendObject: " << append_object_path << " to " << GetGcsPath();
@@ -718,7 +720,7 @@
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(),
" when composing to ", GetGcsPath());
- return OkStatus();
+ return absl::OkStatus();
},
retry_config_));
@@ -734,8 +736,8 @@
/// If the upload has already succeeded, sets 'completed' to true.
/// Otherwise sets 'completed' to false and 'uploaded' to the currently
/// uploaded size in bytes.
- Status RequestUploadSessionStatus(const string& session_uri, bool* completed,
- uint64* uploaded) {
+ absl::Status RequestUploadSessionStatus(const string& session_uri,
+ bool* completed, uint64* uploaded) {
uint64 file_size;
TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
return status_poller_(session_uri, file_size, GetGcsPath(), completed,
@@ -743,11 +745,11 @@
}
/// Uploads data to object.
- Status UploadToSession(const string& session_uri, uint64 start_offset,
- uint64 already_uploaded) {
+ absl::Status UploadToSession(const string& session_uri, uint64 start_offset,
+ uint64 already_uploaded) {
uint64 file_size;
TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
- Status status =
+ absl::Status status =
object_uploader_(session_uri, start_offset, already_uploaded,
tmp_content_filename_, file_size, GetGcsPath());
if (status.ok()) {
@@ -795,14 +797,14 @@
uint64 length_;
};
-bool StringPieceIdentity(StringPiece str, StringPiece* value) {
+bool StringPieceIdentity(absl::string_view str, absl::string_view* value) {
*value = str;
return true;
}
/// \brief Utility function to split a comma delimited list of strings to an
/// unordered set, lowercasing all values.
-bool SplitByCommaToLowercaseSet(StringPiece list,
+bool SplitByCommaToLowercaseSet(absl::string_view list,
std::unordered_set<string>* set) {
std::vector<string> vector = absl::StrSplit(absl::AsciiStrToLower(list), ',');
*set = std::unordered_set<string>(vector.begin(), vector.end());
@@ -897,14 +899,14 @@
}
// Get the additional header
- StringPiece add_header_contents;
+ absl::string_view add_header_contents;
if (GetEnvVar(kAdditionalRequestHeader, StringPieceIdentity,
&add_header_contents)) {
size_t split = add_header_contents.find(':', 0);
- if (split != StringPiece::npos) {
- StringPiece header_name = add_header_contents.substr(0, split);
- StringPiece header_value = add_header_contents.substr(split + 1);
+ if (split != absl::string_view::npos) {
+ absl::string_view header_name = add_header_contents.substr(0, split);
+ absl::string_view header_value = add_header_contents.substr(split + 1);
if (!header_name.empty() && !header_value.empty()) {
additional_header_.reset(new std::pair<const string, const string>(
@@ -968,7 +970,7 @@
GetEnvVar(kAllowedBucketLocations, SplitByCommaToLowercaseSet,
&allowed_locations_);
- StringPiece append_mode;
+ absl::string_view append_mode;
GetEnvVar(kAppendMode, StringPieceIdentity, &append_mode);
if (append_mode == kComposeAppend) {
compose_append_ = true;
@@ -1006,7 +1008,7 @@
compose_append_(compose_append),
additional_header_(additional_header) {}
-Status GcsFileSystem::NewRandomAccessFile(
+absl::Status GcsFileSystem::NewRandomAccessFile(
const string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) {
string bucket, object;
@@ -1016,7 +1018,7 @@
result->reset(new GcsRandomAccessFile(fname, [this, bucket, object](
const string& fname,
uint64 offset, size_t n,
- StringPiece* result,
+ absl::string_view* result,
char* scratch) {
tf_shared_lock l(block_cache_lock_);
GcsFileStat stat;
@@ -1031,37 +1033,37 @@
<< "File signature has been changed. Refreshing the cache. Path: "
<< fname;
}
- *result = StringPiece();
+ *result = absl::string_view();
size_t bytes_transferred;
TF_RETURN_IF_ERROR(file_block_cache_->Read(fname, offset, n, scratch,
&bytes_transferred));
- *result = StringPiece(scratch, bytes_transferred);
+ *result = absl::string_view(scratch, bytes_transferred);
if (bytes_transferred < n) {
return errors::OutOfRange("EOF reached, ", result->size(),
" bytes were read out of ", n,
" bytes requested.");
}
- return OkStatus();
+ return absl::OkStatus();
}));
} else {
result->reset(new BufferedGcsRandomAccessFile(
fname, block_size_,
[this, bucket, object](const string& fname, uint64 offset, size_t n,
- StringPiece* result, char* scratch) {
- *result = StringPiece();
+ absl::string_view* result, char* scratch) {
+ *result = absl::string_view();
size_t bytes_transferred;
TF_RETURN_IF_ERROR(
LoadBufferFromGCS(fname, offset, n, scratch, &bytes_transferred));
- *result = StringPiece(scratch, bytes_transferred);
+ *result = absl::string_view(scratch, bytes_transferred);
if (bytes_transferred < n) {
return errors::OutOfRange("EOF reached, ", result->size(),
" bytes were read out of ", n,
" bytes requested.");
}
- return OkStatus();
+ return absl::OkStatus();
}));
}
- return OkStatus();
+ return absl::OkStatus();
}
void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes,
@@ -1092,9 +1094,10 @@
}
// A helper function to actually read the data from GCS.
-Status GcsFileSystem::LoadBufferFromGCS(const string& fname, size_t offset,
- size_t n, char* buffer,
- size_t* bytes_transferred) {
+absl::Status GcsFileSystem::LoadBufferFromGCS(const string& fname,
+ size_t offset, size_t n,
+ char* buffer,
+ size_t* bytes_transferred) {
*bytes_transferred = 0;
string bucket, object;
@@ -1148,11 +1151,11 @@
}
}
- return OkStatus();
+ return absl::OkStatus();
}
/// Initiates a new upload session.
-Status GcsFileSystem::CreateNewUploadSession(
+absl::Status GcsFileSystem::CreateNewUploadSession(
uint64 start_offset, const std::string& object_to_upload,
const std::string& bucket, uint64 file_size, const std::string& gcs_path,
UploadSessionHandle* session_handle) {
@@ -1179,15 +1182,13 @@
gcs_path, ": 'Location' header not returned.");
}
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::UploadToSession(const std::string& session_uri,
- uint64 start_offset,
- uint64 already_uploaded,
- const std::string& tmp_content_filename,
- uint64 file_size,
- const std::string& file_path) {
+absl::Status GcsFileSystem::UploadToSession(
+ const std::string& session_uri, uint64 start_offset,
+ uint64 already_uploaded, const std::string& tmp_content_filename,
+ uint64 file_size, const std::string& file_path) {
std::unique_ptr<HttpRequest> request;
TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
request->SetUri(session_uri);
@@ -1203,14 +1204,12 @@
start_offset + already_uploaded));
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when uploading ",
file_path);
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::RequestUploadSessionStatus(const string& session_uri,
- uint64 file_size,
- const std::string& gcs_path,
- bool* completed,
- uint64* uploaded) {
+absl::Status GcsFileSystem::RequestUploadSessionStatus(
+ const string& session_uri, uint64 file_size, const std::string& gcs_path,
+ bool* completed, uint64* uploaded) {
CHECK(completed != nullptr) << "RequestUploadSessionStatus() called with out "
"param 'completed' == nullptr."; // Crash ok
CHECK(uploaded != nullptr) << "RequestUploadSessionStatus() called with out "
@@ -1221,10 +1220,10 @@
request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
request->AddHeader("Content-Range", strings::StrCat("bytes */", file_size));
request->SetPutEmptyBody();
- Status status = request->Send();
+ absl::Status status = request->Send();
if (status.ok()) {
*completed = true;
- return OkStatus();
+ return absl::OkStatus();
}
*completed = false;
if (request->GetResponseCode() != HTTP_CODE_RESUME_INCOMPLETE) {
@@ -1235,7 +1234,7 @@
// This means GCS doesn't have any bytes of the file yet.
*uploaded = 0;
} else {
- StringPiece range_piece(received_range);
+ absl::string_view range_piece(received_range);
absl::ConsumePrefix(&range_piece,
"bytes="); // May or may not be present.
@@ -1269,13 +1268,15 @@
// If GCS returned "Range: 0-10", this means 11 bytes were uploaded.
*uploaded = range_parts[1] + 1;
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::ParseGcsPathForScheme(StringPiece fname, string scheme,
- bool empty_object_ok,
- string* bucket, string* object) {
- StringPiece parsed_scheme, bucketp, objectp;
+absl::Status GcsFileSystem::ParseGcsPathForScheme(absl::string_view fname,
+ string scheme,
+ bool empty_object_ok,
+ string* bucket,
+ string* object) {
+ absl::string_view parsed_scheme, bucketp, objectp;
io::ParseURI(fname, &parsed_scheme, &bucketp, &objectp);
if (parsed_scheme != scheme) {
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
@@ -1292,11 +1293,12 @@
return errors::InvalidArgument("GCS path doesn't contain an object name: ",
fname);
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::ParseGcsPath(StringPiece fname, bool empty_object_ok,
- string* bucket, string* object) {
+absl::Status GcsFileSystem::ParseGcsPath(absl::string_view fname,
+ bool empty_object_ok, string* bucket,
+ string* object) {
return ParseGcsPathForScheme(fname, "gs", empty_object_ok, bucket, object);
}
@@ -1308,9 +1310,9 @@
// MatchingPathsCache as well.
}
-Status GcsFileSystem::NewWritableFile(const string& fname,
- TransactionToken* token,
- std::unique_ptr<WritableFile>* result) {
+absl::Status GcsFileSystem::NewWritableFile(
+ const string& fname, TransactionToken* token,
+ std::unique_ptr<WritableFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
@@ -1344,7 +1346,7 @@
},
retry_config_));
*generation = stat.generation_number;
- return OkStatus();
+ return absl::OkStatus();
};
result->reset(new GcsWritableFile(
@@ -1352,20 +1354,20 @@
[this, fname]() { ClearFileCaches(fname); }, retry_config_,
compose_append_, session_creator, object_uploader, status_poller,
generation_getter));
- return OkStatus();
+ return absl::OkStatus();
}
// Reads the file from GCS in chunks and stores it in a tmp file,
// which is then passed to GcsWritableFile.
-Status GcsFileSystem::NewAppendableFile(const string& fname,
- TransactionToken* token,
- std::unique_ptr<WritableFile>* result) {
+absl::Status GcsFileSystem::NewAppendableFile(
+ const string& fname, TransactionToken* token,
+ std::unique_ptr<WritableFile>* result) {
std::unique_ptr<RandomAccessFile> reader;
TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, token, &reader));
std::unique_ptr<char[]> buffer(new char[kReadAppendableFileBufferSize]);
- Status status;
+ absl::Status status;
uint64 offset = 0;
- StringPiece read_chunk;
+ absl::string_view read_chunk;
// Read the file from GCS in chunks and save it to a tmp file.
string old_content_filename;
@@ -1421,7 +1423,7 @@
},
retry_config_));
*generation = stat.generation_number;
- return OkStatus();
+ return absl::OkStatus();
};
// Create a writable file and pass the old content to it.
@@ -1432,10 +1434,10 @@
[this, fname]() { ClearFileCaches(fname); }, retry_config_,
compose_append_, session_creator, object_uploader, status_poller,
generation_getter));
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
+absl::Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
const string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) {
uint64 size;
@@ -1445,21 +1447,22 @@
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, token, &file));
- StringPiece piece;
+ absl::string_view piece;
TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
result->reset(new GcsReadOnlyMemoryRegion(std::move(data), size));
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::FileExists(const string& fname, TransactionToken* token) {
+absl::Status GcsFileSystem::FileExists(const string& fname,
+ TransactionToken* token) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
if (object.empty()) {
bool result;
TF_RETURN_IF_ERROR(BucketExists(bucket, &result));
if (result) {
- return OkStatus();
+ return absl::OkStatus();
} else {
return absl::NotFoundError(
absl::StrCat("The specified bucket ", fname, " was not found."));
@@ -1468,7 +1471,7 @@
// Check if the object exists.
GcsFileStat stat;
- const Status status = StatForObject(fname, bucket, object, &stat);
+ const absl::Status status = StatForObject(fname, bucket, object, &stat);
if (!absl::IsNotFound(status)) {
return status;
}
@@ -1477,31 +1480,32 @@
bool result;
TF_RETURN_IF_ERROR(FolderExists(fname, &result));
if (result) {
- return OkStatus();
+ return absl::OkStatus();
}
return errors::NotFound("The specified path ", fname, " was not found.");
}
-Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket,
- const string& object, bool* result) {
+absl::Status GcsFileSystem::ObjectExists(const string& fname,
+ const string& bucket,
+ const string& object, bool* result) {
GcsFileStat stat;
- const Status status = StatForObject(fname, bucket, object, &stat);
+ const absl::Status status = StatForObject(fname, bucket, object, &stat);
switch (static_cast<int>(status.code())) {
case static_cast<int>(error::Code::OK):
*result = !stat.base.is_directory;
- return OkStatus();
+ return absl::OkStatus();
case static_cast<int>(error::Code::NOT_FOUND):
*result = false;
- return OkStatus();
+ return absl::OkStatus();
default:
return status;
}
}
-Status GcsFileSystem::UncachedStatForObject(const string& fname,
- const string& bucket,
- const string& object,
- GcsFileStat* stat) {
+absl::Status GcsFileSystem::UncachedStatForObject(const string& fname,
+ const string& bucket,
+ const string& object,
+ GcsFileStat* stat) {
std::vector<char> output_buffer;
std::unique_ptr<HttpRequest> request;
TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request),
@@ -1542,7 +1546,7 @@
<< "; mtime_nsec: " << stat->base.mtime_nsec
<< "; updated: " << updated;
- if (str_util::EndsWith(fname, "/")) {
+ if (absl::EndsWith(fname, "/")) {
// In GCS a path can be both a directory and a file, both it is uncommon for
// other file systems. To avoid the ambiguity, if a path ends with "/" in
// GCS, we always regard it as a directory mark or a virtual directory.
@@ -1550,11 +1554,13 @@
} else {
stat->base.is_directory = false;
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
- const string& object, GcsFileStat* stat) {
+absl::Status GcsFileSystem::StatForObject(const string& fname,
+ const string& bucket,
+ const string& object,
+ GcsFileStat* stat) {
if (object.empty()) {
return errors::InvalidArgument(strings::Printf(
"'object' must be a non-empty string. (File: %s)", fname.c_str()));
@@ -1565,26 +1571,27 @@
[this, &bucket, &object](const string& fname, GcsFileStat* stat) {
return UncachedStatForObject(fname, bucket, object, stat);
}));
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
- const Status status = GetBucketMetadata(bucket, nullptr);
+absl::Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
+ const absl::Status status = GetBucketMetadata(bucket, nullptr);
switch (static_cast<absl::StatusCode>(status.code())) {
case absl::StatusCode::kOk:
*result = true;
- return OkStatus();
+ return absl::OkStatus();
case absl::StatusCode::kNotFound:
*result = false;
- return OkStatus();
+ return absl::OkStatus();
default:
return status;
}
}
-Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) {
+absl::Status GcsFileSystem::CheckBucketLocationConstraint(
+ const string& bucket) {
if (allowed_locations_.empty()) {
- return OkStatus();
+ return absl::OkStatus();
}
// Avoid calling external API's in the constructor
@@ -1597,7 +1604,7 @@
string location;
TF_RETURN_IF_ERROR(GetBucketLocation(bucket, &location));
if (allowed_locations_.find(location) != allowed_locations_.end()) {
- return OkStatus();
+ return absl::OkStatus();
}
return errors::FailedPrecondition(strings::Printf(
@@ -1606,11 +1613,11 @@
absl::StrJoin(allowed_locations_, ", ").c_str()));
}
-Status GcsFileSystem::GetBucketLocation(const string& bucket,
- string* location) {
+absl::Status GcsFileSystem::GetBucketLocation(const string& bucket,
+ string* location) {
auto compute_func = [this](const string& bucket, string* location) {
std::vector<char> result_buffer;
- Status status = GetBucketMetadata(bucket, &result_buffer);
+ absl::Status status = GetBucketMetadata(bucket, &result_buffer);
Json::Value result;
TF_RETURN_IF_ERROR(ParseJson(result_buffer, &result));
string bucket_location;
@@ -1618,17 +1625,17 @@
GetStringValue(result, kBucketMetadataLocationKey, &bucket_location));
// Lowercase the GCS location to be case insensitive for allowed locations.
*location = absl::AsciiStrToLower(bucket_location);
- return OkStatus();
+ return absl::OkStatus();
};
TF_RETURN_IF_ERROR(
bucket_location_cache_->LookupOrCompute(bucket, location, compute_func));
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::GetBucketMetadata(const string& bucket,
- std::vector<char>* result_buffer) {
+absl::Status GcsFileSystem::GetBucketMetadata(
+ const string& bucket, std::vector<char>* result_buffer) {
std::unique_ptr<HttpRequest> request;
TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket));
@@ -1641,7 +1648,7 @@
return request->Send();
}
-Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
+absl::Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
StatCache::ComputeFunc compute_func = [this](const string& dirname,
GcsFileStat* stat) {
std::vector<string> children;
@@ -1650,36 +1657,36 @@
true /* include_self_directory_marker */));
if (!children.empty()) {
stat->base = DIRECTORY_STAT;
- return OkStatus();
+ return absl::OkStatus();
} else {
return errors::InvalidArgument("Not a directory!");
}
};
GcsFileStat stat;
- Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname), &stat,
- compute_func);
+ absl::Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname),
+ &stat, compute_func);
if (s.ok()) {
*result = stat.base.is_directory;
- return OkStatus();
+ return absl::OkStatus();
}
if (absl::IsInvalidArgument(s)) {
*result = false;
- return OkStatus();
+ return absl::OkStatus();
}
return s;
}
-Status GcsFileSystem::GetChildren(const string& dirname,
- TransactionToken* token,
- std::vector<string>* result) {
+absl::Status GcsFileSystem::GetChildren(const string& dirname,
+ TransactionToken* token,
+ std::vector<string>* result) {
return GetChildrenBounded(dirname, UINT64_MAX, result,
false /* recursively */,
false /* include_self_directory_marker */);
}
-Status GcsFileSystem::GetMatchingPaths(const string& pattern,
- TransactionToken* token,
- std::vector<string>* results) {
+absl::Status GcsFileSystem::GetMatchingPaths(const string& pattern,
+ TransactionToken* token,
+ std::vector<string>* results) {
MatchingPathsCache::ComputeFunc compute_func =
[this](const string& pattern, std::vector<string>* results) {
results->clear();
@@ -1700,7 +1707,7 @@
// To handle `/` in the object names, we need to remove it from `dir`
// and then use `StrCat` to insert it back.
- const StringPiece dir_no_slash = str_util::StripSuffix(dir, "/");
+ const absl::string_view dir_no_slash = absl::StripSuffix(dir, "/");
// Match all obtained paths to the input pattern.
for (const auto& path : files_and_folders) {
@@ -1715,18 +1722,16 @@
results->push_back(full_path);
}
}
- return OkStatus();
+ return absl::OkStatus();
};
TF_RETURN_IF_ERROR(
matching_paths_cache_->LookupOrCompute(pattern, results, compute_func));
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::GetChildrenBounded(const string& dirname,
- uint64 max_results,
- std::vector<string>* result,
- bool recursive,
- bool include_self_directory_marker) {
+absl::Status GcsFileSystem::GetChildrenBounded(
+ const string& dirname, uint64 max_results, std::vector<string>* result,
+ bool recursive, bool include_self_directory_marker) {
if (!result) {
return errors::InvalidArgument("'result' cannot be null");
}
@@ -1786,7 +1791,7 @@
// The names should be relative to the 'dirname'. That means the
// 'object_prefix', which is part of 'dirname', should be removed from
// the beginning of 'name'.
- StringPiece relative_path(name);
+ absl::string_view relative_path(name);
if (!absl::ConsumePrefix(&relative_path, object_prefix)) {
return errors::Internal(strings::StrCat(
"Unexpected response: the returned file name ", name,
@@ -1796,7 +1801,7 @@
result->emplace_back(relative_path);
}
if (++retrieved_results >= max_results) {
- return OkStatus();
+ return absl::OkStatus();
}
}
}
@@ -1815,7 +1820,7 @@
"response.");
}
const string& prefix_str = prefix.asString();
- StringPiece relative_path(prefix_str);
+ absl::string_view relative_path(prefix_str);
if (!absl::ConsumePrefix(&relative_path, object_prefix)) {
return errors::Internal(
"Unexpected response: the returned folder name ", prefix_str,
@@ -1823,13 +1828,13 @@
}
result->emplace_back(relative_path);
if (++retrieved_results >= max_results) {
- return OkStatus();
+ return absl::OkStatus();
}
}
}
const auto token = root.get("nextPageToken", Json::Value::null);
if (token.isNull()) {
- return OkStatus();
+ return absl::OkStatus();
}
if (!token.isString()) {
return errors::Internal(
@@ -1839,8 +1844,8 @@
}
}
-Status GcsFileSystem::Stat(const string& fname, TransactionToken* token,
- FileStatistics* stat) {
+absl::Status GcsFileSystem::Stat(const string& fname, TransactionToken* token,
+ FileStatistics* stat) {
if (!stat) {
return errors::Internal("'stat' cannot be nullptr.");
}
@@ -1851,16 +1856,16 @@
TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
if (is_bucket) {
*stat = DIRECTORY_STAT;
- return OkStatus();
+ return absl::OkStatus();
}
return errors::NotFound("The specified bucket ", fname, " was not found.");
}
GcsFileStat gcs_stat;
- const Status status = StatForObject(fname, bucket, object, &gcs_stat);
+ const absl::Status status = StatForObject(fname, bucket, object, &gcs_stat);
if (status.ok()) {
*stat = gcs_stat.base;
- return OkStatus();
+ return absl::OkStatus();
}
if (!absl::IsNotFound(status)) {
return status;
@@ -1869,12 +1874,13 @@
TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
if (is_folder) {
*stat = DIRECTORY_STAT;
- return OkStatus();
+ return absl::OkStatus();
}
return errors::NotFound("The specified path ", fname, " was not found.");
}
-Status GcsFileSystem::DeleteFile(const string& fname, TransactionToken* token) {
+absl::Status GcsFileSystem::DeleteFile(const string& fname,
+ TransactionToken* token) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
@@ -1887,11 +1893,11 @@
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when deleting ", fname);
ClearFileCaches(fname);
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::CreateDir(const string& dirname,
- TransactionToken* token) {
+absl::Status GcsFileSystem::CreateDir(const string& dirname,
+ TransactionToken* token) {
string dirname_with_slash = MaybeAppendSlash(dirname);
VLOG(3) << "CreateDir: creating directory with dirname: " << dirname
<< " and dirname_with_slash: " << dirname_with_slash;
@@ -1901,7 +1907,7 @@
if (object.empty()) {
bool is_bucket;
TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
- return is_bucket ? OkStatus()
+ return is_bucket ? absl::OkStatus()
: errors::NotFound("The specified bucket ",
dirname_with_slash, " was not found.");
}
@@ -1924,10 +1930,10 @@
request->SetPostEmptyBody();
request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
- const Status& status = request->Send();
+ const absl::Status& status = request->Send();
if (status.ok()) {
VLOG(3) << "CreateDir: finished uploading directory " << dirname;
- return OkStatus();
+ return absl::OkStatus();
}
if (request->GetResponseCode() != HTTP_CODE_PRECONDITION_FAILED) {
TF_RETURN_WITH_CONTEXT_IF_ERROR(status, " when uploading ",
@@ -1940,8 +1946,8 @@
// Checks that the directory is empty (i.e no objects with this prefix exist).
// Deletes the GCS directory marker if it exists.
-Status GcsFileSystem::DeleteDir(const string& dirname,
- TransactionToken* token) {
+absl::Status GcsFileSystem::DeleteDir(const string& dirname,
+ TransactionToken* token) {
std::vector<string> children;
// A directory is considered empty either if there are no matching objects
// with the corresponding name prefix or if there is exactly one matching
@@ -1958,11 +1964,12 @@
// This is the directory marker object. Delete it.
return DeleteFile(MaybeAppendSlash(dirname), token);
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::GetFileSize(const string& fname, TransactionToken* token,
- uint64* file_size) {
+absl::Status GcsFileSystem::GetFileSize(const string& fname,
+ TransactionToken* token,
+ uint64* file_size) {
if (!file_size) {
return errors::Internal("'file_size' cannot be nullptr.");
}
@@ -1974,11 +1981,11 @@
FileStatistics stat;
TF_RETURN_IF_ERROR(Stat(fname, token, &stat));
*file_size = stat.length;
- return OkStatus();
+ return absl::OkStatus();
}
-Status GcsFileSystem::RenameFile(const string& src, const string& target,
- TransactionToken* token) {
+absl::Status GcsFileSystem::RenameFile(const string& src, const string& target,
+ TransactionToken* token) {
if (!IsDirectory(src, token).ok()) {
return RenameObject(src, target);
}
@@ -1991,11 +1998,12 @@
TF_RETURN_IF_ERROR(
RenameObject(JoinGcsPath(src, subpath), JoinGcsPath(target, subpath)));
}
- return OkStatus();
+ return absl::OkStatus();
}
// Uses a GCS API command to copy the object and then deletes the old one.
-Status GcsFileSystem::RenameObject(const string& src, const string& target) {
+absl::Status GcsFileSystem::RenameObject(const string& src,
+ const string& target) {
VLOG(3) << "RenameObject: started gs://" << src << " to " << target;
string src_bucket, src_object, target_bucket, target_object;
TF_RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object));
@@ -2040,15 +2048,15 @@
[this, &src]() { return DeleteFile(src, nullptr); }, retry_config_);
}
-Status GcsFileSystem::IsDirectory(const string& fname,
- TransactionToken* token) {
+absl::Status GcsFileSystem::IsDirectory(const string& fname,
+ TransactionToken* token) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
if (object.empty()) {
bool is_bucket;
TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
if (is_bucket) {
- return OkStatus();
+ return absl::OkStatus();
}
return errors::NotFound("The specified bucket gs://", bucket,
" was not found.");
@@ -2056,7 +2064,7 @@
bool is_folder;
TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
if (is_folder) {
- return OkStatus();
+ return absl::OkStatus();
}
bool is_object;
TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &is_object));
@@ -2067,10 +2075,10 @@
return errors::NotFound("The specified path ", fname, " was not found.");
}
-Status GcsFileSystem::DeleteRecursively(const string& dirname,
- TransactionToken* token,
- int64_t* undeleted_files,
- int64_t* undeleted_dirs) {
+absl::Status GcsFileSystem::DeleteRecursively(const string& dirname,
+ TransactionToken* token,
+ int64_t* undeleted_files,
+ int64_t* undeleted_dirs) {
if (!undeleted_files || !undeleted_dirs) {
return errors::Internal(
"'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
@@ -2079,7 +2087,7 @@
*undeleted_dirs = 0;
if (!IsDirectory(dirname, token).ok()) {
*undeleted_dirs = 1;
- return Status(
+ return absl::Status(
absl::StatusCode::kNotFound,
strings::StrCat(dirname, " doesn't exist or not a directory."));
}
@@ -2106,7 +2114,7 @@
}
}
}
- return OkStatus();
+ return absl::OkStatus();
}
// Flushes all caches for filesystem metadata and file contents. Useful for
@@ -2148,7 +2156,8 @@
// Creates an HttpRequest and sets several parameters that are common to all
// requests. All code (in GcsFileSystem) that creates an HttpRequest should
// go through this method, rather than directly using http_request_factory_.
-Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
+absl::Status GcsFileSystem::CreateHttpRequest(
+ std::unique_ptr<HttpRequest>* request) {
std::unique_ptr<HttpRequest> new_request{http_request_factory_->Create()};
if (dns_cache_) {
dns_cache_->AnnotateRequest(new_request.get());
@@ -2177,7 +2186,7 @@
}
*request = std::move(new_request);
- return OkStatus();
+ return absl::OkStatus();
}
RetryingGcsFileSystem::RetryingGcsFileSystem()
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h
index 17725e8..f7452a4 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h
@@ -66,7 +66,7 @@
// Helper function to extract an environment variable and convert it into a
// value of type T.
template <typename T>
-bool GetEnvVar(const char* varname, bool (*convert)(StringPiece, T*),
+bool GetEnvVar(const char* varname, bool (*convert)(absl::string_view, T*),
T* value) {
const char* env_value = std::getenv(varname);
if (env_value == nullptr) {
@@ -144,48 +144,54 @@
TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT;
- Status NewRandomAccessFile(
+ absl::Status NewRandomAccessFile(
const string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) override;
- Status NewWritableFile(const string& fname, TransactionToken* token,
- std::unique_ptr<WritableFile>* result) override;
+ absl::Status NewWritableFile(const string& fname, TransactionToken* token,
+ std::unique_ptr<WritableFile>* result) override;
- Status NewAppendableFile(const string& fname, TransactionToken* token,
- std::unique_ptr<WritableFile>* result) override;
+ absl::Status NewAppendableFile(
+ const string& fname, TransactionToken* token,
+ std::unique_ptr<WritableFile>* result) override;
- Status NewReadOnlyMemoryRegionFromFile(
+ absl::Status NewReadOnlyMemoryRegionFromFile(
const string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
- Status FileExists(const string& fname, TransactionToken* token) override;
+ absl::Status FileExists(const string& fname,
+ TransactionToken* token) override;
- Status Stat(const string& fname, TransactionToken* token,
- FileStatistics* stat) override;
+ absl::Status Stat(const string& fname, TransactionToken* token,
+ FileStatistics* stat) override;
- Status GetChildren(const string& dir, TransactionToken* token,
- std::vector<string>* result) override;
+ absl::Status GetChildren(const string& dir, TransactionToken* token,
+ std::vector<string>* result) override;
- Status GetMatchingPaths(const string& pattern, TransactionToken* token,
- std::vector<string>* results) override;
+ absl::Status GetMatchingPaths(const string& pattern, TransactionToken* token,
+ std::vector<string>* results) override;
- Status DeleteFile(const string& fname, TransactionToken* token) override;
+ absl::Status DeleteFile(const string& fname,
+ TransactionToken* token) override;
- Status CreateDir(const string& dirname, TransactionToken* token) override;
+ absl::Status CreateDir(const string& dirname,
+ TransactionToken* token) override;
- Status DeleteDir(const string& dirname, TransactionToken* token) override;
+ absl::Status DeleteDir(const string& dirname,
+ TransactionToken* token) override;
- Status GetFileSize(const string& fname, TransactionToken* token,
- uint64* file_size) override;
+ absl::Status GetFileSize(const string& fname, TransactionToken* token,
+ uint64* file_size) override;
- Status RenameFile(const string& src, const string& target,
- TransactionToken* token) override;
+ absl::Status RenameFile(const string& src, const string& target,
+ TransactionToken* token) override;
- Status IsDirectory(const string& fname, TransactionToken* token) override;
+ absl::Status IsDirectory(const string& fname,
+ TransactionToken* token) override;
- Status DeleteRecursively(const string& dirname, TransactionToken* token,
- int64_t* undeleted_files,
- int64_t* undeleted_dirs) override;
+ absl::Status DeleteRecursively(const string& dirname, TransactionToken* token,
+ int64_t* undeleted_files,
+ int64_t* undeleted_dirs) override;
void FlushCaches(TransactionToken* token) override;
@@ -267,7 +273,7 @@
write(write) {}
};
- Status CreateHttpRequest(std::unique_ptr<HttpRequest>* request);
+ absl::Status CreateHttpRequest(std::unique_ptr<HttpRequest>* request);
/// \brief Sets a new AuthProvider on the GCS FileSystem.
///
@@ -289,37 +295,38 @@
size_t block_size, size_t max_bytes, uint64 max_staleness);
/// Loads file contents from GCS for a given filename, offset, and length.
- virtual Status LoadBufferFromGCS(const string& fname, size_t offset, size_t n,
- char* buffer, size_t* bytes_transferred);
+ virtual absl::Status LoadBufferFromGCS(const string& fname, size_t offset,
+ size_t n, char* buffer,
+ size_t* bytes_transferred);
// Creates an upload session for an upcoming GCS object upload.
- virtual Status CreateNewUploadSession(uint64 start_offset,
- const std::string& object_to_upload,
- const std::string& bucket,
- uint64 file_size,
- const std::string& gcs_path,
- UploadSessionHandle* session_handle);
+ virtual absl::Status CreateNewUploadSession(
+ uint64 start_offset, const std::string& object_to_upload,
+ const std::string& bucket, uint64 file_size, const std::string& gcs_path,
+ UploadSessionHandle* session_handle);
// Uploads object data to session.
- virtual Status UploadToSession(const std::string& session_uri,
- uint64 start_offset, uint64 already_uploaded,
- const std::string& tmp_content_filename,
- uint64 file_size,
- const std::string& file_path);
+ virtual absl::Status UploadToSession(const std::string& session_uri,
+ uint64 start_offset,
+ uint64 already_uploaded,
+ const std::string& tmp_content_filename,
+ uint64 file_size,
+ const std::string& file_path);
/// \brief Requests status of a previously initiated upload session.
///
/// If the upload has already succeeded, sets 'completed' to true.
/// Otherwise sets 'completed' to false and 'uploaded' to the currently
/// uploaded size in bytes.
- virtual Status RequestUploadSessionStatus(const string& session_uri,
- uint64 file_size,
- const std::string& gcs_path,
- bool* completed, uint64* uploaded);
+ virtual absl::Status RequestUploadSessionStatus(const string& session_uri,
+ uint64 file_size,
+ const std::string& gcs_path,
+ bool* completed,
+ uint64* uploaded);
- Status ParseGcsPathForScheme(StringPiece fname, string scheme,
- bool empty_object_ok, string* bucket,
- string* object);
+ absl::Status ParseGcsPathForScheme(absl::string_view fname, string scheme,
+ bool empty_object_ok, string* bucket,
+ string* object);
/// \brief Splits a GCS path to a bucket and an object.
///
@@ -327,8 +334,9 @@
/// "bucket-name" and "path/to/file.txt".
/// If fname only contains the bucket and empty_object_ok = true, the returned
/// object is empty.
- virtual Status ParseGcsPath(StringPiece fname, bool empty_object_ok,
- string* bucket, string* object);
+ virtual absl::Status ParseGcsPath(absl::string_view fname,
+ bool empty_object_ok, string* bucket,
+ string* object);
std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_;
@@ -348,7 +356,7 @@
/// \brief Checks if the bucket exists. Returns OK if the check succeeded.
///
/// 'result' is set if the function returns OK. 'result' cannot be nullptr.
- Status BucketExists(const string& bucket, bool* result);
+ absl::Status BucketExists(const string& bucket, bool* result);
/// \brief Retrieves the GCS bucket location. Returns OK if the location was
/// retrieved.
@@ -359,28 +367,28 @@
/// This requires the bucket metadata permission.
/// Repeated calls for the same bucket are cached so this function can be
/// called frequently without causing an extra API call
- Status GetBucketLocation(const string& bucket, string* location);
+ absl::Status GetBucketLocation(const string& bucket, string* location);
/// \brief Check if the GCS buckets location is allowed with the current
/// constraint configuration
- Status CheckBucketLocationConstraint(const string& bucket);
+ absl::Status CheckBucketLocationConstraint(const string& bucket);
/// \brief Given the input bucket `bucket`, fills `result_buffer` with the
/// results of the metadata. Returns OK if the API call succeeds without
/// error.
- Status GetBucketMetadata(const string& bucket,
- std::vector<char>* result_buffer);
+ absl::Status GetBucketMetadata(const string& bucket,
+ std::vector<char>* result_buffer);
/// \brief Checks if the object exists. Returns OK if the check succeeded.
///
/// 'result' is set if the function returns OK. 'result' cannot be nullptr.
- Status ObjectExists(const string& fname, const string& bucket,
- const string& object, bool* result);
+ absl::Status ObjectExists(const string& fname, const string& bucket,
+ const string& object, bool* result);
/// \brief Checks if the folder exists. Returns OK if the check succeeded.
///
/// 'result' is set if the function returns OK. 'result' cannot be nullptr.
- Status FolderExists(const string& dirname, bool* result);
+ absl::Status FolderExists(const string& dirname, bool* result);
/// \brief Internal version of GetChildren with more knobs.
///
@@ -390,19 +398,19 @@
/// If 'include_self_directory_marker' is true and there is a GCS directory
/// marker at the path 'dir', GetChildrenBound will return an empty string
/// as one of the children that represents this marker.
- Status GetChildrenBounded(const string& dir, uint64 max_results,
- std::vector<string>* result, bool recursively,
- bool include_self_directory_marker);
+ absl::Status GetChildrenBounded(const string& dir, uint64 max_results,
+ std::vector<string>* result, bool recursively,
+ bool include_self_directory_marker);
/// Retrieves file statistics assuming fname points to a GCS object. The data
/// may be read from cache or from GCS directly.
- Status StatForObject(const string& fname, const string& bucket,
- const string& object, GcsFileStat* stat);
+ absl::Status StatForObject(const string& fname, const string& bucket,
+ const string& object, GcsFileStat* stat);
/// Retrieves file statistics of file fname directly from GCS.
- Status UncachedStatForObject(const string& fname, const string& bucket,
- const string& object, GcsFileStat* stat);
+ absl::Status UncachedStatForObject(const string& fname, const string& bucket,
+ const string& object, GcsFileStat* stat);
- Status RenameObject(const string& src, const string& target);
+ absl::Status RenameObject(const string& src, const string& target);
// Clear all the caches related to the file with name `filename`.
void ClearFileCaches(const string& fname);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc
index 9221128..9d9d308 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc
@@ -17,7 +17,7 @@
#include <fstream>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/cloud/http_request_fake.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/str_util.h"
@@ -45,17 +45,17 @@
class FakeAuthProvider : public AuthProvider {
public:
- Status GetToken(string* token) override {
+ absl::Status GetToken(string* token) override {
*token = "fake_token";
- return OkStatus();
+ return absl::OkStatus();
}
};
class FakeZoneProvider : public ZoneProvider {
public:
- Status GetZone(string* zone) override {
+ absl::Status GetZone(string* zone) override {
*zone = "us-east1-b";
- return OkStatus();
+ return absl::OkStatus();
}
};
@@ -88,12 +88,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[6];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk.
TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -135,12 +135,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[6];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk.
TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -183,12 +183,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[6];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk.
EXPECT_TRUE(
@@ -230,12 +230,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[10];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk.
TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -271,12 +271,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[5];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk. Even though the backend response is out-of-range,
// we should get a OK status since we're just reading the first 5 bytes.
@@ -323,12 +323,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[5];
- StringPiece result;
+ absl::string_view result;
TF_EXPECT_OK(file->Read(1, sizeof(scratch), &result, scratch));
EXPECT_EQ("12345", result);
@@ -365,12 +365,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[10];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk. Since the first read is out-of-range,
// we don't cache the out-of-range flag and each subsequent read triggers a
@@ -413,12 +413,12 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
- StringPiece filename;
+ absl::string_view filename;
TF_EXPECT_OK(file->Name(&filename));
EXPECT_EQ(filename, "gs://bucket/random_access.txt");
char scratch[10];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk.
EXPECT_TRUE(
@@ -574,7 +574,7 @@
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
char small_scratch[3];
- StringPiece result;
+ absl::string_view result;
// Read the first chunk.
TF_EXPECT_OK(file->Read(0, sizeof(small_scratch), &result, small_scratch));
@@ -629,7 +629,7 @@
nullptr /* gcs additional header */, false /* compose append */);
char scratch[100];
- StringPiece result;
+ absl::string_view result;
{
// We are instantiating this in an enclosed scope to make sure after the
// unique ptr goes out of scope, we can still access result.
@@ -716,7 +716,7 @@
nullptr /* gcs additional header */, false /* compose append */);
char scratch[100];
- StringPiece result;
+ absl::string_view result;
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
@@ -766,7 +766,7 @@
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */, false /* compose append */);
char scratch[100];
- StringPiece result;
+ absl::string_view result;
// There should only be two HTTP requests issued to GCS even though we iterate
// this loop 10 times. This shows that the underlying FileBlockCache persists
// across file close/open boundaries.
@@ -841,7 +841,7 @@
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
char scratch[5];
- StringPiece result;
+ absl::string_view result;
// First read.
TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -908,7 +908,7 @@
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
char scratch[6];
- StringPiece result;
+ absl::string_view result;
EXPECT_TRUE(
errors::IsInternal(file->Read(0, sizeof(scratch), &result, scratch)));
@@ -972,7 +972,7 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/path/writeable", nullptr, &rfile));
char scratch[100];
- StringPiece result;
+ absl::string_view result;
TF_EXPECT_OK(rfile->Read(0, 4, &result, scratch));
EXPECT_EQ("0123", result);
// Open the writable file.
@@ -1107,7 +1107,7 @@
"Timeouts: 5 1 10\n"
"Header Content-Range: bytes */17\n"
"Put: yes\n",
- "", OkStatus(), nullptr, {}, 201),
+ "", absl::OkStatus(), nullptr, {}, 201),
new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
"path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n"
@@ -1138,7 +1138,7 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/path/writeable", nullptr, &rfile));
char scratch[100];
- StringPiece result;
+ absl::string_view result;
TF_EXPECT_OK(rfile->Read(0, 4, &result, scratch));
EXPECT_EQ("0123", result);
// Now write to the same file. Once the write succeeds, the cached block will
@@ -1402,7 +1402,7 @@
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/path/appendable", nullptr, &rfile));
char scratch[100];
- StringPiece result;
+ absl::string_view result;
TF_EXPECT_OK(rfile->Read(0, 8, &result, scratch));
EXPECT_EQ("content1", result);
// Closing the appendable file will flush its contents to GCS, triggering HTTP
@@ -1496,8 +1496,9 @@
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
"gs://bucket/path/random_access.txt", nullptr, ®ion));
- EXPECT_EQ(content, StringPiece(reinterpret_cast<const char*>(region->data()),
- region->length()));
+ EXPECT_EQ(content,
+ absl::string_view(reinterpret_cast<const char*>(region->data()),
+ region->length()));
}
TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) {
@@ -2262,7 +2263,7 @@
// Do an initial read of the file to load its contents into the block cache.
char scratch[100];
- StringPiece result;
+ absl::string_view result;
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/path/file1.txt", nullptr, &file));
@@ -2656,7 +2657,7 @@
// Do an initial read of the source and destination files to load their
// contents into the block cache.
char scratch[100];
- StringPiece result;
+ absl::string_view result;
std::unique_ptr<RandomAccessFile> src;
std::unique_ptr<RandomAccessFile> dst;
TF_EXPECT_OK(
@@ -3798,7 +3799,7 @@
fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
char scratch[6];
- StringPiece result;
+ absl::string_view result;
TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
EXPECT_EQ("012345", result);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc
index dfd8310..658629f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc
@@ -15,7 +15,7 @@
#include "tsl/platform/cloud/gcs_throttle.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/str_util.h"
#include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc
index f1b62fb..7f1f94d 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc
@@ -82,7 +82,7 @@
}
/// Returns the credentials file name from the env variable.
-Status GetEnvironmentVariableFileName(string* filename) {
+absl::Status GetEnvironmentVariableFileName(string* filename) {
if (!filename) {
return errors::FailedPrecondition("'filename' cannot be nullptr.");
}
@@ -92,11 +92,11 @@
" is not set or corrupt."));
}
*filename = result;
- return OkStatus();
+ return absl::OkStatus();
}
/// Returns the well known file produced by command 'gcloud auth login'.
-Status GetWellKnownFileName(string* filename) {
+absl::Status GetWellKnownFileName(string* filename) {
if (!filename) {
return errors::FailedPrecondition("'filename' cannot be nullptr.");
}
@@ -118,7 +118,7 @@
"Could not find the credentials file in the standard gcloud location.");
}
*filename = result;
- return OkStatus();
+ return absl::OkStatus();
}
} // namespace
@@ -138,42 +138,42 @@
std::move(compute_engine_metadata_client)),
env_(env) {}
-Status GoogleAuthProvider::GetToken(string* t) {
+absl::Status GoogleAuthProvider::GetToken(string* t) {
mutex_lock lock(mu_);
const uint64 now_sec = env_->NowSeconds();
if (now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
*t = current_token_;
- return OkStatus();
+ return absl::OkStatus();
}
if (GetTokenForTesting().ok()) {
*t = current_token_;
- return OkStatus();
+ return absl::OkStatus();
}
auto token_from_files_status = GetTokenFromFiles();
if (token_from_files_status.ok()) {
*t = current_token_;
- return OkStatus();
+ return absl::OkStatus();
}
char* no_gce_check_var = std::getenv(kNoGceCheck);
bool skip_gce_check = no_gce_check_var != nullptr &&
absl::EqualsIgnoreCase(no_gce_check_var, "true");
- Status token_from_gce_status;
+ absl::Status token_from_gce_status;
if (skip_gce_check) {
token_from_gce_status =
- Status(absl::StatusCode::kCancelled,
- strings::StrCat("GCE check skipped due to presence of $",
- kNoGceCheck, " environment variable."));
+ absl::Status(absl::StatusCode::kCancelled,
+ strings::StrCat("GCE check skipped due to presence of $",
+ kNoGceCheck, " environment variable."));
} else {
token_from_gce_status = GetTokenFromGce();
}
if (token_from_gce_status.ok()) {
*t = current_token_;
- return OkStatus();
+ return absl::OkStatus();
}
if (skip_gce_check) {
@@ -203,10 +203,10 @@
}
current_token_ = "";
- return OkStatus();
+ return absl::OkStatus();
}
-Status GoogleAuthProvider::GetTokenFromFiles() {
+absl::Status GoogleAuthProvider::GetTokenFromFiles() {
string credentials_filename;
if (!GetEnvironmentVariableFileName(&credentials_filename).ok() &&
!GetWellKnownFileName(&credentials_filename).ok()) {
@@ -231,33 +231,33 @@
return errors::FailedPrecondition(
"Unexpected content of the JSON credentials file.");
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status GoogleAuthProvider::GetTokenFromGce() {
+absl::Status GoogleAuthProvider::GetTokenFromGce() {
std::vector<char> response_buffer;
const uint64 request_timestamp_sec = env_->NowSeconds();
TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata(
kGceTokenPath, &response_buffer));
- StringPiece response =
- StringPiece(&response_buffer[0], response_buffer.size());
+ absl::string_view response =
+ absl::string_view(&response_buffer[0], response_buffer.size());
TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse(
response, request_timestamp_sec, ¤t_token_,
&expiration_timestamp_sec_));
- return OkStatus();
+ return absl::OkStatus();
}
-Status GoogleAuthProvider::GetTokenForTesting() {
+absl::Status GoogleAuthProvider::GetTokenForTesting() {
const char* token = std::getenv(kGoogleAuthTokenForTesting);
if (!token) {
return errors::NotFound("The env variable for testing was not set.");
}
expiration_timestamp_sec_ = UINT64_MAX;
current_token_ = token;
- return OkStatus();
+ return absl::OkStatus();
}
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h
index 63b7ea6..38ab66d 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h
@@ -40,20 +40,20 @@
/// \brief Returns the short-term authentication bearer token.
///
/// Safe for concurrent use by multiple threads.
- Status GetToken(string* token) override;
+ absl::Status GetToken(string* token) override;
private:
/// \brief Gets the bearer token from files.
///
/// Tries the file from $GOOGLE_APPLICATION_CREDENTIALS and the
/// standard gcloud tool's location.
- Status GetTokenFromFiles() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ absl::Status GetTokenFromFiles() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Gets the bearer token from Google Compute Engine environment.
- Status GetTokenFromGce() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ absl::Status GetTokenFromGce() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Gets the bearer token from the system env variable, for testing purposes.
- Status GetTokenForTesting() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ absl::Status GetTokenForTesting() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
std::unique_ptr<OAuthClient> oauth_client_;
std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc
index 6f3072f..e7d6c4a 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc
@@ -17,7 +17,7 @@
#include <stdlib.h>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/cloud/http_request_fake.h"
#include "tsl/platform/path.h"
#include "tsl/platform/test.h"
@@ -40,23 +40,24 @@
class FakeOAuthClient : public OAuthClient {
public:
- Status GetTokenFromServiceAccountJson(
- Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
- string* token, uint64* expiration_timestamp_sec) override {
- provided_credentials_json = json;
- *token = return_token;
- *expiration_timestamp_sec = return_expiration_timestamp;
- return OkStatus();
- }
-
- /// Retrieves a bearer token using a refresh token.
- Status GetTokenFromRefreshTokenJson(
- Json::Value json, StringPiece oauth_server_uri, string* token,
+ absl::Status GetTokenFromServiceAccountJson(
+ Json::Value json, absl::string_view oauth_server_uri,
+ absl::string_view scope, string* token,
uint64* expiration_timestamp_sec) override {
provided_credentials_json = json;
*token = return_token;
*expiration_timestamp_sec = return_expiration_timestamp;
- return OkStatus();
+ return absl::OkStatus();
+ }
+
+ /// Retrieves a bearer token using a refresh token.
+ absl::Status GetTokenFromRefreshTokenJson(
+ Json::Value json, absl::string_view oauth_server_uri, string* token,
+ uint64* expiration_timestamp_sec) override {
+ provided_credentials_json = json;
+ *token = return_token;
+ *expiration_timestamp_sec = return_expiration_timestamp;
+ return absl::OkStatus();
}
string return_token;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h
index a3a3136..8102dd6 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h
@@ -85,7 +85,8 @@
/// RecordResponse is called after the response has been received.
virtual void RecordResponse(const HttpRequest* request, const string& uri,
- RequestMethod method, const Status& result) = 0;
+ RequestMethod method,
+ const absl::Status& result) = 0;
};
HttpRequest() {}
@@ -124,7 +125,8 @@
///
/// The request body will be taken from the specified file starting from
/// the given offset.
- virtual Status SetPutFromFile(const string& body_filepath, size_t offset) = 0;
+ virtual absl::Status SetPutFromFile(const string& body_filepath,
+ size_t offset) = 0;
/// Makes the request a PUT request with an empty body.
virtual void SetPutEmptyBody() = 0;
@@ -169,7 +171,7 @@
///
/// If the result buffer was defined, the response will be written there.
/// The object is not designed to be re-used after Send() is executed.
- virtual Status Send() = 0;
+ virtual absl::Status Send() = 0;
// Url encodes str and returns a new string.
virtual string EscapeString(const string& str) = 0;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h
index ea1f487..869d2ab 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h
@@ -21,7 +21,7 @@
#include <string>
#include <vector>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/cloud/curl_http_request.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/macros.h"
@@ -38,12 +38,13 @@
public:
/// Return the response for the given request.
FakeHttpRequest(const string& request, const string& response)
- : FakeHttpRequest(request, response, OkStatus(), nullptr, {}, 200) {}
+ : FakeHttpRequest(request, response, absl::OkStatus(), nullptr, {}, 200) {
+ }
/// Return the response with headers for the given request.
FakeHttpRequest(const string& request, const string& response,
const std::map<string, string>& response_headers)
- : FakeHttpRequest(request, response, OkStatus(), nullptr,
+ : FakeHttpRequest(request, response, absl::OkStatus(), nullptr,
response_headers, 200) {}
/// \brief Return the response for the request and capture the POST body.
@@ -51,12 +52,12 @@
/// Post body is not expected to be a part of the 'request' parameter.
FakeHttpRequest(const string& request, const string& response,
string* captured_post_body)
- : FakeHttpRequest(request, response, OkStatus(), captured_post_body, {},
- 200) {}
+ : FakeHttpRequest(request, response, absl::OkStatus(), captured_post_body,
+ {}, 200) {}
/// \brief Return the response and the status for the given request.
FakeHttpRequest(const string& request, const string& response,
- Status response_status, uint64 response_code)
+ absl::Status response_status, uint64 response_code)
: FakeHttpRequest(request, response, response_status, nullptr, {},
response_code) {}
@@ -65,7 +66,7 @@
///
/// Post body is not expected to be a part of the 'request' parameter.
FakeHttpRequest(const string& request, const string& response,
- Status response_status, string* captured_post_body,
+ absl::Status response_status, string* captured_post_body,
const std::map<string, string>& response_headers,
uint64 response_code)
: expected_request_(request),
@@ -88,20 +89,21 @@
actual_request_ += "Auth Token: " + auth_token + "\n";
}
void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; }
- Status SetPutFromFile(const string& body_filepath, size_t offset) override {
+ absl::Status SetPutFromFile(const string& body_filepath,
+ size_t offset) override {
std::ifstream stream(body_filepath);
const string& content = string(std::istreambuf_iterator<char>(stream),
std::istreambuf_iterator<char>())
.substr(offset);
actual_request_ += "Put body: " + content + "\n";
- return OkStatus();
+ return absl::OkStatus();
}
void SetPostFromBuffer(const char* buffer, size_t size) override {
if (captured_post_body_) {
*captured_post_body_ = string(buffer, size);
} else {
actual_request_ +=
- strings::StrCat("Post body: ", StringPiece(buffer, size), "\n");
+ strings::StrCat("Post body: ", absl::string_view(buffer, size), "\n");
}
}
void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; }
@@ -123,7 +125,7 @@
size_t GetResultBufferDirectBytesTransferred() override {
return direct_result_bytes_transferred_;
}
- Status Send() override {
+ absl::Status Send() override {
EXPECT_EQ(expected_request_, actual_request())
<< "Unexpected HTTP request.";
if (buffer_) {
@@ -182,7 +184,7 @@
string actual_uri_;
string actual_request_;
string response_;
- Status response_status_;
+ absl::Status response_status_;
string* captured_post_body_ = nullptr;
std::map<string, string> response_headers_;
uint64 response_code_ = 0;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc
index c983577..7480680 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc
@@ -49,8 +49,8 @@
constexpr char kGrantType[] =
"urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer";
-Status ReadJsonValue(const Json::Value& json, const string& name,
- Json::Value* value) {
+absl::Status ReadJsonValue(const Json::Value& json, const string& name,
+ Json::Value* value) {
if (!value) {
return errors::FailedPrecondition("'value' cannot be nullptr.");
}
@@ -59,11 +59,11 @@
return errors::FailedPrecondition(
strings::StrCat("Couldn't read a JSON value '", name, "'."));
}
- return OkStatus();
+ return absl::OkStatus();
}
-Status ReadJsonString(const Json::Value& json, const string& name,
- string* value) {
+absl::Status ReadJsonString(const Json::Value& json, const string& name,
+ string* value) {
Json::Value json_value;
TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value));
if (!json_value.isString()) {
@@ -71,11 +71,11 @@
strings::StrCat("JSON value '", name, "' is not string."));
}
*value = json_value.asString();
- return OkStatus();
+ return absl::OkStatus();
}
-Status ReadJsonInt(const Json::Value& json, const string& name,
- int64_t* value) {
+absl::Status ReadJsonInt(const Json::Value& json, const string& name,
+ int64_t* value) {
Json::Value json_value;
TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value));
if (!json_value.isIntegral()) {
@@ -83,11 +83,11 @@
strings::StrCat("JSON value '", name, "' is not integer."));
}
*value = json_value.asInt64();
- return OkStatus();
+ return absl::OkStatus();
}
-Status CreateSignature(RSA* private_key, StringPiece to_sign,
- string* signature) {
+absl::Status CreateSignature(RSA* private_key, absl::string_view to_sign,
+ string* signature) {
if (!private_key || !signature) {
return errors::FailedPrecondition(
"'private_key' and 'signature' cannot be nullptr.");
@@ -126,14 +126,15 @@
if (EVP_DigestSignFinal(md_ctx.get(), sig.get(), &sig_len) != 1) {
return errors::Internal("DigestFinal (signature compute) failed.");
}
- return Base64Encode(StringPiece(reinterpret_cast<char*>(sig.get()), sig_len),
- signature);
+ return Base64Encode(
+ absl::string_view(reinterpret_cast<char*>(sig.get()), sig_len),
+ signature);
}
/// Encodes a claim for a JSON web token (JWT) to make an OAuth request.
-Status EncodeJwtClaim(StringPiece client_email, StringPiece scope,
- StringPiece audience, uint64 request_timestamp_sec,
- string* encoded) {
+absl::Status EncodeJwtClaim(absl::string_view client_email,
+ absl::string_view scope, absl::string_view audience,
+ uint64 request_timestamp_sec, string* encoded) {
// Step 1: create the JSON with the claim.
Json::Value root;
root["iss"] = Json::Value(client_email.data(),
@@ -155,7 +156,7 @@
}
/// Encodes a header for a JSON web token (JWT) to make an OAuth request.
-Status EncodeJwtHeader(StringPiece key_id, string* encoded) {
+absl::Status EncodeJwtHeader(absl::string_view key_id, string* encoded) {
// Step 1: create the JSON with the header.
Json::Value root;
root["alg"] = kCryptoAlgorithm;
@@ -180,9 +181,9 @@
std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env)
: http_request_factory_(std::move(http_request_factory)), env_(env) {}
-Status OAuthClient::GetTokenFromServiceAccountJson(
- Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
- string* token, uint64* expiration_timestamp_sec) {
+absl::Status OAuthClient::GetTokenFromServiceAccountJson(
+ Json::Value json, absl::string_view oauth_server_uri,
+ absl::string_view scope, string* token, uint64* expiration_timestamp_sec) {
if (!token || !expiration_timestamp_sec) {
return errors::FailedPrecondition(
"'token' and 'expiration_timestamp_sec' cannot be nullptr.");
@@ -228,15 +229,15 @@
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
- StringPiece response =
- StringPiece(response_buffer.data(), response_buffer.size());
+ absl::string_view response =
+ absl::string_view(response_buffer.data(), response_buffer.size());
TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token,
expiration_timestamp_sec));
- return OkStatus();
+ return absl::OkStatus();
}
-Status OAuthClient::GetTokenFromRefreshTokenJson(
- Json::Value json, StringPiece oauth_server_uri, string* token,
+absl::Status OAuthClient::GetTokenFromRefreshTokenJson(
+ Json::Value json, absl::string_view oauth_server_uri, string* token,
uint64* expiration_timestamp_sec) {
if (!token || !expiration_timestamp_sec) {
return errors::FailedPrecondition(
@@ -260,17 +261,17 @@
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
- StringPiece response =
- StringPiece(response_buffer.data(), response_buffer.size());
+ absl::string_view response =
+ absl::string_view(response_buffer.data(), response_buffer.size());
TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token,
expiration_timestamp_sec));
- return OkStatus();
+ return absl::OkStatus();
}
-Status OAuthClient::ParseOAuthResponse(StringPiece response,
- uint64 request_timestamp_sec,
- string* token,
- uint64* expiration_timestamp_sec) {
+absl::Status OAuthClient::ParseOAuthResponse(absl::string_view response,
+ uint64 request_timestamp_sec,
+ string* token,
+ uint64* expiration_timestamp_sec) {
if (!token || !expiration_timestamp_sec) {
return errors::FailedPrecondition(
"'token' and 'expiration_timestamp_sec' cannot be nullptr.");
@@ -292,7 +293,7 @@
*expiration_timestamp_sec = request_timestamp_sec + expires_in;
TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token));
- return OkStatus();
+ return absl::OkStatus();
}
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h
index 895c2d0..19d8b4f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h
@@ -37,20 +37,20 @@
///
/// Retrieves the authentication bearer token using a JSON file
/// with the client's private key.
- virtual Status GetTokenFromServiceAccountJson(
- Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
- string* token, uint64* expiration_timestamp_sec);
+ virtual absl::Status GetTokenFromServiceAccountJson(
+ Json::Value json, absl::string_view oauth_server_uri,
+ absl::string_view scope, string* token, uint64* expiration_timestamp_sec);
/// Retrieves a bearer token using a refresh token.
- virtual Status GetTokenFromRefreshTokenJson(Json::Value json,
- StringPiece oauth_server_uri,
- string* token,
- uint64* expiration_timestamp_sec);
+ virtual absl::Status GetTokenFromRefreshTokenJson(
+ Json::Value json, absl::string_view oauth_server_uri, string* token,
+ uint64* expiration_timestamp_sec);
/// Parses the JSON response with the token from an OAuth 2.0 server.
- virtual Status ParseOAuthResponse(StringPiece response,
- uint64 request_timestamp_sec, string* token,
- uint64* expiration_timestamp_sec);
+ virtual absl::Status ParseOAuthResponse(absl::string_view response,
+ uint64 request_timestamp_sec,
+ string* token,
+ uint64* expiration_timestamp_sec);
private:
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc
index 8979f44..dc4c116 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc
@@ -20,7 +20,7 @@
#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/base64.h"
#include "tsl/platform/cloud/http_request_fake.h"
#include "tsl/platform/env.h"
@@ -118,7 +118,7 @@
EXPECT_EQ(13920, expiration_timestamp);
// Now look at the JWT claim that was sent to the OAuth server.
- StringPiece grant_type, assertion;
+ absl::string_view grant_type, assertion;
ASSERT_TRUE(strings::Scanner(post_body)
.OneLiteral("grant_type=")
.RestartCapture()
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc
index f16ab81..57330d1 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc
@@ -68,12 +68,12 @@
}
/// Move the block to the front of the LRU list if it isn't already there.
-Status RamFileBlockCache::UpdateLRU(const Key& key,
- const std::shared_ptr<Block>& block) {
+absl::Status RamFileBlockCache::UpdateLRU(const Key& key,
+ const std::shared_ptr<Block>& block) {
mutex_lock lock(mu_);
if (block->timestamp == 0) {
// The block was evicted from another thread. Allow it to remain evicted.
- return OkStatus();
+ return absl::OkStatus();
}
if (block->lru_iterator != lru_list_.begin()) {
lru_list_.erase(block->lru_iterator);
@@ -95,11 +95,11 @@
Trim();
- return OkStatus();
+ return absl::OkStatus();
}
-Status RamFileBlockCache::MaybeFetch(const Key& key,
- const std::shared_ptr<Block>& block) {
+absl::Status RamFileBlockCache::MaybeFetch(
+ const Key& key, const std::shared_ptr<Block>& block) {
bool downloaded_block = false;
auto reconcile_state =
absl::MakeCleanup([this, &downloaded_block, &key, &block] {
@@ -123,7 +123,7 @@
// Loop until either block content is successfully fetched, or our request
// encounters an error.
mutex_lock l(block->mu);
- Status status = OkStatus();
+ absl::Status status = absl::OkStatus();
while (true) {
switch (block->state) {
case FetchState::ERROR:
@@ -155,23 +155,24 @@
case FetchState::FETCHING:
block->cond_var.wait_for(l, std::chrono::seconds(60));
if (block->state == FetchState::FINISHED) {
- return OkStatus();
+ return absl::OkStatus();
}
// Re-loop in case of errors.
break;
case FetchState::FINISHED:
- return OkStatus();
+ return absl::OkStatus();
}
}
return errors::Internal(
"Control flow should never reach the end of RamFileBlockCache::Fetch.");
}
-Status RamFileBlockCache::Read(const string& filename, size_t offset, size_t n,
- char* buffer, size_t* bytes_transferred) {
+absl::Status RamFileBlockCache::Read(const string& filename, size_t offset,
+ size_t n, char* buffer,
+ size_t* bytes_transferred) {
*bytes_transferred = 0;
if (n == 0) {
- return OkStatus();
+ return absl::OkStatus();
}
if (!IsCacheEnabled() || (n > max_bytes_)) {
// The cache is effectively disabled, so we pass the read through to the
@@ -226,7 +227,7 @@
}
}
*bytes_transferred = total_bytes_transferred;
- return OkStatus();
+ return absl::OkStatus();
}
bool RamFileBlockCache::ValidateAndUpdateFileSignature(const string& filename,
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h
index 76cf7eb..627cf6f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h
@@ -45,9 +45,9 @@
/// cache is constructed. The returned Status should be OK as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
- typedef std::function<Status(const string& filename, size_t offset,
- size_t buffer_size, char* buffer,
- size_t* bytes_transferred)>
+ typedef std::function<absl::Status(const string& filename, size_t offset,
+ size_t buffer_size, char* buffer,
+ size_t* bytes_transferred)>
BlockFetcher;
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness,
@@ -88,8 +88,8 @@
/// placed in `out`.
/// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
/// in `out`).
- Status Read(const string& filename, size_t offset, size_t n, char* buffer,
- size_t* bytes_transferred) override;
+ absl::Status Read(const string& filename, size_t offset, size_t n,
+ char* buffer, size_t* bytes_transferred) override;
// Validate the given file signature with the existing file signature in the
// cache. Returns true if the signature doesn't change or the file doesn't
@@ -197,14 +197,14 @@
/// Look up a Key in the block cache.
std::shared_ptr<Block> Lookup(const Key& key) TF_LOCKS_EXCLUDED(mu_);
- Status MaybeFetch(const Key& key, const std::shared_ptr<Block>& block)
+ absl::Status MaybeFetch(const Key& key, const std::shared_ptr<Block>& block)
TF_LOCKS_EXCLUDED(mu_);
/// Trim the block cache to make room for another entry.
void Trim() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Update the LRU iterator for the block at `key`.
- Status UpdateLRU(const Key& key, const std::shared_ptr<Block>& block)
+ absl::Status UpdateLRU(const Key& key, const std::shared_ptr<Block>& block)
TF_LOCKS_EXCLUDED(mu_);
/// Remove all blocks of a file, with mu_ already held.
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc
index 5d17d73..cc71601 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc
@@ -17,7 +17,7 @@
#include <cstring>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/cloud/now_seconds_env.h"
#include "tsl/platform/env.h"
@@ -27,12 +27,12 @@
namespace tsl {
namespace {
-Status ReadCache(RamFileBlockCache* cache, const string& filename,
- size_t offset, size_t n, std::vector<char>* out) {
+absl::Status ReadCache(RamFileBlockCache* cache, const string& filename,
+ size_t offset, size_t n, std::vector<char>* out) {
out->clear();
out->resize(n, 0);
size_t bytes_transferred = 0;
- Status status =
+ absl::Status status =
cache->Read(filename, offset, n, out->data(), &bytes_transferred);
EXPECT_LE(bytes_transferred, n);
out->resize(bytes_transferred, n);
@@ -43,7 +43,7 @@
auto fetcher = [](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred) {
// Do nothing.
- return OkStatus();
+ return absl::OkStatus();
};
RamFileBlockCache cache1(0, 0, 0, fetcher);
RamFileBlockCache cache2(16, 0, 0, fetcher);
@@ -63,7 +63,7 @@
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
string filename = "file";
RamFileBlockCache cache(16, 32, 0, fetcher);
@@ -99,7 +99,7 @@
calls++;
memset(buffer, 'x', got_n);
*bytes_transferred = got_n;
- return OkStatus();
+ return absl::OkStatus();
};
// If block_size, max_bytes, or both are zero, or want_n is larger than
// max_bytes the cache is a pass-through.
@@ -136,7 +136,7 @@
} else {
*bytes_transferred = 0;
}
- return OkStatus();
+ return absl::OkStatus();
};
for (size_t block_size = 2; block_size <= 4; block_size++) {
// Make a cache of N-byte block size (1 block) and verify that reads of
@@ -181,7 +181,7 @@
calls.insert(offset);
memset(buffer, 'x', n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
const uint32 block_count = 256;
RamFileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
@@ -222,7 +222,7 @@
second_block = true;
}
*bytes_transferred = bytes_to_copy;
- return OkStatus();
+ return absl::OkStatus();
};
RamFileBlockCache cache(block_size, block_size, 0, fetcher);
std::vector<char> out;
@@ -233,7 +233,7 @@
// Reading at offset file_size + 4 will read the second block (since the read
// at file_size + 4 = 28 will be aligned to an offset of 16) but will return
// OutOfRange because the offset is past the end of the 24-byte file.
- Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
+ absl::Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
EXPECT_TRUE(second_block);
// Reading the second full block will return 8 bytes, from a cache hit.
@@ -255,7 +255,7 @@
EXPECT_GE(n, 1);
memset(buffer, 'x', 1);
*bytes_transferred = 1;
- return OkStatus();
+ return absl::OkStatus();
};
RamFileBlockCache cache(block_size, 2 * block_size, 0, fetcher);
std::vector<char> out;
@@ -264,7 +264,7 @@
EXPECT_EQ(out.size(), 1);
// Now read the first block; this should yield an INTERNAL error because we
// had already cached a partial block at a later position.
- Status status = ReadCache(&cache, "", 0, block_size, &out);
+ absl::Status status = ReadCache(&cache, "", 0, block_size, &out);
EXPECT_EQ(status.code(), error::INTERNAL);
}
@@ -282,7 +282,7 @@
}
memset(buffer, 'x', n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
const uint32 block_count = 2;
RamFileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
@@ -324,7 +324,7 @@
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
std::vector<char> out;
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
@@ -369,7 +369,7 @@
}
memset(buffer, c, n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
// This cache has space for 4 blocks; we'll read from two files.
const size_t n = 3;
@@ -426,7 +426,7 @@
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
std::vector<char> out;
// Our fake environment is initialized with the current timestamp.
@@ -493,7 +493,7 @@
}
memset(buffer, 'x', n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
const int block_size = 8;
RamFileBlockCache cache(block_size, 2 * callers * block_size, 0, fetcher);
@@ -529,7 +529,7 @@
notification.Notify();
// Wait for other thread to issue read.
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
- return OkStatus();
+ return absl::OkStatus();
};
RamFileBlockCache cache(block_size, block_size, 0, fetcher);
// Fork off thread for parallel read.
@@ -554,7 +554,7 @@
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
- return OkStatus();
+ return absl::OkStatus();
};
RamFileBlockCache cache(16, 32, 0, fetcher);
std::vector<char> out;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc
index fecba6b..62f8258 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc
@@ -34,7 +34,7 @@
// Only implements one special case of RFC 3339 which is returned by
// GCS API, e.g 2016-04-29T23:15:24.896Z.
-Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec) {
+absl::Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec) {
tm parsed{0};
float seconds;
if (sscanf(time.c_str(), "%4d-%2d-%2dT%2d:%2d:%fZ", &(parsed.tm_year),
@@ -52,7 +52,7 @@
static_cast<int64_t>(std::floor((seconds - int_seconds) *
kNanosecondsPerSecond));
- return OkStatus();
+ return absl::OkStatus();
}
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h
index 5eb116c..4dd2d29 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h
@@ -22,7 +22,7 @@
/// Parses the timestamp in RFC 3339 format and returns it
/// as nanoseconds since epoch.
-Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec);
+absl::Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec);
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc
index 3a96555..6b54787 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc
@@ -15,7 +15,7 @@
#include "tsl/platform/cloud/time_util.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/test.h"
namespace tsl {
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h
index 8c000e0..14b64ea 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h
@@ -34,9 +34,9 @@
/// Returns an empty string in the case where the zone does not match the
/// expected format
/// Safe for concurrent use by multiple threads.
- virtual Status GetZone(string* zone) = 0;
+ virtual absl::Status GetZone(string* zone) = 0;
- static Status GetZone(ZoneProvider* provider, string* zone) {
+ static absl::Status GetZone(ZoneProvider* provider, string* zone) {
if (!provider) {
return errors::Internal("Zone provider is required.");
}
diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc
index a1f3eba..a1934f8 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc
@@ -26,7 +26,7 @@
namespace tsl {
-string RocmRoot() {
+std::string RocmRoot() {
#if TENSORFLOW_USE_ROCM
if (const char* rocm_path_env = std::getenv("ROCM_PATH")) {
VLOG(3) << "ROCM root = " << rocm_path_env;
@@ -40,6 +40,12 @@
#endif
}
-string RocdlRoot() { return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); }
+std::string RocdlRoot() {
+ if (const char* device_lib_path_env = std::getenv("HIP_DEVICE_LIB_PATH")) {
+ return device_lib_path_env;
+ } else {
+ return io::JoinPath(RocmRoot(), "amdgcn/bitcode");
+ }
+}
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc
index 8477cdb..33792c8 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc
@@ -17,7 +17,7 @@
#include <fstream>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/str_util.h"
#include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc
index 5d55ec3..0024168 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc
@@ -18,7 +18,7 @@
#include <fstream>
#include "absl/time/time.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/str_util.h"
diff --git a/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc b/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc
index 1b1bbcb..1b01a8c 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc
@@ -20,7 +20,7 @@
#include <algorithm>
#include <string>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/path.h"
#include "tsl/platform/strcat.h"
#include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h
index 478dae8..ef30366 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h
@@ -46,9 +46,9 @@
ProfilerLock& operator=(const ProfilerLock&) = delete;
// Movable.
- ProfilerLock(ProfilerLock&& other)
+ ProfilerLock(ProfilerLock&& other) noexcept
: active_(std::exchange(other.active_, false)) {}
- ProfilerLock& operator=(ProfilerLock&& other) {
+ ProfilerLock& operator=(ProfilerLock&& other) noexcept {
active_ = std::exchange(other.active_, false);
return *this;
}
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h
index 75c2902..da9fe21 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h
@@ -146,8 +146,8 @@
}
// Movable.
- TraceMe(TraceMe&& other) { *this = std::move(other); }
- TraceMe& operator=(TraceMe&& other) {
+ TraceMe(TraceMe&& other) noexcept { *this = std::move(other); }
+ TraceMe& operator=(TraceMe&& other) noexcept {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(other.start_time_ != kUntracedActivity)) {
name_.Emplace(std::move(other.name_).Consume());
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
index 203657d..39113eb 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
@@ -363,6 +363,7 @@
":tpu_xplane_utils",
":xplane_schema",
":xplane_utils",
+ ":xplane_visitor",
"//tsl/platform:test",
"//tsl/platform:test_main",
"//tsl/profiler/protobuf:xplane_proto_cc",
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc
index 19841f5..9274a1d 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc
@@ -17,6 +17,7 @@
#include <optional>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tsl/platform/regexp.h"
#include "tsl/profiler/protobuf/xplane.pb.h"
#include "tsl/profiler/utils/xplane_schema.h"
@@ -48,5 +49,11 @@
return std::nullopt;
}
+std::optional<int> GetSparseCoreId(absl::string_view plane_name) {
+ std::optional<int> core_id;
+ RE2::FullMatch(plane_name, {kSparseCorePlaneRegex}, &core_id);
+ return core_id;
+}
+
} // namespace profiler
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h
index f3a150c..2fb7c67 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h
@@ -36,6 +36,10 @@
// TensorCore plane name.
std::optional<int> GetTensorCoreId(absl::string_view plane_name);
+// Get Sparsecore Id from SparseCore plane name if plane name is a valid
+// SparseCore plane name.
+std::optional<int> GetSparseCoreId(absl::string_view plane_name);
+
} // namespace profiler
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc
index a385c77..e5bcd73 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc
@@ -21,11 +21,13 @@
#include "tsl/profiler/protobuf/xplane.pb.h"
#include "tsl/profiler/utils/xplane_schema.h"
#include "tsl/profiler/utils/xplane_utils.h"
+#include "tsl/profiler/utils/xplane_visitor.h"
namespace tsl {
namespace profiler {
namespace {
+using ::testing::Optional;
using ::testing::UnorderedElementsAre;
TEST(TpuXPlaneUtilsTest, GetTensorCoreXPlanesFromXSpace) {
@@ -65,6 +67,22 @@
GetTensorCoreId(absl::StrCat("/prefix", TpuPlaneName(0))).has_value());
}
+TEST(TpuXplaneUtilsTest, GetSparseCorePlanesFromXSpace) {
+ XSpace space;
+ XPlane* p1 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(0));
+ XPlane* p2 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(1));
+ XPlane* p3 = FindOrAddMutablePlaneWithName(
+ &space, absl::StrCat(TpuPlaneName(0), " SparseCore 0"));
+ XPlane* p4 = FindOrAddMutablePlaneWithName(
+ &space, absl::StrCat(TpuPlaneName(0), " SparseCore 1"));
+
+ EXPECT_THAT(FindTensorCorePlanes(space), UnorderedElementsAre(p1, p2));
+ EXPECT_THAT(FindPlanesWithPrefix(space, kTpuPlanePrefix),
+ UnorderedElementsAre(p1, p2, p3, p4));
+ EXPECT_THAT(GetSparseCoreId(p3->name()), Optional(0));
+ EXPECT_THAT(GetSparseCoreId(p4->name()), Optional(1));
+}
+
} // namespace
} // namespace profiler
} // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
index 33de2b0..2cd8aaa 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
@@ -33,6 +33,8 @@
const absl::string_view kTpuPlanePrefix = "/device:TPU:";
const absl::string_view kTpuNonCorePlaneNamePrefix = "#Chip";
const char kTpuPlaneRegex[] = {"/device:TPU:([0-9]*)$"};
+const char kSparseCorePlaneRegex[] = {
+ "/device:TPU:[0-9]+ SparseCore ([0-9]+)$"};
// TODO(b/195582092): change it to /device:custom once all literals are
// migrated.
const absl::string_view kCustomPlanePrefix = "/device:CUSTOM:";
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
index 2e693b4..edf808b 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
@@ -41,6 +41,8 @@
TF_CONST_INIT extern const absl::string_view kTpuPlanePrefix;
// Regex for XPlanes that contain TensorCore planes.
TF_CONST_INIT extern const char kTpuPlaneRegex[];
+// Regex for XPlanes that contain TPU Core planes.
+TF_CONST_INIT extern const char kSparseCorePlaneRegex[];
// Name prefix of XPlane that contains custom device events.
TF_CONST_INIT extern const absl::string_view kCustomPlanePrefix;
// Name prefix of XPlane that contains TPU non-core events such as HBM, ICI etc.
diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl
index 0a2993f..8b8f1de 100644
--- a/third_party/xla/third_party/tsl/workspace2.bzl
+++ b/third_party/xla/third_party/tsl/workspace2.bzl
@@ -160,13 +160,13 @@
tf_http_archive(
name = "mkl_dnn_acl_compatible",
- build_file = "//tensorflow/third_party/mkl_dnn:mkldnn_acl.BUILD",
+ build_file = "//third_party/mkl_dnn:mkldnn_acl.BUILD",
patch_file = [
- "//tensorflow/third_party/mkl_dnn:onednn_acl_threadcap.patch",
- "//tensorflow/third_party/mkl_dnn:onednn_acl_reorder.patch",
- "//tensorflow/third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch",
- "//tensorflow/third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch",
- "//tensorflow/third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch",
+ "//third_party/mkl_dnn:onednn_acl_threadcap.patch",
+ "//third_party/mkl_dnn:onednn_acl_reorder.patch",
+ "//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch",
+ "//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch",
+ "//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch",
],
sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3",
strip_prefix = "oneDNN-3.2.1",
@@ -591,6 +591,22 @@
urls = tf_mirror_urls("https://github.com/google/glog/archive/refs/tags/v0.4.0.tar.gz"),
)
+ tf_http_archive(
+ name = "spirv_headers",
+ sha256 = "11d835c60297b26532c05c3f3b581ba7a2787b5ae7399e94f72c392169216f11",
+ strip_prefix = "SPIRV-Headers-b73e168ca5e123dcf3dea8a34b19a5130f421ae1",
+ urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-Headers/archive/b73e168ca5e123dcf3dea8a34b19a5130f421ae1.tar.gz"),
+ )
+
+ tf_http_archive(
+ name = "spirv_llvm_translator",
+ sha256 = "d499769f4fd1e0ce9d4dbd3622ee7e3e641b5623dcdf811521e3e7c0bdb1e6c2",
+ strip_prefix = "SPIRV-LLVM-Translator-dad1f0eaab8047a4f73c50ed5f3d1694b78aae97",
+ build_file = "//third_party/spirv_llvm_translator:spirv_llvm_translator.BUILD",
+ patch_file = ["//third_party/spirv_llvm_translator:spirv_llvm_translator.patch"],
+ urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-LLVM-Translator/archive/dad1f0eaab8047a4f73c50ed5f3d1694b78aae97.tar.gz"),
+ )
+
# buildifier: disable=unnamed-macro
def workspace():
# Check the bazel version before executing any repository rules, in case
diff --git a/third_party/xla/third_party/uv/uv.BUILD b/third_party/xla/third_party/uv/uv.BUILD
index 75a2df3..b04383a 100644
--- a/third_party/xla/third_party/uv/uv.BUILD
+++ b/third_party/xla/third_party/uv/uv.BUILD
@@ -11,7 +11,48 @@
cc_library(
name = "uv",
- srcs = glob(["src/*.c"]),
+ srcs = [
+ "src/fs-poll.c",
+ "src/idna.c",
+ "src/inet.c",
+ "src/random.c",
+ "src/strscpy.c",
+ "src/threadpool.c",
+ "src/timer.c",
+ "src/uv-common.c",
+ "src/uv-data-getter-setters.c",
+ "src/version.c",
+ ] + [
+ "src/unix/async.c",
+ "src/unix/core.c",
+ "src/unix/dl.c",
+ "src/unix/fs.c",
+ "src/unix/getaddrinfo.c",
+ "src/unix/getnameinfo.c",
+ "src/unix/loop.c",
+ "src/unix/loop-watcher.c",
+ "src/unix/pipe.c",
+ "src/unix/poll.c",
+ "src/unix/process.c",
+ "src/unix/random-devurandom.c",
+ "src/unix/signal.c",
+ "src/unix/stream.c",
+ "src/unix/tcp.c",
+ "src/unix/thread.c",
+ "src/unix/tty.c",
+ "src/unix/udp.c",
+ ] + select({
+ "@platforms//os:osx": [
+ "src/unix/bsd-ifaddrs.c",
+ "src/unix/darwin.c",
+ "src/unix/darwin-proctitle.c",
+ "src/unix/fsevents.c",
+ "src/unix/kqueue.c",
+ "src/unix/proctitle.c",
+ "src/unix/random-getentropy.c",
+ ],
+ }),
+ # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt.
hdrs = [
"include/uv.h",
],
diff --git a/third_party/xla/warnings.bazelrc b/third_party/xla/warnings.bazelrc
index a5711c9..ac21913 100644
--- a/third_party/xla/warnings.bazelrc
+++ b/third_party/xla/warnings.bazelrc
@@ -93,7 +93,5 @@
build:warnings --copt=-Wnon-virtual-dtor
build:warnings --copt=-Wimplicit-fallthrough
build:warnings --copt=-Wthread-safety-analysis
-build:warnings --copt=-Wno-tautological-type-limit-compare
-build:warnings --copt=-Wno-nullability-completeness
build:warnings --copt=-Wno-builtin-macro-redefined
build:warnings --copt=-Wno-macro-redefined
diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl
index dea8d37..da31481 100644
--- a/third_party/xla/workspace2.bzl
+++ b/third_party/xla/workspace2.bzl
@@ -16,7 +16,6 @@
load("//third_party/shardy:workspace.bzl", shardy = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
load("//third_party/triton:workspace.bzl", triton = "repo")
-load("//third_party/uv:workspace.bzl", uv = "repo")
def _initialize_third_party():
""" Load third party repositories. See above load() statements. """
@@ -28,7 +27,6 @@
shardy()
stablehlo()
triton()
- uv()
# Define all external repositories required by TensorFlow
def _tf_repositories():
@@ -52,9 +50,9 @@
name = "cudnn_frontend_archive",
build_file = "//third_party:cudnn_frontend.BUILD",
patch_file = ["//third_party:cudnn_frontend_header_fix.patch"],
- sha256 = "281789777ac296f5f8215a7c4bd066de8816d240eb44c760788beebf8d25a99f",
- strip_prefix = "cudnn-frontend-1.5.1",
- urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.1.zip"),
+ sha256 = "313f4a38a54e578ed668809697c96754497141ba62332bdcc019faaeb1e3c6f6",
+ strip_prefix = "cudnn-frontend-1.6.0",
+ urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.6.0.zip"),
)
tf_http_archive(
diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD
index 5c72564..5d4476a 100644
--- a/third_party/xla/xla/BUILD
+++ b/third_party/xla/xla/BUILD
@@ -516,6 +516,7 @@
":test",
":util",
":xla_data_proto_cc",
+ "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -1048,7 +1049,9 @@
":array2d",
":array3d",
":array4d",
+ ":literal",
":literal_util",
+ ":shape_util",
":util",
":window_util",
":xla_data_proto_cc",
@@ -1056,9 +1059,11 @@
"//xla/client:xla_builder",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/hlo/ir:hlo",
+ "//xla/service:hlo_module_config",
"//xla/service:shape_inference",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:function_ref",
+ "@com_google_absl//absl/log:check",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/lib/math:math_util",
"@local_tsl//tsl/platform:logging",
@@ -1072,7 +1077,9 @@
":array2d",
":array3d",
":array4d",
+ ":error_spec",
":literal",
+ ":literal_util",
":reference_util",
":test",
":xla_data_proto_cc",
@@ -1254,10 +1261,10 @@
deps = [
":autotune_results_proto_cc",
":autotuning_proto_cc",
+ "//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
- "@local_tsl//tsl/lib/strings:proto_serialization",
],
)
@@ -1312,6 +1319,33 @@
visibility = ["//visibility:public"],
)
+cc_library(
+ name = "sort_json",
+ srcs = ["sort_json.cc"],
+ hdrs = ["sort_json.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "sort_json_test",
+ srcs = ["sort_json_test.cc"],
+ deps = [
+ ":sort_json",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
# Needed to workaround https://github.com/bazelbuild/bazel/issues/21519
alias(
name = "bazel_issue_21519",
diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h
index 6a6f505..03c5f3b 100644
--- a/third_party/xla/xla/array.h
+++ b/third_party/xla/xla/array.h
@@ -603,12 +603,12 @@
std::fill(data.get(), data.get() + size, init);
}
- OwnedBuffer(OwnedBuffer&& other)
+ OwnedBuffer(OwnedBuffer&& other) noexcept
: data(std::move(other.data)), size(other.size) {
other.size = 0;
}
- OwnedBuffer& operator=(OwnedBuffer&& other) {
+ OwnedBuffer& operator=(OwnedBuffer&& other) noexcept {
data = std::move(other.data);
size = other.size;
other.size = 0;
diff --git a/third_party/xla/xla/autotune_result_wrapper.cc b/third_party/xla/xla/autotune_result_wrapper.cc
index 855c8aa..ee92f17 100644
--- a/third_party/xla/xla/autotune_result_wrapper.cc
+++ b/third_party/xla/xla/autotune_result_wrapper.cc
@@ -21,7 +21,7 @@
#include "absl/status/status.h"
#include "xla/autotune_results.pb.h"
#include "xla/autotuning.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
namespace xla {
diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h
index 0aee389b..1228b3b 100644
--- a/third_party/xla/xla/backends/interpreter/executor.h
+++ b/third_party/xla/xla/backends/interpreter/executor.h
@@ -86,15 +86,6 @@
absl::Status Init() override { return absl::OkStatus(); }
int device_ordinal() const override { return device_ordinal_; };
- absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
- const MultiKernelLoaderSpec &spec) override {
- return absl::UnimplementedError("Not Implemented");
- }
- absl::Status Launch(Stream *stream, const ThreadDim &thread_dims,
- const BlockDim &block_dims, const Kernel &kernel,
- const KernelArgs &args) override {
- return absl::UnimplementedError("Not Implemented");
- }
DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
void Deallocate(DeviceMemoryBase *mem) override;
@@ -103,7 +94,7 @@
uint64_t size) override {
return std::make_unique<HostMemoryAllocation>(new char[size], size, this);
}
- void HostMemoryDeallocate(void *mem, uint64_t size) override {
+ void HostMemoryDeallocate(void *mem) override {
delete[] static_cast<char *>(mem);
}
diff --git a/third_party/xla/xla/backends/interpreter/platform.cc b/third_party/xla/xla/backends/interpreter/platform.cc
index 8b77eb1..1a541f1 100644
--- a/third_party/xla/xla/backends/interpreter/platform.cc
+++ b/third_party/xla/xla/backends/interpreter/platform.cc
@@ -54,6 +54,13 @@
return GetExecutor(config);
}
+absl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::FindExisting(
+ int ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ return executor_cache_.Get(config);
+}
+
absl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
const StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate(
diff --git a/third_party/xla/xla/backends/interpreter/platform.h b/third_party/xla/xla/backends/interpreter/platform.h
index da3d18e..08a8d21 100644
--- a/third_party/xla/xla/backends/interpreter/platform.h
+++ b/third_party/xla/xla/backends/interpreter/platform.h
@@ -49,9 +49,13 @@
absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
+ absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
- const StreamExecutorConfig& config) override;
+ const StreamExecutorConfig& config);
private:
// This platform's name.
diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc
index e53b017..65f8d57 100644
--- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc
+++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc
@@ -618,7 +618,7 @@
const std::vector<RocmTracerEvent> ApiActivityInfoExchange()
TF_EXCLUSIVE_LOCKS_REQUIRED(event_maps_mutex_);
- absl::flat_hash_map<uint32_t, PerDeviceCollector> per_device_collector_;
+ absl::node_hash_map<uint32_t, PerDeviceCollector> per_device_collector_;
};
//==========
diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD
index 1e7f1bf..d322caf 100644
--- a/third_party/xla/xla/client/BUILD
+++ b/third_party/xla/xla/client/BUILD
@@ -53,6 +53,7 @@
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/lib/math:math_util",
"@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
],
)
@@ -71,24 +72,23 @@
srcs = ["client.cc"],
hdrs = ["client.h"],
deps = [
- ":global_data",
":xla_computation",
- "//xla:debug_options_flags",
"//xla:execution_options_util",
"//xla:literal",
+ "//xla:shape_util",
"//xla:status_macros",
"//xla:types",
+ "//xla:util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/service",
"//xla/service:hlo_proto_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:protobuf",
+ "@local_tsl//tsl/platform:statusor",
],
)
@@ -111,9 +111,9 @@
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
],
)
@@ -126,14 +126,18 @@
":client",
":executable_build_options",
":xla_computation",
+ "//xla:debug_options_flags",
"//xla:executable_run_options",
+ "//xla:literal",
"//xla:shape_tree",
+ "//xla:shape_util",
+ "//xla:util",
"//xla:xla_data_proto_cc",
"//xla/service:backend",
"//xla/service:compiler",
+ "//xla/service:computation_layout",
"//xla/service:dump",
"//xla/service:executable",
- "//xla/service:hlo_proto_cc",
"//xla/service:local_service",
"//xla/service:maybe_owning_device_memory",
"//xla/service:shaped_buffer",
@@ -141,9 +145,14 @@
"//xla/service:stream_pool",
"//xla/stream_executor",
"//xla/stream_executor:device_memory_allocator",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
],
)
@@ -154,12 +163,18 @@
deps = [
":client",
":xla_computation",
+ "//xla:shape_util",
"//xla:status_macros",
"//xla:xla_data_proto_cc",
+ "//xla:xla_proto_cc",
"//xla/service:compile_only_service",
"//xla/service:compiler",
+ "//xla/service:hlo_module_config",
"//xla/stream_executor",
+ "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
],
@@ -174,18 +189,20 @@
deps = [
":compile_only_client",
":local_client",
- "//xla:status_macros",
"//xla:types",
- "//xla:util",
- "//xla/service:backend",
+ "//xla/service",
"//xla/service:compile_only_service",
"//xla/service:local_service",
"//xla/service:platform_util",
"//xla/stream_executor",
"//xla/stream_executor:device_memory_allocator",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/synchronization",
"@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
],
)
@@ -200,6 +217,7 @@
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
+ "@com_google_absl//absl/log:check",
],
)
@@ -214,6 +232,7 @@
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/service:hlo_proto_cc",
+ "@com_google_absl//absl/status:statusor",
],
)
@@ -233,8 +252,13 @@
"//xla:xla_data_proto_cc",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/hlo/ir:hlo",
+ "//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/hash",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@@ -305,6 +329,7 @@
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_parser",
+ "//xla/service:hlo_proto_cc",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:xla_internal_test_main",
diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc
index 6e89947..f5e174d 100644
--- a/third_party/xla/xla/client/client.cc
+++ b/third_party/xla/xla/client/client.cc
@@ -22,16 +22,22 @@
#include <vector>
#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "xla/client/xla_computation.h"
-#include "xla/debug_options_flags.h"
#include "xla/execution_options_util.h"
+#include "xla/layout.h"
#include "xla/literal.h"
+#include "xla/service/hlo.pb.h"
+#include "xla/service/service.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
#include "xla/status_macros.h"
-#include "xla/types.h"
-#include "tsl/platform/errors.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/protobuf.h"
+#include "tsl/platform/statusor.h"
namespace xla {
diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h
index f3eacbc..1201568 100644
--- a/third_party/xla/xla/client/client.h
+++ b/third_party/xla/xla/client/client.h
@@ -21,12 +21,15 @@
#include <utility>
#include <vector>
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/client/xla_computation.h"
+#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/service.h"
+#include "xla/shape.h"
#include "xla/types.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
diff --git a/third_party/xla/xla/client/client_library.cc b/third_party/xla/xla/client/client_library.cc
index b55691b..476208d 100644
--- a/third_party/xla/xla/client/client_library.cc
+++ b/third_party/xla/xla/client/client_library.cc
@@ -20,11 +20,17 @@
#include <set>
#include <utility>
-#include "xla/service/backend.h"
+#include "absl/synchronization/mutex.h"
+#include "xla/client/compile_only_client.h"
+#include "xla/client/local_client.h"
+#include "xla/service/compile_only_service.h"
+#include "xla/service/local_service.h"
#include "xla/service/platform_util.h"
-#include "xla/status_macros.h"
-#include "xla/util.h"
+#include "xla/service/service.h"
+#include "xla/stream_executor/platform.h"
#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
namespace xla {
diff --git a/third_party/xla/xla/client/client_library.h b/third_party/xla/xla/client/client_library.h
index db86732..0e4f3a9a 100644
--- a/third_party/xla/xla/client/client_library.h
+++ b/third_party/xla/xla/client/client_library.h
@@ -28,13 +28,16 @@
#include <string>
#include <vector>
+#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
+#include "absl/synchronization/mutex.h"
#include "xla/client/compile_only_client.h"
#include "xla/client/local_client.h"
#include "xla/service/compile_only_service.h"
#include "xla/service/local_service.h"
#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/types.h"
diff --git a/third_party/xla/xla/client/compile_only_client.cc b/third_party/xla/xla/client/compile_only_client.cc
index 23c07b37..1aa6a4f 100644
--- a/third_party/xla/xla/client/compile_only_client.cc
+++ b/third_party/xla/xla/client/compile_only_client.cc
@@ -18,9 +18,19 @@
#include <memory>
#include <vector>
+#include "absl/log/check.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/TargetParser/Triple.h"
+#include "xla/service/compile_only_service.h"
+#include "xla/service/compiler.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
#include "xla/status_macros.h"
+#include "xla/xla.pb.h"
namespace xla {
diff --git a/third_party/xla/xla/client/compile_only_client.h b/third_party/xla/xla/client/compile_only_client.h
index 8dde8c8..2dcb977 100644
--- a/third_party/xla/xla/client/compile_only_client.h
+++ b/third_party/xla/xla/client/compile_only_client.h
@@ -20,11 +20,16 @@
#include <vector>
#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "xla/client/client.h"
#include "xla/client/xla_computation.h"
#include "xla/service/compile_only_service.h"
#include "xla/service/compiler.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
#include "xla/stream_executor/stream_executor.h"
+#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
namespace xla {
diff --git a/third_party/xla/xla/client/executable_build_options.cc b/third_party/xla/xla/client/executable_build_options.cc
index 46b810d..e484319 100644
--- a/third_party/xla/xla/client/executable_build_options.cc
+++ b/third_party/xla/xla/client/executable_build_options.cc
@@ -27,14 +27,13 @@
#include "xla/debug_options_flags.h"
#include "xla/execution_options_util.h"
#include "xla/layout_util.h"
+#include "xla/pjrt/compile_options.pb.h"
#include "xla/service/compilation_environments.h"
#include "xla/service/computation_placer.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
namespace xla {
diff --git a/third_party/xla/xla/client/executable_build_options.h b/third_party/xla/xla/client/executable_build_options.h
index c849230..f1129d6 100644
--- a/third_party/xla/xla/client/executable_build_options.h
+++ b/third_party/xla/xla/client/executable_build_options.h
@@ -24,6 +24,10 @@
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "xla/pjrt/compile_options.pb.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/compilation_environments.h"
diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD
index 28f6c27..7e2cc0d 100644
--- a/third_party/xla/xla/client/lib/BUILD
+++ b/third_party/xla/xla/client/lib/BUILD
@@ -227,6 +227,7 @@
xla_test(
name = "math_test",
+ timeout = "long",
srcs = ["math_test.cc"],
backend_tags = {
# Times out.
diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc
index c388dc4..05056ba 100644
--- a/third_party/xla/xla/client/local_client.cc
+++ b/third_party/xla/xla/client/local_client.cc
@@ -20,13 +20,37 @@
#include <utility>
#include <vector>
+#include "absl/log/check.h"
+#include "absl/log/log.h"
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "xla/client/executable_build_options.h"
#include "xla/client/xla_computation.h"
+#include "xla/debug_options_flags.h"
+#include "xla/executable_run_options.h"
+#include "xla/literal.h"
#include "xla/service/backend.h"
+#include "xla/service/compiler.h"
+#include "xla/service/computation_layout.h"
#include "xla/service/dump.h"
+#include "xla/service/executable.h"
+#include "xla/service/maybe_owning_device_memory.h"
#include "xla/service/service_executable_run_options.h"
+#include "xla/service/shaped_buffer.h"
#include "xla/service/source_map_util.h"
#include "xla/service/stream_pool.h"
+#include "xla/shape.h"
+#include "xla/shape_tree.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
using xla::source_map_util::InvalidParameterArgument;
diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h
index 236ebe0..07c6e6e 100644
--- a/third_party/xla/xla/client/local_client.h
+++ b/third_party/xla/xla/client/local_client.h
@@ -21,21 +21,28 @@
#include <utility>
#include <vector>
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/client/client.h"
#include "xla/client/executable_build_options.h"
#include "xla/client/xla_computation.h"
#include "xla/executable_run_options.h"
+#include "xla/literal.h"
+#include "xla/service/backend.h"
#include "xla/service/compiler.h"
#include "xla/service/executable.h"
#include "xla/service/local_service.h"
#include "xla/service/maybe_owning_device_memory.h"
+#include "xla/service/service_executable_run_options.h"
#include "xla/service/shaped_buffer.h"
+#include "xla/service/stream_pool.h"
#include "xla/shape_tree.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
namespace xla {
diff --git a/third_party/xla/xla/client/padding.cc b/third_party/xla/xla/client/padding.cc
index 37abc59..daf26d5 100644
--- a/third_party/xla/xla/client/padding.cc
+++ b/third_party/xla/xla/client/padding.cc
@@ -20,9 +20,11 @@
#include <vector>
#include "absl/status/status.h"
+#include "absl/types/span.h"
#include "xla/util.h"
#include "tsl/lib/math/math_util.h"
#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
namespace xla {
diff --git a/third_party/xla/xla/client/padding.h b/third_party/xla/xla/client/padding.h
index e715226..e717183 100644
--- a/third_party/xla/xla/client/padding.h
+++ b/third_party/xla/xla/client/padding.h
@@ -19,6 +19,7 @@
#include <utility>
#include <vector>
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/types.h"
diff --git a/third_party/xla/xla/client/sharding_builder.cc b/third_party/xla/xla/client/sharding_builder.cc
index e2324d6..7b179b8 100644
--- a/third_party/xla/xla/client/sharding_builder.cc
+++ b/third_party/xla/xla/client/sharding_builder.cc
@@ -17,6 +17,12 @@
#include <vector>
+#include "absl/log/check.h"
+#include "xla/shape.h"
+#include "xla/shape_tree.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
namespace xla {
namespace sharding_builder {
diff --git a/third_party/xla/xla/client/sharding_builder.h b/third_party/xla/xla/client/sharding_builder.h
index 98d6512..eef395e 100644
--- a/third_party/xla/xla/client/sharding_builder.h
+++ b/third_party/xla/xla/client/sharding_builder.h
@@ -19,6 +19,7 @@
#include <vector>
#include "xla/array.h"
+#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/types.h"
diff --git a/third_party/xla/xla/client/value_inference.cc b/third_party/xla/xla/client/value_inference.cc
index 1ba694a..2f0b6e2 100644
--- a/third_party/xla/xla/client/value_inference.cc
+++ b/third_party/xla/xla/client/value_inference.cc
@@ -21,11 +21,17 @@
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/hash/hash.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
+#include "xla/client/xla_builder.h"
#include "xla/comparison_util.h"
+#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/hlo/ir/dfs_hlo_visitor.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_opcode.h"
@@ -33,6 +39,8 @@
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo.pb.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
diff --git a/third_party/xla/xla/client/value_inference.h b/third_party/xla/xla/client/value_inference.h
index 6f1685f..84c1c99 100644
--- a/third_party/xla/xla/client/value_inference.h
+++ b/third_party/xla/xla/client/value_inference.h
@@ -19,6 +19,9 @@
#include <utility>
#include "absl/container/flat_hash_map.h"
+#include "absl/log/check.h"
+#include "absl/status/statusor.h"
+#include "absl/types/span.h"
#include "xla/client/xla_builder.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/hlo/ir/dfs_hlo_visitor.h"
@@ -26,6 +29,7 @@
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
+#include "xla/shape_util.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc
index 0b679af..98e7dad 100644
--- a/third_party/xla/xla/client/xla_builder.cc
+++ b/third_party/xla/xla/client/xla_builder.cc
@@ -1450,6 +1450,42 @@
});
}
+XlaOp XlaBuilder::CompositeCall(const XlaComputation& computation,
+ absl::Span<const XlaOp> operands,
+ const std::string& name,
+ std::optional<absl::string_view> attributes,
+ std::optional<int64_t> version) {
+ return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+ TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
+ computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
+ operand_shape_ptrs,
+ /*to_apply=*/called_program_shape));
+ *instr.mutable_shape() = shape.ToProto();
+
+ AddCalledComputation(computation, &instr);
+ instr.set_is_composite(true);
+
+ TF_ASSIGN_OR_RETURN(
+ XlaOp instruction,
+ AddInstruction(std::move(instr), HloOpcode::kCall, operands));
+ TF_RETURN_IF_ERROR(
+ SetInstructionFrontendAttribute(instruction, "composite.name", name));
+ TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute(
+ instruction, "composite.attributes",
+ attributes.has_value() ? std::string(*attributes) : "{}"));
+ TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute(
+ instruction, "composite.version",
+ version.has_value() ? std::to_string(*version) : "0"));
+ return instruction;
+ });
+}
+
XlaOp XlaBuilder::Parameter(
int64_t parameter_number, const Shape& shape, const std::string& name,
const std::vector<bool>& replicated_at_leaf_buffers) {
@@ -5196,6 +5232,14 @@
return builder->Call(computation, operands);
}
+XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation,
+ absl::Span<const XlaOp> operands, const std::string& name,
+ std::optional<absl::string_view> attributes,
+ std::optional<int64_t> version) {
+ return builder->CompositeCall(computation, operands, name, attributes,
+ version);
+}
+
XlaOp CustomCall(
XlaBuilder* builder, const std::string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h
index c1192bf..53683cf 100644
--- a/third_party/xla/xla/client/xla_builder.h
+++ b/third_party/xla/xla/client/xla_builder.h
@@ -731,6 +731,12 @@
XlaOp Call(const XlaComputation& computation,
absl::Span<const XlaOp> operands);
+ XlaOp CompositeCall(
+ const XlaComputation& computation, absl::Span<const XlaOp> operands,
+ const std::string& name,
+ std::optional<absl::string_view> attributes = std::nullopt,
+ std::optional<int64_t> version = std::nullopt);
+
XlaOp CustomCall(
const std::string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const std::string& opaque,
@@ -1378,6 +1384,14 @@
const std::string& outfeed_config);
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
+
+ friend XlaOp CompositeCall(XlaBuilder* builder,
+ const XlaComputation& computation,
+ absl::Span<const XlaOp> operands,
+ const std::string& name,
+ std::optional<absl::string_view> attributes,
+ std::optional<int64_t> version);
+
friend XlaOp CustomCall(
XlaBuilder* builder, const std::string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
@@ -2305,6 +2319,12 @@
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
+// Enqueues a composite call instruction onto the computation.
+XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation,
+ absl::Span<const XlaOp> operands, const std::string& name,
+ std::optional<absl::string_view> attributes = std::nullopt,
+ std::optional<int64_t> version = std::nullopt);
+
// Enqueues a custom call instruction onto the computation. A custom call
// invokes code external to XLA. The |operands| are passed to the external code,
// and the external code is expected to produce a result of the given
diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc
index 9828d50..8ecf243 100644
--- a/third_party/xla/xla/client/xla_builder_test.cc
+++ b/third_party/xla/xla/client/xla_builder_test.cc
@@ -48,6 +48,7 @@
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/layout_util.h"
+#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_parser.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
@@ -330,6 +331,176 @@
m::Call(m::Constant(), m::Constant()))));
}
+TEST(XlaBuilderTest, CompositeCall) {
+ XlaBuilder b(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {});
+ const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+ XlaBuilder bsum(TestName());
+ Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+ TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+ std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+ Parameter(&b, 1, shape, "arg1")};
+ CompositeCall(&b, computation, absl::MakeSpan(operands),
+ /*name=*/"foo.bar",
+ /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+ /*version=*/1);
+
+ TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+ EXPECT_THAT(GetRoot(*module),
+ GmockMatch(m::Call(m::Parameter(), m::Parameter())));
+}
+
+TEST(XlaBuilderTest, CompositeCallFrontendAttributesStayLocal) {
+ XlaBuilder b(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {});
+ const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+ XlaBuilder bsum(TestName());
+ Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+ TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+ std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+ Parameter(&b, 1, shape, "arg1")};
+ CompositeCall(&b, computation, absl::MakeSpan(operands),
+ /*name=*/"foo.bar",
+ /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+ /*version=*/1);
+ Add(operands[0], operands[1]);
+
+ TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+ EXPECT_TRUE(GetRoot(*module)->frontend_attributes().map().empty());
+}
+
+TEST(XlaBuilderTest, CompositeCallMissingName) {
+ XlaBuilder b(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {});
+ const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+ XlaBuilder bsum(TestName());
+ Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+ TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+ std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+ Parameter(&b, 1, shape, "arg1")};
+ CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"",
+ /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+ /*version=*/1);
+
+ auto statusor = BuildHloModule(b);
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().message(),
+ HasSubstr("A composite call op must have frontend attributes "
+ "with key composite.name whose value is non-empty"));
+}
+
+TEST(XlaBuilderTest, CompositeCallMissingAttribute) {
+ XlaBuilder b(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {});
+ const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+ XlaBuilder bsum(TestName());
+ Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+ TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+ std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+ Parameter(&b, 1, shape, "arg1")};
+ CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"foo.bar",
+ /*attributes=*/"", /*version=*/1);
+
+ auto statusor = BuildHloModule(b);
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().message(),
+ HasSubstr(
+ "A composite call op must have frontend attributes with key "
+ "composite.attributes whose value is default: {} or non-empty"));
+}
+
+TEST(XlaBuilderTest, CompositeCallNonNegativeVersion) {
+ XlaBuilder b(TestName());
+
+ FrontendAttributes frontend_attributes = b.frontend_attributes();
+ frontend_attributes.mutable_map()->insert({"foo", "bar"});
+ b.SetFrontendAttributes(frontend_attributes);
+
+ const Shape shape = ShapeUtil::MakeShape(F32, {});
+ const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+ XlaBuilder bsum(TestName());
+ Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+ TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+ std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+ Parameter(&b, 1, shape, "arg1")};
+ CompositeCall(&b, computation, absl::MakeSpan(operands),
+ /*name=*/"foo.bar",
+ /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+ /*version=*/-1);
+
+ auto statusor = BuildHloModule(b);
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().message(),
+ HasSubstr("A composite call op must have frontend attributes "
+ "with a composite.version whose value is a "
+ "non-negative integer but got: -1"));
+}
+
+TEST(XlaBuilderTest, CompositeCallOptionalVersionAndAttribute) {
+ XlaBuilder b(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {});
+ const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+ XlaBuilder bsum(TestName());
+ Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+ TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+ std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+ Parameter(&b, 1, shape, "arg1")};
+ CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"foo.bar");
+
+ TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+ ASSERT_THAT(GetRoot(*module),
+ GmockMatch(m::Call(m::Parameter(), m::Parameter())));
+ ASSERT_TRUE(GetRoot(*module)->frontend_attributes().map().contains(
+ "composite.attributes"));
+ EXPECT_EQ(
+ GetRoot(*module)->frontend_attributes().map().at("composite.attributes"),
+ "{}");
+ EXPECT_EQ(
+ GetRoot(*module)->frontend_attributes().map().at("composite.version"),
+ "0");
+}
+
+TEST(XlaBuilderTest, CompositeCallWithExtraFrontendAttributes) {
+ XlaBuilder b(TestName());
+
+ FrontendAttributes frontend_attributes = b.frontend_attributes();
+ frontend_attributes.mutable_map()->insert({"foo", "bar"});
+ b.SetFrontendAttributes(frontend_attributes);
+
+ const Shape shape = ShapeUtil::MakeShape(F32, {});
+ const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+ XlaBuilder bsum(TestName());
+ Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+ TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+ std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+ Parameter(&b, 1, shape, "arg1")};
+ CompositeCall(&b, computation, absl::MakeSpan(operands),
+ /*name=*/"foo.bar",
+ /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+ /*version=*/1);
+
+ TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+ EXPECT_THAT(GetRoot(*module),
+ GmockMatch(m::Call(m::Parameter(), m::Parameter())));
+ ASSERT_TRUE(GetRoot(*module)->frontend_attributes().map().contains("foo"));
+ EXPECT_EQ(GetRoot(*module)->frontend_attributes().map().at("foo"), "bar");
+}
+
TEST(XlaBuilderTest, BinopHasDegenerateBroadcast) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
diff --git a/third_party/xla/xla/client/xla_computation.cc b/third_party/xla/xla/client/xla_computation.cc
index c92de63..fc55846 100644
--- a/third_party/xla/xla/client/xla_computation.cc
+++ b/third_party/xla/xla/client/xla_computation.cc
@@ -18,6 +18,9 @@
#include <memory>
#include <utility>
+#include "absl/status/statusor.h"
+#include "xla/service/hlo.pb.h"
+#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/util.h"
diff --git a/third_party/xla/xla/client/xla_computation.h b/third_party/xla/xla/client/xla_computation.h
index e21a92d..52a54aa 100644
--- a/third_party/xla/xla/client/xla_computation.h
+++ b/third_party/xla/xla/client/xla_computation.h
@@ -20,6 +20,7 @@
#include <string>
#include <utility>
+#include "absl/status/statusor.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc
index 79e716b..890bd25 100644
--- a/third_party/xla/xla/debug_options_flags.cc
+++ b/third_party/xla/xla/debug_options_flags.cc
@@ -237,6 +237,7 @@
opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
opts.set_xla_gpu_enable_libnvptxcompiler(false);
+ opts.set_xla_gpu_enable_libnvjitlink(false);
opts.set_xla_gpu_enable_dot_strength_reduction(true);
@@ -246,7 +247,7 @@
opts.set_xla_gpu_nccl_p2p_max_nchannels(0);
#if GOOGLE_CUDA
- opts.set_xla_gpu_mlir_emitter_level(3);
+ opts.set_xla_gpu_mlir_emitter_level(4);
#else
opts.set_xla_gpu_mlir_emitter_level(0);
#endif
@@ -281,6 +282,8 @@
opts.set_xla_enable_command_buffers_during_profiling(false);
+ opts.set_xla_gpu_cudnn_gemm_max_plans(5);
+
return opts;
}
@@ -1834,6 +1837,12 @@
"Experimental: Enable command buffers while a profiling active. "
"By default, enabling profiling switches from command buffers to "
"op-by-op mode."));
+ flag_list->push_back(tsl::Flag(
+ "xla_gpu_cudnn_gemm_max_plans",
+ int32_setter_for(&DebugOptions::set_xla_gpu_cudnn_gemm_max_plans),
+ debug_options->xla_gpu_cudnn_gemm_max_plans(),
+ "Limit for the number of kernel configurations (plans) to use during "
+ "autotuning of cuDNN GEMM fusions."));
} // NOLINT(readability/fn_size)
// Allocates flag_values and flag_objects; this function must not be called more
diff --git a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc
index 897a1e9..49a99ee 100644
--- a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc
+++ b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc
@@ -62,10 +62,8 @@
// PlatformUtil::GetPlatform("CUDA"));
TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
PlatformUtil::GetPlatform("cpu"));
- se::StreamExecutorConfig config;
- config.ordinal = 0;
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor,
- platform->GetExecutor(config));
+ platform->ExecutorForDevice(/*ordinal=*/0));
// LocalDeviceState and PjRtStreamExecutorDevice describes the state of a
// device which can do computation or transfer buffers. This could represent a
diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD
index f01ddef..2e9deaa 100644
--- a/third_party/xla/xla/ffi/BUILD
+++ b/third_party/xla/xla/ffi/BUILD
@@ -118,8 +118,10 @@
":api",
":execution_context",
":execution_state",
+ "//xla:executable_run_options",
"//xla:shape_util",
"//xla:types",
+ "//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi/api:c_api",
"//xla/ffi/api:c_api_internal",
@@ -148,6 +150,7 @@
":execution_context",
":execution_state",
":type_id_registry",
+ "//xla:executable_run_options",
"//xla:util",
"//xla/ffi/api:c_api",
"//xla/ffi/api:c_api_internal",
diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h
index 7675c3a..dccbcc6 100644
--- a/third_party/xla/xla/ffi/api/api.h
+++ b/third_party/xla/xla/ffi/api/api.h
@@ -33,7 +33,6 @@
#include <tuple>
#include <type_traits>
#include <utility>
-#include <variant>
#include <vector>
// This is a header-only base C++ library that defines templates for decoding
@@ -232,16 +231,17 @@
template <typename... Args>
static std::string StrCat(Args... args);
- static inline XLA_FFI_Error* MakeError(const XLA_FFI_Api* api,
- XLA_FFI_Error_Code errc,
- std::string message);
+ static XLA_FFI_Error* Sucess();
- static inline XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api,
- std::string message);
+ static XLA_FFI_Error* MakeError(const XLA_FFI_Api* api,
+ XLA_FFI_Error_Code errc, std::string message);
- static inline XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api,
- std::string_view struct_name,
- size_t expected, size_t actual);
+ static XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api,
+ std::string message);
+
+ static XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api,
+ std::string_view struct_name,
+ size_t expected, size_t actual);
};
XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api,
@@ -266,8 +266,11 @@
return ss.str();
}
-XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, XLA_FFI_Error_Code errc,
- std::string message) {
+inline XLA_FFI_Error* Ffi::Sucess() { return nullptr; }
+
+inline XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api,
+ XLA_FFI_Error_Code errc,
+ std::string message) {
XLA_FFI_Error_Create_Args args;
args.struct_size = XLA_FFI_Error_Create_Args_STRUCT_SIZE;
args.priv = nullptr;
@@ -276,15 +279,15 @@
return api->XLA_FFI_Error_Create(&args);
}
-XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api,
- std::string message) {
+inline XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api,
+ std::string message) {
return MakeError(api, XLA_FFI_Error_Code_INVALID_ARGUMENT,
std::move(message));
}
-XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api,
- std::string_view struct_name,
- size_t expected, size_t actual) {
+inline XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api,
+ std::string_view struct_name,
+ size_t expected, size_t actual) {
if (expected != actual) {
return InvalidArgument(
api, StrCat("Unexpected ", struct_name, " size: expected ", expected,
@@ -306,12 +309,13 @@
// parameter packs. We need this to be able to pattern match FFI handler
// signature at compile time.
+// A type tag for decoding optional argument.
+template <typename T>
+struct OptionalArgTag {};
+
// A type tag to forward all remaining args as `RemainingArgs`.
struct RemainingArgsTag {};
-// A type tag to forward all remaining results as `RemainingRets`.
-struct RemainingRetsTag {};
-
// A type tag to distinguish parameters tied to results in the `Binding`
// variadic template. In XLA FFI we use destination passing style APIs and don't
// return anything from the handler, but instead pass a destination where the
@@ -319,6 +323,13 @@
template <typename T>
struct RetTag {};
+// A type tag for decoding optional result.
+template <typename T>
+struct OptionalRetTag {};
+
+// A type tag to forward all remaining results as `RemainingRets`.
+struct RemainingRetsTag {};
+
// A type tag to distinguish parameters tied to the attributes in the
// `Binding` variadic template.
template <typename T>
@@ -357,12 +368,30 @@
//----------------------------------------------------------------------------//
-// Checks if remaining arguments are in the parameter pack.
+template <typename T>
+struct IsOptionalArgTag : std::false_type {};
+template <typename T>
+struct IsOptionalArgTag<OptionalArgTag<T>> : std::true_type {};
+
+template <typename T>
+struct IsOptionalRetTag : std::false_type {};
+template <typename T>
+struct IsOptionalRetTag<OptionalRetTag<T>> : std::true_type {};
+
+// Checks if parameter pack has an optional argument.
+template <typename... Ts>
+using HasOptionalArgTag = std::disjunction<IsOptionalArgTag<Ts>...>;
+
+// Checks if parameter pack has remaining arguments.
template <typename... Ts>
using HasRemainingArgsTag =
std::disjunction<std::is_same<RemainingArgsTag, Ts>...>;
-// Checks if remaining results are in the parameter pack.
+// Checks if parameter pack has an optional result.
+template <typename... Ts>
+using HasOptionalRetTag = std::disjunction<IsOptionalRetTag<Ts>...>;
+
+// Checks if parameter pack has remaining results.
template <typename... Ts>
using HasRemainingRetsTag =
std::disjunction<std::is_same<RemainingRetsTag, Ts>...>;
@@ -413,11 +442,34 @@
public:
template <typename T>
Binding<stage, Ts..., T> Arg() && {
+ static_assert(!internal::HasOptionalArgTag<Ts...>::value,
+ "argument can't be passed after optional argument");
+ static_assert(!internal::HasRemainingArgsTag<Ts...>::value,
+ "argument can't be passed after remaining arguments");
return {std::move(*this)};
}
template <typename T>
Binding<stage, Ts..., internal::RetTag<T>> Ret() && {
+ static_assert(!internal::HasOptionalRetTag<Ts...>::value,
+ "result can't be passed after optional result");
+ static_assert(!internal::HasRemainingRetsTag<Ts...>::value,
+ "result can't be passed after remaining results");
+ return {std::move(*this)};
+ }
+
+ template <typename T>
+ Binding<stage, Ts..., internal::OptionalArgTag<T>> OptionalArg() && {
+ static_assert(
+ !internal::HasRemainingArgsTag<Ts...>::value,
+ "optional argument can't be passed after remaining arguments");
+ return {std::move(*this)};
+ }
+
+ template <typename T>
+ Binding<stage, Ts..., internal::OptionalRetTag<T>> OptionalRet() && {
+ static_assert(!internal::HasRemainingRetsTag<Ts...>::value,
+ "optional result can't be passed after remaining results");
return {std::move(*this)};
}
@@ -427,7 +479,7 @@
return {std::move(*this)};
}
- Binding<stage, Ts..., internal::RemainingRetsTag> RemainingResults() && {
+ Binding<stage, Ts..., internal::RemainingRetsTag> RemainingRets() && {
static_assert(!internal::HasRemainingRetsTag<Ts...>::value,
"remaining results can be passed just once");
return {std::move(*this)};
@@ -900,10 +952,20 @@
}
};
-} // namespace internal
+template <typename T>
+struct Decode<OptionalArgTag<T>> {
+ static std::optional<std::optional<T>> call(DecodingOffsets& offsets,
+ DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ if (offsets.args >= ctx.call_frame->args.size) {
+ return std::optional<T>(std::nullopt);
+ }
+ return Decode<T>::call(offsets, ctx, diagnostic);
+ }
+};
template <typename T>
-struct internal::Decode<internal::RetTag<T>> {
+struct Decode<RetTag<T>> {
static std::optional<Result<T>> call(DecodingOffsets& offsets,
DecodingContext& ctx,
DiagnosticEngine& diagnostic) {
@@ -914,7 +976,19 @@
};
template <typename T>
-struct internal::Decode<internal::AttrTag<T>> {
+struct Decode<OptionalRetTag<T>> {
+ static std::optional<std::optional<Result<T>>> call(
+ DecodingOffsets& offsets, DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ if (offsets.rets >= ctx.call_frame->rets.size) {
+ return std::optional<Result<T>>(std::nullopt);
+ }
+ return Decode<RetTag<T>>::call(offsets, ctx, diagnostic);
+ }
+};
+
+template <typename T>
+struct Decode<AttrTag<T>> {
using R = typename AttrDecoding<T>::Type;
static std::optional<R> call(DecodingOffsets& offsets, DecodingContext& ctx,
@@ -946,7 +1020,7 @@
};
template <typename T>
-struct internal::Decode<internal::CtxTag<T>> {
+struct Decode<CtxTag<T>> {
using R = typename CtxDecoding<T>::Type;
static std::optional<R> call(DecodingOffsets& offsets, DecodingContext& ctx,
@@ -956,75 +1030,17 @@
}
};
-//===----------------------------------------------------------------------===//
-// Expected
-//===----------------------------------------------------------------------===//
-
-// Forward declare.
-template <typename E>
-class Unexpected;
-
-// TODO(slebedev): Replace with `std::expected` when C++23 is available.
-template <typename T, typename E>
-class Expected {
- public:
- constexpr Expected(T value) : data_(std::move(value)) {} // NOLINT
- constexpr Expected(Unexpected<E> u); // NOLINT
-
- constexpr operator bool() const { // NOLINT
- return has_value();
- }
-
- constexpr T& operator*() & { return value(); }
- constexpr const T& operator*() const& { return value(); }
- constexpr T&& operator*() && { return std::move(value()); }
- constexpr const T& operator*() const&& { return std::move(value()); }
-
- constexpr T* operator->() { return &value(); }
- constexpr const T* operator->() const { return &value(); }
-
- constexpr bool has_value() const { return std::holds_alternative<T>(data_); }
- constexpr bool has_error() const { return std::holds_alternative<E>(data_); }
-
- constexpr T& value() & { return std::get<T>(data_); }
- constexpr const T& value() const& { return std::get<T>(data_); }
- constexpr T&& value() && { return std::get<T>(std::move(data_)); }
- constexpr const T& value() const&& { return std::get<T>(std::move(data_)); }
-
- constexpr E& error() & { return std::get<E>(data_); }
- constexpr const E& error() const& { return std::get<E>(data_); }
- constexpr E&& error() && { return std::get<E>(std::move(data_)); }
- constexpr const E&& error() const&& { return std::get<E>(std::move(data_)); }
-
- private:
- std::variant<T, E> data_;
-};
-
-template <typename E>
-class Unexpected {
- public:
- constexpr Unexpected(E error) : error_(std::move(error)) {} // NOLINT
-
- private:
- template <typename, typename>
- friend class Expected;
-
- E error_;
-};
-
-Unexpected(const char*) -> Unexpected<std::string>;
-
-template <typename T, typename E>
-constexpr Expected<T, E>::Expected(Unexpected<E> u)
- : data_(std::move(u.error_)) {}
+} // namespace internal
//===----------------------------------------------------------------------===//
// Type-safe wrapper for accessing a variable number of arguments.
//===----------------------------------------------------------------------===//
-class RemainingArgs {
+namespace internal {
+
+class RemainingArgsBase {
public:
- RemainingArgs(const XLA_FFI_Args* args, size_t offset)
+ RemainingArgsBase(const XLA_FFI_Args* args, size_t offset)
: args_(args), offset_(offset) {
assert(offset <= args_->size && "illegal remaining args offset");
}
@@ -1032,43 +1048,26 @@
size_t size() const { return args_->size - offset_; }
bool empty() const { return size() == 0; }
- template <typename T>
- Expected<T, std::string> get(size_t index) const {
- size_t idx = offset_ + index;
- if (idx >= args_->size) {
- return Unexpected("Index out of range.");
- }
-
- DiagnosticEngine diagnostic;
- auto value_opt =
- ArgDecoding<T>::Decode(args_->types[idx], args_->args[idx], diagnostic);
- if (!value_opt.has_value()) {
- return Unexpected(diagnostic.Result());
- }
- return *value_opt;
- }
+ protected:
+ const XLA_FFI_Args* args() const { return args_; }
+ size_t offset() const { return offset_; }
private:
- const XLA_FFI_Args* args_; // not owned
+ const XLA_FFI_Args* args_;
size_t offset_;
};
-template <>
-struct internal::Decode<internal::RemainingArgsTag> {
- static std::optional<RemainingArgs> call(DecodingOffsets& offsets,
- DecodingContext& ctx,
- DiagnosticEngine& diagnostic) {
- return RemainingArgs(&ctx.call_frame->args, offsets.args);
- }
-};
+} // namespace internal
//===----------------------------------------------------------------------===//
// Type-safe wrapper for accessing a variable number of results.
//===----------------------------------------------------------------------===//
-class RemainingResults {
+namespace internal {
+
+class RemainingRetsBase {
public:
- RemainingResults(const XLA_FFI_Rets* rets, size_t offset)
+ RemainingRetsBase(const XLA_FFI_Rets* rets, size_t offset)
: rets_(rets), offset_(offset) {
assert(offset <= rets_->size && "illegal remaining rets offset");
}
@@ -1076,43 +1075,30 @@
size_t size() const { return rets_->size - offset_; }
bool empty() const { return size() == 0; }
- template <typename T>
- Expected<T, std::string> get(size_t index) const {
- size_t idx = offset_ + index;
- if (idx >= rets_->size) {
- return Unexpected("Index out of range.");
- }
-
- DiagnosticEngine diagnostic;
- auto value_opt =
- RetDecoding<T>::Decode(rets_->types[idx], rets_->rets[idx], diagnostic);
- if (!value_opt.has_value()) {
- return Unexpected(diagnostic.Result());
- }
- return **value_opt;
- }
+ protected:
+ const XLA_FFI_Rets* rets() const { return rets_; }
+ size_t offset() const { return offset_; }
private:
const XLA_FFI_Rets* rets_; // not owned
size_t offset_;
};
-template <>
-struct internal::Decode<internal::RemainingRetsTag> {
- static std::optional<RemainingResults> call(DecodingOffsets& offsets,
- DecodingContext& ctx,
- DiagnosticEngine& diagnostic) {
- return RemainingResults(&ctx.call_frame->rets, offsets.rets);
- }
-};
+} // namespace internal
//===----------------------------------------------------------------------===//
// Type-safe wrapper for accessing dictionary attributes.
//===----------------------------------------------------------------------===//
-class Dictionary {
+namespace internal {
+
+// Forward declare dictionary attribute decoding defined below.
+template <typename T, typename... Ts>
+struct DecodeDictionaryAttr;
+
+class DictionaryBase {
public:
- explicit Dictionary(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {}
+ explicit DictionaryBase(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {}
size_t size() const { return attrs_->size; }
@@ -1120,21 +1106,15 @@
return Find(name) < attrs_->size;
}
- template <typename T>
- Expected<T, std::string> get(std::string_view name) const {
- DiagnosticEngine diagnostic;
- auto value_opt = get<T>(name, diagnostic);
- if (!value_opt.has_value()) {
- return Unexpected(diagnostic.Result());
- }
- return *value_opt;
- }
+ protected:
+ template <typename T, typename... Ts>
+ friend struct DecodeDictionaryAttr;
template <typename T>
std::optional<T> get(std::string_view name,
DiagnosticEngine& diagnostic) const {
size_t idx = Find(name);
- if (idx >= attrs_->size) {
+ if (XLA_FFI_PREDICT_FALSE(idx >= attrs_->size)) {
return diagnostic.Emit("Unexpected attribute: ") << name;
}
@@ -1161,15 +1141,11 @@
const XLA_FFI_Attrs* attrs_;
};
-// Decode `AttrsTag` into a generic `Dictionary` attribute.
-template <>
-struct internal::Decode<internal::AttrsTag<Dictionary>> {
- static std::optional<Dictionary> call(DecodingOffsets& offsets,
- DecodingContext& ctx,
- DiagnosticEngine& diagnostic) {
- return Dictionary(&ctx.call_frame->attrs);
- }
-};
+} // namespace internal
+
+//===----------------------------------------------------------------------===//
+// Decoding for aggregate attributes (decoding dictionaries into structs).
+//===----------------------------------------------------------------------===//
// Decode `AttrsTag` into a type `T` relying on struct decoding defined below.
template <typename T>
@@ -1186,6 +1162,13 @@
// Template metaprogramming for decoding handler signature
//===----------------------------------------------------------------------===//
+// Forward declare classes for decoding variadic number of arguments and
+// results. They are defined in `ffi.h` headers (internal and external), to be
+// able to use slightly different implementations for internal and external
+// FFI (`absl::StatusOr` vs `ffi::ErrorOr`).
+class RemainingArgs;
+class RemainingRets;
+
namespace internal {
// A helper struct to extract the type of the handler argument.
template <typename T>
@@ -1193,23 +1176,31 @@
using Type = T;
};
+template <typename T>
+struct FnArgType<internal::OptionalArgTag<T>> {
+ using Type = std::optional<T>;
+};
+
template <>
struct FnArgType<internal::RemainingArgsTag> {
using Type = RemainingArgs;
};
-template <>
-struct FnArgType<internal::RemainingRetsTag> {
- using Type = RemainingResults;
-};
-
-// Extracts the underlying type from the returned result type tag.
template <typename T>
struct FnArgType<internal::RetTag<T>> {
using Type = Result<T>;
};
-// Extracts the underlying type from the attribute type tag.
+template <typename T>
+struct FnArgType<internal::OptionalRetTag<T>> {
+ using Type = std::optional<Result<T>>;
+};
+
+template <>
+struct FnArgType<internal::RemainingRetsTag> {
+ using Type = RemainingRets;
+};
+
template <typename T>
struct FnArgType<internal::AttrTag<T>> {
using Type = typename AttrDecoding<T>::Type;
@@ -1220,7 +1211,6 @@
using Type = T;
};
-// Extracts the underlying type from the context type tag.
template <typename T>
struct FnArgType<internal::CtxTag<T>> {
using Type = typename CtxDecoding<T>::Type;
@@ -1230,20 +1220,27 @@
// a special decoding rule defined by template specialization.
template <typename>
struct IsTagged : std::false_type {};
+
+template <typename T>
+struct IsTagged<OptionalArgTag<T>> : std::true_type {};
template <typename T>
struct IsTagged<RetTag<T>> : std::true_type {};
template <typename T>
+struct IsTagged<OptionalRetTag<T>> : std::true_type {};
+template <typename T>
struct IsTagged<AttrTag<T>> : std::true_type {};
template <typename T>
struct IsTagged<AttrsTag<T>> : std::true_type {};
template <typename T>
struct IsTagged<CtxTag<T>> : std::true_type {};
+
template <>
struct IsTagged<RemainingArgsTag> : std::true_type {};
template <>
struct IsTagged<RemainingRetsTag> : std::true_type {};
-// A template for counting regular arguments in the Ts pack.
+// A template for counting regular arguments in the Ts pack (arguments that are
+// not wrapped into a special tag).
template <typename... Ts>
struct NumArgs;
@@ -1269,9 +1266,15 @@
static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value;
+ static constexpr int64_t kNumOptionalArgs =
+ internal::NumTagged<internal::OptionalArgTag, Ts...>::value;
+
static constexpr int64_t kNumRets =
internal::NumTagged<internal::RetTag, Ts...>::value;
+ static constexpr int64_t kNumOptionalRets =
+ internal::NumTagged<internal::OptionalRetTag, Ts...>::value;
+
static constexpr int64_t kNumAttrs =
internal::NumTagged<internal::AttrTag, Ts...>::value;
@@ -1292,22 +1295,22 @@
public:
XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame) const override {
// Sanity checking call frame struct size.
- if (auto* err = CheckStructSize(call_frame->api, "XLA_FFI_CallFrame",
- XLA_FFI_CallFrame_STRUCT_SIZE,
- call_frame->struct_size))
+ if (XLA_FFI_Error* err = CheckStructSize(
+ call_frame->api, "XLA_FFI_CallFrame", XLA_FFI_CallFrame_STRUCT_SIZE,
+ call_frame->struct_size)) {
return err;
+ }
// Check the API versions.
- auto api_version = call_frame->api->api_version;
+ const XLA_FFI_Api_Version& api_version = call_frame->api->api_version;
if (api_version.major_version != XLA_FFI_API_MAJOR ||
api_version.minor_version != XLA_FFI_API_MINOR) {
return InvalidArgument(
call_frame->api,
StrCat("FFI handler's API version (", XLA_FFI_API_MAJOR, ".",
- XLA_FFI_API_MINOR,
- ") does not match the framework's API version (",
- api_version.major_version, ".", api_version.minor_version,
- ")"));
+ XLA_FFI_API_MINOR, ") does not match the framework's API ",
+ "version (", api_version.major_version, ".",
+ api_version.minor_version, ")"));
}
// Check that handler is called during correct execution stage.
@@ -1321,12 +1324,21 @@
// Check that the number of passed arguments matches the signature. Each
// individual argument decoding will check the actual type.
- if (internal::HasRemainingArgsTag<Ts...>::value) {
+ if constexpr (internal::HasRemainingArgsTag<Ts...>::value) {
if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) {
return InvalidArgument(
call_frame->api,
StrCat("Wrong number of arguments: expected at least ",
- kNumArgs - 1, " but got ", call_frame->args.size));
+ kNumArgs - kNumOptionalArgs - 1, " but got ",
+ call_frame->args.size));
+ }
+ } else if constexpr (internal::HasOptionalArgTag<Ts...>::value) {
+ if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) {
+ return InvalidArgument(
+ call_frame->api,
+ StrCat("Wrong number of arguments: expected at least ",
+ kNumArgs - kNumOptionalArgs, " but got ",
+ call_frame->args.size));
}
} else {
if (XLA_FFI_PREDICT_FALSE(call_frame->args.size != kNumArgs)) {
@@ -1339,12 +1351,21 @@
// Check that the number of results matches the signature. Each individual
// result decoding will check the actual type.
- if (internal::HasRemainingRetsTag<Ts...>::value) {
+ if constexpr (internal::HasRemainingRetsTag<Ts...>::value) {
if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) {
return InvalidArgument(
call_frame->api,
- StrCat("Wrong number of results: expected at least ", kNumRets - 1,
- " but got ", call_frame->rets.size));
+ StrCat("Wrong number of results: expected at least ",
+ kNumRets - kNumOptionalRets - 1, " but got ",
+ call_frame->rets.size));
+ }
+ } else if constexpr (internal::HasOptionalRetTag<Ts...>::value) {
+ if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) {
+ return InvalidArgument(
+ call_frame->api,
+ StrCat("Wrong number of results: expected at least ",
+ kNumRets - kNumOptionalRets, " but got ",
+ call_frame->rets.size));
}
} else {
if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size != kNumRets)) {
@@ -1515,21 +1536,6 @@
}
};
-template <>
-struct AttrDecoding<Dictionary> {
- using Type = Dictionary;
- static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
- DiagnosticEngine& diagnostic) {
- if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
- return diagnostic.Emit("Wrong attribute type: expected ")
- << XLA_FFI_AttrType_DICTIONARY << " but got " << type;
- }
-
- auto* attrs = reinterpret_cast<XLA_FFI_Attrs*>(attr);
- return Dictionary(attrs);
- }
-};
-
//===----------------------------------------------------------------------===//
// Automatic dictionary attributes to structs decoding.
//===----------------------------------------------------------------------===//
@@ -1574,7 +1580,7 @@
//
// Consider using `static auto decoder = ...` below, and compute mapping in
// constructor. Add benchmarks first to know what to improve!
- Dictionary dict(attrs);
+ internal::DictionaryBase dict(attrs);
std::tuple<std::optional<Ts>...> members = {
dict.get<Ts>(names[Is], diagnostic)...};
@@ -1637,7 +1643,7 @@
// type to decode the attribute as a scalar value and cast it to the enum type.
#define XLA_FFI_REGISTER_ENUM_ATTR_DECODING(T) \
template <> \
- struct ::xla::ffi::AttrDecoding<T> { \
+ struct xla::ffi::AttrDecoding<T> { \
using Type = T; \
using U = std::underlying_type_t<Type>; \
static_assert(std::is_enum<Type>::value, "Expected enum class"); \
diff --git a/third_party/xla/xla/ffi/api/c_api_internal.h b/third_party/xla/xla/ffi/api/c_api_internal.h
index 3c5c2ba..da5ea32 100644
--- a/third_party/xla/xla/ffi/api/c_api_internal.h
+++ b/third_party/xla/xla/ffi/api/c_api_internal.h
@@ -74,6 +74,11 @@
typedef void* XLA_FFI_INTERNAL_ExecutionState_Get(
XLA_FFI_ExecutionContext* ctx);
+// Returns a pointer to the `Eigen::ThreadPoolDevice` passed via run options,
+// which allows FFI handlers to execute tasks in the same thread pool as XLA.
+typedef void* XLA_FFI_INTERNAL_IntraOpThreadPool_Get(
+ XLA_FFI_ExecutionContext* ctx);
+
//===----------------------------------------------------------------------===//
// API access
//===----------------------------------------------------------------------===//
@@ -89,6 +94,7 @@
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_CalledComputation_Get);
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_ExecutionContext_Get);
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_ExecutionState_Get);
+ _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_IntraOpThreadPool_Get);
};
#undef _XLA_FFI_INTERNAL_API_STRUCT_FIELD
diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h
index c48996b..c35fdce 100644
--- a/third_party/xla/xla/ffi/api/ffi.h
+++ b/third_party/xla/xla/ffi/api/ffi.h
@@ -36,6 +36,7 @@
#include <string_view>
#include <type_traits>
#include <utility>
+#include <variant>
#include <vector>
#include "xla/ffi/api/c_api.h"
@@ -75,6 +76,30 @@
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
};
+// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
+// with xla that defines PrimitiveType enums in ::xla namespace.
+inline constexpr DataType PRED = DataType::PRED;
+inline constexpr DataType S8 = DataType::S8;
+inline constexpr DataType S16 = DataType::S16;
+inline constexpr DataType S32 = DataType::S32;
+inline constexpr DataType S64 = DataType::S64;
+inline constexpr DataType U8 = DataType::U8;
+inline constexpr DataType U16 = DataType::U16;
+inline constexpr DataType U32 = DataType::U32;
+inline constexpr DataType U64 = DataType::U64;
+inline constexpr DataType F16 = DataType::F16;
+inline constexpr DataType F32 = DataType::F32;
+inline constexpr DataType F64 = DataType::F64;
+inline constexpr DataType BF16 = DataType::BF16;
+inline constexpr DataType C64 = DataType::C64;
+inline constexpr DataType C128 = DataType::C128;
+inline constexpr DataType TOKEN = DataType::TOKEN;
+inline constexpr DataType F8E5M2 = DataType::F8E5M2;
+inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
+inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
+inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
+inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
+
inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
}
@@ -149,7 +174,7 @@
};
//===----------------------------------------------------------------------===//
-// Error and ErrorOr
+// Error
//===----------------------------------------------------------------------===//
enum class ErrorCode : uint8_t {
@@ -182,19 +207,89 @@
Error(XLA_FFI_Error_Code errc, std::string message)
: Error(static_cast<ErrorCode>(errc), std::move(message)) {}
- static Error Success() { return Error(); }
-
bool success() const { return errc_ == ErrorCode::kOk; }
bool failure() const { return !success(); }
std::optional<ErrorCode> errc() const { return errc_; }
const std::string& message() const { return message_; }
+ static Error Success() { return Error(); }
+
+ static Error Internal(std::string message) {
+ return Error(ErrorCode::kInternal, std::move(message));
+ }
+
+ static Error InvalidArgument(std::string message) {
+ return Error(ErrorCode::kInvalidArgument, std::move(message));
+ }
+
private:
ErrorCode errc_ = ErrorCode::kOk;
std::string message_;
};
+//===----------------------------------------------------------------------===//
+// Expected<T, E> and ErrorOr<T>
+//===----------------------------------------------------------------------===//
+
+// Forward declare.
+template <typename E>
+class Unexpected;
+
+// TODO(slebedev): Replace with `std::expected` when C++23 is available.
+template <typename T, typename E>
+class Expected {
+ public:
+ constexpr Expected(T value) : data_(std::move(value)) {} // NOLINT
+ constexpr Expected(Unexpected<E> u); // NOLINT
+
+ constexpr operator bool() const { // NOLINT
+ return has_value();
+ }
+
+ constexpr T& operator*() & { return value(); }
+ constexpr const T& operator*() const& { return value(); }
+ constexpr T&& operator*() && { return std::move(value()); }
+ constexpr const T& operator*() const&& { return std::move(value()); }
+
+ constexpr T* operator->() { return &value(); }
+ constexpr const T* operator->() const { return &value(); }
+
+ constexpr bool has_value() const { return std::holds_alternative<T>(data_); }
+ constexpr bool has_error() const { return std::holds_alternative<E>(data_); }
+
+ constexpr T& value() & { return std::get<T>(data_); }
+ constexpr const T& value() const& { return std::get<T>(data_); }
+ constexpr T&& value() && { return std::get<T>(std::move(data_)); }
+ constexpr const T& value() const&& { return std::get<T>(std::move(data_)); }
+
+ constexpr E& error() & { return std::get<E>(data_); }
+ constexpr const E& error() const& { return std::get<E>(data_); }
+ constexpr E&& error() && { return std::get<E>(std::move(data_)); }
+ constexpr const E&& error() const&& { return std::get<E>(std::move(data_)); }
+
+ private:
+ std::variant<T, E> data_;
+};
+
+template <typename E>
+class Unexpected {
+ public:
+ constexpr Unexpected(E error) : error_(std::move(error)) {} // NOLINT
+
+ private:
+ template <typename, typename>
+ friend class Expected;
+
+ E error_;
+};
+
+Unexpected(const char*) -> Unexpected<std::string>;
+
+template <typename T, typename E>
+constexpr Expected<T, E>::Expected(Unexpected<E> u)
+ : data_(std::move(u.error_)) {}
+
template <typename T>
class ErrorOr : public Expected<T, Error> {
public:
@@ -484,6 +579,42 @@
};
//===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of arguments.
+//===----------------------------------------------------------------------===//
+
+class RemainingArgs : public internal::RemainingArgsBase {
+ public:
+ using internal::RemainingArgsBase::RemainingArgsBase;
+
+ template <typename T>
+ ErrorOr<T> get(size_t index) const {
+ size_t idx = offset() + index;
+ if (XLA_FFI_PREDICT_FALSE(idx >= args()->size)) {
+ return Unexpected(
+ Error(ErrorCode::kInvalidArgument, "Index out of range"));
+ }
+
+ DiagnosticEngine diagnostic;
+ std::optional<T> value = ArgDecoding<T>::Decode(
+ args()->types[idx], args()->args[idx], diagnostic);
+ if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
+ return Unexpected(Error::Internal(diagnostic.Result()));
+ }
+
+ return *value;
+ }
+};
+
+template <>
+struct internal::Decode<internal::RemainingArgsTag> {
+ static std::optional<RemainingArgs> call(DecodingOffsets& offsets,
+ DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ return RemainingArgs(&ctx.call_frame->args, offsets.args);
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Results decoding
//===----------------------------------------------------------------------===//
@@ -524,6 +655,42 @@
};
//===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of results.
+//===----------------------------------------------------------------------===//
+
+class RemainingRets : public internal::RemainingRetsBase {
+ public:
+ using internal::RemainingRetsBase::RemainingRetsBase;
+
+ template <typename T>
+ ErrorOr<Result<T>> get(size_t index) const {
+ size_t idx = offset() + index;
+ if (XLA_FFI_PREDICT_FALSE(idx >= rets()->size)) {
+ return Unexpected(
+ Error(ErrorCode::kInvalidArgument, "Index out of range"));
+ }
+
+ DiagnosticEngine diagnostic;
+ std::optional<Result<T>> value = RetDecoding<T>::Decode(
+ rets()->types[idx], rets()->rets[idx], diagnostic);
+ if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
+ return Unexpected(Error::Internal(diagnostic.Result()));
+ }
+
+ return *value;
+ }
+};
+
+template <>
+struct internal::Decode<internal::RemainingRetsTag> {
+ static std::optional<RemainingRets> call(DecodingOffsets& offsets,
+ DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ return RemainingRets(&ctx.call_frame->rets, offsets.rets);
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Attributes decoding
//===----------------------------------------------------------------------===//
@@ -581,6 +748,49 @@
};
//===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing dictionary attributes.
+//===----------------------------------------------------------------------===//
+
+class Dictionary : public internal::DictionaryBase {
+ public:
+ using internal::DictionaryBase::DictionaryBase;
+
+ template <typename T>
+ ErrorOr<T> get(std::string_view name) const {
+ DiagnosticEngine diagnostic;
+ std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
+ if (!value.has_value()) {
+ return Unexpected(Error::Internal(diagnostic.Result()));
+ }
+ return *value;
+ }
+};
+
+// Decode `AttrsTag` (all attributes) into a `Dictionary`.
+template <>
+struct internal::Decode<internal::AttrsTag<Dictionary>> {
+ static std::optional<Dictionary> call(DecodingOffsets& offsets,
+ DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ return Dictionary(&ctx.call_frame->attrs);
+ }
+};
+
+// Decode individual attribute into `Dictionary` type.
+template <>
+struct AttrDecoding<Dictionary> {
+ using Type = Dictionary;
+ static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
+ DiagnosticEngine& diagnostic) {
+ if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
+ return diagnostic.Emit("Wrong attribute type: expected ")
+ << XLA_FFI_AttrType_DICTIONARY << " but got " << type;
+ }
+ return Dictionary(reinterpret_cast<XLA_FFI_Attrs*>(attr));
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Error helpers
//===----------------------------------------------------------------------===//
@@ -758,6 +968,7 @@
internal::DestroyError(api_, error);
return std::nullopt;
}
+ allocations_.push_back({size, args.data});
return args.data;
}
diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc
index 8db0d46..2bbfd04 100644
--- a/third_party/xla/xla/ffi/api/ffi_test.cc
+++ b/third_party/xla/xla/ffi/api/ffi_test.cc
@@ -19,7 +19,9 @@
#include <cstdint>
#include <limits>
#include <memory>
+#include <optional>
#include <string>
+#include <string_view>
#include <type_traits>
#include <vector>
@@ -237,12 +239,11 @@
builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
auto call_frame = builder.Build();
- auto handler =
- Ffi::Bind().Arg<BufferR2<DataType::F32>>().To([&](auto buffer) {
- EXPECT_EQ(buffer.typed_data(), storage.data());
- EXPECT_EQ(buffer.dimensions().size(), 2);
- return Error::Success();
- });
+ auto handler = Ffi::Bind().Arg<BufferR2<F32>>().To([&](auto buffer) {
+ EXPECT_EQ(buffer.typed_data(), storage.data());
+ EXPECT_EQ(buffer.dimensions().size(), 2);
+ return Error::Success();
+ });
auto status = Call(*handler, call_frame);
TF_ASSERT_OK(status);
@@ -270,7 +271,7 @@
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
auto call_frame = builder.Build();
- auto handler = Ffi::Bind().Arg<BufferR1<DataType::F32>>().To(
+ auto handler = Ffi::Bind().Arg<BufferR1<F32>>().To(
[](auto) { return Error::Success(); });
auto status = Call(*handler, call_frame);
@@ -286,7 +287,7 @@
builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
auto call_frame = builder.Build();
- auto handler = Ffi::Bind().Arg<BufferR1<DataType::F32>>().To(
+ auto handler = Ffi::Bind().Arg<BufferR1<F32>>().To(
[](auto) { return Error::Success(); });
auto status = Call(*handler, call_frame);
@@ -303,7 +304,7 @@
builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2});
auto call_frame = builder.Build();
- auto handler = Ffi::Bind().Arg<BufferR2<DataType::F32>>().To(
+ auto handler = Ffi::Bind().Arg<BufferR2<F32>>().To(
[](auto) { return Error::Success(); });
auto status = Call(*handler, call_frame);
@@ -322,7 +323,7 @@
auto fn = [&](Token tok) {
EXPECT_EQ(tok.typed_data(), nullptr);
EXPECT_EQ(tok.dimensions().size(), 0);
- return ffi::Error::Success();
+ return Error::Success();
};
auto handler = Ffi::Bind().Arg<Token>().To(fn);
@@ -330,6 +331,182 @@
TF_ASSERT_OK(status);
}
+TEST(FfiTest, RemainingArgs) {
+ std::vector<float> storage(4, 0.0f);
+ se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+ CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0);
+ builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+ auto call_frame = builder.Build();
+
+ auto fn = [&](RemainingArgs args) {
+ EXPECT_EQ(args.size(), 1);
+
+ ErrorOr<AnyBuffer> arg0 = args.get<AnyBuffer>(0);
+ ErrorOr<AnyBuffer> arg1 = args.get<AnyBuffer>(1);
+
+ EXPECT_TRUE(arg0.has_value());
+ EXPECT_FALSE(arg1.has_value());
+
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().RemainingArgs().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, RemainingRets) {
+ std::vector<float> storage(4, 0.0f);
+ se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+ CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/2);
+ builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+ builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+ auto call_frame = builder.Build();
+
+ auto fn = [&](Result<AnyBuffer> ret, RemainingRets rets) {
+ EXPECT_EQ(rets.size(), 1);
+
+ ErrorOr<Result<AnyBuffer>> ret0 = rets.get<AnyBuffer>(0);
+ ErrorOr<Result<AnyBuffer>> ret1 = rets.get<AnyBuffer>(1);
+
+ EXPECT_TRUE(ret0.has_value());
+ EXPECT_FALSE(ret1.has_value());
+
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().Ret<AnyBuffer>().RemainingRets().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, OptionalArgs) {
+ std::vector<float> storage(4, 0.0f);
+ se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+ CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0);
+ builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+ auto call_frame = builder.Build();
+
+ { // Single optional argument.
+ auto fn = [&](std::optional<AnyBuffer> arg0) {
+ EXPECT_TRUE(arg0.has_value());
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Two optional arguments.
+ auto fn = [&](std::optional<AnyBuffer> arg0,
+ std::optional<AnyBuffer> arg1) {
+ EXPECT_TRUE(arg0.has_value());
+ EXPECT_FALSE(arg1.has_value());
+ return Error::Success();
+ };
+
+ auto handler =
+ Ffi::Bind().OptionalArg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Optional argument after a regular one.
+ auto fn = [&](AnyBuffer arg0, std::optional<AnyBuffer> arg1) {
+ EXPECT_FALSE(arg1.has_value());
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().Arg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Remaining arguments after optional one.
+ auto fn = [&](std::optional<AnyBuffer> arg0, RemainingArgs args) {
+ EXPECT_TRUE(arg0.has_value());
+ EXPECT_EQ(args.size(), 0);
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().RemainingArgs().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+}
+
+TEST(FfiTest, OptionalRets) {
+ std::vector<float> storage(4, 0.0f);
+ se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+ CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1);
+ builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+ auto call_frame = builder.Build();
+
+ { // Single optional result.
+ auto fn = [&](std::optional<Result<AnyBuffer>> ret0) {
+ EXPECT_TRUE(ret0.has_value());
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Two optional results.
+ auto fn = [&](std::optional<Result<AnyBuffer>> ret0,
+ std::optional<Result<AnyBuffer>> ret1) {
+ EXPECT_TRUE(ret0.has_value());
+ EXPECT_FALSE(ret1.has_value());
+ return Error::Success();
+ };
+
+ auto handler =
+ Ffi::Bind().OptionalRet<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Optional result after a regular one.
+ auto fn = [&](Result<AnyBuffer> ret0,
+ std::optional<Result<AnyBuffer>> ret1) {
+ EXPECT_FALSE(ret1.has_value());
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().Ret<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Remaining results after optional one.
+ auto fn = [&](std::optional<Result<AnyBuffer>> ret0, RemainingRets rets) {
+ EXPECT_TRUE(ret0.has_value());
+ EXPECT_EQ(rets.size(), 0);
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().RemainingRets().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+}
+
TEST(FfiTest, AutoBinding) {
static constexpr char kI32[] = "i32";
@@ -463,6 +640,150 @@
TF_ASSERT_OK(status);
}
+TEST(FfiTest, AttrsAsDictionary) {
+ CallFrameBuilder::AttributesBuilder attrs;
+ attrs.Insert("i32", 42);
+ attrs.Insert("f32", 42.0f);
+ attrs.Insert("str", "foo");
+
+ CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+ builder.AddAttributes(attrs.Build());
+ auto call_frame = builder.Build();
+
+ auto fn = [&](Dictionary dict) {
+ EXPECT_EQ(dict.size(), 3);
+
+ EXPECT_TRUE(dict.contains("i32"));
+ EXPECT_TRUE(dict.contains("f32"));
+ EXPECT_TRUE(dict.contains("str"));
+
+ ErrorOr<int32_t> i32 = dict.get<int32_t>("i32");
+ ErrorOr<float> f32 = dict.get<float>("f32");
+ ErrorOr<std::string_view> str = dict.get<std::string_view>("str");
+
+ EXPECT_TRUE(i32.has_value());
+ EXPECT_TRUE(f32.has_value());
+ EXPECT_TRUE(str.has_value());
+
+ if (i32.has_value()) EXPECT_EQ(*i32, 42);
+ if (f32.has_value()) EXPECT_EQ(*f32, 42.0f);
+ if (str.has_value()) EXPECT_EQ(*str, "foo");
+
+ EXPECT_FALSE(dict.contains("i64"));
+ EXPECT_FALSE(dict.get<int64_t>("i32").has_value());
+ EXPECT_FALSE(dict.get<int64_t>("i64").has_value());
+
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().Attrs().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, DictionaryAttr) {
+ CallFrameBuilder::FlatAttributesMap dict0;
+ dict0.try_emplace("i32", 42);
+
+ CallFrameBuilder::FlatAttributesMap dict1;
+ dict1.try_emplace("f32", 42.0f);
+
+ CallFrameBuilder::AttributesBuilder attrs;
+ attrs.Insert("dict0", dict0);
+ attrs.Insert("dict1", dict1);
+
+ CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+ builder.AddAttributes(attrs.Build());
+ auto call_frame = builder.Build();
+
+ auto fn = [&](Dictionary dict0, Dictionary dict1) {
+ EXPECT_EQ(dict0.size(), 1);
+ EXPECT_EQ(dict1.size(), 1);
+
+ EXPECT_TRUE(dict0.contains("i32"));
+ EXPECT_TRUE(dict1.contains("f32"));
+
+ ErrorOr<int32_t> i32 = dict0.get<int32_t>("i32");
+ ErrorOr<float> f32 = dict1.get<float>("f32");
+
+ EXPECT_TRUE(i32.has_value());
+ EXPECT_TRUE(f32.has_value());
+
+ if (i32.has_value()) EXPECT_EQ(*i32, 42);
+ if (f32.has_value()) EXPECT_EQ(*f32, 42.0f);
+
+ return Error::Success();
+ };
+
+ auto handler =
+ Ffi::Bind().Attr<Dictionary>("dict0").Attr<Dictionary>("dict1").To(fn);
+
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+}
+
+struct PairOfI32AndF32 {
+ int32_t i32;
+ float f32;
+};
+
+XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(PairOfI32AndF32,
+ StructMember<int32_t>("i32"),
+ StructMember<float>("f32"));
+
+TEST(FfiTest, StructAttr) {
+ CallFrameBuilder::FlatAttributesMap dict;
+ dict.try_emplace("i32", 42);
+ dict.try_emplace("f32", 42.0f);
+
+ CallFrameBuilder::AttributesBuilder attrs;
+ attrs.Insert("str", "foo");
+ attrs.Insert("i32_and_f32", dict);
+
+ CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+ builder.AddAttributes(attrs.Build());
+ auto call_frame = builder.Build();
+
+ auto fn = [&](std::string_view str, PairOfI32AndF32 i32_and_f32) {
+ EXPECT_EQ(str, "foo");
+ EXPECT_EQ(i32_and_f32.i32, 42);
+ EXPECT_EQ(i32_and_f32.f32, 42.0f);
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind()
+ .Attr<std::string_view>("str")
+ .Attr<PairOfI32AndF32>("i32_and_f32")
+ .To(fn);
+
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, AttrsAsStruct) {
+ CallFrameBuilder::AttributesBuilder attrs;
+ attrs.Insert("i32", 42);
+ attrs.Insert("f32", 42.0f);
+
+ CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+ builder.AddAttributes(attrs.Build());
+ auto call_frame = builder.Build();
+
+ auto fn = [&](PairOfI32AndF32 i32_and_f32) {
+ EXPECT_EQ(i32_and_f32.i32, 42);
+ EXPECT_EQ(i32_and_f32.f32, 42.0f);
+ return Error::Success();
+ };
+
+ auto handler = Ffi::Bind().Attrs<PairOfI32AndF32>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+}
+
TEST(FfiTest, PointerAttr) {
std::string foo = "foo";
@@ -641,14 +962,19 @@
// A test only memory allocator that returns a fixed memory address.
struct TestDeviceMemoryAllocator final : public se::DeviceMemoryAllocator {
- TestDeviceMemoryAllocator() : se::DeviceMemoryAllocator(nullptr) {}
+ size_t count;
+
+ TestDeviceMemoryAllocator()
+ : se::DeviceMemoryAllocator(nullptr), count(0) {}
absl::StatusOr<se::OwningDeviceMemory> Allocate(int, uint64_t size, bool,
int64_t) final {
+ count++;
return se::OwningDeviceMemory(se::DeviceMemoryBase(kAddr, size), 0, this);
}
absl::Status Deallocate(int, se::DeviceMemoryBase mem) final {
+ count--;
EXPECT_EQ(mem.opaque(), kAddr);
return absl::OkStatus();
}
@@ -672,11 +998,25 @@
CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();
CallOptions options;
- options.allocator = &allocator;
+ options.backend_options = CallOptions::GpuOptions{nullptr, &allocator};
auto status = Call(*handler, call_frame, options);
TF_ASSERT_OK(status);
+ EXPECT_EQ(allocator.count, 0);
+}
+
+TEST(FfiTest, ScratchAllocatorUnimplemented) {
+ auto fn = [&](ScratchAllocator scratch_allocator) {
+ auto mem = scratch_allocator.Allocate(1024);
+ EXPECT_FALSE(mem.has_value());
+ return Error::Success();
+ };
+ auto handler = Ffi::Bind().Ctx<ScratchAllocator>().To(fn);
+ CallFrame call_frame =
+ CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();
+ auto status = Call(*handler, call_frame);
+ TF_ASSERT_OK(status);
}
//===----------------------------------------------------------------------===//
@@ -747,7 +1087,7 @@
void BM_BufferArgX1(benchmark::State& state) {
auto call_frame = WithBufferArgs(1).Build();
- auto handler = Ffi::Bind().Arg<BufferR4<DataType::F32>>().To([](auto buffer) {
+ auto handler = Ffi::Bind().Arg<BufferR4<F32>>().To([](auto buffer) {
benchmark::DoNotOptimize(buffer);
return Error::Success();
});
@@ -767,10 +1107,10 @@
auto call_frame = WithBufferArgs(4).Build();
auto handler = Ffi::Bind()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
.To([](auto b0, auto b1, auto b2, auto b3) {
benchmark::DoNotOptimize(b0);
benchmark::DoNotOptimize(b1);
@@ -794,14 +1134,14 @@
auto call_frame = WithBufferArgs(8).Build();
auto handler = Ffi::Bind()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
- .Arg<BufferR4<DataType::F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
+ .Arg<BufferR4<F32>>()
.To([](auto b0, auto b1, auto b2, auto b3, auto b4,
auto b5, auto b6, auto b7) {
benchmark::DoNotOptimize(b0);
diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h
index 2ace88a..82076c5 100644
--- a/third_party/xla/xla/ffi/ffi.h
+++ b/third_party/xla/xla/ffi/ffi.h
@@ -27,6 +27,8 @@
#include <limits>
#include <memory>
#include <optional>
+#include <string>
+#include <string_view>
// IWYU pragma: begin_exports
#include "xla/ffi/api/api.h"
@@ -38,6 +40,7 @@
#include "absl/base/optimization.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
+#include "xla/executable_run_options.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep
#include "xla/ffi/execution_context.h"
@@ -49,6 +52,7 @@
#include "xla/stream_executor/scratch_allocator.h"
#include "xla/stream_executor/stream.h"
#include "xla/types.h" // IWYU pragma: keep
+#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/logging.h"
@@ -60,6 +64,7 @@
struct Allocator {}; // binds `se::DeviceMemoryAllocator*`
struct ScratchAllocator {}; // binds `se::OwningScratchAllocator`
struct CalledComputation {}; // binds `HloComputation*`
+struct IntraOpThreadPool {}; // binds `const Eigen::ThreadPoolDevice*`
//===----------------------------------------------------------------------===//
// Arguments
@@ -239,6 +244,41 @@
};
//===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of arguments.
+//===----------------------------------------------------------------------===//
+
+class RemainingArgs : public internal::RemainingArgsBase {
+ public:
+ using internal::RemainingArgsBase::RemainingArgsBase;
+
+ template <typename T>
+ absl::StatusOr<T> get(size_t index) const {
+ size_t idx = offset() + index;
+ if (ABSL_PREDICT_FALSE(idx >= args()->size)) {
+ return InvalidArgument("Index out of range.");
+ }
+
+ DiagnosticEngine diagnostic;
+ std::optional<T> value = ArgDecoding<T>::Decode(
+ args()->types[idx], args()->args[idx], diagnostic);
+ if (ABSL_PREDICT_FALSE(!value.has_value())) {
+ return Internal("%s", diagnostic.Result());
+ }
+
+ return *value;
+ }
+};
+
+template <>
+struct internal::Decode<internal::RemainingArgsTag> {
+ static std::optional<RemainingArgs> call(DecodingOffsets& offsets,
+ DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ return RemainingArgs(&ctx.call_frame->args, offsets.args);
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Results decoding
//===----------------------------------------------------------------------===//
@@ -272,6 +312,41 @@
};
//===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of results.
+//===----------------------------------------------------------------------===//
+
+class RemainingRets : public internal::RemainingRetsBase {
+ public:
+ using internal::RemainingRetsBase::RemainingRetsBase;
+
+ template <typename T>
+ absl::StatusOr<Result<T>> get(size_t index) const {
+ size_t idx = offset() + index;
+ if (ABSL_PREDICT_FALSE(idx >= rets()->size)) {
+ return InvalidArgument("Index out of range.");
+ }
+
+ DiagnosticEngine diagnostic;
+ std::optional<Result<T>> value = RetDecoding<T>::Decode(
+ rets()->types[idx], rets()->rets[idx], diagnostic);
+ if (ABSL_PREDICT_FALSE(!value.has_value())) {
+ return Internal("%s", diagnostic.Result());
+ }
+
+ return *value;
+ }
+};
+
+template <>
+struct internal::Decode<internal::RemainingRetsTag> {
+ static std::optional<RemainingRets> call(DecodingOffsets& offsets,
+ DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ return RemainingRets(&ctx.call_frame->rets, offsets.rets);
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Attributes decoding
//===----------------------------------------------------------------------===//
@@ -330,6 +405,49 @@
};
//===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing dictionary attributes.
+//===----------------------------------------------------------------------===//
+
+class Dictionary : public internal::DictionaryBase {
+ public:
+ using internal::DictionaryBase::DictionaryBase;
+
+ template <typename T>
+ absl::StatusOr<T> get(std::string_view name) const {
+ DiagnosticEngine diagnostic;
+ std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
+ if (!value.has_value()) {
+ return Internal("%s", diagnostic.Result());
+ }
+ return *value;
+ }
+};
+
+// Decode `AttrsTag` (all attributes) into a `Dictionary`.
+template <>
+struct internal::Decode<internal::AttrsTag<Dictionary>> {
+ static std::optional<Dictionary> call(DecodingOffsets& offsets,
+ DecodingContext& ctx,
+ DiagnosticEngine& diagnostic) {
+ return Dictionary(&ctx.call_frame->attrs);
+ }
+};
+
+// Decode individual attribute into `Dictionary` type.
+template <>
+struct AttrDecoding<Dictionary> {
+ using Type = Dictionary;
+ static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
+ DiagnosticEngine& diagnostic) {
+ if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
+ return diagnostic.Emit("Wrong attribute type: expected ")
+ << XLA_FFI_AttrType_DICTIONARY << " but got " << type;
+ }
+ return Dictionary(reinterpret_cast<XLA_FFI_Attrs*>(attr));
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Context decoding
//===----------------------------------------------------------------------===//
@@ -365,7 +483,7 @@
DiagnosticEngine&) {
void* device_allocator =
api->internal_api->XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(ctx);
- return reinterpret_cast<se::DeviceMemoryAllocator*>(device_allocator);
+ return reinterpret_cast<Type>(device_allocator);
}
};
@@ -399,6 +517,19 @@
}
};
+template <>
+struct CtxDecoding<IntraOpThreadPool> {
+ using Type = const Eigen::ThreadPoolDevice*;
+
+ static std::optional<Type> Decode(const XLA_FFI_Api* api,
+ XLA_FFI_ExecutionContext* ctx,
+ DiagnosticEngine&) {
+ void* intra_op_thread_pool =
+ api->internal_api->XLA_FFI_INTERNAL_IntraOpThreadPool_Get(ctx);
+ return reinterpret_cast<Type>(intra_op_thread_pool);
+ }
+};
+
//===----------------------------------------------------------------------===//
// UserData
//===----------------------------------------------------------------------===//
diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc
index f402ed2..c5f07eb 100644
--- a/third_party/xla/xla/ffi/ffi_api.cc
+++ b/third_party/xla/xla/ffi/ffi_api.cc
@@ -15,12 +15,14 @@
#include "xla/ffi/ffi_api.h"
+#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <string>
#include <string_view>
#include <utility>
+#include <variant>
#include <vector>
#include "absl/base/optimization.h"
@@ -31,6 +33,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
+#include "xla/executable_run_options.h"
#include "xla/ffi/api/api.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep
@@ -56,10 +59,19 @@
};
struct XLA_FFI_ExecutionContext {
- int32_t device_ordinal = -1;
+ struct CpuContext {
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr;
+ };
- stream_executor::Stream* stream = nullptr;
- stream_executor::DeviceMemoryAllocator* allocator = nullptr;
+ struct GpuContext {
+ stream_executor::Stream* stream = nullptr;
+ stream_executor::DeviceMemoryAllocator* allocator = nullptr;
+ };
+
+ using BackendContext = std::variant<std::monostate, CpuContext, GpuContext>;
+
+ int32_t device_ordinal = -1;
+ BackendContext backend_context = {};
const xla::HloComputation* called_computation = nullptr;
const xla::ffi::ExecutionContext* execution_context = nullptr;
@@ -76,10 +88,27 @@
static XLA_FFI_ExecutionContext CreateExecutionContext(
const CallOptions& options) {
+ using BackendContext = XLA_FFI_ExecutionContext::BackendContext;
+
+ // Converts CallOptions to corresponding backend context.
+ struct BackendVisitor {
+ BackendContext operator()(const std::monostate&) const {
+ return std::monostate{};
+ }
+
+ BackendContext operator()(const CallOptions::CpuOptions& options) const {
+ return XLA_FFI_ExecutionContext::CpuContext{options.intra_op_thread_pool};
+ }
+
+ BackendContext operator()(const CallOptions::GpuOptions& options) const {
+ return XLA_FFI_ExecutionContext::GpuContext{options.stream,
+ options.allocator};
+ }
+ };
+
return XLA_FFI_ExecutionContext{
options.device_ordinal,
- options.stream,
- options.allocator,
+ std::visit(BackendVisitor{}, options.backend_options),
options.called_computation,
internal::ScopedExecutionContext::GetCallExecutionContext(options),
options.execution_state,
@@ -376,12 +405,20 @@
"XLA_FFI_Stream_Get", XLA_FFI_Stream_Get_Args_STRUCT_SIZE,
args->struct_size));
- if (args->ctx->stream == nullptr) {
+ auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+ &args->ctx->backend_context);
+
+ if (ABSL_PREDICT_FALSE(gpu == nullptr)) {
return new XLA_FFI_Error{
- InvalidArgument("XLA FFI stream is not available")};
+ Unimplemented("XLA FFI GPU context is not available")};
}
- auto handle = args->ctx->stream->platform_specific_handle();
+ if (ABSL_PREDICT_FALSE(gpu->stream == nullptr)) {
+ return new XLA_FFI_Error{
+ Unimplemented("XLA FFI GPU stream is not available")};
+ }
+
+ auto handle = gpu->stream->platform_specific_handle();
args->stream = handle.stream;
return nullptr;
@@ -459,6 +496,22 @@
"XLA_FFI_DeviceMemory_Allocate_Args",
XLA_FFI_DeviceMemory_Allocate_Args_STRUCT_SIZE, args->struct_size));
+ auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+ &args->ctx->backend_context);
+
+ // TODO(ezhulenev): Device memory allocation should be supported for all
+ // backends, not just GPU, although for CPU it doesn't make much sense, as
+ // plain `new` is sufficient.
+ if (ABSL_PREDICT_FALSE(gpu == nullptr)) {
+ return new XLA_FFI_Error{
+ InvalidArgument("XLA FFI GPU context is not available")};
+ }
+
+ if (ABSL_PREDICT_FALSE(gpu->allocator == nullptr)) {
+ return new XLA_FFI_Error{
+ Unimplemented("No device memory allocator available on this platform")};
+ }
+
// TODO(ezhulenev): We happen to have the same alignment requirement for
// device memory on CPU and GPU backends, but instead of hardcoding it here
// we should query it for the platform XLA FFI handler is registered with.
@@ -471,7 +524,7 @@
}
absl::StatusOr<stream_executor::OwningDeviceMemory> memory =
- args->ctx->allocator->Allocate(args->ctx->device_ordinal, args->size);
+ gpu->allocator->Allocate(args->ctx->device_ordinal, args->size);
if (!memory.ok()) {
return new XLA_FFI_Error{std::move(memory).status()};
}
@@ -486,7 +539,23 @@
"XLA_FFI_DeviceMemory_Free_Args",
XLA_FFI_DeviceMemory_Free_Args_STRUCT_SIZE, args->struct_size));
- absl::Status status = args->ctx->allocator->Deallocate(
+ auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+ &args->ctx->backend_context);
+
+ // TODO(ezhulenev): Device memory allocation should be supported for all
+ // backends, not just GPU, although for CPU it doesn't make much sense, as
+ // plain `new` is sufficient.
+ if (ABSL_PREDICT_FALSE(gpu == nullptr)) {
+ return new XLA_FFI_Error{
+ Unimplemented("XLA FFI GPU context is not available")};
+ }
+
+ if (ABSL_PREDICT_FALSE(gpu->allocator == nullptr)) {
+ return new XLA_FFI_Error{
+ Unimplemented("No device memory allocator available on this platform")};
+ }
+
+ absl::Status status = gpu->allocator->Deallocate(
args->ctx->device_ordinal,
stream_executor::DeviceMemoryBase(args->data, args->size));
if (!status.ok()) {
@@ -509,7 +578,13 @@
}
static void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx) {
- return ctx->stream;
+ if (auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+ &ctx->backend_context)) {
+ return gpu->stream;
+ }
+
+ return new XLA_FFI_Error{
+ InvalidArgument("XLA FFI GPU context is not available")};
}
static int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get(
@@ -519,7 +594,13 @@
static void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(
XLA_FFI_ExecutionContext* ctx) {
- return ctx->allocator;
+ if (auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+ &ctx->backend_context)) {
+ return gpu->allocator;
+ }
+
+ return new XLA_FFI_Error{
+ InvalidArgument("XLA FFI GPU context is not available")};
}
static void* XLA_FFI_INTERNAL_CalledComputation_Get(
@@ -537,6 +618,16 @@
return const_cast<ffi::ExecutionState*>(ctx->execution_state);
}
+void* XLA_FFI_INTERNAL_IntraOpThreadPool_Get(XLA_FFI_ExecutionContext* ctx) {
+ if (auto* cpu = std::get_if<XLA_FFI_ExecutionContext::CpuContext>(
+ &ctx->backend_context)) {
+ return const_cast<Eigen::ThreadPoolDevice*>(cpu->intra_op_thread_pool);
+ }
+
+ return new XLA_FFI_Error{
+ InvalidArgument("XLA FFI CPU context is not available")};
+}
+
//===----------------------------------------------------------------------===//
// XLA FFI Api access
//===----------------------------------------------------------------------===//
@@ -551,6 +642,7 @@
XLA_FFI_INTERNAL_CalledComputation_Get,
XLA_FFI_INTERNAL_ExecutionContext_Get,
XLA_FFI_INTERNAL_ExecutionState_Get,
+ XLA_FFI_INTERNAL_IntraOpThreadPool_Get,
};
static XLA_FFI_Api api = {
diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h
index 7a6e5aa..f583fa3 100644
--- a/third_party/xla/xla/ffi/ffi_api.h
+++ b/third_party/xla/xla/ffi/ffi_api.h
@@ -19,10 +19,12 @@
#include <cstdint>
#include <string>
#include <string_view>
+#include <variant>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "xla/executable_run_options.h"
#include "xla/ffi/api/api.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep
@@ -47,11 +49,23 @@
// Calling XLA FFI handlers
//===----------------------------------------------------------------------===//
+// Options for calling XLA FFI handlers. Backend specific options must be
+// constructed from `xla::ExecuteRunOptions`, to give FFI handlers access to
+// XLA runtime internals.
struct CallOptions {
- int32_t device_ordinal = -1;
+ struct CpuOptions {
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr;
+ };
- se::Stream* stream = nullptr;
- se::DeviceMemoryAllocator* allocator = nullptr;
+ struct GpuOptions {
+ se::Stream* stream = nullptr;
+ se::DeviceMemoryAllocator* allocator = nullptr;
+ };
+
+ using BackendOptions = std::variant<std::monostate, CpuOptions, GpuOptions>;
+
+ int32_t device_ordinal = -1;
+ BackendOptions backend_options = {};
const HloComputation* called_computation = nullptr;
const ExecutionContext* execution_context = nullptr;
diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc
index ab8d200..9fb4ff8 100644
--- a/third_party/xla/xla/ffi/ffi_test.cc
+++ b/third_party/xla/xla/ffi/ffi_test.cc
@@ -317,21 +317,21 @@
EXPECT_TRUE(dict.contains("f32"));
EXPECT_TRUE(dict.contains("str"));
- auto i32 = dict.get<int32_t>("i32");
- auto f32 = dict.get<float>("f32");
- auto str = dict.get<std::string_view>("str");
+ absl::StatusOr<int32_t> i32 = dict.get<int32_t>("i32");
+ absl::StatusOr<float> f32 = dict.get<float>("f32");
+ absl::StatusOr<std::string_view> str = dict.get<std::string_view>("str");
- EXPECT_TRUE(i32.has_value());
- EXPECT_TRUE(f32.has_value());
- EXPECT_TRUE(str.has_value());
+ EXPECT_TRUE(i32.ok());
+ EXPECT_TRUE(f32.ok());
+ EXPECT_TRUE(str.ok());
- if (i32) EXPECT_EQ(*i32, 42);
- if (f32) EXPECT_EQ(*f32, 42.0f);
- if (str) EXPECT_EQ(*str, "foo");
+ if (i32.ok()) EXPECT_EQ(*i32, 42);
+ if (f32.ok()) EXPECT_EQ(*f32, 42.0f);
+ if (str.ok()) EXPECT_EQ(*str, "foo");
EXPECT_FALSE(dict.contains("i64"));
- EXPECT_FALSE(dict.get<int64_t>("i32").has_value());
- EXPECT_FALSE(dict.get<int64_t>("i64").has_value());
+ EXPECT_FALSE(dict.get<int64_t>("i32").ok());
+ EXPECT_FALSE(dict.get<int64_t>("i64").ok());
return absl::OkStatus();
};
@@ -364,14 +364,14 @@
EXPECT_TRUE(dict0.contains("i32"));
EXPECT_TRUE(dict1.contains("f32"));
- auto i32 = dict0.get<int32_t>("i32");
- auto f32 = dict1.get<float>("f32");
+ absl::StatusOr<int32_t> i32 = dict0.get<int32_t>("i32");
+ absl::StatusOr<float> f32 = dict1.get<float>("f32");
- EXPECT_TRUE(i32.has_value());
- EXPECT_TRUE(f32.has_value());
+ EXPECT_TRUE(i32.ok());
+ EXPECT_TRUE(f32.ok());
- if (i32) EXPECT_EQ(*i32, 42);
- if (f32) EXPECT_EQ(*f32, 42.0f);
+ if (i32.ok()) EXPECT_EQ(*i32, 42);
+ if (f32.ok()) EXPECT_EQ(*f32, 42.0f);
return absl::OkStatus();
};
@@ -631,8 +631,14 @@
auto fn = [&](RemainingArgs args) {
EXPECT_EQ(args.size(), 1);
- EXPECT_TRUE(args.get<AnyBuffer>(0).has_value());
- EXPECT_FALSE(args.get<AnyBuffer>(1).has_value());
+
+ absl::StatusOr<AnyBuffer> arg0 = args.get<AnyBuffer>(0);
+ absl::StatusOr<AnyBuffer> arg1 = args.get<AnyBuffer>(1);
+
+ EXPECT_TRUE(arg0.ok());
+ EXPECT_THAT(arg1.status(), StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("Index out of range")));
+
return absl::OkStatus();
};
@@ -651,19 +657,148 @@
builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
auto call_frame = builder.Build();
- auto fn = [&](Result<AnyBuffer> ret, RemainingResults rets) {
+ auto fn = [&](Result<AnyBuffer> ret, RemainingRets rets) {
EXPECT_EQ(rets.size(), 1);
- EXPECT_TRUE(rets.get<AnyBuffer>(0).has_value());
- EXPECT_FALSE(rets.get<AnyBuffer>(1).has_value());
+
+ absl::StatusOr<Result<AnyBuffer>> ret0 = rets.get<AnyBuffer>(0);
+ absl::StatusOr<Result<AnyBuffer>> ret1 = rets.get<AnyBuffer>(1);
+
+ EXPECT_TRUE(ret0.ok());
+ EXPECT_THAT(ret1.status(), StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("Index out of range")));
+
return absl::OkStatus();
};
- auto handler = Ffi::Bind().Ret<AnyBuffer>().RemainingResults().To(fn);
+ auto handler = Ffi::Bind().Ret<AnyBuffer>().RemainingRets().To(fn);
auto status = Call(*handler, call_frame);
TF_ASSERT_OK(status);
}
+TEST(FfiTest, OptionalArgs) {
+ std::vector<float> storage(4, 0.0f);
+ se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+ CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0);
+ builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+ auto call_frame = builder.Build();
+
+ { // Single optional argument.
+ auto fn = [&](std::optional<AnyBuffer> arg0) {
+ EXPECT_TRUE(arg0.has_value());
+ return absl::OkStatus();
+ };
+
+ auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Two optional arguments.
+ auto fn = [&](std::optional<AnyBuffer> arg0,
+ std::optional<AnyBuffer> arg1) {
+ EXPECT_TRUE(arg0.has_value());
+ EXPECT_FALSE(arg1.has_value());
+ return absl::OkStatus();
+ };
+
+ auto handler =
+ Ffi::Bind().OptionalArg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Optional argument after a regular one.
+ auto fn = [&](AnyBuffer arg0, std::optional<AnyBuffer> arg1) {
+ EXPECT_FALSE(arg1.has_value());
+ return absl::OkStatus();
+ };
+
+ auto handler = Ffi::Bind().Arg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Remaining arguments after optional one.
+ auto fn = [&](std::optional<AnyBuffer> arg0, RemainingArgs args) {
+ EXPECT_TRUE(arg0.has_value());
+ EXPECT_EQ(args.size(), 0);
+ return absl::OkStatus();
+ };
+
+ auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().RemainingArgs().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+}
+
+TEST(FfiTest, OptionalRets) {
+ std::vector<float> storage(4, 0.0f);
+ se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+ CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1);
+ builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+ auto call_frame = builder.Build();
+
+ { // Single optional result.
+ auto fn = [&](std::optional<Result<AnyBuffer>> ret0) {
+ EXPECT_TRUE(ret0.has_value());
+ return absl::OkStatus();
+ };
+
+ auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Two optional results.
+ auto fn = [&](std::optional<Result<AnyBuffer>> ret0,
+ std::optional<Result<AnyBuffer>> ret1) {
+ EXPECT_TRUE(ret0.has_value());
+ EXPECT_FALSE(ret1.has_value());
+ return absl::OkStatus();
+ };
+
+ auto handler =
+ Ffi::Bind().OptionalRet<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Optional result after a regular one.
+ auto fn = [&](Result<AnyBuffer> ret0,
+ std::optional<Result<AnyBuffer>> ret1) {
+ EXPECT_FALSE(ret1.has_value());
+ return absl::OkStatus();
+ };
+
+ auto handler = Ffi::Bind().Ret<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+
+ { // Remaining results after optional one.
+ auto fn = [&](std::optional<Result<AnyBuffer>> ret0, RemainingRets rets) {
+ EXPECT_TRUE(ret0.has_value());
+ EXPECT_EQ(rets.size(), 0);
+ return absl::OkStatus();
+ };
+
+ auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().RemainingRets().To(fn);
+ auto status = Call(*handler, call_frame);
+
+ TF_ASSERT_OK(status);
+ }
+}
+
TEST(FfiTest, RunOptionsCtx) {
auto call_frame = CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();
auto* expected = reinterpret_cast<se::Stream*>(0x01234567);
@@ -674,7 +809,7 @@
};
CallOptions options;
- options.stream = expected;
+ options.backend_options = CallOptions::GpuOptions{expected};
auto handler = Ffi::Bind().Ctx<Stream>().To(fn);
auto status = Call(*handler, call_frame, options);
diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
index 7610060..b91e50f 100644
--- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
+++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
@@ -1823,6 +1823,7 @@
auto generate_twiddles = [](int64_t length, bool inverse) {
std::vector<ComplexType> twiddles;
// Need only half the twiddles.
+ twiddles.reserve(length / 2);
for (int64_t k = 0; k < length / 2; k++) {
twiddles.push_back(Twiddle(k, length, inverse));
}
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
index 30a0283..acd6fff 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
@@ -345,6 +345,7 @@
],
deps = [
":auto_sharding",
+ ":auto_sharding_cost_graph",
":auto_sharding_option",
":auto_sharding_strategy",
":auto_sharding_util",
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
index 7fa0c4e..558c6ff 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
@@ -1989,6 +1989,7 @@
request.mutable_max_cost()->set_coeff(*max_cost);
}
for (const auto& [edge, edge_cost] : cost_graph.edge_costs_) {
+ const auto normalized_edge_cost = Normalize(edge_cost);
AutoShardingSolverRequest_Pair raw_edge;
raw_edge.set_first(edge.first);
raw_edge.set_second(edge.second);
@@ -1997,8 +1998,8 @@
AutoShardingSolverRequest_Costs mij;
for (NodeStrategyIdx i = 0; i < edge_cost.n_; i++) {
for (NodeStrategyIdx j = 0; j < edge_cost.m_; j++) {
- rij.add_costs(edge_cost(i, j).communication_cost);
- mij.add_costs(edge_cost(i, j).memory_cost);
+ rij.add_costs(normalized_edge_cost(i, j).communication_cost);
+ mij.add_costs(normalized_edge_cost(i, j).memory_cost);
}
}
request.mutable_resharding_costs()->Add(std::move(rij));
@@ -3016,7 +3017,6 @@
for (size_t i = 0; i < cur->operand_count(); ++i) {
HloInstruction* operand = cur->mutable_operand(i);
- operand = PassThroughCustomCallMarkerOperand(operand, cur);
if (!visited.contains(operand) && !IsAlwaysReplicated(operand) &&
GetShardingStrategy(operand, strategy_map, cost_graph, s_val)
@@ -3040,9 +3040,6 @@
// Propagation ends at output.
const HloInstruction* output = instructions.back();
- if (IsCustomCallMarker(output)) {
- output = output->operand(0);
- }
// A debug option: whether to do all-gather after backward pass.
// This controls the location of all-gather.
@@ -3118,8 +3115,7 @@
while (true) {
path.push_back(root);
if (root->opcode() == HloOpcode::kGetTupleElement) {
- root = PassThroughCustomCallMarkerOperand(root->mutable_operand(0),
- root);
+ root = root->mutable_operand(0);
} else {
break;
}
@@ -3215,14 +3211,6 @@
insert_all_gather.push_back(alias_map.at(to_split));
} else {
insert_all_gather.push_back(to_split);
-
- if (to_split->opcode() == HloOpcode::kGetTupleElement &&
- IsCustomCallMarker(to_split->operand(0)) &&
- to_split->users().size() == 1 &&
- to_split->users().front() == output) {
- insert_all_gather.push_back(PassThroughCustomCallMarkerOperand(
- to_split->mutable_operand(0), to_split));
- }
}
}
} else {
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
index 8512788..9d28df3 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
@@ -18,6 +18,7 @@
#include <algorithm>
#include <cstddef>
#include <cstdlib>
+#include <limits>
#include <numeric>
#include <string>
#include <utility>
@@ -34,6 +35,24 @@
namespace xla {
namespace spmd {
+EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost) {
+ double min_communication_cost = std::numeric_limits<double>::max();
+ for (int i = 0; i < edge_cost.n_; ++i) {
+ for (int j = 0; j < edge_cost.m_; ++j) {
+ min_communication_cost =
+ std::min(min_communication_cost, edge_cost(i, j).communication_cost);
+ }
+ }
+ if (min_communication_cost >= 0) return edge_cost;
+ EdgeReshardingCostMatrix normalized_edge_cost = edge_cost;
+ for (int i = 0; i < edge_cost.n_; ++i) {
+ for (int j = 0; j < edge_cost.m_; ++j) {
+ normalized_edge_cost(i, j).communication_cost -= min_communication_cost;
+ }
+ }
+ return normalized_edge_cost;
+}
+
CostGraph::CostGraph(const StrategyGroups& strategy_groups,
const AssociativeDotPairs& associative_dot_pairs) {
node_lens_.reserve(strategy_groups.size());
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
index fda06ee..3d6bac1 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
@@ -55,6 +55,10 @@
using EdgeReshardingCostMatrix = Matrix<EdgeReshardingCost>;
+// Normalizes the edge cost matrix by a fixed constant to ensure there are no
+// negative communication costs.
+EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost);
+
// A graph data structure to simplify the edge cost graph. It merges nodes and
// performs path compression.
class CostGraph {
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
index 969fdf3..dbd161b 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
@@ -39,7 +39,6 @@
#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
-#include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h"
#include "xla/hlo/experimental/auto_sharding/cluster_environment.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -57,14 +56,8 @@
namespace spmd {
namespace {
-using DimMap = StableMap</*tensor dim*/ int, /* mesh dim*/ int>;
-using MeshDims = absl::Span<const int64_t>;
-
-struct Enumeration {
- MeshDims mesh_dims;
- int64_t i;
- int64_t j;
-};
+using MeshDimSet = StableSet<int>;
+using DimMap = StableMap</*tensor dim*/ int, /*mesh dims*/ MeshDimSet>;
// Contains base functionality common to both DotHandler and ConvHandler.
class HandlerBase {
@@ -88,7 +81,6 @@
option_(option),
call_graph_(call_graph),
device_mesh_(cluster_env.device_mesh_),
- device_mesh_1d_(cluster_env.device_mesh_1d_),
lhs_(ins->operand(0)),
rhs_(ins->operand(1)) {}
@@ -102,10 +94,12 @@
HloSharding CreateInputSpec(const HloInstruction* ins, const DimMap& dim_map,
const Array<int64_t>& device_mesh) const {
if (dim_map.empty()) return HloSharding::Replicate();
- std::vector<int64_t> tensor_dims, mesh_dims;
- for (const auto& [tensor_dim, mesh_dim] : dim_map) {
+ std::vector<int64_t> tensor_dims;
+ std::vector<std::vector<int64_t>> mesh_dims;
+ for (const auto& [tensor_dim, mesh_dim_set] : dim_map) {
tensor_dims.push_back(tensor_dim);
- mesh_dims.push_back(mesh_dim);
+ mesh_dims.push_back(
+ std::vector<int64_t>(mesh_dim_set.begin(), mesh_dim_set.end()));
}
return Tile(ins->shape(), tensor_dims, mesh_dims, device_mesh);
}
@@ -116,7 +110,7 @@
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost = 0,
+ double compute_cost = 0,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn = std::nullopt);
@@ -126,7 +120,7 @@
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost = 0,
+ double compute_cost = 0,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn = std::nullopt);
@@ -137,7 +131,7 @@
virtual void AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost) {}
+ double compute_cost) {}
// Given an existing (allreduce) sharding candidate, generate a corresponding
// candidate by additionally sharding (if possible) the dot/conv output, such
@@ -146,7 +140,7 @@
virtual void AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost) {}
+ double compute_cost) {}
std::optional<HloSharding> GetShardingFromUser(const HloSharding& lhs_spec,
const HloSharding& rhs_spec);
@@ -155,32 +149,65 @@
// where a subset of all tensor dims is mapped to a subset of mesh dims, such
// that each tensor dim is mapped to at most mesh dim, and no two tensor dims
// are mapped to the same mesh dim.
- // TODO(b/226977360): We might need to generalize this to also allow cases
- // where a tensor dim can be mapped to multiple mesh dims.
- void EnumerateGeneral(std::function<void(const DimMap&)> split_func,
- int tensor_rank, int current_tensor_dim,
- const absl::flat_hash_set<int>& unassigned_mesh_dims,
- const DimMap& current_dim_map) {
- if (current_tensor_dim == tensor_rank) {
+ void Enumerate(std::function<void(const DimMap&)> split_func, int tensor_rank,
+ int current_mesh_dim_idx,
+ const std::vector<int>& unassigned_mesh_dims,
+ const DimMap& current_dim_map) {
+ if (current_mesh_dim_idx == unassigned_mesh_dims.size()) {
split_func(current_dim_map);
return;
}
- // current_tensor_dim is unsharded
- EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1,
- unassigned_mesh_dims, current_dim_map);
- // current_tensor_dim is sharded across one of the remaining mesh dims
- for (int mesh_dim : unassigned_mesh_dims) {
+ // Current mesh dim is not assigned to any tensor dim
+ Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1,
+ unassigned_mesh_dims, current_dim_map);
+
+ for (int i = 0; i < tensor_rank; ++i) {
DimMap updated_dim_map = current_dim_map;
- updated_dim_map[current_tensor_dim] = mesh_dim;
- absl::flat_hash_set<int> updated_unassigned_mesh_dims =
- unassigned_mesh_dims;
- updated_unassigned_mesh_dims.erase(
- updated_unassigned_mesh_dims.find(mesh_dim));
- EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1,
- updated_unassigned_mesh_dims, updated_dim_map);
+ if (!updated_dim_map[i].empty() && !option_.allow_mixed_mesh_shape) {
+ continue;
+ }
+ updated_dim_map[i].insert(unassigned_mesh_dims[current_mesh_dim_idx]);
+ Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1,
+ unassigned_mesh_dims, updated_dim_map);
}
}
+ bool IsMeshDimSetNonTrivial(const MeshDimSet& mesh_dim_set) {
+ return absl::c_any_of(mesh_dim_set, [&](int mesh_dim) {
+ return device_mesh_.dim(mesh_dim) > 1;
+ });
+ }
+
+ bool IsFullyReplicatedSharding(const DimMap& dim_map,
+ const Array<int64_t>& device_mesh) {
+ if (dim_map.empty()) {
+ return true;
+ }
+ for (const auto& [_, mesh_dim_set] : dim_map) {
+ if (IsMeshDimSetNonTrivial(mesh_dim_set)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool IsFullyReplicatedStrategy(const DimMap& output_dim_map,
+ const DimMap& lhs_dim_map,
+ const DimMap& rhs_dim_map,
+ const Array<int64_t>& device_mesh) {
+ return IsFullyReplicatedSharding(output_dim_map, device_mesh) &&
+ IsFullyReplicatedSharding(lhs_dim_map, device_mesh) &&
+ IsFullyReplicatedSharding(rhs_dim_map, device_mesh);
+ }
+
+ bool IsFullySharded(const DimMap& dim_map, int num_mesh_dims) {
+ int num_mesh_dims_used = 0;
+ for (const auto& [_, mesh_dims] : dim_map) {
+ num_mesh_dims_used += mesh_dims.size();
+ }
+ return num_mesh_dims_used >= num_mesh_dims;
+ }
+
// Sorts strategies in the increasing order of their memory costs. Anecdotal
// experience suggests that such a sorted list of strategies works better
void SortStrategies();
@@ -197,7 +224,6 @@
const CallGraph& call_graph_;
const Array<int64_t>& device_mesh_;
- const Array<int64_t>& device_mesh_1d_;
const HloInstruction* lhs_;
const HloInstruction* rhs_;
};
@@ -234,12 +260,13 @@
void AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost) override;
+ double compute_cost) override;
- void AppendReduceScatterWindowedEinsumStrategy(
- const std::string& name, const DimMap& lhs_dim_map,
- const DimMap& rhs_dim_map, const DimMap& output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost) override;
+ void AppendReduceScatterWindowedEinsumStrategy(const std::string& name,
+ const DimMap& lhs_dim_map,
+ const DimMap& rhs_dim_map,
+ const DimMap& output_dim_map,
+ double compute_cost) override;
absl::Status RegisterStrategies();
@@ -324,29 +351,28 @@
}));
}
-// Given lhs and rhs dim maps, infers a sharding for the output by relying on
-// the sharding_propagation pass. Given that this is a relatively new change
-// (as of 11/2023), we also take an optional expected output dim map as an
-// argument, to verify that sharding propagation in fact infers the sharding
-// we expect (and to crash if it doesn't).
+// Given lhs and rhs dim maps, infers a sharding for the output by relying
+// on the sharding_propagation pass. Given that this is a relatively new
+// change (as of 11/2023), we also take an optional expected output dim map
+// as an argument, to verify that sharding propagation in fact infers the
+// sharding we expect (and to crash if it doesn't).
// TODO(b/309638633) As we build more confidence in this, we should remove
// this expected_output_dim_map argument and fully rely on sharding
// propagation.
void HandlerBase::MaybeAppendInternal(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
- const std::optional<DimMap>& expected_output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost,
+ const std::optional<DimMap>& expected_output_dim_map, double compute_cost,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn) {
- HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh);
- HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh);
+ HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh_);
+ HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh_);
std::optional<HloSharding> output_spec =
GetShardingFromUser(lhs_spec, rhs_spec);
if (output_spec.has_value()) {
if (expected_output_dim_map.has_value()) {
HloSharding expected_output_spec =
- CreateInputSpec(ins_, *expected_output_dim_map, device_mesh);
+ CreateInputSpec(ins_, *expected_output_dim_map, device_mesh_);
// TODO(b/308687597) Once the bug is resolved, we ideally either want
// have a CHECK statement verifying that the sharding inferred by
// sharding propagation is in fact what we expect, or we trust sharding
@@ -366,7 +392,7 @@
}
} else {
CHECK(expected_output_dim_map.has_value());
- output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh);
+ output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh_);
LOG(WARNING)
<< "Sharding propagation could not infer output sharding for:\n "
<< ins_->ToString() << "\n LHS Spec: " << lhs_spec
@@ -384,29 +410,27 @@
void HandlerBase::MaybeAppend(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
- const std::optional<DimMap>& expected_output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost,
+ const std::optional<DimMap>& expected_output_dim_map, double compute_cost,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn) {
MaybeAppendInternal(name, lhs_dim_map, rhs_dim_map, expected_output_dim_map,
- device_mesh, compute_cost, communication_cost_fn);
+ compute_cost, communication_cost_fn);
if (!option_.generate_windowed_einsum_strategies ||
!expected_output_dim_map.has_value()) {
return;
}
if (absl::StrContains(name, "allreduce")) {
CHECK(communication_cost_fn.has_value());
- AppendReduceScatterWindowedEinsumStrategy(name, lhs_dim_map, rhs_dim_map,
- *expected_output_dim_map,
- device_mesh, compute_cost);
+ AppendReduceScatterWindowedEinsumStrategy(
+ name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map, compute_cost);
} else {
CHECK(!communication_cost_fn.has_value());
AppendAllGatherWindowedEinsumStrategyForOperand(
0, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
- device_mesh, compute_cost);
+ compute_cost);
AppendAllGatherWindowedEinsumStrategyForOperand(
1, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
- device_mesh, compute_cost);
+ compute_cost);
}
}
@@ -437,14 +461,15 @@
}
void HandlerBase::SortStrategies() {
- absl::c_sort(strategy_group_->strategies,
- [](const ShardingStrategy& s1, const ShardingStrategy& s2) {
- if (s1.memory_cost == s2.memory_cost) {
- return s1.name < s2.name;
- } else {
- return s1.memory_cost < s2.memory_cost;
- }
- });
+ absl::c_stable_sort(
+ strategy_group_->strategies,
+ [](const ShardingStrategy& s1, const ShardingStrategy& s2) {
+ if (s1.memory_cost == s2.memory_cost) {
+ return s1.name < s2.name;
+ } else {
+ return s1.memory_cost < s2.memory_cost;
+ }
+ });
}
/************** DotHandler function definitions **************/
@@ -520,6 +545,15 @@
}
}
+std::string ToString(const MeshDimSet& set) { return absl::StrJoin(set, "-"); }
+std::string ToString(const DimMap& map) {
+ std::vector<std::string> strings;
+ for (const auto& [tdim, mdims] : map) {
+ strings.push_back(absl::StrCat("[", tdim, ": ", ToString(mdims), "]"));
+ }
+ return absl::StrJoin(strings, ", ");
+}
+
std::string DotHandler::GenerateNameForDotSharding(const DimMap& output_dim_map,
const DimMap& lhs_dim_map) {
std::string name;
@@ -529,12 +563,12 @@
absl::string_view identifier) {
for (size_t i = 0; i < out_dims.size(); ++i) {
int output_batch_dim = out_dims[i];
- int mesh_dim = -1;
+ MeshDimSet mesh_dim_set;
auto it = dim_map.find(output_batch_dim);
- if (it != dim_map.end() && it->second >= 0) {
- mesh_dim = it->second;
+ if (it != dim_map.end() && !it->second.empty()) {
+ mesh_dim_set = it->second;
}
- absl::StrAppend(&name, identifier, mesh_dim);
+ absl::StrAppend(&name, identifier, ToString(mesh_dim_set));
}
};
@@ -554,9 +588,9 @@
bool contraction_dim_sharded = false;
for (size_t i = 0; i < lhs_con_dims_.size(); ++i) {
if (auto it = lhs_dim_map.find(lhs_con_dims_[i]);
- it != lhs_dim_map.end() && it->second >= 0) {
+ it != lhs_dim_map.end() && !it->second.empty()) {
contraction_dim_sharded =
- contraction_dim_sharded || (device_mesh_.dim(it->second) > 1);
+ contraction_dim_sharded || IsMeshDimSetNonTrivial(it->second);
}
}
@@ -566,34 +600,17 @@
return name;
}
-bool IsFullyReplicatedSharding(const DimMap& dim_map,
- const Array<int64_t>& device_mesh) {
- if (dim_map.empty()) {
- return true;
- }
- for (const auto& [_, mesh_dim] : dim_map) {
- if (device_mesh.dim(mesh_dim) > 1) {
- return false;
- }
- }
- return true;
-}
-
-bool IsFullyReplicatedStrategy(const DimMap& output_dim_map,
- const DimMap& lhs_dim_map,
- const DimMap& rhs_dim_map,
- const Array<int64_t>& device_mesh) {
- return IsFullyReplicatedSharding(output_dim_map, device_mesh) &&
- IsFullyReplicatedSharding(lhs_dim_map, device_mesh) &&
- IsFullyReplicatedSharding(rhs_dim_map, device_mesh);
-}
-
-bool IsFullySharded(const DimMap& dim_map, int num_mesh_dims) {
- return dim_map.size() >= num_mesh_dims;
-}
-
void DotHandler::GenerateDotShardingStrategiesFromOutputSharding(
const DimMap& output_dim_map) {
+ // This early return is added to ensure parity with the older strategy
+ // generation code. Removing it will only increase the search space.
+ for (const auto& [_, mesh_dims] : output_dim_map) {
+ if (mesh_dims.size() > 1 &&
+ mesh_dims.size() != device_mesh_.num_dimensions()) {
+ return;
+ }
+ }
+
DimMap lhs_dim_map, rhs_dim_map;
absl::flat_hash_set<int> used_mesh_dims;
@@ -603,11 +620,11 @@
int lhs_batch_dim = lhs_batch_dims_[i];
int rhs_batch_dim = rhs_batch_dims_[i];
auto it = output_dim_map.find(output_batch_dim);
- if (it != output_dim_map.end() && it->second >= 0) {
- int mesh_dim = it->second;
- used_mesh_dims.insert(mesh_dim);
- lhs_dim_map[lhs_batch_dim] = mesh_dim;
- rhs_dim_map[rhs_batch_dim] = mesh_dim;
+ if (it != output_dim_map.end() && !it->second.empty()) {
+ const StableSet<int>& mesh_dim_set = it->second;
+ used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+ lhs_dim_map[lhs_batch_dim] = mesh_dim_set;
+ rhs_dim_map[rhs_batch_dim] = mesh_dim_set;
}
}
@@ -617,10 +634,10 @@
int lhs_space_dim = lhs_space_dims_[i];
int output_space_dim = out_lhs_space_dims_[i];
auto it = output_dim_map.find(output_space_dim);
- if (it != output_dim_map.end() && it->second >= 0) {
- int mesh_dim = it->second;
- used_mesh_dims.insert(mesh_dim);
- lhs_dim_map[lhs_space_dim] = mesh_dim;
+ if (it != output_dim_map.end() && !it->second.empty()) {
+ const StableSet<int>& mesh_dim_set = it->second;
+ used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+ lhs_dim_map[lhs_space_dim] = mesh_dim_set;
}
}
@@ -629,10 +646,10 @@
int rhs_space_dim = rhs_space_dims_[i];
int output_space_dim = out_rhs_space_dims_[i];
auto it = output_dim_map.find(output_space_dim);
- if (it != output_dim_map.end() && it->second >= 0) {
- int mesh_dim = it->second;
- used_mesh_dims.insert(mesh_dim);
- rhs_dim_map[rhs_space_dim] = mesh_dim;
+ if (it != output_dim_map.end() && !it->second.empty()) {
+ const MeshDimSet& mesh_dim_set = it->second;
+ used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+ rhs_dim_map[rhs_space_dim] = mesh_dim_set;
}
}
@@ -646,7 +663,7 @@
// generation code. Removing it will only increase the search space.
IsFullySharded(output_dim_map, device_mesh_.num_dimensions())) {
MaybeAppend(GenerateNameForDotSharding(output_dim_map, lhs_dim_map),
- lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_);
+ lhs_dim_map, rhs_dim_map, output_dim_map);
}
// Generate shardings for contraction dimensions
@@ -654,10 +671,10 @@
return;
}
- absl::flat_hash_set<int> unused_mesh_dims;
+ std::vector<int> unused_mesh_dims;
for (size_t i = 0; i < device_mesh_.num_dimensions(); ++i) {
if (!used_mesh_dims.contains(i) && device_mesh_.dim(i) > 1) {
- unused_mesh_dims.insert(i);
+ unused_mesh_dims.push_back(i);
}
}
@@ -675,11 +692,11 @@
DimMap lhs_dim_map_with_contractions = lhs_dim_map;
DimMap rhs_dim_map_with_contractions = rhs_dim_map;
- for (const auto& [reducton_dim_index, mesh_dim] : reduction_dim_map) {
+ for (const auto& [reduction_dim_index, mesh_dim_set] : reduction_dim_map) {
lhs_dim_map_with_contractions
- [lhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim;
+ [lhs_con_dims_[reduction_dims[reduction_dim_index]]] = mesh_dim_set;
rhs_dim_map_with_contractions
- [rhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim;
+ [rhs_con_dims_[reduction_dims[reduction_dim_index]]] = mesh_dim_set;
}
// Skip fully the replicated strategy here as we add that outside of
// HandleDot in auto_sharding_strategy.
@@ -696,8 +713,10 @@
double memory_cost =
ByteSizeOfShapeWithSharding(ins_->shape(), output_sharding);
double total_cost = 0;
- for (const auto& [_, mesh_dim] : reduction_dim_map) {
- total_cost += cluster_env_.AllReduceCost(memory_cost, mesh_dim);
+ for (const auto& [_, mesh_dim_set] : reduction_dim_map) {
+ for (int mesh_dim : mesh_dim_set) {
+ total_cost += cluster_env_.AllReduceCost(memory_cost, mesh_dim);
+ }
}
return total_cost;
};
@@ -705,24 +724,24 @@
MaybeAppend(GenerateNameForDotSharding(output_dim_map,
lhs_dim_map_with_contractions),
lhs_dim_map_with_contractions, rhs_dim_map_with_contractions,
- output_dim_map, device_mesh_,
+ output_dim_map,
/*compute_cost=*/0, communication_cost_fn);
};
- EnumerateGeneral(split_func, reduction_dims.size(),
- /*current_tensor_dim=*/0, unused_mesh_dims,
- /*current_dim_map=*/{});
+ Enumerate(split_func, reduction_dims.size(),
+ /*current_mesh_dim_idx=*/0, unused_mesh_dims,
+ /*current_dim_map=*/{});
}
void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost) {
+ double compute_cost) {
const HloInstruction* operand = ins_->operand(operand_num);
const DimMap& operand_dim_map = operand_num == 0 ? lhs_dim_map : rhs_dim_map;
absl::flat_hash_set<int64_t> used_mesh_dims;
- for (const auto [tensor_dim, mesh_dim] : operand_dim_map) {
- used_mesh_dims.insert(mesh_dim);
+ for (const auto& [tensor_dim, mesh_dim_set] : operand_dim_map) {
+ used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
}
if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
used_mesh_dims.size() == operand->shape().rank()) {
@@ -732,16 +751,16 @@
for (int64_t tensor_dim = 0; tensor_dim < operand->shape().rank();
++tensor_dim) {
if (auto it = operand_dim_map.find(tensor_dim);
- it != operand_dim_map.end() && device_mesh.dim(it->second) > 1) {
+ it != operand_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
continue;
}
- for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
+ for (int mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
++mesh_dim) {
if (used_mesh_dims.contains(mesh_dim)) {
continue;
}
DimMap further_sharded_dim_map = operand_dim_map;
- further_sharded_dim_map[tensor_dim] = mesh_dim;
+ further_sharded_dim_map[tensor_dim] = MeshDimSet{mesh_dim};
auto communication_cost_fn =
[](const HloSharding& output_sharding) -> double {
@@ -756,7 +775,7 @@
updated_name,
operand_num == 0 ? further_sharded_dim_map : lhs_dim_map,
operand_num == 1 ? further_sharded_dim_map : rhs_dim_map,
- output_dim_map, device_mesh, compute_cost, communication_cost_fn);
+ output_dim_map, compute_cost, communication_cost_fn);
}
}
}
@@ -764,10 +783,10 @@
void DotHandler::AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
- const Array<int64_t>& device_mesh, double compute_cost) {
+ double compute_cost) {
absl::flat_hash_set<int64_t> used_mesh_dims;
- for (const auto [tensor_dim, mesh_dim] : output_dim_map) {
- used_mesh_dims.insert(mesh_dim);
+ for (const auto& [tensor_dim, mesh_dim_set] : output_dim_map) {
+ used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
}
if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
@@ -778,16 +797,16 @@
for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().rank();
++tensor_dim) {
if (auto it = output_dim_map.find(tensor_dim);
- it != output_dim_map.end() && device_mesh.dim(it->second) > 1) {
+ it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
continue;
}
- for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
+ for (int mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
++mesh_dim) {
if (used_mesh_dims.contains(mesh_dim)) {
continue;
}
DimMap further_sharded_dim_map = output_dim_map;
- further_sharded_dim_map[tensor_dim] = mesh_dim;
+ further_sharded_dim_map[tensor_dim] = MeshDimSet{mesh_dim};
auto communication_cost_fn =
[](const HloSharding& output_sharding) -> double {
@@ -799,23 +818,21 @@
name,
absl::StrFormat("|rs_windowed_einsum_t%dm%d", tensor_dim, mesh_dim));
MaybeAppendInternal(updated_name, lhs_dim_map, rhs_dim_map,
- further_sharded_dim_map, device_mesh, compute_cost,
+ further_sharded_dim_map, compute_cost,
communication_cost_fn);
}
}
}
absl::Status DotHandler::RegisterStrategies() {
- absl::flat_hash_set<int> all_mesh_dims;
- for (int i = 0; i < device_mesh_.num_dimensions(); ++i) {
- all_mesh_dims.insert(i);
- }
- EnumerateGeneral(
+ std::vector<int> all_mesh_dims(device_mesh_.num_dimensions());
+ std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0);
+ Enumerate(
/*split_func=*/
[&](const DimMap& output_dim_map) {
GenerateDotShardingStrategiesFromOutputSharding(output_dim_map);
},
- ins_->shape().rank(), /*current_tensor_dim=*/0, all_mesh_dims,
+ ins_->shape().rank(), /*current_mesh_dim_idx=*/0, all_mesh_dims,
/*current_dim_map=*/{});
SortStrategies();
return absl::OkStatus();
@@ -853,27 +870,27 @@
// Propagate batch dim sharding
auto it = output_dim_map.find(out_batch_dim_);
- if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) {
- int mesh_dim = it->second;
- lhs_dim_map[lhs_batch_dim_] = mesh_dim;
- used_mesh_dims.insert(mesh_dim);
- absl::StrAppend(&name, "b", mesh_dim);
+ if (it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
+ const MeshDimSet& mesh_dim_set = it->second;
+ lhs_dim_map[lhs_batch_dim_] = mesh_dim_set;
+ used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+ absl::StrAppend(&name, "b", ToString(mesh_dim_set));
} else {
absl::StrAppend(&name, "b-1");
}
// Propagate out channel dim sharding
it = output_dim_map.find(out_out_channel_dim_);
- if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) {
- int mesh_dim = it->second;
- lhs_dim_map[rhs_out_channel_dim_] = mesh_dim;
- used_mesh_dims.insert(mesh_dim);
- absl::StrAppend(&name, "oc", mesh_dim);
+ if (it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
+ const MeshDimSet& mesh_dim_set = it->second;
+ lhs_dim_map[rhs_out_channel_dim_] = mesh_dim_set;
+ used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+ absl::StrAppend(&name, "oc", ToString(mesh_dim_set));
} else {
absl::StrAppend(&name, "oc-1");
}
- MaybeAppend(name, lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_);
+ MaybeAppend(name, lhs_dim_map, rhs_dim_map, output_dim_map);
// Generate shardings for contraction dimensions
if (used_mesh_dims.size() == device_mesh_.num_dimensions()) {
@@ -891,12 +908,12 @@
return;
}
- for (int64_t mesh_dim : unused_mesh_dims) {
+ for (int mesh_dim : unused_mesh_dims) {
DimMap lhs_dim_map_with_contractions = lhs_dim_map;
DimMap rhs_dim_map_with_contractions = rhs_dim_map;
- lhs_dim_map_with_contractions[lhs_in_channel_dim_] = mesh_dim;
- rhs_dim_map_with_contractions[rhs_in_channel_dim_] = mesh_dim;
+ lhs_dim_map_with_contractions[lhs_in_channel_dim_] = MeshDimSet{mesh_dim};
+ rhs_dim_map_with_contractions[rhs_in_channel_dim_] = MeshDimSet{mesh_dim};
absl::StrAppend(&name, "ic", mesh_dim, "@allreduce");
auto communication_cost_fn = [&](const HloSharding& output_sharding) {
@@ -906,7 +923,7 @@
};
MaybeAppend(name, lhs_dim_map_with_contractions,
- rhs_dim_map_with_contractions, output_dim_map, device_mesh_,
+ rhs_dim_map_with_contractions, output_dim_map,
/*compute_cost=*/0, communication_cost_fn);
}
}
@@ -931,15 +948,13 @@
SplitDepthwise(false);
}
- absl::flat_hash_set<int> all_mesh_dims;
- for (int i = 0; i < device_mesh_.num_dimensions(); ++i) {
- all_mesh_dims.insert(i);
- }
- EnumerateGeneral(
+ std::vector<int> all_mesh_dims(device_mesh_.num_dimensions());
+ std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0);
+ Enumerate(
[&](const DimMap& output_dim_map) {
GenerateConvolutionShardingStrategiesFromOutputSharding(output_dim_map);
},
- 2, /*current_tensor_dim=*/0, all_mesh_dims,
+ 2, /*current_mesh_dim_idx=*/0, all_mesh_dims,
/*current_dim_map=*/{});
// If force_batch_dim_to_mesh_dim is set, filter out invalid strategies
@@ -957,38 +972,37 @@
void ConvHandler::SplitDepthwise(bool forward) {
std::function<void(const DimMap&)> split_func =
[&](const DimMap& output_dim_map) {
- int out_batch_mesh_dim = -1;
- int out_out_channel_mesh_dim = -1;
+ MeshDimSet out_batch_mesh_dim_set;
+ MeshDimSet out_out_channel_mesh_dim_set;
if (auto it = output_dim_map.find(out_batch_dim_);
it != output_dim_map.end()) {
- out_batch_mesh_dim = it->second;
+ out_batch_mesh_dim_set = it->second;
}
if (auto it = output_dim_map.find(out_out_channel_dim_);
it != output_dim_map.end()) {
- out_out_channel_mesh_dim = it->second;
+ out_out_channel_mesh_dim_set = it->second;
}
- if (out_batch_mesh_dim == -1 || out_out_channel_mesh_dim == -1) {
+ if (out_batch_mesh_dim_set.empty() ||
+ out_out_channel_mesh_dim_set.empty()) {
return;
}
DimMap lhs_dim_map, rhs_dim_map;
lhs_dim_map[lhs_batch_dim_] =
- forward ? out_batch_mesh_dim : out_out_channel_mesh_dim;
+ forward ? out_batch_mesh_dim_set : out_out_channel_mesh_dim_set;
lhs_dim_map[lhs_in_channel_dim_] =
- forward ? out_out_channel_mesh_dim : out_batch_mesh_dim;
+ forward ? out_out_channel_mesh_dim_set : out_batch_mesh_dim_set;
- rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim;
+ rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim_set;
- MaybeAppend(absl::StrCat("b", out_batch_mesh_dim, "oc",
- out_out_channel_mesh_dim, "|depthwise"),
- lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_);
+ MaybeAppend(
+ absl::StrCat("b", ToString(out_batch_mesh_dim_set), "oc",
+ ToString(out_out_channel_mesh_dim_set), "|depthwise"),
+ lhs_dim_map, rhs_dim_map, output_dim_map);
};
- absl::flat_hash_set<int> all_mesh_dims;
- for (int i = 0; i < device_mesh_.num_dimensions(); ++i) {
- all_mesh_dims.insert(i);
- }
- EnumerateGeneral(split_func, 2, /*current_tensor_dim=*/0, all_mesh_dims,
- /*current_dim_map=*/{});
+ std::vector<int> all_mesh_dims(device_mesh_.num_dimensions());
+ Enumerate(split_func, 2, /*current_mesh_dim_idx=*/0, all_mesh_dims,
+ /*current_dim_map=*/{});
}
} // namespace
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
index d2f124b..56ca325 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
@@ -82,8 +82,6 @@
absl::StrCat("allow_recompute_heavy_op: ", allow_recompute_heavy_op));
lines.push_back(
absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape));
- lines.push_back(
- absl::StrCat("grad_acc_num_micro_batches: ", grad_acc_num_micro_batches));
lines.push_back(absl::StrCat("solve_nd_sharding_iteratively: ",
solve_nd_sharding_iteratively));
lines.push_back(
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
index 51eceae..5e73af7 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
@@ -119,11 +119,6 @@
// If true, allow adding 1d strategies in 2d logical mesh.
bool allow_mixed_mesh_shape = true;
- // The number of micro batches if gradient accumulation is used.
- // If this is not 1, the cost of all-reduce for gradient synchronization
- // is divided by this number.
- int grad_acc_num_micro_batches = 1;
-
// If true, N-D sharding (e.g., N maybe be 2 or 3) will be solved in N
// iterations, where one iteration chooses one tensor dimension to shard. If
// false, solve N-D sharding directly, i.e., generating all possible sharding
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
index f204ff4..67a6fed 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
@@ -67,6 +67,10 @@
// solver cannot guarantee exact numerical precision.
constexpr double kMaxCostEpsilon = 1.0001;
+// Memory contributions in the Mixed ILP are converted to units in this range;
+// beware that significantly larger / smaller values can cause numerical issues.
+constexpr double kMemoryMultiplier = 1e6;
+
bool AutoShardingSolverOutput::operator==(
const AutoShardingSolverOutput& other) const {
return s_val == other.s_val && cost == other.cost &&
@@ -261,7 +265,7 @@
reduced_groups.push_back({group.prims().begin(), group.prims().end()});
}
}
- solver.MakeIntVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(),
+ solver.MakeNumVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(),
absl::StrCat("group_", prim_type), &group_vars);
for (int64_t group_idx = 0; group_idx < group_vars.size(); ++group_idx) {
MPConstraint* constraint = solver.MakeRowConstraint(
@@ -271,7 +275,7 @@
for (const int64_t prim_idx : reduced_groups[group_idx]) {
for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) {
double memory_cost = memory_costs.at(prim_idx).costs(j);
- memory_cost /= request.memory_budget() / 100.0;
+ memory_cost /= request.memory_budget() / kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(prim_vars[prim_idx][j]);
constraint->SetCoefficient(prim_vars[prim_idx][j],
@@ -302,9 +306,12 @@
time_idx <= intervals[prim_idx].second; ++time_idx) {
if (!reduced_times.contains(time_idx)) continue;
if (!constraints.contains(time_idx)) {
- MPConstraint* constraint = solver.MakeRowConstraint(
- -MPSolver::infinity(), 100.0, absl::StrCat("mem[", time_idx, "]"));
- if (overbudget_var) constraint->SetCoefficient(overbudget_var, -100.0);
+ MPConstraint* constraint =
+ solver.MakeRowConstraint(-MPSolver::infinity(), kMemoryMultiplier,
+ absl::StrCat("mem[", time_idx, "]"));
+ if (overbudget_var) {
+ constraint->SetCoefficient(overbudget_var, -kMemoryMultiplier);
+ }
constraints[time_idx] = constraint;
}
MPConstraint* constraint = constraints[time_idx];
@@ -314,7 +321,7 @@
}
for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) {
double memory_cost = memory_costs.at(prim_idx).costs(j);
- memory_cost /= request.memory_budget() / 100.0;
+ memory_cost /= request.memory_budget() / kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(prim_vars[prim_idx][j]);
constraint->SetCoefficient(prim_vars[prim_idx][j],
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
index 5a237b2..05631bc 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
@@ -441,6 +441,46 @@
EXPECT_EQ(result, expected_result);
}
+TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
+ AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
+ const std::vector<std::pair<int64_t, int64_t>> node_intervals =
+ {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
+ const std::vector<std::pair<int64_t, int64_t>> edge_intervals =
+ {{1, 2}, {2, 3}};
+ const std::vector<std::vector<int64_t>> node_groups = {{0, 1}};
+ const std::vector<std::vector<int64_t>> edge_groups = {};
+ const CostMatrix memory_costs = {{1, 1, 1, 1}, // These values are tiny and
+ {2, 2, 2}, // shouldn't be rounded up.
+ {300, 300, 300, 300, 300, 300, 300},
+ {4000, 4000, 4000, 4000, 4000, 4000, 4000},
+ {50000, 50000, 50000}};
+ const CostMatrix memory_edge_costs = {{0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0},
+ {0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0}};
+ request.clear_live();
+ request.clear_memory_costs();
+ AddIntervals(request.mutable_node_intervals(), node_intervals);
+ AddIntervals(request.mutable_edge_intervals(), edge_intervals);
+ AddGroups(request.mutable_node_groups(), node_groups);
+ AddGroups(request.mutable_edge_groups(), edge_groups);
+ AddCosts(request.mutable_memory_costs(), memory_costs);
+ AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
+ request.set_enable_memory_edge_costs(true);
+ request.set_memory_budget(4321);
+
+ const AutoShardingSolverResult result = CallORToolsSolver(request);
+
+ const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
+ const double objective_value = 7650.0;
+ const AutoShardingSolverOutput expected_output = {s_val, objective_value};
+ const AutoShardingSolverResult expected_result = {expected_output, false};
+ EXPECT_EQ(result, expected_result);
+}
+
TEST(CallORToolsSolverTest, SolvesWithEquivalences) {
const AutoShardingSolverRequest request =
AutoShardingSolverRequestWithEquivalences();
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
index dd6d38c..27ddc79 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
@@ -824,15 +824,7 @@
}
};
- if (IsCustomCallMarker(ins)) {
- const HloInstruction* operand = ins->operand(0);
- const StrategyGroup* src_strategy_group =
- strategy_map.at(operand).get();
- CHECK(src_strategy_group->is_tuple);
- strategy_group = MaybeFollowInsStrategyGroup(
- src_strategy_group, ins->shape(), instruction_id, strategy_groups,
- cluster_env, pretrimmed_strategy_map);
- } else if (IsSPMDFullToShardShapeCustomCall(ins)) {
+ if (IsSPMDFullToShardShapeCustomCall(ins)) {
return absl::InternalError(
"An SPMDFullToShardShape call found outside a manually "
"partitioned sub-graph.");
@@ -1030,24 +1022,6 @@
strategy_map[ins] = std::move(strategy_group);
} // end of for loop
- // If gradient accumulation is used, adjust the cost of all-reduce for
- // gradient synchronization.
- if (option.grad_acc_num_micro_batches > 1) {
- // find gradient-computation instructions
- std::vector<const HloInstruction*> grad_insts =
- GetGradientComputationInstructions(instructions);
- for (const HloInstruction* inst : grad_insts) {
- StrategyGroup* stra_vector = strategy_map[inst].get();
- CHECK(!stra_vector->is_tuple);
-
- for (auto& stra : stra_vector->strategies) {
- if (absl::StrContains(stra.name, "allreduce")) {
- stra.communication_cost /= option.grad_acc_num_micro_batches;
- }
- }
- }
- }
-
return std::make_tuple(std::move(strategy_map), std::move(strategy_groups),
std::move(associative_dot_pairs));
}
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
index c899a0c..8ab3824 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
@@ -29,6 +29,7 @@
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
+#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
@@ -554,7 +555,7 @@
TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(10) << module->ToString();
EXPECT_TRUE(changed);
- auto* instruction = FindInstruction(module.get(), "p0");
+ const HloInstruction* instruction = FindInstruction(module.get(), "p0");
ASSERT_NE(instruction, nullptr);
EXPECT_THAT(instruction, op::Sharding("{replicated}"));
}
@@ -687,14 +688,49 @@
TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(10) << module->ToString();
EXPECT_TRUE(changed);
- auto* param0 = FindInstruction(module.get(), "param.0");
- auto* param1 = FindInstruction(module.get(), "param.1");
+ const HloInstruction* param0 = FindInstruction(module.get(), "param.0");
+ const HloInstruction* param1 = FindInstruction(module.get(), "param.1");
ASSERT_NE(param0, nullptr);
ASSERT_NE(param0, nullptr);
EXPECT_THAT(param0, op::Sharding("{replicated}"));
EXPECT_THAT(param1, op::Sharding("{replicated}"));
}
+TEST_F(AutoShardingTest, DotMixedMeshStrategies) {
+ constexpr absl::string_view kHloString = R"(
+HloModule module
+ENTRY %entry {
+ %param0 = f32[8192,23]{1,0} parameter(0), sharding={devices=[4,1]0,1,2,3}
+ %param1 = f32[23,23]{1,0} parameter(1)
+ %dot = f32[8192,23]{1,0} dot(%param0, %param1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ ROOT %copy = f32[8192,23]{1,0} copy(%dot)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+ AutoShardingOption option;
+ option.enable = true;
+ option.device_mesh_shape = {2, 2};
+ option.device_mesh_ids = {0, 1, 2, 3};
+ option.device_mesh_alpha = {1.0, 1.0};
+ option.device_mesh_beta = {0.01, 1.0};
+ option.solve_nd_sharding_iteratively = false;
+ option.preserve_shardings =
+ AutoShardingOption::PreserveShardingsType::kKeepAllShardings;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
+ VLOG(2) << module->ToString();
+ EXPECT_TRUE(changed);
+ const HloInstruction* param0 = FindInstruction(module.get(), "param0");
+ const HloInstruction* param1 = FindInstruction(module.get(), "param1");
+ const HloInstruction* dot = FindInstruction(module.get(), "dot");
+ ASSERT_NE(param0, nullptr);
+ ASSERT_NE(param1, nullptr);
+ ASSERT_NE(dot, nullptr);
+ EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}"));
+ EXPECT_THAT(param1, op::Sharding("{replicated}"));
+ EXPECT_THAT(dot, op::Sharding("{devices=[4,1]0,1,2,3}"));
+}
+
TEST_F(AutoShardingTest, DotLHSTwoNonContractingDims) {
constexpr absl::string_view kHloString = R"(
HloModule module
@@ -717,9 +753,9 @@
TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(2) << module->ToString();
EXPECT_TRUE(changed);
- auto* param0 = FindInstruction(module.get(), "param0");
- auto* param1 = FindInstruction(module.get(), "param1");
- auto* dot = FindInstruction(module.get(), "dot");
+ const HloInstruction* param0 = FindInstruction(module.get(), "param0");
+ const HloInstruction* param1 = FindInstruction(module.get(), "param1");
+ const HloInstruction* dot = FindInstruction(module.get(), "dot");
ASSERT_NE(param0, nullptr);
ASSERT_NE(param1, nullptr);
ASSERT_NE(dot, nullptr);
@@ -770,9 +806,9 @@
TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(2) << module->ToString();
EXPECT_TRUE(changed);
- auto* param0 = FindInstruction(module.get(), "param0");
- auto* param1 = FindInstruction(module.get(), "param1");
- auto* dot = FindInstruction(module.get(), "dot");
+ const HloInstruction* param0 = FindInstruction(module.get(), "param0");
+ const HloInstruction* param1 = FindInstruction(module.get(), "param1");
+ const HloInstruction* dot = FindInstruction(module.get(), "dot");
ASSERT_NE(param0, nullptr);
ASSERT_NE(param1, nullptr);
ASSERT_NE(dot, nullptr);
@@ -2482,6 +2518,36 @@
input_output_alias_config_after.ToString());
}
+TEST(NormalizeTest, NormalizeHandlesNegativeCosts) {
+ EdgeReshardingCostMatrix edge_cost(2, 2);
+ edge_cost(0, 0).communication_cost = -100;
+ edge_cost(0, 1).communication_cost = 200;
+ edge_cost(1, 0).communication_cost = 300;
+ edge_cost(1, 1).communication_cost = 400;
+
+ const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost);
+
+ EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 0);
+ EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 300);
+ EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 400);
+ EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 500);
+}
+
+TEST(NormalizeTest, NormalizeHandlesPositiveCosts) {
+ EdgeReshardingCostMatrix edge_cost(2, 2);
+ edge_cost(0, 0).communication_cost = 100;
+ edge_cost(0, 1).communication_cost = 200;
+ edge_cost(1, 0).communication_cost = 300;
+ edge_cost(1, 1).communication_cost = 400;
+
+ const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost);
+
+ EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 100);
+ EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 200);
+ EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 300);
+ EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 400);
+}
+
} // namespace
} // namespace spmd
} // namespace xla
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
index cb8ec2c..d49f16b 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
@@ -59,16 +59,10 @@
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
namespace xla {
namespace spmd {
-inline const HloInstruction* PassThroughCustomCallMarkerGetSource(
- const HloInstruction* ins);
-inline HloInstruction* PassThroughCustomCallMarkerUser(
- HloInstruction* raw_user, const HloInstruction* inst);
-
std::optional<HloSharding> GetInputSharding(const HloInstruction* ins,
int64_t op_index,
const HloSharding& output_sharding,
@@ -109,26 +103,6 @@
return inferred_sharding;
}
-// Return whether the instruction is an activation from another pipeline stage.
-bool IsActivationFromAnotherStage(const HloInstruction* ins,
- const InstructionBatchDimMap& batch_dim_map) {
- if (!(ins->opcode() == HloOpcode::kParameter &&
- batch_dim_map.contains(GetBatchDimMapKey(ins)))) {
- return false;
- }
-
- for (const HloInstruction* user : ins->users()) {
- if (!(user->opcode() == HloOpcode::kTuple && user->users().size() == 1 &&
- user->users().front()->IsCustomCall(kPipelineMarker) &&
- absl::StrContains(user->users().front()->metadata().op_type(),
- "start"))) {
- return false;
- }
- }
-
- return true;
-}
-
// Propagate sharding for dim-wise operations (e.g., slice, pad) which works
// independently on each dimension.
// The sharding can successfully propagate if the operation only happens
@@ -194,11 +168,6 @@
if (degree_dict[inst] == 0) {
depth_map[inst] = 0;
- // Add some initial depth for activations from other pipeline stages.
- if (IsActivationFromAnotherStage(inst, batch_dim_map)) {
- depth_map[inst] = 20;
- }
-
current_frontier.push_back(inst);
collected++;
}
@@ -246,10 +215,6 @@
if (reset) {
depth_map[node] = 0;
- } else if (node->opcode() == HloOpcode::kGetTupleElement &&
- IsCustomCallMarker(node->operand(0))) {
- depth_map[node] =
- depth_map.at(PassThroughCustomCallMarkerGetSource(node));
} else {
int64_t max_depth = depth_map.at(inst) + delta;
for (const HloInstruction* operand : node->operands()) {
@@ -813,12 +778,6 @@
batch_map[ins->name()] = batch_dim_of_source;
}
}
-
- if (ins->IsCustomCall(kPipelineMarker) &&
- absl::StrContains(ins->metadata().op_type(), "start")) {
- // Reset the status after meet a new pipeline marker.
- set_the_next_dot_conv = true;
- }
}
int64_t previous_cnt = 0;
while (true) {
@@ -968,8 +927,7 @@
for (HloInstruction* node : boundary_set) {
HloInstruction* cur = node;
while (cur->operand_count() == 1) {
- HloInstruction* operand =
- PassThroughCustomCallMarkerOperand(cur->mutable_operand(0), cur);
+ HloInstruction* operand = cur->mutable_operand(0);
if (replicated_set.contains(operand)) {
path.insert(cur);
}
@@ -1007,8 +965,7 @@
// Find the add instruction for grad accumulation, skip the identity marker
// for remat and other elementwise ops.
- HloInstruction* add =
- PassThroughCustomCallMarkerUser(inst->users().front(), inst);
+ HloInstruction* add = inst->users().front();
if (add->opcode() == HloOpcode::kGetTupleElement ||
add->opcode() == HloOpcode::kTranspose) {
if (add->users().size() != 1) {
@@ -1025,7 +982,7 @@
}
CHECK_EQ(add->users().size(), 1);
// Skip the end marker of backward computation
- add = PassThroughCustomCallMarkerUser(add->users().front(), add);
+ add = add->users().front();
// Do not partition the dot, add and parameter, so we can generate
// all-reduce for grad accumulation.
@@ -1037,7 +994,7 @@
replicated_set.erase(cur);
for (auto x : cur->operands()) {
- dfs_remove(PassThroughCustomCallMarkerOperand(x, cur));
+ dfs_remove(x);
}
};
@@ -1138,39 +1095,23 @@
return sharded_dims <= 0;
}
-absl::StatusOr<std::vector<int64_t>> GetTensorDimToMeshDimNoCrash(
- int64_t tensor_shape_rank, const HloSharding& spec,
- const Array<int64_t>& device_mesh, bool consider_reverse_device_meshes) {
- if (spec.IsReplicated()) {
- return std::vector<int64_t>(tensor_shape_rank, -1);
- }
- // Check the compatibility of tensor_shape_rank and spec
- if (tensor_shape_rank != spec.TiledDataRank()) {
- return absl::InvalidArgumentError(
- "Tensor shape rank should be equal to the tiled data rank of the input "
- "spec.");
- }
-
+absl::StatusOr<std::vector<int64_t>> GetMeshDimPermutationOrderInShardingSpec(
+ const HloSharding& spec, const Array<int64_t>& device_mesh,
+ bool consider_reverse_device_meshes) {
auto check_mesh =
[&](const Array<int64_t>& mesh) -> std::optional<std::vector<int64_t>> {
// Permute the dimensions (or axes in numpy term), find the transform that
// makes tile_assignment == device_mesh.
std::vector<int64_t> axes(mesh.num_dimensions());
absl::c_iota(axes, 0);
- bool found = false;
do {
Array<int64_t> transposed_mesh = Transpose(mesh, axes);
if (std::equal(transposed_mesh.begin(), transposed_mesh.end(),
spec.tile_assignment().array().begin())) {
- found = true;
- break;
+ return axes;
}
} while (absl::c_next_permutation(axes));
- if (found) {
- return std::optional<std::vector<int64_t>>(axes);
- } else {
- return std::nullopt;
- }
+ return std::nullopt;
};
// This is an expensive search, as we try all possible meshes obtained by
@@ -1178,7 +1119,6 @@
// the somewhat rare kReverse HLO op. The hope therefore is that most calls to
// the function that reach here will find a mapping within the first iteration
// of the loop below.
- bool found = false;
std::vector<int64_t> axes(device_mesh.num_dimensions());
size_t num_subsets =
consider_reverse_device_meshes ? (1 << device_mesh.num_dimensions()) : 1;
@@ -1199,24 +1139,35 @@
*device = device_mesh(original_indices);
});
if (auto result = check_mesh(new_mesh); result.has_value()) {
- axes = result.value();
- found = true;
- break;
+ return result.value();
}
}
+ return absl::NotFoundError(absl::StrCat("Could not find mapping for ",
+ spec.ToString(), " with device mesh ",
+ device_mesh.ToString()));
+}
- if (!found) {
- return absl::NotFoundError(
- absl::StrCat("Could not find mapping for ", spec.ToString(),
- " with device mesh ", device_mesh.ToString()));
+absl::StatusOr<std::vector<int64_t>> GetTensorDimToMeshDimNoCrash(
+ int64_t tensor_shape_rank, const HloSharding& spec,
+ const Array<int64_t>& device_mesh, bool consider_reverse_device_meshes) {
+ if (spec.IsReplicated()) {
+ return std::vector<int64_t>(tensor_shape_rank, -1);
}
-
+ // Check the compatibility of tensor_shape_rank and spec
+ if (tensor_shape_rank != spec.TiledDataRank()) {
+ return absl::InvalidArgumentError(
+ "Tensor shape rank should be equal to the tiled data rank of the input "
+ "spec.");
+ }
if (!TileAssignmentMatchesMesh(spec, device_mesh)) {
return absl::InvalidArgumentError(
"Device mesh and tile assignment need to have the same number of "
"sharded dims.");
}
+ TF_ASSIGN_OR_RETURN(std::vector<int64_t> axes,
+ GetMeshDimPermutationOrderInShardingSpec(
+ spec, device_mesh, consider_reverse_device_meshes));
// Transform tile_assignment_dimensions using found transformation (axes).
std::vector<int64_t> tensor_dim_to_device_dim(tensor_shape_rank, -1);
int mesh_index = 0;
@@ -1558,7 +1509,7 @@
// Create a HloSharding that tiles some tensor dims on some device mesh dims.
HloSharding Tile(const Shape& tensor_shape,
absl::Span<const int64_t> tensor_dims,
- absl::Span<const int64_t> mesh_dims,
+ const std::vector<std::vector<int64_t>>& mesh_dims,
const Array<int64_t>& device_mesh) {
CHECK_EQ(tensor_dims.size(), mesh_dims.size());
CHECK(tensor_shape.IsArray());
@@ -1567,8 +1518,12 @@
// Split on certain mesh dimensions
int64_t split_prod = 1;
for (size_t i = 0; i < tensor_dims.size(); ++i) {
- tile_assignment_dimensions[tensor_dims[i]] = device_mesh.dim(mesh_dims[i]);
- split_prod *= device_mesh.dim(mesh_dims[i]);
+ int64_t num_devices_for_tensor_dim = 1;
+ for (int64_t mesh_dim_idx : mesh_dims[i]) {
+ num_devices_for_tensor_dim *= device_mesh.dim(mesh_dim_idx);
+ }
+ tile_assignment_dimensions[tensor_dims[i]] = num_devices_for_tensor_dim;
+ split_prod *= num_devices_for_tensor_dim;
}
// Replicate on remaining mesh dimensions
bool replicate_on_last_tile_dim = false;
@@ -1582,35 +1537,58 @@
std::vector<int64_t> tile_assignment_devices;
tile_assignment_devices.reserve(device_mesh.num_elements());
- std::vector<int64_t> tmp_indices(device_mesh.num_dimensions(), 0);
- std::function<void(int64_t, std::vector<int64_t>)>
+ std::function<void(int64_t, int64_t, std::vector<int64_t>)>
generate_tile_assignment_devices;
- generate_tile_assignment_devices = [&](int64_t tensor_dim,
+ generate_tile_assignment_devices = [&](int64_t current_tensor_dim,
+ int64_t current_mesh_dim_idx,
std::vector<int64_t> mesh_indices) {
- if (tensor_dim == tensor_shape.rank() - 1) {
- AppendFlattenElements(&tile_assignment_devices, device_mesh, mesh_indices,
- -1, tmp_indices);
+ int64_t current_tensor_dim_index =
+ GetIndex(tensor_dims, current_tensor_dim);
+ bool proceed_to_next_tensor_dim = false;
+ if (current_tensor_dim_index >= 0) {
+ proceed_to_next_tensor_dim =
+ (current_mesh_dim_idx ==
+ mesh_dims[current_tensor_dim_index].size() - 1);
} else {
- int64_t next_tensor_dim = tensor_dim + 1;
- int64_t next_mesh_dim = -1;
+ proceed_to_next_tensor_dim = true;
+ }
- int64_t index = GetIndex(tensor_dims, next_tensor_dim);
- if (index >= 0) {
- next_mesh_dim = mesh_dims[index];
- }
+ if (proceed_to_next_tensor_dim &&
+ current_tensor_dim == tensor_shape.rank() - 1) {
+ AppendFlattenElements(&tile_assignment_devices, device_mesh,
+ mesh_indices);
+ return;
+ }
- for (int64_t i = 0; i < tile_assignment_dimensions[next_tensor_dim];
- ++i) {
- if (next_mesh_dim != -1) {
- mesh_indices[next_mesh_dim] = i;
- }
- generate_tile_assignment_devices(next_tensor_dim, mesh_indices);
+ int64_t next_tensor_dim, next_mesh_dim_idx = -1, next_mesh_dim = -1;
+ if (proceed_to_next_tensor_dim) {
+ next_tensor_dim = current_tensor_dim + 1;
+ next_mesh_dim_idx = -1;
+ int64_t next_tensor_dim_index = GetIndex(tensor_dims, next_tensor_dim);
+ if (next_tensor_dim_index >= 0) {
+ next_mesh_dim_idx = 0;
+ next_mesh_dim = mesh_dims[next_tensor_dim_index][0];
}
+ } else {
+ next_tensor_dim = current_tensor_dim;
+ next_mesh_dim_idx = current_mesh_dim_idx + 1;
+ next_mesh_dim = mesh_dims[current_tensor_dim_index][next_mesh_dim_idx];
+ }
+
+ int64_t limit =
+ (next_mesh_dim_idx >= 0) ? device_mesh.dim(next_mesh_dim) : 1;
+ for (int64_t i = 0; i < limit; ++i) {
+ if (next_mesh_dim != -1) {
+ mesh_indices[next_mesh_dim] = i;
+ }
+ generate_tile_assignment_devices(next_tensor_dim, next_mesh_dim_idx,
+ mesh_indices);
}
};
std::vector<int64_t> mesh_indices(device_mesh.num_dimensions(), -1);
- generate_tile_assignment_devices(-1, mesh_indices);
+ generate_tile_assignment_devices(/*current_tensor_dim=*/-1,
+ /*current_mesh_dim_idx=*/-1, mesh_indices);
// Make HloSharding
Array<int64_t> tile_assignment(tile_assignment_dimensions);
@@ -1625,6 +1603,17 @@
: HloSharding::Tile(std::move(tile_assignment));
}
+HloSharding Tile(const Shape& tensor_shape,
+ absl::Span<const int64_t> tensor_dims,
+ absl::Span<const int64_t> mesh_dims,
+ const Array<int64_t>& device_mesh) {
+ std::vector<std::vector<int64_t>> mesh_dims_general(mesh_dims.size());
+ for (int i = 0; i < mesh_dims.size(); ++i) {
+ mesh_dims_general[i].push_back(mesh_dims[i]);
+ }
+ return Tile(tensor_shape, tensor_dims, mesh_dims_general, device_mesh);
+}
+
AliasMap BuildAliasMap(const HloModule* module,
const HloInputOutputAliasConfig& alias_config) {
AliasMap alias_map;
@@ -1633,10 +1622,6 @@
const auto& parameter_instructions = entry->parameter_instructions();
const HloInstruction* output_tuple = entry->root_instruction();
- if (IsCustomCallMarker(output_tuple)) {
- output_tuple = output_tuple->operand(0);
- }
-
absl::flat_hash_map<int64_t, absl::flat_hash_map<int64_t, HloInstruction*>>
parameter_index_to_operand_map;
alias_config.ForEachAlias([&](const ShapeIndex& output_index,
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h
index 64749c6..f114a3d 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h
@@ -48,10 +48,7 @@
namespace xla {
namespace spmd {
-inline constexpr absl::string_view kPipelineMarker = "xla_pipeline_marker";
inline constexpr absl::string_view kIdentityMarker = "identity";
-inline constexpr absl::string_view kPipelineMarkerStartType = "start";
-inline constexpr absl::string_view kPipelineMarkerEndType = "end";
inline constexpr int64_t kAutoShardingPointerSize = 8;
@@ -94,9 +91,11 @@
// Append elements of `array` to `result`. The `indices` is a generalized
// multi-dimensional index that can index a whole row (use -1 to indicate this).
template <typename T>
-void AppendFlattenElements(std::vector<T>* result, const Array<T>& array,
- absl::Span<const int64_t> indices, int cur_depth,
- std::vector<int64_t> cur_indices) {
+void AppendFlattenElementsInternal(std::vector<T>* result,
+ const Array<T>& array,
+ absl::Span<const int64_t> indices,
+ int cur_depth,
+ std::vector<int64_t> cur_indices) {
if (cur_depth == array.num_dimensions() - 1) {
result->push_back(array(cur_indices));
} else {
@@ -106,15 +105,25 @@
if (index == -1) {
for (int64_t i = 0; i < array.dim(next_depth); ++i) {
cur_indices[next_depth] = i;
- AppendFlattenElements(result, array, indices, next_depth, cur_indices);
+ AppendFlattenElementsInternal(result, array, indices, next_depth,
+ cur_indices);
}
} else {
cur_indices[next_depth] = index;
- AppendFlattenElements(result, array, indices, next_depth, cur_indices);
+ AppendFlattenElementsInternal(result, array, indices, next_depth,
+ cur_indices);
}
}
}
+template <typename T>
+void AppendFlattenElements(std::vector<T>* result, const Array<T>& array,
+ absl::Span<const int64_t> indices) {
+ std::vector<int64_t> tmp_indices(array.num_dimensions(), 0);
+ AppendFlattenElementsInternal(result, array, indices,
+ /*cur_depth=*/-1, tmp_indices);
+}
+
// Return the index of key in a span. -1 means not found.
template <typename T>
int64_t GetIndex(absl::Span<const T> v, const T& key) {
@@ -201,11 +210,6 @@
}
}
-// Return whether this instruction is a custom call marker introduced by us.
-inline bool IsCustomCallMarker(const HloInstruction* inst) {
- return inst->IsCustomCall({kPipelineMarker, kIdentityMarker});
-}
-
// Return whether this instruction is a TopK custom call.
inline bool IsTopKCustomCall(const HloInstruction* inst) {
return inst->opcode() == HloOpcode::kCustomCall &&
@@ -218,70 +222,6 @@
inst->custom_call_target() == "PartialReduce";
}
-// Pass through the custom call marker and get the source instruction
-inline const HloInstruction* PassThroughCustomCallMarkerGetSource(
- const HloInstruction* ins) {
- while (ins->opcode() == HloOpcode::kGetTupleElement &&
- IsCustomCallMarker(ins->operand(0))) {
- const HloInstruction* custom_call = ins->operand(0);
- const HloInstruction* tuple = custom_call->operand(0);
- while (IsCustomCallMarker(tuple)) {
- tuple = tuple->operand(0);
- }
- ins = tuple->operand(ins->tuple_index());
- }
- return ins;
-}
-
-// Pass through the custom call marker and get the acutal operand.
-inline HloInstruction* PassThroughCustomCallMarkerOperand(
- HloInstruction* raw_operand, const HloInstruction* inst) {
- if (!IsCustomCallMarker(raw_operand)) {
- return raw_operand;
- }
-
- CHECK_EQ(inst->opcode(), HloOpcode::kGetTupleElement);
-
- int index = inst->tuple_index();
- return raw_operand->mutable_operand(0)->mutable_operand(index);
-}
-
-// Return whether the tuple is only used by a custom call marker.
-inline bool IsCustomCallMarkerTuple(const HloInstruction* inst) {
- return inst->opcode() == HloOpcode::kTuple && inst->users().size() == 1 &&
- IsCustomCallMarker(inst->users().front());
-}
-
-// Pass through the custom call marker and get the actual user.
-inline HloInstruction* PassThroughCustomCallMarkerUser(
- HloInstruction* raw_user, const HloInstruction* inst) {
- if (!IsCustomCallMarkerTuple(raw_user)) {
- return raw_user;
- }
-
- const HloInstruction* custom_call = raw_user->users().front();
-
- int index = -1;
- for (int i = 0; i < raw_user->operand_count(); i++) {
- if (raw_user->operand(i) == inst) {
- index = i;
- break;
- }
- }
- CHECK_NE(index, -1);
-
- HloInstruction* ret = nullptr;
- for (HloInstruction* user : custom_call->users()) {
- CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement);
- if (user->tuple_index() == index) {
- CHECK_EQ(ret, nullptr);
- ret = user;
- }
- }
-
- return ret == nullptr ? raw_user : ret;
-}
-
// Return the users of an instruction and its alias,
// excluding the final output tuple.
inline InstructionSet UsersWithAlias(const HloInstruction* inst,
@@ -289,8 +229,7 @@
const HloInstruction* output) {
InstructionSet users;
for (HloInstruction* user : inst->users()) {
- HloInstruction* pass_through_user =
- PassThroughCustomCallMarkerUser(user, inst);
+ HloInstruction* pass_through_user = user;
if (pass_through_user == output) {
continue;
}
@@ -300,8 +239,7 @@
auto iter = alias_map.find(inst);
if (iter != alias_map.end()) {
for (HloInstruction* user : iter->second->users()) {
- HloInstruction* pass_through_user =
- PassThroughCustomCallMarkerUser(user, iter->second);
+ HloInstruction* pass_through_user = user;
if (pass_through_user == output) {
continue;
}
@@ -356,10 +294,6 @@
const xla::CallGraph& call_graph,
int64_t num_devices);
-// Return whether the instruction is an activation from another pipeline stage.
-bool IsActivationFromAnotherStage(const HloInstruction* inst,
- const InstructionBatchDimMap& batch_dim_map);
-
// Depth analysis (breadth first search) that compute the depth of each
// instruction. We also assign a much larger distance to heavy operators (e.g.,
// dot, convolution).
@@ -464,41 +398,6 @@
const Array<int64_t>& device_mesh,
ReshardingCache* resharding_cache);
-/*
- * Gradient accumulation
- */
-// Find all instructions that compute gradients in gradient accumulation.
-// This is done by using the hint from pipeline_marker (gradient marker).
-inline std::vector<const HloInstruction*> GetGradientComputationInstructions(
- const std::vector<HloInstruction*>& instructions) {
- std::vector<const HloInstruction*> ret;
-
- for (size_t i = 0; i < instructions.size(); ++i) {
- const HloInstruction* ins = instructions[i];
- if (ins->IsCustomCall(kPipelineMarker) &&
- (absl::StrContains(ins->metadata().op_name(), "compute_grad") ||
- absl::StrContains(ins->metadata().op_name(), "backward")) &&
- ins->metadata().op_type() == kPipelineMarkerEndType) {
- const HloInstruction* tuple = ins->operand(0);
- for (size_t j = 0; j < tuple->operand_count(); ++j) {
- const HloInstruction* add = tuple->operand(j);
- while (add->opcode() == HloOpcode::kAdd) {
- ret.push_back(add->operand(0));
- ret.push_back(add->operand(1));
-
- if (add->operand(0)->opcode() == HloOpcode::kAdd) {
- add = add->operand(0);
- } else {
- add = add->operand(1);
- }
- }
- }
- }
- }
-
- return ret;
-}
-
// Gets the mapping vector from dim_from to dim_to.
// Example: GetDimensionMapping([2], 3) = [0, 1, -1]
std::vector<int64_t> GetDimensionMapping(
@@ -545,6 +444,11 @@
HloSharding Tile(const Shape& tensor_shape,
absl::Span<const int64_t> tensor_dims,
+ const std::vector<std::vector<int64_t>>& mesh_dims,
+ const Array<int64_t>& device_mesh);
+
+HloSharding Tile(const Shape& tensor_shape,
+ absl::Span<const int64_t> tensor_dims,
absl::Span<const int64_t> mesh_dims,
const Array<int64_t>& device_mesh);
diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD
index 352f75c..e65c48d 100644
--- a/third_party/xla/xla/hlo/ir/BUILD
+++ b/third_party/xla/xla/hlo/ir/BUILD
@@ -25,7 +25,6 @@
"dfs_hlo_visitor.cc",
"dynamic_parameter_binding.cc",
"hlo_computation.cc",
- "hlo_frontend_attributes.cc",
"hlo_input_output_alias_config.cc",
"hlo_instruction.cc",
"hlo_instructions.cc",
@@ -47,7 +46,6 @@
"hlo_clone_context.h",
"hlo_computation.h",
"hlo_domain_metadata.h",
- "hlo_frontend_attributes.h",
"hlo_input_output_alias_config.h",
"hlo_instruction.h",
"hlo_instructions.h",
@@ -72,6 +70,7 @@
"//xla:protobuf_util",
"//xla:shape_tree",
"//xla:shape_util",
+ "//xla:sort_json",
"//xla:status_macros",
"//xla:types",
"//xla:util",
@@ -79,6 +78,7 @@
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/service:compilation_environments",
+ "//xla/service:computation_layout",
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_lexer",
"//xla/service:hlo_module_config",
@@ -100,13 +100,13 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/lib/gtl:iterator_range",
"@local_tsl//tsl/lib/gtl:map_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:fingerprint",
- "@local_tsl//tsl/platform:human_readable_json",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:status",
diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc
index 025b1ce..4fbf057 100644
--- a/third_party/xla/xla/hlo/ir/hlo_computation.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc
@@ -1103,6 +1103,19 @@
return call_instruction;
}
+HloInstruction* HloComputation::CreateCompositeCallInstruction(
+ absl::Span<HloInstruction* const> instructions_to_call,
+ const std::string& name, const std::string& attributes, int64_t version) {
+ HloInstruction* root = instructions_to_call.front();
+ HloInstruction* call_instruction =
+ AddInstruction(HloInstruction::CreateCompositeCall(
+ root->shape(), root, name, attributes, version),
+ root->name());
+ AppendInstructionsIntoCalledComputation(instructions_to_call,
+ call_instruction);
+ return call_instruction;
+}
+
absl::StatusOr<HloInstruction*> HloComputation::CreateAsyncInstructions(
HloInstruction* instruction, absl::Span<const Shape> context_shapes,
absl::string_view async_execution_thread, bool replace,
diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h
index 956cf1a..3e73a68 100644
--- a/third_party/xla/xla/hlo/ir/hlo_computation.h
+++ b/third_party/xla/xla/hlo/ir/hlo_computation.h
@@ -17,18 +17,20 @@
#define XLA_HLO_IR_HLO_COMPUTATION_H_
#include <cstdint>
-#include <list>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
@@ -42,9 +44,14 @@
#include "xla/printer.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/name_uniquer.h"
+#include "xla/shape.h"
#include "xla/shape_tree.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
+#include "tsl/lib/gtl/iterator_range.h"
+#include "tsl/platform/errors.h"
namespace xla {
@@ -465,7 +472,7 @@
absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind);
- // Creates a call instruction containing the given instructions. Instructions
+ // Creates a call instruction containing the given instructions. Instructions
// must be in reverse topological order (root of the called computation
// first). Replaces all uses of the original root instruction with the call
// instruction. The original instructions are removed if they have no uses
@@ -473,6 +480,16 @@
HloInstruction* CreateCallInstruction(
absl::Span<HloInstruction* const> instructions_to_call);
+ // Creates a composite call instruction containing the given instructions.
+ // Instructions must be in reverse topological order (root of the called
+ // computation first). Replaces all uses of the original root instruction with
+ // the composite call instruction. The original instructions are removed if
+ // they have no uses after creating the composite call (this is necessarily
+ // true for at least the root).
+ HloInstruction* CreateCompositeCallInstruction(
+ absl::Span<HloInstruction* const> instructions_to_call,
+ const std::string& name, const std::string& attributes, int64_t version);
+
// Creates an async start/done instruction pair where instruction is wrapped
// inside an asynchronous computation. The context shapes are appended to the
// output tuple of the asynchronous start which is backend specific. Returns
diff --git a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc b/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc
deleted file mode 100644
index 347edce..0000000
--- a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
-
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-
-namespace xla {
-
-std::string FrontendAttributesToString(
- const FrontendAttributes& frontend_attributes) {
- std::vector<std::pair<std::string, std::string>> sorted_attributes(
- frontend_attributes.map().begin(), frontend_attributes.map().end());
- absl::c_sort(sorted_attributes);
- // Frontend attribute is a comma-separated list of attribute="value" pairs,
- // e.g., frontend_attributes={name="value_a",type="int32_t"}.
- const auto formatter = [](std::string* out,
- const std::pair<std::string, std::string>& item) {
- absl::StrAppend(out, item.first, "=\"", item.second, "\"");
- };
- return absl::StrFormat("{%s}",
- absl::StrJoin(sorted_attributes, ",", formatter));
-}
-
-} // namespace xla
diff --git a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h b/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h
deleted file mode 100644
index 7348691..0000000
--- a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h
+++ /dev/null
@@ -1,28 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_
-#define XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_
-
-#include <string>
-
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-std::string FrontendAttributesToString(
- const FrontendAttributes& frontend_attributes);
-} // namespace xla
-
-#endif // XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_
diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc
index 7261134..7764750 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc
@@ -55,7 +55,6 @@
#include "xla/hlo/ir/hlo_clone_context.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_domain_metadata.h"
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_op_metadata.h"
@@ -75,14 +74,16 @@
#include "xla/service/name_uniquer.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
+#include "xla/sort_json.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/lib/gtl/iterator_range.h"
#include "tsl/lib/gtl/map_util.h"
#include "tsl/platform/errors.h"
-#include "tsl/platform/human_readable_json.h"
#include "tsl/platform/logging.h" // IWYU pragma: keep
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
namespace xla {
@@ -1159,12 +1160,43 @@
<< instruction->opcode() << proto.name();
TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
- auto call_instruction = new HloCallInstruction(
- shape, all_operands(),
- computation_map.at(proto.called_computation_ids()[0]));
- call_instruction->set_output_to_operand_aliasing(
- output_to_operand_aliasing());
- instruction = absl::WrapUnique(call_instruction);
+ if (proto.is_composite()) {
+ TF_RET_CHECK(proto.has_frontend_attributes())
+ << "A composite call op must have frontend attributes";
+ auto map = proto.frontend_attributes().map();
+ auto name = map.find("composite.name");
+ TF_RET_CHECK(name != map.end() && !name->second.empty())
+ << "A composite call op must have frontend attributes with key "
+ "composite.name whose value is non-empty";
+
+ auto attributes = map.find("composite.attributes");
+ TF_RET_CHECK(attributes == map.end() || !attributes->second.empty())
+ << "A composite call op must have frontend attributes with key "
+ "composite.attributes whose value is default: {} or non-empty";
+
+ auto version_str = map.find("composite.version");
+ int64_t version = 0;
+ TF_RET_CHECK(
+ version_str == map.end() ||
+ (absl::SimpleAtoi(version_str->second, &version) && version >= 0))
+ << "A composite call op must have frontend attributes with a "
+ "composite.version whose value is a non-negative integer but "
+ "got: "
+ << version_str->second;
+
+ instruction = CreateCompositeCall(
+ shape, all_operands(),
+ computation_map.at(proto.called_computation_ids()[0]), name->second,
+ attributes == map.end() ? "{}" : attributes->second, version);
+ instruction->set_output_to_operand_aliasing(
+ output_to_operand_aliasing());
+ } else {
+ instruction = std::make_unique<HloCallInstruction>(
+ shape, all_operands(),
+ computation_map.at(proto.called_computation_ids()[0]));
+ instruction->set_output_to_operand_aliasing(
+ output_to_operand_aliasing());
+ }
break;
}
default: {
@@ -1230,7 +1262,6 @@
const xla::OriginalValueProto& original_value_proto =
proto.original_value();
auto original_value = std::make_shared<OriginalValue>(shape);
- std::cerr << __func__ << ", shape: " << shape.ToString() << "\n";
for (const auto& leaf : original_value_proto.leaves()) {
*original_value->mutable_element(ShapeIndex(leaf.leaf_shape_index())) = {
@@ -2252,6 +2283,27 @@
return std::make_unique<HloCallInstruction>(shape, operands, computation);
}
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCompositeCall(const Shape& shape,
+ HloInstruction* decomposition_root,
+ const std::string& name,
+ const std::string& attributes,
+ int64_t version) {
+ return std::make_unique<HloCallInstruction>(shape, decomposition_root, name,
+ attributes, version);
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCompositeCall(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* decomposition,
+ const std::string& name,
+ const std::string& attributes,
+ int64_t version) {
+ return std::make_unique<HloCallInstruction>(shape, operands, decomposition,
+ name, attributes, version);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, std::string opaque,
@@ -3635,6 +3687,13 @@
}
if (options.print_backend_config() && !backend_config_.empty()) {
absl::string_view config = backend_config_.GetRawString();
+ std::string sorted_config;
+ if (options.sort_backend_config()) {
+ // Use `value_or` below, because the backend config string isn't
+ // guaranteed to be a JSON string.
+ sorted_config = SortJson(config).value_or(std::string(config));
+ config = sorted_config;
+ }
printer->Append(", backend_config=");
// In the common case that the backend-config is valid-ish JSON, the parser
// doesn't need it delimited by quotes, so we can print it without
@@ -3782,6 +3841,10 @@
PrintNameInternal(printer, to_apply()->name(), options);
});
}
+ if (opcode() == HloOpcode::kCall && is_composite()) {
+ printer.Next(
+ [](Printer* printer) { printer->Append("is_composite=true"); });
+ }
} else if (opcode() == HloOpcode::kCustomCall) {
if (!called_computations().empty()) {
printer.Next([this, &options](Printer* printer) {
@@ -3878,6 +3941,10 @@
to_apply()->Print(printer, new_options);
});
}
+ if (opcode() == HloOpcode::kCall && is_composite()) {
+ printer.Next(
+ [](Printer* printer) { printer->Append("is_composite=true"); });
+ }
break;
default:
if (!called_computations().empty()) {
@@ -3900,13 +3967,18 @@
sharding().Print(printer, options.print_metadata());
});
}
- if (!rare()->frontend_attributes.map().empty()) {
+ if (!frontend_attributes().map().empty()) {
printer.Next([this](Printer* printer) {
AppendCat(printer, "frontend_attributes=",
- FrontendAttributesToString(rare()->frontend_attributes));
+ FrontendAttributesToString(frontend_attributes()));
});
}
+ if (opcode() != HloOpcode::kCall) {
+ CHECK(!is_composite())
+ << "Only kCall instructions should have is_composite set";
+ }
+
if (options.print_control_dependencies() && !control_predecessors().empty()) {
printer.Next([this, &options](Printer* printer) {
printer->Append("control-predecessors={");
@@ -3952,6 +4024,23 @@
return std::move(multi_string_printer).ConsumeStrings();
}
+std::string FrontendAttributesToString(
+ const FrontendAttributes& frontend_attributes) {
+ std::vector<std::pair<std::string, std::string>> sorted_attributes(
+ frontend_attributes.map().begin(), frontend_attributes.map().end());
+ absl::c_sort(sorted_attributes);
+ const auto formatter = [](std::string* out,
+ const std::pair<std::string, std::string>& item) {
+ if (LexesAsJsonDict(item.second)) {
+ absl::StrAppend(out, item.first, "=", item.second);
+ } else {
+ absl::StrAppend(out, item.first, "=\"", item.second, "\"");
+ }
+ };
+ return absl::StrFormat("{%s}",
+ absl::StrJoin(sorted_attributes, ",", formatter));
+}
+
std::string HloInstruction::ToShortString() const {
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
StrJoin(operands_, ", ",
@@ -3990,6 +4079,7 @@
}
*proto.mutable_frontend_attributes() = frontend_attributes();
+ proto.set_is_composite(is_composite());
*proto.mutable_statistics_viz() = statistics_viz();
diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h
index 6d8821c..a98f996 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instruction.h
+++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h
@@ -25,6 +25,7 @@
#include <cstdint>
#include <functional>
#include <iosfwd>
+#include <iterator>
#include <map>
#include <memory>
#include <optional>
@@ -98,6 +99,7 @@
print_metadata_(true),
print_metadata_only_op_name_(false),
print_backend_config_(true),
+ sort_backend_config_(false),
print_infeed_outfeed_config_(true),
compact_operands_(false),
include_layout_in_shapes_(true),
@@ -218,6 +220,14 @@
return *this;
}
+ // If true, will attempt to sort the backend config's json representation
+ // before printing it. If the backend config is a raw string that is not json,
+ // it will be printed as is, without sorting.
+ HloPrintOptions& set_sort_backend_config(bool value) {
+ sort_backend_config_ = value;
+ return *this;
+ }
+
// If true, infeed_config and outfeed_config will be printed.
HloPrintOptions& set_print_infeed_outfeed_config(bool value) {
print_infeed_outfeed_config_ = value;
@@ -382,6 +392,7 @@
return print_metadata_only_op_name_;
}
bool print_backend_config() const { return print_backend_config_; }
+ bool sort_backend_config() const { return sort_backend_config_; }
bool print_infeed_outfeed_config() const {
return print_infeed_outfeed_config_;
}
@@ -422,6 +433,7 @@
bool print_metadata_;
bool print_metadata_only_op_name_;
bool print_backend_config_;
+ bool sort_backend_config_;
bool print_infeed_outfeed_config_;
bool compact_operands_;
bool include_layout_in_shapes_;
@@ -1351,6 +1363,17 @@
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation);
+ // Creates a composite call instruction that applies the given computation on
+ // the given operands. "shape" is the resultant shape.
+ static std::unique_ptr<HloInstruction> CreateCompositeCall(
+ const Shape& shape, HloInstruction* decomposition_root,
+ const std::string& name, const std::string& attributes, int64_t version);
+
+ static std::unique_ptr<HloInstruction> CreateCompositeCall(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ HloComputation* decomposition, const std::string& name,
+ const std::string& attributes, int64_t version);
+
// Creates a custom call instruction that applies the given custom call target
// to the given operands. "opaque" can be an arbitrary string with a
// backend-specific interpretation. "shape" is the resultant shape.
@@ -2093,6 +2116,9 @@
mutable_rare()->frontend_attributes = std::move(frontend_attributes);
}
+ // Appends the given frontend attributes to the existing ones. If existing
+ // frontend attributes are empty, then create it and set it to the provided
+ // one.
void add_frontend_attributes(FrontendAttributes frontend_attributes) {
if (!frontend_attributes.map().empty()) {
mutable_rare()->frontend_attributes.mutable_map()->insert(
@@ -2100,10 +2126,25 @@
}
}
+ bool has_frontend_attributes() const {
+ return has_rare() && !rare()->frontend_attributes.map().empty();
+ }
+
const FrontendAttributes& frontend_attributes() const {
return rare()->frontend_attributes;
}
+ void set_is_composite(bool is_composite) {
+ if (!has_rare() && !is_composite) {
+ return;
+ }
+ mutable_rare()->is_composite = is_composite;
+ }
+
+ // Return the is_composite attribute. This attribute is only relevant for
+ // kCall instructions used as a Composite op.
+ bool is_composite() const { return has_rare() && rare()->is_composite; }
+
void add_single_statistic(Statistic statistic) {
*mutable_rare()->statistics_viz.add_statistics() = std::move(statistic);
}
@@ -2199,8 +2240,8 @@
void set_metadata_preserve_layout(bool preserve_layout) {
metadata_->set_preserve_layout(preserve_layout);
}
- void set_metadata_scheduling_name(const std::string& name) {
- metadata_->set_scheduling_name(name);
+ void set_metadata_scheduling_name(absl::string_view name) {
+ metadata_->set_scheduling_name(std::string(name));
}
const OpMetadata& metadata() const { return *metadata_; }
@@ -2688,6 +2729,9 @@
// z' = const(20), frontend_attributes={?}
FrontendAttributes frontend_attributes;
+ // Used by kCall to determine if the Call instruction is a composite.
+ bool is_composite;
+
// Used to render an HLO graph when tracking the propagation desired values
// through it.
StatisticsViz statistics_viz;
@@ -2827,6 +2871,14 @@
// Custom (de)stringification functions for protos that live inside
// HloInstruction.
std::string PaddingConfigToString(const PaddingConfig& padding);
+
+// Returns string representation of frontend attributes.
+// Frontend attribute is a list of attribute=<value> pairs where value is either
+// a "string" or a JSON-like dict surrounded in {}. Similar to custom_call
+// backend config, this can be used to store stringified MLIR-dictionaries with
+// pretty printing.
+std::string FrontendAttributesToString(
+ const FrontendAttributes& frontend_attributes);
std::string StatisticsVizToString(const StatisticsViz& statistics_viz);
std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm);
std::string RandomDistributionToString(const RandomDistribution& distribution);
diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc
index d9801cf..cff0907 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc
@@ -1877,6 +1877,35 @@
}
}
+HloCallableInstruction::HloCallableInstruction(HloOpcode opcode,
+ const Shape& shape,
+ const std::string& name,
+ const std::string& attributes,
+ int64_t version)
+ : HloInstruction(opcode, shape) {
+ auto frontend_attributes =
+ BuildFrontendAttributesForComposite(name, attributes, version);
+ add_frontend_attributes(frontend_attributes);
+ set_is_composite(true);
+}
+
+HloCallableInstruction::HloCallableInstruction(
+ HloOpcode opcode, const Shape& shape,
+ absl::Span<HloInstruction* const> operands, HloComputation* decomposition,
+ const std::string& name, const std::string& attributes, int64_t version)
+ : HloInstruction(opcode, shape) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+ SetAndSanitizeName(HloOpcodeString(opcode));
+ AppendComputation(decomposition);
+
+ auto frontend_attributes =
+ BuildFrontendAttributesForComposite(name, attributes, version);
+ add_frontend_attributes(frontend_attributes);
+ set_is_composite(true);
+}
+
HloCallableInstruction::~HloCallableInstruction() { ClearCalledComputations(); }
HloComputation* HloCallableInstruction::called_computation() const {
@@ -1924,7 +1953,7 @@
return u->opcode() == HloOpcode::kGetTupleElement;
});
if (called_computations().empty()) {
- // New fusion instruction. It should not be a multioutput instruction.
+ // New fusion instruction. It should not be a multi-output instruction.
CHECK(!add_output);
auto builder = HloComputation::Builder(default_called_computation_name());
builder.AddInstruction(instruction_to_append->Clone(/*suffix=*/""));
@@ -2552,6 +2581,47 @@
: HloCallableInstruction(HloOpcode::kCall, shape, operands,
called_computation) {}
+HloCallInstruction::HloCallInstruction(const Shape& shape,
+ HloInstruction* decomposition_root,
+ const std::string& name,
+ const std::string& attributes,
+ int64_t version)
+ : HloCallableInstruction(HloOpcode::kCall, shape, name, attributes,
+ version) {
+ CHECK(decomposition_root != nullptr);
+ SetAndSanitizeName(HloOpcodeString(opcode()));
+
+ FrontendAttributes frontend_attributes;
+ frontend_attributes.mutable_map()->insert({"composite.name", name});
+ frontend_attributes.mutable_map()->insert(
+ {"composite.attributes", attributes});
+ frontend_attributes.mutable_map()->insert(
+ {"composite.version", std::to_string(version)});
+
+ add_frontend_attributes(frontend_attributes);
+ set_is_composite(true);
+ set_parent(decomposition_root->parent());
+ set_metadata(decomposition_root->metadata());
+ CloneAndAppendInstructionIntoCalledComputation(decomposition_root);
+}
+
+HloCallInstruction::HloCallInstruction(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ HloComputation* decomposition, const std::string& name,
+ const std::string& attributes, int64_t version)
+ : HloCallableInstruction(HloOpcode::kCall, shape, operands, decomposition,
+ name, attributes, version) {
+ FrontendAttributes frontend_attributes;
+ frontend_attributes.mutable_map()->insert({"composite.name", name});
+ frontend_attributes.mutable_map()->insert(
+ {"composite.attributes", attributes});
+ frontend_attributes.mutable_map()->insert(
+ {"composite.version", std::to_string(version)});
+
+ add_frontend_attributes(frontend_attributes);
+ set_is_composite(true);
+}
+
HloRngInstruction::HloRngInstruction(
const Shape& shape, RandomDistribution distribution,
absl::Span<HloInstruction* const> parameters)
@@ -3500,13 +3570,23 @@
AppendJoin(printer, dim_numbers.collapsed_slice_dims(), ",");
printer->Append("}, start_index_map={");
AppendJoin(printer, dim_numbers.start_index_map(), ",");
+ if (dim_numbers.operand_batching_dims_size()) {
+ printer->Append("}, operand_batching_dims={");
+ AppendJoin(printer, dim_numbers.operand_batching_dims(), ",");
+ }
+ if (dim_numbers.start_indices_batching_dims_size()) {
+ printer->Append("}, start_indices_batching_dims={");
+ AppendJoin(printer, dim_numbers.start_indices_batching_dims(), ",");
+ }
AppendCat(printer, "}, index_vector_dim=", dim_numbers.index_vector_dim());
}
/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
absl::Span<const int64_t> offset_dims,
absl::Span<const int64_t> collapsed_slice_dims,
- absl::Span<const int64_t> start_index_map, int64_t index_vector_dim) {
+ absl::Span<const int64_t> start_index_map, int64_t index_vector_dim,
+ absl::Span<const int64_t> operand_batching_dims,
+ absl::Span<const int64_t> start_indices_batching_dims) {
GatherDimensionNumbers gather_dim_numbers;
for (int64_t output_window_dim : offset_dims) {
gather_dim_numbers.add_offset_dims(output_window_dim);
@@ -3517,6 +3597,13 @@
for (int64_t gather_dim_to_input_dim : start_index_map) {
gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
}
+ for (int64_t operand_batching_dim : operand_batching_dims) {
+ gather_dim_numbers.add_operand_batching_dims(operand_batching_dim);
+ }
+ for (int64_t start_indices_batching_dim : start_indices_batching_dims) {
+ gather_dim_numbers.add_start_indices_batching_dims(
+ start_indices_batching_dim);
+ }
gather_dim_numbers.set_index_vector_dim(index_vector_dim);
return gather_dim_numbers;
@@ -3601,6 +3688,14 @@
AppendJoin(printer, dim_numbers.inserted_window_dims(), ",");
printer->Append("}, scatter_dims_to_operand_dims={");
AppendJoin(printer, dim_numbers.scatter_dims_to_operand_dims(), ",");
+ if (dim_numbers.input_batching_dims_size()) {
+ printer->Append("}, input_batching_dims={");
+ AppendJoin(printer, dim_numbers.input_batching_dims(), ",");
+ }
+ if (dim_numbers.scatter_indices_batching_dims_size()) {
+ printer->Append("}, scatter_indices_batching_dims={");
+ AppendJoin(printer, dim_numbers.scatter_indices_batching_dims(), ",");
+ }
AppendCat(printer, "}, index_vector_dim=", dim_numbers.index_vector_dim());
}
@@ -3609,7 +3704,8 @@
absl::Span<const int64_t> update_window_dims,
absl::Span<const int64_t> inserted_window_dims,
absl::Span<const int64_t> scatter_dims_to_operand_dims,
- int64_t index_vector_dim) {
+ int64_t index_vector_dim, absl::Span<const int64_t> input_batching_dims,
+ absl::Span<const int64_t> scatter_indices_batching_dims) {
ScatterDimensionNumbers scatter_dim_numbers;
for (int64_t update_window_dim : update_window_dims) {
scatter_dim_numbers.add_update_window_dims(update_window_dim);
@@ -3621,6 +3717,13 @@
scatter_dim_numbers.add_scatter_dims_to_operand_dims(
scatter_dim_to_operand_dim);
}
+ for (int64_t input_batching_dim : input_batching_dims) {
+ scatter_dim_numbers.add_input_batching_dims(input_batching_dim);
+ }
+ for (int64_t scatter_indices_batching_dim : scatter_indices_batching_dims) {
+ scatter_dim_numbers.add_scatter_indices_batching_dims(
+ scatter_indices_batching_dim);
+ }
scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
return scatter_dim_numbers;
}
diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h
index b0e337a..c0f0324 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instructions.h
+++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h
@@ -19,13 +19,13 @@
#define XLA_HLO_IR_HLO_INSTRUCTIONS_H_
#include <cstdint>
-#include <list>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
+#include "absl/base/attributes.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
@@ -38,7 +38,6 @@
#include "xla/hlo/ir/hlo_domain_metadata.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/iterator_util.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/printer.h"
@@ -1343,6 +1342,15 @@
absl::Span<HloInstruction* const> operands,
absl::Span<HloComputation* const> called_computations);
+ HloCallableInstruction(HloOpcode opcode, const Shape& shape,
+ const std::string& name, const std::string& attributes,
+ int64_t version);
+
+ HloCallableInstruction(HloOpcode opcode, const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* decomposition, const std::string& name,
+ const std::string& attributes, int64_t version);
+
~HloCallableInstruction() override;
// Adds a new operand to the callable instruction.
@@ -1402,6 +1410,21 @@
output_to_operand_aliasing_ = std::move(aliasing);
}
+ FrontendAttributes BuildFrontendAttributesForComposite(
+ const std::string& name,
+ std::optional<absl::string_view> attributes = std::nullopt,
+ std::optional<int64_t> version = std::nullopt) {
+ FrontendAttributes frontend_attributes;
+ frontend_attributes.mutable_map()->insert({"composite.name", name});
+ frontend_attributes.mutable_map()->insert(
+ {"composite.attributes",
+ attributes.has_value() ? std::string(*attributes) : "{}"});
+ frontend_attributes.mutable_map()->insert(
+ {"composite.version",
+ version.has_value() ? std::to_string(*version) : "0"});
+ return frontend_attributes;
+ }
+
protected:
// Returns the default called computation name.
virtual std::string default_called_computation_name() const = 0;
@@ -1450,7 +1473,7 @@
void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
// Merges the fused instructions from instruction_to_merge into the fused
- // instruction set of 'this' and generates multioutput fusion instructions.
+ // instruction set of 'this' and generates multi-output fusion instructions.
// All the users of instruction_to_merge will be redirected to 'this'
// instruction. instruction_to_merge will be removed from its parent
// computation.
@@ -1555,6 +1578,15 @@
absl::Span<HloInstruction* const> operands,
HloComputation* called_computation);
+ HloCallInstruction(const Shape& shape, HloInstruction* decomposition_root,
+ const std::string& name, const std::string& attributes,
+ int64_t version);
+
+ HloCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* decomposition, const std::string& name,
+ const std::string& attributes, int64_t version);
+
static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kCall;
}
@@ -2313,7 +2345,9 @@
static GatherDimensionNumbers MakeGatherDimNumbers(
absl::Span<const int64_t> offset_dims,
absl::Span<const int64_t> collapsed_slice_dims,
- absl::Span<const int64_t> start_index_map, int64_t index_vector_dim);
+ absl::Span<const int64_t> start_index_map, int64_t index_vector_dim,
+ absl::Span<const int64_t> operand_batching_dims = {},
+ absl::Span<const int64_t> start_indices_batching_dims = {});
// Returns the dump string of the given gather dimension numbers.
static std::string GatherDimensionNumbersToString(
const GatherDimensionNumbers& dim_numbers);
@@ -2378,7 +2412,9 @@
absl::Span<const int64_t> update_window_dims,
absl::Span<const int64_t> inserted_window_dims,
absl::Span<const int64_t> scatter_dims_to_operand_dims,
- int64_t index_vector_dim);
+ int64_t index_vector_dim,
+ absl::Span<const int64_t> input_batching_dims = {},
+ absl::Span<const int64_t> scatter_indices_batching_dims = {});
// Returns the dump string of the given scatter dimension numbers.
static std::string ScatterDimensionNumbersToString(
const ScatterDimensionNumbers& dim_numbers);
diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc
index 0711d49..cc8dda9 100644
--- a/third_party/xla/xla/hlo/ir/hlo_module.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_module.cc
@@ -17,8 +17,8 @@
#include <algorithm>
#include <atomic>
+#include <cstddef>
#include <cstdint>
-#include <functional>
#include <iterator>
#include <memory>
#include <optional>
@@ -30,24 +30,36 @@
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/cord.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/map_util.h"
#include "xla/printer.h"
#include "xla/service/compilation_environments.h"
+#include "xla/service/computation_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/mapped_ptr_container_sorter.h"
+#include "xla/service/name_uniquer.h"
+#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
+#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/lib/gtl/map_util.h"
+#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/fingerprint.h"
#include "tsl/platform/logging.h"
@@ -405,8 +417,8 @@
? MakeComputationSorted()
: MakeComputationPostOrder();
for (const HloComputation* computation : computations) {
- // Don't print async computations when the sytax sugar is enabled since that
- // is redundant information.
+ // Don't print async computations when the syntax sugar is enabled since
+ // that is redundant information.
if (options.syntax_sugar_async_ops() && computation->IsAsyncComputation() &&
computation->CanExpandIntoSingleInstruction()) {
continue;
@@ -848,7 +860,7 @@
outlined_instruction);
// Mark instruction_to_outline an output if it is used outside the
- // subcomputation or is the output of the original computation (i.e. used
+ // sub-computation or is the output of the original computation (i.e. used
// externally).
if (instruction_to_outline->user_count() == 0 ||
IsUsedOutsideSubcomputation(*instruction_to_outline,
@@ -917,7 +929,7 @@
if (computations_.empty()) {
return {};
}
- // First determine all root computations by building a set of nonroot
+ // First determine all root computations by building a set of non-root
// computations (computations which are called by an instruction in the
// module).
absl::flat_hash_set<HloComputation*> nonroot_computations;
diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD
index 0bae42a..8f20f63 100644
--- a/third_party/xla/xla/hlo/utils/BUILD
+++ b/third_party/xla/xla/hlo/utils/BUILD
@@ -158,6 +158,7 @@
"//xla/service:pattern_matcher",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)
diff --git a/third_party/xla/xla/hlo/utils/hlo_query.cc b/third_party/xla/xla/hlo/utils/hlo_query.cc
index 69a6fef..147f548 100644
--- a/third_party/xla/xla/hlo/utils/hlo_query.cc
+++ b/third_party/xla/xla/hlo/utils/hlo_query.cc
@@ -17,6 +17,7 @@
#include <algorithm>
#include <cstdint>
+#include <utility>
#include "absl/algorithm/container.h"
#include "absl/strings/string_view.h"
@@ -269,22 +270,46 @@
return gte;
}
-bool IsBeforeInComputation(const HloComputation* computation,
- absl::string_view inst1, absl::string_view inst2) {
- int index1 = -1;
- int index2 = -1;
+HloComputation* FindComputation(HloModule* module, absl::string_view name) {
+ auto computations = module->computations();
+ auto it = absl::c_find_if(
+ computations, [&](HloComputation* c) { return c->name() == name; });
+ if (it == computations.end()) {
+ return nullptr;
+ }
+ return *it;
+}
+
+std::pair<HloInstruction*, int> FindFirstInstruction(
+ const HloComputation* computation, absl::string_view name) {
int current_index = 0;
- for (auto instruction : computation->instructions()) {
- if (instruction->name() == inst1) {
- index1 = current_index;
- }
- if (instruction->name() == inst2) {
- index2 = current_index;
+ for (auto* instruction : computation->instructions()) {
+ if (instruction->name() == name) {
+ return {instruction, current_index};
+ break;
}
current_index++;
}
- current_index++;
- return index1 < index2;
+ return {nullptr, -1};
+}
+
+std::pair<HloInstruction*, int> FindFirstInstruction(
+ const HloComputation* computation, HloOpcode opcode) {
+ int current_index = 0;
+ for (auto* instruction : computation->instructions()) {
+ if (instruction->opcode() == opcode) {
+ return {instruction, current_index};
+ break;
+ }
+ current_index++;
+ }
+ return {nullptr, -1};
+}
+
+bool IsBeforeInComputation(const HloComputation* computation,
+ absl::string_view inst1, absl::string_view inst2) {
+ return FindFirstInstruction(computation, inst1).second <
+ FindFirstInstruction(computation, inst2).second;
}
} // namespace hlo_query
} // namespace xla
diff --git a/third_party/xla/xla/hlo/utils/hlo_query.h b/third_party/xla/xla/hlo/utils/hlo_query.h
index 950082a..ec5c0b2 100644
--- a/third_party/xla/xla/hlo/utils/hlo_query.h
+++ b/third_party/xla/xla/hlo/utils/hlo_query.h
@@ -17,6 +17,7 @@
#define XLA_HLO_UTILS_HLO_QUERY_H_
#include <cstdint>
+#include <utility>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
@@ -153,7 +154,19 @@
HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand,
int64_t index);
-// TODO: b/356153995 - refactor hlo_test_base
+// Gets the computation from the given module with the given name.
+HloComputation* FindComputation(HloModule* module, absl::string_view name);
+// Gets the first instruction and its index from the given computation with the
+// given instruction name. The function returns {nullptr, -1} if the instruction
+// cannot be found.
+std::pair<HloInstruction*, int> FindFirstInstruction(
+ const HloComputation* computation, absl::string_view name);
+// Gets the first instruction and its index from the given computation with the
+// given instruction opcode. The function returns {nullptr, -1} if the
+// instruction cannot be found.
+std::pair<HloInstruction*, int> FindFirstInstruction(
+ const HloComputation* computation, HloOpcode opcode);
+
// Check that one instruction comes before another one for a given computation.
// The function returns true if the first instruction comes before the second
// one, and false otherwise. This is useful for partial checks on the
diff --git a/third_party/xla/xla/hlo/utils/hlo_query_test.cc b/third_party/xla/xla/hlo/utils/hlo_query_test.cc
index acefa21..e4dad10 100644
--- a/third_party/xla/xla/hlo/utils/hlo_query_test.cc
+++ b/third_party/xla/xla/hlo/utils/hlo_query_test.cc
@@ -40,6 +40,14 @@
return counter;
}
+constexpr absl::string_view kConstantAdditionHloString = R"(
+HloModule test
+ENTRY main {
+ zero = f32[] constant(0)
+ five = f32[] constant(5)
+ ROOT out = f32[] add(zero, five)
+})";
+
TEST_F(HloQueryTest,
GetInstructionWithOpCodeReturnsMatchingInstructionForModule) {
constexpr absl::string_view kHloString = R"(
@@ -132,5 +140,66 @@
EXPECT_EQ(gte2, nullptr);
}
+TEST_F(HloQueryTest, FindComputationTest) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+ EXPECT_NE(hlo_query::FindComputation(module.get(), "main"), nullptr);
+ EXPECT_EQ(hlo_query::FindComputation(module.get(), "foo"), nullptr);
+}
+
+TEST_F(HloQueryTest, FindInstructionUsingNameTest) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+ const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+ EXPECT_NE(hlo_query::FindFirstInstruction(main, "zero").first, nullptr);
+ EXPECT_NE(hlo_query::FindFirstInstruction(main, "five").first, nullptr);
+ EXPECT_NE(hlo_query::FindFirstInstruction(main, "out").first, nullptr);
+ EXPECT_EQ(hlo_query::FindFirstInstruction(main, "foo").first, nullptr);
+}
+
+TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+ const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+ EXPECT_NE(
+ hlo_query::FindFirstInstruction(main, StringToHloOpcode("add").value())
+ .first,
+ nullptr);
+ EXPECT_NE(hlo_query::FindFirstInstruction(
+ main, StringToHloOpcode("constant").value())
+ .first,
+ nullptr);
+ EXPECT_EQ(
+ hlo_query::FindFirstInstruction(main, StringToHloOpcode("select").value())
+ .first,
+ nullptr);
+}
+
+TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+ const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+ EXPECT_NE(main, nullptr);
+ auto find_beef = hlo_query::FindFirstInstruction(main, "deadbeef");
+ auto find_nothing = hlo_query::FindFirstInstruction(main, "");
+ EXPECT_EQ(find_beef.first, nullptr);
+ EXPECT_EQ(find_beef.second, -1);
+ EXPECT_EQ(find_nothing.first, nullptr);
+ EXPECT_EQ(find_nothing.second, -1);
+}
+
+TEST_F(HloQueryTest, IsBeforeInComputationTest) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+ const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+ EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "zero", "five"));
+ EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "five", "out"));
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
index a4e8c6a..28c9f0f 100644
--- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
+++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
@@ -2247,7 +2247,6 @@
// %indices = concatenate(..., %iota.1, ...)
// ... = gather(..., %indices)
// is common for tf.reverse_sequence and would match this case.
- absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
const int num_indices = index_map.size();
std::vector<int64_t> index_parallel_in_dim(num_indices, -1);
@@ -2733,21 +2732,21 @@
// 2. Try borrow dimensions from replicable_dims in order, and group sharding.
if (sharding.IsTiled()) {
- int64_t max_replicable_dimensions =
+ const int64_t reps_on_last_tile_dim =
sharding.ReplicateOnLastTileDim()
? sharding.tile_assignment().dimensions().back()
: 1;
- max_replicable_dimensions = absl::c_accumulate(
- replicable_dims, max_replicable_dimensions,
+
+ const int64_t max_replicable_dimensions = absl::c_accumulate(
+ replicable_dims, reps_on_last_tile_dim,
[&](int64_t product, int64_t dim) {
return product * sharding.tile_assignment().dim(dim);
});
- if (max_replicable_dimensions % num_groups == 0) {
+
+ if (max_replicable_dimensions % num_groups == 0 &&
+ num_groups % reps_on_last_tile_dim == 0) {
auto tile_assignment = [&]() -> std::optional<TileAssignment> {
- int dimensions_to_borrow =
- num_groups / (sharding.ReplicateOnLastTileDim()
- ? sharding.tile_assignment().dimensions().back()
- : 1);
+ int dimensions_to_borrow = num_groups / reps_on_last_tile_dim;
DimensionVector tile_dims(
sharding.tile_assignment().dimensions().begin(),
sharding.tile_assignment().dimensions().end());
diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h
index f6e5f58..dc28277 100644
--- a/third_party/xla/xla/literal_util.h
+++ b/third_party/xla/xla/literal_util.h
@@ -533,7 +533,7 @@
template <typename NativeT>
/* static */ Literal LiteralUtil::MakeScalarMatrixR2(int64_t size,
NativeT scalar) {
- Array2D<NativeT> array(size, size, 0);
+ Array2D<NativeT> array(size, size, NativeT(0));
for (int64_t i = 0; i < size; ++i) {
array(i, i) = scalar;
}
@@ -542,7 +542,7 @@
template <typename NativeT>
/* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) {
- return MakeScalarMatrixR2<NativeT>(size, 1);
+ return MakeScalarMatrixR2<NativeT>(size, NativeT(1));
}
template <typename NativeT>
@@ -550,7 +550,7 @@
NativeT scale) {
NativeT row_factor = log10(m) + 1;
NativeT col_factor = log10(n) + 1;
- Array2D<NativeT> array(m, n, 0);
+ Array2D<NativeT> array(m, n, NativeT(0));
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
array(i, i) = scale * (row_factor * i + col_factor * j);
diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc
index 0ab6c24..dfd3702 100644
--- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc
+++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc
@@ -87,56 +87,6 @@
op->erase();
}
-// Ensure that there aren't any implicit capture before exporting.
-static void prepareWhileOp(WhileOp whileOp) {
- llvm::SetVector<Value> implicitInputs;
- getUsedValuesDefinedAbove(whileOp->getRegions(), implicitInputs);
- if (implicitInputs.empty()) return;
- // Each captured value has to be passed as operand to the while, become then
- // an operand to the condition region and the body region, and an extra
- // operand to the return op in the body. It also becomes an extra result for
- // the while operation, even if it is unused.
- // We'll process the captured values one at a time and patch the body and
- // condition regions as we go, but we'll accumulate the new operands and
- // result type and recreate a new while op to replace the existing one at the
- // end.
- SmallVector<Type> returnedTypes(whileOp->getResultTypes().begin(),
- whileOp->getResultTypes().end());
- SmallVector<Value> operands(whileOp->getOperands().begin(),
- whileOp->getOperands().end());
- Region &condRegion = whileOp.getCond();
- Region &bodyRegion = whileOp.getBody();
-
- for (Value input : implicitInputs) {
- returnedTypes.push_back(input.getType());
- operands.push_back(input);
-
- Value condArg =
- condRegion.front().addArgument(input.getType(), input.getLoc());
- Value bodyArg =
- bodyRegion.front().addArgument(input.getType(), input.getLoc());
- for (OpOperand &operand : llvm::make_early_inc_range(input.getUses())) {
- if (condRegion.isAncestor(operand.getOwner()->getParentRegion()))
- operand.set(condArg);
- else if (bodyRegion.isAncestor(operand.getOwner()->getParentRegion()))
- operand.set(bodyArg);
- }
- auto returnOp = cast<mhlo::ReturnOp>(bodyRegion.front().back());
- returnOp->insertOperands(returnOp->getNumOperands(), bodyArg);
- }
- OpBuilder builder(whileOp);
- auto newWhileOp =
- builder.create<mhlo::WhileOp>(whileOp.getLoc(), returnedTypes, operands);
- newWhileOp.getCond().getBlocks().clear();
- newWhileOp.getCond().takeBody(whileOp.getCond());
- newWhileOp.getBody().getBlocks().clear();
- newWhileOp.getBody().takeBody(whileOp.getBody());
- for (auto zippedResults :
- llvm::zip_first(whileOp.getResults(), newWhileOp.getResults()))
- std::get<0>(zippedResults).replaceAllUsesWith(std::get<1>(zippedResults));
- whileOp->erase();
-}
-
static void prepareBroadcastInDim(BroadcastInDimOp bcast) {
DenseIntElementsAttr dims = bcast.getBroadcastDimensions();
// If dimensions aren't sorted, there is a transpose fused into the op, which
@@ -200,7 +150,6 @@
mlir::SplatElementsAttr attr;
if (matchPattern(op, m_Constant(&attr))) return prepareConstantOp(op, attr);
- if (auto whileOp = dyn_cast<WhileOp>(op)) return prepareWhileOp(whileOp);
if (auto bcastOp = dyn_cast<BroadcastInDimOp>(op))
return prepareBroadcastInDim(bcastOp);
// IfOp, CaseOp, WhileOp are already being handled during
diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp
index fb5728f..68ec969 100644
--- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp
+++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp
@@ -104,7 +104,7 @@
auto res = verifyCustomCallOpAttributes(
op, rewriter, [&](NamedAttribute attr) -> LogicalResult {
if (attr.getName() != "largest") return success();
- if (cast<BoolAttr>(attr.getValue()).getValue() == false)
+ if (!cast<BoolAttr>(attr.getValue()).getValue())
return rewriter.notifyMatchFailure(
op, "largest = false is not supported.");
return success();
diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir
index 6944e79..a214de5 100644
--- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir
+++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir
@@ -22,94 +22,6 @@
// -----
-// CHECK-LABEL: @while_without_implicit_capture
-func.func @while_without_implicit_capture(%arg0: tensor<i64>) -> tensor<i64> {
- // CHECK: mhlo.while
- // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0)
- // CHECK-SAME: {mhlo.sharding = "{{\{}}{replicated},{replicated}}"}
- %0:2 = "mhlo.while"(%arg0, %arg0) ({
- ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
- %1 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i64>, tensor<i64>) -> tensor<i1>
- "mhlo.return"(%1) : (tensor<i1>) -> ()
- }, {
- ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
- %2 = mhlo.add %arg1, %arg1 : tensor<i64>
- "mhlo.return"(%2, %arg2) : (tensor<i64>, tensor<i64>) -> ()
- }) {mhlo.sharding = "{{replicated},{replicated}}"} : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
- func.return %0#0 : tensor<i64>
-}
-
-// -----
-
-// CHECK-LABEL: @while_with_implicit_arg_capture
-func.func @while_with_implicit_arg_capture(%arg0: tensor<i64>) -> tensor<i64> {
- // CHECK: mhlo.while
- // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0)
- %0 = "mhlo.while"(%arg0) ({
- ^bb0(%arg1: tensor<i64>):
- // CHECK: mhlo.compare
- // CHECK-SAME: %[[ARG2]], %[[ARG1]]
- %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i64>, tensor<i64>) -> tensor<i1>
- "mhlo.return"(%1) : (tensor<i1>) -> ()
- }, {
- ^bb0(%arg1: tensor<i64>):
- // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG1]], %[[ARG1]]
- %2 = mhlo.add %arg1, %arg1 : tensor<i64>
- // CHECK: mhlo.return
- // CHECK-SAME: %[[ADD]], %[[ARG2]]
- "mhlo.return"(%2) : (tensor<i64>) -> ()
- }) : (tensor<i64>) -> tensor<i64>
- func.return %0 : tensor<i64>
-}
-
-// -----
-
-// CHECK-LABEL: @while_with_implicit_capture
-// func @while_with_implicit_capture(%arg0 : tuple<tensor<i1>, tensor<5xi32>>) -> tuple<tensor<i1>, tensor<5xi32>> {
-func.func @while_with_implicit_capture(%arg0 : tensor<i1>, %arg1 : tensor<5xi32>) -> tuple<tensor<i1>, tensor<5xi32>> {
- %0 = mhlo.constant dense<0> : tensor<i32>
- %1 = mhlo.constant dense<false> : tensor<i1>
- // Check that the iota implicit capture is made explicit
- // CHECK: %[[IOTA:.*]] = "mhlo.iota
- %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32>
- // CHECK: mhlo.while{{.*}} %[[IOTA]])
- %3:2 = "mhlo.while"(%arg0, %arg1) ({
- ^bb0(%arg2: tensor<i1>, %arg3 : tensor<5xi32>):
- "mhlo.return"(%arg2) : (tensor<i1>) -> ()
- }, {
- ^bb0(%arg2: tensor<i1>, %arg3 : tensor<5xi32>):
- "mhlo.return"(%arg2, %2) : (tensor<i1>, tensor<5xi32>) -> ()
- }) : (tensor<i1>, tensor<5xi32>) -> (tensor<i1>, tensor<5xi32>)
- %4 = "mhlo.tuple"(%3#0, %3#1) : (tensor<i1>, tensor<5xi32>) -> tuple<tensor<i1>, tensor<5xi32>>
- func.return %4 : tuple<tensor<i1>, tensor<5xi32>>
- }
-
-// -----
-
-// Verifies that a value captured multiple times gets all of its uses updated.
-// CHECK-LABEL: @while_with_multiple_capture
-func.func @while_with_multiple_capture(%arg0: tensor<i64>) -> tensor<i64> {
- // CHECK: mhlo.while
- // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0)
- %0 = "mhlo.while"(%arg0) ({
- ^bb0(%arg1: tensor<i64>):
- // CHECK: mhlo.compare
- // CHECK-SAME: %[[ARG2]], %[[ARG1]]
- %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i64>, tensor<i64>) -> tensor<i1>
- "mhlo.return"(%1) : (tensor<i1>) -> ()
- }, {
- ^bb0(%arg1: tensor<i64>):
- // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG1]]
- %2 = mhlo.add %arg0, %arg1 : tensor<i64>
- // CHECK: mhlo.return
- // CHECK-SAME: %[[ADD]], %[[ARG2]]
- "mhlo.return"(%2) : (tensor<i64>) -> ()
- }) : (tensor<i64>) -> tensor<i64>
- func.return %0 : tensor<i64>
-}
-
-// -----
-
// CHECK-LABEL: @broadcast_in_dim_dimension_unsorted
func.func @broadcast_in_dim_dimension_unsorted(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
// Unfuse the transpose from the broadcastInDim before export.
diff --git a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc
index 2562c06..0c8a4de 100644
--- a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc
+++ b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc
@@ -100,9 +100,10 @@
return calcMultiDimIndex(b, loc, linearIndex, shapeVec);
}
-SmallVector<Value> calcMultiDimIndexForFirstOperand(OpBuilder& b, Location loc,
- Value linearIndex,
- Operation* op) {
+static SmallVector<Value> calcMultiDimIndexForFirstOperand(OpBuilder& b,
+ Location loc,
+ Value linearIndex,
+ Operation* op) {
assert(op->getDialect()->getNamespace() == "lmhlo");
Value operandMemref = op->getOperand(0);
return calcMultiDimIndex(b, loc, linearIndex, operandMemref);
diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD
index b73a59f..adf30e0 100644
--- a/third_party/xla/xla/pjrt/BUILD
+++ b/third_party/xla/xla/pjrt/BUILD
@@ -311,6 +311,10 @@
deps = [
":pjrt_common",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc
index eee88ad..e17b04d 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc
@@ -77,8 +77,8 @@
TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) {
// Prepares a device memory ptr on GPU.
- std::unique_ptr<PJRT_Buffer, ::pjrt::PJRT_BufferDeleter> buffer =
- create_buffer().first;
+ auto [buffer, buffer_future] = create_buffer();
+ TF_CHECK_OK(buffer_future.Await());
PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args device_buffer_ptr_args;
device_buffer_ptr_args.struct_size =
PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE;
diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD
index f3ae2c5..0ac2c99 100644
--- a/third_party/xla/xla/pjrt/cpu/BUILD
+++ b/third_party/xla/xla/pjrt/cpu/BUILD
@@ -191,6 +191,7 @@
"//xla/stream_executor",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/concurrency:ref_count",
+ "//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:dynamic_annotations",
@@ -208,7 +209,6 @@
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool.
"@llvm-project//mlir:IR",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:denormal",
"@local_tsl//tsl/platform:env",
diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc
index 3b2caa6..78779c8 100644
--- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc
+++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc
@@ -103,10 +103,10 @@
#include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/ref_count.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/denormal.h"
#include "tsl/platform/env.h"
@@ -398,6 +398,11 @@
num_threads, options.asynchronous));
}
+// An upper bound on the number of threads to use for intra-op parallelism. It
+// is nearly impossible to utilize efficiently more than 256 threads for compute
+// intensive operations that are supposed to run inside the intra-op threadpool.
+static const size_t kMaxIntraOpThreads = 256;
+
static tsl::ThreadOptions GetThreadOptions() {
tsl::ThreadOptions thread_options;
// On Mac OS the default stack size is 512KiB, which is too small for some
@@ -415,16 +420,17 @@
: process_index_(process_index),
owned_devices_(std::move(devices)),
computation_placer_(std::make_unique<ComputationPlacer>()),
+ eigen_intraop_pool_(new tsl::thread::ThreadPool(
+ tsl::Env::Default(), "XLAEigen",
+ std::min(num_threads, kMaxIntraOpThreads))),
+ eigen_intraop_device_(
+ new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(),
+ eigen_intraop_pool_->NumThreads())),
pjrt_client_thread_pool_(
new tsl::thread::ThreadPool(tsl::Env::Default(), GetThreadOptions(),
"XLATfrtCpuClient", num_threads)),
async_work_runner_(std::make_unique<ThreadPoolAsyncWorkRunner>(
pjrt_client_thread_pool_.get())),
- eigen_intraop_pool_(new tsl::thread::ThreadPool(tsl::Env::Default(),
- "XLAEigen", num_threads)),
- eigen_intraop_device_(
- new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(),
- eigen_intraop_pool_->NumThreads())),
last_collective_launch_event_(
tsl::MakeAvailableAsyncValueRef<CpuEvent>()),
transpose_cache_(1024),
@@ -463,10 +469,10 @@
owned_memory_spaces_.push_back(std::move(memory_space));
}
- LOG(INFO) << "TfrtCpuClient created.";
+ VLOG(1) << "TfrtCpuClient created.";
}
-TfrtCpuClient::~TfrtCpuClient() { LOG(INFO) << "TfrtCpuClient destroyed."; }
+TfrtCpuClient::~TfrtCpuClient() { VLOG(1) << "TfrtCpuClient destroyed."; }
absl::StatusOr<PjRtDevice*> TfrtCpuClient::LookupDevice(
xla::PjRtGlobalDeviceId global_device_id) const {
diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h
index 54f75f3..ba4426b 100644
--- a/third_party/xla/xla/pjrt/cpu/cpu_client.h
+++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h
@@ -446,14 +446,14 @@
// Pointers to `owned_memory_spaces_`.
std::vector<PjRtMemorySpace*> memory_spaces_;
- // Thread pool for running PjRtClient tasks.
- std::unique_ptr<tsl::thread::ThreadPool> pjrt_client_thread_pool_;
- std::unique_ptr<AsyncWorkRunner> async_work_runner_;
-
// TODO(zhangqiaorjc): Use tsl::compat::EigenHostContextThreadPool.
std::unique_ptr<tsl::thread::ThreadPool> eigen_intraop_pool_;
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_intraop_device_;
+ // Thread pool for running PjRtClient tasks.
+ std::unique_ptr<tsl::thread::ThreadPool> pjrt_client_thread_pool_;
+ std::unique_ptr<AsyncWorkRunner> async_work_runner_;
+
// Launching collectives are prone to deadlock when we use fixed-sized
// threadpools since ExecuteHelper will block until all replicas reach the
// barrier. We ensure that
@@ -589,7 +589,7 @@
}
memory_stats.serialized_hlo_proto = proto->SerializeAsString();
memory_stats.PopulateBufferStatsFromAllocations(
- cpu_executable_.get()->GetAllocations());
+ cpu_executable_->GetAllocations());
return memory_stats;
}
diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD
index 90ab11e..f9019fc 100644
--- a/third_party/xla/xla/pjrt/gpu/BUILD
+++ b/third_party/xla/xla/pjrt/gpu/BUILD
@@ -91,6 +91,7 @@
"//xla/tsl/framework:bfc_allocator",
"//xla/tsl/framework:device_id",
"//xla/tsl/framework:device_id_impl",
+ "//xla/tsl/lib/strings:proto_serialization",
"//xla/tsl/util:env_var",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
@@ -108,7 +109,6 @@
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
index 27d3f18..9fe8b2e 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
@@ -82,7 +82,7 @@
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/framework/allocator.h"
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/fingerprint.h"
diff --git a/third_party/xla/xla/pjrt/pjrt_device_description.h b/third_party/xla/xla/pjrt/pjrt_device_description.h
index ed85269..77107fd 100644
--- a/third_party/xla/xla/pjrt/pjrt_device_description.h
+++ b/third_party/xla/xla/pjrt/pjrt_device_description.h
@@ -20,12 +20,35 @@
#include <string_view>
#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "xla/pjrt/pjrt_common.h"
namespace xla {
using PjRtDeviceAttribute = PjRtValueType;
+class PjRtMemorySpaceDescription {
+ public:
+ PjRtMemorySpaceDescription(absl::string_view kind, int kind_id)
+ : kind_(kind), kind_id_(kind_id) {}
+
+ // A platform-dependent string that uniquely identifies the kind of the
+ // memory space.
+ absl::string_view kind() const { return kind_; }
+
+ // An ID uniquely identifies the kind of the memory space among those attached
+ // to the same `PjRtClient`. The IDs assigned to a kind is implementation
+ // specific.
+ int kind_id() const { return kind_id_; }
+
+ private:
+ absl::string_view kind_;
+ int kind_id_;
+};
+
class PjRtDeviceDescription {
public:
virtual ~PjRtDeviceDescription() = default;
@@ -60,6 +83,19 @@
// reference will remain valid for the lifetime of the PjRtDevice.
virtual const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
Attributes() const = 0;
+
+ // Returns all memory spaces attached to this device.
+ // The memory spaces are in no particular order.
+ virtual absl::Span<const PjRtMemorySpaceDescription* const> memory_spaces()
+ const {
+ return {};
+ }
+
+ // Returns the default memory space attached to this device.
+ virtual absl::StatusOr<const PjRtMemorySpaceDescription*>
+ default_memory_space() const {
+ return absl::UnimplementedError("default_memory_space Not implemented.");
+ }
};
} // namespace xla
diff --git a/third_party/xla/xla/protobuf_util.cc b/third_party/xla/xla/protobuf_util.cc
index a8d6dfa..4c6815d 100644
--- a/third_party/xla/xla/protobuf_util.cc
+++ b/third_party/xla/xla/protobuf_util.cc
@@ -49,20 +49,5 @@
return absl::HashOf(serialized);
}
-absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
- const std::string& directory,
- const std::string& file_name,
- std::string* full_path) {
- tsl::Env* env = tsl::Env::Default();
- TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
- std::string safe_file_name = SanitizeFileName(file_name) + ".pb";
- std::string full_path_impl;
- if (!full_path) {
- full_path = &full_path_impl;
- }
- *full_path = tsl::io::JoinPath(directory, safe_file_name);
- return tsl::WriteBinaryProto(env, *full_path, message);
-}
-
} // namespace protobuf_util
} // namespace xla
diff --git a/third_party/xla/xla/protobuf_util.h b/third_party/xla/xla/protobuf_util.h
index 79f0077..81f7952 100644
--- a/third_party/xla/xla/protobuf_util.h
+++ b/third_party/xla/xla/protobuf_util.h
@@ -55,17 +55,6 @@
return ProtobufHash(m);
}
};
-// Writes the given message in binary proto to the path formed by joining
-// 'directory/file_name.pb'. The 'directory' is recursively created if it
-// doesn't already exist, and the 'file_name' is sanitized by replacing
-// illegal characters with underscore '_'.
-//
-// If 'full_name' is not null then it is set to the name of the file the
-// protobuf was written to.
-absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
- const std::string& directory,
- const std::string& file_name,
- std::string* full_path = nullptr);
// Registers a function that may either expand a dirpath or forward the original
// dirpath along as-is.
diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD
index 3bf1ebf..99c49a1 100644
--- a/third_party/xla/xla/python/BUILD
+++ b/third_party/xla/xla/python/BUILD
@@ -3,10 +3,6 @@
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library")
load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
-load(
- "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
- "if_cuda_is_configured",
-)
load("//xla:pytype.default.bzl", "pytype_strict_library")
load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test")
load(
@@ -18,7 +14,6 @@
"//xla/tsl:tsl.bzl",
"if_cuda_or_rocm",
"if_google",
- "if_oss",
"internal_visibility",
)
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_pybind_extension")
@@ -1052,12 +1047,7 @@
"@local_tsl//tsl/profiler/rpc:profiler_server_impl",
"@local_tsl//tsl/profiler/rpc/client:capture_profile",
"@local_tsl//tsl/profiler/rpc/client:profiler_client_impl",
- ] + select({
- ":gpu_enabled": [
- "//xla/backends/profiler/gpu:device_tracer",
- ],
- "//conditions:default": [],
- }),
+ ],
)
cc_library(
@@ -1180,7 +1170,7 @@
"//xla/service:hlo_proto_cc",
"//xla/service:name_uniquer",
"//xla/service:tuple_simplifier",
- "@local_tsl//tsl/lib/strings:proto_serialization",
+ "//xla/tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
@@ -1194,26 +1184,6 @@
cc_api_version = 2,
)
-# TODO(phawkins): the configuration settings here are overly confusing. The right fix is to split
-# xla_extension.so so that each backend is a separate plugin, however that must wait for a clean
-# ABI separation between devices.
-config_setting(
- name = "link_gpu_plugin",
- define_values = {"xla_python_enable_gpu": "true"},
-)
-
-bool_flag(
- name = "enable_gpu",
- build_setting_default = True,
-)
-
-config_setting(
- name = "gpu_enabled",
- flag_values = {
- ":enable_gpu": "True",
- },
-)
-
# If this flag is enabled, it sets RPATH on the xla_extension to values that are suitable for
# finding NVIDIA's CUDA libraries when they are installed as pip packages.
bool_flag(
@@ -1228,17 +1198,6 @@
},
)
-# We cannot nest select and if_cuda_is_configured so we introduce
-# a standalone cc_library target.
-cc_library(
- name = "gpu_plugin_deps",
- deps = [
- "//xla/service:gpu_plugin",
- ] + if_cuda_is_configured([
- "//xla/stream_executor:cuda_platform",
- ]),
-)
-
cc_library(
name = "logging",
srcs = ["logging.cc"],
@@ -1301,13 +1260,6 @@
"-fexceptions",
"-fno-strict-aliasing",
],
- defines = if_google(
- [],
- select({
- ":gpu_enabled": ["XLA_PYTHON_ENABLE_GPU=1"],
- "//conditions:default": [],
- }),
- ),
features = ["-use_header_modules"],
linkopts = select({
":use_jax_cuda_pip_rpaths": [
@@ -1407,10 +1359,7 @@
"@local_tsl//tsl/platform/cloud:gcs_file_system",
] + select({
# gloo transport only builds on linux
- "//xla/tsl:macos": [
- "//xla/pjrt/cpu:gloo_collectives",
- "//xla/pjrt/cpu:gloo_kv_store",
- ] + if_oss(["@gloo//:transport_uv"]),
+ "//xla/tsl:macos": [],
"//xla/tsl:windows": [],
"//conditions:default": [
"//xla/pjrt/cpu:gloo_collectives",
@@ -1423,20 +1372,7 @@
"//conditions:default": [
"//xla/pjrt/cpu:mpi_collectives",
],
- }) + if_google(
- [],
- select({
- ":gpu_enabled": [
- ":gpu_support",
- ],
- "//conditions:default": [],
- }) + select({
- ":link_gpu_plugin": [
- ":gpu_plugin_deps",
- ],
- "//conditions:default": [],
- }),
- ),
+ }),
)
cc_library(
diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc
index 6821c7e..080f0fa 100644
--- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc
+++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc
@@ -182,6 +182,46 @@
int output_index;
};
+mlir::LogicalResult VerifyElementTypeAndPerShardShapeAreEqual(
+ mlir::Operation* op, IfrtArrayType in, int in_index, IfrtArrayType out,
+ int out_index) {
+ if (in.getShape().getElementType() != out.getShape().getElementType()) {
+ return op->emitOpError()
+ << "can't alias input #" << in_index << " to output #" << out_index
+ << " with different element types: " << in << " vs " << out;
+ }
+
+ absl::StatusOr<llvm::SmallVector<int64_t>> in_per_shard_shape =
+ in.getShardingAttr().LocalShapeFromGlobalShape(in.getShape().getShape());
+ if (!in_per_shard_shape.ok()) {
+ return op->emitOpError()
+ << "unable to get per-shard shape of aliased input #" << in_index
+ << ": " << in_per_shard_shape.status().message();
+ }
+ absl::StatusOr<llvm::SmallVector<int64_t>> out_per_shard_shape =
+ out.getShardingAttr().LocalShapeFromGlobalShape(
+ out.getShape().getShape());
+ if (!out_per_shard_shape.ok()) {
+ return op->emitOpError()
+ << "unable to get per-shard shape of aliased output #" << out_index
+ << ": " << out_per_shard_shape.status().message();
+ }
+ if (in_per_shard_shape->size() != out_per_shard_shape->size()) {
+ return op->emitOpError()
+ << "can't alias input #" << in_index << " to output #" << out_index
+ << " with different per-shard shapes: " << in << " vs " << out;
+ }
+ for (const auto& [in_dim, out_dim] :
+ llvm::zip(*in_per_shard_shape, *out_per_shard_shape)) {
+ if (in_dim != out_dim) {
+ return op->emitOpError()
+ << "can't alias input #" << in_index << " to output #" << out_index
+ << " with different per-shard shapes: " << in << " vs " << out;
+ }
+ }
+ return mlir::success();
+}
+
mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias,
llvm::ArrayRef<IfrtArrayType> inputs,
llvm::ArrayRef<IfrtArrayType> outputs) {
@@ -198,11 +238,12 @@
<< " outputs";
}
if (inputs[io_alias.input_index] != outputs[io_alias.output_index]) {
- return op->emitOpError()
- << "can't alias input #" << io_alias.input_index << " to output #"
- << io_alias.output_index
- << " with different types: " << inputs[io_alias.input_index]
- << " vs " << outputs[io_alias.output_index];
+ // TODO(icgog): Relax this aliasing check to allow for different per-shard
+ // shapes as long as the byte size is the same. We cannot do this now
+ // because we do not have layout information.
+ return VerifyElementTypeAndPerShardShapeAreEqual(
+ op, inputs[io_alias.input_index], io_alias.input_index,
+ outputs[io_alias.output_index], io_alias.output_index);
}
return mlir::success();
}
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir
index edaff92..8c70318 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir
@@ -238,7 +238,6 @@
}
}
-
// -----
!array0 = !ifrt.array<tensor<2xi32>,
@@ -259,3 +258,21 @@
return %arg0 : tensor<2xi32>
}
}
+
+// -----
+
+!array = !ifrt.array<tensor<2xi32>,
+ #ifrt.sharding_param<2 to [0] on 2>, [0, 1]>
+module @copy_arrays_with_already_donated_array_error {
+ func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array)
+ attributes {ifrt.function} {
+ %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1]
+ {io_aliases=[array<i32: 0, 0>]} : (!array) -> !array
+ // expected-error @+1 {{'func.return' op result #1 of op at}}
+ return %0, %arg0 : !array, !array
+ }
+
+ func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+ return %arg0 : tensor<2xi32>
+ }
+}
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir
index 28a1dda..4fef087 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir
@@ -3,7 +3,7 @@
#device = #ifrt<devices[0,1]>
#sharding = #ifrt.sharding_param<2x1 to [0] on 2>
// CHECK-LABEL: @identity_axis0_sharded
-module @identity_axis0_sharded attributes {ifrt.devices = #device} {
+module @identity_axis0_sharded attributes {ifrt.num_devices = 2} {
// CHECK-NEXT: func.func @main
// CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32>
// CHECK-NEXT: return %[[ARG]]
@@ -23,7 +23,7 @@
#sharding = #ifrt.sharding_param<1x2 to [0] on 2>
// CHECK-LABEL: @identity_axis1_sharded
module @identity_axis1_sharded
- attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} {
+ attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} {
// CHECK-NEXT: func.func @entry_func
// CHECK-SAME: %[[ARG:.*]]: tensor<2x1xi32>
// CHECK-NEXT: return %[[ARG]]
@@ -42,7 +42,7 @@
#device = #ifrt<devices[0,1,2,3,4,5]>
#sharding = #ifrt.sharding_param<3x2 to [1,0] on 2x3>
// CHECK-LABEL: @identify_both_axes_sharded
-module @identify_both_axes_sharded attributes {ifrt.devices = #device} {
+module @identify_both_axes_sharded attributes {ifrt.num_devices = 6} {
// CHECK-NEXT: func.func @main
// CHECK-SAME: %[[ARG:.*]]: tensor<1x1xi32>
// CHECK-NEXT: return %[[ARG]]
@@ -60,7 +60,7 @@
#device = #ifrt<devices[0,1]>
// CHECK-LABEL: @with_func_call
-module @with_func_call attributes {ifrt.devices = #device} {
+module @with_func_call attributes {ifrt.num_devices = 2} {
// CHECK-NEXT: func.func @main
// CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32>
// CHECK-SAME: tensor<1x2xi32>
@@ -94,7 +94,7 @@
#device = #ifrt<devices[0,1]>
// CHECK-LABEL: @with_nested_func_call
-module @with_nested_func_call attributes {ifrt.devices = #device} {
+module @with_nested_func_call attributes {ifrt.num_devices = 2} {
// CHECK-NEXT: func.func @main
// CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32>
// CHECK-SAME: tensor<1x2xi32>
@@ -139,11 +139,10 @@
// -----
-#device = #ifrt<devices[0,1]>
#sharding = #ifrt.sharding_param<1x2 to [0] on 2>
// expected-error@+1 {{cannot find entry function `main`}}
module @missing_main_function
- attributes {ifrt.devices = #device} {
+ attributes {ifrt.num_devices = 2} {
}
// -----
@@ -152,7 +151,7 @@
#sharding = #ifrt.sharding_param<1x2 to [0] on 2>
// expected-error@+1 {{cannot find entry function `entry_func`}}
module @missing_entry_function
- attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} {
+ attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} {
func.func @main(
%arg0: tensor<2x2xi32> {ifrt.sharding = #sharding,
ifrt.devices = #device})
@@ -166,7 +165,7 @@
#device = #ifrt<devices[0,1]>
#sharding = #ifrt.sharding_param<2x1 to [0] on 2>
-module @non_divisible_global_shape attributes {ifrt.devices = #device} {
+module @non_divisible_global_shape attributes {ifrt.num_devices = 2} {
// expected-error@+1 {{Global shape is not divisible by the number of shards in dimension 0. Global size: 3, number of shards: 2}}
func.func @main(
%arg0: tensor<3x2xi32> {ifrt.sharding = #sharding,
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir
index b318756..e512b26 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir
@@ -355,11 +355,32 @@
// -----
-func.func @io_aliases_should_have_same_type(
+!array0 = !ifrt.array<tensor<2x1xi32>,
+ #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]>
+!array1 = !ifrt.array<tensor<1x1xi32>,
+ #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
+func.func @io_aliases_of_different_type_but_same_per_shard_shape(%arg0: !array0)
+ attributes {ifrt.function} {
+ %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1]
+ {io_aliases=[array<i32: 0, 0>]} : (!array0) -> !array1
+ return
+}
+
+func.func @callee(%arg0: tensor<2x1xi32>) -> tensor<1x1xi32> {
+ %0 = mhlo.constant dense<-2147483648> : tensor<i32>
+ %1 = mhlo.reduce(%arg0 init: %0) applies mhlo.maximum across dimensions = [0, 1]
+ : (tensor<2x1xi32>, tensor<i32>) -> tensor<i32>
+ %2 = mhlo.reshape %1 : (tensor<i32>) -> tensor<1x1xi32>
+ return %2 : tensor<1x1xi32>
+}
+
+// -----
+
+func.func @io_aliases_should_alias_arrays_with_same_per_shard_shape(
%arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
[0,1]>)
attributes {ifrt.function} {
- // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different types: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
+ // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different per-shard shapes: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
%0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1]
{io_aliases=[array<i32: 0, 0>]}
: (!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir
index d522841..14485f4 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir
@@ -215,7 +215,7 @@
%arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
[0,1]>)
attributes {ifrt.function} {
- // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different types: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
+ // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different per-shard shapes: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
%0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0)
{io_aliases=[array<i32: 0, 0>]}
: (!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h b/third_party/xla/xla/python/ifrt/ir/transforms/constants.h
index 98bfd12..fff0484 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/constants.h
@@ -21,6 +21,8 @@
namespace xla::ifrt {
inline constexpr llvm::StringLiteral kIfrtDevicesAttrName = "ifrt.devices";
+inline constexpr llvm::StringLiteral kIfrtNumDevicesAttrName =
+ "ifrt.num_devices";
inline constexpr llvm::StringLiteral kIfrtShardingAttrName = "ifrt.sharding";
inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName =
"ifrt.entry_function";
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc
index bdfcf78..7e34921 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc
@@ -16,10 +16,10 @@
#include <memory>
#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
@@ -63,10 +63,15 @@
};
void IfrtVerifyDonationPass::runOnOperation() {
- mlir::ModuleOp module_op = getOperation();
+ mlir::func::FuncOp func_op = getOperation();
+ // We only need to run this pass on IFRT functions.
+ if (!func_op->hasAttr(kIfrtFunctionAttrName) &&
+ !func_op->hasAttr(kIfrtReshardFunctionAttrName)) {
+ return;
+ }
llvm::DenseMap<mlir::Value, mlir::Operation*> donated_value_to_op;
- mlir::WalkResult result = module_op.walk([&](mlir::Operation* op)
- -> mlir::WalkResult {
+ mlir::WalkResult result = func_op.walk([&](mlir::Operation* op)
+ -> mlir::WalkResult {
auto result =
llvm::TypeSwitch<mlir::Operation*, mlir::LogicalResult>(op)
.Case<xla::ifrt::CallOp, xla::ifrt::CallLoadedExecutableOp>(
@@ -136,6 +141,20 @@
}
return mlir::success();
})
+ .Case<mlir::func::ReturnOp>([&](mlir::func::ReturnOp return_op) {
+ for (const auto& [idx, result] :
+ llvm::enumerate(return_op.getOperands())) {
+ auto donated_it = donated_value_to_op.find(result);
+ if (donated_it != donated_value_to_op.end()) {
+ return_op.emitOpError()
+ << "result #" << idx << " of op at " << return_op.getLoc()
+ << " was already donated to the op at "
+ << donated_it->second->getLoc();
+ return mlir::failure();
+ }
+ }
+ return mlir::success();
+ })
.Default(mlir::success());
if (mlir::failed(result)) {
@@ -151,7 +170,7 @@
} // namespace
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateIfrtVerifyDonationPass() {
return std::make_unique<IfrtVerifyDonationPass>();
}
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td
index d299c6b..c8c8e99 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td
@@ -139,7 +139,8 @@
let constructor = "CreateIfrtMergeReshardsPass()";
}
-def IfrtVerifyDonationPass : Pass<"ifrt-verify-donation", "mlir::ModuleOp"> {
+def IfrtVerifyDonationPass :
+ Pass<"ifrt-verify-donation", "mlir::func::FuncOp"> {
let summary = "Verify that `!ifrt.array` are not donated more than once.";
let description = [{
Verifiy that no `!ifrt.array` is donated more than once, and that all
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc
index 2669dfd..07a8c8d 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc
@@ -35,7 +35,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
-#include "xla/python/ifrt/ir/ifrt_dialect.h"
#include "xla/python/ifrt/ir/ifrt_interfaces.h"
#include "xla/python/ifrt/ir/transforms/constants.h"
#include "xla/python/ifrt/ir/transforms/passes.h"
@@ -272,15 +271,15 @@
void SpmdExpansionPass::runOnOperation() {
mlir::ModuleOp module_op = getOperation();
// Skip single-device case.
- auto devices = module_op->getAttrOfType<xla::ifrt::IfrtDevicesAttr>(
- kIfrtDevicesAttrName);
- if (devices == nullptr) {
+ auto num_devices =
+ module_op->getAttrOfType<mlir::IntegerAttr>(kIfrtNumDevicesAttrName);
+ if (num_devices == nullptr) {
module_op->emitOpError()
<< "`" << module_op.getName()->str() << "` requires `"
- << kIfrtDevicesAttrName << "` attribute.";
+ << kIfrtNumDevicesAttrName << "` attribute.";
return signalPassFailure();
}
- if (devices.getIds().size() == 1) {
+ if (num_devices.getInt() == 1) {
return;
}
diff --git a/third_party/xla/xla/python/ifrt/sharding.cc b/third_party/xla/xla/python/ifrt/sharding.cc
index e302535..f5ca885 100644
--- a/third_party/xla/xla/python/ifrt/sharding.cc
+++ b/third_party/xla/xla/python/ifrt/sharding.cc
@@ -50,6 +50,14 @@
namespace {
+// Returns a canonicalized memory kind for the given devices.
+// REQUIRES: !devices.empty()
+MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind,
+ const DeviceList& devices) {
+ CHECK(!devices.empty());
+ return CanonicalizeMemoryKind(memory_kind, devices.front());
+}
+
// Returns if `sharding_param` indicates a fully replicated sharding.
bool ComputeIsFullyReplicated(const ShardingParam& sharding_param) {
return llvm::all_of(sharding_param.dim_shards(),
@@ -155,6 +163,12 @@
char DeserializeShardingOptions::ID = 0;
+Sharding::Sharding(DeviceList devices, MemoryKind memory_kind,
+ bool is_fully_replicated)
+ : devices_(std::move(devices)),
+ memory_kind_(memory_kind),
+ is_fully_replicated_(is_fully_replicated) {}
+
bool Sharding::operator==(const Sharding& other) const {
if (this == &other) {
return true;
@@ -184,6 +198,7 @@
std::unique_ptr<SingleDeviceSharding> SingleDeviceSharding::Create(
Device* device, MemoryKind memory_kind) {
+ memory_kind = CanonicalizeMemoryKind(memory_kind, device);
return std::unique_ptr<SingleDeviceSharding>(
new SingleDeviceSharding(device, memory_kind));
}
@@ -247,6 +262,7 @@
std::unique_ptr<OpaqueSharding> OpaqueSharding::Create(DeviceList devices,
MemoryKind memory_kind) {
+ memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
return std::unique_ptr<OpaqueSharding>(
new OpaqueSharding(std::move(devices), memory_kind));
}
@@ -318,6 +334,7 @@
DeviceList devices, MemoryKind memory_kind, Shape shape,
std::vector<Shape> shard_shapes) {
CHECK_EQ(devices.size(), shard_shapes.size());
+ memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
return std::unique_ptr<ConcreteSharding>(
new ConcreteSharding(std::move(devices), memory_kind, std::move(shape),
std::move(shard_shapes)));
@@ -327,6 +344,7 @@
DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape,
std::vector<DynamicShape> shard_dynamic_shapes) {
CHECK_EQ(devices.size(), shard_dynamic_shapes.size());
+ memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
return std::unique_ptr<ConcreteSharding>(new ConcreteSharding(
std::move(devices), memory_kind, std::move(dynamic_shape),
std::move(shard_dynamic_shapes)));
@@ -472,6 +490,7 @@
std::unique_ptr<ConcreteEvenSharding> ConcreteEvenSharding::Create(
DeviceList devices, MemoryKind memory_kind, Shape shape, Shape shard_shape,
bool is_fully_replicated) {
+ memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
return std::unique_ptr<ConcreteEvenSharding>(new ConcreteEvenSharding(
std::move(devices), memory_kind, std::move(shape), std::move(shard_shape),
is_fully_replicated));
@@ -586,6 +605,7 @@
"%d",
device_count, devices.size());
}
+ memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
return std::unique_ptr<ShardingParamSharding>(new ShardingParamSharding(
std::move(sharding_param), std::move(devices), memory_kind));
}
@@ -595,7 +615,8 @@
DeviceList devices,
MemoryKind memory_kind)
: llvm::RTTIExtends<ShardingParamSharding, Sharding>(
- devices, memory_kind, ComputeIsFullyReplicated(sharding_param)),
+ std::move(devices), memory_kind,
+ ComputeIsFullyReplicated(sharding_param)),
sharding_param_(sharding_param) {}
absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
diff --git a/third_party/xla/xla/python/ifrt/sharding.h b/third_party/xla/xla/python/ifrt/sharding.h
index c7fbd25..91b8b8a 100644
--- a/third_party/xla/xla/python/ifrt/sharding.h
+++ b/third_party/xla/xla/python/ifrt/sharding.h
@@ -125,10 +125,8 @@
static char ID; // NOLINT
protected:
- Sharding(DeviceList devices, MemoryKind memory_kind, bool is_fully_replicated)
- : devices_(devices),
- memory_kind_(memory_kind),
- is_fully_replicated_(is_fully_replicated) {}
+ Sharding(DeviceList devices, MemoryKind memory_kind,
+ bool is_fully_replicated);
DeviceList devices_;
MemoryKind memory_kind_;
@@ -189,6 +187,7 @@
class OpaqueSharding : public llvm::RTTIExtends<OpaqueSharding, Sharding> {
public:
// Creates an opaque sharding. `Disassemble()` will fail.
+ // REQUIRES: !devices.empty()
static std::unique_ptr<OpaqueSharding> Create(DeviceList devices,
MemoryKind memory_kind);
@@ -230,6 +229,7 @@
public:
// Creates a concrete sharding that may contain non-identical shard shapes.
// REQUIRES: `devices`.size() == `shard_shapes`.size()
+ // REQUIRES: !devices.empty()
static std::unique_ptr<ConcreteSharding> Create(
DeviceList devices, MemoryKind memory_kind, Shape shape,
std::vector<Shape> shard_shapes);
@@ -237,6 +237,7 @@
// Creates a concrete sharding that may contain non-identical shard dynamic
// shapes.
// REQUIRES: `devices`.size() == `shard_dynamic_shapes`.size()
+ // REQUIRES: !devices.empty()
static std::unique_ptr<ConcreteSharding> Create(
DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape,
std::vector<DynamicShape> shard_dynamic_shapes);
@@ -321,6 +322,7 @@
// Creates a concrete even sharding.
// TODO(hyeontaek): Remove the default value of `is_fully_replicated` once all
// callers are updated to provide it explicitly.
+ // REQUIRES: !devices.empty()
static std::unique_ptr<ConcreteEvenSharding> Create(
DeviceList devices, MemoryKind memory_kind, Shape shape,
Shape shard_shape, bool is_fully_replicated = false);
@@ -371,6 +373,7 @@
class ShardingParamSharding
: public llvm::RTTIExtends<ShardingParamSharding, Sharding> {
public:
+ // REQUIRES: !devices.empty()
static absl::StatusOr<std::unique_ptr<ShardingParamSharding>> Create(
ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind);
diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc
index 15c2935..09b4ccb 100644
--- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc
+++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc
@@ -82,7 +82,6 @@
for (const auto& d : init_response.devices()) {
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>
pjrt_device_attributes;
- AttributeMap::Map attributes;
if (rpc_helper->version().protocol_version() <= 3) {
for (const auto& [key, attr] : d.deprecated_attributes()) {
TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value,
diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc
index 6f79e56..ec675f4 100644
--- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc
+++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc
@@ -97,11 +97,20 @@
return result;
}
+// Returns a canonicalized memory kind for the given devices.
+// REQUIRES: !devices.empty()
+MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind,
+ const DeviceList& devices) {
+ CHECK(!devices.empty());
+ return CanonicalizeMemoryKind(memory_kind, devices.front());
+}
+
} // namespace
std::unique_ptr<HloSharding> HloSharding::Create(
DeviceList devices, MemoryKind memory_kind,
xla::HloSharding xla_hlo_sharding) {
+ memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
return std::unique_ptr<HloSharding>(new HloSharding(
std::move(devices), memory_kind, std::move(xla_hlo_sharding)));
}
diff --git a/third_party/xla/xla/python/profiler/internal/python_hooks.h b/third_party/xla/xla/python/profiler/internal/python_hooks.h
index a9b502e..29e6b83 100644
--- a/third_party/xla/xla/python/profiler/internal/python_hooks.h
+++ b/third_party/xla/xla/python/profiler/internal/python_hooks.h
@@ -77,7 +77,7 @@
Py_XDECREF(m_module);
}
- PythonTraceEntry(PythonTraceEntry&& other) {
+ PythonTraceEntry(PythonTraceEntry&& other) noexcept {
start_time_ns = other.start_time_ns;
end_time_ns = other.end_time_ns;
co_firstlineno = other.co_firstlineno;
diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc
index 9d9db9a..6f5aff6 100644
--- a/third_party/xla/xla/python/py_compile_only_client.cc
+++ b/third_party/xla/xla/python/py_compile_only_client.cc
@@ -19,12 +19,14 @@
#include <functional>
#include <memory>
#include <optional>
+#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
@@ -79,6 +81,40 @@
namespace {
+class CompileOnlyMemory
+ : public llvm::RTTIExtends<CompileOnlyMemory, ifrt::Memory> {
+ public:
+ explicit CompileOnlyMemory(
+ int id, const PjRtMemorySpaceDescription* memory_description,
+ ifrt::Device* device)
+ : id_(id),
+ kind_(memory_description->kind()),
+ debug_string_(absl::StrFormat("CompileOnlyMemory(id=%d, kind=%s)", id,
+ memory_description->kind())),
+ device_(device) {}
+
+ ifrt::MemoryId Id() const override { return ifrt::MemoryId(id_); }
+
+ const ifrt::MemoryKind& Kind() const override { return kind_; }
+
+ absl::string_view ToString() const override { return debug_string_; }
+ absl::string_view DebugString() const override { return debug_string_; }
+
+ absl::Span<ifrt::Device* const> Devices() const override {
+ return absl::Span<ifrt::Device* const>{&device_, 1};
+ }
+
+ static char ID; // NOLINT
+
+ private:
+ int id_;
+ ifrt::MemoryKind kind_;
+ std::string debug_string_;
+ ifrt::Device* device_;
+};
+
+[[maybe_unused]] char CompileOnlyMemory::ID = 0;
+
class CompileOnlyDevice
: public llvm::RTTIExtends<CompileOnlyDevice, ifrt::Device> {
public:
@@ -108,16 +144,31 @@
return description_->DebugString();
}
- absl::Span<ifrt::Memory* const> Memories() const override { return {}; }
+ absl::Span<ifrt::Memory* const> Memories() const override {
+ return unowned_memories_;
+ }
absl::StatusOr<ifrt::Memory*> DefaultMemory() const override {
+ if (default_memory_) {
+ return default_memory_;
+ }
return Unimplemented("DefaultMemory is not supported");
}
const ifrt::AttributeMap& Attributes() const override { return attributes_; }
+ void AttachMemory(std::unique_ptr<ifrt::Memory> memory) {
+ unowned_memories_.push_back(memory.get());
+ owned_memories_.push_back(std::move(memory));
+ }
+
+ void SetDefaultMemory(ifrt::Memory* memory) { default_memory_ = memory; }
+
private:
const PjRtDeviceDescription* description_;
ifrt::AttributeMap attributes_;
+ ifrt::Memory* default_memory_ = nullptr;
+ std::vector<ifrt::Memory*> unowned_memories_;
+ std::vector<std::unique_ptr<ifrt::Memory>> owned_memories_;
};
class InvalidIfrtCompiler final
@@ -153,10 +204,24 @@
: topology_(std::move(topology)),
descriptions_(topology_->DeviceDescriptions()),
attributes_(ifrt::AttributeMap::Map()) {
+ int offset = 0;
for (auto& description : descriptions_) {
owned_devices_.push_back(
std::make_unique<CompileOnlyDevice>(description.get()));
- devices_.push_back(owned_devices_.back().get());
+ auto* device = owned_devices_.back().get();
+ devices_.push_back(device);
+ if (description->process_index() == process_index()) {
+ auto default_memory = description->default_memory_space();
+ for (auto* memory_description : description->memory_spaces()) {
+ auto memory = std::make_unique<CompileOnlyMemory>(
+ offset, memory_description, device);
+ if (default_memory.ok() && memory_description == *default_memory) {
+ device->SetDefaultMemory(memory.get());
+ }
+ device->AttachMemory(std::move(memory));
+ ++offset;
+ }
+ }
}
}
diff --git a/third_party/xla/xla/python/py_values.h b/third_party/xla/xla/python/py_values.h
index 9733a42..51bfdb9 100644
--- a/third_party/xla/xla/python/py_values.h
+++ b/third_party/xla/xla/python/py_values.h
@@ -49,8 +49,8 @@
// dangerous due to `owning_pybuffer`.
DevicePutResult(const DevicePutResult&) = delete;
DevicePutResult& operator=(const DevicePutResult&) = delete;
- DevicePutResult(DevicePutResult&&) = default;
- DevicePutResult& operator=(DevicePutResult&&) = default;
+ DevicePutResult(DevicePutResult&&) noexcept = default;
+ DevicePutResult& operator=(DevicePutResult&&) noexcept = default;
// Points to the on-device array. Not owned.
tsl::RCReference<ifrt::Array> ifrt_array;
diff --git a/third_party/xla/xla/python/python_ref_manager.h b/third_party/xla/xla/python/python_ref_manager.h
index 815e80e..4f1d821 100644
--- a/third_party/xla/xla/python/python_ref_manager.h
+++ b/third_party/xla/xla/python/python_ref_manager.h
@@ -57,7 +57,7 @@
ManagedPyObjects(const ManagedPyObjects& other) = delete;
ManagedPyObjects(ManagedPyObjects&& other) = default;
ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete;
- ManagedPyObjects& operator=(ManagedPyObjects&& other) = default;
+ ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default;
private:
PythonRefManager* manager_ = nullptr;
diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc
index 68a483c..65bfb3f 100644
--- a/third_party/xla/xla/python/pytree.cc
+++ b/third_party/xla/xla/python/pytree.cc
@@ -1249,7 +1249,7 @@
nb::cast<std::string_view>(nb::repr(node_data->first))));
}
node.kind = registration->kind;
- if (node.kind == PyTreeKind::kCustom) {
+ if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) {
node.custom = registration;
node.node_data = node_data->second;
} else if (node.kind == PyTreeKind::kNamedTuple) {
diff --git a/third_party/xla/xla/python/pytree_test.py b/third_party/xla/xla/python/pytree_test.py
index 4125d7a..922a4d7 100644
--- a/third_party/xla/xla/python/pytree_test.py
+++ b/third_party/xla/xla/python/pytree_test.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import collections
+import dataclasses
from absl.testing import absltest
@@ -44,6 +45,15 @@
registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable)
+@dataclasses.dataclass
+class Custom:
+ a: int
+ b: str
+
+
+registry.register_dataclass_node(Custom, ["a"], ["b"])
+
+
class PyTreeTest(absltest.TestCase):
def roundtrip(self, example):
@@ -92,6 +102,15 @@
y = registry.flatten((0, 0))[1]
self.assertEqual((x.compose(y)).num_leaves, 2)
+ def testDataclassMakeFromNodeData(self):
+ c = Custom(1, "a")
+ c_leafs, c_tree = registry.flatten(c)
+ c_tree2 = c_tree.make_from_node_data_and_children(
+ registry, c_tree.node_data(), c_tree.children()
+ )
+ self.assertEqual(c_tree2.unflatten(c_leafs), c)
+ self.assertEqual(str(c_tree2), str(c_tree))
+
if __name__ == "__main__":
absltest.main()
diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD
index 6d57e56..cc0c5e0 100644
--- a/third_party/xla/xla/python/tools/BUILD
+++ b/third_party/xla/xla/python/tools/BUILD
@@ -86,7 +86,7 @@
":types",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
- #internal proto upb dep
+ # copybara:uncomment "//third_party/py/google/protobuf:use_fast_cpp_protos",
"//third_party/py/numpy",
"//xla:xla_data_proto_py",
],
diff --git a/third_party/xla/xla/python/traceback.cc b/third_party/xla/xla/python/traceback.cc
index b3cdfce..19e4f94 100644
--- a/third_party/xla/xla/python/traceback.cc
+++ b/third_party/xla/xla/python/traceback.cc
@@ -99,7 +99,8 @@
}
}
-Traceback::Traceback(Traceback&& other) : frames_(std::move(other.frames_)) {
+Traceback::Traceback(Traceback&& other) noexcept
+ : frames_(std::move(other.frames_)) {
// absl::InlinedVector does not always clear itself if moved. Since we rely on
// its empty() method to destroy Traceback differently, we explicitly clear
// here.
diff --git a/third_party/xla/xla/python/traceback.h b/third_party/xla/xla/python/traceback.h
index c93860b..da80362 100644
--- a/third_party/xla/xla/python/traceback.h
+++ b/third_party/xla/xla/python/traceback.h
@@ -48,7 +48,7 @@
~Traceback();
Traceback(const Traceback&) = delete;
- Traceback(Traceback&& other);
+ Traceback(Traceback&& other) noexcept;
Traceback& operator=(const Traceback&) = delete;
Traceback& operator=(Traceback&&) = delete;
diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc
index b9f9b6e..4feb8cb 100644
--- a/third_party/xla/xla/python/xla.cc
+++ b/third_party/xla/xla/python/xla.cc
@@ -62,19 +62,12 @@
#include "xla/python/py_program.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/tsl/python/lib/core/numpy.h" //NOLINT
-#ifdef XLA_PYTHON_ENABLE_GPU
-#include "xla/python/gpu_support.h"
-#endif // XLA_PYTHON_ENABLE_GPU
#ifdef __linux__
#include "gloo/transport/tcp/attr.h"
#include "gloo/transport/tcp/device.h"
#include "xla/pjrt/cpu/gloo_collectives.h"
#include "xla/pjrt/cpu/gloo_kv_store.h"
-#elif __APPLE__
-#include "gloo/transport/uv/device.h"
-#include "xla/pjrt/cpu/gloo_collectives.h"
-#include "xla/pjrt/cpu/gloo_kv_store.h"
#endif // __linux__
#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE)
@@ -261,7 +254,7 @@
std::optional<std::string> hostname,
std::optional<std::string> interface)
-> std::shared_ptr<xla::cpu::CollectivesInterface> {
-#if defined(__linux__)
+#ifdef __linux__
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
if (distributed_client != nullptr) {
kv_store = GetDistributedKeyValueStore(distributed_client,
@@ -278,26 +271,9 @@
auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs);
return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
std::move(tcp_device));
-#elif defined(__APPLE__)
- std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
- if (distributed_client != nullptr) {
- kv_store = GetDistributedKeyValueStore(distributed_client,
- /*key_prefix=*/"cpu:");
- }
- auto gloo_kv_store = std::make_unique<cpu::GlooKeyValueStore>(kv_store);
- auto uv_attrs = gloo::transport::uv::attr();
- if (hostname) {
- uv_attrs.hostname = *hostname;
- }
- if (interface) {
- uv_attrs.iface = *interface;
- }
- auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs);
- return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
- std::move(uv_device));
#else // __linux__
throw xla::XlaRuntimeError(
- "make_gloo_tcp_collectives only implemented for linux and macos");
+ "make_gloo_tcp_collectives only implemented for linux");
#endif // __linux__
},
nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt,
@@ -387,10 +363,6 @@
return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name));
});
-#ifdef XLA_PYTHON_ENABLE_GPU
- RegisterGpuClientAndDefineGpuAllocatorConfig(m_nb);
-#endif // XLA_PYTHON_ENABLE_GPU
-
m_nb.def(
"get_c_api_client",
[](std::string platform_name,
diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi
index bf5a6d9..8731080 100644
--- a/third_party/xla/xla/python/xla_client.pyi
+++ b/third_party/xla/xla/python/xla_client.pyi
@@ -222,12 +222,16 @@
collapsed_slice_dims: list[int]
start_index_map: list[int]
index_vector_dim: int
+ operand_batching_dims: list[int]
+ start_indices_batching_dims: list[int]
class ScatterDimensionNumbers:
update_window_dims: list[int]
inserted_window_dims: list[int]
scatter_dims_to_operand_dims: list[int]
index_vector_dim: int
+ input_batching_dims: list[int]
+ scatter_indices_batching_dims: list[int]
class ReplicaGroup:
replica_ids: list[int]
diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py
index 65d6d7f..37484cc 100644
--- a/third_party/xla/xla/python/xla_client_test.py
+++ b/third_party/xla/xla/python/xla_client_test.py
@@ -65,8 +65,12 @@
# pylint: disable=invalid-name
-def jax_array_convert_to_array(self):
- return self._single_device_array_to_np_array()
+def jax_array_convert_to_array(self, dtype=None, copy=None):
+ del copy
+ out = self._single_device_array_to_np_array()
+ if dtype is not None:
+ out = out.astype(dtype)
+ return out
def jax_array_device(self):
@@ -586,7 +590,10 @@
def testScalarTimesVector(self, dtype):
c = self._NewComputation()
arg0 = np.array(3, dtype=dtype)
- arg1 = np.array([10, 15, -2, 7], dtype=dtype)
+ if np.issubdtype(dtype, np.unsignedinteger):
+ arg1 = np.array([10, 15, 2, 7], dtype=dtype)
+ else:
+ arg1 = np.array([10, 15, -2, 7], dtype=dtype)
p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
ops.Mul(p0, p1)
diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc
index 7e2504c..f58a59a 100644
--- a/third_party/xla/xla/python/xla_compiler.cc
+++ b/third_party/xla/xla/python/xla_compiler.cc
@@ -80,10 +80,10 @@
#include "xla/service/tuple_simplifier.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc
index 33fb500..d7461ca 100644
--- a/third_party/xla/xla/reference_util.cc
+++ b/third_party/xla/xla/reference_util.cc
@@ -25,11 +25,19 @@
#include "absl/container/flat_hash_set.h"
#include "absl/functional/function_ref.h"
+#include "absl/types/span.h"
+#include "xla/array2d.h"
+#include "xla/array3d.h"
+#include "xla/array4d.h"
+#include "xla/client/padding.h"
#include "xla/client/xla_builder.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/literal.h"
#include "xla/literal_util.h"
+#include "xla/service/hlo_module_config.h"
#include "xla/service/shape_inference.h"
+#include "xla/shape.h"
#include "xla/window_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/lib/math/math_util.h"
diff --git a/third_party/xla/xla/reference_util.h b/third_party/xla/xla/reference_util.h
index 9a124d6..a086fdb 100644
--- a/third_party/xla/xla/reference_util.h
+++ b/third_party/xla/xla/reference_util.h
@@ -24,6 +24,7 @@
#include <vector>
#include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
#include "absl/types/span.h"
#include "xla/array2d.h"
#include "xla/array3d.h"
diff --git a/third_party/xla/xla/reference_util_test.cc b/third_party/xla/xla/reference_util_test.cc
index 320d1ca..c27e541 100644
--- a/third_party/xla/xla/reference_util_test.cc
+++ b/third_party/xla/xla/reference_util_test.cc
@@ -22,7 +22,9 @@
#include "xla/array3d.h"
#include "xla/array4d.h"
#include "xla/client/padding.h"
+#include "xla/error_spec.h"
#include "xla/literal.h"
+#include "xla/literal_util.h"
#include "xla/test.h"
#include "xla/tests/literal_test_util.h"
#include "xla/xla_data.pb.h"
diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD
index 2582534..208c9d6 100644
--- a/third_party/xla/xla/service/BUILD
+++ b/third_party/xla/xla/service/BUILD
@@ -257,53 +257,6 @@
)
cc_library(
- name = "all_reduce_splitter",
- srcs = ["all_reduce_splitter.cc"],
- hdrs = ["all_reduce_splitter.h"],
- deps = [
- ":collective_opt_utils",
- ":hlo_module_config",
- ":hlo_pass",
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "@com_google_absl//absl/cleanup",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "all_reduce_splitter_test",
- srcs = ["all_reduce_splitter_test.cc"],
- deps = [
- ":all_reduce_splitter",
- ":hlo_module_config",
- ":hlo_pass_pipeline",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:gpu_reduce_scatter_creator",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "float_support",
srcs = ["float_support.cc"],
hdrs = ["float_support.h"],
@@ -687,6 +640,7 @@
"//xla:util",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
+ "//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_set",
@@ -703,7 +657,6 @@
"@llvm-project//mlir:Transforms",
"@local_tsl//tsl/lib/io:zlib_compression_options",
"@local_tsl//tsl/lib/io:zlib_outputbuffer",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:path",
@@ -1089,7 +1042,8 @@
deps = [
":pattern_matcher",
":pattern_matcher_gmock",
- "//xla:literal",
+ "//xla:comparison_util",
+ "//xla:literal_util",
"//xla:protobuf_util",
"//xla:shape_util",
"//xla:test",
@@ -1103,6 +1057,9 @@
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
],
)
@@ -1217,17 +1174,20 @@
srcs = ["call_inliner_test.cc"],
deps = [
":call_inliner",
- ":hlo_pass",
+ ":hlo_parser",
"//xla:literal",
+ "//xla:literal_util",
"//xla:shape_util",
"//xla:test",
- "//xla:types",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
],
)
@@ -1737,6 +1697,7 @@
"//xla/stream_executor",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",
+ "//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
@@ -1744,7 +1705,6 @@
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
@@ -4549,7 +4509,9 @@
":hlo_parser",
":pattern_matcher",
":pattern_matcher_gmock",
- "//xla:literal",
+ "//xla:comparison_util",
+ "//xla:literal_util",
+ "//xla:shape_tree",
"//xla:shape_util",
"//xla:test",
"//xla:test_helpers",
@@ -4559,7 +4521,9 @@
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
],
)
@@ -4581,10 +4545,10 @@
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
+ "//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
@@ -6943,6 +6907,7 @@
"@local_tsl//tsl/lib/gtl:map_util",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
],
)
@@ -6951,9 +6916,12 @@
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
+ ":hlo_lexer",
+ ":hlo_module_config",
":hlo_parser",
":pattern_matcher",
":pattern_matcher_gmock",
+ "//xla:array",
"//xla:shape_util",
"//xla:window_util",
"//xla:xla_data_proto_cc",
@@ -6961,8 +6929,13 @@
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
@@ -7926,8 +7899,7 @@
deps = [
":hlo_creation_utils",
":hlo_pass",
- "//xla/service/cpu:onednn_convolution_rewriter",
- "//xla/service/cpu:onednn_matmul_rewriter",
+ "//xla/service/cpu:onednn_contraction_rewriter",
],
)
@@ -8385,4 +8357,42 @@
],
)
+cc_library(
+ name = "add_original_value",
+ srcs = ["add_original_value.cc"],
+ hdrs = ["add_original_value.h"],
+ deps = [
+ ":hlo_pass",
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ ],
+)
+
+xla_cc_test(
+ name = "add_original_value_test",
+ srcs = ["add_original_value_test.cc"],
+ deps = [
+ ":add_original_value",
+ ":pattern_matcher",
+ ":pattern_matcher_gmock",
+ "//xla:shape_util",
+ "//xla:window_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:verified_hlo_module",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/lib/core:status_test_util",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"])
diff --git a/third_party/xla/xla/service/add_original_value.cc b/third_party/xla/xla/service/add_original_value.cc
new file mode 100644
index 0000000..37cab3c
--- /dev/null
+++ b/third_party/xla/xla/service/add_original_value.cc
@@ -0,0 +1,68 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/add_original_value.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_original_value.h"
+#include "xla/shape_util.h"
+
+namespace xla {
+
+absl::StatusOr<bool> AddOriginalValue::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+
+ for (const auto computation : module->computations()) {
+ for (const auto instruction : computation->instructions()) {
+ auto original_value =
+ std::make_shared<OriginalValue>(instruction->shape());
+
+ if (instruction->opcode() == HloOpcode::kGetTupleElement) {
+ const auto* tuple = instruction->operand(0);
+ original_value->CopySubtreeFrom(*tuple->original_value(),
+ {instruction->tuple_index()}, {});
+ } else if (instruction->opcode() == HloOpcode::kTuple) {
+ for (int64_t operand_number = 0;
+ operand_number < instruction->operand_count(); ++operand_number) {
+ original_value->CopySubtreeFrom(
+ *instruction->operand(operand_number)->original_value(), {},
+ {operand_number});
+ }
+ } else {
+ for (auto& leaf : original_value->leaves()) {
+ leaf.second = {std::string(instruction->name()), leaf.first};
+ }
+ }
+ instruction->set_original_value(original_value);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/add_original_value.h b/third_party/xla/xla/service/add_original_value.h
new file mode 100644
index 0000000..b4fb093
--- /dev/null
+++ b/third_party/xla/xla/service/add_original_value.h
@@ -0,0 +1,38 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_ADD_ORIGINAL_VALUE_H_
+#define XLA_SERVICE_ADD_ORIGINAL_VALUE_H_
+
+#include "absl/status/statusor.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// This pass adds to each op in the HLO graph the original_value attribute,
+// which is used for HLO value tracking. See go/hlo-value-tracking for more
+// details.
+class AddOriginalValue : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "add-original-value"; }
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_ADD_ORIGINAL_VALUE_H_
diff --git a/third_party/xla/xla/service/add_original_value_test.cc b/third_party/xla/xla/service/add_original_value_test.cc
new file mode 100644
index 0000000..f69ba94
--- /dev/null
+++ b/third_party/xla/xla/service/add_original_value_test.cc
@@ -0,0 +1,118 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/add_original_value.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace {
+
+using AddOriginalValueTest = HloTestBase;
+
+using ::absl::string_view;
+
+TEST_F(AddOriginalValueTest, Basic) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={(s32[]{:T(256)})->u32[2]{0:T(256)}}
+
+ENTRY test {
+ Arg_0.1 = s32[] parameter(0)
+ constant.2 = s32[] constant(32)
+ shift-right-logical.3 = s32[] shift-right-logical(Arg_0.1, constant.2)
+ convert.4 = u32[] convert(shift-right-logical.3)
+ reshape.5 = u32[1]{0} reshape(convert.4)
+ convert.6 = u32[] convert(Arg_0.1)
+ reshape.7 = u32[1]{0} reshape(convert.6)
+ ROOT concatenate.8 = u32[2]{0} concatenate(reshape.5, reshape.7), dimensions={0}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ AddOriginalValue pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+ EXPECT_TRUE(changed);
+}
+
+TEST_F(AddOriginalValueTest, Tuple) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})}
+
+ENTRY test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]{0}), f32[2,3]{1,0}) {
+ v1 = f32[] parameter(0)
+ v2 = f32[3]{0} parameter(1)
+ v3 = f32[2,3]{1,0} parameter(2)
+ t1 = (f32[], f32[3]{0}) tuple(f32[] v1, f32[3]{0} v2)
+ ROOT t2 = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) t1, f32[2,3]{1,0} v3)
+}
+
+)";
+
+ RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"(
+CHECK: %[[V1:.*]] = f32[] parameter(0), original_value={{[{]}}{"[[V1]]"}
+CHECK: %[[V2:.*]] = f32[3]{0} parameter(1), original_value={{[{]}}{"[[V2]]"}
+CHECK: %[[TUPLE:.*]] = (f32[], f32[3]{0}) tuple(%[[V1]], %[[V2]]), original_value={({"[[V1]]"}, {"[[V2]]"})}
+CHECK: %[[V3:.*]] = f32[2,3]{1,0} parameter(2), original_value={{[{]}}{"[[V3]]"}
+CHECK: ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), original_value={(({"v1"}, {"v2"}), {"v3"})}
+ )");
+}
+
+TEST_F(AddOriginalValueTest, GetTupleElement) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={()->s32[2,3]{1,0}}
+
+ENTRY test {
+ constant = f32[3]{0} constant({1, 2, 3})
+ constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
+ tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} constant, s32[2,3]{1,0} constant.1)
+ ROOT get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) tuple), index=1
+}
+
+)";
+
+ RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"(
+CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), original_value={{[{]}}{"[[CONSTANT1]]"}
+CHECK: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), original_value={{[{]}}{"[[CONSTANT2]]"}
+CHECK: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), original_value={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})}
+CHECK: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, original_value={{[{]}}{"[[CONSTANT2]]"}
+ )");
+}
+
+TEST_F(AddOriginalValueTest, GetTupleElementNonSymbolic) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={((f32[], s32[]))->s32[]}
+
+ENTRY test {
+ p = (f32[], s32[]) parameter(0)
+ ROOT get-tuple-element = s32[] get-tuple-element(p), index=1
+}
+
+)";
+
+ RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"(
+CHECK: %[[PARAM:.*]] = (f32[], s32[]) parameter(0), original_value={({"p" {0}{{[}]}}, {"p" {1}})}
+CHECK: s32[] get-tuple-element(%[[PARAM]]), index=1, original_value={{[{]}}{"[[PARAM]]" {1}
+ )");
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc
index 3ee1ffa..a3a7f76 100644
--- a/third_party/xla/xla/service/algebraic_simplifier.cc
+++ b/third_party/xla/xla/service/algebraic_simplifier.cc
@@ -179,6 +179,11 @@
using NativeT = NativeTypeOf<primitive_type_constant>;
return static_cast<double>(
inst->literal().GetFirstElement<NativeT>());
+ } else if constexpr (primitive_util::IsIntegralType(
+ primitive_type_constant)) {
+ using NativeT = NativeTypeOf<primitive_type_constant>;
+ return static_cast<int64_t>(
+ inst->literal().GetFirstElement<NativeT>());
}
return std::nullopt;
},
@@ -608,6 +613,11 @@
using NativeT = NativeTypeOf<primitive_type_constant>;
return HloInstruction::CreateConstant(
LiteralUtil::CreateR0<NativeT>(static_cast<NativeT>(multiplier)));
+ } else if constexpr (primitive_util::IsIntegralType(
+ primitive_type_constant)) {
+ using NativeT = NativeTypeOf<primitive_type_constant>;
+ return HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<NativeT>(static_cast<NativeT>(multiplier)));
}
LOG(FATAL) << "Unsupported data type: "
<< target->shape().element_type();
@@ -3331,12 +3341,11 @@
} else if (opcode == HloOpcode::kBroadcast) {
// Broadcasts of dot contracting dimensions can be reordered to reduces
// of the corresponding contracting dimensions in the other dot operand
- DimensionVector reduce_dims, broadcast_dim_sizes;
+ DimensionVector reduce_dims;
const int64_t pre_broadcast_rank =
reorder_from->mutable_operand(0)->shape().rank();
int64_t post_broadcast_rank = reorder_from->shape().rank();
Shape new_broadcast_shape = reorder_from->shape();
- DimensionVector contracting_reordered;
// Construct map from broadcasted shape to its original shape. Broadcast
// dimensions are mapped to -1 since they were not present
@@ -3554,48 +3563,28 @@
other_index = outer_dnums.lhs_batch_dimensions(i);
}
- // Once we have the inner_index, we determine whether this index
- // corresponds to a dimension coming from the lhs or rhs of inner
- bool from_inner_lhs = map_inner_rhs[inner_index] == -1;
+ auto add_batch_dims = [](DotDimensionNumbers& dnums, int64_t lhs_ix,
+ int64_t rhs_ix) {
+ dnums.add_lhs_batch_dimensions(lhs_ix);
+ dnums.add_rhs_batch_dimensions(rhs_ix);
+ };
- // The map we use depends on which operand of inner this dim comes from
- std::vector<int64_t> map;
- if (from_inner_lhs) {
- map = map_inner_lhs;
- } else {
- map = map_inner_rhs;
- }
-
- // Whether the mapped value goes into the lhs or rhs of the new dnums
- // depends on whether inner was the lhs or rhs operand of outer
- int64_t lhs_index, rhs_index;
- if (outer_lhs_dot) {
- lhs_index = map[inner_index];
- rhs_index = other_index;
- } else {
- lhs_index = other_index;
- rhs_index = map[inner_index];
- }
-
- // Finally, we have to determine which dnums to add to
- DotDimensionNumbers* dnums;
- if (outer_lhs_dot) {
- if (from_inner_lhs) {
- dnums = &ac_dnums;
- } else {
- dnums = &bc_dnums;
- }
- } else {
- if (from_inner_lhs) {
- dnums = &ab_dnums;
- } else {
- dnums = &ac_dnums;
+ for (auto& map : {map_inner_lhs, map_inner_rhs}) {
+ int64_t mapped_index = map[inner_index];
+ if (mapped_index != -1) {
+ // Whether the mapped value is the lhs or rhs of the new dnums
+ // depends on whether inner is the lhs or rhs operand of outer. The
+ // dnums itself depends on this and also on which map we are
+ // iterating through
+ if (outer_lhs_dot) {
+ add_batch_dims(map == map_inner_lhs ? ac_dnums : bc_dnums,
+ mapped_index, other_index);
+ } else {
+ add_batch_dims(map == map_inner_lhs ? ab_dnums : ac_dnums,
+ other_index, mapped_index);
+ }
}
}
-
- // Add the batch dimensions
- dnums->add_lhs_batch_dimensions(lhs_index);
- dnums->add_rhs_batch_dimensions(rhs_index);
}
// We now do the same thing for the contracting dimensions of outer
@@ -3614,7 +3603,14 @@
// Once we have the inner_index, we determine whether this index
// corresponds to a dimension coming from the lhs or rhs of inner
- bool from_inner_lhs = map_inner_rhs[inner_index] == -1;
+ bool from_inner_lhs = map_inner_lhs[inner_index] != -1;
+ bool from_inner_rhs = map_inner_rhs[inner_index] != -1;
+
+ // If a dimension of inner is the result of batching and it is
+ // contracted in outer, we stop trying to reorder
+ if (from_inner_lhs && from_inner_rhs) {
+ return absl::OkStatus();
+ }
// The map we use depends on which operand of inner this dim comes from
std::vector<int64_t> map;
@@ -3714,8 +3710,11 @@
rhs_index = other_index;
}
- new_outer_dnums.add_lhs_batch_dimensions(lhs_index);
- new_outer_dnums.add_rhs_batch_dimensions(rhs_index);
+ if (!absl::c_linear_search(new_outer_dnums.lhs_batch_dimensions(),
+ lhs_index)) {
+ new_outer_dnums.add_lhs_batch_dimensions(lhs_index);
+ new_outer_dnums.add_rhs_batch_dimensions(rhs_index);
+ }
}
for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
int64_t new_inner_index, other_index;
@@ -4280,6 +4279,19 @@
}
} // namespace
+bool AlgebraicSimplifierVisitor::IsNondecreasingSublinear(
+ const HloInstruction* hlo) {
+ switch (hlo->opcode()) {
+ case HloOpcode::kCbrt:
+ case HloOpcode::kErf:
+ case HloOpcode::kLogistic:
+ case HloOpcode::kTanh:
+ return true;
+ default:
+ return false;
+ }
+}
+
absl::Status AlgebraicSimplifierVisitor::HandleMaximum(
HloInstruction* maximum) {
HloInstruction *lhs, *rhs;
@@ -4350,6 +4362,33 @@
}
}
+ // If the operands of the max are the same non-decreasing function, then we
+ // can sink it; i.e. max(tanh(x), tanh(y)) to tanh(max(x, y))
+ // We only do this if the function asymptotically satisfies |f(x)| <= |x| to
+ // guarantee that no overflow occurs. Proof of correctness:
+ /* https://cvc5.github.io/app/
+ (set-logic ALL)
+ (declare-fun f (Float32) Float32)
+ (assert (forall ((x Float32) (y Float32))
+ (=> (fp.lt x y) (fp.leq (f x) (f y))))) ; NonDecreasing
+ (assert (forall ((x Float32))
+ (fp.leq (fp.abs (f x)) (fp.abs x)))) ; Sublinear
+ (assert (not (forall ((x Float32) (y Float32))
+ (fp.eq (fp.max (f x) (f y))
+ (f (fp.max x y)))))) ; Expect unsat
+ (check-sat)
+ */
+ if (lhs->opcode() == rhs->opcode() && IsNondecreasingSublinear(lhs)) {
+ TF_ASSIGN_OR_RETURN(
+ auto new_maximum,
+ MakeBinaryHlo(HloOpcode::kMaximum, lhs->mutable_operand(0),
+ rhs->mutable_operand(0)));
+ VLOG(10) << "Sinking nondecreasing op through max";
+ return ReplaceWithNewInstruction(
+ maximum, HloInstruction::CreateUnary(maximum->shape(), lhs->opcode(),
+ new_maximum));
+ }
+
return absl::OkStatus();
}
@@ -7839,27 +7878,6 @@
}
}
- // Replace Reduce(Broadcast(x), dims, Sum()) with Broadcast(x * prod(dims)).
- if (HloInstruction * broadcast_arg;
- Match(arg, m::Broadcast(m::ConstantScalar(&broadcast_arg))) &&
- Match(function->root_instruction(),
- m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
- if (auto broadcast_value = GetConstantValue(broadcast_arg);
- broadcast_value.has_value() &&
- // Skip float64, where product is too accurate compared to repeated-sum.
- broadcast_arg->shape().element_type() != PrimitiveType::F64) {
- auto result_value = broadcast_value.value() *
- ShapeUtil::ElementsIn(arg->shape()) /
- ShapeUtil::ElementsIn(reduce_result_shape);
- return ReplaceWithNewInstruction(
- reduce, HloInstruction::CreateBroadcast(
- reduce_result_shape,
- reduce->AddInstruction(
- MakeScalarInstruction(reduce, result_value)),
- {}));
- }
- }
-
// For Computation equal to Min, Max, And or Or, replace Reduce(Broadcast(x),
// a, Computation()) with Computation(x, a) when x is a scalar and the
// broadcast is reduced to a scalar.
@@ -7883,26 +7901,52 @@
}
}
- // Replace Reduce(Broadcast(Scalar)) with Broadcast(Multiply(Scalar)) when the
- // reduction operation is addition
+ // Replace Reduce(Broadcast(x), +, init_value) with Broadcast(Add(Multiply(x),
+ // init_value))) if all reduction dimensions were introduced by Broadcast
if (arg->opcode() == HloOpcode::kBroadcast &&
- ShapeUtil::IsScalar(arg->operand(0)->shape())) {
- if (Match(reduce->to_apply()->root_instruction(),
- m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) &&
- IsScalarConstantZero(init_value)) {
- int64_t reduction_dims_prod = 1;
- for (auto i : reduce->dimensions()) {
- reduction_dims_prod *= arg->shape().dimensions(i);
- }
+ Match(reduce->to_apply()->root_instruction(),
+ m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
+ bool only_reduce_dims_from_broadcast = true;
+ int64_t common_dims_prod = 1;
+ int64_t num_common_dims = 0;
+ Shape new_broadcast_shape = arg->shape();
+ std::vector<int64_t> new_broadcast_dims;
+ // Now we build up the new broadcast shape and dims vector
+ for (int64_t i = 0; i < arg->shape().rank(); ++i) {
+ bool added_by_broadcast = !absl::c_linear_search(arg->dimensions(), i);
+ bool removed_by_reduce = absl::c_linear_search(reduce->dimensions(), i);
+
+ if (removed_by_reduce && !added_by_broadcast) {
+ only_reduce_dims_from_broadcast = false;
+ break;
+ } else if (removed_by_reduce && added_by_broadcast) {
+ new_broadcast_shape.DeleteDimension(i - num_common_dims);
+ common_dims_prod *= arg->shape().dimensions(i);
+ num_common_dims++;
+ } else if (!removed_by_reduce && !added_by_broadcast) {
+ new_broadcast_dims.push_back(i - num_common_dims);
+ }
+ }
+
+ if (only_reduce_dims_from_broadcast) {
+ // HloConstantFolding will later remove any unnecessary multiply and add
+ // instructions.
HloInstruction* multiplier =
- MakeScalarLike(arg->mutable_operand(0), reduction_dims_prod);
+ MakeScalarLike(arg->mutable_operand(0), common_dims_prod);
TF_ASSIGN_OR_RETURN(HloInstruction * multiplied_scalar,
MakeBinaryHlo(HloOpcode::kMultiply,
arg->mutable_operand(0), multiplier));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * add,
+ MakeBinaryHlo(
+ HloOpcode::kAdd,
+ MakeBroadcastHlo(init_value, {}, multiplied_scalar->shape()),
+ multiplied_scalar));
+ VLOG(10) << "Converting common reduce(broadcast) dimensions to multiply";
return ReplaceWithNewInstruction(
- reduce, HloInstruction::CreateBroadcast(reduce->shape(),
- multiplied_scalar, {}));
+ reduce, HloInstruction::CreateBroadcast(new_broadcast_shape, add,
+ new_broadcast_dims));
}
}
diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h
index bdd4f91..49e87fa 100644
--- a/third_party/xla/xla/service/algebraic_simplifier.h
+++ b/third_party/xla/xla/service/algebraic_simplifier.h
@@ -487,6 +487,10 @@
static bool IsNonNegative(const HloInstruction* hlo,
const AlgebraicSimplifierOptions& options);
+ // Check if the opcode of a given instruction is a non-decreasing function
+ // asymptotically satisfying |f(x)| <= |x|
+ static bool IsNondecreasingSublinear(const HloInstruction* hlo);
+
// Modify the layout dimensions of result_shape, so that it becomes the
// re-shaped result of applying bitcast to the original_shape, by using
// dim_map to re-shape layout dimensions of original_shape. Returns the
diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc
index d9e4252..60003eb 100644
--- a/third_party/xla/xla/service/algebraic_simplifier_test.cc
+++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc
@@ -6411,6 +6411,94 @@
m::Dot(m::Parameter(1), m::Parameter(2)))));
}
+TEST_F(AlgebraicSimplifierTest, DotLeftDotSharedBatchReorder) {
+ const char* hlo_string = R"(
+ HloModule module
+
+ ENTRY test {
+ a = f32[5,150,5] parameter(0)
+ b = f32[5,5,5] parameter(1)
+ c = f32[5,5,5] parameter(2)
+
+ inner = f32[5,150,5] dot(a,b),
+ lhs_batch_dims={0}, lhs_contracting_dims={2},
+ rhs_batch_dims={0}, rhs_contracting_dims={2}
+ ROOT outer = f32[5,150,5] dot(inner,c),
+ lhs_batch_dims={0}, lhs_contracting_dims={2},
+ rhs_batch_dims={0}, rhs_contracting_dims={2}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ AlgebraicSimplifierOptions options;
+ options.set_use_associative_reordering(true);
+ options.set_associative_reordering_threshold(1.5);
+ AlgebraicSimplifier simplifier(options);
+ EXPECT_TRUE(simplifier.Run(module.get()).value());
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Parameter(0),
+ m::Dot(m::Parameter(1), m::Parameter(2)))));
+}
+
+TEST_F(AlgebraicSimplifierTest, DotRightDotSharedBatchReorder) {
+ const char* hlo_string = R"(
+ HloModule module
+
+ ENTRY test {
+ a = f32[2,3,3] parameter(0)
+ b = f32[2,3,3] parameter(1)
+ c = f32[2,3,16] parameter(2)
+
+ inner = f32[2,3,16] dot(b,c),
+ lhs_batch_dims={0}, lhs_contracting_dims={2},
+ rhs_batch_dims={0}, rhs_contracting_dims={1}
+ ROOT outer = f32[2,3,16] dot(a,inner),
+ lhs_batch_dims={0}, lhs_contracting_dims={2},
+ rhs_batch_dims={0}, rhs_contracting_dims={1}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ AlgebraicSimplifierOptions options;
+ options.set_use_associative_reordering(true);
+ options.set_associative_reordering_threshold(1.5);
+ AlgebraicSimplifier simplifier(options);
+ EXPECT_TRUE(simplifier.Run(module.get()).value());
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Dot(m::Parameter(0), m::Parameter(1)),
+ m::Parameter(2))));
+}
+
+TEST_F(AlgebraicSimplifierTest, DotRightDotContractBatchReorder) {
+ const char* hlo_string = R"(
+ HloModule module
+
+ ENTRY test {
+ a = f32[80,38,1536] parameter(0)
+ b = f32[80,38,4] parameter(1)
+ c = f32[80,4,1536] parameter(2)
+ inner = f32[80,38,1536] dot(b, c),
+ lhs_batch_dims={0},
+ lhs_contracting_dims={2},
+ rhs_batch_dims={0},
+ rhs_contracting_dims={1}
+ ROOT outer = f32[1536,1536] dot(a, inner),
+ lhs_contracting_dims={0,1},
+ rhs_contracting_dims={0,1}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ AlgebraicSimplifierOptions options;
+ options.set_use_associative_reordering(true);
+ options.set_associative_reordering_threshold(1.5);
+ AlgebraicSimplifier simplifier(options);
+ EXPECT_FALSE(simplifier.Run(module.get()).value());
+}
+
TEST_F(AlgebraicSimplifierTest, DotReverseLeftReorder) {
const char* hlo_string = R"(
HloModule module
@@ -10075,24 +10163,25 @@
TEST_F(AlgebraicSimplifierTest, ReplaceReduceSumOfConstantBroadcast) {
const char* kModuleStr = R"(
-HloModule ReplaceReduceSumOfConstantBroadcast
+ HloModule ReplaceReduceSumOfConstantBroadcast
-add_f32 {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT r = f32[] add(p0, p1)
-}
+ add_f32 {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT r = f32[] add(p0, p1)
+ }
-ENTRY main {
- init_value = f32[] constant(0)
- const_value = f32[] constant(1)
- const_bcast = f32[8, 128] broadcast(f32[] const_value), dimensions={}
- ROOT reduce = f32[8] reduce(f32[8, 128] const_bcast, f32[] init_value), dimensions={1}, to_apply=add_f32
-}
-)";
+ ENTRY main {
+ init_value = f32[] constant(0)
+ const_value = f32[] constant(1)
+ const_bcast = f32[8, 128] broadcast(f32[] const_value), dimensions={}
+ ROOT reduce = f32[8] reduce(f32[8, 128] const_bcast, f32[] init_value), dimensions={1}, to_apply=add_f32
+ }
+ )";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
- ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
+ HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+ EXPECT_TRUE(simplifier.Run(m.get()).value());
int64_t reduce_count =
absl::c_count_if(m->entry_computation()->instructions(),
HloPredicateIsOp<HloOpcode::kReduce>);
@@ -11714,6 +11803,88 @@
HloOpcode::kParameter);
}
+TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastS32) {
+ const std::string hlo_string = R"(
+ HloModule test
+ add_s32 {
+ p0 = s32[] parameter(0)
+ p1 = s32[] parameter(1)
+ ROOT r = s32[] add(p0, p1)
+ }
+ ENTRY test.1 {
+ one = s32[] constant(2)
+ init = s32[] constant(10)
+ bcast = s32[1,7,7,1] broadcast(one), dimensions={}
+ ROOT out = s32[1,7,1] reduce(bcast, init), dimensions={1}, to_apply=add_s32
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ auto clone = m->Clone();
+ HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+ EXPECT_TRUE(simplifier.Run(m.get()).value());
+ std::cout << m->ToString() << std::endl;
+ int64_t reduce_count =
+ absl::c_count_if(m->entry_computation()->instructions(),
+ HloPredicateIsOp<HloOpcode::kReduce>);
+ // Expect no Reduce operation after simplification.
+ EXPECT_EQ(0, reduce_count);
+}
+
+TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastBF16) {
+ const std::string hlo_string = R"(
+ HloModule test
+ add_bf16 {
+ p0 = bf16[] parameter(0)
+ p1 = bf16[] parameter(1)
+ ROOT r = bf16[] add(p0, p1)
+ }
+ ENTRY test.1 {
+ one = bf16[] constant(2.12)
+ init = bf16[] constant(10.34)
+ bcast = bf16[1,7,7,1] broadcast(one), dimensions={}
+ ROOT out = bf16[1,7,1] reduce(bcast, init), dimensions={1}, to_apply=add_bf16
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ auto clone = m->Clone();
+ HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+ EXPECT_TRUE(simplifier.Run(m.get()).value());
+ int64_t reduce_count =
+ absl::c_count_if(m->entry_computation()->instructions(),
+ HloPredicateIsOp<HloOpcode::kReduce>);
+ // Expect no Reduce operation after simplification.
+ EXPECT_EQ(0, reduce_count);
+}
+
+TEST_F(AlgebraicSimplifierTest, ReduceOfNonScalarBroadcast) {
+ const std::string hlo_string = R"(
+ HloModule module
+ add {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT sum = f32[] add(a, b)
+ }
+
+ ENTRY test {
+ a = f32[64,1001] parameter(0)
+ broadcast = f32[64,7,7,1001] broadcast(a), dimensions={0,3}
+ zero = f32[] constant(0)
+ ROOT reduce = f32[64,7,1001] reduce(broadcast, zero), dimensions={2},
+ to_apply=add
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+ EXPECT_TRUE(simplifier.Run(m.get()).value());
+ HloInstruction* root = m->entry_computation()->root_instruction();
+ int64_t reduce_count =
+ absl::c_count_if(m->entry_computation()->instructions(),
+ HloPredicateIsOp<HloOpcode::kReduce>);
+ // Expect no Reduce operation after simplification.
+ EXPECT_EQ(0, reduce_count);
+ EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Multiply())));
+}
+
TEST_F(AlgebraicSimplifierTest, RemoveConvertConstant) {
const std::string hlo_string = R"(
HloModule module
@@ -11773,11 +11944,31 @@
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
+ HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+ EXPECT_TRUE(simplifier.Run(m.get()).value());
HloInstruction* root = m->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMultiply);
}
+TEST_F(AlgebraicSimplifierTest, SinkCbrtThroughMax) {
+ absl::string_view hlo_string = R"(
+ HloModule module
+
+ ENTRY test {
+ a = bf16[17,96,120] parameter(0)
+ b = bf16[17,96,120] parameter(1)
+ cbrt_a = bf16[17,96,120] cbrt(a)
+ cbrt_b = bf16[17,96,120] cbrt(b)
+ ROOT max = bf16[17,96,120] maximum(cbrt_a, cbrt_b)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
+ HloInstruction* root = m->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root, GmockMatch(m::Cbrt(m::Maximum(m::Parameter(0), m::Parameter(1)))));
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/service/all_reduce_splitter.cc b/third_party/xla/xla/service/all_reduce_splitter.cc
deleted file mode 100644
index ce1e0e2..0000000
--- a/third_party/xla/xla/service/all_reduce_splitter.cc
+++ /dev/null
@@ -1,436 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/all_reduce_splitter.h"
-
-#include <cstdint>
-#include <optional>
-#include <string>
-#include <variant>
-#include <vector>
-
-#include "absl/cleanup/cleanup.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/collective_device_list.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_opt_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-// Structure containing the newly calculated replica groups.
-struct ARReplicaGroups {
- // First AR's replica group.
- std::vector<ReplicaGroup> first_ar_replica_groups;
- // Second AR's replica group.
- std::vector<ReplicaGroup> second_ar_replica_groups;
-};
-
-// Contains relevant data to rewrite the AR + DS into AR + DS + AR.
-struct AllReduceRewriteSpec {
- // Determines a dimension on which DS occurs.
- int split_dim;
- // Determines the size of the process group.
- int group_size;
- // AllReduce instruction to be rewritten.
- HloAllReduceInstruction* all_reduce;
- // DynamicSlice following the `all_reduce` indicating logical RS.
- HloDynamicSliceInstruction* dynamic_slice;
- // New replica groups for an `all_reduce`.
- ARReplicaGroups replica_groups;
-
- std::string ToString() {
- return absl::Substitute(
- "{\n split_dim=$0\n group_size=$1\n all_reduce=$2\n "
- "dynamic_slice=$3\n}\n",
- split_dim, group_size, all_reduce->ToString(),
- dynamic_slice->ToString());
- }
-};
-
-// Contains the relevant metadata for debugging why rewrite is infeasible.
-struct RewriteInfeasibleReason {
- // Instruction for which it is infeasible to do a rewrite.
- const HloInstruction* ar;
- // Describes a reason of infeasibility.
- std::string message;
-};
-
-// Hashable container to hold replica groups.
-struct ReplicaGroups {
- std::vector<ReplicaGroup> replica_groups;
-
- template <typename H>
- friend H AbslHashValue(H h, const ReplicaGroups& rg) {
- return H::combine(std::move(h), rg.replica_groups.size());
- }
-
- friend bool operator==(const ReplicaGroups& item,
- const ReplicaGroups& other) {
- if (item.replica_groups.size() != other.replica_groups.size()) {
- return false;
- }
- for (int i = 0; i < item.replica_groups.size(); i++) {
- const ReplicaGroup& item_replica_group = item.replica_groups[i];
- const ReplicaGroup& other_replica_group = other.replica_groups[i];
- for (int i = 0; i < item_replica_group.replica_ids_size(); i++) {
- if (item_replica_group.replica_ids(i) !=
- other_replica_group.replica_ids(i)) {
- return false;
- }
- }
- }
- return true;
- }
-};
-
-using ARReplicaGroupMap =
- absl::flat_hash_map<ReplicaGroups,
- std::vector<const HloAllReduceInstruction*>>;
-
-using RewriteDecision =
- std::variant<AllReduceRewriteSpec, RewriteInfeasibleReason>;
-
-// Returns a single dimension which is being split by `ds`. Returns
-// std::nullopt if there are more, or no dimension to be split.
-std::optional<int> GetSplitDim(const HloAllReduceInstruction& ar,
- const HloDynamicSliceInstruction& ds) {
- int split_dim = -1;
- int num_dims = 0;
- for (int64_t dim = 0; dim < ar.shape().rank(); ++dim) {
- if (ar.shape().dimensions(dim) != ds.shape().dimensions(dim)) {
- num_dims++;
- split_dim = dim;
- }
- }
- if (num_dims != 1) {
- VLOG(2) << "No support for multiple nor 0 split dims.";
- return std::nullopt;
- }
- return split_dim;
-}
-
-// For input collective instruction `ar` get the process group size (# shards).
-std::optional<int> GetProcessGroupSize(const HloAllReduceInstruction& ar,
- const HloDynamicSliceInstruction& ds) {
- CHECK(ds.operand(0) == &ar) << "Irrelevant AR + DS pair.";
- std::optional<int> split_dim = GetSplitDim(ar, ds);
- if (!split_dim.has_value()) {
- return std::nullopt;
- }
-
- return ar.shape().dimensions(*split_dim) /
- ds.dynamic_slice_sizes()[*split_dim];
-}
-
-ARReplicaGroupMap GetReplicaGroupsMap(HloComputation& computation) {
- ARReplicaGroupMap map;
- hlo_query::ForEachInstructionWithOpcode(
- computation, HloOpcode::kAllReduce,
- [&map](const HloInstruction* instruction) {
- const HloAllReduceInstruction* ar =
- Cast<HloAllReduceInstruction>(instruction);
- auto rgs = ReplicaGroups{ar->replica_groups()};
- map[rgs].push_back(ar);
- });
- return map;
-}
-
-ARReplicaGroups GetNewReplicaGroups(int group_size, int num_partitions) {
- CHECK_EQ(num_partitions % group_size, 0);
-
- std::vector<ReplicaGroup> first_ar_rgs, second_ar_rgs;
- int num_units = num_partitions / group_size;
- first_ar_rgs.reserve(num_units);
- second_ar_rgs.reserve(group_size);
-
- // Construct first AR replica groups.
- for (int u = 0; u < group_size * num_units; u += group_size) {
- ReplicaGroup& group = first_ar_rgs.emplace_back();
- for (int r = u; r < u + group_size; r++) {
- group.add_replica_ids(r);
- }
- }
-
- // Construct second AR replica groups.
- for (int g = 0; g < group_size; g++) {
- ReplicaGroup& group = second_ar_rgs.emplace_back();
- for (int r = g; r < group_size * num_units; r += group_size) {
- group.add_replica_ids(r);
- }
- }
- return {
- /*first_ar_replica_groups=*/first_ar_rgs,
- /*second_ar_replica_groups=*/second_ar_rgs,
- };
-}
-
-// Returns true if `spec` can be transformed into a logical reduce scatter.
-// False otherwise.
-bool IsLogicalReduceScatter(const HloModule& module,
- const AllReduceRewriteSpec& spec,
- HloComputation& computation) {
- HloAllReduceInstruction& ar = *spec.all_reduce;
- CHECK_EQ(ar.user_count(), 1);
- CHECK_EQ(module.config().replica_count(), 1);
-
- HloInstruction* first_ar =
- computation.AddInstruction(HloInstruction::CreateAllReduce(
- ar.shape(), ar.operands(), ar.to_apply(),
- CollectiveDeviceList(spec.replica_groups.first_ar_replica_groups),
- ar.constrain_layout(), hlo_query::NextChannelId(module),
- ar.use_global_device_ids()));
-
- HloInstruction* ds = ar.users()[0];
- auto* old_operand = ds->mutable_operand(0);
- if (!ds->ReplaceOperandWith(0, first_ar).ok()) {
- return false;
- }
- absl::Cleanup _ = [&] {
- CHECK_OK(ds->ReplaceOperandWith(0, old_operand));
- CHECK_OK(computation.RemoveInstruction(first_ar));
- };
- return MatchReduceScatter(Cast<HloAllReduceInstruction>(first_ar),
- module.config().num_partitions(),
- module.config().replica_count(),
- /*allow_multiple_split_dims=*/false,
- /*allow_intervening_reshape=*/true)
- .has_value();
-}
-
-// Determine whether the given `spec`'s AllReduce instruction is profitable to
-// split. Currently it employs a simple heuristic, and it checks whether there
-// exists at least one all reduce with same replica groups as any of the all
-// reduce's replica groups after the potential split.
-bool IsProfitableToSplit(const ARReplicaGroupMap& replica_map,
- const AllReduceRewriteSpec& spec) {
- auto new_rgs = spec.replica_groups;
- bool first_replica_exists =
- replica_map.contains(ReplicaGroups{new_rgs.first_ar_replica_groups});
- bool second_replica_exists =
- replica_map.contains(ReplicaGroups{new_rgs.second_ar_replica_groups});
- return first_replica_exists || second_replica_exists;
-}
-
-RewriteDecision CanRewrite(const HloModule& module,
- const ARReplicaGroupMap& replica_map,
- HloComputation& computation,
- HloInstruction& instruction) {
- // We rely on SPMD partitioning enabled, thus asserting `replica_count` = 1.
- const HloModuleConfig& config = module.config();
- if (config.use_auto_spmd_partitioning() || !config.use_spmd_partitioning() ||
- config.replica_count() != 1) {
- return RewriteInfeasibleReason{
- &instruction,
- "Supporting only SPMD partitioning scheme.",
- };
- }
-
- if (instruction.opcode() != HloOpcode::kAllReduce) {
- return RewriteInfeasibleReason{
- &instruction,
- "Cannot rewrite an AllReduce, since it's not AllReduce.",
- };
- }
-
- auto* ar = Cast<HloAllReduceInstruction>(&instruction);
-
- if (!ar->use_global_device_ids()) {
- return RewriteInfeasibleReason{
- &instruction,
- "Only global ids are supported currently.",
- };
- }
-
- if (ar->user_count() != 1 ||
- ar->users().front()->opcode() != HloOpcode::kDynamicSlice) {
- return RewriteInfeasibleReason{
- &instruction,
- "Cannot rewrite AllReduce if it is not a logical reduce scatter.",
- };
- }
-
- auto* ds = Cast<HloDynamicSliceInstruction>(ar->users().front());
-
- if (ds->user_count() > 1) {
- return RewriteInfeasibleReason{
- &instruction,
- "Exactly one user of dynamic slice is required for a rewrite.",
- };
- }
-
- int num_partitions = config.num_partitions();
-
- std::vector<ReplicaGroup> rgs = ar->replica_groups();
- if (rgs.size() != 1 || rgs.front().replica_ids_size() != num_partitions) {
- return RewriteInfeasibleReason{
- &instruction,
- absl::StrCat("Cannot determine a valid split with num_partitions: ",
- num_partitions),
- };
- }
-
- std::optional<int> split_dim = GetSplitDim(*ar, *ds);
- if (!split_dim.has_value()) {
- return RewriteInfeasibleReason{
- &instruction,
- "Cannot get a split dim.",
- };
- }
-
- std::optional<int> group_size = GetProcessGroupSize(*ar, *ds);
- if (!group_size.has_value()) {
- return RewriteInfeasibleReason{
- &instruction,
- "Cannot determine a group size.",
- };
- }
-
- if (num_partitions == group_size) {
- return RewriteInfeasibleReason{
- &instruction,
- "Nothing to rewrite",
- };
- }
-
- if (num_partitions % *group_size != 0) {
- return RewriteInfeasibleReason{
- &instruction,
- "Group size does not evenly divide the number of partitions",
- };
- }
-
- auto spec = AllReduceRewriteSpec{
- /*split_dim=*/*split_dim,
- /*group_size=*/*group_size,
- /*all_reduce=*/ar,
- /*dynamic_slice=*/ds,
- /*replica_groups=*/GetNewReplicaGroups(*group_size, num_partitions),
- };
-
- if (!IsLogicalReduceScatter(module, spec, computation)) {
- return RewriteInfeasibleReason{
- &instruction,
- "Not a logical reduce scatter.",
- };
- }
-
- if (!IsProfitableToSplit(replica_map, spec)) {
- return RewriteInfeasibleReason{
- &instruction,
- "Splitting is not profitable.",
- };
- }
-
- return spec;
-}
-
-absl::StatusOr<bool> SplitAllReduce(const HloModuleConfig& config,
- AllReduceRewriteSpec spec,
- HloComputation& computation) {
- int64_t next_channel_id =
- hlo_query::NextChannelId(*spec.all_reduce->GetModule());
- VLOG(1) << "AR splitting spec: " << spec.ToString();
- // Create first AR.
- int num_partitions = config.num_partitions();
- // # of shards within a replica
- int group_size = spec.group_size;
-
- CHECK_EQ(num_partitions % group_size, 0);
-
- HloAllReduceInstruction& ar = *spec.all_reduce;
- HloDynamicSliceInstruction& ds = *spec.dynamic_slice;
-
- const auto& [first_ar_replica_groups, second_ar_replica_groups] =
- spec.replica_groups;
- int channel_id = next_channel_id++;
- HloInstruction* first_ar =
- computation.AddInstruction(HloInstruction::CreateAllReduce(
- ar.shape(), ar.operands(), ar.to_apply(),
- CollectiveDeviceList(first_ar_replica_groups), ar.constrain_layout(),
- channel_id, ar.use_global_device_ids()));
-
- // Create second AR.
- channel_id = next_channel_id++;
- HloInstruction* second_ar =
- computation.AddInstruction(HloInstruction::CreateAllReduce(
- ds.shape(), {&ds}, ar.to_apply(),
- CollectiveDeviceList(second_ar_replica_groups), ar.constrain_layout(),
- channel_id, ar.use_global_device_ids()));
-
- // Rewire.
- TF_RETURN_IF_ERROR(computation.ReplaceInstruction(&ar, first_ar));
- if (ds.IsRoot()) {
- computation.set_root_instruction(second_ar);
- }
- TF_RETURN_IF_ERROR(ds.ReplaceAllUsesWith(second_ar));
- return true; // changed
-}
-
-// Splits `instruction` if it finds it is feasible and profitable to do so.
-// Return true if `instruction` has been rewritten, or false otherwise.
-absl::StatusOr<bool> SplitAllReduce(const HloModule& module,
- const ARReplicaGroupMap& replica_map,
- HloComputation& computation,
- HloInstruction& instruction) {
- RewriteDecision spec =
- CanRewrite(module, replica_map, computation, instruction);
- if (std::holds_alternative<RewriteInfeasibleReason>(spec)) {
- auto reason = std::get<RewriteInfeasibleReason>(spec);
- VLOG(1) << "Cannot process {" << reason.ar->ToString()
- << "} due to : " << reason.message;
- return false; // changed
- }
- return SplitAllReduce(module.config(), std::get<AllReduceRewriteSpec>(spec),
- computation); // changed
-}
-
-} // namespace
-
-absl::StatusOr<bool> AllReduceSplitter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
-
- for (auto* computation : module->computations(execution_threads)) {
- ARReplicaGroupMap replica_map = GetReplicaGroupsMap(*computation);
- for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
- TF_ASSIGN_OR_RETURN(bool rewritten, SplitAllReduce(*module, replica_map,
- *computation, *instr));
- changed |= rewritten;
- }
- }
-
- return changed;
-}
-
-} // namespace xla
diff --git a/third_party/xla/xla/service/all_reduce_splitter.h b/third_party/xla/xla/service/all_reduce_splitter.h
deleted file mode 100644
index ac8dec7..0000000
--- a/third_party/xla/xla/service/all_reduce_splitter.h
+++ /dev/null
@@ -1,77 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_ALL_REDUCE_SPLITTER_H_
-#define XLA_SERVICE_ALL_REDUCE_SPLITTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Rewrites global AR if it is in the form of AR + DS and matches existing
-// replica groups into a logical RS followed by AR.
-//
-// If the pass detects AR followed by DS, then it checks whether
-// it is profitable to break it down into a logical RS (but AR + DS still),
-// followed by an AR to keep the rewrite numerically equivalent.
-//
-// Consider following example:
-//
-// Input program:
-// HloModule m, num_partitions=8
-// p = partition_id()
-// ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3,4,5,6,7}}
-// ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
-//
-// There is a global AR performing a reduction over 8 partitions.
-// However DS is performing 8-sized slice of a 32-sized tensor which implies
-// only 4 distinct slices of a tensor, which further implies 2 replicas of each
-// calculated slice. This can be expressed as RS within the replicas followed by
-// AR across the replicas. The transformation limits collectives to the data
-// that is actually needed for the requested slice.
-//
-// Output program:
-// HloModule m, num_partitions=8
-// p = partition_id()
-// ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3},{4,5,6,7}}
-// ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
-// ar.2 = bf16[32] all-reduce(ds), replica_groups={{0,4},{1,5},{2,6},{3,7}}
-//
-// In addition the pass does the rewrite only if it finds it profitable to do
-// so. The profitability function is simple, and just checks whether there are
-// any collectives with same replica groups. If there are then the combiner pass
-// can pick it up, and fuse it into the same NCCL call.
-//
-// While the solution is orthogonal to existing known distribution patterns, in
-// practice it is profitable for HSDP style communication pattern.
-// https://arxiv.org/pdf/2203.11014
-//
-class AllReduceSplitter : public HloModulePass {
- public:
- absl::string_view name() const override { return "all-reduce-splitter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla
-
-#endif // XLA_SERVICE_ALL_REDUCE_SPLITTER_H_
diff --git a/third_party/xla/xla/service/all_reduce_splitter_test.cc b/third_party/xla/xla/service/all_reduce_splitter_test.cc
deleted file mode 100644
index 3902a97..0000000
--- a/third_party/xla/xla/service/all_reduce_splitter_test.cc
+++ /dev/null
@@ -1,506 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/all_reduce_splitter.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <string>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::tsl::testing::IsOkAndHolds;
-
-class AllReduceSplitterTest : public HloTestBase {
- public:
- absl::StatusOr<std::unique_ptr<HloModule>> PrepareModule(
- absl::string_view hlo_module, int64_t num_replicas,
- int64_t num_partitions) {
- HloModuleConfig config = GetModuleConfigForTest(
- /*replica_count=*/num_replicas,
- /*num_partitions=*/num_partitions);
- config.set_use_spmd_partitioning(num_partitions > 1);
- return ParseAndReturnVerifiedModule(hlo_module, config);
- }
-
- size_t AllReduceCount(const HloModule &module) {
- return CollectiveCount(module, HloOpcode::kAllReduce);
- }
-
- private:
- size_t CollectiveCount(const HloModule &module, HloOpcode opcode) {
- return absl::c_count_if(
- module.entry_computation()->instructions(),
- [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
- }
-};
-
-class AllReduceSplitterFilecheckTest : public AllReduceSplitterTest {
- public:
- absl::Status FileCheck(const std::string &hlo_text,
- absl::string_view pattern) {
- TF_ASSIGN_OR_RETURN(bool matched, RunFileCheck(hlo_text, pattern));
- if (!matched) {
- return absl::InternalError("Filecheck failed.");
- }
- return absl::OkStatus();
- }
-};
-
-TEST_F(
- AllReduceSplitterFilecheckTest,
- MatchBasicPatternIfDynamicSliceIsRootAndThereExistsAllReduceWithSameReplicaGroups) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
- TF_EXPECT_OK(FileCheck(module->ToString(), R"(
- CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
- CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
- CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]}
- CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
- CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
- CHECK: %[[AR1:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
- CHECK-SAME: replica_groups={[[DESIRED_RGS]]}
- CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR1]], s32[] %[[_:.*]])
- CHECK-SAME: dynamic_slice_sizes={1024}
- CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
- CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
- )"));
-}
-
-TEST_F(
- AllReduceSplitterTest,
- DoesNotMatchMatchBasicPatternIfDynamicSliceIsRootAndThereIsNoAllReduceWithSameReplicaGroups) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-
- EXPECT_EQ(AllReduceCount(*module), 1);
-}
-
-TEST_F(
- AllReduceSplitterFilecheckTest,
- MatchBasicPatternIfDynamicSliceIsNotRootAndThereExistsAllReduceWithSameReplicaGroups) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- zero = bf16[] constant(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
- broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
- ROOT _ = tuple(broadcast, first.ar)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
- TF_EXPECT_OK(FileCheck(module->ToString(), R"(
- CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
- CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
- CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
- CHECK: %[[AR0:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
- CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]}
- CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR0]], s32[] %[[_:.*]])
- CHECK-SAME: dynamic_slice_sizes={1024}
- CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
- CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
- CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
- CHECK-SAME: replica_groups={[[DESIRED_RGS]]}
- CHECK: ROOT
- CHECK-NOT: %[[AR1]]
- CHECK-SAME: %[[EXISTING_AR]]
- )"));
-}
-
-TEST_F(
- AllReduceSplitterTest,
- DoesNotMatchBasicPatternIfDynamicSliceIsNotRootAndThereIsNoAllReduceWithSameReplicaGroups) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- p.1 = bf16[2,4096,4096] parameter(1)
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
- broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
- add = bf16[2,4096,4096] add(p,p.1)
- ROOT _ = tuple(broadcast, add)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
- EXPECT_EQ(AllReduceCount(*module), 1);
-}
-
-TEST_F(AllReduceSplitterTest,
- DoesNotMatchBasicPatternIfDynamicSliceIsFullySharded) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(512)
- offset = s32[] multiply(reshape, slice_size)
- ROOT _ = bf16[512] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={512}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
- EXPECT_EQ(AllReduceCount(*module), 2);
-}
-
-TEST_F(AllReduceSplitterTest,
- DoesNotMatchBasicPatternIfItIsNotCompiledWithSPMDPartitioning) {
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
- HloModuleConfig config =
- GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/8);
- config.set_use_spmd_partitioning(false);
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string, config));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
- EXPECT_THAT(AllReduceCount(*module), 2);
-}
-
-TEST_F(AllReduceSplitterTest,
- DoesNotMatchBasicPatternIfUseGlobalDeviceIdsIsFalse) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, channel_id=1
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, channel_id=2
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-
- EXPECT_EQ(AllReduceCount(*module), 2);
-}
-
-TEST_F(AllReduceSplitterTest,
- DoesNotMatchBasicPatternIfIsNotCrossAllPartitionsAllReduce) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-
- EXPECT_EQ(AllReduceCount(*module), 2);
-}
-
-TEST_F(
- AllReduceSplitterFilecheckTest,
- PipelineMatchesBasicPatternWithDynamicSliceAsRootAndRewritesToReduceScatter) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- zero = bf16[] constant(0)
- reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- HloPassPipeline pipeline("all-reduce-splitter-rewrite");
- pipeline.AddPass<AllReduceSplitter>();
- pipeline.AddPass<ReduceScatterCreator>();
- EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
- TF_EXPECT_OK(FileCheck(module->ToString(), R"(
- CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
- CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
- CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]}
- CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
- CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
- CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
- CHECK-SAME: replica_groups={[[DESIRED_RGS]]}
- CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
- CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
- )"));
-}
-
-TEST_F(
- AllReduceSplitterFilecheckTest,
- PipelineMatchesBasicPatternWithDynamicSliceNotAsRootAndRewritesToReduceScatter) { // NOLINT
- absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
- a = bf16[] parameter(0)
- b = bf16[] parameter(1)
- ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
- p = bf16[2,4096,4096] parameter(0)
- zero = bf16[] constant(0)
- first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
- all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
- table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
- pid = u32[] partition-id()
- id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
- reshape = s32[] reshape(id)
- slice_size = s32[] constant(1024)
- offset = s32[] multiply(reshape, slice_size)
- dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
- broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
- ROOT _ = tuple(broadcast, first.ar)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
- HloPassPipeline pipeline("all-reduce-splitter-rewrite");
- pipeline.AddPass<AllReduceSplitter>();
- pipeline.AddPass<ReduceScatterCreator>();
- EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
- TF_EXPECT_OK(FileCheck(module->ToString(), R"(
- CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
- CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
- CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
- CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
- CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
- CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
- CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
- CHECK: ROOT
- CHECK-NOT: %[[AR1]]
- CHECK-SAME: %[[EXISTING_AR]]
- )"));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/call_graph.cc b/third_party/xla/xla/service/call_graph.cc
index ea16ca0..80515e1 100644
--- a/third_party/xla/xla/service/call_graph.cc
+++ b/third_party/xla/xla/service/call_graph.cc
@@ -214,8 +214,8 @@
} else if (a == b) {
return a;
} else {
- // Contexts are different and neither is kNone, ie one is kSequential and
- // the other is kParallel.
+ // Contexts are different and neither is kNone, i.e. one is kControlFlow and
+ // the other is kEmbedded.
return CallContext::kBoth;
}
}
diff --git a/third_party/xla/xla/service/call_inliner.cc b/third_party/xla/xla/service/call_inliner.cc
index 579de41..0605fbd 100644
--- a/third_party/xla/xla/service/call_inliner.cc
+++ b/third_party/xla/xla/service/call_inliner.cc
@@ -139,6 +139,16 @@
CallInliner::Inline(HloInstruction* call) {
TF_RET_CHECK(call->opcode() == HloOpcode::kCall)
<< "Instruction was not a call op: " << call->opcode();
+ if (call->is_composite()) {
+ // Remove composite FE attrs before inlining, else they will appear on the
+ // inlined instructions.
+ FrontendAttributes frontend_attributes = call->frontend_attributes();
+ frontend_attributes.mutable_map()->erase("composite.name");
+ frontend_attributes.mutable_map()->erase("composite.attributes");
+ frontend_attributes.mutable_map()->erase("composite.version");
+ call->set_frontend_attributes(frontend_attributes);
+ }
+
const auto& callees = call->called_computations();
TF_RET_CHECK(callees.size() == 1);
HloComputation* callee = callees[0];
diff --git a/third_party/xla/xla/service/call_inliner_test.cc b/third_party/xla/xla/service/call_inliner_test.cc
index da73fe6..ad6ee73 100644
--- a/third_party/xla/xla/service/call_inliner_test.cc
+++ b/third_party/xla/xla/service/call_inliner_test.cc
@@ -15,25 +15,25 @@
#include "xla/service/call_inliner.h"
+#include <cstdint>
#include <memory>
-#include <optional>
#include <string>
-#include <utility>
-#include <vector>
+#include "absl/log/log.h"
+#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_matchers.h"
-#include "xla/layout_util.h"
#include "xla/literal.h"
-#include "xla/service/hlo_pass_fix.h"
+#include "xla/literal_util.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/types.h"
#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
namespace op = xla::testing::opcode_matchers;
@@ -346,5 +346,36 @@
op::Constant(LiteralUtil::CreateR0<uint32_t>(2))));
}
}
+
+TEST_F(CallInlinerTest, InlineCompositeCall) {
+ const absl::string_view hlo_string = R"(
+ HloModule composite
+
+ %add (lhs: f32[]) -> f32[] {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] constant(2)
+ ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+ }
+
+ ENTRY %main () -> f32[] {
+ %lhs = f32[] constant(42)
+ ROOT %call = f32[] call(f32[] %lhs), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+ })";
+
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ CallInliner call_inliner(/*single_call_site=*/true);
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ ASSERT_TRUE(mutated);
+
+ ASSERT_EQ(module->entry_computation()->instruction_count(), 3);
+ auto inst = module->entry_computation()->instructions().begin();
+ EXPECT_THAT(*inst, op::Constant());
+ ++inst;
+ EXPECT_THAT(*inst, op::Constant());
+ ++inst;
+ EXPECT_THAT(*inst, op::Add());
+ EXPECT_TRUE((*inst)->frontend_attributes().map().empty());
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/service/change_op_data_type.cc b/third_party/xla/xla/service/change_op_data_type.cc
index 7c907d8..3c7875a 100644
--- a/third_party/xla/xla/service/change_op_data_type.cc
+++ b/third_party/xla/xla/service/change_op_data_type.cc
@@ -19,8 +19,7 @@
#include "xla/service/hlo_creation_utils.h"
#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
#endif // INTEL_MKL && ENABLE_ONEDNN_V3
namespace xla {
@@ -65,11 +64,11 @@
}
#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
if (instr->opcode() == HloOpcode::kDot &&
- cpu::OneDnnMatMulRewriter::ShouldRewrite(instr)) {
+ cpu::OneDnnContractionRewriter::ShouldRewriteDot(instr, true)) {
continue;
}
if (instr->opcode() == HloOpcode::kConvolution &&
- cpu::OneDnnConvolutionRewriter::ShouldRewrite(instr)) {
+ cpu::OneDnnContractionRewriter::ShouldRewriteConv(instr)) {
continue;
}
#endif // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/collective_permute_decomposer_test.cc b/third_party/xla/xla/service/collective_permute_decomposer_test.cc
index 8de403e..eac5ab0 100644
--- a/third_party/xla/xla/service/collective_permute_decomposer_test.cc
+++ b/third_party/xla/xla/service/collective_permute_decomposer_test.cc
@@ -106,7 +106,7 @@
EXPECT_THAT(
recv->ToString(),
HasSubstr(
- "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+ "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
check_metadata(recv);
check_not_pipelined(recv);
HloInstruction* recv_done = FindInstruction(module.get(), "recv-done");
@@ -118,7 +118,7 @@
EXPECT_THAT(
send->ToString(),
HasSubstr(
- "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+ "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
check_metadata(send);
check_not_pipelined(send);
HloInstruction* send_done = FindInstruction(module.get(), "send-done");
@@ -212,7 +212,7 @@
EXPECT_THAT(
recv->ToString(),
HasSubstr(
- "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+ "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
EXPECT_THAT(recv->ToString(), HasSubstr("_xla_other_attribute=\"xyz\""));
HloInstruction* recv_done = FindInstruction(module.get(), "recv-done");
@@ -224,7 +224,7 @@
EXPECT_THAT(
send->ToString(),
HasSubstr(
- "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+ "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
EXPECT_THAT(send->ToString(), HasSubstr("_xla_other_attribute=\"xyz\""));
HloInstruction* send_done = FindInstruction(module.get(), "send-done");
@@ -290,18 +290,18 @@
HloInstruction* recv = FindInstruction(module.get(), "recv");
EXPECT_EQ(recv->channel_id().value(), 1);
EXPECT_THAT(recv->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}"));
EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
HloInstruction* send = FindInstruction(module.get(), "send");
EXPECT_THAT(send->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}"));
EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
HloInstruction* recv1 = FindInstruction(module.get(), "recv.1");
EXPECT_EQ(recv1->channel_id().value(), 2);
EXPECT_THAT(
recv1->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}"));
EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
HloInstruction* recv_done1 = FindInstruction(module.get(), "recv-done.1");
EXPECT_THAT(recv_done1->ToString(),
@@ -309,7 +309,7 @@
HloInstruction* send1 = FindInstruction(module.get(), "send.1");
EXPECT_THAT(
send1->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}"));
EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
HloInstruction* send_done1 = FindInstruction(module.get(), "send-done.1");
EXPECT_THAT(send_done1->ToString(),
@@ -489,22 +489,22 @@
EXPECT_EQ(recv->channel_id().value(), 1);
EXPECT_THAT(
recv->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}"));
EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
HloInstruction* send = FindInstruction(module.get(), "send");
EXPECT_THAT(
send->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}"));
EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
HloInstruction* recv1 = FindInstruction(module.get(), "recv.1");
EXPECT_EQ(recv1->channel_id().value(), 2);
EXPECT_THAT(recv1->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}"));
EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
HloInstruction* send1 = FindInstruction(module.get(), "send.1");
EXPECT_THAT(send1->ToString(),
- HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\""));
+ HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}"));
EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
}
diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc
index 859b6c9..72eeb36 100644
--- a/third_party/xla/xla/service/collective_pipeliner.cc
+++ b/third_party/xla/xla/service/collective_pipeliner.cc
@@ -445,7 +445,6 @@
ops.end());
formatting_set.insert(source_ops.begin(), source_ops.end());
std::vector<HloInstruction*> to_return;
- absl::flat_hash_set<HloInstruction*> already_inserted;
for (const HloInstruction* op : ops) {
for (HloInstruction* operand : op->operands()) {
if (!formatting_set.count(operand)) {
@@ -1380,7 +1379,6 @@
// to still visit before visiting the HLO itself.
std::vector<std::pair<const HloInstruction*, int>> stack(
1, std::make_pair(instr, 0));
- absl::flat_hash_set<const HloInstruction*> visited;
while (!stack.empty()) {
auto& current = stack.back();
invariant_cache[std::get<0>(current)] = true;
@@ -2775,8 +2773,6 @@
instruction, false, CollectivePipeliner::PipeliningDirection::kBackward,
loop_analysis));
}
- absl::flat_hash_map<const HloInstruction*, HloInstruction*>
- loop_cond_replacements;
auto cond_builder =
HloComputation::Builder(while_loop->while_condition()->name());
HloInstruction* new_cond_param =
diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc
index 5492cbc..a924a9a 100644
--- a/third_party/xla/xla/service/collective_pipeliner_test.cc
+++ b/third_party/xla/xla/service/collective_pipeliner_test.cc
@@ -239,14 +239,14 @@
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
// CHECK: HloModule
// CHECK: %while_body
- // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}}"
+ // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}}
// CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
// CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]])
// CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
// CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}})
// CHECK: }
// CHECK: ENTRY %entry
- // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}"
+ // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}
// CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
// CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]])
// CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
@@ -315,14 +315,14 @@
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
// CHECK: HloModule
// CHECK: %while_body
- // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}}"
+ // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}}
// CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
// CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]])
// CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
// CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}})
// CHECK: }
// CHECK: ENTRY %entry
- // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}"
+ // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}
// CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
// CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]])
// CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
@@ -1507,13 +1507,13 @@
XLA_VLOG_LINES(1, module->ToString());
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
// CHECK: %while_body
- // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}"}
+ // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}}
// CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}})
// CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
// CHECK: ENTRY %entry
// CHECK: %[[while:.+]] = {{.+}} while({{.+}})
// CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1
- // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}"
+ // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}
// CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}})
// CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
// CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1
@@ -1586,13 +1586,13 @@
XLA_VLOG_LINES(1, module->ToString());
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
// CHECK: %while_body
- // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}"}
+ // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}}
// CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}})
// CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
// CHECK: ENTRY %entry
// CHECK: %[[while:.+]] = {{.+}} while({{.+}})
// CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1
- // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}"
+ // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}
// CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}})
// CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
// CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1
@@ -2426,7 +2426,7 @@
EXPECT_EQ(recv1->channel_id(), send1->channel_id());
- const char* kSourceTarget = "_xla_send_recv_source_target_pairs=\"{{3,0}}\"";
+ const char* kSourceTarget = "_xla_send_recv_source_target_pairs={{3,0}}";
const char* kPeeledAttr = "_xla_other_attr=\"1\"";
const char* kRotatedAttr = "_xla_other_attr=\"2\"";
EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kSourceTarget));
diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD
index a97736d..3ed45fe 100644
--- a/third_party/xla/xla/service/cpu/BUILD
+++ b/third_party/xla/xla/service/cpu/BUILD
@@ -222,8 +222,7 @@
":ir_emission_utils",
":ir_emitter",
":ir_emitter2",
- ":onednn_convolution_rewriter",
- ":onednn_matmul_rewriter",
+ ":onednn_contraction_rewriter",
":onednn_ops_rewriter",
":parallel_task_assignment",
":simple_orc_jit",
@@ -589,6 +588,7 @@
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
+ "@llvm-project//llvm:Core",
"@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcShared",
@@ -664,6 +664,28 @@
)
xla_cc_test(
+ name = "ir_emitter_test",
+ srcs = ["ir_emitter_test.cc"],
+ deps = [
+ ":ir_emitter",
+ ":ir_function",
+ ":target_machine_features_fake",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:buffer_assignment",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_ordering",
+ "//xla/service:hlo_parser",
+ "//xla/service:logical_buffer",
+ "//xla/tests:hlo_test_base",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+xla_cc_test(
name = "ir_emitter2_test",
srcs = ["ir_emitter2_test.cc"],
deps = [
@@ -1850,17 +1872,19 @@
)
cc_library(
- name = "onednn_matmul_rewriter",
- srcs = ["onednn_matmul_rewriter.cc"],
+ name = "onednn_contraction_rewriter",
+ srcs = ["onednn_contraction_rewriter.cc"],
hdrs = [
+ "onednn_contraction_rewriter.h",
+ "onednn_convolution.h",
"onednn_matmul.h",
- "onednn_matmul_rewriter.h",
"//xla/tsl/util:onednn_util_hdrs",
],
copts = tsl_copts(),
deps = [
":backend_config_proto_cc",
":onednn_config_proto_cc",
+ ":onednn_convolution",
":onednn_matmul",
":onednn_memory_util",
":onednn_pattern_utils",
@@ -1876,6 +1900,7 @@
"//xla/service:hlo_pass",
"//xla/service:pattern_matcher",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/status:statusor",
"@eigen_archive//:eigen3",
"@local_tsl//tsl/platform:blocking_counter",
"@local_tsl//tsl/platform:env",
@@ -1909,43 +1934,12 @@
)
cc_library(
- name = "onednn_convolution_rewriter",
- srcs = ["onednn_convolution_rewriter.cc"],
- hdrs = ["onednn_convolution_rewriter.h"],
- copts = tsl_copts(),
- deps = [
- ":backend_config_proto_cc",
- ":onednn_config_proto_cc",
- ":onednn_convolution",
- ":onednn_memory_util",
- ":onednn_util",
- "//xla:executable_run_options",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/evaluator:hlo_evaluator",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/status:statusor",
- "@eigen_archive//:eigen3",
- "@local_tsl//tsl/platform:blocking_counter",
- "@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:platform_port",
- ] + mkl_deps(),
-)
-
-cc_library(
name = "cpu_float_support",
srcs = ["cpu_float_support.cc"],
hdrs = ["cpu_float_support.h"],
copts = tsl_copts(),
deps = [
- ":onednn_convolution_rewriter",
- ":onednn_matmul_rewriter",
+ ":onednn_contraction_rewriter",
"//xla/service:float_support",
],
)
diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc
index 53be150..fc292c9 100644
--- a/third_party/xla/xla/service/cpu/cpu_compiler.cc
+++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc
@@ -205,8 +205,7 @@
#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
#include "xla/service/cpu/cpu_float_support.h"
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
#include "xla/service/cpu/onednn_ops_rewriter.h"
#include "xla/service/simplify_fp_conversions.h"
#endif
@@ -759,11 +758,10 @@
if (debug_options.xla_allow_excess_precision()) {
pipeline.AddPass<SimplifyFPConversions>();
}
- pipeline.AddPass<OneDnnConvolutionRewriter>();
- pipeline.AddPass<OneDnnMatMulRewriter>(max_parallelism,
- compile_options.thread_pool);
+ pipeline.AddPass<OneDnnContractionRewriter>(max_parallelism,
+ compile_options.thread_pool);
// Run SimplifyFPConversions pass again to remove redundant Convert ops
- // that may exist as a result of running OneDnnMatMulRewriter pass.
+ // that may exist as a result of running OneDnnContractionRewriter pass.
if (debug_options.xla_allow_excess_precision()) {
pipeline.AddPass<SimplifyFPConversions>();
}
@@ -1273,11 +1271,17 @@
cantFail((*jit)->AddModule(llvm::orc::ThreadSafeModule(
std::move(llvm_module), std::move(llvm_context))));
+ auto mangle = [&](std::string_view name) {
+ llvm::SmallVector<char, 40> mangled;
+ llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout());
+ return std::string(mangled.begin(), mangled.end());
+ };
+
// TODO(ezhulenev): We should be able to make it lazy on-demand, but today
// we capture obj_files by reference and it leads to asan errors. Figure out
// lifetime issues and move compilation to Thunk initialization stage.
for (const auto& kernel : ir_emitter2.kernels()) {
- if (auto sym = (*jit)->FindCompiledSymbol(kernel.name); !sym) {
+ if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) {
return Internal("Failed to find compiled symbol for kernel %s",
kernel.name);
}
@@ -1285,7 +1289,7 @@
// Compile auxiliary comparator functions used by sort thunks.
for (const auto& comparator : ir_emitter2.comparators()) {
- if (auto sym = (*jit)->FindCompiledSymbol(comparator.name); !sym) {
+ if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) {
return Internal("Failed to find compiled symbol for comparator %s",
comparator.name);
}
@@ -1785,16 +1789,22 @@
TF_ASSIGN_OR_RETURN(ThunkSequence thunks,
thunk_emitter.EmitEntryComputation(*module));
+ auto mangle = [&](std::string_view name) {
+ llvm::SmallVector<char, 40> mangled;
+ llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout());
+ return std::string(mangled.begin(), mangled.end());
+ };
+
// Lookup all kernel functions by name in the loaded object file.
for (const auto& kernel : ir_emitter2.kernels()) {
- if (auto sym = (*jit)->FindCompiledSymbol(kernel.name); !sym) {
+ if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) {
return Internal("Failed to find compiled symbol for kernel %s",
kernel.name);
}
}
for (const auto& comparator : ir_emitter2.comparators()) {
- if (auto sym = (*jit)->FindCompiledSymbol(comparator.name); !sym) {
+ if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) {
return Internal("Failed to find compiled symbol for comparator %s",
comparator.name);
}
diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc
index 816ad13..0bec735 100644
--- a/third_party/xla/xla/service/cpu/cpu_executable.cc
+++ b/third_party/xla/xla/service/cpu/cpu_executable.cc
@@ -37,7 +37,9 @@
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
+#include "llvm/IR/Mangler.h"
#include "llvm/Support/Error.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/hlo_computation.h"
@@ -80,12 +82,18 @@
FunctionRegistry::FunctionRegistry(SimpleOrcJIT* jit) : jit_(jit) {}
+std::string FunctionRegistry::Mangle(std::string_view name) {
+ llvm::SmallVector<char, 40> mangled;
+ llvm::Mangler::getNameWithPrefix(mangled, name, jit_->data_layout());
+ return std::string(mangled.begin(), mangled.end());
+}
+
absl::StatusOr<FunctionRegistry::Kernel> FunctionRegistry::FindKernel(
std::string_view name) {
VLOG(3) << "Find host kernel with a name " << name;
llvm::Expected<llvm::orc::ExecutorSymbolDef> sym =
- jit_->FindCompiledSymbol(std::string(name));
+ jit_->FindCompiledSymbol(Mangle(name));
if (!sym) {
return absl::InvalidArgumentError(
absl::StrCat("Can't resolve host kernel with a name ", name,
@@ -99,7 +107,7 @@
VLOG(3) << "Find comparator with a name " << name;
llvm::Expected<llvm::orc::ExecutorSymbolDef> sym =
- jit_->FindCompiledSymbol(std::string(name));
+ jit_->FindCompiledSymbol(Mangle(name));
if (!sym) {
return absl::InvalidArgumentError(
absl::StrCat("Can't resolve comparator with a name ", name,
@@ -175,6 +183,7 @@
std::move(hlo_profile_index_map), std::move(assignment)));
executable->jit_ = std::move(jit);
+ executable->jit_->DoneCompiling();
executable->function_registry_ = FunctionRegistry(executable->jit_.get());
TF_ASSIGN_OR_RETURN(executable->thunks_,
diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h
index b129674..8c8883d 100644
--- a/third_party/xla/xla/service/cpu/cpu_executable.h
+++ b/third_party/xla/xla/service/cpu/cpu_executable.h
@@ -151,6 +151,8 @@
absl::StatusOr<Comparator> FindComparator(std::string_view name) final;
private:
+ std::string Mangle(std::string_view name);
+
SimpleOrcJIT* jit_;
};
diff --git a/third_party/xla/xla/service/cpu/cpu_float_support.cc b/third_party/xla/xla/service/cpu/cpu_float_support.cc
index 6914e65..c590716 100644
--- a/third_party/xla/xla/service/cpu/cpu_float_support.cc
+++ b/third_party/xla/xla/service/cpu/cpu_float_support.cc
@@ -17,8 +17,7 @@
#include "xla/service/cpu/cpu_float_support.h"
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
namespace xla {
namespace cpu {
@@ -28,10 +27,10 @@
// oneDNN rewritable ops
case HloOpcode::kDot:
return LowPrecisionType() == BF16 &&
- OneDnnMatMulRewriter::ShouldRewrite(&hlo);
+ OneDnnContractionRewriter::ShouldRewriteDot(&hlo, true);
case HloOpcode::kConvolution:
return LowPrecisionType() == BF16 &&
- OneDnnConvolutionRewriter::ShouldRewrite(&hlo);
+ OneDnnContractionRewriter::ShouldRewriteConv(&hlo);
// Collective ops.
case HloOpcode::kAllGather:
case HloOpcode::kAllReduce:
diff --git a/third_party/xla/xla/service/cpu/executable.proto b/third_party/xla/xla/service/cpu/executable.proto
index bca8a2c..d222660 100644
--- a/third_party/xla/xla/service/cpu/executable.proto
+++ b/third_party/xla/xla/service/cpu/executable.proto
@@ -17,7 +17,6 @@
package xla.cpu;
-import "xla/service/cpu/xla_framework.proto";
import "xla/service/hlo.proto";
import "xla/xla.proto";
diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc
index d8bae80..1ee2f69 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter.cc
+++ b/third_party/xla/xla/service/cpu/ir_emitter.cc
@@ -451,6 +451,18 @@
}
}
+void IrEmitter::AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load) const {
+ AttachInvariantLoadMetadataForLoad(load, hlo_module_config_);
+}
+
+/*static*/ void IrEmitter::AttachInvariantLoadMetadataForLoad(
+ llvm::LoadInst* load, const HloModuleConfig& config) {
+ if (config.debug_options().xla_llvm_enable_invariant_load_metadata()) {
+ load->setMetadata(llvm::LLVMContext::MD_invariant_load,
+ llvm::MDNode::get(load->getContext(), /*MDs=*/{}));
+ }
+}
+
absl::Status IrEmitter::HandleGetTupleElement(
HloInstruction* get_tuple_element) {
// A tuple is an array of pointers, one for each operand. Each pointer points
@@ -4073,12 +4085,8 @@
GetBufferTableArgument(), b()->getPtrTy(), slice.index(), b());
llvm::LoadInst* tempbuf_address_base =
Load(b()->getPtrTy(), tempbuf_address_ptr);
- if (hlo_module_config_.debug_options()
- .xla_llvm_enable_invariant_load_metadata()) {
- tempbuf_address_base->setMetadata(
- llvm::LLVMContext::MD_invariant_load,
- llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
- }
+
+ AttachInvariantLoadMetadataForLoad(tempbuf_address_base);
AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h
index 0e6b1a3..d2c94a9 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter.h
+++ b/third_party/xla/xla/service/cpu/ir_emitter.h
@@ -760,8 +760,14 @@
// result with the dereferenceable bytes required by the shape / buffer size.
void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
const Shape& shape);
- void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
- int64_t buffer_size);
+ static void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ int64_t buffer_size);
+
+ // Given a load instruction, annotate the load's result with the invariant
+ // load metadata.
+ void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load) const;
+ static void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load,
+ const HloModuleConfig& config);
// Calculate the alignment of a buffer allocated for a given shape.
int MinimumAlignmentForShape(const Shape& shape);
diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc
index eeac5bc..2e64bea 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter2.cc
+++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc
@@ -40,6 +40,7 @@
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -47,7 +48,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
-#include "llvm/Support/Casting.h"
+#include "llvm/Support/CodeGen.h"
#include "xla/cpu_function_runtime.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -564,6 +565,10 @@
/*is_top_level_computation=*/true, schedule,
/*allow_reassociation=*/false));
+ // Generate unwind information so that GDB can crawl through the stack frames
+ // created by the JIT compiled code.
+ comparator_function->setUWTableKind(llvm::UWTableKind::Default);
+
return comparators_.emplace_back(
ComparatorInfo{comparator_function->getName().str()});
}
@@ -710,6 +715,15 @@
// emit metadata to allow LLVM to use that information for optimization.
llvm_ir::SetAlignmentMetadataForLoad(data, cpu_function_runtime::MinAlign());
+ // All buffers pointers passed to host kernels are expected to be
+ // dereferenceable.
+ IrEmitter::AttachDereferenceableMetadataForLoad(data, ByteSizeOf(shape));
+
+ // All buffers pointers passed to host kernels are expected to be invariant
+ // over the whole program. Note the metadata is attached only to loading
+ // buffer pointers, not to loading actual buffers.
+ AttachInvariantLoadMetadataForLoad(data);
+
return llvm_ir::IrArray(data, llvm_ir::ShapeToIrType(shape, module_), shape);
}
@@ -791,11 +805,15 @@
}
}
- // Create a kernel function with HostKernel API.
- llvm::Function* function = llvm::dyn_cast<llvm::Function>(
- module_->getOrInsertFunction(name, KernelFunctionTy(ctx)).getCallee());
+ // Create a kernel function with HostKernel API. We use external linkage
+ // because we'll be resolving this function from the XLA runtime.
+ llvm::Function* function = llvm::Function::Create(
+ KernelFunctionTy(ctx), llvm::GlobalValue::ExternalLinkage, name, module_);
function->setCallingConv(llvm::CallingConv::C);
- function->setDoesNotThrow();
+
+ // Generate unwind information so that GDB can crawl through the stack frames
+ // created by the JIT compiled code.
+ function->setUWTableKind(llvm::UWTableKind::Default);
// Set prefer-vector-width attribute to allow LLVM to use wider vector
// registers (by default LLVM uses at most 256-bit registers).
@@ -1027,4 +1045,17 @@
return se::ThreadDim();
}
+// This is a convenience function taken from IrEmitter, it uses module_ class
+// field. If there will be more functions that use module_, we should consider
+// refactoring (like we did for compute_function_ and builder_).
+int64_t IrEmitter2::ByteSizeOf(const Shape& shape) const {
+ return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
+}
+
+void IrEmitter2::AttachInvariantLoadMetadataForLoad(
+ llvm::LoadInst* instr) const {
+ nested_ir_emitter_->AttachInvariantLoadMetadataForLoad(instr,
+ hlo_module_.config());
+}
+
} // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h
index c998840..a205e91 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter2.h
+++ b/third_party/xla/xla/service/cpu/ir_emitter2.h
@@ -28,6 +28,7 @@
#include "absl/types/span.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -240,6 +241,13 @@
bool fast_min_max() const;
+ // Returns the number of bytes within the shape.
+ int64_t ByteSizeOf(const Shape& shape) const;
+
+ // Given a load instruction, annotate the load's result with the invariant
+ // load metadata.
+ void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* instr) const;
+
const HloModule& hlo_module_;
llvm::Module* module_;
diff --git a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc
index 539facd..a70241b 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc
+++ b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc
@@ -85,45 +85,45 @@
ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"(
CHECK: define ptr @test(ptr %0) #0 {
- CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 0
- CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 0
- CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 1
- CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 2
+ CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 0
+ CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 0
+ CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 1
+ CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 2
CHECK: load i64
CHECK: load i64
CHECK: load i64
- CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 1
- CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 0
- CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 1
- CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 2
+ CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 1
+ CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 0
+ CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 1
+ CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 2
CHECK: load i64
CHECK: load i64
CHECK: load i64
- CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+ CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 0, i32 0
- CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT:.+]]
+ CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.*]], !align ![[ALIGNMENT:.+]]
- CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+ CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 1, i32 0
- CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]]
+ CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
- CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+ CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 2, i32 0
- CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]]
+ CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
- CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+ CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 3, i32 0
- CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]]
+ CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
CHECK-NEXT: %[[PTR0:.+]] = getelementptr inbounds float, ptr %[[ARG0]]
CHECK: load float, ptr %[[PTR0]], align 4,
- CHECK-SAME: !invariant.load ![[SCOPE0:.+]],
+ CHECK-SAME: !invariant.load ![[SCOPE0]],
CHECK-SAME: !noalias ![[SCOPE1:.+]]
CHECK-NEXT: %[[PTR1:.+]] = getelementptr inbounds float, ptr %[[ARG1]]
@@ -142,6 +142,8 @@
CHECK: ret ptr null
CHECK: }
+ #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
+ CHECK-DAG: ![[DEREF_BYTES]] = !{i64 32}
CHECK-DAG: ![[ALIGNMENT]] = !{i64 16}
CHECK-DAG: ![[SCOPE0]] = !{}
CHECK-DAG: ![[SCOPE1]] = !{![[RES0:.+]], ![[RES1:.+]]}
diff --git a/third_party/xla/xla/service/cpu/ir_emitter_test.cc b/third_party/xla/xla/service/cpu/ir_emitter_test.cc
new file mode 100644
index 0000000..7102d20
--- /dev/null
+++ b/third_party/xla/xla/service/cpu/ir_emitter_test.cc
@@ -0,0 +1,124 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/cpu/ir_emitter.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/Casting.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/buffer_assignment.h"
+#include "xla/service/cpu/ir_function.h"
+#include "xla/service/cpu/target_machine_features_fake.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_ordering.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/logical_buffer.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla::cpu {
+namespace {
+
+using IrEmitterTest = HloTestBase;
+
+static std::pair<llvm::Function*, llvm::BasicBlock*> CreateFunction(
+ llvm::LLVMContext& context, llvm::Module* module, llvm::IRBuilder<>* b) {
+ llvm::PointerType* ptrtype = llvm::PointerType::getUnqual(context);
+ llvm::FunctionType* ftype = llvm::FunctionType::get(ptrtype, ptrtype, false);
+
+ llvm::Function* function = llvm::dyn_cast<llvm::Function>(
+ module->getOrInsertFunction("func2", ftype).getCallee());
+
+ llvm::BasicBlock* return_block =
+ llvm::BasicBlock::Create(context, "", function);
+ b->SetInsertPoint(return_block);
+ [[maybe_unused]] llvm::ReturnInst* ret = b->CreateRet(
+ llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(context)));
+
+ return std::make_pair(function, return_block);
+}
+
+TEST_F(IrEmitterTest, ComputeFuncStack) {
+ llvm::LLVMContext context;
+ auto module = std::make_unique<llvm::Module>("test", context);
+
+ const char* hlo_text = R"(
+ HloModule m
+ ENTRY main {
+ ROOT %zero = f32[] constant(0)
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text));
+ const HloInstruction* zero = FindInstruction(hlo.get(), "zero");
+ ASSERT_NE(zero, nullptr);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<BufferAssignment> buffer_assignment,
+ BufferAssigner::Run(
+ hlo.get(), std::make_unique<DependencyHloOrdering>(hlo.get()),
+ backend().compiler()->BufferSizeBytesFunction(),
+ [](LogicalBuffer::Color) { return /*alignment=*/1; }));
+
+ TargetMachineFeaturesWithFakeAlignmentLogic target_machine(
+ [](int64_t size) { return 1; });
+
+ IrEmitter ir_emitter(nullptr, *hlo, *buffer_assignment, module.get(), {}, {},
+ {}, &target_machine, false);
+
+ llvm::IRBuilder<>* b = ir_emitter.b();
+ ASSERT_NE(b, nullptr);
+
+ const std::pair<llvm::Function*, llvm::BasicBlock*> fb =
+ CreateFunction(context, module.get(), b);
+
+ llvm::Function* function = fb.first;
+ llvm::BasicBlock* return_block = fb.second;
+
+ ASSERT_NE(function, nullptr);
+ ASSERT_NE(return_block, nullptr);
+
+ const auto funcname = "func1";
+ const auto linkagetype = llvm::GlobalValue::LinkageTypes::ExternalLinkage;
+ const HloModuleConfig module_config;
+ ir_emitter.PushComputeFunction(funcname, linkagetype, module_config,
+ module.get(), 0);
+ ASSERT_EQ(ir_emitter.compute_function()->function()->getName().str(),
+ funcname);
+
+ ir_emitter.PushComputeFunction(b, module.get(), 0, function, nullptr,
+ return_block);
+ ASSERT_EQ(ir_emitter.compute_function()->function(), function);
+
+ ir_emitter.PopComputeFunction();
+ ASSERT_EQ(ir_emitter.compute_function()->function()->getName().str(),
+ funcname);
+
+ ir_emitter.PopComputeFunction();
+}
+
+} // namespace
+} // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc
new file mode 100644
index 0000000..19122b3
--- /dev/null
+++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc
@@ -0,0 +1,1258 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
+
+#define EIGEN_USE_THREADS
+
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
+
+#include "xla/executable_run_options.h"
+#include "xla/hlo/evaluator/hlo_evaluator.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/cpu/backend_config.pb.h"
+#include "xla/service/cpu/onednn_config.pb.h"
+#include "xla/service/cpu/onednn_convolution.h"
+#include "xla/service/cpu/onednn_matmul.h"
+#include "xla/service/cpu/onednn_memory_util.h"
+#include "xla/service/cpu/onednn_pattern_utils.h"
+#include "xla/service/cpu/onednn_util.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/status_macros.h"
+#include "xla/tsl/util/onednn_threadpool.h"
+#include "tsl/platform/logging.h" // IWYU pragma: keep
+
+namespace xla {
+namespace cpu {
+
+namespace {
+namespace m = match;
+namespace pu = ::xla::cpu::onednn_pattern_utils_internal;
+
+inline absl::Status ValidateDotDimensionNumbers(
+ const DotDimensionNumbers& dim_numbers) {
+ // Checks some invariants that do not hold in general, but DotDecomposer
+ // should have established for us.
+ TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
+ std::vector<int64_t> batch_dim_numbers(
+ dim_numbers.lhs_batch_dimensions_size());
+ absl::c_iota(batch_dim_numbers, 0);
+ TF_RET_CHECK(
+ absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
+ TF_RET_CHECK(
+ absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
+ return absl::OkStatus();
+}
+
+// Whether the element type of instr is compatible with oneDNN kernels.
+// TODO(intel-tf): Restict compatible types based on instruction kind.
+inline bool CompatibleElementType(const HloInstruction* instr) {
+ PrimitiveType element_type = instr->shape().element_type();
+ return element_type == BF16 || element_type == F32 || element_type == F16;
+}
+
+inline bool IsRowMajor(const Shape& shape) {
+ return LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
+}
+
+template <typename Pattern>
+inline auto BitcastWithReshapeSemantics(HloInstruction** bitcast,
+ Pattern pattern) {
+ // TODO(intel-tf): Add stronger condition that Bitcast does not have transpose
+ // semantics. Some of the HLO passes replaces Transpose with Bitcast. Here
+ // the layouts are checked to be rowmajor since the current pass runs after
+ // the layout assignment and oneDNN matmul is enabled for rowmajor layouts.
+ auto is_reshape = [](const HloInstruction* instr) -> bool {
+ if (!instr) return false;
+ auto input_shape = instr->operand(0)->shape();
+ auto output_shape = instr->shape();
+ bool is_same_type = ShapeUtil::SameElementType(input_shape, output_shape);
+ bool has_equal_num_elems = ShapeUtil::ElementsIn(input_shape) ==
+ ShapeUtil::ElementsIn(output_shape);
+ bool has_rowmajor_layout =
+ IsRowMajor(input_shape) && IsRowMajor(output_shape);
+ return is_same_type && has_equal_num_elems && has_rowmajor_layout;
+ };
+ return m::Bitcast(bitcast, pattern).WithPredicate(is_reshape);
+}
+
+template <typename Pattern>
+auto ElementwiseSafeIntermediates(HloInstruction** instr,
+ HloInstruction** optional_bitcast,
+ Pattern pattern) {
+ return m::AnyOf<HloInstruction>(
+ m::Broadcast(instr, pattern.WithOneUser()),
+ m::Slice(instr, pattern.WithOneUser()),
+ m::Bitcast(instr, pattern.WithOneUser()),
+ m::Reshape(instr, pattern.WithOneUser()),
+ pu::SupportedConvert(instr, pattern.WithOneUser()),
+ pu::SupportedConvert(instr, BitcastWithReshapeSemantics(
+ optional_bitcast, pattern.WithOneUser())),
+ pattern);
+}
+
+inline auto OneDnnMatmulInstr(HloInstruction** instr) {
+ return m::CustomCall(instr, {"__onednn$matmul"});
+}
+
+inline auto ConvertBF16ToF32(HloInstruction** instr) {
+ return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16))
+ .WithElementType(PrimitiveType::F32);
+}
+
+inline auto BcastConstScalar(HloInstruction** instr, double value) {
+ return m::Broadcast(instr, m::ConstantScalar(value));
+}
+
+inline auto BcastConstScalar(double value) {
+ return BcastConstScalar(nullptr, value);
+}
+
+inline auto BcastConvertConstScalar(double value) {
+ return m::Broadcast(pu::OptionalConvert(m::ConstantScalar(value)));
+}
+
+inline bool IsBatchDot(const HloInstruction& instr) {
+ if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
+ return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
+ }
+ return false;
+}
+
+auto ConstScalarNear(double value) {
+ return m::ConstantScalar().WithPredicate(
+ [expected = value](const HloInstruction* instr) {
+ // Not a very robust floating-point comparison, but good enough for our
+ // purposes.
+ std::optional<double> actual =
+ static_cast<const HloConstantInstruction*>(instr)
+ ->literal()
+ .GetAsDouble({});
+ if (!actual.has_value()) return false;
+ double epsilon;
+ switch (instr->shape().element_type()) {
+ case F16:
+ epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
+ break;
+ case BF16:
+ epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
+ break;
+ case F32:
+ epsilon = 128 * std::numeric_limits<float>::epsilon();
+ break;
+ case F64:
+ epsilon = 128 * std::numeric_limits<double>::epsilon();
+ break;
+ default:
+ return false;
+ }
+ return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
+ });
+}
+
+bool IsScalar(const HloInstruction* instr) {
+ return ShapeUtil::IsEffectiveScalar(instr->shape());
+}
+
+std::optional<float> GetConstantValueAsFloat32(const HloInstruction* inst) {
+ if (!IsScalar(inst)) {
+ return std::nullopt;
+ }
+ switch (inst->shape().element_type()) {
+ case F16:
+ return inst->literal().GetFirstElement<half>();
+ case BF16:
+ return inst->literal().GetFirstElement<bfloat16>();
+ case F32:
+ return inst->literal().GetFirstElement<float>();
+ default:
+ return std::nullopt;
+ }
+}
+
+inline auto BcastConstScalarNear(double value) {
+ return m::Broadcast(ConstScalarNear(value));
+}
+
+// Associativity and commutativity properties of multiply results in various
+// patterns for an equivalent computation. This function tries to capture most
+// of the variations for a computation a * b * c. For example, patterns could be
+// any of (a * b) * c or a * (b * c), along with the variations resulting from
+// commutative patterns.
+template <typename PatternA, typename PatternB, typename PatternC>
+inline auto MultiplyMultiplyAnyOrder(PatternA a, PatternB b, PatternC c) {
+ return m::AnyOf<HloInstruction>(
+ m::MultiplyAnyOrder(a, m::MultiplyAnyOrder(b, c)),
+ m::MultiplyAnyOrder(b, m::MultiplyAnyOrder(a, c)),
+ m::MultiplyAnyOrder(c, m::MultiplyAnyOrder(a, b)));
+}
+
+auto GELUActivation(HloInstruction* instr, HloInstruction** src) {
+ // Attempt to match GELU_TANH activation or GELU_ERF activation
+ // (https://arxiv.org/abs/1606.08415), where:
+ // gelu_tanh(x) = x * cdf(x)
+ // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
+ // -------------errf_approximate------------
+ //
+ // gelu_erf(x) = x * cdf(x)
+ // cdf(x) = 0.5 * (1 + erf(x / sqrt(2)))
+ // --errf_exact--
+
+ HloInstruction* errf;
+
+ // The expression 0.5 * x * (1 + errf) as common pattern for GELU exact and
+ // approximate activations.
+ auto common_pattern = MultiplyMultiplyAnyOrder(
+ BcastConstScalar(0.5), m::Op(src),
+ m::AddAnyOrder(BcastConstScalar(1.0), m::Op(&errf).WithOneUser()));
+
+ bool matched = Match(instr, common_pattern);
+ if (matched) {
+ // The subexpression 0.044715 * x**3 appears in GELU approximate activation.
+ // However, it is often optimized by other HLO passes into an expression of
+ // 0.044715 * x * (x * x). Since there are three consecutive multiplies,
+ // there could be a large number of patterns. We try to capture some of
+ // those:
+ //
+ // 1. (0.044715 * x) * x * x
+ // 2. 0.044715 * (x * x) * x
+ //
+ // Note each of the above could in turn have various patterns due to
+ // associativity and commutativity properties of multiply.
+ auto subexpr_pattern = m::AnyOf<HloInstruction>(
+ MultiplyMultiplyAnyOrder(
+ m::MultiplyAnyOrder(BcastConstScalarNear(0.044715),
+ m::Op().Is(*src)),
+ m::Op().Is(*src), m::Op().Is(*src)),
+ MultiplyMultiplyAnyOrder(
+ BcastConstScalarNear(0.044715),
+ m::Multiply(m::Op().Is(*src), m::Op().Is(*src)), m::Op().Is(*src)));
+
+ auto errf_apprx_pattern =
+ m::Tanh(m::MultiplyAnyOrder(
+ BcastConstScalarNear(sqrt(M_2_PI)),
+ m::AddAnyOrder(m::Op().Is(*src), subexpr_pattern)
+ .WithOneUser()))
+ .WithOneUser();
+
+ HloInstruction* erf;
+ auto errf_exact_pattern =
+ m::Op(&erf)
+ .WithOpcode(HloOpcode::kErf)
+ .WithOperand(
+ 0, m::MultiplyAnyOrder(m::Op(src),
+ m::AnyOf<HloInstruction>(
+ BcastConstScalarNear(0.707106769),
+ BcastConstScalarNear(0.70703125),
+ BcastConstScalarNear(0.707182348)))
+ .WithOneUser())
+ .WithOneUser();
+
+ if (Match(errf, errf_apprx_pattern)) {
+ // Matched Gelu-approximate pattern
+ return OneDnnFusionConfig::GELU_TANH;
+ } else if (Match(errf, errf_exact_pattern)) {
+ // Matched Gelu-exact pattern
+ return OneDnnFusionConfig::GELU_ERF;
+ }
+ }
+ return OneDnnFusionConfig::UNDEFINED;
+}
+
+// OneDNN matmul can fuse add operation with automatic broadcasting along the
+// addend's dimensions that are 1s. When compatible, Broadcast can be replaced
+// by Bitcast, which is much cheaper. Compute new shape for the Bitcast.
+absl::StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr,
+ const Shape& dot_shape) {
+ if (broadcast_instr->opcode() != HloOpcode::kBroadcast) {
+ return absl::InvalidArgumentError(
+ "Hlo instruction is not a Broadcast insruction.");
+ }
+ auto bcast = Cast<HloBroadcastInstruction>(broadcast_instr);
+ Shape new_shape = bcast->shape();
+ // Broadcast instruction has "dimensions" parameter along which its input's
+ // dimensions should not change. For example,
+ // dot = f32[3,4,5,6] dot(...)
+ // arg = f32[3,6]{1,0} parameter(0)
+ // broad = f32[3,4,5,6]{3,2,1,0} broadcast(arg), dimensions={0,3}
+ // add = f32[3,4,5,6]{3,2,1,0} add(dot, arg)
+ // can be replaced with the following
+ // arg = f32[3,6]{1,0} parameter(0)
+ // bitcast = f32[3,1,1,6]{3,2,1,0} bitcast(arg)
+ // fused = f32[3,4,5,6]{3,2,1,0} custom-call((..., bitcast)
+ auto kept_dimensions = bcast->dimensions();
+ for (int i = 0; i < new_shape.rank(); i++) {
+ if (!absl::c_linear_search(kept_dimensions, i)) {
+ new_shape.set_dimensions(i, 1);
+ }
+ }
+
+ // If rank(new_shape) > rank(dot), extra dimensions with value = 1 can be
+ // deleted from the new_shape.
+ int64_t rank_difference = new_shape.rank() - dot_shape.rank();
+ auto new_dims = new_shape.dimensions();
+ std::vector<int64_t> dims_to_delete;
+ for (int i = 0; i < rank_difference; ++i) {
+ if (new_dims[i] == 1) {
+ dims_to_delete.push_back(i);
+ }
+ }
+ new_shape = ShapeUtil::DeleteDimensions(dims_to_delete, new_shape);
+
+ // New shape for bias should satisfy the condition:
+ // rank(new_shape) <= rank(dot).
+ if (new_shape.rank() > dot_shape.rank()) {
+ return absl::CancelledError(
+ "Bias shape could not be adjusted for a fusion.");
+ }
+
+ return new_shape;
+};
+
+inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) {
+ // Check if the operand's shape is compatible with matmul for fusion.
+ // An operand is fusable if
+ // 1. rank(operand) <= rank(dot) and
+ // 2. Starting from the last dim in backward direction, the dimension
+ // size of operand is either 1 or same to dot.
+ auto operand_dims = operand->shape().dimensions();
+ auto dot_dims = dot->shape().dimensions();
+ if (operand_dims.size() > dot_dims.size()) return false;
+ int operand_idx = operand_dims.size() - 1;
+ int dot_idx = dot_dims.size() - 1;
+ for (; operand_idx >= 0; --operand_idx, --dot_idx) {
+ if (operand_dims[operand_idx] != 1 &&
+ operand_dims[operand_idx] != dot_dims[dot_idx])
+ return false;
+ }
+ return true;
+}
+
+template <typename Pattern>
+inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert,
+ HloInstruction** optional_bitcast,
+ Pattern pattern) {
+ // Checks the presence of some intermediate operations that can be moved /
+ // folded to allow dot fusion with add.
+ // Try to match either of the following:
+ // 1. pattern-root -> bf16/f16-to-fp32 convert -> bitcast
+ // 2. pattern-root -> bf16/f16-to-fp32 convert
+ // 3. pattern-root -> bitcast
+ // 4. pattern-root
+ auto common = m::AnyOf<HloInstruction>(
+ pu::SupportedConvert(optional_convert, std::move(pattern).WithOneUser())
+ .WithElementType(PrimitiveType::F32),
+ std::move(pattern).WithOneUser());
+ return m::AnyOf<HloInstruction>(
+ BitcastWithReshapeSemantics(optional_bitcast, common), common);
+}
+
+} // namespace
+
+bool OneDnnContractionRewriter::ShouldRewriteDot(
+ const HloInstruction* dot_instr, bool before_layout_assignment) {
+ // Currently, blocking control dependencies
+ if (dot_instr->HasControlDependencies()) return false;
+ if (!IsSupportedType(dot_instr->shape().element_type())) return false;
+ if (dot_instr->operands().size() != 2) return false;
+
+ // Currently, we rewrite when the data type is F32 or BF16. Note we do not
+ // need to check equality of contraction dim-size of the operands. HLO
+ // verifier already does the job. We, however, need to check if contraction
+ // is over only 1 dimension (a.k.a. K dimension in matrix-multiplication
+ // parlance). We also restrict that batch dimensions of the operands
+ // match.
+ const Shape& lhs_shape = dot_instr->operand(0)->shape();
+ const Shape& rhs_shape = dot_instr->operand(1)->shape();
+ const Shape& output_shape = dot_instr->shape();
+ // None of the operands and result should be ZeroElementArray.
+ if (ShapeUtil::IsZeroElementArray(lhs_shape) ||
+ ShapeUtil::IsZeroElementArray(rhs_shape) ||
+ ShapeUtil::IsZeroElementArray(output_shape)) {
+ return false;
+ }
+ // OneDNN only supports rank <= kOneDnnMaxNDims and singular non-contracting
+ // dimensions. We should not rewrite if any of these conditions are violated.
+ if (lhs_shape.rank() <= 0 || lhs_shape.rank() > kOneDnnMaxNDims ||
+ rhs_shape.rank() <= 0 || rhs_shape.rank() > kOneDnnMaxNDims ||
+ output_shape.rank() > std::min({lhs_shape.rank(), rhs_shape.rank(),
+ static_cast<int64_t>(kOneDnnMaxNDims)})) {
+ return false;
+ }
+
+ // Layout should be row-major, contraction dimensions captures transpose
+ // scenarios in last two dimensions.
+ // Col-major layouts are corrected to row-major for BatchDot operation as
+ // part of the layout-assignment pass.
+ // Skip row-major layout check before layout-assignment pass
+ if (!before_layout_assignment) {
+ bool row_major = IsRowMajor(lhs_shape) && IsRowMajor(rhs_shape) &&
+ IsRowMajor(output_shape);
+ if (!row_major) return false;
+ }
+
+ auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
+ int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
+ int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
+ // Supported contraction is only in one of last two dimensions.
+ if (lhs_dim_k < lhs_shape.rank() - 2 || rhs_dim_k < rhs_shape.rank() - 2) {
+ return false;
+ }
+
+ // OneDNN matmul has scratch allocation and copy overheads. The overheads
+ // can be amortized if there is sufficient number of flops. We don't rewrite
+ // for small cases (determined empirically).
+ // TODO(intel-tf): Relax the condition when more optimizations in oneDNN
+ // matmul is achieved.
+ auto num_flops = xla::HloCostAnalysis::GetDotFlops(lhs_shape, output_shape,
+ dot_dim_numbers);
+ auto rank = output_shape.rank();
+ auto flops_threshold = (rank <= 2) ? (1 << 24) : (1 << 19);
+ return (num_flops >= flops_threshold);
+}
+
+bool OneDnnContractionRewriter::ShouldRewriteConv(
+ const HloInstruction* conv_instr) {
+ if (conv_instr->HasControlDependencies()) return false;
+ if (!IsSupportedType(conv_instr->shape().element_type())) return false;
+ if (conv_instr->batch_group_count() != 1) return false;
+
+ // TODO(intel-tf): Remove this restriction after enabling backward weights
+ // support
+ if (conv_instr->operand(1)->opcode() == HloOpcode::kReverse) return false;
+
+ const Shape& inp_shape = conv_instr->operand(0)->shape();
+ const Shape& ker_shape = conv_instr->operand(1)->shape();
+ const Shape& out_shape = conv_instr->shape();
+ if (ShapeUtil::IsZeroElementArray(inp_shape) ||
+ ShapeUtil::IsZeroElementArray(ker_shape) ||
+ ShapeUtil::IsZeroElementArray(out_shape)) {
+ return false;
+ }
+
+ auto dims = conv_instr->window().dimensions().size();
+ if (dims >= 4 || dims <= 0) return false;
+
+ if (inp_shape.rank() != ker_shape.rank() ||
+ inp_shape.rank() != out_shape.rank()) {
+ return false;
+ }
+
+ return true;
+}
+
+class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor {
+ public:
+ // Matches patterns for possible MatMul fusions that are supported by oneDNN
+ // library. Matched HLO instruction(s) are replaced by custom call.
+ absl::Status HandleDot(HloInstruction* instr) override {
+ HloInstruction* dot_instr;
+ auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot);
+ if (!Match(instr, pattern)) return absl::OkStatus();
+
+ TF_RETURN_IF_ERROR(
+ ValidateDotDimensionNumbers(dot_instr->dot_dimension_numbers()));
+ if (!OneDnnContractionRewriter::ShouldRewriteDot(dot_instr)) {
+ TF_RETURN_IF_ERROR(UpcastDotToF32(dot_instr));
+ return absl::OkStatus();
+ }
+ TF_ASSIGN_OR_RETURN(dot_instr, ReconfigureDotDimensions(dot_instr));
+ auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
+ const Shape& lhs_shape = dot_instr->operand(0)->shape();
+ const Shape& rhs_shape = dot_instr->operand(1)->shape();
+ const Shape& output_shape = dot_instr->shape();
+
+ int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
+ int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
+
+ HloInstruction* matmul_call =
+ dot_instr->AddInstruction(HloInstruction::CreateCustomCall(
+ output_shape,
+ {dot_instr->mutable_operand(0), dot_instr->mutable_operand(1)},
+ "__onednn$matmul"));
+ // Set additional info via config, e.g., transpose and fusion info.
+ BackendConfig backend_config;
+ OneDnnMatMulConfig* matmul_config =
+ backend_config.mutable_onednn_matmul_config();
+ bool transpose_a = (lhs_dim_k != lhs_shape.rank() - 1);
+ bool transpose_b = (rhs_dim_k != rhs_shape.rank() - 2);
+ matmul_config->set_transpose_a(transpose_a);
+ matmul_config->set_transpose_b(transpose_b);
+ TF_RETURN_IF_ERROR(matmul_call->set_backend_config(backend_config));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, matmul_call));
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleConvolution(HloInstruction* conv) override {
+ if (!OneDnnContractionRewriter::ShouldRewriteConv(conv)) {
+ return absl::OkStatus();
+ }
+
+ const Shape& conv_shape = conv->shape();
+ auto dims = conv->window().dimensions().size();
+ const ConvolutionDimensionNumbers& conv_dims =
+ conv->convolution_dimension_numbers();
+
+ BackendConfig backend_config;
+ OneDnnConvolutionConfig* conv_config =
+ backend_config.mutable_onednn_conv_config();
+
+ conv_config->set_dims(conv_shape.rank());
+ conv_config->set_feature_groups(conv->feature_group_count());
+ conv_config->mutable_input()->mutable_data()->set_batch_dim(
+ conv_dims.input_batch_dimension());
+ conv_config->mutable_kernel()->mutable_filter()->set_input_feature_dim(
+ conv_dims.kernel_input_feature_dimension());
+ conv_config->mutable_output()->mutable_data()->set_batch_dim(
+ conv_dims.output_batch_dimension());
+ conv_config->mutable_input()->mutable_data()->set_feature_dim(
+ conv_dims.input_feature_dimension());
+ conv_config->mutable_kernel()->mutable_filter()->set_output_feature_dim(
+ conv_dims.kernel_output_feature_dimension());
+ conv_config->mutable_output()->mutable_data()->set_feature_dim(
+ conv_dims.output_feature_dimension());
+
+ const Shape& output_shape = conv->shape();
+
+ for (auto it = conv->window().dimensions().begin();
+ it != conv->window().dimensions().end(); it++) {
+ if ((*it).padding_low() < 0 || (*it).padding_high() < 0 ||
+ (*it).stride() < 0 || (*it).base_dilation() != 1 ||
+ (*it).window_reversal()) {
+ return absl::OkStatus();
+ }
+ // Changing the input subspace of uint repeated fields from whole numbers
+ // to natural nummbers to avoid misinterpretation of buffer values.
+ conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1);
+ conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1);
+ conv_config->mutable_window()->add_strides((*it).stride() + 1);
+ conv_config->mutable_window()->add_window_dilations(
+ (*it).window_dilation() + 1);
+ }
+
+ for (int i = 0; i < dims; i++) {
+ conv_config->mutable_input()->mutable_data()->add_spatial_dims(
+ conv_dims.input_spatial_dimensions()[i] + 1);
+ conv_config->mutable_kernel()->mutable_filter()->add_spatial_dims(
+ conv_dims.kernel_spatial_dimensions()[i] + 1);
+ conv_config->mutable_output()->mutable_data()->add_spatial_dims(
+ conv_dims.output_spatial_dimensions()[i] + 1);
+ }
+
+ HloInstruction* custom_call =
+ conv->AddInstruction(HloInstruction::CreateCustomCall(
+ output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)},
+ "__onednn$convolution"));
+
+ TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call));
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleAdd(HloInstruction* instr) override {
+ // Try to do a fusion for Dot(onednn-matmul) + Add. However,
+ // HLO Add instruction might receive the addends after additional
+ // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw
+ // addends. Here, the following possible pattern is matched.
+ //
+ // clang-format off
+ //
+ // Dot addend
+ // | |
+ // v v
+ // optional instructions optional instructions
+ // (e.g, Convert, Bitcast) (e.g, Convert, Broadcast)
+ // | |
+ // +--------------+-------------------+
+ // |
+ // v
+ // Add
+ //
+ // clang-format on
+
+ HloInstruction *addend_intermediate, *dot;
+ HloInstruction* optional_dot_bitcast = nullptr;
+ HloInstruction* optional_dot_convert = nullptr;
+
+ auto pattern = m::AddAnyOrder(
+ &instr,
+ OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast,
+ OneDnnMatmulInstr(&dot))
+ .WithOneUser(),
+ m::Op(&addend_intermediate));
+
+ if (Match(instr, pattern)) {
+ if (!IsSupportedType(dot->shape().element_type()))
+ return absl::OkStatus();
+ // TODO(intel-tf): Remove the condition below when the fusion Dot +
+ // Add(bias) + Add(e.g., residual) is enabled.
+ if (!dot->backend_config<BackendConfig>()
+ ->mutable_onednn_matmul_config()
+ ->mutable_fusions()
+ ->ops()
+ .empty() &&
+ dot->backend_config<BackendConfig>()
+ ->mutable_onednn_matmul_config()
+ ->mutable_fusions()
+ ->ops(0) == OneDnnFusionConfig::BIAS) {
+ return absl::OkStatus();
+ }
+ std::vector<HloInstruction*> new_operands;
+ for (auto operand : dot->operands()) {
+ new_operands.push_back(operand);
+ }
+
+ // At this point, the addend could have one of the following
+ // possiblities that the current fusion can handle:
+ //
+ // - addend -> Convert -> Broadcast -> Add
+ // - addend -> Broadcast -> Convert -> Add
+ // - addend -> Convert
+ // - addend -> Broadcast
+ // - addend
+ //
+ // Hunt for addend through possible sequences above and check the addend
+ // is compatible to onednn-matmul fusion.
+ HloInstruction* addend = nullptr;
+ HloInstruction* optional_addend_broadcast = nullptr;
+ auto addend_pattern = m::AnyOf<HloInstruction>(
+ m::Broadcast(&optional_addend_broadcast,
+ m::Convert(&addend, m::Op())),
+ m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))),
+ m::Convert(&addend, m::Op()),
+ m::Broadcast(&optional_addend_broadcast, m::Op(&addend)),
+ m::Op(&addend));
+ if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus();
+
+ if (optional_addend_broadcast && addend->shape().rank() != 1) {
+ auto new_shape =
+ AdjustBiasShape(optional_addend_broadcast, dot->shape());
+ if (new_shape.ok()) {
+ addend = addend->AddInstruction(
+ HloInstruction::CreateBitcast(new_shape.value(), addend));
+ } else {
+ VLOG(2) << new_shape.status();
+ return absl::OkStatus();
+ }
+ }
+
+ // Validate addend for fusion.
+ if (IsSupportedType(addend->shape().element_type()) &&
+ IsOperandFusible(addend, dot)) {
+ new_operands.push_back(addend);
+ } else {
+ return absl::OkStatus();
+ }
+
+ // TODO(intel-tf): Remove this restriction once oneDNN has an optimized
+ // implementation for broadcasted add across all dimensions.
+ OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED;
+ kind = (addend->shape().rank() == 1)
+ ? (dot->backend_config<BackendConfig>()
+ ->mutable_onednn_matmul_config()
+ ->fusions()
+ .ops()
+ .empty()
+ ? OneDnnFusionConfig::BIAS
+ : OneDnnFusionConfig::UNDEFINED)
+ : OneDnnFusionConfig::BINARY_ADD;
+ if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus();
+
+ auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
+ dot->CloneWithNewOperands(dot->shape(), new_operands)));
+
+ auto backend_config = matmul_call->backend_config<BackendConfig>();
+ backend_config->mutable_onednn_matmul_config()
+ ->mutable_fusions()
+ ->add_ops(kind);
+
+ if (optional_addend_broadcast) {
+ backend_config->mutable_onednn_matmul_config()
+ ->mutable_optimization_config()
+ ->set_bias_broadcast(true);
+ }
+ TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
+
+ HloInstruction* new_instr;
+ // If matched pattern has custom-call -> bitcast -> add, then we need to
+ // insert bitcast after the new fusion to maintain the correct shape
+ // (new-custom-call -> bitcast). Also, this will optionally be followed
+ // by -> convert for bf16 case to avoid datatype mismatch.
+ if (optional_dot_bitcast != nullptr &&
+ optional_dot_bitcast->opcode() == HloOpcode::kBitcast) {
+ if (optional_dot_convert != nullptr &&
+ optional_dot_convert->opcode() == HloOpcode::kConvert) {
+ auto bitcast_call =
+ matmul_call->AddInstruction(HloInstruction::CreateBitcast(
+ ShapeUtil::ChangeElementType(
+ instr->shape(), matmul_call->shape().element_type()),
+ matmul_call));
+ new_instr =
+ bitcast_call->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(
+ bitcast_call->shape(),
+ optional_dot_convert->shape().element_type()),
+ bitcast_call));
+ } else {
+ new_instr = matmul_call->AddInstruction(
+ HloInstruction::CreateBitcast(instr->shape(), matmul_call));
+ }
+ } else {
+ if (optional_dot_convert != nullptr &&
+ optional_dot_convert->opcode() == HloOpcode::kConvert) {
+ new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(
+ matmul_call->shape(),
+ optional_dot_convert->shape().element_type()),
+ matmul_call));
+ } else {
+ new_instr = matmul_call;
+ }
+ }
+ TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
+ }
+
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleMaximum(HloInstruction* instr) override {
+ HloInstruction* matmul_call;
+ HloInstruction* intermediate_instr = nullptr;
+ HloInstruction* optional_bitcast = nullptr;
+ // Attempt to elide maximum and fuse ReLU activation into GEMM, including
+ // when slicing or bitcasting is applied to the result.
+ if (Match(instr,
+ m::MaximumAnyOrder(ElementwiseSafeIntermediates(
+ &intermediate_instr, &optional_bitcast,
+ OneDnnMatmulInstr(&matmul_call))
+ .WithOneUser(),
+ BcastConstScalar(0)))) {
+ return FuseActivation(OneDnnFusionConfig::RELU, instr, matmul_call,
+ intermediate_instr, optional_bitcast);
+ }
+ return absl::OkStatus();
+ }
+
+ auto ELUActivation(HloInstruction* instr, HloInstruction** src) {
+ // Reference: tensorflow/compiler/tf2xla/kernels/elu_op.cc
+ // const auto zero = ScalarLike(x, 0);
+ // const auto pred = Gt(x, zero);
+ // const auto expm1 = Expm1(x);
+ // return Select(pred, x, expm1);
+ auto pattern = m::Select(
+ m::Gt(pu::OptionalConvert(m::Op(src)), BcastConvertConstScalar(0)),
+ m::Op(src),
+ pu::OptionalConvert(m::Expm1(pu::OptionalConvert(m::Op(src)))));
+ return Match(instr, pattern);
+ }
+
+ absl::Status HandleSelect(HloInstruction* instr) override {
+ HloInstruction* matmul_call;
+ HloInstruction* intermediate_instr = nullptr;
+ HloInstruction* optional_bitcast = nullptr;
+ HloInstruction* src;
+ // Attempt to elide ELU subgraph and fuse ELU activation into GEMM,
+ // including when slicing or bitcasting is applied to the result.
+ if (ELUActivation(instr, &src)) {
+ if (Match(src, ElementwiseSafeIntermediates(
+ &intermediate_instr, &optional_bitcast,
+ OneDnnMatmulInstr(&matmul_call)))) {
+ return FuseActivation(OneDnnFusionConfig::ELU, instr, matmul_call,
+ intermediate_instr);
+ }
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleTanh(HloInstruction* instr) override {
+ HloInstruction* matmul_call;
+ HloInstruction* intermediate_instr = nullptr;
+ HloInstruction* optional_bitcast = nullptr;
+ // Attempt to elide Tanh and fuse Tanh activation into GEMM, including
+ // when slicing or bitcasting is applied to the result.
+ if (Match(instr, m::Tanh(ElementwiseSafeIntermediates(
+ &intermediate_instr, &optional_bitcast,
+ OneDnnMatmulInstr(&matmul_call))
+ .WithOneUser()))) {
+ return FuseActivation(OneDnnFusionConfig::TANH, instr, matmul_call,
+ intermediate_instr);
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleClamp(HloInstruction* instr) override {
+ HloInstruction* matmul_call;
+ HloInstruction* intermediate_instr = nullptr;
+ HloInstruction* optional_bitcast = nullptr;
+ // Attempt to elide RELU6 and fuse RELU6 activation into GEMM, including
+ // when slicing or bitcasting is applied to the result.
+ if (Match(instr, m::Clamp(BcastConstScalar(0),
+ ElementwiseSafeIntermediates(
+ &intermediate_instr, &optional_bitcast,
+ OneDnnMatmulInstr(&matmul_call))
+ .WithOneUser(),
+ BcastConstScalar(6)))) {
+ return FuseActivation(OneDnnFusionConfig::RELU6, instr, matmul_call,
+ intermediate_instr);
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleMultiply(HloInstruction* instr) override {
+ HloInstruction* matmul_call;
+ HloInstruction* intermediate_instr = nullptr;
+ HloInstruction* src;
+ auto activation = GELUActivation(instr, &src);
+ if (activation != OneDnnFusionConfig::UNDEFINED) {
+ HloInstruction* optional_bitcast = nullptr;
+ if (Match(src, ElementwiseSafeIntermediates(
+ &intermediate_instr, &optional_bitcast,
+ OneDnnMatmulInstr(&matmul_call)))) {
+ return FuseActivation(activation, instr, matmul_call,
+ intermediate_instr, optional_bitcast);
+ }
+ }
+
+ HloInstruction *dot, *constant;
+ HloInstruction* optional_convert = nullptr;
+ auto pattern = m::Op(&instr)
+ .WithOpcode(HloOpcode::kMultiply)
+ .WithBinaryOperandsAnyOrder(
+ m::AnyOf<HloInstruction>(
+ pu::SupportedConvert(&optional_convert,
+ OneDnnMatmulInstr(&dot))
+ .WithElementType(PrimitiveType::F32),
+ OneDnnMatmulInstr(&dot))
+ .WithOneUser(),
+ m::Broadcast(m::Constant(&constant)));
+
+ if (Match(instr, pattern)) {
+ std::vector<HloInstruction*> new_operands;
+ auto constant_value = GetConstantValueAsFloat32(constant);
+ if (!constant_value) {
+ return absl::OkStatus();
+ }
+
+ for (auto operand : dot->operands()) {
+ new_operands.push_back(operand);
+ }
+ auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
+ dot->CloneWithNewOperands(instr->shape(), new_operands)));
+ auto backend_config = matmul_call->backend_config<BackendConfig>();
+ backend_config->mutable_onednn_matmul_config()
+ ->mutable_fusions()
+ ->add_ops(OneDnnFusionConfig::LINEAR);
+ // Casting to int32 because of issues in proto config for decimal types
+ // handling.
+ backend_config->mutable_onednn_matmul_config()
+ ->mutable_fusions()
+ ->set_alpha_typecast(
+ *(reinterpret_cast<int32_t*>(&constant_value.value())));
+ TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
+ HloInstruction* new_instr;
+ if (optional_convert != nullptr &&
+ optional_convert->opcode() == HloOpcode::kConvert) {
+ new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(
+ matmul_call->shape(), optional_convert->shape().element_type()),
+ matmul_call));
+ } else {
+ new_instr = matmul_call;
+ }
+
+ TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
+ }
+ return absl::OkStatus();
+ }
+
+ auto SigmoidActivation(HloInstruction* instr, HloInstruction** src) {
+ return Match(instr,
+ m::Divide(BcastConstScalar(1.0),
+ m::AddAnyOrder(BcastConstScalar(1.0),
+ m::Exp(m::Negate(m::Op(src))))));
+ }
+
+ absl::Status HandleDivide(HloInstruction* instr) override {
+ HloInstruction* matmul_call;
+ HloInstruction* intermediate_instr = nullptr;
+ HloInstruction* optional_bitcast = nullptr;
+ HloInstruction* src;
+ if (SigmoidActivation(instr, &src)) {
+ if (Match(src, ElementwiseSafeIntermediates(
+ &intermediate_instr, &optional_bitcast,
+ OneDnnMatmulInstr(&matmul_call))
+ .WithOneUser())) {
+ return FuseActivation(OneDnnFusionConfig::SIGMOID, instr, matmul_call,
+ intermediate_instr, optional_bitcast);
+ }
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status FuseActivation(OneDnnFusionConfig_FusionKind kind,
+ HloInstruction* activation,
+ HloInstruction* matmul,
+ HloInstruction* intermediate_instr = nullptr,
+ HloInstruction* optional_bitcast = nullptr) {
+ TF_ASSIGN_OR_RETURN(auto backend_config,
+ matmul->backend_config<BackendConfig>());
+ auto* matmul_config = backend_config.mutable_onednn_matmul_config();
+ matmul_config->mutable_fusions()->add_ops(kind);
+ TF_RETURN_IF_ERROR(matmul->set_backend_config(backend_config));
+ std::unique_ptr<HloInstruction> output = matmul->Clone();
+ if (optional_bitcast != nullptr &&
+ optional_bitcast->opcode() == HloOpcode::kBitcast) {
+ HloInstruction* new_instr = nullptr;
+ if (intermediate_instr != nullptr &&
+ intermediate_instr->opcode() == HloOpcode::kConvert) {
+ auto bitcast_call =
+ matmul->AddInstruction(HloInstruction::CreateBitcast(
+ ShapeUtil::ChangeElementType(optional_bitcast->shape(),
+ matmul->shape().element_type()),
+ matmul));
+ new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(
+ bitcast_call->shape(),
+ intermediate_instr->shape().element_type()),
+ bitcast_call));
+ return ReplaceInstruction(activation, new_instr);
+ }
+ } else if (intermediate_instr) {
+ output = intermediate_instr->CloneWithNewOperands(
+ intermediate_instr->shape(),
+ {matmul->parent()->AddInstruction(std::move(output))});
+ }
+
+ return ReplaceWithNewInstruction(activation, std::move(output));
+ }
+
+ // This function changes dot instruction for supported matrix
+ // multiplication scenarios. In particular, it changes the shape
+ // of lhs, rhs and result arrays.
+ // - lhs configuration scenario
+ // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
+ // result: [batch_dims,feature_dim] to [batch_dims,1,feature_dim]
+ //
+ // - rhs configuration scenario
+ // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
+ // result: [batch_dims,feature_dim] to [batch_dims,feature_dim, 1]
+ //
+ // - both lhs and rhs configuration scenario
+ // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
+ // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
+ // result: [batch_dims] to [batch_dims,1,1]
+ absl::StatusOr<HloInstruction*> ReconfigureDotDimensions(
+ HloInstruction* dot_instr) {
+ HloInstruction* lhs = dot_instr->mutable_operand(0);
+ HloInstruction* rhs = dot_instr->mutable_operand(1);
+ DotDimensionNumbers dim_numbers = dot_instr->dot_dimension_numbers();
+
+ auto lhs_batch_dims = dim_numbers.lhs_batch_dimensions();
+ auto lhs_contraction_dims = dim_numbers.lhs_contracting_dimensions();
+ bool is_lhs_vector = lhs->shape().rank() ==
+ (lhs_batch_dims.size() + lhs_contraction_dims.size());
+
+ auto rhs_batch_dims = dim_numbers.rhs_batch_dimensions();
+ auto rhs_contraction_dims = dim_numbers.rhs_contracting_dimensions();
+ bool is_rhs_vector = rhs->shape().rank() ==
+ (rhs_batch_dims.size() + rhs_contraction_dims.size());
+
+ if (!is_lhs_vector && !is_rhs_vector) return dot_instr;
+
+ std::vector<int64_t> adjusted_lhs_dims(lhs->shape().dimensions().begin(),
+ lhs->shape().dimensions().end());
+ std::vector<int64_t> adjusted_rhs_dims(rhs->shape().dimensions().begin(),
+ rhs->shape().dimensions().end());
+ std::vector<int64_t> adjusted_dot_dims(
+ dot_instr->shape().dimensions().begin(),
+ dot_instr->shape().dimensions().end());
+
+ if (is_lhs_vector) {
+ auto lhs_it = adjusted_lhs_dims.begin() + lhs_batch_dims.size();
+ adjusted_lhs_dims.insert(lhs_it, 1, 1);
+ auto result_it = adjusted_dot_dims.begin() + lhs_batch_dims.size();
+ adjusted_dot_dims.insert(result_it, 1, 1);
+ auto lhs_contraction_dim =
+ dot_instr->dot_dimension_numbers().lhs_contracting_dimensions(0);
+ dim_numbers.set_lhs_contracting_dimensions(0, lhs_contraction_dim + 1);
+ lhs = lhs->AddInstruction(HloInstruction::CreateBitcast(
+ ShapeUtil::MakeShape(lhs->shape().element_type(), adjusted_lhs_dims),
+ lhs));
+ }
+
+ if (is_rhs_vector) {
+ auto it = adjusted_rhs_dims.end();
+ adjusted_rhs_dims.insert(it, 1, 1);
+ auto result_it = adjusted_dot_dims.end();
+ adjusted_dot_dims.insert(result_it, 1, 1);
+ rhs = rhs->AddInstruction(HloInstruction::CreateBitcast(
+ ShapeUtil::MakeShape(rhs->shape().element_type(), adjusted_rhs_dims),
+ rhs));
+ }
+
+ HloInstruction* adjusted_dot =
+ dot_instr->AddInstruction(HloInstruction::CreateDot(
+ ShapeUtil::MakeShape(dot_instr->shape().element_type(),
+ adjusted_dot_dims),
+ lhs, rhs, dim_numbers, dot_instr->precision_config()));
+
+ HloInstruction* replacement_instr = adjusted_dot->AddInstruction(
+ HloInstruction::CreateBitcast(dot_instr->shape(), adjusted_dot));
+
+ TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr));
+ return adjusted_dot;
+ }
+
+ // This function upcasts BF16 dots to F32 if we are unable to rewrite them to
+ // oneDNN custom calls.
+ absl::Status UpcastDotToF32(HloInstruction* dot_instr) {
+ if (dot_instr->shape().element_type() != BF16) return absl::OkStatus();
+ std::vector<HloInstruction*> new_operands;
+ auto bf16_operands = dot_instr->operands();
+
+ std::for_each(
+ bf16_operands.begin(), bf16_operands.end(),
+ [&new_operands](HloInstruction* instr) {
+ new_operands.push_back(
+ instr->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(instr->shape(), F32), instr)));
+ });
+
+ HloInstruction* f32_dot =
+ dot_instr->AddInstruction(dot_instr->CloneWithNewOperands(
+ ShapeUtil::ChangeElementType(dot_instr->shape(), F32),
+ new_operands));
+
+ HloInstruction* replacement_instr =
+ f32_dot->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(f32_dot->shape(), BF16), f32_dot));
+
+ TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr));
+ return absl::OkStatus();
+ }
+};
+
+class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
+ public:
+ OneDnnPostRewriteVisitor(int intra_op_parallelism,
+ const tsl::thread::ThreadPool* compile_threadpool)
+ : intra_op_parallelism_(intra_op_parallelism > 0
+ ? intra_op_parallelism
+ : tsl::port::MaxParallelism()),
+ evaluator_(/*max_loop_iterations=*/0) {
+ if (compile_threadpool) {
+ threadpool_device_.reset(
+ new Eigen::ThreadPoolDevice(compile_threadpool->AsEigenThreadPool(),
+ compile_threadpool->NumThreads()));
+ } else {
+ threadpool_handle_.reset(new tsl::thread::ThreadPool(
+ tsl::Env::Default(), "XLACpuCompile", tsl::port::MaxParallelism()));
+ threadpool_device_.reset(
+ new Eigen::ThreadPoolDevice(threadpool_handle_->AsEigenThreadPool(),
+ threadpool_handle_->NumThreads()));
+ }
+
+#ifndef ENABLE_ONEDNN_OPENMP
+ // Set oneDNN concurrency settings (which is thread-local)
+ tsl::OneDnnThreadPool::set_onednn_max_threads(intra_op_parallelism_);
+#endif
+ }
+
+ absl::Status HandleCustomCall(HloInstruction* custom_call) override {
+ HloInstruction* matmul;
+ if (Match(custom_call, OneDnnMatmulInstr(&matmul))) {
+ return HandleCustomCallInternal<dnnl::matmul::primitive_desc>(
+ custom_call);
+ }
+
+ return DefaultAction(custom_call);
+ }
+
+ template <typename PrimDesc>
+ absl::Status HandleCustomCallInternal(HloInstruction* custom_call) {
+ auto scratch_add = AddScratch<PrimDesc>(custom_call);
+ if (scratch_add.ok()) {
+ custom_call = *scratch_add;
+ } else {
+ VLOG(2) << scratch_add.status();
+ }
+ auto weights_prepack = PrepackWeights<PrimDesc>(custom_call);
+ if (!weights_prepack.ok()) {
+ VLOG(2) << weights_prepack.status();
+ }
+ return absl::OkStatus();
+ }
+
+ template <typename>
+ absl::Status SetWeightsPrepack(HloInstruction*, bool);
+
+ template <typename>
+ absl::Status SetUserScratch(HloInstruction*, bool);
+
+ template <typename>
+ bool GetWeightsPrepack(HloInstruction*);
+
+ template <typename>
+ bool GetUserScratch(HloInstruction*);
+
+ // Add scratch for matmul by changing the result of custom-call to
+ // tuple(result, scratch)
+ template <typename PrimDesc>
+ absl::StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) {
+ if (GetUserScratch<PrimDesc>(custom_call)) {
+ return custom_call;
+ }
+ TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, true));
+ auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
+ int64_t scratch_size = prim_desc->scratchpad_desc().get_size();
+ Shape scratch_shape = ShapeUtil::MakeShape(U8, {scratch_size});
+ Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({custom_call->shape(), scratch_shape});
+ auto new_custom_call = custom_call->AddInstruction(
+ custom_call->CloneWithNewShape(tuple_shape));
+ HloInstruction* gte =
+ new_custom_call->AddInstruction(HloInstruction::CreateGetTupleElement(
+ custom_call->shape(), new_custom_call, 0));
+ auto status = ReplaceInstruction(custom_call, gte);
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, false));
+ return absl::CancelledError("Adding scratch is unsuccessful.");
+ }
+ return new_custom_call;
+ }
+
+ template <typename PrimDesc>
+ absl::StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) {
+ if (GetWeightsPrepack<PrimDesc>(custom_call)) {
+ return custom_call;
+ }
+ auto weights = custom_call->operand(1);
+ auto weights_shape = weights->shape();
+ Literal weights_literal;
+ if (!(weights_shape.rank() == 2 &&
+ evaluator_.TryEvaluate(weights, &weights_literal, true))) {
+ return absl::CancelledError(
+ "Cannot prepack weights. Not constant 2D weights.");
+ }
+ auto plain_weights_md = ShapeToMemDesc(weights_shape);
+ if constexpr (std::is_same<PrimDesc, dnnl::matmul::primitive_desc>::value) {
+ TF_ASSIGN_OR_RETURN(auto backend_config,
+ custom_call->backend_config<BackendConfig>());
+ TRANSPOSE_LAST_TWO_DIMS_IF(
+ backend_config.onednn_matmul_config().transpose_b(),
+ plain_weights_md);
+ }
+ TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, true));
+ auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
+ auto packed_weights_md = prim_desc->weights_desc();
+ auto packed_weights_shape = MemDescToXlaShapeFlattened(packed_weights_md);
+ auto packed_weights_literal = Literal(packed_weights_shape);
+ ReorderWeight(plain_weights_md, weights_literal.untyped_data(),
+ packed_weights_md, packed_weights_literal.untyped_data());
+ HloInstruction* reordered_weight = custom_call->AddInstruction(
+ HloInstruction::CreateConstant(std::move(packed_weights_literal)));
+ auto status =
+ custom_call->ReplaceOperandWithDifferentShape(1, reordered_weight);
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, false));
+ return absl::CancelledError(
+ "Cannot replace plain weights with prepacked weights.");
+ } else {
+ return custom_call;
+ }
+ }
+
+ void ReorderWeight(const dnnl::memory::desc& src_md, void* src_buf,
+ const dnnl::memory::desc& dst_md, void* dst_buf) {
+ auto onednn_threadpool = CreateOneDnnThreadPool(threadpool_device_.get());
+ dnnl::engine cpu_engine(dnnl::engine::kind::cpu, 0);
+ auto onednn_stream = MakeOneDnnStream(cpu_engine, onednn_threadpool.get());
+ auto src_mem = dnnl::memory(src_md, cpu_engine, src_buf);
+ auto dst_mem = dnnl::memory(dst_md, cpu_engine, dst_buf);
+ dnnl::reorder reorder_prim{src_mem, dst_mem};
+ reorder_prim.execute(onednn_stream, src_mem, dst_mem);
+ onednn_stream.wait();
+ }
+
+ private:
+ int intra_op_parallelism_;
+ HloEvaluator evaluator_;
+ std::unique_ptr<tsl::thread::ThreadPool> threadpool_handle_;
+ std::unique_ptr<Eigen::ThreadPoolDevice> threadpool_device_;
+};
+
+#define EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GETTER, PRIM_DESC, CONFIG, \
+ SUB_CONFIG, FIELD) \
+ template <> \
+ inline bool OneDnnPostRewriteVisitor::GETTER<PRIM_DESC>(HloInstruction * \
+ custom_call) { \
+ auto backend_config = custom_call->backend_config<BackendConfig>(); \
+ return backend_config.ok() ? backend_config->CONFIG().SUB_CONFIG().FIELD() \
+ : false; \
+ }
+
+EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetUserScratch,
+ dnnl::matmul::primitive_desc,
+ onednn_matmul_config,
+ optimization_config, user_scratchpad);
+EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetWeightsPrepack,
+ dnnl::matmul::primitive_desc,
+ onednn_matmul_config,
+ optimization_config, weights_prepacked);
+
+#define EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SETTER, PRIM_DESC, CONFIG_TYPE, \
+ CONFIG, SUB_CONFIG, FIELD) \
+ template <> \
+ inline absl::Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>( \
+ HloInstruction * custom_call, bool value) { \
+ TF_ASSIGN_OR_RETURN(auto backend_config, \
+ custom_call->backend_config<BackendConfig>()); \
+ CONFIG_TYPE* config = backend_config.mutable_##CONFIG(); \
+ config->mutable_##SUB_CONFIG()->set_##FIELD(value); \
+ return custom_call->set_backend_config(backend_config); \
+ }
+
+EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetWeightsPrepack,
+ dnnl::matmul::primitive_desc,
+ OneDnnMatMulConfig, onednn_matmul_config,
+ optimization_config, weights_prepacked);
+EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch,
+ dnnl::matmul::primitive_desc,
+ OneDnnMatMulConfig, onednn_matmul_config,
+ optimization_config, user_scratchpad);
+
+absl::StatusOr<bool> OneDnnContractionRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ OneDnnContractionRewriteVisitor visitor;
+ TF_ASSIGN_OR_RETURN(auto result,
+ visitor.RunOnModule(module, execution_threads));
+
+ OneDnnPostRewriteVisitor reorder_visitor(intra_op_parallelism_,
+ compile_threadpool_);
+ TF_ASSIGN_OR_RETURN(auto result2,
+ reorder_visitor.RunOnModule(module, execution_threads));
+
+ return {result || result2};
+}
+
+} // namespace cpu
+} // namespace xla
+
+#endif // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h
new file mode 100644
index 0000000..7864dae
--- /dev/null
+++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h
@@ -0,0 +1,63 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_
+#define XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_
+#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
+
+#include <optional>
+
+#include "absl/algorithm/container.h"
+#include "unsupported/Eigen/CXX11/Tensor"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace cpu {
+
+// This pass pattern-matches HLO Dot and Convolution instructions and rewrites
+// them into custom calls.
+class OneDnnContractionRewriter : public HloModulePass {
+ public:
+ OneDnnContractionRewriter(int intra_op_parallelism,
+ const tsl::thread::ThreadPool* compile_threadpool)
+ : intra_op_parallelism_(intra_op_parallelism),
+ compile_threadpool_(compile_threadpool) {}
+ OneDnnContractionRewriter() = default;
+ absl::string_view name() const override {
+ return "onednn-contraction-rewriter";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ static bool ShouldRewriteDot(const HloInstruction* dot_instr,
+ bool before_layout_assignment = false);
+ static bool ShouldRewriteConv(const HloInstruction* conv_instr);
+
+ private:
+ int intra_op_parallelism_;
+ const tsl::thread::ThreadPool* compile_threadpool_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // INTEL_MKL && ENABLE_ONEDNN_V3
+#endif // XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc
deleted file mode 100644
index 0c65c5d..0000000
--- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc
+++ /dev/null
@@ -1,143 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/cpu/backend_config.pb.h"
-#include "xla/service/cpu/onednn_config.pb.h"
-#include "xla/service/cpu/onednn_memory_util.h"
-#include "xla/service/cpu/onednn_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/status_macros.h"
-
-namespace xla {
-namespace cpu {
-
-namespace {
-namespace m = match;
-} // namespace
-
-bool OneDnnConvolutionRewriter::ShouldRewrite(const HloInstruction* conv) {
- if (conv->HasControlDependencies()) return false;
- if (!IsSupportedType(conv->shape().element_type())) return false;
- if (conv->batch_group_count() != 1) return false;
-
- if (conv->operand(1)->opcode() == HloOpcode::kReverse) return false;
-
- const Shape& inp_shape = conv->operand(0)->shape();
- const Shape& ker_shape = conv->operand(1)->shape();
- const Shape& out_shape = conv->shape();
- if (ShapeUtil::IsZeroElementArray(inp_shape) ||
- ShapeUtil::IsZeroElementArray(ker_shape) ||
- ShapeUtil::IsZeroElementArray(out_shape)) {
- return false;
- }
-
- auto dims = conv->window().dimensions().size();
- if (dims >= 4 || dims <= 0) return false;
-
- if (inp_shape.rank() != ker_shape.rank() ||
- inp_shape.rank() != out_shape.rank()) {
- return false;
- }
-
- return true;
-}
-
-class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleConvolution(HloInstruction* conv) override {
- auto pattern = match::Op(&conv).WithOpcode(HloOpcode::kConvolution);
- if (!Match(conv, pattern)) return absl::OkStatus();
- if (!OneDnnConvolutionRewriter::ShouldRewrite(conv)) {
- return absl::OkStatus();
- }
-
- const Shape& conv_shape = conv->shape();
- auto dims = conv->window().dimensions().size();
- const ConvolutionDimensionNumbers& conv_ddata =
- conv->convolution_dimension_numbers();
-
- BackendConfig backend_config;
- OneDnnConvolutionConfig* conv_config =
- backend_config.mutable_onednn_conv_config();
-
- conv_config->set_dims(conv_shape.rank());
- conv_config->set_feature_groups(conv->feature_group_count());
- conv_config->mutable_input()->mutable_data()->set_batch_dim(
- conv_ddata.input_batch_dimension());
- conv_config->mutable_kernel()->mutable_filter()->set_input_feature_dim(
- conv_ddata.kernel_input_feature_dimension());
- conv_config->mutable_output()->mutable_data()->set_batch_dim(
- conv_ddata.output_batch_dimension());
- conv_config->mutable_input()->mutable_data()->set_feature_dim(
- conv_ddata.input_feature_dimension());
- conv_config->mutable_kernel()->mutable_filter()->set_output_feature_dim(
- conv_ddata.kernel_output_feature_dimension());
- conv_config->mutable_output()->mutable_data()->set_feature_dim(
- conv_ddata.output_feature_dimension());
-
- const Shape& output_shape = conv->shape();
-
- for (auto it = conv->window().dimensions().begin();
- it != conv->window().dimensions().end(); it++) {
- if ((*it).padding_low() < 0 || (*it).padding_high() < 0 ||
- (*it).stride() < 0) {
- return absl::OkStatus();
- }
- conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1);
- conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1);
- conv_config->mutable_window()->add_strides((*it).stride() + 1);
- conv_config->mutable_window()->add_window_dilations(
- (*it).window_dilation() + 1);
- if ((*it).base_dilation() != 1 || (*it).window_reversal()) {
- return absl::OkStatus();
- }
- }
-
- for (int i = 0; i < dims; i++) {
- conv_config->mutable_input()->mutable_data()->add_spatial_dims(
- conv_ddata.input_spatial_dimensions()[i] + 1);
- conv_config->mutable_kernel()->mutable_filter()->add_spatial_dims(
- conv_ddata.kernel_spatial_dimensions()[i] + 1);
- conv_config->mutable_output()->mutable_data()->add_spatial_dims(
- conv_ddata.output_spatial_dimensions()[i] + 1);
- }
-
- HloInstruction* custom_call =
- conv->AddInstruction(HloInstruction::CreateCustomCall(
- output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)},
- "__onednn$convolution"));
-
- TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
- TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call));
- return absl::OkStatus();
- }
-};
-
-absl::StatusOr<bool> OneDnnConvolutionRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- OneDnnConvolutionRewriterVisitor visitor;
- return visitor.RunOnModule(module, execution_threads);
-}
-
-} // namespace cpu
-} // namespace xla
-
-#endif // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h
deleted file mode 100644
index 2dbd3a6..0000000
--- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_
-#define XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#include <optional>
-
-#include "absl/algorithm/container.h"
-#include "absl/status/statusor.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace cpu {
-
-// This pass converts hlo convolution instructions into a single oneDNN
-// operation and rewrites into custom calls.
-class OneDnnConvolutionRewriter : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "onednn-convolution-rewriter";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- static bool ShouldRewrite(const HloInstruction* instr);
-};
-
-} // namespace cpu
-} // namespace xla
-
-#endif // INTEL_MKL && ENABLE_ONEDNN_V3
-#endif // XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
deleted file mode 100644
index 45c6bc1..0000000
--- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
+++ /dev/null
@@ -1,1129 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#define EIGEN_USE_THREADS
-
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
-
-#include "xla/executable_run_options.h"
-#include "xla/hlo/evaluator/hlo_evaluator.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/cpu/backend_config.pb.h"
-#include "xla/service/cpu/onednn_config.pb.h"
-#include "xla/service/cpu/onednn_matmul.h"
-#include "xla/service/cpu/onednn_memory_util.h"
-#include "xla/service/cpu/onednn_pattern_utils.h"
-#include "xla/service/cpu/onednn_util.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/status_macros.h"
-#include "xla/tsl/util/onednn_threadpool.h"
-#include "tsl/platform/logging.h" // IWYU pragma: keep
-
-namespace xla {
-namespace cpu {
-
-namespace {
-namespace m = match;
-namespace pu = ::xla::cpu::onednn_pattern_utils_internal;
-
-inline absl::Status ValidateDotDimensionNumbers(
- const DotDimensionNumbers& dim_numbers) {
- // Checks some invariants that do not hold in general, but DotDecomposer
- // should have established for us.
- TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
- std::vector<int64_t> batch_dim_numbers(
- dim_numbers.lhs_batch_dimensions_size());
- absl::c_iota(batch_dim_numbers, 0);
- TF_RET_CHECK(
- absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
- TF_RET_CHECK(
- absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
- return absl::OkStatus();
-}
-
-// Whether the element type of instr is compatible with oneDNN kernels.
-// TODO(intel-tf): Restict compatible types based on instruction kind.
-inline bool CompatibleElementType(const HloInstruction* instr) {
- PrimitiveType element_type = instr->shape().element_type();
- return element_type == BF16 || element_type == F32 || element_type == F16;
-}
-
-inline bool IsRowMajor(const Shape& shape) {
- return LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
-}
-
-template <typename Pattern>
-inline auto BitcastWithReshapeSemantics(HloInstruction** bitcast,
- Pattern pattern) {
- // TODO(intel-tf): Add stronger condition that Bitcast does not have transpose
- // semantics. Some of the HLO passes replaces Transpose with Bitcast. Here
- // the layouts are checked to be rowmajor since the current pass runs after
- // the layout assignment and oneDNN matmul is enabled for rowmajor layouts.
- auto is_reshape = [](const HloInstruction* instr) -> bool {
- if (!instr) return false;
- auto input_shape = instr->operand(0)->shape();
- auto output_shape = instr->shape();
- bool is_same_type = ShapeUtil::SameElementType(input_shape, output_shape);
- bool has_equal_num_elems = ShapeUtil::ElementsIn(input_shape) ==
- ShapeUtil::ElementsIn(output_shape);
- bool has_rowmajor_layout =
- IsRowMajor(input_shape) && IsRowMajor(output_shape);
- return is_same_type && has_equal_num_elems && has_rowmajor_layout;
- };
- return m::Bitcast(bitcast, pattern).WithPredicate(is_reshape);
-}
-
-template <typename Pattern>
-auto ElementwiseSafeIntermediates(HloInstruction** instr,
- HloInstruction** optional_bitcast,
- Pattern pattern) {
- return m::AnyOf<HloInstruction>(
- m::Broadcast(instr, pattern.WithOneUser()),
- m::Slice(instr, pattern.WithOneUser()),
- m::Bitcast(instr, pattern.WithOneUser()),
- m::Reshape(instr, pattern.WithOneUser()),
- pu::SupportedConvert(instr, pattern.WithOneUser()),
- pu::SupportedConvert(instr, BitcastWithReshapeSemantics(
- optional_bitcast, pattern.WithOneUser())),
- pattern);
-}
-
-inline auto OneDnnMatmulInstr(HloInstruction** instr) {
- return m::CustomCall(instr, {"__onednn$matmul"});
-}
-
-inline auto ConvertBF16ToF32(HloInstruction** instr) {
- return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16))
- .WithElementType(PrimitiveType::F32);
-}
-
-inline auto BcastConstScalar(HloInstruction** instr, double value) {
- return m::Broadcast(instr, m::ConstantScalar(value));
-}
-
-inline auto BcastConstScalar(double value) {
- return BcastConstScalar(nullptr, value);
-}
-
-inline auto BcastConvertConstScalar(double value) {
- return m::Broadcast(pu::OptionalConvert(m::ConstantScalar(value)));
-}
-
-inline bool IsBatchDot(const HloInstruction& instr) {
- if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
- return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
- }
- return false;
-}
-
-auto ConstScalarNear(double value) {
- return m::ConstantScalar().WithPredicate(
- [expected = value](const HloInstruction* instr) {
- // Not a very robust floating-point comparison, but good enough for our
- // purposes.
- std::optional<double> actual =
- static_cast<const HloConstantInstruction*>(instr)
- ->literal()
- .GetAsDouble({});
- if (!actual.has_value()) return false;
- double epsilon;
- switch (instr->shape().element_type()) {
- case F16:
- epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
- break;
- case BF16:
- epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
- break;
- case F32:
- epsilon = 128 * std::numeric_limits<float>::epsilon();
- break;
- case F64:
- epsilon = 128 * std::numeric_limits<double>::epsilon();
- break;
- default:
- return false;
- }
- return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
- });
-}
-
-bool IsScalar(const HloInstruction* instr) {
- return ShapeUtil::IsEffectiveScalar(instr->shape());
-}
-
-std::optional<float> GetConstantValueAsFloat32(const HloInstruction* inst) {
- if (!IsScalar(inst)) {
- return std::nullopt;
- }
- switch (inst->shape().element_type()) {
- case F16:
- return inst->literal().GetFirstElement<half>();
- case BF16:
- return inst->literal().GetFirstElement<bfloat16>();
- case F32:
- return inst->literal().GetFirstElement<float>();
- default:
- return std::nullopt;
- }
-}
-
-inline auto BcastConstScalarNear(double value) {
- return m::Broadcast(ConstScalarNear(value));
-}
-
-// Associativity and commutativity properties of multiply results in various
-// patterns for an equivalent computation. This function tries to capture most
-// of the variations for a computation a * b * c. For example, patterns could be
-// any of (a * b) * c or a * (b * c), along with the variations resulting from
-// commutative patterns.
-template <typename PatternA, typename PatternB, typename PatternC>
-inline auto MultiplyMultiplyAnyOrder(PatternA a, PatternB b, PatternC c) {
- return m::AnyOf<HloInstruction>(
- m::MultiplyAnyOrder(a, m::MultiplyAnyOrder(b, c)),
- m::MultiplyAnyOrder(b, m::MultiplyAnyOrder(a, c)),
- m::MultiplyAnyOrder(c, m::MultiplyAnyOrder(a, b)));
-}
-
-auto GELUActivation(HloInstruction* instr, HloInstruction** src) {
- // Attempt to match GELU_TANH activation or GELU_ERF activation
- // (https://arxiv.org/abs/1606.08415), where:
- // gelu_tanh(x) = x * cdf(x)
- // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
- // -------------errf_approximate------------
- //
- // gelu_erf(x) = x * cdf(x)
- // cdf(x) = 0.5 * (1 + erf(x / sqrt(2)))
- // --errf_exact--
-
- HloInstruction* errf;
-
- // The expression 0.5 * x * (1 + errf) as common pattern for GELU exact and
- // approximate activations.
- auto common_pattern = MultiplyMultiplyAnyOrder(
- BcastConstScalar(0.5), m::Op(src),
- m::AddAnyOrder(BcastConstScalar(1.0), m::Op(&errf).WithOneUser()));
-
- bool matched = Match(instr, common_pattern);
- if (matched) {
- // The subexpression 0.044715 * x**3 appears in GELU approximate activation.
- // However, it is often optimized by other HLO passes into an expression of
- // 0.044715 * x * (x * x). Since there are three consecutive multiplies,
- // there could be a large number of patterns. We try to capture some of
- // those:
- //
- // 1. (0.044715 * x) * x * x
- // 2. 0.044715 * (x * x) * x
- //
- // Note each of the above could in turn have various patterns due to
- // associativity and commutativity properties of multiply.
- auto subexpr_pattern = m::AnyOf<HloInstruction>(
- MultiplyMultiplyAnyOrder(
- m::MultiplyAnyOrder(BcastConstScalarNear(0.044715),
- m::Op().Is(*src)),
- m::Op().Is(*src), m::Op().Is(*src)),
- MultiplyMultiplyAnyOrder(
- BcastConstScalarNear(0.044715),
- m::Multiply(m::Op().Is(*src), m::Op().Is(*src)), m::Op().Is(*src)));
-
- auto errf_apprx_pattern =
- m::Tanh(m::MultiplyAnyOrder(
- BcastConstScalarNear(sqrt(M_2_PI)),
- m::AddAnyOrder(m::Op().Is(*src), subexpr_pattern)
- .WithOneUser()))
- .WithOneUser();
-
- HloInstruction* erf;
- auto errf_exact_pattern =
- m::Op(&erf)
- .WithOpcode(HloOpcode::kErf)
- .WithOperand(
- 0, m::MultiplyAnyOrder(m::Op(src),
- m::AnyOf<HloInstruction>(
- BcastConstScalarNear(0.707106769),
- BcastConstScalarNear(0.70703125),
- BcastConstScalarNear(0.707182348)))
- .WithOneUser())
- .WithOneUser();
-
- if (Match(errf, errf_apprx_pattern)) {
- // Matched Gelu-approximate pattern
- return OneDnnFusionConfig::GELU_TANH;
- } else if (Match(errf, errf_exact_pattern)) {
- // Matched Gelu-exact pattern
- return OneDnnFusionConfig::GELU_ERF;
- }
- }
- return OneDnnFusionConfig::UNDEFINED;
-}
-
-// OneDNN matmul can fuse add operation with automatic broadcasting along the
-// addend's dimensions that are 1s. When compatible, Broadcast can be replaced
-// by Bitcast, which is much cheaper. Compute new shape for the Bitcast.
-absl::StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr,
- const Shape& dot_shape) {
- if (broadcast_instr->opcode() != HloOpcode::kBroadcast) {
- return absl::InvalidArgumentError(
- "Hlo instruction is not a Broadcast insruction.");
- }
- auto bcast = Cast<HloBroadcastInstruction>(broadcast_instr);
- Shape new_shape = bcast->shape();
- // Broadcast instruction has "dimensions" parameter along which its input's
- // dimensions should not change. For example,
- // dot = f32[3,4,5,6] dot(...)
- // arg = f32[3,6]{1,0} parameter(0)
- // broad = f32[3,4,5,6]{3,2,1,0} broadcast(arg), dimensions={0,3}
- // add = f32[3,4,5,6]{3,2,1,0} add(dot, arg)
- // can be replaced with the following
- // arg = f32[3,6]{1,0} parameter(0)
- // bitcast = f32[3,1,1,6]{3,2,1,0} bitcast(arg)
- // fused = f32[3,4,5,6]{3,2,1,0} custom-call((..., bitcast)
- auto kept_dimensions = bcast->dimensions();
- for (int i = 0; i < new_shape.rank(); i++) {
- if (!absl::c_linear_search(kept_dimensions, i)) {
- new_shape.set_dimensions(i, 1);
- }
- }
-
- // If rank(new_shape) > rank(dot), extra dimensions with value = 1 can be
- // deleted from the new_shape.
- int64_t rank_difference = new_shape.rank() - dot_shape.rank();
- auto new_dims = new_shape.dimensions();
- std::vector<int64_t> dims_to_delete;
- for (int i = 0; i < rank_difference; ++i) {
- if (new_dims[i] == 1) {
- dims_to_delete.push_back(i);
- }
- }
- new_shape = ShapeUtil::DeleteDimensions(dims_to_delete, new_shape);
-
- // New shape for bias should satisfy the condition:
- // rank(new_shape) <= rank(dot).
- if (new_shape.rank() > dot_shape.rank()) {
- return absl::CancelledError(
- "Bias shape could not be adjusted for a fusion.");
- }
-
- return new_shape;
-};
-
-inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) {
- // Check if the operand's shape is compatible with matmul for fusion.
- // An operand is fusable if
- // 1. rank(operand) <= rank(dot) and
- // 2. Starting from the last dim in backward direction, the dimension
- // size of operand is either 1 or same to dot.
- auto operand_dims = operand->shape().dimensions();
- auto dot_dims = dot->shape().dimensions();
- if (operand_dims.size() > dot_dims.size()) return false;
- int operand_idx = operand_dims.size() - 1;
- int dot_idx = dot_dims.size() - 1;
- for (; operand_idx >= 0; --operand_idx, --dot_idx) {
- if (operand_dims[operand_idx] != 1 &&
- operand_dims[operand_idx] != dot_dims[dot_idx])
- return false;
- }
- return true;
-}
-
-template <typename Pattern>
-inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert,
- HloInstruction** optional_bitcast,
- Pattern pattern) {
- // Checks the presence of some intermediate operations that can be moved /
- // folded to allow dot fusion with add.
- // Try to match either of the following:
- // 1. pattern-root -> bf16/f16-to-fp32 convert -> bitcast
- // 2. pattern-root -> bf16/f16-to-fp32 convert
- // 3. pattern-root -> bitcast
- // 4. pattern-root
- auto common = m::AnyOf<HloInstruction>(
- pu::SupportedConvert(optional_convert, std::move(pattern).WithOneUser())
- .WithElementType(PrimitiveType::F32),
- std::move(pattern).WithOneUser());
- return m::AnyOf<HloInstruction>(
- BitcastWithReshapeSemantics(optional_bitcast, common), common);
-}
-
-} // namespace
-
-bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) {
- // Currently, blocking control dependencies
- if (dot_instr->HasControlDependencies()) return false;
- if (!IsSupportedType(dot_instr->shape().element_type())) return false;
- if (dot_instr->operands().size() != 2) return false;
-
- // Currently, we rewrite when the data type is F32 or BF16. Note we do not
- // need to check equality of contraction dim-size of the operands. HLO
- // verifier already does the job. We, however, need to check if contraction
- // is over only 1 dimension (a.k.a. K dimension in matrix-multiplication
- // parlance). We also restrict that batch dimensions of the operands
- // match.
- const Shape& lhs_shape = dot_instr->operand(0)->shape();
- const Shape& rhs_shape = dot_instr->operand(1)->shape();
- const Shape& output_shape = dot_instr->shape();
- // None of the operands and result should be ZeroElementArray.
- if (ShapeUtil::IsZeroElementArray(lhs_shape) ||
- ShapeUtil::IsZeroElementArray(rhs_shape) ||
- ShapeUtil::IsZeroElementArray(output_shape)) {
- return false;
- }
- // OneDNN only supports rank <= kOneDnnMaxNDims and singular non-contracting
- // dimensions. We should not rewrite if any of these conditions are violated.
- if (lhs_shape.rank() <= 0 || lhs_shape.rank() > kOneDnnMaxNDims ||
- rhs_shape.rank() <= 0 || rhs_shape.rank() > kOneDnnMaxNDims ||
- output_shape.rank() > std::min({lhs_shape.rank(), rhs_shape.rank(),
- static_cast<int64_t>(kOneDnnMaxNDims)})) {
- return false;
- }
-
- // Layout should be row-major, contraction dimensions captures transpose
- // scenarios in last two dimensions.
- // Col-major layouts are corrected to row-majow for BatchDot operation as
- // part of the layout-pass.
- if (!IsBatchDot(*dot_instr) &&
- (!IsRowMajor(lhs_shape) || !IsRowMajor(rhs_shape) ||
- !IsRowMajor(output_shape))) {
- return false;
- }
-
- auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
- int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
- int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
- // Supported contraction is only in one of last two dimensions.
- if (lhs_dim_k < lhs_shape.rank() - 2 || rhs_dim_k < rhs_shape.rank() - 2) {
- return false;
- }
-
- // OneDNN matmul has scratch allocation and copy overheads. The overheads
- // can be amortized if there is sufficient number of flops. We don't rewrite
- // for small cases (determined empirically).
- // TODO(intel-tf): Relax the condition when more optimizations in oneDNN
- // matmul is achieved.
- auto num_flops = xla::HloCostAnalysis::GetDotFlops(lhs_shape, output_shape,
- dot_dim_numbers);
- auto rank = output_shape.rank();
- auto flops_threshold = (rank <= 2) ? (1 << 24) : (1 << 19);
- return (num_flops >= flops_threshold);
-}
-
-class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
- public:
- // Matches patterns for possible MatMul fusions that are supported by oneDNN
- // library. Matched HLO instruction(s) are replaced by custom call.
- absl::Status HandleDot(HloInstruction* instr) override {
- HloInstruction* dot_instr;
- auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot);
- if (!Match(instr, pattern)) return absl::OkStatus();
-
- TF_RETURN_IF_ERROR(
- ValidateDotDimensionNumbers(dot_instr->dot_dimension_numbers()));
- if (!OneDnnMatMulRewriter::ShouldRewrite(dot_instr))
- return absl::OkStatus();
- TF_ASSIGN_OR_RETURN(dot_instr, ReconfigureDotDimensions(dot_instr));
- auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
- const Shape& lhs_shape = dot_instr->operand(0)->shape();
- const Shape& rhs_shape = dot_instr->operand(1)->shape();
- const Shape& output_shape = dot_instr->shape();
-
- int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
- int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
-
- HloInstruction* matmul_call =
- dot_instr->AddInstruction(HloInstruction::CreateCustomCall(
- output_shape,
- {dot_instr->mutable_operand(0), dot_instr->mutable_operand(1)},
- "__onednn$matmul"));
- // Set additional info via config, e.g., transpose and fusion info.
- BackendConfig backend_config;
- OneDnnMatMulConfig* matmul_config =
- backend_config.mutable_onednn_matmul_config();
- bool transpose_a = (lhs_dim_k != lhs_shape.rank() - 1);
- bool transpose_b = (rhs_dim_k != rhs_shape.rank() - 2);
- matmul_config->set_transpose_a(transpose_a);
- matmul_config->set_transpose_b(transpose_b);
- TF_RETURN_IF_ERROR(matmul_call->set_backend_config(backend_config));
- TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, matmul_call));
- return absl::OkStatus();
- }
-
- absl::Status HandleAdd(HloInstruction* instr) override {
- // Try to do a fusion for Dot(onednn-matmul) + Add. However,
- // HLO Add instruction might receive the addends after additional
- // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw
- // addends. Here, the following possible pattern is matched.
- //
- // clang-format off
- //
- // Dot addend
- // | |
- // v v
- // optional instructions optional instructions
- // (e.g, Convert, Bitcast) (e.g, Convert, Broadcast)
- // | |
- // +--------------+-------------------+
- // |
- // v
- // Add
- //
- // clang-format on
-
- HloInstruction *addend_intermediate, *dot;
- HloInstruction* optional_dot_bitcast = nullptr;
- HloInstruction* optional_dot_convert = nullptr;
-
- auto pattern = m::AddAnyOrder(
- &instr,
- OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast,
- OneDnnMatmulInstr(&dot))
- .WithOneUser(),
- m::Op(&addend_intermediate));
-
- if (Match(instr, pattern)) {
- if (!IsSupportedType(dot->shape().element_type()))
- return absl::OkStatus();
- // TODO(intel-tf): Remove the condition below when the fusion Dot +
- // Add(bias) + Add(e.g., residual) is enabled.
- if (!dot->backend_config<BackendConfig>()
- ->mutable_onednn_matmul_config()
- ->mutable_fusions()
- ->ops()
- .empty() &&
- dot->backend_config<BackendConfig>()
- ->mutable_onednn_matmul_config()
- ->mutable_fusions()
- ->ops(0) == OneDnnFusionConfig::BIAS) {
- return absl::OkStatus();
- }
- std::vector<HloInstruction*> new_operands;
- for (auto operand : dot->operands()) {
- new_operands.push_back(operand);
- }
-
- // At this point, the addend could have one of the following
- // possiblities that the current fusion can handle:
- //
- // - addend -> Convert -> Broadcast -> Add
- // - addend -> Broadcast -> Convert -> Add
- // - addend -> Convert
- // - addend -> Broadcast
- // - addend
- //
- // Hunt for addend through possible sequences above and check the addend
- // is compatible to onednn-matmul fusion.
- HloInstruction* addend = nullptr;
- HloInstruction* optional_addend_broadcast = nullptr;
- auto addend_pattern = m::AnyOf<HloInstruction>(
- m::Broadcast(&optional_addend_broadcast,
- m::Convert(&addend, m::Op())),
- m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))),
- m::Convert(&addend, m::Op()),
- m::Broadcast(&optional_addend_broadcast, m::Op(&addend)),
- m::Op(&addend));
- if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus();
-
- if (optional_addend_broadcast && addend->shape().rank() != 1) {
- auto new_shape =
- AdjustBiasShape(optional_addend_broadcast, dot->shape());
- if (new_shape.ok()) {
- addend = addend->AddInstruction(
- HloInstruction::CreateBitcast(new_shape.value(), addend));
- } else {
- VLOG(2) << new_shape.status();
- return absl::OkStatus();
- }
- }
-
- // Validate addend for fusion.
- if (IsSupportedType(addend->shape().element_type()) &&
- IsOperandFusible(addend, dot)) {
- new_operands.push_back(addend);
- } else {
- return absl::OkStatus();
- }
-
- // TODO(intel-tf): Remove this restriction once oneDNN has an optimized
- // implementation for broadcasted add across all dimensions.
- OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED;
- kind = (addend->shape().rank() == 1)
- ? (dot->backend_config<BackendConfig>()
- ->mutable_onednn_matmul_config()
- ->fusions()
- .ops()
- .empty()
- ? OneDnnFusionConfig::BIAS
- : OneDnnFusionConfig::UNDEFINED)
- : OneDnnFusionConfig::BINARY_ADD;
- if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus();
-
- auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
- dot->CloneWithNewOperands(dot->shape(), new_operands)));
-
- auto backend_config = matmul_call->backend_config<BackendConfig>();
- backend_config->mutable_onednn_matmul_config()
- ->mutable_fusions()
- ->add_ops(kind);
-
- if (optional_addend_broadcast) {
- backend_config->mutable_onednn_matmul_config()
- ->mutable_optimization_config()
- ->set_bias_broadcast(true);
- }
- TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
-
- HloInstruction* new_instr;
- // If matched pattern has custom-call -> bitcast -> add, then we need to
- // insert bitcast after the new fusion to maintain the correct shape
- // (new-custom-call -> bitcast). Also, this will optionally be followed
- // by -> convert for bf16 case to avoid datatype mismatch.
- if (optional_dot_bitcast != nullptr &&
- optional_dot_bitcast->opcode() == HloOpcode::kBitcast) {
- if (optional_dot_convert != nullptr &&
- optional_dot_convert->opcode() == HloOpcode::kConvert) {
- auto bitcast_call =
- matmul_call->AddInstruction(HloInstruction::CreateBitcast(
- ShapeUtil::ChangeElementType(
- instr->shape(), matmul_call->shape().element_type()),
- matmul_call));
- new_instr =
- bitcast_call->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::ChangeElementType(
- bitcast_call->shape(),
- optional_dot_convert->shape().element_type()),
- bitcast_call));
- } else {
- new_instr = matmul_call->AddInstruction(
- HloInstruction::CreateBitcast(instr->shape(), matmul_call));
- }
- } else {
- if (optional_dot_convert != nullptr &&
- optional_dot_convert->opcode() == HloOpcode::kConvert) {
- new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::ChangeElementType(
- matmul_call->shape(),
- optional_dot_convert->shape().element_type()),
- matmul_call));
- } else {
- new_instr = matmul_call;
- }
- }
- TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
- }
-
- return absl::OkStatus();
- }
-
- absl::Status HandleMaximum(HloInstruction* instr) override {
- HloInstruction* matmul_call;
- HloInstruction* intermediate_instr = nullptr;
- HloInstruction* optional_bitcast = nullptr;
- // Attempt to elide maximum and fuse ReLU activation into GEMM, including
- // when slicing or bitcasting is applied to the result.
- if (Match(instr,
- m::MaximumAnyOrder(ElementwiseSafeIntermediates(
- &intermediate_instr, &optional_bitcast,
- OneDnnMatmulInstr(&matmul_call))
- .WithOneUser(),
- BcastConstScalar(0)))) {
- return FuseActivation(OneDnnFusionConfig::RELU, instr, matmul_call,
- intermediate_instr, optional_bitcast);
- }
- return absl::OkStatus();
- }
-
- auto ELUActivation(HloInstruction* instr, HloInstruction** src) {
- // Reference: tensorflow/compiler/tf2xla/kernels/elu_op.cc
- // const auto zero = ScalarLike(x, 0);
- // const auto pred = Gt(x, zero);
- // const auto expm1 = Expm1(x);
- // return Select(pred, x, expm1);
- auto pattern = m::Select(
- m::Gt(pu::OptionalConvert(m::Op(src)), BcastConvertConstScalar(0)),
- m::Op(src),
- pu::OptionalConvert(m::Expm1(pu::OptionalConvert(m::Op(src)))));
- return Match(instr, pattern);
- }
-
- absl::Status HandleSelect(HloInstruction* instr) override {
- HloInstruction* matmul_call;
- HloInstruction* intermediate_instr = nullptr;
- HloInstruction* optional_bitcast = nullptr;
- HloInstruction* src;
- // Attempt to elide ELU subgraph and fuse ELU activation into GEMM,
- // including when slicing or bitcasting is applied to the result.
- if (ELUActivation(instr, &src)) {
- if (Match(src, ElementwiseSafeIntermediates(
- &intermediate_instr, &optional_bitcast,
- OneDnnMatmulInstr(&matmul_call)))) {
- return FuseActivation(OneDnnFusionConfig::ELU, instr, matmul_call,
- intermediate_instr);
- }
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleTanh(HloInstruction* instr) override {
- HloInstruction* matmul_call;
- HloInstruction* intermediate_instr = nullptr;
- HloInstruction* optional_bitcast = nullptr;
- // Attempt to elide Tanh and fuse Tanh activation into GEMM, including
- // when slicing or bitcasting is applied to the result.
- if (Match(instr, m::Tanh(ElementwiseSafeIntermediates(
- &intermediate_instr, &optional_bitcast,
- OneDnnMatmulInstr(&matmul_call))
- .WithOneUser()))) {
- return FuseActivation(OneDnnFusionConfig::TANH, instr, matmul_call,
- intermediate_instr);
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleClamp(HloInstruction* instr) override {
- HloInstruction* matmul_call;
- HloInstruction* intermediate_instr = nullptr;
- HloInstruction* optional_bitcast = nullptr;
- // Attempt to elide RELU6 and fuse RELU6 activation into GEMM, including
- // when slicing or bitcasting is applied to the result.
- if (Match(instr, m::Clamp(BcastConstScalar(0),
- ElementwiseSafeIntermediates(
- &intermediate_instr, &optional_bitcast,
- OneDnnMatmulInstr(&matmul_call))
- .WithOneUser(),
- BcastConstScalar(6)))) {
- return FuseActivation(OneDnnFusionConfig::RELU6, instr, matmul_call,
- intermediate_instr);
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleMultiply(HloInstruction* instr) override {
- HloInstruction* matmul_call;
- HloInstruction* intermediate_instr = nullptr;
- HloInstruction* src;
- auto activation = GELUActivation(instr, &src);
- if (activation != OneDnnFusionConfig::UNDEFINED) {
- HloInstruction* optional_bitcast = nullptr;
- if (Match(src, ElementwiseSafeIntermediates(
- &intermediate_instr, &optional_bitcast,
- OneDnnMatmulInstr(&matmul_call)))) {
- return FuseActivation(activation, instr, matmul_call,
- intermediate_instr, optional_bitcast);
- }
- }
-
- HloInstruction *dot, *constant;
- HloInstruction* optional_convert = nullptr;
- auto pattern = m::Op(&instr)
- .WithOpcode(HloOpcode::kMultiply)
- .WithBinaryOperandsAnyOrder(
- m::AnyOf<HloInstruction>(
- pu::SupportedConvert(&optional_convert,
- OneDnnMatmulInstr(&dot))
- .WithElementType(PrimitiveType::F32),
- OneDnnMatmulInstr(&dot))
- .WithOneUser(),
- m::Broadcast(m::Constant(&constant)));
-
- if (Match(instr, pattern)) {
- std::vector<HloInstruction*> new_operands;
- auto constant_value = GetConstantValueAsFloat32(constant);
- if (!constant_value) {
- return absl::OkStatus();
- }
-
- for (auto operand : dot->operands()) {
- new_operands.push_back(operand);
- }
- auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
- dot->CloneWithNewOperands(instr->shape(), new_operands)));
- auto backend_config = matmul_call->backend_config<BackendConfig>();
- backend_config->mutable_onednn_matmul_config()
- ->mutable_fusions()
- ->add_ops(OneDnnFusionConfig::LINEAR);
- // Casting to int32 because of issues in proto config for decimal types
- // handling.
- backend_config->mutable_onednn_matmul_config()
- ->mutable_fusions()
- ->set_alpha_typecast(
- *(reinterpret_cast<int32_t*>(&constant_value.value())));
- TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
- HloInstruction* new_instr;
- if (optional_convert != nullptr &&
- optional_convert->opcode() == HloOpcode::kConvert) {
- new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::ChangeElementType(
- matmul_call->shape(), optional_convert->shape().element_type()),
- matmul_call));
- } else {
- new_instr = matmul_call;
- }
-
- TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
- }
- return absl::OkStatus();
- }
-
- auto SigmoidActivation(HloInstruction* instr, HloInstruction** src) {
- return Match(instr,
- m::Divide(BcastConstScalar(1.0),
- m::AddAnyOrder(BcastConstScalar(1.0),
- m::Exp(m::Negate(m::Op(src))))));
- }
-
- absl::Status HandleDivide(HloInstruction* instr) override {
- HloInstruction* matmul_call;
- HloInstruction* intermediate_instr = nullptr;
- HloInstruction* optional_bitcast = nullptr;
- HloInstruction* src;
- if (SigmoidActivation(instr, &src)) {
- if (Match(src, ElementwiseSafeIntermediates(
- &intermediate_instr, &optional_bitcast,
- OneDnnMatmulInstr(&matmul_call))
- .WithOneUser())) {
- return FuseActivation(OneDnnFusionConfig::SIGMOID, instr, matmul_call,
- intermediate_instr, optional_bitcast);
- }
- }
- return absl::OkStatus();
- }
-
- absl::Status FuseActivation(OneDnnFusionConfig_FusionKind kind,
- HloInstruction* activation,
- HloInstruction* matmul,
- HloInstruction* intermediate_instr = nullptr,
- HloInstruction* optional_bitcast = nullptr) {
- TF_ASSIGN_OR_RETURN(auto backend_config,
- matmul->backend_config<BackendConfig>());
- auto* matmul_config = backend_config.mutable_onednn_matmul_config();
- matmul_config->mutable_fusions()->add_ops(kind);
- TF_RETURN_IF_ERROR(matmul->set_backend_config(backend_config));
- std::unique_ptr<HloInstruction> output = matmul->Clone();
- if (optional_bitcast != nullptr &&
- optional_bitcast->opcode() == HloOpcode::kBitcast) {
- HloInstruction* new_instr = nullptr;
- if (intermediate_instr != nullptr &&
- intermediate_instr->opcode() == HloOpcode::kConvert) {
- auto bitcast_call =
- matmul->AddInstruction(HloInstruction::CreateBitcast(
- ShapeUtil::ChangeElementType(optional_bitcast->shape(),
- matmul->shape().element_type()),
- matmul));
- new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::ChangeElementType(
- bitcast_call->shape(),
- intermediate_instr->shape().element_type()),
- bitcast_call));
- return ReplaceInstruction(activation, new_instr);
- }
- } else if (intermediate_instr) {
- output = intermediate_instr->CloneWithNewOperands(
- intermediate_instr->shape(),
- {matmul->parent()->AddInstruction(std::move(output))});
- }
-
- return ReplaceWithNewInstruction(activation, std::move(output));
- }
-
- // This function changes dot instruction for supported matrix
- // multiplication scenarios. In particular, it changes the shape
- // of lhs, rhs and result arrays.
- // - lhs configuration scenario
- // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
- // result: [batch_dims,feature_dim] to [batch_dims,1,feature_dim]
- //
- // - rhs configuration scenario
- // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
- // result: [batch_dims,feature_dim] to [batch_dims,feature_dim, 1]
- //
- // - both lhs and rhs configuration scenario
- // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
- // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
- // result: [batch_dims] to [batch_dims,1,1]
- absl::StatusOr<HloInstruction*> ReconfigureDotDimensions(
- HloInstruction* dot_instr) {
- HloInstruction* lhs = dot_instr->mutable_operand(0);
- HloInstruction* rhs = dot_instr->mutable_operand(1);
- DotDimensionNumbers dim_numbers = dot_instr->dot_dimension_numbers();
-
- auto lhs_batch_dims = dim_numbers.lhs_batch_dimensions();
- auto lhs_contraction_dims = dim_numbers.lhs_contracting_dimensions();
- bool is_lhs_vector = lhs->shape().rank() ==
- (lhs_batch_dims.size() + lhs_contraction_dims.size());
-
- auto rhs_batch_dims = dim_numbers.rhs_batch_dimensions();
- auto rhs_contraction_dims = dim_numbers.rhs_contracting_dimensions();
- bool is_rhs_vector = rhs->shape().rank() ==
- (rhs_batch_dims.size() + rhs_contraction_dims.size());
-
- if (!is_lhs_vector && !is_rhs_vector) return dot_instr;
-
- std::vector<int64_t> adjusted_lhs_dims(lhs->shape().dimensions().begin(),
- lhs->shape().dimensions().end());
- std::vector<int64_t> adjusted_rhs_dims(rhs->shape().dimensions().begin(),
- rhs->shape().dimensions().end());
- std::vector<int64_t> adjusted_dot_dims(
- dot_instr->shape().dimensions().begin(),
- dot_instr->shape().dimensions().end());
-
- if (is_lhs_vector) {
- auto lhs_it = adjusted_lhs_dims.begin() + lhs_batch_dims.size();
- adjusted_lhs_dims.insert(lhs_it, 1, 1);
- auto result_it = adjusted_dot_dims.begin() + lhs_batch_dims.size();
- adjusted_dot_dims.insert(result_it, 1, 1);
- auto lhs_contraction_dim =
- dot_instr->dot_dimension_numbers().lhs_contracting_dimensions(0);
- dim_numbers.set_lhs_contracting_dimensions(0, lhs_contraction_dim + 1);
- lhs = lhs->AddInstruction(HloInstruction::CreateBitcast(
- ShapeUtil::MakeShape(lhs->shape().element_type(), adjusted_lhs_dims),
- lhs));
- }
-
- if (is_rhs_vector) {
- auto it = adjusted_rhs_dims.end();
- adjusted_rhs_dims.insert(it, 1, 1);
- auto result_it = adjusted_dot_dims.end();
- adjusted_dot_dims.insert(result_it, 1, 1);
- rhs = rhs->AddInstruction(HloInstruction::CreateBitcast(
- ShapeUtil::MakeShape(rhs->shape().element_type(), adjusted_rhs_dims),
- rhs));
- }
-
- HloInstruction* adjusted_dot =
- dot_instr->AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(dot_instr->shape().element_type(),
- adjusted_dot_dims),
- lhs, rhs, dim_numbers, dot_instr->precision_config()));
-
- HloInstruction* replacement_instr = adjusted_dot->AddInstruction(
- HloInstruction::CreateBitcast(dot_instr->shape(), adjusted_dot));
-
- TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr));
- return adjusted_dot;
- }
-};
-
-class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
- public:
- OneDnnPostRewriteVisitor(int intra_op_parallelism,
- const tsl::thread::ThreadPool* compile_threadpool)
- : intra_op_parallelism_(intra_op_parallelism > 0
- ? intra_op_parallelism
- : tsl::port::MaxParallelism()),
- evaluator_(/*max_loop_iterations=*/0) {
- if (compile_threadpool) {
- threadpool_device_.reset(
- new Eigen::ThreadPoolDevice(compile_threadpool->AsEigenThreadPool(),
- compile_threadpool->NumThreads()));
- } else {
- threadpool_handle_.reset(new tsl::thread::ThreadPool(
- tsl::Env::Default(), "XLACpuCompile", tsl::port::MaxParallelism()));
- threadpool_device_.reset(
- new Eigen::ThreadPoolDevice(threadpool_handle_->AsEigenThreadPool(),
- threadpool_handle_->NumThreads()));
- }
-
-#ifndef ENABLE_ONEDNN_OPENMP
- // Set oneDNN concurrency settings (which is thread-local)
- tsl::OneDnnThreadPool::set_onednn_max_threads(intra_op_parallelism_);
-#endif
- }
-
- absl::Status HandleCustomCall(HloInstruction* custom_call) override {
- HloInstruction* matmul;
- if (Match(custom_call, OneDnnMatmulInstr(&matmul))) {
- return HandleCustomCallInternal<dnnl::matmul::primitive_desc>(
- custom_call);
- }
-
- return DefaultAction(custom_call);
- }
-
- template <typename PrimDesc>
- absl::Status HandleCustomCallInternal(HloInstruction* custom_call) {
- auto scratch_add = AddScratch<PrimDesc>(custom_call);
- if (scratch_add.ok()) {
- custom_call = *scratch_add;
- } else {
- VLOG(2) << scratch_add.status();
- }
- auto weights_prepack = PrepackWeights<PrimDesc>(custom_call);
- if (!weights_prepack.ok()) {
- VLOG(2) << weights_prepack.status();
- }
- return absl::OkStatus();
- }
-
- template <typename>
- absl::Status SetWeightsPrepack(HloInstruction*, bool);
-
- template <typename>
- absl::Status SetUserScratch(HloInstruction*, bool);
-
- template <typename>
- bool GetWeightsPrepack(HloInstruction*);
-
- template <typename>
- bool GetUserScratch(HloInstruction*);
-
- // Add scratch for matmul by changing the result of custom-call to
- // tuple(result, scratch)
- template <typename PrimDesc>
- absl::StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) {
- if (GetUserScratch<PrimDesc>(custom_call)) {
- return custom_call;
- }
- TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, true));
- auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
- int64_t scratch_size = prim_desc->scratchpad_desc().get_size();
- Shape scratch_shape = ShapeUtil::MakeShape(U8, {scratch_size});
- Shape tuple_shape =
- ShapeUtil::MakeTupleShape({custom_call->shape(), scratch_shape});
- auto new_custom_call = custom_call->AddInstruction(
- custom_call->CloneWithNewShape(tuple_shape));
- HloInstruction* gte =
- new_custom_call->AddInstruction(HloInstruction::CreateGetTupleElement(
- custom_call->shape(), new_custom_call, 0));
- auto status = ReplaceInstruction(custom_call, gte);
- if (!status.ok()) {
- TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, false));
- return absl::CancelledError("Adding scratch is unsuccessful.");
- }
- return new_custom_call;
- }
-
- template <typename PrimDesc>
- absl::StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) {
- if (GetWeightsPrepack<PrimDesc>(custom_call)) {
- return custom_call;
- }
- auto weights = custom_call->operand(1);
- auto weights_shape = weights->shape();
- Literal weights_literal;
- if (!(weights_shape.rank() == 2 &&
- evaluator_.TryEvaluate(weights, &weights_literal, true))) {
- return absl::CancelledError(
- "Cannot prepack weights. Not constant 2D weights.");
- }
- auto plain_weights_md = ShapeToMemDesc(weights_shape);
- if constexpr (std::is_same<PrimDesc, dnnl::matmul::primitive_desc>::value) {
- TF_ASSIGN_OR_RETURN(auto backend_config,
- custom_call->backend_config<BackendConfig>());
- TRANSPOSE_LAST_TWO_DIMS_IF(
- backend_config.onednn_matmul_config().transpose_b(),
- plain_weights_md);
- }
- TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, true));
- auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
- auto packed_weights_md = prim_desc->weights_desc();
- auto packed_weights_shape = MemDescToXlaShapeFlattened(packed_weights_md);
- auto packed_weights_literal = Literal(packed_weights_shape);
- ReorderWeight(plain_weights_md, weights_literal.untyped_data(),
- packed_weights_md, packed_weights_literal.untyped_data());
- HloInstruction* reordered_weight = custom_call->AddInstruction(
- HloInstruction::CreateConstant(std::move(packed_weights_literal)));
- auto status =
- custom_call->ReplaceOperandWithDifferentShape(1, reordered_weight);
- if (!status.ok()) {
- TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, false));
- return absl::CancelledError(
- "Cannot replace plain weights with prepacked weights.");
- } else {
- return custom_call;
- }
- }
-
- void ReorderWeight(const dnnl::memory::desc& src_md, void* src_buf,
- const dnnl::memory::desc& dst_md, void* dst_buf) {
- auto onednn_threadpool = CreateOneDnnThreadPool(threadpool_device_.get());
- dnnl::engine cpu_engine(dnnl::engine::kind::cpu, 0);
- auto onednn_stream = MakeOneDnnStream(cpu_engine, onednn_threadpool.get());
- auto src_mem = dnnl::memory(src_md, cpu_engine, src_buf);
- auto dst_mem = dnnl::memory(dst_md, cpu_engine, dst_buf);
- dnnl::reorder reorder_prim{src_mem, dst_mem};
- reorder_prim.execute(onednn_stream, src_mem, dst_mem);
- onednn_stream.wait();
- }
-
- private:
- int intra_op_parallelism_;
- HloEvaluator evaluator_;
- std::unique_ptr<tsl::thread::ThreadPool> threadpool_handle_;
- std::unique_ptr<Eigen::ThreadPoolDevice> threadpool_device_;
-};
-
-#define EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GETTER, PRIM_DESC, CONFIG, \
- SUB_CONFIG, FIELD) \
- template <> \
- inline bool OneDnnPostRewriteVisitor::GETTER<PRIM_DESC>(HloInstruction * \
- custom_call) { \
- auto backend_config = custom_call->backend_config<BackendConfig>(); \
- return backend_config.ok() ? backend_config->CONFIG().SUB_CONFIG().FIELD() \
- : false; \
- }
-
-EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetUserScratch,
- dnnl::matmul::primitive_desc,
- onednn_matmul_config,
- optimization_config, user_scratchpad);
-EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetWeightsPrepack,
- dnnl::matmul::primitive_desc,
- onednn_matmul_config,
- optimization_config, weights_prepacked);
-
-#define EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SETTER, PRIM_DESC, CONFIG_TYPE, \
- CONFIG, SUB_CONFIG, FIELD) \
- template <> \
- inline absl::Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>( \
- HloInstruction * custom_call, bool value) { \
- TF_ASSIGN_OR_RETURN(auto backend_config, \
- custom_call->backend_config<BackendConfig>()); \
- CONFIG_TYPE* config = backend_config.mutable_##CONFIG(); \
- config->mutable_##SUB_CONFIG()->set_##FIELD(value); \
- return custom_call->set_backend_config(backend_config); \
- }
-
-EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetWeightsPrepack,
- dnnl::matmul::primitive_desc,
- OneDnnMatMulConfig, onednn_matmul_config,
- optimization_config, weights_prepacked);
-EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch,
- dnnl::matmul::primitive_desc,
- OneDnnMatMulConfig, onednn_matmul_config,
- optimization_config, user_scratchpad);
-
-absl::StatusOr<bool> OneDnnMatMulRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- OneDnnMatMulRewriteVisitor visitor;
- TF_ASSIGN_OR_RETURN(auto result,
- visitor.RunOnModule(module, execution_threads));
-
- OneDnnPostRewriteVisitor reorder_visitor(intra_op_parallelism_,
- compile_threadpool_);
- TF_ASSIGN_OR_RETURN(auto result2,
- reorder_visitor.RunOnModule(module, execution_threads));
-
- return {result || result2};
-}
-
-} // namespace cpu
-} // namespace xla
-
-#endif // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h
deleted file mode 100644
index 7ad7f76..0000000
--- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_
-#define XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#include <optional>
-
-#include "absl/algorithm/container.h"
-#include "unsupported/Eigen/CXX11/Tensor"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace cpu {
-
-// This pass pattern-matches HLO Dot instructions and rewrites into custom
-// calls.
-class OneDnnMatMulRewriter : public HloModulePass {
- public:
- OneDnnMatMulRewriter(int intra_op_parallelism,
- const tsl::thread::ThreadPool* compile_threadpool)
- : intra_op_parallelism_(intra_op_parallelism),
- compile_threadpool_(compile_threadpool) {}
- OneDnnMatMulRewriter() = default;
- absl::string_view name() const override { return "onednn-matmul-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- static bool ShouldRewrite(const HloInstruction* dot_instr);
-
- private:
- int intra_op_parallelism_;
- const tsl::thread::ThreadPool* compile_threadpool_;
-};
-
-} // namespace cpu
-} // namespace xla
-
-#endif // INTEL_MKL && ENABLE_ONEDNN_V3
-#endif // XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_
diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD
index 9e5bd2e..a7e97b0 100644
--- a/third_party/xla/xla/service/cpu/runtime/BUILD
+++ b/third_party/xla/xla/service/cpu/runtime/BUILD
@@ -160,6 +160,7 @@
"//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc
index a0cd9f4..32a452a 100644
--- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc
@@ -205,10 +205,6 @@
return op_buffers_.source_shapes[index];
}
-absl::Span<const Shape> CollectiveThunk::source_shapes() const {
- return op_buffers_.source_shapes;
-}
-
const BufferAllocation::Slice& CollectiveThunk::destination_buffer(
int64_t index) const {
return op_buffers_.destination_buffers[index];
@@ -223,8 +219,4 @@
return op_buffers_.destination_shapes[index];
}
-absl::Span<const Shape> CollectiveThunk::destination_shapes() const {
- return op_buffers_.destination_shapes;
-}
-
} // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h
index 5bcf16b..5ae9c98 100644
--- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h
+++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h
@@ -77,7 +77,6 @@
OpBuffers op_buffers, OpResources op_resources);
const OpParams& op_params() const { return op_params_; }
- const OpBuffers& op_buffers() const { return op_buffers_; }
// Resolves operation's device memory from the buffers and buffer allocations.
absl::StatusOr<OpDeviceMemory> GetOpDeviceMemory(const ExecuteParams& params);
@@ -109,13 +108,11 @@
absl::Span<const BufferAllocation::Slice> source_buffers() const;
const Shape& source_shape(int64_t index) const;
- absl::Span<const Shape> source_shapes() const;
const BufferAllocation::Slice& destination_buffer(int64_t index) const;
absl::Span<const BufferAllocation::Slice> destination_buffers() const;
const Shape& destination_shape(int64_t index) const;
- absl::Span<const Shape> destination_shapes() const;
private:
OpParams op_params_;
diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc
index 1161673..8c6deca 100644
--- a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc
@@ -196,11 +196,11 @@
// Forward ExecutableRunOptions to the FFI handlers via the call options.
CustomCallExecuteParams* custom_call_params = params.custom_call_params;
- ffi::CallOptions call_options = {custom_call_params->device_ordinal,
- custom_call_params->stream,
- custom_call_params->allocator,
- /*called_computation=*/nullptr,
- custom_call_params->ffi_execution_context};
+ ffi::CallOptions call_options = {
+ custom_call_params->device_ordinal,
+ ffi::CallOptions::CpuOptions{custom_call_params->intra_op_thread_pool},
+ /*called_computation=*/nullptr,
+ custom_call_params->ffi_execution_context};
// Call the function and check execution status.
auto status = ffi::Call(handler->bundle.execute, call_frame, call_options);
diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
index 50b42c7..5ab801c 100644
--- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
@@ -200,7 +200,7 @@
// TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk
// initialization stage.
- se::host::HostKernel* kernel = kernel_ptr_.load(std::memory_order_relaxed);
+ se::host::HostKernel* kernel = kernel_ptr_.load(std::memory_order_acquire);
// Because thunks are owned by a parent CpuExecutable, we can safely assume
// that kernel pointer will not change after we find it the first time.
@@ -209,8 +209,10 @@
params.function_registry->FindKernel(kernel_name_));
absl::MutexLock lock(&mutex_);
- kernel_.emplace(num_kernel_args_, kernel_fn, nullptr);
- kernel_ptr_.store(kernel = &kernel_.value());
+ if ((kernel = kernel_ptr_.load(std::memory_order_relaxed)) == nullptr) {
+ kernel = &kernel_.emplace(num_kernel_args_, kernel_fn, nullptr);
+ kernel_ptr_.store(kernel, std::memory_order_release);
+ }
}
// Use a fast path if kernel called just once.
diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc
index a24a227..041bf03 100644
--- a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc
@@ -471,6 +471,9 @@
case 25:
sort(std::integral_constant<size_t, 25>{});
break;
+ case 29:
+ sort(std::integral_constant<size_t, 29>{});
+ break;
default:
return Internal("Unsupported number of sorted inputs: %d", data.size());
}
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc
index 9588b02..9228de3 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc
@@ -140,18 +140,16 @@
? run_options->device_ordinal()
: run_options->stream()->parent()->device_ordinal();
- return CustomCallExecuteParams{device_ordinal, run_options->stream(),
- run_options->allocator(),
+ return CustomCallExecuteParams{device_ordinal,
+ run_options->intra_op_thread_pool(),
run_options->ffi_execution_context()};
}
Thunk::CustomCallExecuteParams::CustomCallExecuteParams(
- int32_t device_ordinal, stream_executor::Stream* stream,
- stream_executor::DeviceMemoryAllocator* allocator,
+ int32_t device_ordinal, const Eigen::ThreadPoolDevice* intra_op_thread_pool,
const ffi::ExecutionContext* ffi_execution_context)
: device_ordinal(device_ordinal),
- stream(stream),
- allocator(allocator),
+ intra_op_thread_pool(intra_op_thread_pool),
ffi_execution_context(ffi_execution_context) {}
tsl::AsyncValueRef<Thunk::ExecuteEvent> Thunk::OkExecuteEventSingleton() {
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h
index 0e645f2..9141da7 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk.h
+++ b/third_party/xla/xla/service/cpu/runtime/thunk.h
@@ -208,14 +208,12 @@
const ExecutableRunOptions* run_options);
int32_t device_ordinal;
- stream_executor::Stream* stream = nullptr;
- stream_executor::DeviceMemoryAllocator* allocator = nullptr;
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr;
const ffi::ExecutionContext* ffi_execution_context = nullptr;
private:
CustomCallExecuteParams(int32_t device_ordinal,
- stream_executor::Stream* stream,
- stream_executor::DeviceMemoryAllocator* allocator,
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool,
const ffi::ExecutionContext* ffi_execution_context);
};
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
index 805840a..9b4c735 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
+++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
@@ -16,6 +16,7 @@
#include "xla/service/cpu/runtime/thunk_executor.h"
#include <atomic>
+#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
@@ -61,7 +62,7 @@
}
// Erase redundant edges between nodes.
- int64_t num_erased_edges = TransitiveReduction();
+ int64_t num_erased_edges = RunTransitiveReductionAndUpdatePriorities();
// Check if constructed execution DAG is sequential: every node depends on the
// completion of the previous node.
@@ -160,8 +161,14 @@
// Create async execution state on heap and kick-off execution.
auto state = std::make_unique<ExecuteState>(this, params.task_runner);
- Execute(state.get(), params, ReadyQueue(source_.begin(), source_.end()),
- /*lock=*/params.session.Join());
+
+ if (options_.use_priority_ready_queue) {
+ Execute(state.get(), params, PriorityReadyQueue(nodes_defs_, source_),
+ /*lock=*/params.session.Join());
+ } else {
+ Execute(state.get(), params, FifoReadyQueue(source_),
+ /*lock=*/params.session.Join());
+ }
// If execution already completed (all kernels executed in the caller thread),
// immediately return the result to avoid wasteful reference counting below.
@@ -256,13 +263,14 @@
event.SetStateConcrete();
}
+template <typename ReadyQueue>
void ThunkExecutor::Execute(ExecuteState* state,
const Thunk::ExecuteParams& params,
ReadyQueue ready_queue,
Thunk::ExecuteSession::Lock lock) {
tsl::profiler::TraceMe trace("ThunkExecutor::Execute");
- DCHECK(!ready_queue.empty()) << "Ready queue must not be empty";
+ DCHECK(!ready_queue.Empty()) << "Ready queue must not be empty";
DCHECK(lock) << "Execute session lock must be set";
bool has_runner = state->runner != nullptr;
@@ -270,8 +278,8 @@
// Threshold for splitting ready queue into separate thunk executor tasks.
int64_t split_threshold = params.session.split_threshold();
- for (int64_t i = 0; i < ready_queue.size(); ++i) {
- NodeId id = ready_queue[i];
+ while (!ready_queue.Empty()) {
+ NodeId id = ready_queue.Pop();
ExecuteState::Node& node = state->node(id);
int64_t cnt = node.counter.load(std::memory_order_acquire);
@@ -279,9 +287,9 @@
// If we have multiple ready thunks, split the ready queue and offload
// thunks processing to the task runner.
- int64_t num_ready_thunks = ready_queue.size() - i;
+ int64_t num_ready_thunks = ready_queue.Size();
if (ABSL_PREDICT_FALSE(has_runner && num_ready_thunks > split_threshold)) {
- SplitReadyQueue(state, params, /*start_index=*/i + 1, ready_queue);
+ SplitReadyQueue(state, params, ready_queue, split_threshold);
}
// Execute thunk for the given node id. If execution is aborted, we keep
@@ -307,13 +315,13 @@
// the same execute session.
execute_event.AndThen([¶ms, &node, state,
execute_event = execute_event.AsPtr(),
+ ready_queue = ready_queue.CreateEmptyReadyQueue(),
lock = params.session.Join()]() mutable {
- ReadyQueue ready_queue;
state->executor->ProcessOutEdges(state, execute_event, node,
ready_queue);
// If ready queue is empty it might mean that we have completed an
// execution and destroyed the `state`.
- if (ABSL_PREDICT_TRUE(!ready_queue.empty())) {
+ if (ABSL_PREDICT_TRUE(!ready_queue.Empty())) {
state->executor->Execute(state, params, std::move(ready_queue),
std::move(lock));
}
@@ -322,17 +330,17 @@
}
}
+template <typename ReadyQueue>
inline ABSL_ATTRIBUTE_ALWAYS_INLINE void ThunkExecutor::SplitReadyQueue(
ExecuteState* state, const Thunk::ExecuteParams& params,
- int64_t start_index, ReadyQueue& ready_queue) {
+ ReadyQueue& ready_queue, int64_t split_threshold) {
DCHECK(state->runner) << "TaskRunner must be set";
- int64_t end_index = ready_queue.size();
// We use recursive work splitting to push the tail of the ready queue to
// the task runner. Recursive work splitting creates a more uniform work
// distribution across the task runner threads and avoids a situation when
// we have a long tail of work that is processed by a single thread.
- while (end_index > start_index) {
+ while (ready_queue.Size() > split_threshold) {
// Try to acquire a lock to offload ready thunks to the task runner. If
// we can't get a lock, we will keep processing the ready queue in the
// current thread as it means that we have enough concurrent workers
@@ -342,22 +350,16 @@
break;
}
- // Execute [mid_index, end_index) nodes in the task runner.
- int64_t mid_index = (start_index + end_index) / 2;
- (*state->runner)([¶ms, state,
- ready_queue = ReadyQueue(ready_queue.begin() + mid_index,
- ready_queue.begin() + end_index),
+ // Execute half of the ready queue nodes in the task runner.
+ (*state->runner)([¶ms, state, ready_queue = ready_queue.PopHalf(),
lock = std::move(task_runner_lock)]() mutable {
state->executor->Execute(state, params, std::move(ready_queue),
std::move(lock));
});
- end_index = mid_index;
}
-
- // Erase ready nodes passed to the task runner.
- ready_queue.erase(ready_queue.begin() + end_index, ready_queue.end());
}
+template <typename ReadyQueue>
void ThunkExecutor::ProcessOutEdges(
ExecuteState* state, tsl::AsyncValuePtr<Thunk::ExecuteEvent> node_event,
ExecuteState::Node& node, ReadyQueue& ready_queue) {
@@ -380,7 +382,7 @@
int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release);
DCHECK_GE(cnt, 1) << "Node counter can't drop below 0";
- if (cnt == 1) ready_queue.push_back(out_edge);
+ if (cnt == 1) ready_queue.Push(out_edge);
}
// Drop the pending sink nodes counter if the node is a sink.
@@ -431,7 +433,7 @@
return 0;
}
-int64_t ThunkExecutor::TransitiveReduction() {
+int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities() {
int64_t num_erased_edges = 0;
// Keep workspace for DFS traversal between iterations.
@@ -454,11 +456,11 @@
stack.clear();
visited.assign(nodes_defs_.size(), false);
- // Initialize stack with nodes reachable via immediate out nodes. We don't
- // need to add source node and immediate out nodes to the visited set
- // because graph is acyclic and we don't visit them again.
+ // Initialize stack with nodes reachable via immediate out nodes. We mark
+ // immediate out nodes as visited to correctly compute node priority below.
for (int64_t out_id : source_node.out_edges) {
NodeDef& out_node = nodes_defs_[out_id];
+ visited[out_id] = true;
for (int64_t start_id : out_node.out_edges) add_to_stack(start_id);
}
@@ -472,6 +474,9 @@
for (int64_t out_id : node.out_edges) add_to_stack(out_id);
}
+
+ // Set node priority to the number of visited nodes in the DFS traversal.
+ source_node.priority = absl::c_count(visited, true);
}
return num_erased_edges;
@@ -495,14 +500,88 @@
const Thunk& thunk = *thunk_sequence_[i];
bool is_source = absl::c_find(source_, i) != source_.end();
bool is_sink = absl::c_find(sink_, i) != sink_.end();
- absl::StrAppendFormat(
- &str,
- "\n thunk #%05d: op_name=%s, dependencies=[%s], source=%v, sink=%v", i,
- thunk.info().op_name, absl::StrJoin(in_edges[i], ", "), is_source,
- is_sink);
+ absl::StrAppendFormat(&str,
+ "\n thunk #%05d: op_name=%s, dependencies=[%s], "
+ "source=%v, sink=%v, priority=%d",
+ i, thunk.info().op_name,
+ absl::StrJoin(in_edges[i], ", "), is_source, is_sink,
+ nodes_defs_[i].priority);
}
return str;
}
+ThunkExecutor::FifoReadyQueue::FifoReadyQueue(
+ absl::Span<const NodeId> ready_nodes)
+ : queue_(ready_nodes.begin(), ready_nodes.end()) {}
+
+void ThunkExecutor::FifoReadyQueue::Push(NodeId id) { queue_.push_back(id); }
+
+ThunkExecutor::NodeId ThunkExecutor::FifoReadyQueue::Pop() {
+ DCHECK(!Empty()) << "Queue must not be empty";
+ return queue_[head_++];
+}
+
+ThunkExecutor::FifoReadyQueue ThunkExecutor::FifoReadyQueue::PopHalf() {
+ DCHECK(!Empty()) << "Queue must not be empty";
+ auto mid = queue_.begin() + head_ + Size() / 2;
+ FifoReadyQueue popped(absl::MakeConstSpan(&*mid, queue_.end() - mid));
+ queue_.resize(mid - queue_.begin());
+ return popped;
+}
+
+size_t ThunkExecutor::FifoReadyQueue::Size() const {
+ return queue_.size() - head_;
+}
+
+bool ThunkExecutor::FifoReadyQueue::Empty() const {
+ return head_ == queue_.size();
+}
+
+ThunkExecutor::FifoReadyQueue
+ThunkExecutor::FifoReadyQueue::CreateEmptyReadyQueue() const {
+ return FifoReadyQueue(absl::Span<const NodeId>());
+}
+
+ThunkExecutor::PriorityReadyQueue::PriorityReadyQueue(
+ absl::Span<const NodeDef> nodes_defs, absl::Span<const NodeId> ready_nodes)
+ : nodes_defs_(nodes_defs),
+ queue_(ready_nodes.begin(), ready_nodes.end(), Compare{nodes_defs}) {}
+
+void ThunkExecutor::PriorityReadyQueue::Push(NodeId id) { queue_.push(id); }
+
+ThunkExecutor::NodeId ThunkExecutor::PriorityReadyQueue::Pop() {
+ DCHECK(!Empty()) << "Queue must not be empty";
+ NodeId id = queue_.top();
+ queue_.pop();
+ return id;
+}
+
+ThunkExecutor::PriorityReadyQueue ThunkExecutor::PriorityReadyQueue::PopHalf() {
+ DCHECK(!Empty()) << "Queue must not be empty";
+ int64_t keep_top_nodes = queue_.size() / 2;
+
+ // First pop nodes with highest priority from the queue.
+ PriorityReadyQueue popped(nodes_defs_, {});
+ while (keep_top_nodes-- > 0) {
+ popped.queue_.push(queue_.top());
+ queue_.pop();
+ }
+
+ // Swap popped nodes with remaining nodes, to return to the caller nodes with
+ // smaller priorities, and keep higher priority nodes in the queue.
+ popped.queue_.swap(queue_);
+
+ return popped;
+}
+
+size_t ThunkExecutor::PriorityReadyQueue::Size() const { return queue_.size(); }
+
+bool ThunkExecutor::PriorityReadyQueue::Empty() const { return queue_.empty(); }
+
+ThunkExecutor::PriorityReadyQueue
+ThunkExecutor::PriorityReadyQueue::CreateEmptyReadyQueue() const {
+ return PriorityReadyQueue(nodes_defs_, {});
+}
+
} // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h b/third_party/xla/xla/service/cpu/runtime/thunk_executor.h
index 10df02c..f0df6cf 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h
+++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor.h
@@ -21,6 +21,7 @@
#include <cstdint>
#include <limits>
#include <new>
+#include <queue>
#include <string>
#include <type_traits>
#include <vector>
@@ -45,6 +46,10 @@
// `execute_sequential_buffer_threshold`, we mark execution as sequential, as
// concurrency overheads will likely dominate the overall execution time.
size_t execute_sequential_buffer_threshold = 512;
+
+ // Use priority ready queue to execute nodes according to their priority. By
+ // default we use FIFO ready queue.
+ bool use_priority_ready_queue = false;
};
} // namespace internal
@@ -72,6 +77,7 @@
// NodeDef defines an execution order for all thunks in a sequence.
struct NodeDef {
NodeId id = kInvalidNodeId;
+ int64_t priority = 0;
std::vector<NodeId> in_edges;
std::vector<NodeId> out_edges;
};
@@ -97,6 +103,57 @@
bool is_sequential() const { return is_sequential_; }
+ // A ready queue that executes nodes in FIFO order.
+ class FifoReadyQueue {
+ public:
+ explicit FifoReadyQueue(absl::Span<const NodeId> ready_nodes);
+
+ void Push(NodeId id);
+
+ NodeId Pop();
+ FifoReadyQueue PopHalf();
+
+ size_t Size() const;
+ bool Empty() const;
+
+ FifoReadyQueue CreateEmptyReadyQueue() const;
+
+ private:
+ absl::InlinedVector<NodeId, 8> queue_;
+ size_t head_ = 0;
+ };
+
+ // A ready queue that executes nodes sorted by NodeDef priority.
+ class PriorityReadyQueue {
+ public:
+ PriorityReadyQueue(absl::Span<const NodeDef> nodes_defs,
+ absl::Span<const NodeId> ready_nodes);
+
+ void Push(NodeId id);
+
+ NodeId Pop();
+ PriorityReadyQueue PopHalf();
+
+ size_t Size() const;
+ bool Empty() const;
+
+ PriorityReadyQueue CreateEmptyReadyQueue() const;
+
+ private:
+ struct Compare {
+ bool operator()(NodeId a, NodeId b) const {
+ return nodes_defs[a].priority < nodes_defs[b].priority;
+ }
+ absl::Span<const NodeDef> nodes_defs;
+ };
+
+ using InlinedPriorityQueue =
+ std::priority_queue<NodeId, absl::InlinedVector<NodeId, 8>, Compare>;
+
+ absl::Span<const NodeDef> nodes_defs_;
+ InlinedPriorityQueue queue_;
+ };
+
private:
// Align all atomic counters to a cache line boundary to avoid false
// sharing between multiple worker threads.
@@ -107,8 +164,6 @@
64;
#endif
- using ReadyQueue = absl::InlinedVector<NodeId, 8>;
-
// A struct to keep the state of a running ThunkExecutor.
struct ExecuteState {
// At run time NodeDef instantiated as a Node with an atomic counter that
@@ -162,26 +217,29 @@
tsl::AsyncValueRef<ExecuteEvent> event);
// Executes nodes in the ready queue with given thunk parameters.
+ template <typename ReadyQueue>
void Execute(ExecuteState* state, const Thunk::ExecuteParams& params,
ReadyQueue ready_queue, Thunk::ExecuteSession::Lock lock);
// Splits ready queue starting from `start_index` into ThunkExecutor tasks and
// offloads them to the task runner.
+ template <typename ReadyQueue>
void SplitReadyQueue(ExecuteState* state, const Thunk::ExecuteParams& params,
- int64_t start_index, ReadyQueue& ready_queue);
+ ReadyQueue& ready_queue, int64_t split_threshold);
// Processes out edges of a completed `node` and updates `ready_queue` with
// nodes that are ready to execute. If `event` is in error state, aborts the
// execution and records the error status to forward it to the caller.
+ template <typename ReadyQueue>
void ProcessOutEdges(ExecuteState* state,
tsl::AsyncValuePtr<Thunk::ExecuteEvent> node_event,
ExecuteState::Node& node, ReadyQueue& ready_queue);
- // Runs a transitive reduction on the NodeDef graph to remove redundant edges.
- // Returns the number of removed edges.
+ // Runs a transitive reduction on the NodeDef graph to remove redundant edges,
+ // and updates nodes priorities. Returns the number of removed edges.
//
// See: https://en.wikipedia.org/wiki/Transitive_reduction
- int64_t TransitiveReduction();
+ int64_t RunTransitiveReductionAndUpdatePriorities();
ThunkSequence thunk_sequence_;
Options options_;
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
index 2bbb932..60996eb 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
+++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
@@ -26,6 +26,7 @@
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
@@ -218,6 +219,123 @@
return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0};
}
+TEST(ThunkExecutorTest, FifoReadyQueueTest) {
+ ThunkExecutor::FifoReadyQueue queue({});
+
+ // Check basic queue properties.
+ EXPECT_TRUE(queue.Empty());
+ EXPECT_EQ(queue.Size(), 0);
+
+ queue.Push(1);
+ queue.Push(2);
+ queue.Push(3);
+
+ EXPECT_EQ(queue.Size(), 3);
+
+ EXPECT_EQ(queue.Pop(), 1);
+ EXPECT_EQ(queue.Pop(), 2);
+ EXPECT_EQ(queue.Pop(), 3);
+
+ EXPECT_TRUE(queue.Empty());
+ EXPECT_EQ(queue.Size(), 0);
+
+ // Prepare queue for PopHalf test case.
+ queue.Push(1);
+ queue.Push(2);
+ queue.Push(3);
+
+ // Pop half of the queue.
+ ThunkExecutor::FifoReadyQueue half0 = queue.PopHalf();
+ EXPECT_EQ(half0.Size(), 2);
+ EXPECT_EQ(half0.Pop(), 2);
+ EXPECT_EQ(half0.Pop(), 3);
+
+ // Check that the rest is still in the queue.
+ EXPECT_EQ(queue.Size(), 1);
+
+ // Pop the rest of the queue.
+ ThunkExecutor::FifoReadyQueue half1 = queue.PopHalf();
+ EXPECT_EQ(half1.Size(), 1);
+
+ // Check that all nodes were returned from PopHalf.
+ EXPECT_EQ(queue.Size(), 0);
+
+ // Add 5 elements to test Pop followed by PopHalf.
+ queue.Push(1);
+ queue.Push(2);
+ queue.Push(3);
+ queue.Push(4);
+ queue.Push(5);
+
+ EXPECT_EQ(queue.Pop(), 1);
+
+ // Check that PopHalf returns 2 last nodes.
+ ThunkExecutor::FifoReadyQueue half2 = queue.PopHalf();
+ EXPECT_EQ(half2.Size(), 2);
+ EXPECT_EQ(half2.Pop(), 4);
+ EXPECT_EQ(half2.Pop(), 5);
+}
+
+TEST(ThunkExecutorTest, PriorityReadyQueueTest) {
+ std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
+ for (size_t i = 0; i < nodes_defs.size(); ++i) {
+ nodes_defs[i].priority = i;
+ }
+
+ ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {});
+ // Check basic queue properties.
+ EXPECT_TRUE(queue.Empty());
+ EXPECT_EQ(queue.Size(), 0);
+
+ queue.Push(1);
+ queue.Push(3);
+ queue.Push(2);
+
+ EXPECT_EQ(queue.Pop(), 3);
+ EXPECT_EQ(queue.Pop(), 2);
+ EXPECT_EQ(queue.Pop(), 1);
+
+ EXPECT_TRUE(queue.Empty());
+ EXPECT_EQ(queue.Size(), 0);
+
+ // Prepare queue for PopHalf test case.
+ queue.Push(2);
+ queue.Push(1);
+ queue.Push(3);
+
+ // Pop half of the queue.
+ ThunkExecutor::PriorityReadyQueue half0 = queue.PopHalf();
+ EXPECT_EQ(half0.Size(), 2);
+ EXPECT_EQ(half0.Pop(), 2);
+ EXPECT_EQ(half0.Pop(), 1);
+
+ // Check that the rest is still in the queue.
+ EXPECT_EQ(queue.Size(), 1);
+
+ // Pop the rest of the queue.
+ ThunkExecutor::PriorityReadyQueue half1 = queue.PopHalf();
+ EXPECT_EQ(half1.Size(), 1);
+ EXPECT_EQ(half1.Pop(), 3);
+
+ // Check that all nodes were returned from PopHalf.
+ EXPECT_EQ(queue.Size(), 0);
+
+ // Add 5 elements to test Pop followed by PopHalf.
+ queue.Push(4);
+ queue.Push(3);
+ queue.Push(5);
+ queue.Push(1);
+ queue.Push(2);
+
+ EXPECT_EQ(queue.Pop(), 5);
+
+ // Check that PopHalf returns 2 last nodes.
+ ThunkExecutor::PriorityReadyQueue half2 = queue.PopHalf();
+ EXPECT_EQ(half2.Size(), 2);
+ EXPECT_EQ(half2.Pop(), 2);
+ EXPECT_EQ(half2.Pop(), 1);
+}
+
TEST(ThunkExecutorTest, DependencyOrdering) {
BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0);
@@ -237,6 +355,10 @@
EXPECT_FALSE(executor.is_sequential());
EXPECT_THAT(executor.source(), ElementsAre(0, 1));
EXPECT_THAT(executor.sink(), ElementsAre(2));
+
+ EXPECT_EQ(executor.node_def(0).priority, 1);
+ EXPECT_EQ(executor.node_def(1).priority, 1);
+ EXPECT_EQ(executor.node_def(2).priority, 0);
}
TEST(ThunkExecutorTest, SequentialOrdering) {
@@ -255,6 +377,10 @@
EXPECT_TRUE(executor.is_sequential());
EXPECT_THAT(executor.source(), ElementsAre(0));
EXPECT_THAT(executor.sink(), ElementsAre(2));
+
+ EXPECT_EQ(executor.node_def(0).priority, 2);
+ EXPECT_EQ(executor.node_def(1).priority, 1);
+ EXPECT_EQ(executor.node_def(2).priority, 0);
}
TEST(ThunkExecutorTest, ResourceOrdering) {
@@ -278,6 +404,9 @@
EXPECT_TRUE(executor.is_sequential());
EXPECT_THAT(executor.source(), ElementsAre(0));
EXPECT_THAT(executor.sink(), ElementsAre(1));
+
+ EXPECT_EQ(executor.node_def(0).priority, 1);
+ EXPECT_EQ(executor.node_def(1).priority, 0);
}
TEST(ThunkExecutorTest, TransitiveReduction) {
@@ -300,6 +429,10 @@
EXPECT_THAT(executor.node_def(1).in_edges, ElementsAre(0));
EXPECT_THAT(executor.node_def(1).out_edges, ElementsAre(2));
EXPECT_THAT(executor.node_def(2).in_edges, ElementsAre(1));
+
+ EXPECT_EQ(executor.node_def(0).priority, 2);
+ EXPECT_EQ(executor.node_def(1).priority, 1);
+ EXPECT_EQ(executor.node_def(2).priority, 0);
}
TEST(ThunkExecutorTest, Execute) {
@@ -333,7 +466,7 @@
Thunk::ExecuteParams params = {nullptr, &allocations};
params.task_runner = &task_runner;
params.session =
- Thunk::ExecuteSession(/*max_workers=*/8, /*split_threshold=*/1);
+ Thunk::ExecuteSession(/*max_workers=*/8, /*split_threshold=*/0);
auto execute_event = executor.Execute(params);
@@ -433,11 +566,11 @@
// and optionally uses a thread pool to execute thunk executor tasks.
class ThunkExecutorStressTest
: public testing::TestWithParam<
- std::tuple<int32_t, bool, bool, SharedResourceUse, bool>> {
+ std::tuple<int32_t, bool, bool, SharedResourceUse, bool, bool>> {
public:
void SetUp() override {
auto& [num_thunks, use_task_runner, use_device, shared_resource_use,
- inject_errors] = GetParam();
+ inject_errors, use_priority_ready_queue] = GetParam();
use_task_runner_ = use_task_runner;
use_device_ = use_device;
@@ -477,16 +610,21 @@
TEST_P(ThunkExecutorStressTest, Execute) {
auto [num_thunks, use_task_runner, use_device, shared_resource_use,
- inject_errors] = GetParam();
+ inject_errors, use_priority_ready_queue] = GetParam();
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GeneratedThunkSequence> g,
GenerateThunkSequence(/*num_elements=*/1024, num_thunks,
shared_resource_use, inject_errors));
+ ThunkExecutor::Options executor_options = {
+ /*execute_sequential_buffer_threshold=*/0,
+ /*use_priority_ready_queue=*/use_priority_ready_queue,
+ };
+
TF_ASSERT_OK_AND_ASSIGN(
ThunkExecutor executor,
- ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()));
+ ThunkExecutor::Create(std::move(g->sequence), executor_options));
BufferAllocations allocations(g->buffers);
Thunk::ExecuteParams params = {nullptr, &allocations, nullptr, device(),
@@ -516,12 +654,95 @@
testing::Values(SharedResourceUse::kNo,
SharedResourceUse::kAll,
SharedResourceUse::kRandom),
- /*inject_errors=*/testing::Bool()));
+ /*inject_errors=*/testing::Bool(),
+ /*use_priority_ready_queue=*/testing::Bool()));
//===----------------------------------------------------------------------===//
// Performance benchmarks below
//===----------------------------------------------------------------------===//
+static void BM_FifoReadyQueuePushPop(benchmark::State& state) {
+ ThunkExecutor::FifoReadyQueue queue({});
+ const size_t num_push_pop = state.range(0);
+
+ for (auto _ : state) {
+ for (int i = 0; i < num_push_pop; ++i) {
+ queue.Push(i);
+ }
+ for (int i = 0; i < num_push_pop; ++i) {
+ benchmark::DoNotOptimize(queue.Pop());
+ }
+ }
+}
+
+static void BM_FifoReadyQueuePushPopHalf(benchmark::State& state) {
+ ThunkExecutor::FifoReadyQueue queue({});
+ const size_t num_push_pop = state.range(0);
+
+ for (auto _ : state) {
+ for (int i = 0; i < num_push_pop; ++i) {
+ queue.Push(i);
+ }
+ benchmark::DoNotOptimize(queue.PopHalf());
+ }
+}
+
+static void BM_PriorityReadyQueuePushPop(benchmark::State& state) {
+ std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
+ for (size_t i = 0; i < nodes_defs.size(); ++i) {
+ nodes_defs[i].priority = i;
+ }
+
+ std::default_random_engine rng;
+ absl::c_shuffle(nodes_defs, rng);
+
+ ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {});
+ const size_t num_push_pop = state.range(0);
+
+ for (auto _ : state) {
+ for (int i = 0; i < num_push_pop; ++i) {
+ queue.Push(i);
+ }
+ for (int i = 0; i < num_push_pop; ++i) {
+ benchmark::DoNotOptimize(queue.Pop());
+ }
+ }
+}
+
+static void BM_PriorityReadyQueuePushPopHalf(benchmark::State& state) {
+ std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
+ for (size_t i = 0; i < nodes_defs.size(); ++i) {
+ nodes_defs[i].priority = i;
+ }
+
+ std::default_random_engine rng;
+ absl::c_shuffle(nodes_defs, rng);
+
+ ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {});
+ const size_t num_push_pop = state.range(0);
+
+ for (auto _ : state) {
+ for (int i = 0; i < num_push_pop; ++i) {
+ queue.Push(i);
+ }
+ benchmark::DoNotOptimize(queue.PopHalf());
+ }
+}
+
+#define BENCHMARK_READY_QUEUE(name) \
+ BENCHMARK(name) \
+ ->MeasureProcessCPUTime() \
+ ->Arg(1) \
+ ->Arg(2) \
+ ->Arg(4) \
+ ->Arg(8) \
+ ->Arg(16)
+
+BENCHMARK_READY_QUEUE(BM_FifoReadyQueuePushPop);
+BENCHMARK_READY_QUEUE(BM_FifoReadyQueuePushPopHalf);
+BENCHMARK_READY_QUEUE(BM_PriorityReadyQueuePushPop);
+BENCHMARK_READY_QUEUE(BM_PriorityReadyQueuePushPopHalf);
+
static void BM_SequentialThunkExecutor(benchmark::State& state) {
const size_t num_thunks = state.range(0);
diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc
index 3c9c8d1..6b07b41 100644
--- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc
+++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc
@@ -142,9 +142,9 @@
// Forward executable run options to the FFI handlers via the call options.
ffi::CallOptions call_options = {
- run_options->device_ordinal(), run_options->stream(),
- run_options->allocator(), /*called_computation=*/nullptr,
- run_options->ffi_execution_context()};
+ run_options->device_ordinal(),
+ ffi::CallOptions::CpuOptions{run_options->intra_op_thread_pool()},
+ /*called_computation=*/nullptr, run_options->ffi_execution_context()};
ffi::CallFrame call_frame = builder.Build();
return ffi::Call(registration->bundle.execute, call_frame, call_options);
diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD
index 1c8c111..be51e84 100644
--- a/third_party/xla/xla/service/cpu/tests/BUILD
+++ b/third_party/xla/xla/service/cpu/tests/BUILD
@@ -361,6 +361,7 @@
name = "onednn_matmul_test",
srcs = ["onednn_matmul_test.cc"],
copts = tsl_copts(),
+ shard_count = 4,
tags = [
"no_oss",
"notap",
@@ -372,7 +373,7 @@
"//xla:test_helpers",
"//xla/hlo/utils:hlo_matchers",
"//xla/service:cpu_plugin",
- "//xla/service/cpu:onednn_matmul_rewriter",
+ "//xla/service/cpu:onednn_contraction_rewriter",
"//xla/service/cpu:onednn_util",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
@@ -393,7 +394,7 @@
"//xla:test_helpers",
"//xla/hlo/utils:hlo_matchers",
"//xla/service:cpu_plugin",
- "//xla/service/cpu:onednn_matmul_rewriter",
+ "//xla/service/cpu:onednn_contraction_rewriter",
"//xla/service/cpu:onednn_util",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
diff --git a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc
index 50c0e8f..6bceebc 100644
--- a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc
+++ b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc
@@ -19,7 +19,7 @@
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/literal.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
#include "xla/service/cpu/onednn_util.h"
#include "xla/shape_util.h"
#include "xla/test.h"
diff --git a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc
index d7fb39f..57f7c09 100644
--- a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc
+++ b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc
@@ -19,7 +19,7 @@
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/literal.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
#include "xla/service/cpu/onednn_util.h"
#include "xla/shape_util.h"
#include "xla/test.h"
@@ -803,7 +803,9 @@
; CHECK: backend_config={
; CHECK-DAG: "outer_dimension_partitions":[],
; CHECK-DAG: "onednn_matmul_config":{
- ; CHECK-NOT: "fused_ops":["LINEAR"]
+ ; CHECK-NOT: "fusions":{
+ ; CHECK-NOT: "ops":["LINEAR"]
+ ; CHECK-NOT: }
; CHECK-DAG: }
; CHECK: }
)");
@@ -1499,47 +1501,44 @@
)");
}
+TEST_F(MatmulTest, ColMajorBF16DotBeforeLayoutAssignment) {
+ if (!IsSupportedType(PrimitiveType::BF16)) {
+ GTEST_SKIP() << "CPU does not support BF16.";
+ }
+
+ const char* matmul_module_str = R"(
+ HloModule matmul.colmajor.test
+ ENTRY matmul.colmajor.test.bf16 {
+ arg.0 = bf16[500,500]{0,1} parameter(0)
+ arg.1 = bf16[500,500]{1,0} parameter(1)
+ transpose.0 = bf16[500,500]{0,1} transpose(arg.1), dimensions={1,0}
+ ROOT dot.0 = bf16[500,500]{1,0} dot(arg.0, arg.1), lhs_contracting_dims={1},
+ rhs_contracting_dims={0}
+ })";
+
+ EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec(1e-2, 1e-2)));
+ MatchOptimizedHlo(matmul_module_str,
+ R"(
+ ; CHECK: (bf16[500,500]{1,0}, u8[{{.*}}]{0})
+ ; CHECK-SAME: custom_call_target="__onednn$matmul"
+ )");
+}
+
TEST_F(MatmulTest, ConsecutiveBinaryAdd) {
const char* matmul_module_str = R"(
HloModule matmul.test.f32
- region_0.22 {
- Arg_0.23 = f32[] parameter(0)
- Arg_1.24 = f32[] parameter(1)
- ROOT add.25 = f32[] add(Arg_0.23, Arg_1.24)
- }
-
- region_1.29 {
- Arg_0.30 = f32[] parameter(0)
- Arg_1.31 = f32[] parameter(1)
- ROOT add.32 = f32[] add(Arg_0.30, Arg_1.31)
- }
-
- ENTRY main {
- constant.2 = f32[] constant(1e-06)
- broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={}
- constant.7 = f32[] constant(1)
- broadcast.8 = f32[1000000,3] broadcast(constant.7), dimensions={}
- Arg_0.1 = f32[3] parameter(0)
- reshape.10 = f32[1,3] reshape(Arg_0.1)
- broadcast.11 = f32[1,3] broadcast(reshape.10), dimensions={0,1}
- reshape.12 = f32[3] reshape(broadcast.11)
- broadcast.13 = f32[1000000,3] broadcast(reshape.12), dimensions={1}
- subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13)
- constant.4 = f32[] constant(0)
- broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={}
- dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot.16 = f32[1000000,3] dot(broadcast.3, dot.15), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
- dot.17 = f32[1000000,3] dot(broadcast.3, subtract.14), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
- dot.18 = f32[1000000,3] dot(dot.17, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- add.19 = f32[1000000,3] add(dot.16, dot.18)
- constant.9 = f32[3] constant({1, 2, 3})
- dot.20 = f32[1000000,3] dot(broadcast.3, constant.9), lhs_contracting_dims={}, rhs_contracting_dims={}
- add.21 = f32[1000000,3] add(add.19, dot.20)
- constant.6 = f32[] constant(0)
- reduce.26 = f32[3] reduce(add.21, constant.6), dimensions={0}, to_apply=region_0.22
- reshape.27 = f32[1,3] reshape(reduce.26)
- negate.28 = f32[1,3] negate(reshape.27)
- ROOT reduce.33 = f32[3] reduce(negate.28, constant.6), dimensions={0}, to_apply=region_1.29
+ ENTRY matmul.test.f32 {
+ arg0.1 = f32[128,32,4,4] parameter(0)
+ arg0.2 = f32[128,32,4,4] parameter(1)
+ dot.7 = f32[128,32,4,4] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ const.0 = f32[128,32] constant({...})
+ bcast.1 = f32[128,32,4,4] broadcast(const.0), dimensions={0,1}
+ add.0 = f32[128,32,4,4] add(dot.7,bcast.1)
+ const.1 = f32[4] constant({1,2,3,4})
+ bcast.2 = f32[128,32,4,4] broadcast(const.1), dimensions={3}
+ add.1 = f32[128,32,4,4] add(add.0, bcast.2)
+ tuple.12 = (f32[128,32,4,4]) tuple(add.1)
+ ROOT get-tuple-element.13 = f32[128,32,4,4] get-tuple-element(tuple.12), index=0
})";
EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4}));
diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
index 0d4317b..534b359 100644
--- a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
+++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
@@ -31,6 +31,7 @@
switch (instruction->opcode()) {
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
+ case HloOpcode::kCall:
case HloOpcode::kConstant:
case HloOpcode::kConcatenate:
case HloOpcode::kConvert:
diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
index 7bb40d7..ec4d07e 100644
--- a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
+++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
@@ -65,10 +65,15 @@
const char* const hlo_string = R"(
HloModule Module
- ENTRY main {
+ bcast {
p0 = u4[] parameter(0)
ROOT out = u4[3, 3] broadcast(p0), dimensions={}
}
+
+ ENTRY main {
+ p0 = u4[] parameter(0)
+ ROOT out = u4[3, 3] call(p0), to_apply=bcast
+ }
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc
index 13b3032..3aa3a88 100644
--- a/third_party/xla/xla/service/dump.cc
+++ b/third_party/xla/xla/service/dump.cc
@@ -50,10 +50,10 @@
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/hlo_graph_dumper.h"
#include "xla/service/hlo_proto_util.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/util.h"
#include "tsl/lib/io/zlib_compression_options.h"
#include "tsl/lib/io/zlib_outputbuffer.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/file_system.h"
@@ -64,6 +64,24 @@
namespace xla {
+absl::Status CreateDirIfNeeded(const std::string& dir, tsl::Env* env) {
+ if (!env->IsDirectory(dir).ok()) {
+ absl::Status status = env->RecursivelyCreateDir(dir);
+ // Two threads can race to observe the absence of the dump directory and
+ // simultaneously try to create it, causing the "losing" thread to get a
+ // "directory already exists" error. We can work around this by checking
+ // again whether the dir exists.
+ if (!status.ok()) {
+ status = env->IsDirectory(dir);
+ if (!status.ok()) {
+ LOG(ERROR) << "Could not create directory " << dir;
+ return status;
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
std::string RenderGraph(absl::string_view label, const HloModule& module,
RenderedGraphFormat format,
bool show_fusion_subcomputations) {
@@ -299,17 +317,8 @@
VLOG(1) << "Dumping " << filename << " to " << dir;
tsl::Env* env = tsl::Env::Default();
- // Two threads can race to observe the absence of the dump directory and
- // simultaneously try to create it, causing the "losing" thread to get a
- // "directory already exists" error. We can work around this by checking
- // again whether the dir exists.
- if (!env->IsDirectory(dir).ok()) {
- auto status = env->RecursivelyCreateDir(dir);
- if (!status.ok() && !env->IsDirectory(dir).ok()) {
- LOG(ERROR) << "Could not create directory " << dir
- << " for dumping XLA debug data: " << status;
- return std::nullopt;
- }
+ if (!CreateDirIfNeeded(dir, env).ok()) {
+ return std::nullopt;
}
// Make sure we are not going to dump more modules than the user has asked.
@@ -677,15 +686,7 @@
if (dir.empty()) {
return;
}
- if (!env->IsDirectory(dir).ok()) {
- auto status = env->RecursivelyCreateDir(dir);
- if (!status.ok()) {
- LOG(ERROR) << "Could not create directory " << dir
- << " for dumping: " << status;
- return;
- }
- }
- if (!env->IsDirectory(dir).ok()) {
+ if (!CreateDirIfNeeded(dir, env).ok()) {
return;
}
const std::string path = tsl::io::JoinPath(dir, filename);
@@ -884,4 +885,20 @@
}
}
+absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
+ const std::string& directory,
+ const std::string& file_name,
+ std::string* full_path) {
+ tsl::Env* env = tsl::Env::Default();
+ TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
+ TF_RETURN_IF_ERROR(CreateDirIfNeeded(directory, env));
+ std::string safe_file_name = SanitizeFileName(file_name) + ".pb";
+ std::string full_path_impl;
+ if (!full_path) {
+ full_path = &full_path_impl;
+ }
+ *full_path = tsl::io::JoinPath(directory, safe_file_name);
+ return tsl::WriteBinaryProto(env, *full_path, message);
+}
+
} // namespace xla
diff --git a/third_party/xla/xla/service/dump.h b/third_party/xla/xla/service/dump.h
index 0b1a6d2..623e729 100644
--- a/third_party/xla/xla/service/dump.h
+++ b/third_party/xla/xla/service/dump.h
@@ -43,6 +43,10 @@
class BufferAssignment;
class HloSnapshot;
+// Creates dir if doesn't exist (analogue of `mkdir -p`), tries to get around
+// race conditions by trying again on collision.
+absl::Status CreateDirIfNeeded(const std::string& dir, tsl::Env* env);
+
// Get a timestamp which we can use as a filename prefix specific to this
// module.
std::string TimestampFor(const HloModule& module);
@@ -173,6 +177,18 @@
// writing to two files, but you don't want to print twice.
bool DumpingToStdout(const DebugOptions& opts);
+// Writes the given message in binary proto to the path formed by joining
+// 'directory/file_name.pb'. The 'directory' is recursively created if it
+// doesn't already exist, and the 'file_name' is sanitized by replacing
+// illegal characters with underscore '_'.
+//
+// If 'full_name' is not null then it is set to the name of the file the
+// protobuf was written to.
+absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
+ const std::string& directory,
+ const std::string& file_name,
+ std::string* full_path = nullptr);
+
} // namespace xla
#endif // XLA_SERVICE_DUMP_H_
diff --git a/third_party/xla/xla/service/executable.cc b/third_party/xla/xla/service/executable.cc
index ed86114..aa81fce 100644
--- a/third_party/xla/xla/service/executable.cc
+++ b/third_party/xla/xla/service/executable.cc
@@ -25,7 +25,6 @@
#include "xla/service/maybe_owning_device_memory.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD
index 7c04293..080c7e8 100644
--- a/third_party/xla/xla/service/gpu/BUILD
+++ b/third_party/xla/xla/service/gpu/BUILD
@@ -36,7 +36,6 @@
load(
"//xla/tsl:tsl.bzl",
"if_google",
- "if_oss",
"internal_visibility",
"tsl_copts",
"tsl_gpu_library",
@@ -314,7 +313,6 @@
":execution_stream_assignment",
":gpu_asm_opts_util",
":gpu_conv_runner",
- ":gpu_fused_mha_runner",
":gpu_norm_runner",
":hlo_fusion_analysis",
":ir_emission_utils",
@@ -357,9 +355,9 @@
"//xla/service/gpu/runtime:conditional_thunk",
"//xla/service/gpu/runtime:convolution_thunk",
"//xla/service/gpu/runtime:copy_thunk",
+ "//xla/service/gpu/runtime:cudnn_thunk",
"//xla/service/gpu/runtime:custom_call_thunk",
"//xla/service/gpu/runtime:fft_thunk",
- "//xla/service/gpu/runtime:fused_mha_thunk",
"//xla/service/gpu/runtime:gemm_thunk",
"//xla/service/gpu/runtime:gpublas_lt_matmul_thunk",
"//xla/service/gpu/runtime:infeed_thunk",
@@ -486,135 +484,6 @@
)
cc_library(
- name = "gemm_fusion_autotuner",
- srcs = if_cuda_is_configured(["gemm_fusion_autotuner.cc"]),
- hdrs = if_cuda_is_configured(["gemm_fusion_autotuner.h"]),
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
- deps = if_cuda_is_configured([
- ":autotuner_compile_util",
- ":autotuner_util",
- ":backend_configs_cc",
- ":buffer_comparator",
- ":gemm_rewriter",
- ":gpu_float_support",
- ":gpu_fusible",
- ":instruction_fusion",
- ":ir_emission_utils",
- ":matmul_utils",
- ":split_k_gemm_rewriter",
- ":stream_executor_util",
- ":cudnn_fusion_compiler",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- "@local_config_cuda//cuda:cuda_headers",
- "//xla:autotuning_proto_cc",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla/tools:hlo_decomposer_lib",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/pjrt/distributed:key_value_store_interface",
- "//xla/service:algorithm_util",
- "//xla/service:dump",
- "//xla/service:executable",
- "//xla/service:float_normalization",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla/service:shaped_buffer",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:device_memory",
- "//xla/stream_executor",
- "//xla/stream_executor/gpu:redzone_allocator",
- "@local_tsl//tsl/lib/core:bits",
- "@local_tsl//tsl/platform:blocking_counter",
- "@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:protobuf",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/profiler/lib:scoped_annotation",
- "//xla/tsl/util/proto:proto_utils",
- "//xla/service/gpu:hlo_traversal",
- ":fusion_wrapper",
- ":priority_fusion",
- "//xla/service/gpu/model:gpu_hlo_cost_analysis",
- "//xla/stream_executor:stream_executor_memory_allocator",
- "@local_tsl//tsl/platform:path",
- ]),
-)
-
-xla_test(
- name = "gemm_fusion_autotuner_test",
- srcs = if_cuda_is_configured(["gemm_fusion_autotuner_test.cc"]),
- backend_tags = {"gpu": [
- "requires-gpu-sm80",
- ]},
- backends = [
- "gpu",
- ],
- tags = [
- "nomac",
- ],
- deps = [
- ":autotuner_util",
- ":backend_configs_cc",
- ":gemm_fusion",
- ":gemm_fusion_autotuner",
- ":gemm_rewriter",
- ":ir_emission_utils",
- ":matmul_utils",
- "//xla:autotuning_proto_cc",
- "//xla:error_spec",
- "//xla:xla_data_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/pjrt/distributed:key_value_store_interface",
- "//xla/service:call_inliner",
- "//xla/service:dump",
- "//xla/service:executable",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass_pipeline",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:device_description_proto_cc",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "//xla/tests:test_utils",
- "//xla/tests:verified_hlo_module",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "//xla/tools:hlo_decomposer_lib",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:path",
- "@local_tsl//tsl/platform:platform_port",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- ] + if_cuda_is_configured([
- "@local_config_cuda//cuda:cuda_headers",
- ]),
-)
-
-cc_library(
name = "triton_call",
srcs = if_gpu_is_configured(["triton_call.cc"]),
hdrs = ["triton_call.h"],
@@ -789,6 +658,7 @@
"//xla/service/llvm_ir:llvm_type_conversion_util",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor:device_description",
+ "//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
@@ -800,7 +670,6 @@
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:statusor",
],
@@ -817,11 +686,11 @@
"//xla:literal_util",
"//xla:shape_util",
"//xla:types",
- "//xla:util",
"//xla/hlo/ir:backend_config",
"//xla/service:buffer_assignment",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
@@ -833,6 +702,7 @@
name = "reduction_utils",
srcs = ["reduction_utils.cc"],
hdrs = ["reduction_utils.h"],
+ compatible_with = get_compatible_with_portable(),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":ir_emission_utils",
@@ -914,45 +784,6 @@
)
cc_library(
- name = "gemm_rewriter",
- srcs = ["gemm_rewriter.cc"],
- hdrs = ["gemm_rewriter.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":ir_emission_utils",
- ":matmul_utils",
- "//xla:literal",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:types",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/evaluator:hlo_evaluator",
- "//xla/hlo/ir:hlo",
- "//xla/service:algorithm_util",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "//xla/stream_executor:blas",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor/gpu:gpu_blas_lt",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:ml_dtypes",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/protobuf:dnn_proto_cc",
- ],
-)
-
-cc_library(
name = "triton_tiling_propagation",
srcs = ["triton_tiling_propagation.cc"],
hdrs = ["triton_tiling_propagation.h"],
@@ -1017,9 +848,9 @@
name = "triton_fusion_analysis_test",
srcs = ["triton_fusion_analysis_test.cc"],
deps = [
- ":gemm_fusion",
":triton_fusion_analysis",
"//xla/hlo/ir:hlo",
+ "//xla/service/gpu/transforms:gemm_fusion",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
@@ -1031,100 +862,6 @@
)
cc_library(
- name = "gemm_fusion",
- srcs = ["gemm_fusion.cc"],
- hdrs = ["gemm_fusion.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_padding_requirements",
- ":ir_emission_utils",
- ":matmul_utils",
- ":triton_fusion_analysis",
- ":triton_tiling_propagation",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:instruction_fusion",
- "//xla/service/gpu/fusions/triton:triton_support",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:tensor_float_32_utils",
- ],
-)
-
-xla_cc_test(
- name = "gemm_fusion_test",
- srcs = ["gemm_fusion_test.cc"],
- deps = [
- ":cublas_padding_requirements",
- ":gemm_fusion",
- ":triton_fusion_analysis",
- "//xla:autotuning_proto_cc",
- "//xla:xla_data_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/stream_executor:device_description",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "//xla/tests:verified_hlo_module",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gemv_rewriter",
- srcs = ["gemv_rewriter.cc"],
- hdrs = ["gemv_rewriter.h"],
- deps = [
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gemv_rewriter_test",
- srcs = ["gemv_rewriter_test.cc"],
- deps = [
- ":gemv_rewriter",
- "//xla/hlo/ir:hlo",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_absl//absl/status:statusor",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "split_k_gemm_rewriter",
srcs = ["split_k_gemm_rewriter.cc"],
hdrs = ["split_k_gemm_rewriter.h"],
@@ -1185,216 +922,6 @@
)
cc_library(
- name = "softmax_rewriter_triton",
- srcs = ["softmax_rewriter_triton.cc"],
- hdrs = ["softmax_rewriter_triton.h"],
- deps = [
- ":backend_configs_cc",
- ":hlo_traversal",
- ":ir_emission_utils",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:hlo_cost_analysis",
- "//xla/service:hlo_pass",
- "//xla/service:instruction_fusion",
- "//xla/service/gpu/fusions/triton:triton_support",
- "//xla/service/gpu/model:fusion_analysis_cache",
- "//xla/service/gpu/model:gpu_indexing_performance_model",
- "//xla/service/gpu/model:symbolic_tile_analysis",
- "//xla/service/gpu/model:tiled_hlo_computation",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gemm_algorithm_picker",
- srcs = if_gpu_is_configured(["gemm_algorithm_picker.cc"]),
- hdrs = if_gpu_is_configured(["gemm_algorithm_picker.h"]),
- deps = if_gpu_is_configured([
- ":backend_configs_cc",
- ":buffer_comparator",
- ":cublas_cudnn",
- ":gpu_asm_opts_util",
- ":gpu_conv_runner",
- ":ir_emission_utils",
- ":matmul_utils",
- ":stream_executor_util",
- ":variant_visitor",
- ":autotuner_compile_util",
- ":autotuner_util",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:span",
- "//xla:autotune_results_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla:status_macros",
- "//xla/stream_executor",
- "//xla/stream_executor:blas",
- "//xla/stream_executor/gpu:gpu_blas_lt",
- "//xla/stream_executor:device_memory",
- "//xla/stream_executor:device_memory_allocator",
- "//xla/stream_executor/gpu:redzone_allocator",
- "//xla/tsl/util/proto:proto_utils",
- "//xla:util",
- "//xla:autotuning_proto_cc",
- "//xla:shape_util",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/profiler/lib:scoped_annotation",
- ]) + ["@com_google_absl//absl/status"],
-)
-
-cc_library(
- name = "autotuner_util",
- srcs = if_gpu_is_configured(["autotuner_util.cc"]),
- hdrs = if_gpu_is_configured(["autotuner_util.h"]),
- deps = if_gpu_is_configured([
- ":gpu_asm_opts_util",
- ":stream_executor_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/base",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/hash",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@llvm-project//llvm:Core",
- "@llvm-project//llvm:Support",
- "//xla:autotune_results_proto_cc",
- "//xla:autotuning_proto_cc",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:types",
- "//xla:util",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:compilation_environments",
- "//xla/stream_executor:stream_executor_memory_allocator",
- "//xla/stream_executor",
- "//xla/stream_executor/gpu:redzone_allocator",
- "@local_tsl//tsl/platform:base64",
- "@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:path",
- "@local_tsl//tsl/platform:protobuf",
- "@local_tsl//tsl/platform:statusor",
- ]),
-)
-
-# We need a separate target, as runtime executable cannot depend on compilation
-# pipeline.
-cc_library(
- name = "autotuner_compile_util",
- srcs = if_gpu_is_configured(["autotuner_compile_util.cc"]),
- hdrs = if_gpu_is_configured(["autotuner_compile_util.h"]),
- deps = if_gpu_is_configured([
- ":autotuner_util",
- ":gpu_executable_run_options",
- ":ir_emission_utils",
- "@com_google_absl//absl/functional:any_invocable",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_absl//absl/types:span",
- "//xla/hlo/ir:hlo",
- "//xla/service:compiler",
- "//xla/service:executable",
- "//xla/service:hlo_module_config",
- "//xla/service:maybe_owning_device_memory",
- "//xla/service:shaped_buffer",
- "//xla/stream_executor",
- "//xla/stream_executor/gpu:gpu_stream_header",
- "//xla/stream_executor/gpu:redzone_allocator",
- "//xla:executable_run_options",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_proto_cc",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ]) + ["@com_google_absl//absl/status"],
-)
-
-xla_test(
- name = "autotuner_compile_util_test",
- srcs = if_gpu_is_configured(["autotuner_compile_util_test.cc"]),
- backends = ["gpu"],
- deps = if_gpu_is_configured(
- [
- ":autotuner_compile_util",
- ":autotuner_util",
- "@com_google_googletest//:gtest_main",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "//xla/hlo/ir:hlo",
- "//xla/service:platform_util",
- "//xla/stream_executor:platform",
- "//xla/tests:hlo_test_base",
- "@local_tsl//tsl/platform:statusor",
- ],
- if_false = [
- "@com_google_googletest//:gtest_main", # b/317293391
- ],
- ),
-)
-
-xla_test(
- name = "gemm_algorithm_picker_test",
- srcs = if_gpu_is_configured(["gemm_algorithm_picker_test.cc"]),
- backends = [
- "gpu_v100",
- "gpu_amd_any",
- ],
- deps = [
- ":autotuner_util",
- ":backend_configs_cc",
- ":gemm_algorithm_picker",
- ":gemm_rewriter",
- ":variant_visitor",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service:platform_util",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:platform",
- "//xla/tests:hlo_test_base",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- "@local_tsl//tsl/protobuf:dnn_proto_cc",
- ],
-)
-
-cc_library(
name = "matmul_utils",
srcs = ["matmul_utils.cc"],
hdrs = ["matmul_utils.h"],
@@ -1465,239 +992,6 @@
)
cc_library(
- name = "dot_dimension_sorter",
- srcs = ["dot_dimension_sorter.cc"],
- hdrs = ["dot_dimension_sorter.h"],
- deps = [
- "//xla:permutation_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- ],
-)
-
-xla_test(
- name = "dot_dimension_sorter_test",
- srcs = ["dot_dimension_sorter_test.cc"],
- backends = ["gpu"],
- deps = [
- ":dot_dimension_sorter",
- "//xla:error_spec",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu/tests:gpu_codegen_test",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "dot_sparsity_rewriter",
- srcs = ["dot_sparsity_rewriter.cc"],
- hdrs = ["dot_sparsity_rewriter.h"],
- deps = [
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "dot_sparsity_rewriter_test",
- srcs = ["dot_sparsity_rewriter_test.cc"],
- deps = [
- ":dot_sparsity_rewriter",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_async_collective_annotator",
- srcs = ["gpu_async_collective_annotator.cc"],
- hdrs = ["gpu_async_collective_annotator.h"],
- deps = [
- ":backend_configs_cc",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gpu_async_collective_annotator_test",
- srcs = ["gpu_async_collective_annotator_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":gpu_async_collective_annotator",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/tests:hlo_test_base",
- "//xla/tests:test_macros_header",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_convert_async_collectives_to_sync",
- srcs = ["gpu_convert_async_collectives_to_sync.cc"],
- hdrs = ["gpu_convert_async_collectives_to_sync.h"],
- deps = [
- ":backend_configs_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:convert_async_collectives_to_sync",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gpu_convert_async_collectives_to_sync_test",
- srcs = ["gpu_convert_async_collectives_to_sync_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":gpu_convert_async_collectives_to_sync",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "conv_algorithm_picker",
- srcs = if_gpu_is_configured(["conv_algorithm_picker.cc"]),
- hdrs = if_gpu_is_configured(["conv_algorithm_picker.h"]),
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
- "TENSORFLOW_USE_ROCM=1",
- ]),
- deps = if_gpu_is_configured([
- ":autotuner_compile_util",
- ":autotuner_util",
- ":backend_configs_cc",
- ":buffer_comparator",
- ":cublas_cudnn",
- ":gpu_asm_opts_util",
- ":gpu_autotuning_proto_cc",
- ":gpu_conv_runner",
- ":hlo_algorithm_denylist",
- ":stream_executor_util",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_absl//absl/types:span",
- "@local_config_cuda//cuda:cudnn_header",
- "//xla:autotune_results_proto_cc",
- "//xla:autotuning_proto_cc",
- "//xla:debug_options_flags",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:executable",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla/service:slow_operation_alarm",
- "//xla/stream_executor",
- "//xla/stream_executor:dnn",
- "//xla/stream_executor:numeric_options",
- "//xla/stream_executor:scratch_allocator",
- "//xla/stream_executor:device_memory_allocator",
- "//xla/stream_executor:lazy_op_runner",
- "//xla/stream_executor/cuda:cuda_platform_id",
- "//xla/stream_executor/gpu:redzone_allocator",
- "//xla/stream_executor/rocm:rocm_platform_id",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:numbers",
- "//xla/tsl/util:env_var",
- "@local_tsl//tsl/platform:statusor",
- "//xla/tsl/util/proto:proto_utils",
- "@local_tsl//tsl/platform:status",
- ]),
-)
-
-xla_test(
- name = "conv_algorithm_picker_test",
- srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]),
- backends = [
- "gpu_v100",
- "gpu_amd_any",
- ],
- tags = [
- "noasan",
- "nomsan",
- ],
- deps = [
- ":autotuner_util",
- ":conv_algorithm_picker",
- ":gpu_conv_rewriter",
- ":stream_executor_util",
- "//xla:debug_options_flags",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service:platform_util",
- "//xla/service:tuple_simplifier",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:platform",
- "//xla/tests:hlo_test_base",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
name = "gpu_conv_runner",
srcs = ["gpu_conv_runner.cc"],
hdrs = ["gpu_conv_runner.h"],
@@ -1751,167 +1045,6 @@
)
cc_library(
- name = "gpu_fused_mha_runner",
- srcs = ["gpu_fused_mha_runner.cc"],
- hdrs = ["gpu_fused_mha_runner.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":stream_executor_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/stream_executor",
- "//xla/stream_executor:dnn",
- "//xla/stream_executor:lazy_op_runner",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@eigen_archive//:eigen3",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_conv_rewriter",
- srcs = ["gpu_conv_rewriter.cc"],
- hdrs = ["gpu_conv_rewriter.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- "//xla:permutation_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:window_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_sort_rewriter",
- srcs = if_gpu_is_configured(
- ["gpu_sort_rewriter.cc"],
- ["gpu_sort_rewriter_stub.cc"],
- ),
- hdrs = ["gpu_sort_rewriter.h"],
- deps = [
- ":cublas_cudnn",
- "//xla:comparison_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:stable_sort_expander",
- "//xla/service/gpu/runtime:cub_sort_thunk",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "move_copy_to_users",
- srcs = ["move_copy_to_users.cc"],
- hdrs = ["move_copy_to_users.h"],
- deps = [
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "move_copy_to_users_test",
- srcs = ["move_copy_to_users_test.cc"],
- deps = [
- ":move_copy_to_users",
- "//xla/service:layout_assignment",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-xla_cc_test(
- name = "gpu_conv_rewriter_test",
- srcs = ["gpu_conv_rewriter_test.cc"],
- deps = [
- ":cublas_cudnn",
- ":gpu_conv_rewriter",
- "//xla:array4d",
- "//xla:literal_util",
- "//xla:protobuf_util",
- "//xla:shape_util",
- "//xla:test",
- "//xla:test_helpers",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service:shape_inference",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/strings:str_format",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- ],
-)
-
-xla_test(
- name = "gpu_sort_rewriter_test",
- srcs = if_cuda_is_configured(["gpu_sort_rewriter_test.cc"]),
- backends = ["gpu"],
- tags = ["no_oss"],
- deps = [
- ":cublas_cudnn",
- ":gpu_sort_rewriter",
- "//xla:error_spec",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- ],
-)
-
-cc_library(
name = "cusolver_context",
srcs = if_gpu_is_configured(["cusolver_context.cc"]),
hdrs = if_gpu_is_configured(["cusolver_context.h"]),
@@ -1942,83 +1075,6 @@
]),
)
-cc_library(
- name = "cusolver_rewriter",
- srcs = if_gpu_is_configured(["cusolver_rewriter.cc"]),
- hdrs = if_gpu_is_configured(["cusolver_rewriter.h"]),
- deps = if_gpu_is_configured([
- ":cusolver_context",
- ":ir_emission_utils",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "//xla:comparison_util",
- "//xla:literal",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/stream_executor",
- "//xla/stream_executor:blas",
- "//xla/stream_executor:device_memory_allocator",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ]),
-)
-
-cc_library(
- name = "instruction_fusion",
- srcs = ["instruction_fusion.cc"],
- hdrs = ["instruction_fusion.h"],
- deps = [
- ":gpu_fusible",
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:fusion_node_indexing_evaluation",
- "//xla/service:fusion_queue",
- "//xla/service:hlo_pass",
- "//xla/service:instruction_fusion",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/meta:type_traits",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- ],
-)
-
-xla_cc_test(
- name = "instruction_fusion_test",
- srcs = ["instruction_fusion_test.cc"],
- tags = [
- "nomsan",
- "not_run:arm",
- ],
- deps = [
- ":gpu_device_info_for_tests",
- ":gpu_fusible",
- ":instruction_fusion",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:test_utils",
- "//xla/tests:verified_hlo_module",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
tf_proto_library(
name = "fusion_process_dump_proto",
srcs = ["fusion_process_dump.proto"],
@@ -2074,311 +1130,6 @@
)
cc_library(
- name = "priority_fusion",
- srcs = ["priority_fusion.cc"],
- hdrs = ["priority_fusion.h"],
- deps = [
- ":backend_configs_cc",
- ":fusion_process_dump_proto_cc",
- ":gpu_fusible",
- ":hlo_fusion_analysis",
- ":hlo_traversal",
- "//xla:debug_options_flags",
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:dump",
- "//xla/service:fusion_queue",
- "//xla/service:hlo_cost_analysis",
- "//xla/service:hlo_graph_dumper",
- "//xla/service:hlo_pass",
- "//xla/service:instruction_fusion",
- "//xla/service/gpu/model:fusion_analysis_cache",
- "//xla/service/gpu/model:gpu_hlo_cost_analysis",
- "//xla/service/gpu/model:gpu_performance_model",
- "//xla/service/gpu/model:gpu_performance_model_base",
- "//xla/service/gpu/model:symbolic_tile_analysis",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/meta:type_traits",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:blocking_counter",
- "@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:status",
- ],
-)
-
-xla_cc_test(
- name = "priority_fusion_test",
- srcs = ["priority_fusion_test.cc"],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
- tags = ["no_pip"],
- deps = [
- ":backend_configs_cc",
- ":gpu_device_info_for_tests",
- ":gpu_fusible",
- ":hlo_fusion_analysis",
- ":priority_fusion",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_cost_analysis",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service/gpu/model:gpu_hlo_cost_analysis",
- "//xla/tests:hlo_test_base",
- "//xla/tests:verified_hlo_module",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "multi_output_fusion",
- srcs = ["multi_output_fusion.cc"],
- hdrs = ["multi_output_fusion.h"],
- deps = [
- ":gpu_fusible",
- "//xla:debug_options_flags",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/ir:hlo_dfs_reachability",
- "//xla/service:hlo_cost_analysis",
- "//xla/service:hlo_graph_dumper",
- "//xla/service:hlo_pass",
- "//xla/service:instruction_fusion",
- "//xla/service/gpu/model:gpu_hlo_cost_analysis",
- "//xla/service/gpu/model:gpu_performance_model",
- "//xla/service/gpu/model:gpu_performance_model_base",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "multi_output_fusion_test",
- srcs = ["multi_output_fusion_test.cc"],
- tags = [
- "nomsan",
- ],
- deps = [
- ":gpu_device_info_for_tests",
- ":gpu_fusible",
- ":multi_output_fusion",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_cost_analysis",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "rename_fusions",
- srcs = ["rename_fusions.cc"],
- hdrs = ["rename_fusions.h"],
- deps = [
- ":hlo_traversal",
- ":ir_emission_utils",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:btree",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- ],
-)
-
-xla_cc_test(
- name = "rename_fusions_test",
- srcs = ["rename_fusions_test.cc"],
- deps = [
- ":rename_fusions",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest",
- ],
-)
-
-xla_cc_test(
- name = "softmax_rewriter_triton_test",
- srcs = ["softmax_rewriter_triton_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":gpu_device_info_for_tests",
- ":softmax_rewriter_triton",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:instruction_fusion",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service/gpu/fusions/triton:triton_support",
- "//xla/service/gpu/model:gpu_hlo_cost_analysis",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # build_cleaner: keep
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_sanitize_constant_names",
- srcs = ["gpu_sanitize_constant_names.cc"],
- hdrs = ["gpu_sanitize_constant_names.h"],
- deps = [
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:name_uniquer",
- "//xla/service/llvm_ir:buffer_assignment_util",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:logging",
- ],
-)
-
-xla_cc_test(
- name = "gpu_sanitize_constant_names_test",
- srcs = ["gpu_sanitize_constant_names_test.cc"],
- deps = [
- ":gpu_sanitize_constant_names",
- "//xla:literal_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- ],
-)
-
-cc_library(
- name = "fusion_merger",
- srcs = ["fusion_merger.cc"],
- hdrs = ["fusion_merger.h"],
- deps = [
- ":gpu_fusible",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_cost_analysis",
- "//xla/service:hlo_graph_dumper",
- "//xla/service:hlo_pass",
- "//xla/service:instruction_fusion",
- "//xla/service/gpu/model:gpu_hlo_cost_analysis",
- "//xla/service/gpu/model:gpu_performance_model",
- "//xla/service/gpu/model:gpu_performance_model_base",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:status",
- ],
-)
-
-xla_cc_test(
- name = "fusion_merger_test",
- srcs = ["fusion_merger_test.cc"],
- tags = [
- "nomsan",
- ],
- deps = [
- ":fusion_merger",
- ":gpu_device_info_for_tests",
- ":gpu_fusible",
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_cost_analysis",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/types:span",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "gpu_conv_padding_legalization",
- srcs = ["gpu_conv_padding_legalization.cc"],
- hdrs = ["gpu_conv_padding_legalization.h"],
- deps = [
- ":cublas_cudnn",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:window_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:shape_inference",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gpu_conv_padding_legalization_test",
- srcs = ["gpu_conv_padding_legalization_test.cc"],
- deps = [
- ":cublas_cudnn",
- ":gpu_conv_padding_legalization",
- "//xla:shape_util",
- "//xla:test",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@local_tsl//tsl/platform:test",
- ],
-)
-
-cc_library(
name = "cudnn_support_utils",
srcs = ["cudnn_support_utils.cc"],
hdrs = ["cudnn_support_utils.h"],
@@ -2420,183 +1171,6 @@
)
cc_library(
- name = "cudnn_pad_for_convolutions",
- srcs = ["cudnn_pad_for_convolutions.cc"],
- hdrs = ["cudnn_pad_for_convolutions.h"],
- deps = [
- ":cublas_cudnn",
- ":cudnn_support_utils",
- ":stream_executor_util",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/stream_executor",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/functional:bind_front",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "cudnn_pad_for_convolutions_test",
- srcs = ["cudnn_pad_for_convolutions_test.cc"],
- deps = [
- ":cublas_cudnn",
- ":cudnn_pad_for_convolutions",
- "//xla/service:hlo_parser",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # build_cleaner: keep
- "@com_google_googletest//:gtest",
- ],
-)
-
-cc_library(
- name = "cudnn_vectorize_convolutions",
- srcs = ["cudnn_vectorize_convolutions.cc"],
- hdrs = ["cudnn_vectorize_convolutions.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":cudnn_support_utils",
- ":stream_executor_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla/client:xla_builder",
- "//xla/client:xla_computation",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla/stream_executor",
- "//xla/stream_executor:dnn",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "cudnn_vectorize_convolutions_test",
- srcs = ["cudnn_vectorize_convolutions_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":cudnn_vectorize_convolutions",
- "//xla:util",
- "//xla/service:call_inliner",
- "//xla/service:hlo_parser",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:dnn",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # build_cleaner: keep
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/status:statusor",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "cudnn_simplify_padding",
- srcs = ["cudnn_simplify_padding.cc"],
- hdrs = ["cudnn_simplify_padding.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- "//xla:literal",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "cudnn_simplify_padding_test",
- srcs = ["cudnn_simplify_padding_test.cc"],
- deps = [
- ":cudnn_pad_for_convolutions",
- ":cudnn_simplify_padding",
- ":cudnn_vectorize_convolutions",
- "//xla:literal",
- "//xla:util",
- "//xla/service:algebraic_simplifier",
- "//xla/service:call_inliner",
- "//xla/service:hlo_pass",
- "//xla/service:hlo_pass_pipeline",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service:reshape_mover",
- "//xla/service:tuple_simplifier",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:dnn",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # build_cleaner: keep
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/functional:function_ref",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "cublas_pad_for_gemms",
- srcs = ["cublas_pad_for_gemms.cc"],
- hdrs = ["cublas_pad_for_gemms.h"],
- deps = [
- ":gemm_fusion",
- ":ir_emission_utils",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service/gpu/fusions/triton:triton_support",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "cublas_padding_requirements",
srcs = ["cublas_padding_requirements.cc"],
hdrs = ["cublas_padding_requirements.h"],
@@ -2609,92 +1183,6 @@
],
)
-xla_cc_test(
- name = "cublas_pad_for_gemms_test",
- srcs = ["cublas_pad_for_gemms_test.cc"],
- tags = [
- "nomsan",
- ],
- deps = [
- ":cublas_pad_for_gemms",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # build_cleaner: keep
- "@com_google_googletest//:gtest",
- ],
-)
-
-cc_library(
- name = "cudnn_fusion_compiler",
- srcs = if_cuda_is_configured(["cudnn_fusion_compiler.cc"]),
- hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]),
- deps = if_cuda_is_configured([
- ":backend_configs_cc",
- ":cudnn_support_utils",
- ":ir_emission_utils",
- ":kernel_reuse_cache",
- ":matmul_utils",
- ":triton_fusion_analysis",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_config_cuda//cuda:cudnn_header",
- "//xla:shape_util",
- "//xla:comparison_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:hlo_pass",
- "//xla/stream_executor:dnn",
- "//xla/stream_executor:stream_executor_h",
- "//xla/service:dump",
- "//xla/stream_executor/cuda:cudnn_frontend_helpers",
- "//xla/stream_executor/cuda:cudnn_plugin",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ]),
-)
-
-cc_library(
- name = "cudnn_workspace_rewriter",
- srcs = if_cuda_is_configured(["cudnn_workspace_rewriter.cc"]),
- hdrs = if_cuda_is_configured(["cudnn_workspace_rewriter.h"]),
- deps = if_cuda_is_configured([
- ":backend_configs_cc",
- ":ir_emission_utils",
- ":gpu_fused_mha_runner",
- ":cublas_cudnn",
- ":stream_executor_util",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_config_cuda//cuda:cudnn_header",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/stream_executor/cuda:cudnn_frontend_helpers",
- "//xla/stream_executor/cuda:cudnn_plugin",
- "//xla/stream_executor:dnn",
- "//xla/stream_executor:stream_executor_h",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- "//xla:status_macros",
- ]),
-)
-
tf_proto_library(
name = "executable_proto",
srcs = ["executable.proto"],
@@ -2749,43 +1237,6 @@
)
cc_library(
- name = "gpu_reduce_scatter_creator",
- srcs = ["gpu_reduce_scatter_creator.cc"],
- hdrs = ["gpu_reduce_scatter_creator.h"],
- deps = [
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:collective_opt_utils",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- ],
-)
-
-cc_library(
- name = "gpu_all_gather_optimizer",
- srcs = ["gpu_all_gather_optimizer.cc"],
- hdrs = ["gpu_all_gather_optimizer.h"],
- deps = [
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:collective_ops_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- ],
-)
-
-cc_library(
name = "gpu_float_support",
srcs = ["gpu_float_support.cc"],
hdrs = ["gpu_float_support.h"],
@@ -2856,256 +1307,10 @@
)
cc_library(
- name = "command_buffer_scheduling",
- srcs = ["command_buffer_scheduling.cc"],
- hdrs = ["command_buffer_scheduling.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":hlo_fusion_analysis",
- ":hlo_traversal",
- ":ir_emission_utils",
- ":variant_visitor",
- "//xla:shape_util",
- "//xla:util",
- "//xla/ffi:ffi_api",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "command_buffer_scheduling_test",
- srcs = ["command_buffer_scheduling_test.cc"],
- deps = [
- ":command_buffer_scheduling",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_parser",
- "//xla/stream_executor:device_description",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "//xla/tests:verified_hlo_module",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "custom_kernel_fusion_autotuner",
- srcs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.cc"]),
- hdrs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.h"]),
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
- deps = if_gpu_is_configured([
- ":autotuner_compile_util",
- ":autotuner_util",
- ":backend_configs_cc",
- ":buffer_comparator",
- ":gemm_rewriter",
- ":gpu_float_support",
- ":gpu_fusible",
- ":instruction_fusion",
- ":ir_emission_utils",
- ":matmul_utils",
- ":split_k_gemm_rewriter",
- "//xla/service/gpu/kernels:custom_kernel",
- "//xla/service/gpu/kernels:custom_kernel_fusion",
- ":stream_executor_util",
- ":cudnn_fusion_compiler",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- "@local_config_cuda//cuda:cuda_headers",
- "//xla:autotuning_proto_cc",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla/tools:hlo_decomposer_lib",
- "//xla:statusor",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:algorithm_util",
- "//xla/service:dump",
- "//xla/service:executable",
- "//xla/service:float_normalization",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla/service:shaped_buffer",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:device_memory",
- "//xla/stream_executor",
- "//xla/stream_executor/gpu:redzone_allocator",
- "@local_tsl//tsl/lib/core:bits",
- "@local_tsl//tsl/platform:blocking_counter",
- "@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:protobuf",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/profiler/lib:scoped_annotation",
- "//xla/tsl/util/proto:proto_utils",
- "//xla/service/gpu:hlo_traversal",
- ]) + [
- "//xla/stream_executor:stream_executor_memory_allocator",
- "@com_google_absl//absl/status",
- "@local_tsl//tsl/platform:path",
- ],
-)
-
-xla_test(
- name = "custom_kernel_fusion_autotuner_test",
- srcs = if_cuda_is_configured(["custom_kernel_fusion_autotuner_test.cc"]),
- backends = [
- "gpu",
- ],
- deps = [
- ":autotuner_util",
- ":custom_kernel_fusion_autotuner",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass_pipeline",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:path",
- "@local_tsl//tsl/platform:test",
- ],
-)
-
-cc_library(
- name = "custom_kernel_fusion_rewriter",
- srcs = ["custom_kernel_fusion_rewriter.cc"],
- hdrs = ["custom_kernel_fusion_rewriter.h"],
- deps = [
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service/gpu/kernels:custom_fusion_library",
- "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "custom_kernel_fusion_rewriter_test",
- srcs = ["custom_kernel_fusion_rewriter_test.cc"],
- deps = [
- ":custom_kernel_fusion_rewriter",
- ":gpu_device_info_for_tests",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
- name = "dynamic_slice_fusion_rewriter",
- srcs = ["dynamic_slice_fusion_rewriter.cc"],
- hdrs = ["dynamic_slice_fusion_rewriter.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":gpu_constants",
- ":hlo_traversal",
- ":ir_emission_utils",
- "//xla:shape_util",
- "//xla:util",
- "//xla/ffi:ffi_api",
- "//xla/ffi/api:c_api",
- "//xla/hlo/ir:hlo",
- "//xla/service:custom_call_target_registry",
- "//xla/service:hlo_pass",
- "//xla/service/gpu/kernels:custom_fusion_library",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "dynamic_slice_fusion_rewriter_test",
- srcs = if_cuda_is_configured(["dynamic_slice_fusion_rewriter_test.cc"]),
- deps = [
- ":dynamic_slice_fusion_rewriter",
- ":gpu_device_info_for_tests",
- "//xla:shape_util",
- "//xla/client:xla_builder",
- "//xla/client/lib:constants",
- "//xla/ffi",
- "//xla/ffi:ffi_api",
- "//xla/hlo/ir:hlo",
- "//xla/service:buffer_value",
- "//xla/service:custom_call_target_registry",
- "//xla/service:executable",
- "//xla/service:hlo_memory_scheduler",
- "//xla/service:hlo_module_config",
- "//xla/stream_executor",
- "//xla/stream_executor/gpu:gpu_types_header",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/status",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
name = "fusion_pipeline",
srcs = ["fusion_pipeline.cc"],
hdrs = ["fusion_pipeline.h"],
deps = [
- ":fusion_merger",
- ":horizontal_input_fusion",
- ":horizontal_loop_fusion",
- ":instruction_fusion",
- ":multi_output_fusion",
- ":priority_fusion",
- ":variadic_op_splitter",
"//xla:xla_proto_cc",
"//xla/service:cpu_gpu_shape_verifier",
"//xla/service:hlo_cost_analysis",
@@ -3116,6 +1321,13 @@
"//xla/service:hlo_verifier",
"//xla/service:layout_assignment",
"//xla/service/gpu/model:gpu_hlo_cost_analysis",
+ "//xla/service/gpu/transforms:fusion_merger",
+ "//xla/service/gpu/transforms:horizontal_input_fusion",
+ "//xla/service/gpu/transforms:horizontal_loop_fusion",
+ "//xla/service/gpu/transforms:instruction_fusion",
+ "//xla/service/gpu/transforms:multi_output_fusion",
+ "//xla/service/gpu/transforms:priority_fusion",
+ "//xla/service/gpu/transforms:variadic_op_splitter",
"//xla/stream_executor:device_description",
"@local_tsl//tsl/platform:env",
],
@@ -3126,10 +1338,6 @@
srcs = ["prepare_hlo_for_ir_emitting_pipeline.cc"],
hdrs = ["prepare_hlo_for_ir_emitting_pipeline.h"],
deps = [
- ":alias_passthrough_params",
- ":copy_fusion",
- ":gpu_sanitize_constant_names",
- ":horizontal_loop_fusion",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:copy_insertion",
@@ -3140,6 +1348,10 @@
"//xla/service:hlo_verifier",
"//xla/service:layout_assignment",
"//xla/service:loop_schedule_linearizer",
+ "//xla/service/gpu/transforms:alias_passthrough_params",
+ "//xla/service/gpu/transforms:copy_fusion",
+ "//xla/service/gpu/transforms:horizontal_loop_fusion",
+ "//xla/service/gpu/transforms:sanitize_constant_names",
],
)
@@ -3153,53 +1365,20 @@
]),
deps = if_gpu_is_configured([
# go/keep-sorted start prefix_order=":,,
- ":algorithm_checker",
- ":alias_passthrough_params",
- ":all_reduce_blueconnect",
- ":autotuner_util",
":buffer_sharing",
- ":collective_permute_cycle_decomposer",
- ":collective_permute_valid_iteration_annotator",
- ":command_buffer_scheduling",
":compile_module_to_llvm_ir",
":conv_layout_normalization",
- ":copy_fusion",
- ":custom_kernel_fusion_autotuner",
- ":custom_kernel_fusion_rewriter",
- ":dot_dimension_sorter",
- ":dot_operand_converter",
- ":double_buffer_loop_unrolling",
- ":dynamic_slice_fusion_rewriter",
":executable_proto_cc",
":execution_stream_assignment",
- ":fusion_merger",
":fusion_pipeline",
- ":fusion_wrapper",
- ":gemm_broadcast_folding_rewriter",
- ":gemm_fusion",
- ":gemm_rewriter",
- ":gemv_rewriter",
- ":gpu_algebraic_simplifier",
- ":gpu_all_gather_optimizer",
- ":gpu_async_collective_annotator",
":gpu_constants",
- ":gpu_conv_rewriter",
- ":gpu_convert_async_collectives_to_sync",
":gpu_executable",
":gpu_float_support",
":gpu_hlo_schedule",
":gpu_latency_hiding_scheduler",
- ":gpu_layout_assignment",
":gpu_p2p_pipeliner",
- ":gpu_reduce_scatter_creator",
- ":gpu_sanitize_constant_names",
- ":gpu_scatter_expander",
":gpu_spmd_pipeline",
- ":gpu_windowed_einsum_handler",
":hlo_fusion_stats",
- ":horizontal_input_fusion",
- ":horizontal_loop_fusion",
- ":instruction_fusion",
":ir_emission_utils",
":ir_emitter",
":ir_emitter_context",
@@ -3207,28 +1386,10 @@
":kernel_reuse_cache",
":matmul_utils",
":metrics",
- ":move_copy_to_users",
- ":multi_output_fusion",
- ":pipelined_p2p_rewriter",
":prepare_hlo_for_ir_emitting_pipeline",
- ":priority_fusion",
- ":reduction_degenerate_dim_remover",
- ":reduction_dimension_grouper",
- ":reduction_layout_normalizer",
- ":reduction_splitter",
":reduction_utils",
- ":rename_fusions",
":runtime_intrinsics",
- ":scatter_slice_simplifier",
- ":softmax_rewriter_triton",
- ":stream_attribute_annotator",
- ":stream_attribute_async_wrapper",
":stream_executor_util",
- ":topk_specializer",
- ":topk_splitter",
- ":tree_reduction_rewriter",
- ":triton_fusion_numerics_verifier",
- ":variadic_op_splitter",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@@ -3252,9 +1413,53 @@
"@llvm-project//mlir:Support",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
+ "//xla/service/gpu/autotuning:autotuner_util",
+ "//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner",
"//xla/service/gpu/model:gpu_cost_model_stats_collection",
"//xla/service/gpu/model:gpu_hlo_cost_analysis",
"//xla/service/gpu/runtime:thunk",
+ "//xla/service/gpu/transforms:algebraic_simplifier",
+ "//xla/service/gpu/transforms:algorithm_checker",
+ "//xla/service/gpu/transforms:all_gather_optimizer",
+ "//xla/service/gpu/transforms:all_reduce_blueconnect",
+ "//xla/service/gpu/transforms:all_reduce_splitter",
+ "//xla/service/gpu/transforms:async_collective_annotator",
+ "//xla/service/gpu/transforms:collective_permute_cycle_decomposer",
+ "//xla/service/gpu/transforms:collective_permute_valid_iteration_annotator",
+ "//xla/service/gpu/transforms:command_buffer_scheduling",
+ "//xla/service/gpu/transforms:conv_rewriter",
+ "//xla/service/gpu/transforms:convert_async_collectives_to_sync",
+ "//xla/service/gpu/transforms:cudnn_custom_call_converter",
+ "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
+ "//xla/service/gpu/transforms:dot_dimension_sorter",
+ "//xla/service/gpu/transforms:dot_operand_converter",
+ "//xla/service/gpu/transforms:double_buffer_loop_unrolling",
+ "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
+ "//xla/service/gpu/transforms:fusion_wrapper",
+ "//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter",
+ "//xla/service/gpu/transforms:gemm_fusion",
+ "//xla/service/gpu/transforms:gemm_rewriter",
+ "//xla/service/gpu/transforms:gemv_rewriter",
+ "//xla/service/gpu/transforms:layout_assignment",
+ "//xla/service/gpu/transforms:move_copy_to_users",
+ "//xla/service/gpu/transforms:pipelined_p2p_rewriter",
+ "//xla/service/gpu/transforms:reduce_scatter_creator",
+ "//xla/service/gpu/transforms:reduction_degenerate_dim_remover",
+ "//xla/service/gpu/transforms:reduction_dimension_grouper",
+ "//xla/service/gpu/transforms:reduction_layout_normalizer",
+ "//xla/service/gpu/transforms:reduction_splitter",
+ "//xla/service/gpu/transforms:rename_fusions",
+ "//xla/service/gpu/transforms:sanitize_constant_names",
+ "//xla/service/gpu/transforms:scatter_expander",
+ "//xla/service/gpu/transforms:scatter_slice_simplifier",
+ "//xla/service/gpu/transforms:softmax_rewriter_triton",
+ "//xla/service/gpu/transforms:stream_attribute_annotator",
+ "//xla/service/gpu/transforms:stream_attribute_async_wrapper",
+ "//xla/service/gpu/transforms:topk_specializer",
+ "//xla/service/gpu/transforms:topk_splitter",
+ "//xla/service/gpu/transforms:tree_reduction_rewriter",
+ "//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
+ "//xla/service/gpu/transforms:windowed_einsum_handler",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/spmd:collective_permute_motion",
"//xla/service:algebraic_simplifier",
@@ -3265,7 +1470,6 @@
"//xla/service:all_reduce_folder",
"//xla/service:all_reduce_promotion",
"//xla/service:all_reduce_reassociate",
- "//xla/service:all_reduce_splitter",
"//xla/service:async_collective_creator",
"//xla/service:batchnorm_expander",
"//xla/service:bitcast_dtypes_expander",
@@ -3395,11 +1599,10 @@
xla_test(
name = "gpu_compiler_test",
- srcs = if_gpu_is_configured(["gpu_compiler_test.cc"]),
+ srcs = ["gpu_compiler_test.cc"],
backends = ["gpu"],
data = ["gpu_compiler_test_autotune_db.textproto"],
deps = [
- ":autotuner_util",
":gpu_compiler",
":gpu_hlo_schedule",
":metrics",
@@ -3413,6 +1616,7 @@
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service:xla_debug_info_manager",
+ "//xla/service/gpu/autotuning:autotuner_util",
"//xla/stream_executor:device_description",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
@@ -3449,7 +1653,7 @@
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_memory_scheduler",
"//xla/service:hlo_rematerialization",
- "//xla/service/gpu:stream_attribute_annotator",
+ "//xla/service/gpu/transforms:stream_attribute_annotator",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
@@ -3511,34 +1715,13 @@
"no_rocm",
],
deps = [
- ":autotuner_util",
":buffer_sharing",
- ":conv_algorithm_picker",
- ":cublas_pad_for_gemms",
":cublas_padding_requirements",
- ":cudnn_fused_conv_rewriter",
- ":cudnn_fused_mha_rewriter",
- ":cudnn_fused_mha_transpose_fusion",
- ":cudnn_fusion_compiler",
- ":cudnn_norm_rewriter",
- ":cudnn_pad_for_convolutions",
- ":cudnn_simplify_padding",
- ":cudnn_vectorize_convolutions",
- ":cudnn_workspace_rewriter",
- ":cusolver_rewriter",
- ":dot_sparsity_rewriter",
- ":gemm_algorithm_picker",
- ":gemm_fusion_autotuner",
- ":gpu_algebraic_simplifier",
":gpu_asm_opts_util",
":gpu_compiler",
- ":gpu_conv_padding_legalization",
- ":gpu_conv_rewriter",
- ":gpu_sort_rewriter",
":ir_emission_utils",
":metrics",
":target_constants",
- ":triangular_solve_rewriter",
"//xla:autotune_results_proto_cc",
"//xla:util",
"//xla:xla_proto_cc",
@@ -3561,13 +1744,36 @@
"//xla/service:hlo_verifier",
"//xla/service:reshape_mover",
"//xla/service:tuple_simplifier",
+ "//xla/service/gpu/autotuning:autotuner_util",
+ "//xla/service/gpu/autotuning:conv_algorithm_picker",
+ "//xla/service/gpu/autotuning:gemm_algorithm_picker",
+ "//xla/service/gpu/autotuning:gemm_fusion_autotuner",
"//xla/service/gpu/llvm_gpu_backend",
+ "//xla/service/gpu/transforms:algebraic_simplifier",
+ "//xla/service/gpu/transforms:conv_padding_legalization",
+ "//xla/service/gpu/transforms:conv_rewriter",
+ "//xla/service/gpu/transforms:cublas_pad_for_gemms",
+ "//xla/service/gpu/transforms:cudnn_custom_call_compiler",
+ "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter",
+ "//xla/service/gpu/transforms:cudnn_fused_mha_rewriter",
+ "//xla/service/gpu/transforms:cudnn_fused_mha_transpose_fusion",
+ "//xla/service/gpu/transforms:cudnn_fusion_compiler",
+ "//xla/service/gpu/transforms:cudnn_norm_rewriter",
+ "//xla/service/gpu/transforms:cudnn_pad_for_convolutions",
+ "//xla/service/gpu/transforms:cudnn_simplify_padding",
+ "//xla/service/gpu/transforms:cudnn_vectorize_convolutions",
+ "//xla/service/gpu/transforms:dot_sparsity_rewriter",
+ "//xla/service/gpu/transforms:gpusolver_rewriter",
+ "//xla/service/gpu/transforms:sort_rewriter",
+ "//xla/service/gpu/transforms:triangular_solve_rewriter",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor",
"//xla/stream_executor:dnn",
"//xla/stream_executor/cuda:cuda_asm_compiler",
"//xla/stream_executor/cuda:cuda_diagnostics",
"//xla/stream_executor/cuda:cuda_platform_id",
+ "//xla/stream_executor/cuda:nvjitlink",
+ "//xla/stream_executor/cuda:nvjitlink_support",
"//xla/stream_executor/cuda:ptx_compilation_method",
"//xla/stream_executor/cuda:ptx_compiler",
"//xla/stream_executor/cuda:ptx_compiler_support",
@@ -3615,7 +1821,6 @@
"gpu_a100",
],
tags = [
- "gpu",
"no_rocm",
"nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan.
],
@@ -3654,7 +1859,6 @@
"gpu",
],
tags = [
- "gpu",
"no_rocm",
"nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan.
],
@@ -3666,6 +1870,7 @@
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/stream_executor:device_description",
+ "//xla/stream_executor/cuda:nvjitlink_support",
"//xla/stream_executor/cuda:ptx_compilation_method",
"//xla/stream_executor/cuda:ptx_compiler_support",
"//xla/stream_executor/cuda:ptx_linking_method",
@@ -3745,45 +1950,6 @@
)
cc_library(
- name = "gpu_algebraic_simplifier",
- srcs = [
- "gpu_algebraic_simplifier.cc",
- ],
- hdrs = [
- "gpu_algebraic_simplifier.h",
- ],
- deps = [
- ":matmul_utils",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:algebraic_simplifier",
- "//xla/service:hlo_pass",
- "//xla/service/gpu/fusions/triton:triton_support",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- ],
-)
-
-xla_cc_test(
- name = "gpu_algebraic_simplifier_test",
- srcs = ["gpu_algebraic_simplifier_test.cc"],
- deps = [
- ":gpu_algebraic_simplifier",
- "//xla/hlo/ir:hlo",
- "//xla/service:algebraic_simplifier",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "amdgpu_compiler_impl",
srcs = [
"amdgpu_compiler.cc",
@@ -3793,20 +1959,9 @@
],
tags = ["manual"],
deps = [
- ":autotuner_util",
- ":conv_algorithm_picker",
- ":cublas_pad_for_gemms",
":cublas_padding_requirements",
- ":cudnn_fused_conv_rewriter",
- ":cusolver_rewriter",
- ":gemm_algorithm_picker",
- ":gpu_algebraic_simplifier",
":gpu_compiler",
- ":gpu_conv_padding_legalization",
- ":gpu_conv_rewriter",
- ":gpu_sort_rewriter",
":target_constants",
- ":triangular_solve_rewriter",
"//xla:util",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
@@ -3823,7 +1978,18 @@
"//xla/service:hlo_verifier",
"//xla/service:reshape_mover",
"//xla/service:tuple_simplifier",
+ "//xla/service/gpu/autotuning:autotuner_util",
+ "//xla/service/gpu/autotuning:conv_algorithm_picker",
+ "//xla/service/gpu/autotuning:gemm_algorithm_picker",
"//xla/service/gpu/llvm_gpu_backend",
+ "//xla/service/gpu/transforms:algebraic_simplifier",
+ "//xla/service/gpu/transforms:conv_padding_legalization",
+ "//xla/service/gpu/transforms:conv_rewriter",
+ "//xla/service/gpu/transforms:cublas_pad_for_gemms",
+ "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter",
+ "//xla/service/gpu/transforms:gpusolver_rewriter",
+ "//xla/service/gpu/transforms:sort_rewriter",
+ "//xla/service/gpu/transforms:triangular_solve_rewriter",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:dnn",
@@ -3843,53 +2009,6 @@
)
cc_library(
- name = "all_reduce_blueconnect",
- srcs = ["all_reduce_blueconnect.cc"],
- hdrs = ["all_reduce_blueconnect.h"],
- deps = [
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:collective_ops_utils",
- "//xla/service:computation_placer_hdr",
- "//xla/service:global_device_id",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:btree",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "all_reduce_blueconnect_test",
- srcs = ["all_reduce_blueconnect_test.cc"],
- deps = [
- ":all_reduce_blueconnect",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:computation_placer_hdr",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
name = "xfeed_queue",
hdrs = ["xfeed_queue.h"],
deps = [
@@ -3934,109 +2053,12 @@
)
cc_library(
- name = "gpu_layout_assignment",
- srcs = ["gpu_layout_assignment.cc"],
- hdrs = ["gpu_layout_assignment.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":matmul_utils",
- ":reduction_utils",
- ":stream_executor_util",
- "//xla:shape_layout",
- "//xla:shape_util",
- "//xla:util",
- "//xla:window_util",
- "//xla:xla_data_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:computation_layout",
- "//xla/service:host_memory_offload_annotations_hdr",
- "//xla/service:layout_assignment",
- "//xla/service:logical_buffer",
- "//xla/stream_executor",
- "//xla/stream_executor:dnn",
- "//xla/tsl/util:env_var",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gpu_layout_assignment_test",
- srcs = ["gpu_layout_assignment_test.cc"],
- deps = [
- ":gpu_layout_assignment",
- ":stream_executor_util",
- "//xla:shape_layout",
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:computation_layout",
- "//xla/service:hlo_parser",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:dnn",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # build_cleaner: keep
- "@com_google_absl//absl/types:span",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_schedule_postprocessing",
- srcs = ["gpu_schedule_postprocessing.cc"],
- hdrs = ["gpu_schedule_postprocessing.h"],
- deps = [
- ":backend_configs_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gpu_schedule_postprocessing_test",
- srcs = ["gpu_schedule_postprocessing_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":gpu_schedule_postprocessing",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_parser",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "gpu_hlo_schedule",
srcs = ["gpu_hlo_schedule.cc"],
hdrs = ["gpu_hlo_schedule.h"],
deps = [
":backend_configs_cc",
":gpu_latency_hiding_scheduler",
- ":gpu_schedule_postprocessing",
- ":scheduling_instruction_annotator",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
@@ -4049,6 +2071,8 @@
"//xla/service:p2p_schedule_preparation",
"//xla/service:profile_guided_latency_estimator",
"//xla/service/gpu/model:analytical_latency_estimator",
+ "//xla/service/gpu/transforms:schedule_postprocessing",
+ "//xla/service/gpu/transforms:scheduling_instruction_annotator",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
@@ -4085,6 +2109,7 @@
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
+ "//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
@@ -4141,7 +2166,6 @@
srcs = ["gpu_spmd_pipeline.cc"],
hdrs = ["gpu_spmd_pipeline.h"],
deps = [
- ":gpu_algebraic_simplifier",
":runtime_intrinsics",
"//xla/hlo/ir:hlo",
"//xla/hlo/transforms:hlo_constant_splitter",
@@ -4160,6 +2184,7 @@
"//xla/service:tuple_simplifier",
"//xla/service:while_loop_constant_sinking",
"//xla/service:while_loop_simplifier",
+ "//xla/service/gpu/transforms:algebraic_simplifier",
"//xla/service/spmd:collective_permute_motion",
"//xla/service/spmd:stateful_rng_spmd_partitioner",
"//xla/service/spmd/shardy:shardy_xla_pass",
@@ -4292,6 +2317,7 @@
name = "hlo_fusion_analysis",
srcs = ["hlo_fusion_analysis.cc"],
hdrs = ["hlo_fusion_analysis.h"],
+ compatible_with = get_compatible_with_portable(),
deps = [
":backend_configs_cc",
":hlo_traversal",
@@ -4319,6 +2345,8 @@
":gpu_device_info_for_tests",
":hlo_fusion_analysis",
":hlo_traversal",
+ ":ir_emission_utils",
+ "//xla:protobuf_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/tests:hlo_test_base",
@@ -4426,10 +2454,13 @@
name = "gpu_fusible",
srcs = ["gpu_fusible.cc"],
hdrs = ["gpu_fusible.h"],
+ compatible_with = get_compatible_with_portable(),
deps = [
":backend_configs_cc",
+ ":hlo_fusion_analysis",
":hlo_traversal",
":ir_emission_utils",
+ ":launch_dimensions",
":reduction_utils",
"//xla:permutation_util",
"//xla:shape_util",
@@ -4467,265 +2498,6 @@
],
)
-cc_library(
- name = "cudnn_fused_conv_rewriter",
- srcs = ["cudnn_fused_conv_rewriter.cc"],
- hdrs = ["cudnn_fused_conv_rewriter.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- "//xla:comparison_util",
- "//xla:debug_options_flags",
- "//xla:literal",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "//xla/stream_executor",
- "//xla/stream_executor:dnn",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:ml_dtypes",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_test(
- name = "cudnn_fused_conv_rewriter_test",
- srcs = ["cudnn_fused_conv_rewriter_test.cc"],
- backend_tags = {
- "gpu_a100": [
- "noasan",
- "nomsan",
- "no_rocm",
- ],
- },
- backends = [
- "gpu_a100",
- "gpu_amd_any",
- ] + if_oss(["gpu_any"]),
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
- "TENSORFLOW_USE_ROCM=1",
- ]),
- shard_count = 10,
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":cudnn_fused_conv_rewriter",
- ":gpu_conv_rewriter",
- ":stream_executor_util",
- "//xla:comparison_util",
- "//xla:error_spec",
- "//xla/hlo/ir:hlo",
- "//xla/service:algebraic_simplifier",
- "//xla/service:convert_mover",
- "//xla/service:hlo_constant_folding",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla/service:hlo_pass_pipeline",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service:reshape_mover",
- "//xla/service/gpu/tests:gpu_codegen_test",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:dnn",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "//xla/tests:verified_hlo_module",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test_main",
- ] + if_cuda_is_configured([
- "@local_config_cuda//cuda:cuda_headers",
- "@local_config_cuda//cuda:cudnn_header",
- ]) + if_rocm_is_configured([
- "@local_config_rocm//rocm:rocm_headers",
- ]),
-)
-
-cc_library(
- name = "cudnn_norm_rewriter",
- srcs = ["cudnn_norm_rewriter.cc"],
- hdrs = ["cudnn_norm_rewriter.h"],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- "//xla:shape_util",
- "//xla:types",
- "//xla:util",
- "//xla:window_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "//xla/stream_executor",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/protobuf:dnn_proto_cc",
- ] + if_cuda_is_configured([
- "@local_config_cuda//cuda:cuda_headers",
- "@local_config_cuda//cuda:cudnn_header",
- ]) + if_google([
- "@com_google_protobuf//:wrappers_cc_proto",
- ]),
-)
-
-xla_test(
- name = "cudnn_norm_rewriter_test",
- srcs = ["cudnn_norm_rewriter_test.cc"],
- backends = ["gpu"],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
- deps = [
- ":cublas_cudnn",
- ":cudnn_norm_rewriter",
- "//xla:error_spec",
- "//xla/service/gpu/tests:gpu_codegen_test",
- "//xla/stream_executor:device_description",
- "//xla/tests:filecheck",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_googletest//:gtest_main",
- ] + if_cuda_is_configured([
- "@local_config_cuda//cuda:cuda_headers",
- "@local_config_cuda//cuda:cudnn_header",
- ]),
-)
-
-cc_library(
- name = "cudnn_fused_mha_rewriter",
- srcs = ["cudnn_fused_mha_rewriter.cc"],
- hdrs = ["cudnn_fused_mha_rewriter.h"],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":matmul_utils",
- ":stream_executor_util",
- "//xla:permutation_util",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:types",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "//xla/stream_executor",
- "//xla/stream_executor:dnn",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ] + if_cuda_is_configured([
- "@local_config_cuda//cuda:cuda_headers",
- ]),
-)
-
-cc_library(
- name = "cudnn_fused_mha_transpose_fusion",
- srcs = ["cudnn_fused_mha_transpose_fusion.cc"],
- hdrs = ["cudnn_fused_mha_transpose_fusion.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":matmul_utils",
- "//xla:permutation_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_test(
- name = "cudnn_fused_mha_rewriter_test",
- srcs = ["cudnn_fused_mha_rewriter_test.cc"],
- backend_tags = {"gpu": [
- "requires-gpu-nvidia",
- "no_rocm",
- ]},
- backends = [
- "gpu",
- ],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- ":cudnn_fused_mha_rewriter",
- ":cudnn_fused_mha_transpose_fusion",
- "//xla:error_spec",
- "//xla:test_helpers",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:algebraic_simplifier",
- "//xla/service:computation_layout",
- "//xla/service:hlo_cse",
- "//xla/service:hlo_dce",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_parser",
- "//xla/service:hlo_verifier",
- "//xla/service:layout_normalization",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service:reshape_decomposer",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:dnn",
- "//xla/tests:hlo_test_base",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test_main",
- ] + if_cuda_is_configured([
- "@local_config_cuda//cuda:cuda_headers",
- "@local_config_cuda//cuda:cudnn_header",
- ]),
-)
-
xla_test(
name = "float_support_test",
srcs = ["float_support_test.cc"],
@@ -4763,78 +2535,15 @@
)
cc_library(
- name = "variadic_op_splitter",
- srcs = ["variadic_op_splitter.cc"],
- hdrs = ["variadic_op_splitter.h"],
- deps = [
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_scatter_expander",
- srcs = ["gpu_scatter_expander.cc"],
- hdrs = ["gpu_scatter_expander.h"],
- deps = [
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:scatter_expander",
- "@com_google_absl//absl/strings:string_view",
- ],
-)
-
-xla_cc_test(
- name = "variadic_op_splitter_test",
- srcs = ["variadic_op_splitter_test.cc"],
- tags = [
- "nomsan",
- ],
- deps = [
- ":variadic_op_splitter",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_parser",
- "//xla/service:pattern_matcher",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-tf_proto_library(
- name = "gpu_autotuning_proto",
- srcs = ["gpu_autotuning.proto"],
- cc_api_version = 2,
- protodeps = [
- ":backend_configs",
- "//xla:xla_data_proto",
- "//xla/service:hlo_proto",
- "//xla:autotuning_proto",
- ],
-)
-
-cc_library(
name = "hlo_algorithm_denylist",
srcs = ["hlo_algorithm_denylist.cc"],
hdrs = ["hlo_algorithm_denylist.h"],
deps = [
":backend_configs_cc",
- ":gpu_autotuning_proto_cc",
"//xla:autotuning_proto_cc",
"//xla:debug_options_flags",
"//xla/hlo/ir:backend_config",
+ "//xla/service/gpu/autotuning:gpu_autotuning_proto_cc",
"//xla/stream_executor:dnn",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
@@ -4861,133 +2570,6 @@
],
)
-cc_library(
- name = "alias_passthrough_params",
- srcs = ["alias_passthrough_params.cc"],
- hdrs = ["alias_passthrough_params.h"],
- deps = [
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- ],
-)
-
-xla_cc_test(
- name = "alias_passthrough_params_test",
- srcs = ["alias_passthrough_params_test.cc"],
- tags = [
- "nomsan",
- ],
- deps = [
- ":alias_passthrough_params",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "//xla/tsl/lib/core:status_test_util",
- "@local_tsl//tsl/platform:test",
- ],
-)
-
-cc_library(
- name = "horizontal_loop_fusion",
- srcs = ["horizontal_loop_fusion.cc"],
- hdrs = ["horizontal_loop_fusion.h"],
- deps = [
- ":gpu_fusible",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:sub_byte_normalization",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_test(
- name = "horizontal_loop_fusion_test",
- srcs = ["horizontal_loop_fusion_test.cc"],
- backends = ["gpu"],
- deps = [
- ":gpu_device_info_for_tests",
- ":horizontal_loop_fusion",
- ":instruction_fusion",
- "//xla:error_spec",
- "//xla:shape_util",
- "//xla:test",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_dce",
- "//xla/service:hlo_parser",
- "//xla/service:hlo_pass",
- "//xla/service:hlo_pass_pipeline",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/log",
- ],
-)
-
-cc_library(
- name = "horizontal_input_fusion",
- srcs = ["horizontal_input_fusion.cc"],
- hdrs = ["horizontal_input_fusion.h"],
- deps = [
- ":gpu_fusible",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_test(
- name = "horizontal_input_fusion_test",
- srcs = ["horizontal_input_fusion_test.cc"],
- backends = ["gpu"],
- deps = [
- ":gpu_device_info_for_tests",
- ":horizontal_input_fusion",
- "//xla:error_spec",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:test",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service/gpu/tests:gpu_codegen_test",
- "//xla/stream_executor:device_description",
- "//xla/tests:xla_internal_test_main",
- ],
-)
-
xla_cc_test(
name = "gpu_float_support_test",
srcs = ["gpu_float_support_test.cc"],
@@ -5010,151 +2592,6 @@
)
cc_library(
- name = "reduction_degenerate_dim_remover",
- srcs = ["reduction_degenerate_dim_remover.cc"],
- hdrs = ["reduction_degenerate_dim_remover.h"],
- deps = [
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "reduction_dimension_grouper",
- srcs = ["reduction_dimension_grouper.cc"],
- hdrs = ["reduction_dimension_grouper.h"],
- deps = [
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "reduction_splitter",
- srcs = ["reduction_splitter.cc"],
- hdrs = ["reduction_splitter.h"],
- deps = [
- ":reduction_utils",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "reduction_splitter_test",
- srcs = ["reduction_splitter_test.cc"],
- deps = [
- ":reduction_splitter",
- "//xla:shape_util",
- "//xla:test",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_parser",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- ],
-)
-
-cc_library(
- name = "reduction_layout_normalizer",
- srcs = ["reduction_layout_normalizer.cc"],
- hdrs = ["reduction_layout_normalizer.h"],
- deps = [
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "tree_reduction_rewriter",
- srcs = ["tree_reduction_rewriter.cc"],
- hdrs = ["tree_reduction_rewriter.h"],
- deps = [
- ":reduction_utils",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:collective_ops_utils",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/numeric:bits",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gemm_broadcast_folding_rewriter",
- srcs = ["gemm_broadcast_folding_rewriter.cc"],
- hdrs = ["gemm_broadcast_folding_rewriter.h"],
- deps = [
- ":backend_configs_cc",
- ":cublas_cudnn",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "metrics",
srcs = ["metrics.cc"],
hdrs = ["metrics.h"],
@@ -5166,48 +2603,6 @@
)
cc_library(
- name = "dot_operand_converter",
- srcs = ["dot_operand_converter.cc"],
- hdrs = ["dot_operand_converter.h"],
- deps = [
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:op_expander_pass",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- ],
-)
-
-xla_test(
- name = "dot_operand_converter_test",
- srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]),
- backends = [
- "gpu_a100",
- "gpu_p100",
- "gpu_v100",
- "gpu_amd_any",
- ],
- deps = if_gpu_is_configured(
- [
- ":dot_operand_converter",
- "@com_google_googletest//:gtest",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_matchers",
- "//xla/service:pattern_matcher",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@local_tsl//tsl/platform:statusor",
- ],
- ["@local_tsl//tsl/platform:test_main"], # b/317293391
- ) + ["//xla:xla_data_proto_cc"],
-)
-
-cc_library(
name = "make_batch_pointers",
srcs = if_gpu_is_configured(["make_batch_pointers.cc"]),
hdrs = if_gpu_is_configured(["make_batch_pointers.h"]),
@@ -5237,25 +2632,6 @@
],
)
-cc_library(
- name = "triangular_solve_rewriter",
- srcs = ["triangular_solve_rewriter.cc"],
- hdrs = ["triangular_solve_rewriter.h"],
- deps = [
- ":cublas_cudnn",
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
tsl_gpu_library(
name = "runtime_intrinsics",
srcs = ["runtime_intrinsics.cc"],
@@ -5324,43 +2700,6 @@
)
cc_library(
- name = "scatter_slice_simplifier",
- srcs = ["scatter_slice_simplifier.cc"],
- hdrs = ["scatter_slice_simplifier.h"],
- deps = [
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "scatter_slice_simplifier_test",
- srcs = ["scatter_slice_simplifier_test.cc"],
- deps = [
- ":scatter_slice_simplifier",
- "//xla:shape_util",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
name = "conv_layout_normalization",
srcs = ["conv_layout_normalization.cc"],
hdrs = ["conv_layout_normalization.h"],
@@ -5377,133 +2716,6 @@
],
)
-cc_library(
- name = "topk_specializer",
- srcs = ["topk_specializer.cc"],
- hdrs = ["topk_specializer.h"],
- deps = [
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:hlo_proto_cc",
- "//xla/service:tuple_util",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "topk_splitter",
- srcs = ["topk_splitter.cc"],
- hdrs = ["topk_splitter.h"],
- deps = [
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/numeric:bits",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "topk_splitter_test",
- srcs = ["topk_splitter_test.cc"],
- deps = [
- ":topk_splitter",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_dce",
- "//xla/service:pattern_matcher",
- "//xla/service:topk_rewriter",
- "//xla/tests:hlo_test_base",
- "//xla/tests:verified_hlo_module",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- ],
-)
-
-xla_test(
- name = "topk_test",
- srcs = ["topk_test.cc"],
- backends = ["gpu"],
- deps = [
- ":topk_specializer",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service:platform_util",
- "//xla/service:topk_rewriter",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main", # fixdeps: keep
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
- name = "copy_fusion",
- srcs = ["copy_fusion.cc"],
- hdrs = ["copy_fusion.h"],
- deps = [
- ":gpu_fusible",
- ":hlo_traversal",
- ":ir_emission_utils",
- ":reduction_utils",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- ],
-)
-
-cc_library(
- name = "algorithm_checker",
- srcs = ["algorithm_checker.cc"],
- hdrs = ["algorithm_checker.h"],
- deps = [
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:algorithm_util",
- "//xla/service:hlo_pass",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- ],
-)
-
xla_test(
name = "dot_algorithm_support_test",
srcs = if_gpu_is_configured(["dot_algorithm_support_test.cc"]),
@@ -5618,138 +2830,6 @@
],
)
-cc_library(
- name = "fusion_wrapper",
- srcs = ["fusion_wrapper.cc"],
- hdrs = ["fusion_wrapper.h"],
- deps = [
- ":gpu_fusible",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:errors",
- ],
-)
-
-xla_cc_test(
- name = "fusion_wrapper_test",
- srcs = ["fusion_wrapper_test.cc"],
- deps = [
- ":fusion_wrapper",
- "//xla/tests:hlo_test_base",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-xla_cc_test(
- name = "copy_fusion_test",
- srcs = ["copy_fusion_test.cc"],
- deps = [
- ":copy_fusion",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
- ],
-)
-
-xla_cc_test(
- name = "autotuner_util_test",
- srcs = if_cuda_is_configured(["autotuner_util_test.cc"]),
- data = [
- "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb",
- "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb",
- "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb",
- ],
- deps = if_cuda_is_configured([
- # keep sorted
- ":autotuner_util",
- "//xla:autotune_results_proto_cc",
- "//xla:autotuning_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:device_description_proto_cc",
- "//xla/stream_executor:platform",
- "//xla/stream_executor:platform_manager",
- "//xla/stream_executor/host:host_platform",
- "//xla/tests:hlo_test_base",
- "//xla/tests:verified_hlo_module",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/base:log_severity",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/log:scoped_mock_log",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:env",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:path",
- "@local_tsl//tsl/platform:protobuf",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- ]) + [
- "//xla/tests:xla_internal_test_main", # Keep outside GPU guard
- ],
-)
-
-cc_library(
- name = "double_buffer_loop_unrolling",
- srcs = ["double_buffer_loop_unrolling.cc"],
- hdrs = ["double_buffer_loop_unrolling.h"],
- deps = [
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/ir:hlo_instruction_utils",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:collective_ops_utils",
- "//xla/service:flatten_call_graph",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "double_buffer_loop_unrolling_test",
- srcs = ["double_buffer_loop_unrolling_test.cc"],
- deps = [
- ":double_buffer_loop_unrolling",
- "//xla:test",
- "//xla:xla_data_proto_cc",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:tuple_simplifier",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/container:flat_hash_set",
- "@local_tsl//tsl/platform:status_matchers",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
xla_test(
name = "determinism_test",
srcs = if_gpu_is_configured(["determinism_test.cc"]),
@@ -5762,7 +2842,7 @@
]),
deps = if_gpu_is_configured(
[
- ":autotuner_util",
+ "//xla/service/gpu/autotuning:autotuner_util",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/strings",
"//xla:literal",
@@ -5791,283 +2871,6 @@
)
cc_library(
- name = "collective_permute_cycle_decomposer",
- srcs = ["collective_permute_cycle_decomposer.cc"],
- hdrs = ["collective_permute_cycle_decomposer.h"],
- deps = [
- ":backend_configs_cc",
- "//xla:comparison_util",
- "//xla:literal_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:collective_ops_utils",
- "//xla/service:hlo_parser",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- ],
-)
-
-xla_cc_test(
- name = "collective_permute_cycle_decomposer_test",
- srcs = ["collective_permute_cycle_decomposer_test.cc"],
- deps = [
- ":collective_permute_cycle_decomposer",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_parser",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
- name = "collective_permute_valid_iteration_annotator",
- srcs = ["collective_permute_valid_iteration_annotator.cc"],
- hdrs = ["collective_permute_valid_iteration_annotator.h"],
- deps = [
- "//xla:literal_util",
- "//xla/hlo/ir:hlo",
- "//xla/service:collective_ops_utils",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "//xla/service:while_loop_analysis",
- ],
-)
-
-xla_cc_test(
- name = "collective_permute_valid_iteration_annotator_test",
- srcs = ["collective_permute_valid_iteration_annotator_test.cc"],
- deps = [
- ":collective_permute_valid_iteration_annotator",
- "//xla/hlo/ir:hlo",
- "//xla/service:collective_ops_utils",
- "//xla/service:hlo_pass_pipeline",
- "//xla/service:while_loop_trip_count_annotator",
- "//xla/tests:hlo_test_base",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
- name = "stream_attribute_annotator",
- srcs = ["stream_attribute_annotator.cc"],
- hdrs = ["stream_attribute_annotator.h"],
- deps = [
- ":backend_configs_cc",
- ":gpu_fusible",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:hlo_pass",
- "//xla/service/gpu/runtime:thunk",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "stream_attribute_annotator_test",
- srcs = ["stream_attribute_annotator_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":stream_attribute_annotator",
- "//xla/hlo/ir:hlo",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "stream_attribute_async_wrapper",
- srcs = ["stream_attribute_async_wrapper.cc"],
- hdrs = ["stream_attribute_async_wrapper.h"],
- deps = [
- ":backend_configs_cc",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "//xla/service/gpu/runtime:thunk",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "stream_attribute_async_wrapper_test",
- srcs = ["stream_attribute_async_wrapper_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":stream_attribute_async_wrapper",
- "//xla/hlo/ir:hlo",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "gpu_windowed_einsum_handler",
- srcs = ["gpu_windowed_einsum_handler.cc"],
- hdrs = ["gpu_windowed_einsum_handler.h"],
- deps = [
- ":backend_configs_cc",
- "//xla:literal_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:hlo_creation_utils",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "//xla/service:shape_inference",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gpu_windowed_einsum_handler_test",
- srcs = ["gpu_windowed_einsum_handler_test.cc"],
- deps = [
- ":backend_configs_cc",
- ":gpu_windowed_einsum_handler",
- "//xla/hlo/ir:hlo",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "triton_fusion_numerics_verifier",
- srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier.cc"]),
- hdrs = if_gpu_is_configured(["triton_fusion_numerics_verifier.h"]),
- deps = if_gpu_is_configured([
- ":autotuner_compile_util",
- ":autotuner_util",
- ":backend_configs_cc",
- ":buffer_comparator",
- ":ir_emission_utils",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/functional:any_invocable",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:executable",
- "//xla/service:hlo_pass",
- "//xla/service:shaped_buffer",
- "//xla/service:hlo_module_config",
- "//xla/stream_executor:stream",
- "//xla/tools:hlo_decomposer_lib",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ]),
-)
-
-xla_test(
- name = "triton_fusion_numerics_verifier_test",
- srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier_test.cc"]),
- backend_tags = {"gpu": [
- "requires-gpu-sm80",
- ]},
- backends = ["gpu"],
- deps = [
- ":autotuner_compile_util",
- ":autotuner_util",
- ":triton_fusion_numerics_verifier",
- "//xla:shape_util",
- "//xla:test_helpers",
- "//xla/hlo/ir:hlo",
- "//xla/service:platform_util",
- "//xla/stream_executor:platform",
- "//xla/tests:hlo_test_base",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "pipelined_p2p_rewriter",
- srcs = ["pipelined_p2p_rewriter.cc"],
- hdrs = ["pipelined_p2p_rewriter.h"],
- deps = [
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/hlo/utils:hlo_query",
- "//xla/service:collective_ops_utils",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "pipelined_p2p_rewriter_test",
- srcs = ["pipelined_p2p_rewriter_test.cc"],
- deps = [
- ":pipelined_p2p_rewriter",
- "//xla/hlo/ir:hlo",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-cc_library(
name = "execution_stream_assignment",
srcs = ["execution_stream_assignment.cc"],
hdrs = ["execution_stream_assignment.h"],
@@ -6138,32 +2941,3 @@
"@local_tsl//tsl/platform:statusor",
],
)
-
-cc_library(
- name = "scheduling_instruction_annotator",
- srcs = ["scheduling_instruction_annotator.cc"],
- hdrs = ["scheduling_instruction_annotator.h"],
- deps = [
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_pass",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "scheduling_instruction_annotator_test",
- srcs = ["scheduling_instruction_annotator_test.cc"],
- deps = [
- ":scheduling_instruction_annotator",
- "//xla/hlo/ir:hlo",
- "//xla/tests:filecheck",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
diff --git a/third_party/xla/xla/service/gpu/algorithm_checker.cc b/third_party/xla/xla/service/gpu/algorithm_checker.cc
deleted file mode 100644
index 3104293..0000000
--- a/third_party/xla/xla/service/gpu/algorithm_checker.cc
+++ /dev/null
@@ -1,117 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/algorithm_checker.h"
-
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/algorithm_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-bool HasNonDefaultOperandPrecision(const PrecisionConfig& config) {
- return absl::c_any_of(config.operand_precision(), [](int precision) {
- return static_cast<PrecisionConfig::Precision>(precision) !=
- PrecisionConfig::DEFAULT;
- });
-}
-
-class AlgorithmCheckerVisitor : public ConstDfsHloVisitorWithDefault {
- public:
- explicit AlgorithmCheckerVisitor(
- se::GpuComputeCapability gpu_compute_capability)
- : gpu_compute_capability_(std::move(gpu_compute_capability)) {}
-
- absl::Status RunOnModule(
- const HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_RETURN_IF_ERROR(computation->Accept(this));
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleDot(const HloInstruction* hlo) override {
- VLOG(1) << "Handling dot: " << hlo->ToString();
- const PrecisionConfig& config = hlo->precision_config();
-
- if (config.algorithm() != PrecisionConfig::ALG_UNSET &&
- HasNonDefaultOperandPrecision(config)) {
- LOG(WARNING)
- << "There is no need to set precisions when we set the algorithm: "
- << hlo->ToString();
- }
-
- if (config.algorithm() == PrecisionConfig::ALG_UNSET) {
- return absl::OkStatus();
- }
-
- PrimitiveType lhs_storage_type = hlo->operand(0)->shape().element_type();
- PrimitiveType rhs_storage_type = hlo->operand(1)->shape().element_type();
- PrimitiveType output_storage_type = hlo->shape().element_type();
-
- if (lhs_storage_type != rhs_storage_type) {
- return absl::UnimplementedError(absl::StrFormat(
- "Dot operands must have the same type when using an algorithm: %s",
- hlo->ToString()));
- }
-
- return algorithm_util::IsSupportedDotAlgorithmOnGpu(
- config.algorithm(), gpu_compute_capability_, lhs_storage_type,
- output_storage_type)
- ? absl::OkStatus()
- : absl::UnimplementedError(absl::StrFormat(
- "Unsupported algorithm on the current device(s): %s",
- PrecisionConfig::Algorithm_Name(config.algorithm())));
- }
-
- absl::Status DefaultAction(const HloInstruction* hlo) override {
- return absl::OkStatus();
- }
-
- private:
- se::GpuComputeCapability gpu_compute_capability_;
-};
-
-} // namespace
-
-absl::StatusOr<bool> AlgorithmChecker::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- TF_RETURN_IF_ERROR(AlgorithmCheckerVisitor(gpu_compute_capability_)
- .RunOnModule(module, execution_threads));
- // No change was made.
- return false;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/algorithm_checker.h b/third_party/xla/xla/service/gpu/algorithm_checker.h
deleted file mode 100644
index f3b30c1..0000000
--- a/third_party/xla/xla/service/gpu/algorithm_checker.h
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_
-#define XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// This checks if the requested algorithms are supported. This can give an early
-// and specific error if an unsupported algorithm is requested.
-//
-// Note: Maybe we can make this more generic and move it outside of GPU.
-class AlgorithmChecker : public HloModulePass {
- public:
- explicit AlgorithmChecker(se::GpuComputeCapability gpu_compute_capability)
- : gpu_compute_capability_(std::move(gpu_compute_capability)){};
-
- absl::string_view name() const override { return "algorithm-checker"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::GpuComputeCapability gpu_compute_capability_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_
diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params.cc b/third_party/xla/xla/service/gpu/alias_passthrough_params.cc
deleted file mode 100644
index 5dea5bc..0000000
--- a/third_party/xla/xla/service/gpu/alias_passthrough_params.cc
+++ /dev/null
@@ -1,69 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/alias_passthrough_params.h"
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> AliasPassthroughParams::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- const HloInstruction* root = module->entry_computation()->root_instruction();
- if (module->entry_computation()->num_parameters() == 0 ||
- root->opcode() != HloOpcode::kTuple) {
- return false;
- }
- bool changed = false;
- absl::flat_hash_set<int64_t> used_params;
- for (int64_t i = 0; i < root->operand_count(); ++i) {
- if (root->operand(i)->opcode() == HloOpcode::kParameter &&
- used_params.count(root->operand(i)->parameter_number()) == 0) {
- VLOG(2) << "Parameter " << root->operand(i)->parameter_number()
- << " with shape " << root->operand(i)->shape().ToString()
- << " in module " << module->name()
- << " is passed-through to root tuple element " << i << ": "
- << root->shape().ToString();
-
- if (module->input_output_alias_config().OutputHasAlias({i}) ||
- module->input_output_alias_config().ParameterHasAlias(
- root->operand(i)->parameter_number(), /*param_index=*/{})) {
- VLOG(2) << "Skip setting the above pass-through alias as an alias may"
- << " have been set up for alising resource update.";
- continue;
- }
-
- TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias(
- /*output_index=*/{i},
- /*param_number=*/root->operand(i)->parameter_number(),
- /*param_index=*/{}));
- used_params.insert(root->operand(i)->parameter_number());
- changed = true;
- }
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params.h b/third_party/xla/xla/service/gpu/alias_passthrough_params.h
deleted file mode 100644
index 029068a..0000000
--- a/third_party/xla/xla/service/gpu/alias_passthrough_params.h
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_
-#define XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// This pass aliases input and output buffers that are associated with a
-// parameter that is passed through to the module root unmodified.
-//
-// This pass assumes that parameters and the root use unnested shapes, which is
-// the case for XLA:GPU.
-//
-// This pass must run prior to copy insertion.
-class AliasPassthroughParams : public HloModulePass {
- public:
- AliasPassthroughParams() = default;
- ~AliasPassthroughParams() override = default;
- absl::string_view name() const override { return "alias_passthrough_params"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_
diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc b/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc
deleted file mode 100644
index 2c09daf..0000000
--- a/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/alias_passthrough_params.h"
-
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-
-class AliasPassthroughParamsTest : public HloTestBase {};
-
-TEST_F(AliasPassthroughParamsTest, AliasPassThroughParams) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- p0 = f16[2048,1024] parameter(0)
- p1 = f16[2048,1024] parameter(1)
- sum = f16[2048,1024] add(p0, p1)
- ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
- })")
- .value();
- EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
- const auto& alias_config = module->input_output_alias_config();
- EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
- EXPECT_FALSE(alias_config.OutputHasAlias({1}));
- EXPECT_EQ(1, alias_config.GetAliasedParameter({2})->parameter_number);
-}
-
-TEST_F(AliasPassthroughParamsTest, DoNotAliasPassThroughParamsMoreThanOnce) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- p0 = f16[2048,1024] parameter(0)
- ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p0)
- })")
- .value();
- EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
- const auto& alias_config = module->input_output_alias_config();
- EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
- EXPECT_FALSE(alias_config.OutputHasAlias({1}));
-}
-
-TEST_F(AliasPassthroughParamsTest, PresetAliases) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- p0 = f16[2048,1024] parameter(0)
- p1 = f16[2048,1024] parameter(1)
- sum = f16[2048,1024] add(p0, p1)
- ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
- })")
- .value();
-
- // Presetting an alias for p0 -> Sum. This could happen in a case of
- // `alias_resource_update`.
- auto& preset_alias = module->input_output_alias_config();
- TF_EXPECT_OK(preset_alias.SetUpAlias(/*output_index=*/{1},
- /*param_number=*/0,
- /*param_index=*/{}));
-
- EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
- const auto& alias_result = module->input_output_alias_config();
- // Assert that an alias p1 -> p1 is established by `AliasPassthroughParams`.
- EXPECT_EQ(1, alias_result.GetAliasedParameter({2})->parameter_number);
- EXPECT_FALSE(alias_result.OutputHasAlias({0}));
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc b/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc
deleted file mode 100644
index 2e75ffa..0000000
--- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc
+++ /dev/null
@@ -1,373 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/all_reduce_blueconnect.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/btree_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/computation_placer.h"
-#include "xla/service/global_device_id.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-std::vector<HloInstruction*> GetOutputs(HloInstruction& instruction) {
- if (!instruction.shape().IsTuple()) {
- return {&instruction};
- }
-
- std::vector<HloInstruction*> outputs;
- outputs.reserve(instruction.shape().tuple_shapes_size());
-
- HloComputation& computation = *instruction.parent(); // never null
- for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
- outputs.push_back(computation.AddInstruction(
- HloInstruction::CreateGetTupleElement(&instruction, i)));
- }
- return outputs;
-}
-
-struct DecomposedReplicaGroups {
- std::vector<ReplicaGroup> scatter_gather_groups;
- std::vector<ReplicaGroup> new_all_reduce_groups;
-};
-
-// Returns the global device id for the given replica id. Returns nullopt if
-// if the replica id can refer to multiple devices, or if the pass does not
-// support the CollectiveOpGroupMode.
-std::optional<GlobalDeviceId> TryConvertingReplicaIdToDeviceId(
- int64_t replica_id, const DeviceAssignment& device_assignment,
- CollectiveOpGroupMode collective_group_mode) {
- if (collective_group_mode == CollectiveOpGroupMode::kCrossReplica) {
- if (device_assignment.computation_count() != 1) {
- // If there are multiple partitions, the replica_id may refer to multiple
- // devices on different partitions.
- return std::nullopt;
- }
- return GlobalDeviceId{device_assignment(replica_id, /*computation_id=*/0)};
- } else if (collective_group_mode == CollectiveOpGroupMode::kFlattenedID) {
- int partition_count = device_assignment.computation_count();
- int64_t actual_replica_id = replica_id / partition_count;
- int64_t partition_id = replica_id % partition_count;
- return GlobalDeviceId{device_assignment(actual_replica_id, partition_id)};
- }
-
- // kCrossPartition and kCrossReplicaAndPartition are unsupported.
- VLOG(1) << "Skip AllReduceBlueConnect because of unsupported "
- "CollectiveOpGroupMode "
- << CollectiveOpGroupModeToString(collective_group_mode);
- return std::nullopt;
-}
-
-absl::StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroup(
- const ReplicaGroup& replica_group,
- const DeviceAssignment& device_assignment, size_t num_devices_per_host,
- CollectiveOpGroupMode collective_group_mode) {
- int group_size = replica_group.replica_ids_size();
- TF_RET_CHECK(group_size > 0);
-
- absl::btree_map<int, std::vector<int64_t>> replica_ids_by_host;
- for (int64_t replica_id : replica_group.replica_ids()) {
- std::optional<GlobalDeviceId> device_id = TryConvertingReplicaIdToDeviceId(
- replica_id, device_assignment, collective_group_mode);
- if (!device_id.has_value()) {
- return {std::nullopt};
- }
- TF_RET_CHECK(*device_id >= 0);
- // We assume that devices are ordered by host.
- int host_id = device_id->value() / num_devices_per_host;
- replica_ids_by_host[host_id].push_back(replica_id);
- }
-
- size_t num_local_devices = replica_ids_by_host.begin()->second.size();
- bool same_num_devices_on_each_host =
- absl::c_all_of(replica_ids_by_host, [&](const auto& entry) {
- return entry.second.size() == num_local_devices;
- });
-
- if (!same_num_devices_on_each_host) {
- return {std::nullopt};
- }
-
- std::vector<int64_t> sorted_replica_group;
- sorted_replica_group.reserve(group_size);
- for (const auto& entry : replica_ids_by_host) {
- absl::c_copy(entry.second, std::back_inserter(sorted_replica_group));
- }
-
- size_t scatter_group_size = std::max(num_local_devices, size_t(2));
- size_t num_scatter_groups = group_size / scatter_group_size;
-
- if ((group_size % scatter_group_size != 0) || (num_scatter_groups < 2)) {
- return {std::nullopt};
- }
-
- std::vector<ReplicaGroup> scatter_gather_groups(num_scatter_groups);
- std::vector<ReplicaGroup> new_all_reduce_groups(scatter_group_size);
-
- for (size_t i = 0; i < group_size; ++i) {
- int64_t replica_id = sorted_replica_group[i];
- scatter_gather_groups[i / scatter_group_size].add_replica_ids(replica_id);
- new_all_reduce_groups[i % scatter_group_size].add_replica_ids(replica_id);
- }
-
- return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
- std::move(new_all_reduce_groups)}};
-}
-
-absl::StatusOr<std::optional<DecomposedReplicaGroups>>
-TryDecomposeReplicaGroups(const HloAllReduceInstruction& all_reduce,
- size_t num_devices_per_host) {
- const DeviceAssignment& device_assignment =
- all_reduce.GetModule()->config().static_device_assignment();
-
- absl::Span<const ReplicaGroup> replica_groups = all_reduce.replica_groups();
-
- ReplicaGroup all_replicas; // only populated if replica groups not present.
- if (replica_groups.empty()) {
- for (int i = 0; i < device_assignment.replica_count(); ++i) {
- all_replicas.add_replica_ids(i);
- }
- replica_groups = absl::MakeSpan(&all_replicas, 1);
- }
-
- TF_ASSIGN_OR_RETURN(
- CollectiveOpGroupMode collective_op_group_mode,
- GetCollectiveOpGroupMode(all_reduce.channel_id().has_value(),
- all_reduce.use_global_device_ids()));
-
- std::vector<ReplicaGroup> scatter_gather_groups;
- std::vector<ReplicaGroup> new_all_reduce_groups;
-
- // Try to find a valid decomposition for each replica group.
- for (const ReplicaGroup& replica_group : replica_groups) {
- TF_ASSIGN_OR_RETURN(
- std::optional<DecomposedReplicaGroups> decomposed_groups,
- TryDecomposeReplicaGroup(replica_group, device_assignment,
- num_devices_per_host,
- collective_op_group_mode));
-
- if (!decomposed_groups) return {std::nullopt};
-
- int scatter_group_size =
- decomposed_groups->scatter_gather_groups[0].replica_ids_size();
-
- if (scatter_gather_groups.empty()) {
- // Check that every operand is exactly divisible by scatter group sizes.
- for (const HloInstruction* operand : all_reduce.operands()) {
- TF_RET_CHECK(operand->shape().IsArray());
- int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
- if (num_elements % scatter_group_size != 0) {
- return {std::nullopt};
- }
- }
-
- scatter_gather_groups.reserve(
- replica_groups.size() *
- decomposed_groups->scatter_gather_groups.size());
- new_all_reduce_groups.reserve(
- replica_groups.size() *
- decomposed_groups->new_all_reduce_groups.size());
- } else if (scatter_group_size !=
- scatter_gather_groups[0].replica_ids_size()) {
- // Reduce-scatter would have different output shapes on different devices.
- return {std::nullopt};
- }
-
- absl::c_move(decomposed_groups->scatter_gather_groups,
- std::back_inserter(scatter_gather_groups));
- absl::c_move(decomposed_groups->new_all_reduce_groups,
- std::back_inserter(new_all_reduce_groups));
- }
-
- return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
- std::move(new_all_reduce_groups)}};
-}
-
-// Attempts to decompose all-reduces as described by the BlueConnect paper.
-//
-// If possible, the all-reduce will be transformed into:
-// 1. reduce-scatter
-// 2. all-reduce
-// 3. all-gather
-//
-// If the all-reduce replica groups have more than one device within the same
-// host, the reduce-scatter will be performed over all devices with each host.
-// Otherwise, the reduce-scatter will be performed between pairs of devices on
-// different hosts.
-//
-// When applied repeatedly, this transformation will reproduce the same pattern
-// as described in the BlueConnect paper.
-absl::StatusOr<bool> TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce,
- size_t num_devices_per_host) {
- TF_RET_CHECK(all_reduce);
- TF_RET_CHECK(!all_reduce->has_sharding());
-
- HloComputation& computation = *all_reduce->parent(); // never null
- PrimitiveType element_type = all_reduce->operand(0)->shape().element_type();
-
- TF_ASSIGN_OR_RETURN(
- std::optional<DecomposedReplicaGroups> decomposed_groups,
- TryDecomposeReplicaGroups(*all_reduce, num_devices_per_host));
-
- if (!decomposed_groups) return false;
-
- // Bitcast operands to 1D to guarantee that first dimension is divisible by
- // scatter group size (we checked num elements was divisible above).
- std::vector<HloInstruction*> flat_operands;
- flat_operands.reserve(all_reduce->operand_count());
- std::vector<Shape> flat_shapes;
- flat_shapes.reserve(all_reduce->operand_count());
- std::vector<Shape> scattered_shapes;
- scattered_shapes.reserve(all_reduce->operand_count());
-
- int scatter_group_size =
- decomposed_groups->scatter_gather_groups[0].replica_ids_size();
-
- for (HloInstruction* operand : all_reduce->operands()) {
- TF_RET_CHECK(operand->shape().IsArray());
- int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
- Shape flat_shape = ShapeUtil::MakeShape(element_type, {num_elements});
- flat_operands.push_back(computation.AddInstruction(
- HloInstruction::CreateBitcast(flat_shape, operand)));
- flat_shapes.push_back(std::move(flat_shape));
- scattered_shapes.push_back(ShapeUtil::MakeShape(
- element_type, {num_elements / scatter_group_size}));
- }
-
- Shape reduce_scatter_shape = ShapeUtil::MakeMaybeTupleShape(scattered_shapes);
-
- int64_t next_channel_id = hlo_query::NextChannelId(*computation.parent());
- auto get_channel_id = [&]() -> std::optional<int64_t> {
- if (all_reduce->channel_id().has_value()) {
- return next_channel_id++;
- }
- return std::nullopt;
- };
-
- HloInstruction* reduce_scatter =
- computation.AddInstruction(HloInstruction::CreateReduceScatter(
- reduce_scatter_shape, flat_operands, all_reduce->to_apply(),
- CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
- /*constrain_layout=*/false, get_channel_id(),
- all_reduce->use_global_device_ids(),
- /*scatter_dimension=*/0));
-
- HloInstruction* new_all_reduce =
- computation.AddInstruction(HloInstruction::CreateAllReduce(
- reduce_scatter_shape, GetOutputs(*reduce_scatter),
- all_reduce->to_apply(),
- CollectiveDeviceList(decomposed_groups->new_all_reduce_groups),
- /*constrain_layout=*/false, all_reduce->channel_id(),
- all_reduce->use_global_device_ids()));
-
- HloInstruction* all_gather =
- computation.AddInstruction(HloInstruction::CreateAllGather(
- ShapeUtil::MakeMaybeTupleShape(flat_shapes),
- GetOutputs(*new_all_reduce),
- /*all_gather_dimension=*/0,
- CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
- /*constrain_layout=*/false, get_channel_id(),
- all_reduce->use_global_device_ids()));
-
- // Bitcast back to the original shapes and replace all-reduce with decomposed
- // implementation.
- std::vector<HloInstruction*> outputs = GetOutputs(*all_gather);
- for (int64_t i = 0; i < outputs.size(); ++i) {
- outputs[i] = computation.AddInstruction(HloInstruction::CreateBitcast(
- all_reduce->operand(i)->shape(), outputs[i]));
- }
- HloInstruction* replacement = MaybeMakeTuple(outputs);
-
- TF_RETURN_IF_ERROR(
- all_reduce->CopyAllControlDepsTo(reduce_scatter, replacement));
-
- TF_RETURN_IF_ERROR(all_reduce->DropAllControlDeps());
- TF_RETURN_IF_ERROR(computation.ReplaceInstruction(all_reduce, replacement));
-
- // Try to apply decomposition recursively.
- TF_RETURN_IF_ERROR(
- TryDecomposeAllReduce(Cast<HloAllReduceInstruction>(new_all_reduce),
- num_devices_per_host)
- .status());
- return true;
-}
-
-} // namespace
-
-absl::StatusOr<bool> AllReduceBlueConnect::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- VLOG(1) << "Running AllReduceBlueConnect";
-
- if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
- VLOG(1)
- << "Skip AllReduceBlueConnect because the module contains all-reduce "
- "with constrained layouts";
- return false;
- }
- if (!module->config().has_static_device_assignment()) {
- VLOG(1)
- << "Skip AllReduceBlueConnect because the module doesn't have static "
- "device assignment";
- return false;
- }
-
- std::vector<HloAllReduceInstruction*> all_reduces;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kAllReduce) {
- all_reduces.push_back(Cast<HloAllReduceInstruction>(instruction));
- }
- }
- }
-
- bool changed = false;
- for (HloAllReduceInstruction* all_reduce : all_reduces) {
- TF_ASSIGN_OR_RETURN(
- bool all_reduce_changed,
- TryDecomposeAllReduce(all_reduce, num_devices_per_host_));
- changed |= all_reduce_changed;
- }
-
- return changed;
-}
-
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h b/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h
deleted file mode 100644
index 8633c77..0000000
--- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_
-#define XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_
-
-#include <cstddef>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Decomposes all-reduce operations using the BlueConnect algorithm.
-//
-// Paper: "BLUECONNECT: DECOMPOSING ALL-REDUCE FOR DEEP LEARNING ON
-// HETEROGENEOUS NETWORK HIERARCHY"
-// https://mlsys.org/Conferences/2019/doc/2019/130.pdf
-//
-// This algorithm attempts to minimize the number of levels of network hierarchy
-// traversed for as much data transfer as possible. This implementation assumes
-// that host IDs are ordered corresponding to network hierarchy.
-class AllReduceBlueConnect : public HloModulePass {
- public:
- explicit AllReduceBlueConnect(size_t num_devices_per_host)
- : num_devices_per_host_(num_devices_per_host) {}
-
- absl::string_view name() const override { return "all-reduce-blueconnect"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- size_t num_devices_per_host_;
-};
-
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_
diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc b/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc
deleted file mode 100644
index a6a66c5..0000000
--- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc
+++ /dev/null
@@ -1,414 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/all_reduce_blueconnect.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/computation_placer.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-using ::tsl::testing::IsOkAndHolds;
-namespace m = ::xla::match;
-
-using AllReduceBlueConnectTest = HloTestBase;
-
-HloPredicate MatchChannelId(std::optional<int64_t> channel_id) {
- return [channel_id](const HloInstruction* instruction) {
- return instruction->channel_id() == channel_id;
- };
-}
-
-void SetModuleConfig(HloModuleConfig* module_config, size_t replica_count,
- size_t partition_count = 1) {
- DeviceAssignment device_assignment(replica_count,
- /*computation_count=*/partition_count);
- device_assignment.FillIota(0);
- module_config->set_replica_count(replica_count);
- module_config->set_num_partitions(partition_count);
- module_config->set_static_device_assignment(device_assignment);
-}
-
-void SetModuleConfig(HloModule& module, size_t replica_count,
- size_t partition_count = 1) {
- SetModuleConfig(&module.mutable_config(), replica_count, partition_count);
-}
-
-TEST_F(AllReduceBlueConnectTest, OneStage) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[4,4] parameter(0)
- ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/8);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
- // clang-format off
- std::vector<std::vector<int64_t>> scatter_gather_groups = {
- {0, 1, 2, 3}, {4, 5, 6, 7}};
- std::vector<std::vector<int64_t>> new_all_reduce_groups = {
- {0, 4}, {1, 5}, {2, 6}, {3, 7}};
- // clang-format on
-
- auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
- auto reduce_scatter = m::ReduceScatter(bitcast)
- .WithShape(F32, {4})
- .WithReplicaGroups(scatter_gather_groups)
- .WithPredicate(MatchChannelId(std::nullopt));
- auto all_reduce = m::AllReduce(reduce_scatter)
- .WithShape(F32, {4})
- .WithReplicaGroups(new_all_reduce_groups)
- .WithPredicate(MatchChannelId(std::nullopt));
- auto all_gather = m::AllGather(all_reduce)
- .WithShape(F32, {16})
- .WithReplicaGroups(scatter_gather_groups)
- .WithPredicate(MatchChannelId(std::nullopt));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Bitcast(all_gather).WithShape(F32, {4, 4})));
-}
-
-TEST_F(AllReduceBlueConnectTest, TwoStage) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[4,4] parameter(0)
- ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/16);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
- std::vector<std::vector<int64_t>> outer_scatter_gather_groups = {
- {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}};
- std::vector<std::vector<int64_t>> inner_scatter_gather_groups = {
- {0, 4}, {8, 12}, {1, 5}, {9, 13}, {2, 6}, {10, 14}, {3, 7}, {11, 15}};
- std::vector<std::vector<int64_t>> new_all_reduce_groups = {
- {0, 8}, {4, 12}, {1, 9}, {5, 13}, {2, 10}, {6, 14}, {3, 11}, {7, 15}};
-
- auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
- auto reduce_scatter0 =
- m::ReduceScatter(bitcast0).WithShape(F32, {4}).WithReplicaGroups(
- outer_scatter_gather_groups);
- auto bitcast1 = m::Bitcast(reduce_scatter0).WithShape(F32, {4});
- auto reduce_scatter1 =
- m::ReduceScatter(bitcast1).WithShape(F32, {2}).WithReplicaGroups(
- inner_scatter_gather_groups);
- auto all_reduce = m::AllReduce(reduce_scatter1)
- .WithShape(F32, {2})
- .WithReplicaGroups(new_all_reduce_groups);
- auto all_gather0 = m::AllGather(all_reduce)
- .WithShape(F32, {4})
- .WithReplicaGroups(inner_scatter_gather_groups);
- auto bitcast2 = m::Bitcast(all_gather0).WithShape(F32, {4});
- auto all_gather1 =
- m::AllGather(bitcast2).WithShape(F32, {16}).WithReplicaGroups(
- outer_scatter_gather_groups);
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Bitcast(all_gather1).WithShape(F32, {4, 4})));
-}
-
-TEST_F(AllReduceBlueConnectTest, TwoOperands) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[4,4] parameter(0)
- p1 = f32[4,4,2] parameter(1)
- ROOT crs = (f32[4,4], f32[4,4,2]) all-reduce(p0, p1), to_apply=add
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/8);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
- // clang-format off
- std::vector<std::vector<int64_t>> scatter_gather_groups = {
- {0, 1, 2, 3}, {4, 5, 6, 7}};
- std::vector<std::vector<int64_t>> new_all_reduce_groups = {
- {0, 4}, {1, 5}, {2, 6}, {3, 7}};
- // clang-format on
-
- auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
- auto bitcast1 = m::Bitcast(m::Parameter(1)).WithShape(F32, {32});
-
- Shape expected0 = ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F32, {4}), ShapeUtil::MakeShape(F32, {8})});
- Shape expected1 = ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F32, {16}), ShapeUtil::MakeShape(F32, {32})});
- auto reduce_scatter = m::ReduceScatter(bitcast0, bitcast1)
- .WithShapeEqualTo(&expected0)
- .WithReplicaGroups(scatter_gather_groups);
- auto all_reduce = m::AllReduce(m::GetTupleElement(reduce_scatter, 0),
- m::GetTupleElement(reduce_scatter, 1))
- .WithShapeEqualTo(&expected0)
- .WithReplicaGroups(new_all_reduce_groups);
- auto all_gather = m::AllGather(m::GetTupleElement(all_reduce, 0),
- m::GetTupleElement(all_reduce, 1))
- .WithShapeEqualTo(&expected1)
- .WithReplicaGroups(scatter_gather_groups);
- auto bitcast2 =
- m::Bitcast(m::GetTupleElement(all_gather, 0)).WithShape(F32, {4, 4});
- auto bitcast3 =
- m::Bitcast(m::GetTupleElement(all_gather, 1)).WithShape(F32, {4, 4, 2});
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(bitcast2, bitcast3)));
-}
-
-TEST_F(AllReduceBlueConnectTest, MultiplePartitionsFilecheck) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[8,8] parameter(0)
- ROOT crs = f32[8,8] all-reduce(p0), channel_id=1,
- replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=add
-})";
- HloModuleConfig module_config;
- SetModuleConfig(&module_config, /*replica_count=*/1, /*partition_count=*/8);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- // Note: When matching strings like "replica_groups={{0,1,2,3}}", FileCheck
- // interprets the string inside the double braces as regex. So to match such
- // strings, we use "replica_groups={{..0,1,2,3..}}", where the dots match the
- // opening and closing braces.
- RunAndFilecheckHloRewrite(hlo_string, std::move(pass), R"(
- CHECK: %p0 = f32[8,8]{1,0} parameter(0)
- CHECK-NEXT: [[bitcast:%[^ ]+]] = f32[64]{0} bitcast(%p0)
- CHECK-NEXT: [[reduce_scatter:%[^ ]+]] = f32[16]{0} reduce-scatter([[bitcast]]), channel_id=2, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
- CHECK-NEXT: [[all_reduce:%[^ ]+]] = f32[16]{0} all-reduce([[reduce_scatter]]), channel_id=1, replica_groups={{..0,4.,.1,5.,.2,6.,.3,7..}}, use_global_device_ids=true, to_apply=%add
- CHECK-NEXT: [[all_gather:%[^ ]+]] = f32[64]{0} all-gather([[all_reduce]]), channel_id=3, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, dimensions={0}, use_global_device_ids=true
- CHECK-NEXT: ROOT [[output:%[^ ]+]] = f32[8,8]{1,0} bitcast([[all_gather]])
-}
-)",
- /*after_pass_checks=*/nullptr, &module_config);
-}
-
-TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesWithinReplicaGroup) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[4,4] parameter(0)
- ROOT crs = f32[4,4] all-reduce(p0),
- replica_groups={{0,1,2,7},{3,4,5,6}}, to_apply=add
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/8);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesAcrossReplicaGroups) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[4,4] parameter(0)
- ROOT crs = f32[4,4] all-reduce(p0),
- replica_groups={{0,1,4,5},{2,3,6,7},{8,9,10,11},{12,13,14,15}}, to_apply=add
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/16);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(AllReduceBlueConnectTest, OperandIndivisible) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[4,4] parameter(0)
- p1 = f32[9] parameter(1)
- ROOT crs = (f32[4,4], f32[9]) all-reduce(p0, p1), to_apply=add
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/8);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(AllReduceBlueConnectTest, ControlDeps) {
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[4,4] parameter(0)
- p1 = f32[4,4] parameter(1)
- add = f32[4,4] add(p0, p1)
- crs = f32[4,4] all-reduce(p0), to_apply=add, control-predecessors={add}
- ROOT add1 = f32[4,4] add(crs, add), control-predecessors={crs}
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/8);
-
- // Remember all-reduce's control succ and preds.
- const HloInstruction* ar =
- module->entry_computation()->root_instruction()->operand(0);
- auto expected_preds = ar->control_predecessors();
- auto expected_succs = ar->control_successors();
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
- // clang-format off
- std::vector<std::vector<int64_t>> scatter_gather_groups = {
- {0, 1, 2, 3}, {4, 5, 6, 7}};
- std::vector<std::vector<int64_t>> new_all_reduce_groups = {
- {0, 4}, {1, 5}, {2, 6}, {3, 7}};
- // clang-format on
-
- const HloInstruction *matched_rs, *matched_bitcast;
- auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
- auto reduce_scatter = m::ReduceScatter(&matched_rs, bitcast)
- .WithShape(F32, {4})
- .WithReplicaGroups(scatter_gather_groups);
- auto all_reduce = m::AllReduce(reduce_scatter)
- .WithShape(F32, {4})
- .WithReplicaGroups(new_all_reduce_groups);
- auto all_gather = m::AllGather(all_reduce)
- .WithShape(F32, {16})
- .WithReplicaGroups(scatter_gather_groups);
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Add()));
-
- EXPECT_THAT(
- root->operand(0),
- GmockMatch(
- m::Bitcast(&matched_bitcast, all_gather).WithShape(F32, {4, 4})));
-
- // Verify that control dependencies are transferred correctly.
- EXPECT_THAT(matched_rs, GmockMatch(m::Op().WithControlDeps(
- absl::MakeSpan(expected_preds), {})));
- EXPECT_THAT(matched_bitcast, GmockMatch(m::Op().WithControlDeps(
- {}, absl::MakeSpan(expected_succs))));
-}
-
-TEST_F(AllReduceBlueConnectTest, ReduceScatterUnchanged) {
- // Tests that this pass does not affect reduce-scatter. In principle, the
- // BlueConnect algorithm could be applied to reduce-scatter, but for now it
- // doesn't.
- constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
- p0 = f32[8,4] parameter(0)
- ROOT crs = f32[1,4] reduce-scatter(p0), dimensions={0}, to_apply=add
-})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- SetModuleConfig(*module, /*replica_count=*/8);
-
- AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
- EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
index 04483ce..ae541ba 100644
--- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
+++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
@@ -32,21 +32,21 @@
#include "xla/service/dot_dimension_merger.h"
#include "xla/service/float_normalization.h"
#include "xla/service/float_support.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/conv_algorithm_picker.h"
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-#include "xla/service/gpu/cusolver_rewriter.h"
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
#include "xla/service/gpu/gpu_compiler.h"
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/gpu_sort_rewriter.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "xla/service/gpu/target_constants.h"
-#include "xla/service/gpu/triangular_solve_rewriter.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+#include "xla/service/gpu/transforms/gpusolver_rewriter.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+#include "xla/service/gpu/transforms/triangular_solve_rewriter.h"
#include "xla/service/hlo_constant_folding.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_pass_fix.h"
@@ -123,8 +123,8 @@
pipeline.AddPass<FloatNormalization>(&conv_bf16_support);
pipeline.AddPass<GpusolverRewriter>();
- pipeline.AddPass<GpuConvRewriter>(gpu_version);
- pipeline.AddPass<GpuConvPaddingLegalization>();
+ pipeline.AddPass<ConvRewriter>(gpu_version);
+ pipeline.AddPass<ConvPaddingLegalization>();
auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version,
GetToolkitVersion());
@@ -135,7 +135,7 @@
pipeline.AddPass<CallInliner>();
pipeline.AddPass<TupleSimplifier>();
- // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter
+ // tf2xla bridge, DepthwiseConvolutionConverter and ConvRewriter
// introduces reshapes and transposes that can be eliminated using
// AlgebraicSimplifier We run algsimp to a fixed point.
AlgebraicSimplifierOptions options =
@@ -144,7 +144,7 @@
options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(options, gpu_version);
- // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
+ // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
// to a fixed point. Include algsimp because ReshapeMover relies on it.
[&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
@@ -166,7 +166,7 @@
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
}();
- // GpuConvRewriter, GpuConvPaddingLegalization and
+ // ConvRewriter, ConvPaddingLegalization and
// CudnnConvPadForTensorCores may add instructions which can be simplified
// by constant folding.
pipeline.AddPass<HloConstantFolding>();
@@ -240,7 +240,7 @@
absl::Status AMDGPUCompiler::AddCustomKernelReplacementPasses(
HloPassPipeline* pipeline, const DebugOptions& debug_options) {
if (debug_options.xla_gpu_enable_cub_radix_sort()) {
- pipeline->AddPass<GpuSortRewriter>();
+ pipeline->AddPass<SortRewriter>();
}
return absl::OkStatus();
}
diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.h b/third_party/xla/xla/service/gpu/amdgpu_compiler.h
index 483647b..062a0ef 100644
--- a/third_party/xla/xla/service/gpu/amdgpu_compiler.h
+++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.h
@@ -20,7 +20,7 @@
#include "absl/status/statusor.h"
#include "llvm/IR/Module.h"
#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_pass_pipeline.h"
diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc
deleted file mode 100644
index b3e880b..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc
+++ /dev/null
@@ -1,284 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_compile_util.h"
-
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "xla/executable_run_options.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/compiler.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/gpu_executable_run_options.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/maybe_owning_device_memory.h"
-#include "xla/service/service_executable_run_options.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-std::vector<ExecutionInput> ExecutionInputsFromBuffers(
- absl::Span<se::DeviceMemoryBase const> buffers,
- absl::Span<Shape const> shapes) {
- CHECK_EQ(buffers.size(), shapes.size());
- std::vector<ExecutionInput> inputs;
- for (int i = 0; i < buffers.size(); ++i) {
- inputs.emplace_back(shapes.at(i));
- // Our executable doesn't have input-output aliasing, so we can pass
- // unowned input buffers.
- inputs.back().SetUnownedBuffer(
- /*index=*/{}, MaybeOwningDeviceMemory(/*unowned=*/buffers.at(i)));
- }
- return inputs;
-}
-
-} // namespace
-
-AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config,
- Compiler* compiler,
- se::StreamExecutor& stream_executor,
- se::Stream& stream,
- se::DeviceMemoryAllocator& allocator,
- const DebugOptions& opts)
- : config_(config),
- compiler_(compiler),
- stream_executor_(stream_executor),
- stream_(stream),
- allocator_(allocator),
- opts_(opts) {
- // Avoid dumping compilation steps.
- opts_.set_xla_enable_dumping(false);
- opts_.set_xla_gpu_dump_autotune_results_to("");
- opts_.set_xla_gpu_load_autotune_results_from("");
- opts_.set_xla_gpu_dump_llvmir(false);
- opts_.set_xla_gpu_dump_autotune_logs_to("");
- // Avoid using another thread pool.
- opts_.set_xla_gpu_force_compilation_parallelism(1);
- opts_.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
- // Avoid using GPU graphs as we don't want to measure graph construction time.
- opts_.clear_xla_gpu_enable_command_buffer();
- opts_.set_xla_embed_ir_in_executable(false);
- opts_.set_xla_gpu_kernel_cache_file("");
-}
-
-absl::StatusOr<std::optional<AutotunerCompileUtil::ProfilingOutput>>
-AutotunerCompileUtil::ProfileExecutable(
- Executable* executable, se::Stream* stream,
- absl::Span<se::DeviceMemoryBase const> input_buffers,
- absl::Span<Shape const> input_shapes) {
- {
- std::vector<ExecutionInput> execution_inputs =
- ExecutionInputsFromBuffers(input_buffers, input_shapes);
- // Warmup: in and out buffers are reused while probing different configs,
- // so GPU caches should be in some comparable states during measurements.
- absl::StatusOr<ExecutionOutput> execution_output =
- Execute(*executable, std::move(execution_inputs));
- if (!execution_output.ok()) {
- // Treat register allocation error gracefully. If the compilation happens
- // with the driver during execution then the error could surface here.
- // It's enough to check this once here.
- if (execution_output.status().code() ==
- absl::StatusCode::kResourceExhausted) {
- return {std::nullopt};
- }
- return execution_output.status();
- }
-
- TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
- }
- std::vector<ExecutionInput> execution_inputs =
- ExecutionInputsFromBuffers(input_buffers, input_shapes);
- ExecutionProfile profile;
- // Flag that a warm-up run was executed so that GpuTimer can use the, more
- // accurate, delay kernel implementation.
- profile.set_warmup_run_executed(true);
- TF_ASSIGN_OR_RETURN(
- ExecutionOutput execution_output,
- Execute(*executable, std::move(execution_inputs), &profile));
- return std::make_optional<ProfilingOutput>(
- absl::Nanoseconds(profile.compute_time_ns()),
- execution_output.Commit().ConsumeResult());
-}
-
-absl::StatusOr<std::unique_ptr<Executable>> AutotunerCompileUtil::Compile(
- GenerateModuleFn extractor) {
- absl::StatusOr<std::unique_ptr<HloModule>> new_hlo_module = extractor(opts_);
- if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) {
- // Incompatible value of split-k is an example of an expected failure.
- return std::unique_ptr<Executable>();
- } else if (!new_hlo_module.status().ok()) {
- return new_hlo_module.status();
- }
-
- absl::StatusOr<std::unique_ptr<Executable>> out = compiler_->RunBackend(
- std::move(*new_hlo_module), &stream_executor_,
- Compiler::CompileOptions{&allocator_, /*thread_pool=*/nullptr,
- /*layout_canonicalization_callback=*/{},
- /*is_autotuning_compilation=*/true});
- if (out.status().code() == absl::StatusCode::kResourceExhausted ||
- out.status().code() == absl::StatusCode::kCancelled) {
- // Being out of shared memory budget or registers is an expected failure.
- // Cancelling upon register spilling is also an expected failure.
- return std::unique_ptr<Executable>();
- }
- return out;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> AutotunerCompileUtil::ExtractModule(
- GenerateModuleFn extractor) {
- return extractor(opts_);
-}
-
-/*static*/ absl::StatusOr<std::optional<AutotunerCompileUtil>>
-AutotunerCompileUtil::Create(const AutotuneConfig& config,
- const DebugOptions& opts) {
- if (config.IsDeviceless()) {
- return std::nullopt;
- }
- se::StreamExecutor* stream_exec = config.GetExecutor();
- se::DeviceMemoryAllocator* allocator = config.GetAllocator();
- TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream());
- TF_ASSIGN_OR_RETURN(Compiler * compiler,
- Compiler::GetForPlatform(stream_exec->GetPlatform()));
- return AutotunerCompileUtil(config, compiler, *stream_exec, *stream,
- *allocator, opts);
-}
-
-absl::StatusOr<ExecutionOutput> AutotunerCompileUtil::Execute(
- Executable& executable, std::vector<ExecutionInput> arguments,
- ExecutionProfile* profile) {
- // Require exclusive GPU lock to prevent other runs during autotuning.
- GpuExecutableRunOptions gpu_opts;
- gpu_opts.set_requires_exclusive_lock_on_gpu();
-
- ExecutableRunOptions run_options;
- run_options.set_device_ordinal(stream_executor_.device_ordinal());
- run_options.set_stream(&stream_);
- run_options.set_allocator(&allocator_);
- run_options.set_gpu_executable_run_options(&gpu_opts);
- run_options.set_execution_profile(profile);
- ServiceExecutableRunOptions service_run_options(run_options);
- TF_ASSIGN_OR_RETURN(ExecutionOutput output,
- executable.ExecuteAsyncOnStreamWrapper(
- &service_run_options, std::move(arguments)));
-
- return std::move(output);
-}
-
-absl::StatusOr<RedzoneBuffers> RedzoneBuffers::FromInstruction(
- const HloInstruction& instruction, const AutotuneConfig& config,
- const DebugOptions& debug_options, BuffersToCreate buffers_to_create) {
- RedzoneBuffers buffers;
-
- TF_ASSIGN_OR_RETURN(auto rz_allocator, AutotunerUtil::CreateRedzoneAllocator(
- config, debug_options));
- buffers.redzone_allocator_ =
- std::make_unique<se::RedzoneAllocator>(std::move(rz_allocator));
-
- int64_t rng_state = 0;
-
- TF_RETURN_IF_ERROR(
- buffers.CreateInputs(instruction, config, debug_options, rng_state));
-
- if (buffers_to_create == BuffersToCreate::kAllInputsAllOutputs ||
- buffers_to_create == BuffersToCreate::kAllInputsOutputsNoScratch) {
- TF_RETURN_IF_ERROR(buffers.CreateOutputs(instruction, config, debug_options,
- buffers_to_create, rng_state));
- }
-
- return buffers;
-}
-
-absl::Status RedzoneBuffers::CreateInputs(const HloInstruction& instruction,
- const AutotuneConfig& config,
- const DebugOptions& debug_options,
- int64_t& rng_state) {
- for (const auto* operand : instruction.operands()) {
- TF_ASSIGN_OR_RETURN(
- se::DeviceMemoryBase buf,
- AutotunerUtil::CreateBuffer(*redzone_allocator_, operand->shape(),
- config, rng_state));
- input_buffers_.push_back(buf);
- input_shapes_.push_back(operand->shape());
- }
- return absl::OkStatus();
-}
-
-absl::Status RedzoneBuffers::CreateOutputs(const HloInstruction& instruction,
- const AutotuneConfig& config,
- const DebugOptions& debug_options,
- BuffersToCreate buffers_to_create,
- int64_t& rng_state) {
- if (!instruction.shape().IsTuple()) {
- TF_ASSIGN_OR_RETURN(
- se::DeviceMemoryBase buf,
- AutotunerUtil::CreateBuffer(*redzone_allocator_, instruction.shape(),
- config, rng_state));
- output_buffers_.push_back(buf);
- output_shape_ = instruction.shape();
- return absl::OkStatus();
- }
-
- // The output is a tuple.
-
- auto current_shape_it = instruction.shape().tuple_shapes().begin();
- auto end = instruction.shape().tuple_shapes().end();
- end -= buffers_to_create == kAllInputsAllOutputs ? 0 : 1;
-
- output_shape_ = std::distance(current_shape_it, end) == 1
- ? output_shape_ = *current_shape_it
- : ShapeUtil::MakeTupleShape(
- std::vector<Shape>{current_shape_it, end});
-
- for (; current_shape_it < end; current_shape_it++) {
- if (current_shape_it->IsTuple()) {
- return Unimplemented("Nested tuples are unsupported by RedzoneBuffers.");
- }
- TF_ASSIGN_OR_RETURN(
- se::DeviceMemoryBase buf,
- AutotunerUtil::CreateBuffer(*redzone_allocator_, *current_shape_it,
- config, rng_state));
- output_buffers_.push_back(buf);
- }
- return absl::OkStatus();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.h b/third_party/xla/xla/service/gpu/autotuner_compile_util.h
deleted file mode 100644
index 5137fcf..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_compile_util.h
+++ /dev/null
@@ -1,177 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_
-#define XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/functional/any_invocable.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/compiler.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// Autotuning utils which require compiling fusions separately. Requires a
-// separate target, as runtime autotuning cannot perform compilation.
-class AutotunerCompileUtil {
- public:
- // The GenerateModuleFn must generate/extract a module using the provided
- // debug options. Typically it should set the debug options of the extracted
- // module before it would transform it, to ensure that the transforms can use
- // the debug options. In justified cases, it may override some of the provided
- // debug options.
- using GenerateModuleFn =
- absl::AnyInvocable<absl::StatusOr<std::unique_ptr<HloModule>>(
- const DebugOptions&)>;
-
- // Generates a compile util for a platform associated with the `stream`.
- //
- // Returns an empty optional if the AutotuneConfig is deviceless, as
- // autotuning is impossible in that case.
- static absl::StatusOr<std::optional<AutotunerCompileUtil>> Create(
- const AutotuneConfig& config, const DebugOptions& opts);
-
- struct ProfilingOutput {
- ProfilingOutput(absl::Duration duration, ScopedShapedBuffer&& buffer)
- : duration(duration), output(std::move(buffer)) {}
-
- absl::Duration duration;
- ScopedShapedBuffer output;
- };
-
- // Generates an executable first, given the module generator function in
- // `extractor`.
- //
- // Runs the resulting executable with the given extractor, cached with
- // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad
- // `Status` otherwise.
- absl::StatusOr<std::optional<ProfilingOutput>> ProfileExecutable(
- Executable* executable, se::Stream* stream,
- absl::Span<se::DeviceMemoryBase const> input_buffers,
- absl::Span<Shape const> input_shapes);
-
- // Generic method to compile a generated module from `extractor` in isolation.
- //
- // Returns:
- // - `nullptr` on *expected* failure
- // - `Executable` if everything goes fine.
- // - `Status` on *unexpected* failure.
- absl::StatusOr<std::unique_ptr<Executable>> Compile(
- GenerateModuleFn extractor);
-
- // Generic method to extract an HLO using the debug options of the
- // AutotunerCompileUtil.
- //
- // Typically we can use Compile directly.
- absl::StatusOr<std::unique_ptr<HloModule>> ExtractModule(
- GenerateModuleFn extractor);
-
- private:
- AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler,
- se::StreamExecutor& stream_executor, se::Stream& stream,
- se::DeviceMemoryAllocator& allocator,
- const DebugOptions& opts);
-
- absl::StatusOr<ExecutionOutput> Execute(Executable& executable,
- std::vector<ExecutionInput> arguments,
- ExecutionProfile* profile = nullptr);
-
- AutotuneConfig config_;
- Compiler* compiler_;
- se::StreamExecutor& stream_executor_;
- se::Stream& stream_;
- se::DeviceMemoryAllocator& allocator_;
- DebugOptions opts_;
-};
-
-// A RedZone allocator and a collection of buffers that store the inputs and
-// outputs of an HloInstruction. These are used when running the instruction
-// for autotuning.
-class RedzoneBuffers {
- public:
- enum BuffersToCreate {
- // Create a buffer for all of the instruction's operands. The result shape
- // is ignored.
- kAllInputs = 0,
- // Create a buffer for all of the instruction's operands and the entire
- // result shape. If the result shape is a tuple, a separate buffer is
- // created for each subshape.
- kAllInputsAllOutputs = 1,
- // Create a buffer for all of the instruction's operands and all of the
- // subshapes of the result tuple, except for the last one. The last subshape
- // is considered a scratch buffer and is assumed to be allocated elsewhere.
- // If the result shape is not a tuple, this will create a buffer
- // corresponding to the entire shape - equivalent to `kAllInputsAllOutputs`.
- kAllInputsOutputsNoScratch = 2,
- };
- static absl::StatusOr<RedzoneBuffers> FromInstruction(
- const HloInstruction& instruction, const AutotuneConfig& config,
- const DebugOptions& debug_options, BuffersToCreate buffers_to_create);
-
- const std::vector<se::DeviceMemoryBase>& input_buffers() const {
- return input_buffers_;
- }
- const std::vector<Shape>& input_shapes() const { return input_shapes_; }
- const std::vector<se::DeviceMemoryBase>& output_buffers() const {
- return output_buffers_;
- }
- const Shape& output_shape() const { return output_shape_; }
- se::RedzoneAllocator& RedzoneAllocator() const { return *redzone_allocator_; }
-
- private:
- absl::Status CreateInputs(const HloInstruction& instruction,
- const AutotuneConfig& config,
- const DebugOptions& debug_options,
- int64_t& rng_state);
-
- absl::Status CreateOutputs(const HloInstruction& instruction,
- const AutotuneConfig& config,
- const DebugOptions& debug_options,
- BuffersToCreate buffers_to_create,
- int64_t& rng_state);
-
- std::unique_ptr<se::RedzoneAllocator> redzone_allocator_;
- std::vector<se::DeviceMemoryBase> input_buffers_;
- std::vector<Shape> input_shapes_;
- std::vector<se::DeviceMemoryBase> output_buffers_;
- Shape output_shape_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc
deleted file mode 100644
index 1db5afb..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc
+++ /dev/null
@@ -1,196 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_compile_util.h"
-
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/platform_util.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using AutotunerCompileUtilTest = HloTestBase;
-
-TEST_F(AutotunerCompileUtilTest, VerifyOutputNotATuple) {
- constexpr absl::string_view kHlo = R"(
-HloModule hlo
-ENTRY main {
- p0 = f32[2,2] parameter(0)
- p1 = f32[4,4] parameter(1)
- p2 = f32[6,6] parameter(2)
- ROOT root = f32[1,2,3] custom-call(p0, p1, p2), custom_call_target="fake"
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
-
- se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
- PlatformUtil::GetStreamExecutors(platform));
-
- AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
- GetDebugOptionsForTest()};
-
- auto& root = *module->entry_computation()->root_instruction();
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputs));
-
- EXPECT_EQ(rzb.input_shapes().size(), 3);
- EXPECT_EQ(rzb.input_buffers().size(), 3);
- EXPECT_EQ(rzb.output_buffers().size(), 0);
- EXPECT_NE(rzb.output_shape(), root.shape());
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputsAllOutputs));
-
- EXPECT_EQ(rzb2.input_shapes().size(), 3);
- EXPECT_EQ(rzb2.input_buffers().size(), 3);
- EXPECT_EQ(rzb2.output_buffers().size(), 1);
- EXPECT_EQ(rzb2.output_shape(), root.shape());
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputsOutputsNoScratch));
-
- EXPECT_EQ(rzb3.input_shapes().size(), 3);
- EXPECT_EQ(rzb3.input_buffers().size(), 3);
- EXPECT_EQ(rzb3.output_buffers().size(), 1);
- EXPECT_EQ(rzb3.output_shape(), root.shape());
-}
-
-TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleOneElement) {
- constexpr absl::string_view kHlo = R"(
-HloModule hlo
-ENTRY main {
- p0 = f32[2,2] parameter(0)
- p1 = f32[4,4] parameter(1)
- p2 = f32[6,6] parameter(2)
- ROOT root = (f32[1,2,3]) custom-call(p0, p1, p2), custom_call_target="fake"
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
-
- se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
- PlatformUtil::GetStreamExecutors(platform));
-
- AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
- GetDebugOptionsForTest()};
-
- auto& root = *module->entry_computation()->root_instruction();
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputs));
-
- EXPECT_EQ(rzb.input_shapes().size(), 3);
- EXPECT_EQ(rzb.input_buffers().size(), 3);
- EXPECT_EQ(rzb.output_buffers().size(), 0);
- EXPECT_NE(rzb.output_shape(), root.shape());
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputsAllOutputs));
-
- EXPECT_EQ(rzb2.input_shapes().size(), 3);
- EXPECT_EQ(rzb2.input_buffers().size(), 3);
- EXPECT_EQ(rzb2.output_buffers().size(), 1);
- EXPECT_FALSE(rzb2.output_shape().IsTuple());
- EXPECT_EQ(rzb2.output_shape(), root.shape().tuple_shapes(0));
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputsOutputsNoScratch));
-
- EXPECT_EQ(rzb3.input_shapes().size(), 3);
- EXPECT_EQ(rzb3.input_buffers().size(), 3);
- EXPECT_EQ(rzb3.output_buffers().size(), 0);
-}
-
-TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleTwoElements) {
- constexpr absl::string_view kHlo = R"(
-HloModule hlo
-ENTRY main {
- p0 = f32[2,2] parameter(0)
- p1 = f32[4,4] parameter(1)
- p2 = f32[6,6] parameter(2)
- ROOT root = (f32[1,2,3], u8[1,2]) custom-call(p0, p1, p2), custom_call_target="fake"
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
-
- se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
- PlatformUtil::GetStreamExecutors(platform));
-
- AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
- GetDebugOptionsForTest()};
-
- auto& root = *module->entry_computation()->root_instruction();
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputs));
-
- EXPECT_EQ(rzb.input_shapes().size(), 3);
- EXPECT_EQ(rzb.input_buffers().size(), 3);
- EXPECT_EQ(rzb.output_buffers().size(), 0);
- EXPECT_NE(rzb.output_shape(), root.shape());
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputsAllOutputs));
-
- EXPECT_EQ(rzb2.input_shapes().size(), 3);
- EXPECT_EQ(rzb2.input_buffers().size(), 3);
- EXPECT_EQ(rzb2.output_buffers().size(), 2);
- EXPECT_TRUE(rzb2.output_shape().IsTuple());
- EXPECT_EQ(rzb2.output_shape(), root.shape());
-
- TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
- RedzoneBuffers::FromInstruction(
- root, autotune_config, GetDebugOptionsForTest(),
- RedzoneBuffers::kAllInputsOutputsNoScratch));
-
- EXPECT_EQ(rzb3.input_shapes().size(), 3);
- EXPECT_EQ(rzb3.input_buffers().size(), 3);
- EXPECT_EQ(rzb3.output_buffers().size(), 1);
- EXPECT_FALSE(rzb3.output_shape().IsTuple());
- EXPECT_EQ(rzb3.output_shape(), root.shape().tuple_shapes(0));
-}
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuner_util.cc
deleted file mode 100644
index 93c946f..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_util.cc
+++ /dev/null
@@ -1,547 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_util.h"
-
-#include <algorithm>
-#include <array>
-#include <cmath>
-#include <cstdint>
-#include <limits>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-
-#include "absl/base/const_init.h"
-#include "absl/base/thread_annotations.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/match.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/SHA256.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_asm_opts_util.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "tsl/platform/base64.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/protobuf.h" // IWYU pragma: keep
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Bump this version whenever you change the structure of the results.
-// LINT.IfChange(version)
-constexpr int kVersion = 3;
-// LINT.ThenChange()
-
-} // namespace
-
-using AutotuneCacheMap = absl::flat_hash_map<AutotuneCacheKey, AutotuneResult>;
-
-static absl::Mutex autotune_cache_mu(absl::kConstInit);
-static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) =
- *new AutotuneCacheMap();
-
-absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s) {
- llvm::SHA256 sha256;
- sha256.update(llvm::StringRef(s));
- std::array<uint8_t, 32> hash = sha256.final();
- // C++ strict aliasing rules allow reinterpret casting to (const) char*.
- absl::string_view hash_view(reinterpret_cast<const char*>(hash.data()),
- hash.size());
- std::string base64_encoded_hash;
- TF_RETURN_IF_ERROR(tsl::Base64Encode(hash_view, &base64_encoded_hash));
- return base64_encoded_hash;
-}
-
-namespace {
-
-// Get the path corresponding to the given key.
-absl::StatusOr<std::string> GetCacheFilePath(absl::string_view cache_dir,
- const AutotuneCacheKey& key) {
- if (cache_dir.empty()) {
- return absl::InvalidArgumentError("autotune_cache_dir should not be empty");
- }
-
- TF_ASSIGN_OR_RETURN(std::string key_hash,
- GetBase64EncodedSha256Hash(key.ToString()));
- return tsl::io::JoinPath(cache_dir, absl::StrCat(key_hash, ".textproto"));
-}
-
-struct ResultAndInserted {
- // The result that ended up in the cache. This is the existing result if
- // inserted is false, and the new result if inserted is true.
- //
- // We return a value, not a pointer, for thread safety reasons.
- AutotuneResult result;
- // Did we insert the given result into the cache?
- bool inserted;
-};
-
-ResultAndInserted AddResultToInMemoryCache(const AutotuneCacheKey& key,
- AutotuneResult result)
- ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
- absl::MutexLock lock(&autotune_cache_mu);
- auto [it, inserted] = autotune_cache.emplace(key, std::move(result));
- return {it->second, inserted};
-}
-
-absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
- AutotuneResult result,
- std::string_view cache_dir)
- ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
- if (cache_dir.empty()) {
- return absl::OkStatus();
- }
-
- TF_ASSIGN_OR_RETURN(const std::string file_path,
- GetCacheFilePath(cache_dir, key));
-
- VLOG(1) << "Writing autotune result to file: " << file_path;
-
- std::string result_str;
- if (!tsl::protobuf::TextFormat::PrintToString(result, &result_str)) {
- return absl::InternalError("Failed to serialize autotune result.");
- }
-
- // Rename trick: Write to a temporary file, then rename it to the final file
- // to avoid mingled files when multiple processes are writing to the same
- // file. Also avoids reading incomplete files. (This may not work on all file
- // systems.)
- std::string temp_file_path = tsl::io::GetTempFilename(".textproto");
- tsl::Env* default_env = tsl::Env::Default();
- TF_RETURN_IF_ERROR(
- tsl::WriteStringToFile(default_env, temp_file_path, result_str));
- return default_env->RenameFile(temp_file_path, file_path);
-}
-
-absl::StatusOr<ResultAndInserted> AddResultToCaches(const AutotuneCacheKey& key,
- AutotuneResult result,
- std::string_view cache_dir)
- ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
- ResultAndInserted result_and_inserted = AddResultToInMemoryCache(key, result);
- if (result_and_inserted.inserted) {
- TF_RETURN_IF_ERROR(AddResultToFileBasedCacheIfEnabled(
- key, result_and_inserted.result, cache_dir));
- }
- return result_and_inserted;
-}
-
-std::optional<AutotuneResult> TryToFindInInMemoryCache(
- const AutotuneCacheKey& key) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
- absl::MutexLock lock(&autotune_cache_mu);
- auto it = autotune_cache.find(key);
- if (it == autotune_cache.end()) {
- return std::nullopt;
- }
- return it->second;
-}
-
-absl::StatusOr<std::optional<AutotuneResult>>
-TryToFindInFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
- absl::string_view cache_dir)
- ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
- if (cache_dir.empty()) {
- return std::nullopt;
- }
-
- TF_ASSIGN_OR_RETURN(const std::string file_path,
- GetCacheFilePath(cache_dir, key));
- if (!tsl::Env::Default()->FileExists(file_path).ok()) {
- VLOG(1) << "Autotune result file not found: " << file_path;
- return std::nullopt;
- }
-
- VLOG(1) << "Autotune result file found: " << file_path;
- std::string autotune_result_str;
- TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), file_path,
- &autotune_result_str));
- AutotuneResult result;
- if (!tsl::protobuf::TextFormat::ParseFromString(autotune_result_str,
- &result)) {
- return absl::InvalidArgumentError("Failed to parse autotune result.");
- }
- return result;
-}
-
-// Sort the results so that they're deterministic.
-void SortAutotuneResults(AutotuneResults* results) {
- std::sort(results->mutable_results()->pointer_begin(),
- results->mutable_results()->pointer_end(),
- [](const auto* a, const auto* b) {
- return std::make_pair(absl::string_view(a->device()),
- absl::string_view(a->hlo())) <
- std::make_pair(absl::string_view(b->device()),
- absl::string_view(b->hlo()));
- });
-}
-
-} // namespace
-
-// Serialize `results` to string as a proto.
-absl::StatusOr<std::string> AutotuneResultsToString(
- const AutotuneResults& results, bool as_textproto) {
- if (as_textproto) {
- std::string textproto;
- if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) {
- return textproto;
- } else {
- return Internal("Failed to serialize autotune results.");
- }
- }
- return results.SerializeAsString();
-}
-
-namespace {
-// Serialize a single entry to `results`.
-void SerializeAutotuneEntry(AutotuneResults* results, const AutotuneCacheKey& k,
- const AutotuneResult* res) {
- auto& entry = *results->add_results();
- entry.set_device(std::string(k.GetModelStr()));
- entry.set_hlo(std::string(k.GetHlo()));
- *entry.mutable_result() = *res;
-}
-} // namespace
-
-/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults(
- AutotuneResults* results) {
- absl::MutexLock lock(&autotune_cache_mu);
- for (const auto& [k, result] : autotune_cache) {
- SerializeAutotuneEntry(results, k, &result);
- }
-
- results->set_version(kVersion);
- SortAutotuneResults(results);
-
- return absl::OkStatus();
-}
-
-/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
- const AutotuneResults& results) {
- absl::MutexLock lock(&autotune_cache_mu);
- for (const AutotuneResults::Entry& result : results.results()) {
- if (auto [it, inserted] = autotune_cache.emplace(
- AutotuneCacheKey(result.device(), result.hlo()), result.result());
- !inserted) {
- return absl::InternalError(absl::StrCat(
- "Duplicate autotuning result for ", it->first.ToString()));
- }
- }
- return absl::OkStatus();
-}
-
-/*static*/ void AutotunerUtil::ClearAutotuneResults() {
- absl::MutexLock lock(&autotune_cache_mu);
- autotune_cache.clear();
-}
-
-/*static*/ bool AutotunerUtil::ResultCacheIsEmpty() {
- absl::MutexLock lock(&autotune_cache_mu);
- return autotune_cache.empty();
-}
-
-/* static*/ absl::StatusOr<se::DeviceMemoryBase> AutotunerUtil::CreateBuffer(
- se::RedzoneAllocator& allocator, const Shape& shape,
- const AutotuneConfig& config, int64_t& rng_state) {
- TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
- allocator.AllocateBytes(ShapeUtil::ByteSizeOf(shape)));
- if (config.should_init_buffers()) {
- InitializeBuffer(allocator.stream(), shape.element_type(), &rng_state,
- buffer);
- }
- return buffer;
-}
-
-namespace {
-std::string ToCanonicalString(const HloInstruction* instr) {
- auto options = HloPrintOptions::Canonical();
- if (instr->opcode() != HloOpcode::kFusion) {
- options.set_print_backend_config(true);
- return instr->ToString(options);
- }
- options.set_print_subcomputation_mode(
- HloPrintOptions::PrintSubcomputationMode::kOff);
- options.set_print_infeed_outfeed_config(false);
- options.set_print_only_essential_constants(true);
- options.set_print_operand_shape(true);
- options.set_print_ids(false);
- options.set_canonicalize_computations(true);
-
- // TODO(b/266210099): This is unsound. We should probably do the fingerprint
- // of the HLO computation proto instead.
- return instr->called_computations()[0]->ToString(options);
-}
-
-} // namespace
-
-AutotuneCacheKey::AutotuneCacheKey(absl::string_view model_str,
- const HloInstruction& instr)
- : AutotuneCacheKey(model_str, ToCanonicalString(&instr)) {}
-
-/*static*/ std::string AutotuneCacheKey::DeviceDescriptionToCacheKey(
- const se::DeviceDescription& device_description) {
- std::string compute_capability;
- if (auto* ccc = std::get_if<se::CudaComputeCapability>(
- &device_description.gpu_compute_capability())) {
- compute_capability = absl::StrCat("CUDA: ", ccc->major, ".", ccc->minor);
- } else {
- auto* rcc = std::get_if<se::RocmComputeCapability>(
- &device_description.gpu_compute_capability());
- CHECK(rcc != nullptr) << "Unknown compute capability type";
- compute_capability = absl::StrCat("ROCM: ", rcc->gfx_version());
- }
-
- // The string below should include only as much information as is needed to
- // make it a valid key. Information that should not be included is:
- // - specs that are directly derivable from the compute capability, e.g.
- // shared memory size. For NVIDIA GPUs, you can see what is derivable from
- // the SM version here:
- // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
- // - specs that are irrelevant for autotuning. E.g. the total available memory
- // on a device is not relevant, because by itself, it does not affect the
- // performance of single kernels.
- //
- // See b/344573710 for some discussion.
-
- double memory_bandwidth = device_description.memory_bandwidth() / 1e9;
- // Round the memory bandwidth to make the final string nicer to read.
- // This will also cause minute differences in bandwidth to yield the same
- // cache key, but that's fine, since the difference is inconsequential.
- memory_bandwidth = std::round(memory_bandwidth);
-
- constexpr double kBytesPerMegabyte = 1 << 20;
- double l2_cache_size = device_description.l2_cache_size() / kBytesPerMegabyte;
-
- return absl::StrCat(compute_capability,
- ", Cores: ", device_description.core_count(),
- ", GPU clock: ", device_description.clock_rate_ghz(),
- " GHz, Memory bandwidth: ", memory_bandwidth,
- " GB/s, L2 cache: ", l2_cache_size, " MB");
-}
-
-namespace {
-absl::StatusOr<std::optional<AutotuneResult>> TryFindInCache(
- const AutotuneCacheKey& key, absl::string_view cache_dir)
- ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
- std::optional<AutotuneResult> opt_result = TryToFindInInMemoryCache(key);
- if (opt_result.has_value()) {
- if (VLOG_IS_ON(1)) {
- LOG(INFO) << "In-memory autotune cache hit";
- } else if (VLOG_IS_ON(2)) {
- LOG(INFO) << "In-memory autotune cache hit: key = " << key.ToString();
- }
- return opt_result;
- }
-
- TF_ASSIGN_OR_RETURN(opt_result,
- TryToFindInFileBasedCacheIfEnabled(key, cache_dir));
- if (opt_result.has_value()) {
- AddResultToInMemoryCache(key, opt_result.value());
-
- if (VLOG_IS_ON(1)) {
- LOG(INFO) << "File-based autotune cache hit";
- } else if (VLOG_IS_ON(2)) {
- LOG(INFO) << "File-based autotune cache hit: key = " << key.ToString();
- }
- return opt_result;
- }
-
- if (VLOG_IS_ON(1)) {
- LOG(INFO) << "Autotune cache miss";
- } else if (VLOG_IS_ON(2)) {
- LOG(INFO) << "Autotune cache miss: key = " << key.ToString();
- }
- return std::nullopt;
-}
-} // namespace
-
-/*static*/ AutotuneCacheKey AutotunerUtil::GetKey(
- const HloInstruction* instr, const AutotuneConfig& config) {
- return AutotuneCacheKey(config.GetModelStr(), *instr);
-}
-
-/*static*/ absl::StatusOr<bool> AutotunerUtil::IsInCache(
- const AutotuneCacheKey& key, const AutotuneConfig& config) {
- TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
- TryFindInCache(key, config.autotune_cache_dir()));
- return opt_res.has_value();
-}
-
-/*static*/ absl::StatusOr<bool> AutotunerUtil::AddResult(
- const AutotuneCacheKey& key, AutotuneResult result,
- const AutotuneConfig& config) {
- TF_ASSIGN_OR_RETURN(
- ResultAndInserted result_and_inserted,
- AddResultToCaches(key, std::move(result), config.autotune_cache_dir()));
- return result_and_inserted.inserted;
-}
-
-/*static*/ absl::StatusOr<AutotuneResult> AutotunerUtil::Autotune(
- const HloInstruction* instr, const AutotuneConfig& config,
- const AutotuneNoCacheFn& autotune_fn) {
- const AutotuneCacheKey key = GetKey(instr, config);
- TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
- TryFindInCache(key, config.autotune_cache_dir()));
- if (opt_res.has_value()) {
- return opt_res.value();
- }
-
- // Cache miss.
- if (config.should_require_complete_aot_autotune_results()) {
- return NotFound(
- "Complete XLA AOT autotuning results are required, but no AOT result "
- "was found for key: %s",
- key.ToString());
- }
-
- TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn());
-
- TF_ASSIGN_OR_RETURN(ResultAndInserted result_and_inserted,
- AddResultToCaches(key, std::move(autotune_result),
- config.autotune_cache_dir()));
- return result_and_inserted.result;
-}
-
-namespace {
-
-bool IsTextProtoPath(absl::string_view file_path) {
- return absl::EndsWith(file_path, ".txt") ||
- absl::EndsWith(file_path, ".textproto") ||
- absl::EndsWith(file_path, ".prototxt") ||
- absl::EndsWith(file_path, ".pbtxt");
-}
-
-} // anonymous namespace
-
-/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
- absl::string_view data, bool as_textproto) {
- AutotuneResults results;
- // The cast here is necessary for MacOS builds.
- bool parse_success =
- as_textproto ? tsl::protobuf::TextFormat::ParseFromString(
- std::string(data), &results) // NOLINT
- : results.ParseFromString(std::string(data)); // NOLINT
- if (!parse_success) {
- return absl::InvalidArgumentError(
- "Failed to parse autotune results string.");
- }
- if (results.version() != kVersion) {
- return absl::InvalidArgumentError(absl::StrFormat(
- "Version mismatch in autotune results. Expected %d but was %d",
- kVersion, results.version()));
- }
-
- TF_RETURN_IF_ERROR(LoadAutotuneResults(results));
- return absl::OkStatus();
-}
-
-/*static*/ absl::StatusOr<std::string> AutotunerUtil::SerializeAutotuneResults(
- bool as_textproto) {
- AutotuneResults results;
- TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
- return AutotuneResultsToString(results, as_textproto);
-}
-
-/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
- const AutotuneResults& results, absl::string_view file_path) {
- TF_RET_CHECK(!file_path.empty());
- TF_RET_CHECK(results.version() > 0)
- << "Did you call SerializeAutotuneResults to get this AutotuneResults?";
-
- std::string resolved_path;
- if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
- return FailedPrecondition("File path can not be resolved: %s", file_path);
- }
-
- TF_ASSIGN_OR_RETURN(
- std::string autotune_results_str,
- AutotuneResultsToString(results, IsTextProtoPath(resolved_path)));
- TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), resolved_path,
- autotune_results_str));
- LOG(INFO) << "Autotune results serialized to file: " << resolved_path;
-
- return absl::OkStatus();
-}
-
-/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
- absl::string_view file_path) {
- AutotuneResults results;
- TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
- return SerializeAutotuneResultsToFile(results, file_path);
-}
-
-/*static*/ absl::Status AutotunerUtil::LoadAutotuneResultsFromFile(
- absl::string_view file_path) {
- TF_RET_CHECK(!file_path.empty());
-
- std::string resolved_path;
- if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
- return FailedPrecondition("File path can not be resolved: %s", file_path);
- }
-
- if (!tsl::Env::Default()->FileExists(resolved_path).ok()) {
- return FailedPrecondition("Autotune results file does not exist: %s",
- resolved_path);
- }
- std::string autotune_results_str;
- TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), resolved_path,
- &autotune_results_str));
-
- TF_RETURN_IF_ERROR(LoadAutotuneResults(autotune_results_str,
- IsTextProtoPath(resolved_path)));
-
- LOG(INFO) << "Autotune results loaded from file: " << resolved_path;
-
- return absl::OkStatus();
-}
-
-/*static*/ absl::StatusOr<se::RedzoneAllocator>
-AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config,
- const DebugOptions& opts) {
- TF_ASSIGN_OR_RETURN(se::Stream * stream, config.GetStream());
- return se::RedzoneAllocator(
- stream, config.GetAllocator(), PtxOptsFromDebugOptions(opts),
- /*memory_limit=*/std::numeric_limits<int64_t>::max(),
- /*redzone_size=*/config.should_check_correctness()
- ? opts.xla_gpu_redzone_padding_bytes()
- : 0);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuner_util.h
deleted file mode 100644
index 4634fc2..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_util.h
+++ /dev/null
@@ -1,334 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_
-#define XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_
-
-#include <algorithm>
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/xla.pb.h"
-
-namespace xla {
-namespace gpu {
-
-struct DeviceConfig {
- se::StreamExecutor* stream_exec; // never null
-
- // If the `allocator` parameter is not null, we will use it to allocate temp
- // memory while timing the various convolution algorithms. If it's null,
- // we'll use the default allocator on the StreamExecutor.
- se::DeviceMemoryAllocator* allocator = nullptr; // may be null
-};
-
-struct DevicelessConfig {
- // The device description of the target device.
- se::DeviceDescription device_description;
-};
-
-class AutotuneCacheKey {
- public:
- AutotuneCacheKey(const se::DeviceDescription& device_description,
- const HloInstruction& instruction)
- : AutotuneCacheKey(DeviceDescriptionToCacheKey(device_description),
- instruction.ToString()) {}
-
- AutotuneCacheKey(absl::string_view model_str,
- const HloInstruction& instruction);
-
- explicit AutotuneCacheKey(absl::string_view model_str,
- absl::string_view hlo_canonical)
- : model_str_(model_str), hlo_canonical_(hlo_canonical) {}
-
- absl::string_view GetModelStr() const { return model_str_; }
-
- absl::string_view GetHlo() const { return hlo_canonical_; }
-
- template <typename H>
- friend H AbslHashValue(H h, const AutotuneCacheKey& w) {
- return H::combine(std::move(h), w.model_str_, w.hlo_canonical_);
- }
-
- bool operator==(const AutotuneCacheKey& w) const {
- return model_str_ == w.model_str_ && hlo_canonical_ == w.hlo_canonical_;
- }
-
- std::string ToString() const {
- return absl::StrFormat("<key model='%s', hlo='%s'>", model_str_,
- hlo_canonical_);
- }
-
- static std::string DeviceDescriptionToCacheKey(
- const se::DeviceDescription& device_description);
-
- private:
- std::string model_str_;
- std::string hlo_canonical_;
-};
-
-class AutotuneConfig {
- public:
- bool should_init_buffers() const { return autotune_level_ >= 2; }
- bool should_reinit_output_buffer() const { return autotune_level_ >= 3; }
- bool should_check_correctness() const { return autotune_level_ >= 4; }
- bool should_skip_wrong_results() const { return autotune_level_ >= 5; }
- bool should_crash_on_check_failure() const {
- return should_crash_on_check_failure_;
- }
- bool should_require_complete_aot_autotune_results() const {
- return require_complete_aot_autotune_results_;
- }
- // Empty string means no cache is used.
- const std::string& autotune_cache_dir() const { return autotune_cache_dir_; }
-
- AutotuneConfig(const AutotuneConfig& right)
- : config_(right.config_),
- autotune_level_(right.autotune_level_),
- should_crash_on_check_failure_(right.should_crash_on_check_failure_),
- exhaustive_tiling_search_(right.exhaustive_tiling_search_),
- require_complete_aot_autotune_results_(
- right.require_complete_aot_autotune_results_),
- autotune_cache_dir_(right.autotune_cache_dir_) {}
-
- AutotuneConfig(const std::variant<DeviceConfig, DevicelessConfig>& config,
- const DebugOptions& debug_options)
- : config_(config),
- autotune_level_(debug_options.xla_gpu_autotune_level()),
- should_crash_on_check_failure_(
- debug_options.xla_gpu_crash_on_verification_failures()),
- exhaustive_tiling_search_(
- debug_options.xla_gpu_exhaustive_tiling_search()),
- require_complete_aot_autotune_results_(
- debug_options.xla_gpu_require_complete_aot_autotune_results()),
- autotune_cache_dir_(
- debug_options.xla_gpu_per_fusion_autotune_cache_dir()) {}
-
- std::string GetModelStr() const {
- if (auto deviceless_config = std::get_if<DevicelessConfig>(&config_)) {
- return AutotuneCacheKey::DeviceDescriptionToCacheKey(
- deviceless_config->device_description);
- }
-
- const auto& device_config = std::get<DeviceConfig>(config_);
- return AutotuneCacheKey::DeviceDescriptionToCacheKey(
- device_config.stream_exec->GetDeviceDescription());
- }
-
- se::StreamExecutor* GetExecutor() const {
- CHECK(std::holds_alternative<DeviceConfig>(config_));
- return std::get<DeviceConfig>(config_).stream_exec;
- }
-
- se::DeviceMemoryAllocator* GetAllocator() const {
- CHECK(std::holds_alternative<DeviceConfig>(config_));
- auto& cf = std::get<DeviceConfig>(config_);
- if (cf.allocator != nullptr) {
- return cf.allocator;
- }
- if (allocator_ == nullptr) {
- allocator_ =
- std::make_unique<se::StreamExecutorMemoryAllocator>(GetExecutor());
- }
- return allocator_.get();
- }
-
- absl::StatusOr<se::Stream*> GetStream() const {
- CHECK(std::holds_alternative<DeviceConfig>(config_));
- return GetAllocator()->GetStream(GetExecutor()->device_ordinal());
- }
-
- const se::GpuComputeCapability& GetGpuComputeCapability() const {
- if (auto c = std::get_if<DeviceConfig>(&config_)) {
- return c->stream_exec->GetDeviceDescription().gpu_compute_capability();
- }
- return std::get<DevicelessConfig>(config_)
- .device_description.gpu_compute_capability();
- }
-
- bool IsDeviceless() const {
- return std::holds_alternative<DevicelessConfig>(config_);
- }
-
- bool ExhaustiveTilingSearch() const { return exhaustive_tiling_search_; }
-
- private:
- std::variant<DeviceConfig, DevicelessConfig> config_;
- int32_t autotune_level_;
- bool should_crash_on_check_failure_;
- bool exhaustive_tiling_search_;
- bool require_complete_aot_autotune_results_;
- mutable std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
- std::string autotune_cache_dir_;
-};
-
-using AutotuneNoCacheFn = std::function<absl::StatusOr<AutotuneResult>()>;
-
-struct AutotunerUtil {
- // Create a buffer for a given operation using redzone checker, initialize
- // based on a given rng state.
- static absl::StatusOr<se::DeviceMemoryBase> CreateBuffer(
- se::RedzoneAllocator& allocator, const Shape& shape,
- const AutotuneConfig& config, int64_t& rng_state);
-
- static absl::StatusOr<AutotuneResult> Autotune(
- const HloInstruction* instr, const AutotuneConfig& config,
- const AutotuneNoCacheFn& autotune_fn);
-
- // Returns the same cache key that would be used inside Autotune().
- //
- // Normally, we don't have to use this low level method.
- static AutotuneCacheKey GetKey(const HloInstruction* instr,
- const AutotuneConfig& config);
-
- // Checks if the key is in the autotune cache.
- //
- // Normally, we don't have to use this low level method.
- static absl::StatusOr<bool> IsInCache(const AutotuneCacheKey& key,
- const AutotuneConfig& config);
-
- // Adds the result to the autotune cache.
- //
- // Returns true if the entry is inserted.
- //
- // Normally, we don't have to use this low level method.
- static absl::StatusOr<bool> AddResult(const AutotuneCacheKey& key,
- AutotuneResult result,
- const AutotuneConfig& config);
-
- // Creates a RedzoneAllocator from a given config.
- static absl::StatusOr<se::RedzoneAllocator> CreateRedzoneAllocator(
- const AutotuneConfig& config, const DebugOptions& opts);
-
- // Functions to save/load XLA's autotuning results.
- //
- // This is used for ahead-of-time autotuning. Specifically:
- //
- // When XLA calls cublas (for matmuls, aka "gemm" or "dot") or cudnn (for
- // convolutions), it usually has to choose an "algorithm" for the particular
- // dot/conv. XLA queries cublas/cudnn for a list of candidate algorithms.
- // Then it runs all of them and picks the fastest one. This is what we call
- // "autotuning". It happens in GemmAlgorithmPicker and GpuConvAlgorithmPicker.
- //
- // Autotuning is necessary to get good performance for dot/conv. But it also
- // has some disadvantages.
- //
- // - Because it relies on timing data, it is fundamentally nondeterministic.
- // But even if two algorithms have similar runtimes, our choice of
- // algorithm may be visible to the user: Different algorithms can have
- // different numerics, and sometimes they can even have different bugs!
- //
- // - Trying all the candidate algorithms can be slow, especially if when some
- // of the candidates are "very bad" and run especially slowly compared to
- // the optimal candidate. This slows down compilation.
- //
- // To address the disadvantages above, we allow users to save/restore the
- // autotuning choices that XLA has made, using the functions below.
- //
- // Loading autotuning results does not erase existing autotuning choices, but
- // in the event of a disagreement between the existing data and the new data,
- // the new algorithm is chosen.
- //
- // Note that even if you call LoadAutotuneResults(), if XLA encounters a
- // dot/conv that is *not* covered by the loaded data, it will go ahead and
- // autotune it like normal. In other words, the behavior of XLA should be
- // identical with or without ahead-of-time autotuning, modulo nondeterminism.
- //
- // This is important if you want to be able to use the same autotuning file
- // with different versions of XLA, because as XLA changes, exactly which
- // dots/convs it wants to run can also change. For example, XLA might change
- // the conv padding heuristics it uses, and we don't want that to mean that
- // all users of ahead-of-time autotuning are broken.
- static absl::StatusOr<std::string> SerializeAutotuneResults(
- bool as_textproto = false);
-
- // Serializes autotune results into the given proto.
- static absl::Status SerializeAutotuneResults(AutotuneResults* results);
-
- // Loads autotune results from the given string of bytes.
- //
- // Warning: The results are only loaded to the in-memory cache.
- static absl::Status LoadAutotuneResults(absl::string_view data,
- bool as_textproto = false);
-
- // Loads autotune results from the given proto.
- //
- // Warning: The results are only loaded to the in-memory cache.
- static absl::Status LoadAutotuneResults(const AutotuneResults& results);
-
- // Serializes autotune results into a file.
- //
- // If `file_path` ends with ".txt" or ".textproto", then the textproto format
- // is used, otherwise the binary protobuf format.
- static absl::Status SerializeAutotuneResultsToFile(
- absl::string_view file_path);
-
- // As above, but if you already called SerializeAutotuneResults to get a
- // proto.
- static absl::Status SerializeAutotuneResultsToFile(
- const AutotuneResults& results, absl::string_view file_path);
-
- // Loads autotune results from a file.
- //
- // If `file_path` ends with ".txt" or ".textproto", then the file is
- // considered to be in the textproto format, otherwise the binary protobuf
- // format.
- //
- // Warning: The results are only loaded to the in-memory cache.
- static absl::Status LoadAutotuneResultsFromFile(absl::string_view file_path);
-
- // Warning: This only clears the in-memory cache. If you use a file based
- // cache you're responsible for clearing the cache directory when you want to.
- static void ClearAutotuneResults();
-
- // Warning: This only checks the in-memory cache. If you use a file based
- // cache, you're responsible for checking whether the cache directory is
- // empty.
- static bool ResultCacheIsEmpty();
-};
-
-absl::StatusOr<std::string> AutotuneResultsToString(
- const AutotuneResults& results, bool as_textproto);
-
-// Exposed only for testing. Returns the SHA-256 hash of the input string,
-// encoded in base64.
-//
-// SHA-256 was chosen to follow industry best practices and avoid collisions.
-// Git is also transitioning to SHA-256. This is probably better than
-// tsl::Fingerprint128.
-absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s);
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuner_util_test.cc
deleted file mode 100644
index 69c1395..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_util_test.cc
+++ /dev/null
@@ -1,459 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_util.h"
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/platform_manager.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h" // IWYU pragma: keep
-#include "tsl/platform/path.h"
-#include "tsl/platform/protobuf.h" // IWYU pragma: keep
-#include "tsl/platform/status.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-using ::testing::HasSubstr;
-using ::testing::IsEmpty;
-using ::testing::Not;
-using ::testing::TempDir;
-using ::tsl::testing::StatusIs;
-
-class AutotunerUtilTest : public HloTestBase {
- protected:
- static constexpr absl::string_view kHloText = R"(
-HloModule t
-
-ENTRY e {
- p0 = f16[1,16,17,3] parameter(0)
- p1 = s8[16,17,3] parameter(1)
- cp1 = f16[16,17,3] convert(p1)
- ROOT _ = f16[1,16,16] dot(p0, cp1),
- lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-})";
-
- static constexpr absl::string_view kResultText = R"(
-version: 3
-results {
- device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
- hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
- result {
- run_time {
- nanos: 31744
- }
- triton {
- block_m: 32
- block_n: 32
- block_k: 32
- split_k: 1
- num_stages: 1
- num_warps: 4
- num_ctas: 1
- }
- }
-})";
-
- void SetUp() override { AutotunerUtil::ClearAutotuneResults(); }
-
- std::string GetUniqueTempFilePath(absl::string_view suffix) {
- std::string filename = TempDir();
- CHECK(tsl::Env::Default()->CreateUniqueFileName(&filename,
- std::string(suffix)));
- return filename;
- }
-
- std::string ExpectToReadNonEmptyFile(absl::string_view file_path) {
- std::string str;
- tsl::Env* env = tsl::Env::Default();
- TF_EXPECT_OK(tsl::ReadFileToString(env, std::string(file_path), &str));
- EXPECT_THAT(str, Not(IsEmpty()));
- return str;
- }
-
- static std::unique_ptr<stream_executor::StreamExecutor> NewStreamExecutor() {
- stream_executor::Platform* platform =
- stream_executor::PlatformManager::PlatformWithName("Host").value();
- stream_executor::StreamExecutorConfig config(/*ordinal=*/0);
- return platform->GetUncachedExecutor(config).value();
- }
-
- absl::Status PopulateResultCache() {
- EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
- TF_RETURN_IF_ERROR(AutotunerUtil::LoadAutotuneResults(kResultText, true));
- EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
- return absl::OkStatus();
- }
-};
-
-TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto1) {
- TF_EXPECT_OK(PopulateResultCache());
- std::string kFilePath = GetUniqueTempFilePath(".txt");
- TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-
- std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
- AutotuneResults results;
- EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
- &results));
- EXPECT_GT(results.results_size(), 0);
-}
-
-TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto2) {
- TF_EXPECT_OK(PopulateResultCache());
- std::string kFilePath = GetUniqueTempFilePath(".textproto");
- TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-
- std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
- AutotuneResults results;
- EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
- &results));
-}
-
-TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_Protobuf) {
- TF_EXPECT_OK(PopulateResultCache());
- std::string kFilePath = GetUniqueTempFilePath(".pb");
- TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-
- std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
- AutotuneResults results;
- EXPECT_TRUE(results.ParseFromString(autotune_results_str));
-}
-
-TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto1) {
- TF_EXPECT_OK(PopulateResultCache());
- std::string kFilePath = GetUniqueTempFilePath(".txt");
- TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
- AutotunerUtil::ClearAutotuneResults();
- EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
-
- TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
- EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
-}
-
-TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto2) {
- TF_EXPECT_OK(PopulateResultCache());
- std::string kFilePath = GetUniqueTempFilePath(".textproto");
- TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
- AutotunerUtil::ClearAutotuneResults();
- EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
-
- TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
- EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
-}
-
-TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_Protobuf) {
- TF_EXPECT_OK(PopulateResultCache());
- std::string kFilePath = GetUniqueTempFilePath(".pb");
- TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
- AutotunerUtil::ClearAutotuneResults();
- EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
-
- TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
- EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
-}
-
-TEST_F(AutotunerUtilTest, ResultConflictsAreDetected) {
- TF_EXPECT_OK(PopulateResultCache());
- std::string kFilePath = GetUniqueTempFilePath(".pb");
- TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
- EXPECT_THAT(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath),
- StatusIs(absl::StatusCode::kInternal,
- HasSubstr("Duplicate autotuning result")));
-}
-
-// Test that when complete AOT autotuning is required, and there is cache miss,
-// a `NotFound` error will be raised.
-TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) {
- std::string kFilePath = GetUniqueTempFilePath(".txt");
- auto hlo_module = GetOptimizedModule(kHloText);
- TF_EXPECT_OK(hlo_module.status());
- std::vector<HloComputation*> computations =
- (*hlo_module)
- ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
- EXPECT_THAT(computations, Not(IsEmpty()));
- const HloInstruction* instruction = *computations[0]->instructions().begin();
- std::unique_ptr<stream_executor::StreamExecutor> executor =
- NewStreamExecutor();
- auto options = DebugOptions();
- options.set_xla_gpu_require_complete_aot_autotune_results(true);
- AutotuneConfig config(DeviceConfig{executor.get()}, options);
- EXPECT_THAT(
- AutotunerUtil::Autotune(instruction, config,
- [&] { return AutotuneResult(); }),
- StatusIs(
- absl::StatusCode::kNotFound,
- HasSubstr("Complete XLA AOT autotuning results are required, but "
- "no AOT result was found for key: <key model")));
-}
-
-// Test that when JIT autotuning is disabled, but no cache miss due to AOT
-// autotuning, `Autotune` still returns Ok status.
-TEST_F(AutotunerUtilTest, OkIfJitAutotuningDisabledButAlreadyLoadedAOT) {
- auto hlo_module = GetOptimizedModule(kHloText);
- std::vector<HloComputation*> computations =
- (*hlo_module)
- ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
- EXPECT_THAT(computations, Not(IsEmpty()));
- const HloInstruction* instruction = *computations[0]->instructions().begin();
- std::unique_ptr<stream_executor::StreamExecutor> executor =
- NewStreamExecutor();
-
- {
- // By default, JIT autotuning is OK.
- AutotuneConfig config(DeviceConfig{executor.get()}, DebugOptions());
- TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
- return AutotuneResult();
- }).status());
- }
-
- // Now require complete AOT autotuning results.
- auto options = DebugOptions();
- options.set_xla_gpu_require_complete_aot_autotune_results(true);
-
- AutotuneConfig config(DeviceConfig{executor.get()}, options);
- // Even though JIT autotuning is disabled, there is no cache miss when running
- // autotuning for the same entry, so no error should be raised either.
- TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
- return AutotuneResult();
- }).status());
-}
-
-class FileBasedCacheTest : public AutotunerUtilTest {
- public:
- static std::string ToString(const AutotuneResult& message) {
- std::string textproto;
- CHECK(tsl::protobuf::TextFormat::PrintToString(message, &textproto));
- return textproto;
- }
-
- static std::vector<std::string> GetFilesInDir(
- const absl::string_view cache_dir) {
- std::vector<std::string> files_in_cache;
- TF_CHECK_OK(tsl::Env::Default()->GetChildren(std::string(cache_dir),
- &files_in_cache));
- return files_in_cache;
- }
-
- static std::string Read(const absl::string_view filepath) {
- std::string file_content;
- TF_CHECK_OK(tsl::ReadFileToString(tsl::Env::Default(),
- std::string(filepath), &file_content));
- return file_content;
- }
-
- static void Write(const absl::string_view filepath,
- const absl::string_view content) {
- TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(),
- std::string(filepath), content));
- }
-
- std::unique_ptr<stream_executor::StreamExecutor> executor_ =
- NewStreamExecutor();
- std::unique_ptr<HloModule> module_ =
- ParseAndReturnVerifiedModule(kHloText).value();
- const HloInstruction* dot_ = hlo_query::GetFirstInstructionWithOpcode(
- *module_->entry_computation(), HloOpcode::kDot);
- std::string cache_dir_ = [] {
- tsl::Env* default_env = tsl::Env::Default();
- std::string cache_dir;
- CHECK(default_env->LocalTempFilename(&cache_dir));
- CHECK_OK(default_env->CreateDir(cache_dir));
- return cache_dir;
- }();
- AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_.get()}, [&] {
- DebugOptions options;
- options.set_xla_gpu_per_fusion_autotune_cache_dir(cache_dir_);
- return options;
- }());
- AutotuneCacheKey cache_key_ = AutotunerUtil::GetKey(dot_, config_);
- std::string cache_filename_ = [&] {
- absl::StatusOr<std::string> key_hash =
- GetBase64EncodedSha256Hash(cache_key_.ToString());
- CHECK_OK(key_hash.status());
- return absl::StrCat(key_hash.value(), ".textproto");
- }();
- std::string cache_file_path_ = tsl::io::JoinPath(cache_dir_, cache_filename_);
- const AutotuneResult result1_ = [] {
- AutotuneResult result;
- result.set_scratch_bytes(1);
- return result;
- }();
- const AutotuneResult result2_ = [] {
- AutotuneResult result;
- result.set_scratch_bytes(2);
- return result;
- }();
-};
-
-TEST_F(FileBasedCacheTest, AutotuneWritesResultToTheCacheDir) {
- TF_ASSERT_OK_AND_ASSIGN(
- AutotuneResult result,
- AutotunerUtil::Autotune(dot_, config_, [&] { return result1_; }));
- EXPECT_EQ(ToString(result), ToString(result1_));
-
- ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
- EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
-}
-
-TEST_F(FileBasedCacheTest, AutotuneReadsResultFromTheCacheDir) {
- Write(cache_file_path_, ToString(result1_));
-
- bool cache_hit = true;
- TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
- AutotunerUtil::Autotune(dot_, config_, [&] {
- cache_hit = false;
- return result2_;
- }));
-
- EXPECT_TRUE(cache_hit);
- EXPECT_EQ(ToString(result), ToString(result1_));
-}
-
-TEST_F(FileBasedCacheTest,
- RepeatedAutotuneCallsDontReadOrWriteTheCacheFileAgain) {
- auto check_autotune_cache_hit = [](const HloInstruction* instr,
- const AutotuneConfig& config,
- const AutotuneResult& expected_result) {
- bool cache_hit = true;
- TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
- AutotunerUtil::Autotune(instr, config, [&] {
- cache_hit = false;
- AutotuneResult new_result;
- new_result.set_scratch_bytes(2);
- return new_result;
- }));
- EXPECT_TRUE(cache_hit);
- EXPECT_EQ(ToString(result), ToString(expected_result));
- };
-
- Write(cache_file_path_, ToString(result1_));
- check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
-
- constexpr absl::string_view kPlaceholderContent = "placeholder content";
- Write(cache_file_path_, kPlaceholderContent);
- // File was not read again:
- check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
- // File was not written again:
- EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
-}
-
-TEST_F(FileBasedCacheTest,
- IsInCacheReturnsTrueIfTheResultIsInTheFileBasedCache) {
- Write(cache_file_path_, ToString(result1_));
-
- TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
- AutotunerUtil::IsInCache(cache_key_, config_));
-
- EXPECT_TRUE(is_in_cache);
-}
-
-TEST_F(FileBasedCacheTest, IsInCacheReturnsFalseIfTheResultIsNotInEitherCache) {
- TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
- AutotunerUtil::IsInCache(cache_key_, config_));
-
- EXPECT_FALSE(is_in_cache);
-}
-
-TEST_F(FileBasedCacheTest, AddResultAddsTheResultToTheFileBasedCache) {
- TF_ASSERT_OK_AND_ASSIGN(
- bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
- EXPECT_TRUE(added);
-
- ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
- EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
-}
-
-TEST_F(FileBasedCacheTest, RepeatedAddResultDoesNotWriteTheFileAgain) {
- {
- TF_ASSERT_OK_AND_ASSIGN(
- bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
- EXPECT_TRUE(added);
- }
- ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
- EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
- constexpr absl::string_view kPlaceholderContent = "placeholder content";
- Write(cache_file_path_, kPlaceholderContent);
-
- {
- TF_ASSERT_OK_AND_ASSIGN(
- bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
- EXPECT_FALSE(added);
- }
-
- // File was not written again:
- EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
-}
-
-TEST(AutotuneCacheKeyTest, DeviceDescriptionToCacheKey) {
- auto device_description =
- [](absl::string_view spec_file_name) -> se::DeviceDescription {
- se::GpuTargetConfigProto proto;
- std::string spec_string;
- CHECK_OK(tsl::ReadFileToString(
- tsl::Env::Default(),
- tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "hlo_opt",
- "gpu_specs", spec_file_name),
- &spec_string));
- EXPECT_TRUE(
- tsl::protobuf::TextFormat::ParseFromString(spec_string, &proto));
- return se::DeviceDescription(proto.gpu_device_info());
- };
-
- EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
- device_description("a100_sxm_40.txtpb")),
- "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
- "1555 GB/s, L2 cache: 40 MB");
-
- EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
- device_description("a100_sxm_80.txtpb")),
- "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
- "2039 GB/s, L2 cache: 40 MB");
-
- EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
- device_description("mi200.txtpb")),
- "ROCM: gfx90a, Cores: 110, GPU clock: 1.7 GHz, Memory bandwidth: "
- "1638 GB/s, L2 cache: 8 MB");
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD
new file mode 100644
index 0000000..aa82b86
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/BUILD
@@ -0,0 +1,542 @@
+# Description:
+# Components that implement GPU autotuning.
+
+load(
+ "@local_tsl//tsl/platform:build_config.bzl",
+ "tf_proto_library",
+)
+load(
+ "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
+ "if_cuda_is_configured",
+)
+load("//xla:xla.bzl", "xla_cc_test")
+load("//xla/tests:build_defs.bzl", "xla_test")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = [":friends"],
+ licenses = ["notice"],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//xla:friends",
+ ],
+)
+
+cc_library(
+ name = "gemm_fusion_autotuner",
+ srcs = ["gemm_fusion_autotuner.cc"],
+ hdrs = ["gemm_fusion_autotuner.h"],
+ tags = [
+ "gpu",
+ "no_rocm",
+ ],
+ deps = [
+ ":autotuner_compile_util",
+ ":autotuner_util",
+ "//xla:autotuning_proto_cc",
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/pjrt/distributed:key_value_store_interface",
+ "//xla/service:algorithm_util",
+ "//xla/service:dump",
+ "//xla/service:executable",
+ "//xla/service:float_normalization",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service:shaped_buffer",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:buffer_comparator",
+ "//xla/service/gpu:gpu_float_support",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu:split_k_gemm_rewriter",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+ "//xla/service/gpu/transforms:cudnn_fusion_compiler",
+ "//xla/service/gpu/transforms:fusion_wrapper",
+ "//xla/service/gpu/transforms:gemm_rewriter",
+ "//xla/service/gpu/transforms:instruction_fusion",
+ "//xla/service/gpu/transforms:priority_fusion",
+ "//xla/stream_executor",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:device_memory",
+ "//xla/stream_executor:stream_executor_memory_allocator",
+ "//xla/stream_executor/gpu:redzone_allocator",
+ "//xla/tools:hlo_decomposer_lib",
+ "//xla/tsl/util/proto:proto_utils",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:span",
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_tsl//tsl/lib/core:bits",
+ "@local_tsl//tsl/platform:blocking_counter",
+ "@local_tsl//tsl/platform:env",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:path",
+ "@local_tsl//tsl/platform:protobuf",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/profiler/lib:scoped_annotation",
+ ],
+)
+
+xla_test(
+ name = "gemm_fusion_autotuner_test",
+ timeout = "long",
+ srcs = ["gemm_fusion_autotuner_test.cc"],
+ backend_tags = {"gpu": [
+ "requires-gpu-sm80",
+ ]},
+ backends = [
+ "gpu",
+ ],
+ tags = [
+ "no_rocm",
+ "nomac",
+ ],
+ deps = [
+ ":autotuner_compile_util",
+ ":autotuner_util",
+ ":gemm_fusion_autotuner",
+ "//xla:autotuning_proto_cc",
+ "//xla:error_spec",
+ "//xla:xla_data_proto_cc",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/pjrt/distributed:key_value_store_interface",
+ "//xla/service:call_inliner",
+ "//xla/service:dump",
+ "//xla/service:executable",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass_pipeline",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu/transforms:gemm_fusion",
+ "//xla/service/gpu/transforms:gemm_rewriter",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:device_description_proto_cc",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:test_utils",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//xla/tools:hlo_decomposer_lib",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest",
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_tsl//tsl/platform:env",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:path",
+ "@local_tsl//tsl/platform:platform_port",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "gemm_algorithm_picker",
+ srcs = ["gemm_algorithm_picker.cc"],
+ hdrs = ["gemm_algorithm_picker.h"],
+ tags = ["gpu"],
+ deps = [
+ ":autotuner_compile_util",
+ ":autotuner_util",
+ "//xla:autotune_results_proto_cc",
+ "//xla:autotuning_proto_cc",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:buffer_comparator",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/service/gpu:variant_visitor",
+ "//xla/stream_executor",
+ "//xla/stream_executor:blas",
+ "//xla/stream_executor:device_memory",
+ "//xla/stream_executor:device_memory_allocator",
+ "//xla/stream_executor/gpu:redzone_allocator",
+ "//xla/tsl/util/proto:proto_utils",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/profiler/lib:scoped_annotation",
+ ],
+)
+
+cc_library(
+ name = "autotuner_util",
+ srcs = ["autotuner_util.cc"],
+ hdrs = ["autotuner_util.h"],
+ tags = ["gpu"],
+ deps = [
+ "//xla:autotune_results_proto_cc",
+ "//xla:autotuning_proto_cc",
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:dump",
+ "//xla/service/gpu:gpu_asm_opts_util",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:device_memory",
+ "//xla/stream_executor:device_memory_allocator",
+ "//xla/stream_executor:stream",
+ "//xla/stream_executor:stream_executor_h",
+ "//xla/stream_executor:stream_executor_memory_allocator",
+ "//xla/stream_executor/gpu:redzone_allocator",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@llvm-project//llvm:Support",
+ "@local_tsl//tsl/platform:base64",
+ "@local_tsl//tsl/platform:env",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:path",
+ "@local_tsl//tsl/platform:protobuf",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+# We need a separate target, as runtime executable cannot depend on compilation
+# pipeline.
+cc_library(
+ name = "autotuner_compile_util",
+ srcs = ["autotuner_compile_util.cc"],
+ hdrs = ["autotuner_compile_util.h"],
+ tags = ["gpu"],
+ deps = [
+ ":autotuner_util",
+ "//xla:executable_run_options",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:compiler",
+ "//xla/service:executable",
+ "//xla/service:maybe_owning_device_memory",
+ "//xla/service:shaped_buffer",
+ "//xla/service/gpu:gpu_executable_run_options",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/stream_executor",
+ "//xla/stream_executor/gpu:redzone_allocator",
+ "@com_google_absl//absl/functional:any_invocable",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "autotuner_compile_util_test",
+ srcs = ["autotuner_compile_util_test.cc"],
+ backends = ["gpu"],
+ deps = [
+ ":autotuner_compile_util",
+ ":autotuner_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:platform_util",
+ "//xla/stream_executor:platform",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "gemm_algorithm_picker_test",
+ srcs = ["gemm_algorithm_picker_test.cc"],
+ backends = [
+ "gpu_v100",
+ "gpu_amd_any",
+ ],
+ deps = [
+ ":autotuner_util",
+ ":gemm_algorithm_picker",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service:platform_util",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:variant_visitor",
+ "//xla/service/gpu/transforms:gemm_rewriter",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:platform",
+ "//xla/tests:hlo_test_base",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ "@local_tsl//tsl/protobuf:dnn_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "conv_algorithm_picker",
+ srcs = ["conv_algorithm_picker.cc"],
+ hdrs = ["conv_algorithm_picker.h"],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+ tags = ["gpu"],
+ deps = [
+ ":autotuner_compile_util",
+ ":autotuner_util",
+ ":gpu_autotuning_proto_cc",
+ "//xla:autotune_results_proto_cc",
+ "//xla:autotuning_proto_cc",
+ "//xla:debug_options_flags",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:executable",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service:slow_operation_alarm",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:gpu_asm_opts_util",
+ "//xla/service/gpu:gpu_conv_runner",
+ "//xla/service/gpu:hlo_algorithm_denylist",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor",
+ "//xla/stream_executor:device_memory_allocator",
+ "//xla/stream_executor:dnn",
+ "//xla/stream_executor:lazy_op_runner",
+ "//xla/stream_executor:numeric_options",
+ "//xla/stream_executor:scratch_allocator",
+ "//xla/stream_executor/cuda:cuda_platform_id",
+ "//xla/stream_executor/rocm:rocm_platform_id",
+ "//xla/tsl/util:env_var",
+ "//xla/tsl/util/proto:proto_utils",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:numbers",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ] + if_cuda_is_configured([
+ # keep sorted
+ "//xla/service/gpu:buffer_comparator",
+ "//xla/stream_executor/gpu:redzone_allocator",
+ "@local_config_cuda//cuda:cudnn_header",
+ ]),
+)
+
+xla_test(
+ name = "conv_algorithm_picker_test",
+ srcs = ["conv_algorithm_picker_test.cc"],
+ backends = [
+ "gpu_v100",
+ "gpu_amd_any",
+ ],
+ tags = [
+ "noasan",
+ "nomsan",
+ ],
+ deps = [
+ ":autotuner_util",
+ ":conv_algorithm_picker",
+ "//xla:debug_options_flags",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service:platform_util",
+ "//xla/service:tuple_simplifier",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/service/gpu/transforms:conv_rewriter",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:platform",
+ "//xla/tests:hlo_test_base",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "custom_kernel_fusion_autotuner",
+ srcs = ["custom_kernel_fusion_autotuner.cc"],
+ hdrs = ["custom_kernel_fusion_autotuner.h"],
+ tags = ["gpu"],
+ deps = [
+ ":autotuner_compile_util",
+ ":autotuner_util",
+ "//xla:autotuning_proto_cc",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:executable",
+ "//xla/service:hlo_pass",
+ "//xla/service:shaped_buffer",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:buffer_comparator",
+ "//xla/service/gpu:gpu_float_support",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu:split_k_gemm_rewriter",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/service/gpu/kernels:custom_kernel",
+ "//xla/service/gpu/kernels:custom_kernel_fusion",
+ "//xla/stream_executor",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:stream_executor_memory_allocator",
+ "//xla/stream_executor/gpu:redzone_allocator",
+ "//xla/tools:hlo_decomposer_lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:path",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "custom_kernel_fusion_autotuner_test",
+ srcs = ["custom_kernel_fusion_autotuner_test.cc"],
+ backends = [
+ "gpu",
+ ],
+ tags = ["no_rocm"],
+ deps = [
+ ":autotuner_util",
+ ":custom_kernel_fusion_autotuner",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass_pipeline",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:path",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+tf_proto_library(
+ name = "gpu_autotuning_proto",
+ srcs = ["gpu_autotuning.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ "//xla/service/gpu:backend_configs",
+ "//xla:xla_data_proto",
+ "//xla/service:hlo_proto",
+ "//xla:autotuning_proto",
+ ],
+)
+
+xla_cc_test(
+ name = "autotuner_util_test",
+ srcs = ["autotuner_util_test.cc"],
+ data = [
+ "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb",
+ "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb",
+ "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb",
+ ],
+ tags = [
+ "gpu",
+ "no_rocm",
+ ],
+ deps = [
+ ":autotuner_util",
+ "//xla:autotune_results_proto_cc",
+ "//xla:autotuning_proto_cc",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:dump",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:device_description_proto_cc",
+ "//xla/stream_executor:platform",
+ "//xla/stream_executor:platform_manager",
+ "//xla/stream_executor/host:host_platform",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tests:xla_internal_test_main",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/base:log_severity",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/log:scoped_mock_log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:env",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:path",
+ "@local_tsl//tsl/platform:protobuf",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc
new file mode 100644
index 0000000..9922ea2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc
@@ -0,0 +1,284 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "xla/executable_run_options.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/compiler.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/gpu_executable_run_options.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/maybe_owning_device_memory.h"
+#include "xla/service/service_executable_run_options.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+std::vector<ExecutionInput> ExecutionInputsFromBuffers(
+ absl::Span<se::DeviceMemoryBase const> buffers,
+ absl::Span<Shape const> shapes) {
+ CHECK_EQ(buffers.size(), shapes.size());
+ std::vector<ExecutionInput> inputs;
+ for (int i = 0; i < buffers.size(); ++i) {
+ inputs.emplace_back(shapes.at(i));
+ // Our executable doesn't have input-output aliasing, so we can pass
+ // unowned input buffers.
+ inputs.back().SetUnownedBuffer(
+ /*index=*/{}, MaybeOwningDeviceMemory(/*unowned=*/buffers.at(i)));
+ }
+ return inputs;
+}
+
+} // namespace
+
+AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config,
+ Compiler* compiler,
+ se::StreamExecutor& stream_executor,
+ se::Stream& stream,
+ se::DeviceMemoryAllocator& allocator,
+ const DebugOptions& opts)
+ : config_(config),
+ compiler_(compiler),
+ stream_executor_(stream_executor),
+ stream_(stream),
+ allocator_(allocator),
+ opts_(opts) {
+ // Avoid dumping compilation steps.
+ opts_.set_xla_enable_dumping(false);
+ opts_.set_xla_gpu_dump_autotune_results_to("");
+ opts_.set_xla_gpu_load_autotune_results_from("");
+ opts_.set_xla_gpu_dump_llvmir(false);
+ opts_.set_xla_gpu_dump_autotune_logs_to("");
+ // Avoid using another thread pool.
+ opts_.set_xla_gpu_force_compilation_parallelism(1);
+ opts_.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
+ // Avoid using GPU graphs as we don't want to measure graph construction time.
+ opts_.clear_xla_gpu_enable_command_buffer();
+ opts_.set_xla_embed_ir_in_executable(false);
+ opts_.set_xla_gpu_kernel_cache_file("");
+}
+
+absl::StatusOr<std::optional<AutotunerCompileUtil::ProfilingOutput>>
+AutotunerCompileUtil::ProfileExecutable(
+ Executable* executable, se::Stream* stream,
+ absl::Span<se::DeviceMemoryBase const> input_buffers,
+ absl::Span<Shape const> input_shapes) {
+ {
+ std::vector<ExecutionInput> execution_inputs =
+ ExecutionInputsFromBuffers(input_buffers, input_shapes);
+ // Warmup: in and out buffers are reused while probing different configs,
+ // so GPU caches should be in some comparable states during measurements.
+ absl::StatusOr<ExecutionOutput> execution_output =
+ Execute(*executable, std::move(execution_inputs));
+ if (!execution_output.ok()) {
+ // Treat register allocation error gracefully. If the compilation happens
+ // with the driver during execution then the error could surface here.
+ // It's enough to check this once here.
+ if (execution_output.status().code() ==
+ absl::StatusCode::kResourceExhausted) {
+ return {std::nullopt};
+ }
+ return execution_output.status();
+ }
+
+ TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+ }
+ std::vector<ExecutionInput> execution_inputs =
+ ExecutionInputsFromBuffers(input_buffers, input_shapes);
+ ExecutionProfile profile;
+ // Flag that a warm-up run was executed so that GpuTimer can use the, more
+ // accurate, delay kernel implementation.
+ profile.set_warmup_run_executed(true);
+ TF_ASSIGN_OR_RETURN(
+ ExecutionOutput execution_output,
+ Execute(*executable, std::move(execution_inputs), &profile));
+ return std::make_optional<ProfilingOutput>(
+ absl::Nanoseconds(profile.compute_time_ns()),
+ execution_output.Commit().ConsumeResult());
+}
+
+absl::StatusOr<std::unique_ptr<Executable>> AutotunerCompileUtil::Compile(
+ GenerateModuleFn extractor) {
+ absl::StatusOr<std::unique_ptr<HloModule>> new_hlo_module = extractor(opts_);
+ if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) {
+ // Incompatible value of split-k is an example of an expected failure.
+ return std::unique_ptr<Executable>();
+ } else if (!new_hlo_module.status().ok()) {
+ return new_hlo_module.status();
+ }
+
+ absl::StatusOr<std::unique_ptr<Executable>> out = compiler_->RunBackend(
+ std::move(*new_hlo_module), &stream_executor_,
+ Compiler::CompileOptions{&allocator_, /*thread_pool=*/nullptr,
+ /*layout_canonicalization_callback=*/{},
+ /*is_autotuning_compilation=*/true});
+ if (out.status().code() == absl::StatusCode::kResourceExhausted ||
+ out.status().code() == absl::StatusCode::kCancelled) {
+ // Being out of shared memory budget or registers is an expected failure.
+ // Cancelling upon register spilling is also an expected failure.
+ return std::unique_ptr<Executable>();
+ }
+ return out;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> AutotunerCompileUtil::ExtractModule(
+ GenerateModuleFn extractor) {
+ return extractor(opts_);
+}
+
+/*static*/ absl::StatusOr<std::optional<AutotunerCompileUtil>>
+AutotunerCompileUtil::Create(const AutotuneConfig& config,
+ const DebugOptions& opts) {
+ if (config.IsDeviceless()) {
+ return std::nullopt;
+ }
+ se::StreamExecutor* stream_exec = config.GetExecutor();
+ se::DeviceMemoryAllocator* allocator = config.GetAllocator();
+ TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream());
+ TF_ASSIGN_OR_RETURN(Compiler * compiler,
+ Compiler::GetForPlatform(stream_exec->GetPlatform()));
+ return AutotunerCompileUtil(config, compiler, *stream_exec, *stream,
+ *allocator, opts);
+}
+
+absl::StatusOr<ExecutionOutput> AutotunerCompileUtil::Execute(
+ Executable& executable, std::vector<ExecutionInput> arguments,
+ ExecutionProfile* profile) {
+ // Require exclusive GPU lock to prevent other runs during autotuning.
+ GpuExecutableRunOptions gpu_opts;
+ gpu_opts.set_requires_exclusive_lock_on_gpu();
+
+ ExecutableRunOptions run_options;
+ run_options.set_device_ordinal(stream_executor_.device_ordinal());
+ run_options.set_stream(&stream_);
+ run_options.set_allocator(&allocator_);
+ run_options.set_gpu_executable_run_options(&gpu_opts);
+ run_options.set_execution_profile(profile);
+ ServiceExecutableRunOptions service_run_options(run_options);
+ TF_ASSIGN_OR_RETURN(ExecutionOutput output,
+ executable.ExecuteAsyncOnStreamWrapper(
+ &service_run_options, std::move(arguments)));
+
+ return std::move(output);
+}
+
+absl::StatusOr<RedzoneBuffers> RedzoneBuffers::FromInstruction(
+ const HloInstruction& instruction, const AutotuneConfig& config,
+ const DebugOptions& debug_options, BuffersToCreate buffers_to_create) {
+ RedzoneBuffers buffers;
+
+ TF_ASSIGN_OR_RETURN(auto rz_allocator, AutotunerUtil::CreateRedzoneAllocator(
+ config, debug_options));
+ buffers.redzone_allocator_ =
+ std::make_unique<se::RedzoneAllocator>(std::move(rz_allocator));
+
+ int64_t rng_state = 0;
+
+ TF_RETURN_IF_ERROR(
+ buffers.CreateInputs(instruction, config, debug_options, rng_state));
+
+ if (buffers_to_create == BuffersToCreate::kAllInputsAllOutputs ||
+ buffers_to_create == BuffersToCreate::kAllInputsOutputsNoScratch) {
+ TF_RETURN_IF_ERROR(buffers.CreateOutputs(instruction, config, debug_options,
+ buffers_to_create, rng_state));
+ }
+
+ return buffers;
+}
+
+absl::Status RedzoneBuffers::CreateInputs(const HloInstruction& instruction,
+ const AutotuneConfig& config,
+ const DebugOptions& debug_options,
+ int64_t& rng_state) {
+ for (const auto* operand : instruction.operands()) {
+ TF_ASSIGN_OR_RETURN(
+ se::DeviceMemoryBase buf,
+ AutotunerUtil::CreateBuffer(*redzone_allocator_, operand->shape(),
+ config, rng_state));
+ input_buffers_.push_back(buf);
+ input_shapes_.push_back(operand->shape());
+ }
+ return absl::OkStatus();
+}
+
+absl::Status RedzoneBuffers::CreateOutputs(const HloInstruction& instruction,
+ const AutotuneConfig& config,
+ const DebugOptions& debug_options,
+ BuffersToCreate buffers_to_create,
+ int64_t& rng_state) {
+ if (!instruction.shape().IsTuple()) {
+ TF_ASSIGN_OR_RETURN(
+ se::DeviceMemoryBase buf,
+ AutotunerUtil::CreateBuffer(*redzone_allocator_, instruction.shape(),
+ config, rng_state));
+ output_buffers_.push_back(buf);
+ output_shape_ = instruction.shape();
+ return absl::OkStatus();
+ }
+
+ // The output is a tuple.
+
+ auto current_shape_it = instruction.shape().tuple_shapes().begin();
+ auto end = instruction.shape().tuple_shapes().end();
+ end -= buffers_to_create == kAllInputsAllOutputs ? 0 : 1;
+
+ output_shape_ = std::distance(current_shape_it, end) == 1
+ ? output_shape_ = *current_shape_it
+ : ShapeUtil::MakeTupleShape(
+ std::vector<Shape>{current_shape_it, end});
+
+ for (; current_shape_it < end; current_shape_it++) {
+ if (current_shape_it->IsTuple()) {
+ return Unimplemented("Nested tuples are unsupported by RedzoneBuffers.");
+ }
+ TF_ASSIGN_OR_RETURN(
+ se::DeviceMemoryBase buf,
+ AutotunerUtil::CreateBuffer(*redzone_allocator_, *current_shape_it,
+ config, rng_state));
+ output_buffers_.push_back(buf);
+ }
+ return absl::OkStatus();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h
new file mode 100644
index 0000000..02b1bfa
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h
@@ -0,0 +1,177 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/functional/any_invocable.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/compiler.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+
+namespace xla {
+namespace gpu {
+
+// Autotuning utils which require compiling fusions separately. Requires a
+// separate target, as runtime autotuning cannot perform compilation.
+class AutotunerCompileUtil {
+ public:
+ // The GenerateModuleFn must generate/extract a module using the provided
+ // debug options. Typically it should set the debug options of the extracted
+ // module before it would transform it, to ensure that the transforms can use
+ // the debug options. In justified cases, it may override some of the provided
+ // debug options.
+ using GenerateModuleFn =
+ absl::AnyInvocable<absl::StatusOr<std::unique_ptr<HloModule>>(
+ const DebugOptions&)>;
+
+ // Generates a compile util for a platform associated with the `stream`.
+ //
+ // Returns an empty optional if the AutotuneConfig is deviceless, as
+ // autotuning is impossible in that case.
+ static absl::StatusOr<std::optional<AutotunerCompileUtil>> Create(
+ const AutotuneConfig& config, const DebugOptions& opts);
+
+ struct ProfilingOutput {
+ ProfilingOutput(absl::Duration duration, ScopedShapedBuffer&& buffer)
+ : duration(duration), output(std::move(buffer)) {}
+
+ absl::Duration duration;
+ ScopedShapedBuffer output;
+ };
+
+ // Generates an executable first, given the module generator function in
+ // `extractor`.
+ //
+ // Runs the resulting executable with the given extractor, cached with
+ // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad
+ // `Status` otherwise.
+ absl::StatusOr<std::optional<ProfilingOutput>> ProfileExecutable(
+ Executable* executable, se::Stream* stream,
+ absl::Span<se::DeviceMemoryBase const> input_buffers,
+ absl::Span<Shape const> input_shapes);
+
+ // Generic method to compile a generated module from `extractor` in isolation.
+ //
+ // Returns:
+ // - `nullptr` on *expected* failure
+ // - `Executable` if everything goes fine.
+ // - `Status` on *unexpected* failure.
+ absl::StatusOr<std::unique_ptr<Executable>> Compile(
+ GenerateModuleFn extractor);
+
+ // Generic method to extract an HLO using the debug options of the
+ // AutotunerCompileUtil.
+ //
+ // Typically we can use Compile directly.
+ absl::StatusOr<std::unique_ptr<HloModule>> ExtractModule(
+ GenerateModuleFn extractor);
+
+ private:
+ AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler,
+ se::StreamExecutor& stream_executor, se::Stream& stream,
+ se::DeviceMemoryAllocator& allocator,
+ const DebugOptions& opts);
+
+ absl::StatusOr<ExecutionOutput> Execute(Executable& executable,
+ std::vector<ExecutionInput> arguments,
+ ExecutionProfile* profile = nullptr);
+
+ AutotuneConfig config_;
+ Compiler* compiler_;
+ se::StreamExecutor& stream_executor_;
+ se::Stream& stream_;
+ se::DeviceMemoryAllocator& allocator_;
+ DebugOptions opts_;
+};
+
+// A RedZone allocator and a collection of buffers that store the inputs and
+// outputs of an HloInstruction. These are used when running the instruction
+// for autotuning.
+class RedzoneBuffers {
+ public:
+ enum BuffersToCreate {
+ // Create a buffer for all of the instruction's operands. The result shape
+ // is ignored.
+ kAllInputs = 0,
+ // Create a buffer for all of the instruction's operands and the entire
+ // result shape. If the result shape is a tuple, a separate buffer is
+ // created for each subshape.
+ kAllInputsAllOutputs = 1,
+ // Create a buffer for all of the instruction's operands and all of the
+ // subshapes of the result tuple, except for the last one. The last subshape
+ // is considered a scratch buffer and is assumed to be allocated elsewhere.
+ // If the result shape is not a tuple, this will create a buffer
+ // corresponding to the entire shape - equivalent to `kAllInputsAllOutputs`.
+ kAllInputsOutputsNoScratch = 2,
+ };
+ static absl::StatusOr<RedzoneBuffers> FromInstruction(
+ const HloInstruction& instruction, const AutotuneConfig& config,
+ const DebugOptions& debug_options, BuffersToCreate buffers_to_create);
+
+ const std::vector<se::DeviceMemoryBase>& input_buffers() const {
+ return input_buffers_;
+ }
+ const std::vector<Shape>& input_shapes() const { return input_shapes_; }
+ const std::vector<se::DeviceMemoryBase>& output_buffers() const {
+ return output_buffers_;
+ }
+ const Shape& output_shape() const { return output_shape_; }
+ se::RedzoneAllocator& RedzoneAllocator() const { return *redzone_allocator_; }
+
+ private:
+ absl::Status CreateInputs(const HloInstruction& instruction,
+ const AutotuneConfig& config,
+ const DebugOptions& debug_options,
+ int64_t& rng_state);
+
+ absl::Status CreateOutputs(const HloInstruction& instruction,
+ const AutotuneConfig& config,
+ const DebugOptions& debug_options,
+ BuffersToCreate buffers_to_create,
+ int64_t& rng_state);
+
+ std::unique_ptr<se::RedzoneAllocator> redzone_allocator_;
+ std::vector<se::DeviceMemoryBase> input_buffers_;
+ std::vector<Shape> input_shapes_;
+ std::vector<se::DeviceMemoryBase> output_buffers_;
+ Shape output_shape_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc
new file mode 100644
index 0000000..a8b9594
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc
@@ -0,0 +1,196 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/platform_util.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using AutotunerCompileUtilTest = HloTestBase;
+
+TEST_F(AutotunerCompileUtilTest, VerifyOutputNotATuple) {
+ constexpr absl::string_view kHlo = R"(
+HloModule hlo
+ENTRY main {
+ p0 = f32[2,2] parameter(0)
+ p1 = f32[4,4] parameter(1)
+ p2 = f32[6,6] parameter(2)
+ ROOT root = f32[1,2,3] custom-call(p0, p1, p2), custom_call_target="fake"
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+ PlatformUtil::GetStreamExecutors(platform));
+
+ AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
+ GetDebugOptionsForTest()};
+
+ auto& root = *module->entry_computation()->root_instruction();
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputs));
+
+ EXPECT_EQ(rzb.input_shapes().size(), 3);
+ EXPECT_EQ(rzb.input_buffers().size(), 3);
+ EXPECT_EQ(rzb.output_buffers().size(), 0);
+ EXPECT_NE(rzb.output_shape(), root.shape());
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputsAllOutputs));
+
+ EXPECT_EQ(rzb2.input_shapes().size(), 3);
+ EXPECT_EQ(rzb2.input_buffers().size(), 3);
+ EXPECT_EQ(rzb2.output_buffers().size(), 1);
+ EXPECT_EQ(rzb2.output_shape(), root.shape());
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+ EXPECT_EQ(rzb3.input_shapes().size(), 3);
+ EXPECT_EQ(rzb3.input_buffers().size(), 3);
+ EXPECT_EQ(rzb3.output_buffers().size(), 1);
+ EXPECT_EQ(rzb3.output_shape(), root.shape());
+}
+
+TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleOneElement) {
+ constexpr absl::string_view kHlo = R"(
+HloModule hlo
+ENTRY main {
+ p0 = f32[2,2] parameter(0)
+ p1 = f32[4,4] parameter(1)
+ p2 = f32[6,6] parameter(2)
+ ROOT root = (f32[1,2,3]) custom-call(p0, p1, p2), custom_call_target="fake"
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+ PlatformUtil::GetStreamExecutors(platform));
+
+ AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
+ GetDebugOptionsForTest()};
+
+ auto& root = *module->entry_computation()->root_instruction();
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputs));
+
+ EXPECT_EQ(rzb.input_shapes().size(), 3);
+ EXPECT_EQ(rzb.input_buffers().size(), 3);
+ EXPECT_EQ(rzb.output_buffers().size(), 0);
+ EXPECT_NE(rzb.output_shape(), root.shape());
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputsAllOutputs));
+
+ EXPECT_EQ(rzb2.input_shapes().size(), 3);
+ EXPECT_EQ(rzb2.input_buffers().size(), 3);
+ EXPECT_EQ(rzb2.output_buffers().size(), 1);
+ EXPECT_FALSE(rzb2.output_shape().IsTuple());
+ EXPECT_EQ(rzb2.output_shape(), root.shape().tuple_shapes(0));
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+ EXPECT_EQ(rzb3.input_shapes().size(), 3);
+ EXPECT_EQ(rzb3.input_buffers().size(), 3);
+ EXPECT_EQ(rzb3.output_buffers().size(), 0);
+}
+
+TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleTwoElements) {
+ constexpr absl::string_view kHlo = R"(
+HloModule hlo
+ENTRY main {
+ p0 = f32[2,2] parameter(0)
+ p1 = f32[4,4] parameter(1)
+ p2 = f32[6,6] parameter(2)
+ ROOT root = (f32[1,2,3], u8[1,2]) custom-call(p0, p1, p2), custom_call_target="fake"
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+ PlatformUtil::GetStreamExecutors(platform));
+
+ AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
+ GetDebugOptionsForTest()};
+
+ auto& root = *module->entry_computation()->root_instruction();
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputs));
+
+ EXPECT_EQ(rzb.input_shapes().size(), 3);
+ EXPECT_EQ(rzb.input_buffers().size(), 3);
+ EXPECT_EQ(rzb.output_buffers().size(), 0);
+ EXPECT_NE(rzb.output_shape(), root.shape());
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputsAllOutputs));
+
+ EXPECT_EQ(rzb2.input_shapes().size(), 3);
+ EXPECT_EQ(rzb2.input_buffers().size(), 3);
+ EXPECT_EQ(rzb2.output_buffers().size(), 2);
+ EXPECT_TRUE(rzb2.output_shape().IsTuple());
+ EXPECT_EQ(rzb2.output_shape(), root.shape());
+
+ TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
+ RedzoneBuffers::FromInstruction(
+ root, autotune_config, GetDebugOptionsForTest(),
+ RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+ EXPECT_EQ(rzb3.input_shapes().size(), 3);
+ EXPECT_EQ(rzb3.input_buffers().size(), 3);
+ EXPECT_EQ(rzb3.output_buffers().size(), 1);
+ EXPECT_FALSE(rzb3.output_shape().IsTuple());
+ EXPECT_EQ(rzb3.output_shape(), root.shape().tuple_shapes(0));
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc
new file mode 100644
index 0000000..79bb744
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc
@@ -0,0 +1,551 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+
+#include <algorithm>
+#include <array>
+#include <cmath>
+#include <cstdint>
+#include <limits>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+
+#include "absl/base/const_init.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SHA256.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/dump.h"
+#include "xla/service/gpu/gpu_asm_opts_util.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "tsl/platform/base64.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/protobuf.h" // IWYU pragma: keep
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Bump this version whenever you change the structure of the results.
+// LINT.IfChange(version)
+constexpr int kVersion = 3;
+// LINT.ThenChange()
+
+} // namespace
+
+using AutotuneCacheMap = absl::flat_hash_map<AutotuneCacheKey, AutotuneResult>;
+
+static absl::Mutex autotune_cache_mu(absl::kConstInit);
+static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) =
+ *new AutotuneCacheMap();
+
+absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s) {
+ llvm::SHA256 sha256;
+ sha256.update(llvm::StringRef(s));
+ std::array<uint8_t, 32> hash = sha256.final();
+ // C++ strict aliasing rules allow reinterpret casting to (const) char*.
+ absl::string_view hash_view(reinterpret_cast<const char*>(hash.data()),
+ hash.size());
+ std::string base64_encoded_hash;
+ TF_RETURN_IF_ERROR(tsl::Base64Encode(hash_view, &base64_encoded_hash));
+ return base64_encoded_hash;
+}
+
+namespace {
+
+// Get the path corresponding to the given key.
+absl::StatusOr<std::string> GetCacheFilePath(absl::string_view cache_dir,
+ const AutotuneCacheKey& key) {
+ if (cache_dir.empty()) {
+ return absl::InvalidArgumentError("autotune_cache_dir should not be empty");
+ }
+
+ TF_ASSIGN_OR_RETURN(std::string key_hash,
+ GetBase64EncodedSha256Hash(key.ToString()));
+ return tsl::io::JoinPath(cache_dir, absl::StrCat(key_hash, ".textproto"));
+}
+
+struct ResultAndInserted {
+ // The result that ended up in the cache. This is the existing result if
+ // inserted is false, and the new result if inserted is true.
+ //
+ // We return a value, not a pointer, for thread safety reasons.
+ AutotuneResult result;
+ // Did we insert the given result into the cache?
+ bool inserted;
+};
+
+ResultAndInserted AddResultToInMemoryCache(const AutotuneCacheKey& key,
+ AutotuneResult result)
+ ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+ absl::MutexLock lock(&autotune_cache_mu);
+ auto [it, inserted] = autotune_cache.emplace(key, std::move(result));
+ return {it->second, inserted};
+}
+
+absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
+ AutotuneResult result,
+ std::string_view cache_dir)
+ ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+ if (cache_dir.empty()) {
+ return absl::OkStatus();
+ }
+
+ tsl::Env* default_env = tsl::Env::Default();
+ TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(cache_dir), default_env));
+
+ TF_ASSIGN_OR_RETURN(const std::string file_path,
+ GetCacheFilePath(cache_dir, key));
+
+ VLOG(1) << "Writing autotune result to file: " << file_path;
+
+ std::string result_str;
+ if (!tsl::protobuf::TextFormat::PrintToString(result, &result_str)) {
+ return absl::InternalError("Failed to serialize autotune result.");
+ }
+
+ // Rename trick: Write to a temporary file, then rename it to the final file
+ // to avoid mingled files when multiple processes are writing to the same
+ // file. Also avoids reading incomplete files. (This may not work on all file
+ // systems.)
+ std::string temp_file_path = tsl::io::GetTempFilename(".textproto");
+ TF_RETURN_IF_ERROR(
+ tsl::WriteStringToFile(default_env, temp_file_path, result_str));
+ return default_env->RenameFile(temp_file_path, file_path);
+}
+
+absl::StatusOr<ResultAndInserted> AddResultToCaches(const AutotuneCacheKey& key,
+ AutotuneResult result,
+ std::string_view cache_dir)
+ ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+ ResultAndInserted result_and_inserted = AddResultToInMemoryCache(key, result);
+ if (result_and_inserted.inserted) {
+ TF_RETURN_IF_ERROR(AddResultToFileBasedCacheIfEnabled(
+ key, result_and_inserted.result, cache_dir));
+ }
+ return result_and_inserted;
+}
+
+std::optional<AutotuneResult> TryToFindInInMemoryCache(
+ const AutotuneCacheKey& key) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+ absl::MutexLock lock(&autotune_cache_mu);
+ auto it = autotune_cache.find(key);
+ if (it == autotune_cache.end()) {
+ return std::nullopt;
+ }
+ return it->second;
+}
+
+absl::StatusOr<std::optional<AutotuneResult>>
+TryToFindInFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
+ absl::string_view cache_dir)
+ ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+ if (cache_dir.empty()) {
+ return std::nullopt;
+ }
+
+ TF_ASSIGN_OR_RETURN(const std::string file_path,
+ GetCacheFilePath(cache_dir, key));
+ if (!tsl::Env::Default()->FileExists(file_path).ok()) {
+ VLOG(1) << "Autotune result file not found: " << file_path;
+ return std::nullopt;
+ }
+
+ VLOG(1) << "Autotune result file found: " << file_path;
+ std::string autotune_result_str;
+ TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), file_path,
+ &autotune_result_str));
+ AutotuneResult result;
+ if (!tsl::protobuf::TextFormat::ParseFromString(autotune_result_str,
+ &result)) {
+ return absl::InvalidArgumentError("Failed to parse autotune result.");
+ }
+ return result;
+}
+
+// Sort the results so that they're deterministic.
+void SortAutotuneResults(AutotuneResults* results) {
+ std::sort(results->mutable_results()->pointer_begin(),
+ results->mutable_results()->pointer_end(),
+ [](const auto* a, const auto* b) {
+ return std::make_pair(absl::string_view(a->device()),
+ absl::string_view(a->hlo())) <
+ std::make_pair(absl::string_view(b->device()),
+ absl::string_view(b->hlo()));
+ });
+}
+
+} // namespace
+
+// Serialize `results` to string as a proto.
+absl::StatusOr<std::string> AutotuneResultsToString(
+ const AutotuneResults& results, bool as_textproto) {
+ if (as_textproto) {
+ std::string textproto;
+ if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) {
+ return textproto;
+ } else {
+ return Internal("Failed to serialize autotune results.");
+ }
+ }
+ return results.SerializeAsString();
+}
+
+namespace {
+// Serialize a single entry to `results`.
+void SerializeAutotuneEntry(AutotuneResults* results, const AutotuneCacheKey& k,
+ const AutotuneResult* res) {
+ auto& entry = *results->add_results();
+ entry.set_device(std::string(k.GetModelStr()));
+ entry.set_hlo(std::string(k.GetHlo()));
+ *entry.mutable_result() = *res;
+}
+} // namespace
+
+/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults(
+ AutotuneResults* results) {
+ absl::MutexLock lock(&autotune_cache_mu);
+ for (const auto& [k, result] : autotune_cache) {
+ SerializeAutotuneEntry(results, k, &result);
+ }
+
+ results->set_version(kVersion);
+ SortAutotuneResults(results);
+
+ return absl::OkStatus();
+}
+
+/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
+ const AutotuneResults& results) {
+ absl::MutexLock lock(&autotune_cache_mu);
+ for (const AutotuneResults::Entry& result : results.results()) {
+ if (auto [it, inserted] = autotune_cache.emplace(
+ AutotuneCacheKey(result.device(), result.hlo()), result.result());
+ !inserted) {
+ return absl::InternalError(absl::StrCat(
+ "Duplicate autotuning result for ", it->first.ToString()));
+ }
+ }
+ return absl::OkStatus();
+}
+
+/*static*/ void AutotunerUtil::ClearAutotuneResults() {
+ absl::MutexLock lock(&autotune_cache_mu);
+ autotune_cache.clear();
+}
+
+/*static*/ bool AutotunerUtil::ResultCacheIsEmpty() {
+ absl::MutexLock lock(&autotune_cache_mu);
+ return autotune_cache.empty();
+}
+
+/* static*/ absl::StatusOr<se::DeviceMemoryBase> AutotunerUtil::CreateBuffer(
+ se::RedzoneAllocator& allocator, const Shape& shape,
+ const AutotuneConfig& config, int64_t& rng_state) {
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
+ allocator.AllocateBytes(ShapeUtil::ByteSizeOf(shape)));
+ if (config.should_init_buffers()) {
+ InitializeBuffer(allocator.stream(), shape.element_type(), &rng_state,
+ buffer);
+ }
+ return buffer;
+}
+
+namespace {
+std::string ToCanonicalString(const HloInstruction* instr) {
+ auto options = HloPrintOptions::Canonical();
+ if (instr->opcode() != HloOpcode::kFusion) {
+ options.set_print_backend_config(true);
+ options.set_sort_backend_config(true);
+ return instr->ToString(options);
+ }
+ options.set_print_subcomputation_mode(
+ HloPrintOptions::PrintSubcomputationMode::kOff);
+ options.set_print_infeed_outfeed_config(false);
+ options.set_print_only_essential_constants(true);
+ options.set_print_operand_shape(true);
+ options.set_print_ids(false);
+ options.set_canonicalize_computations(true);
+
+ // TODO(b/266210099): This is unsound. We should probably do the fingerprint
+ // of the HLO computation proto instead.
+ return instr->called_computations()[0]->ToString(options);
+}
+
+} // namespace
+
+AutotuneCacheKey::AutotuneCacheKey(absl::string_view model_str,
+ const HloInstruction& instr)
+ : AutotuneCacheKey(model_str, ToCanonicalString(&instr)) {}
+
+/*static*/ std::string AutotuneCacheKey::DeviceDescriptionToCacheKey(
+ const se::DeviceDescription& device_description) {
+ std::string compute_capability;
+ if (auto* ccc = std::get_if<se::CudaComputeCapability>(
+ &device_description.gpu_compute_capability())) {
+ compute_capability = absl::StrCat("CUDA: ", ccc->major, ".", ccc->minor);
+ } else {
+ auto* rcc = std::get_if<se::RocmComputeCapability>(
+ &device_description.gpu_compute_capability());
+ CHECK(rcc != nullptr) << "Unknown compute capability type";
+ compute_capability = absl::StrCat("ROCM: ", rcc->gfx_version());
+ }
+
+ // The string below should include only as much information as is needed to
+ // make it a valid key. Information that should not be included is:
+ // - specs that are directly derivable from the compute capability, e.g.
+ // shared memory size. For NVIDIA GPUs, you can see what is derivable from
+ // the SM version here:
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
+ // - specs that are irrelevant for autotuning. E.g. the total available memory
+ // on a device is not relevant, because by itself, it does not affect the
+ // performance of single kernels.
+ //
+ // See b/344573710 for some discussion.
+
+ double memory_bandwidth = device_description.memory_bandwidth() / 1e9;
+ // Round the memory bandwidth to make the final string nicer to read.
+ // This will also cause minute differences in bandwidth to yield the same
+ // cache key, but that's fine, since the difference is inconsequential.
+ memory_bandwidth = std::round(memory_bandwidth);
+
+ constexpr double kBytesPerMegabyte = 1 << 20;
+ double l2_cache_size = device_description.l2_cache_size() / kBytesPerMegabyte;
+
+ return absl::StrCat(compute_capability,
+ ", Cores: ", device_description.core_count(),
+ ", GPU clock: ", device_description.clock_rate_ghz(),
+ " GHz, Memory bandwidth: ", memory_bandwidth,
+ " GB/s, L2 cache: ", l2_cache_size, " MB");
+}
+
+namespace {
+absl::StatusOr<std::optional<AutotuneResult>> TryFindInCache(
+ const AutotuneCacheKey& key, absl::string_view cache_dir)
+ ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+ std::optional<AutotuneResult> opt_result = TryToFindInInMemoryCache(key);
+ if (opt_result.has_value()) {
+ if (VLOG_IS_ON(1)) {
+ LOG(INFO) << "In-memory autotune cache hit";
+ } else if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "In-memory autotune cache hit: key = " << key.ToString();
+ }
+ return opt_result;
+ }
+
+ TF_ASSIGN_OR_RETURN(opt_result,
+ TryToFindInFileBasedCacheIfEnabled(key, cache_dir));
+ if (opt_result.has_value()) {
+ AddResultToInMemoryCache(key, opt_result.value());
+
+ if (VLOG_IS_ON(1)) {
+ LOG(INFO) << "File-based autotune cache hit";
+ } else if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "File-based autotune cache hit: key = " << key.ToString();
+ }
+ return opt_result;
+ }
+
+ if (VLOG_IS_ON(1)) {
+ LOG(INFO) << "Autotune cache miss";
+ } else if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "Autotune cache miss: key = " << key.ToString();
+ }
+ return std::nullopt;
+}
+} // namespace
+
+/*static*/ AutotuneCacheKey AutotunerUtil::GetKey(
+ const HloInstruction* instr, const AutotuneConfig& config) {
+ return AutotuneCacheKey(config.GetModelStr(), *instr);
+}
+
+/*static*/ absl::StatusOr<bool> AutotunerUtil::IsInCache(
+ const AutotuneCacheKey& key, const AutotuneConfig& config) {
+ TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
+ TryFindInCache(key, config.autotune_cache_dir()));
+ return opt_res.has_value();
+}
+
+/*static*/ absl::StatusOr<bool> AutotunerUtil::AddResult(
+ const AutotuneCacheKey& key, AutotuneResult result,
+ const AutotuneConfig& config) {
+ TF_ASSIGN_OR_RETURN(
+ ResultAndInserted result_and_inserted,
+ AddResultToCaches(key, std::move(result), config.autotune_cache_dir()));
+ return result_and_inserted.inserted;
+}
+
+/*static*/ absl::StatusOr<AutotuneResult> AutotunerUtil::Autotune(
+ const HloInstruction* instr, const AutotuneConfig& config,
+ const AutotuneNoCacheFn& autotune_fn) {
+ const AutotuneCacheKey key = GetKey(instr, config);
+ TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
+ TryFindInCache(key, config.autotune_cache_dir()));
+ if (opt_res.has_value()) {
+ return opt_res.value();
+ }
+
+ // Cache miss.
+ if (config.should_require_complete_aot_autotune_results()) {
+ return NotFound(
+ "Complete XLA AOT autotuning results are required, but no AOT result "
+ "was found for key: %s",
+ key.ToString());
+ }
+
+ TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn());
+
+ TF_ASSIGN_OR_RETURN(ResultAndInserted result_and_inserted,
+ AddResultToCaches(key, std::move(autotune_result),
+ config.autotune_cache_dir()));
+ return result_and_inserted.result;
+}
+
+namespace {
+
+bool IsTextProtoPath(absl::string_view file_path) {
+ return absl::EndsWith(file_path, ".txt") ||
+ absl::EndsWith(file_path, ".textproto") ||
+ absl::EndsWith(file_path, ".prototxt") ||
+ absl::EndsWith(file_path, ".pbtxt");
+}
+
+} // anonymous namespace
+
+/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
+ absl::string_view data, bool as_textproto) {
+ AutotuneResults results;
+ // The cast here is necessary for MacOS builds.
+ bool parse_success =
+ as_textproto ? tsl::protobuf::TextFormat::ParseFromString(
+ std::string(data), &results) // NOLINT
+ : results.ParseFromString(std::string(data)); // NOLINT
+ if (!parse_success) {
+ return absl::InvalidArgumentError(
+ "Failed to parse autotune results string.");
+ }
+ if (results.version() != kVersion) {
+ return absl::InvalidArgumentError(absl::StrFormat(
+ "Version mismatch in autotune results. Expected %d but was %d",
+ kVersion, results.version()));
+ }
+
+ TF_RETURN_IF_ERROR(LoadAutotuneResults(results));
+ return absl::OkStatus();
+}
+
+/*static*/ absl::StatusOr<std::string> AutotunerUtil::SerializeAutotuneResults(
+ bool as_textproto) {
+ AutotuneResults results;
+ TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
+ return AutotuneResultsToString(results, as_textproto);
+}
+
+/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
+ const AutotuneResults& results, absl::string_view file_path) {
+ TF_RET_CHECK(!file_path.empty());
+ TF_RET_CHECK(results.version() > 0)
+ << "Did you call SerializeAutotuneResults to get this AutotuneResults?";
+
+ std::string resolved_path;
+ if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
+ return FailedPrecondition("File path can not be resolved: %s", file_path);
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::string autotune_results_str,
+ AutotuneResultsToString(results, IsTextProtoPath(resolved_path)));
+ TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), resolved_path,
+ autotune_results_str));
+ LOG(INFO) << "Autotune results serialized to file: " << resolved_path;
+
+ return absl::OkStatus();
+}
+
+/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
+ absl::string_view file_path) {
+ AutotuneResults results;
+ TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
+ return SerializeAutotuneResultsToFile(results, file_path);
+}
+
+/*static*/ absl::Status AutotunerUtil::LoadAutotuneResultsFromFile(
+ absl::string_view file_path) {
+ TF_RET_CHECK(!file_path.empty());
+
+ std::string resolved_path;
+ if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
+ return FailedPrecondition("File path can not be resolved: %s", file_path);
+ }
+
+ if (!tsl::Env::Default()->FileExists(resolved_path).ok()) {
+ return FailedPrecondition("Autotune results file does not exist: %s",
+ resolved_path);
+ }
+ std::string autotune_results_str;
+ TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), resolved_path,
+ &autotune_results_str));
+
+ TF_RETURN_IF_ERROR(LoadAutotuneResults(autotune_results_str,
+ IsTextProtoPath(resolved_path)));
+
+ LOG(INFO) << "Autotune results loaded from file: " << resolved_path;
+
+ return absl::OkStatus();
+}
+
+/*static*/ absl::StatusOr<se::RedzoneAllocator>
+AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config,
+ const DebugOptions& opts) {
+ TF_ASSIGN_OR_RETURN(se::Stream * stream, config.GetStream());
+ return se::RedzoneAllocator(
+ stream, config.GetAllocator(), PtxOptsFromDebugOptions(opts),
+ /*memory_limit=*/std::numeric_limits<int64_t>::max(),
+ /*redzone_size=*/config.should_check_correctness()
+ ? opts.xla_gpu_redzone_padding_bytes()
+ : 0);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h
new file mode 100644
index 0000000..6d5c321
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h
@@ -0,0 +1,334 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/xla.pb.h"
+
+namespace xla {
+namespace gpu {
+
+struct DeviceConfig {
+ se::StreamExecutor* stream_exec; // never null
+
+ // If the `allocator` parameter is not null, we will use it to allocate temp
+ // memory while timing the various convolution algorithms. If it's null,
+ // we'll use the default allocator on the StreamExecutor.
+ se::DeviceMemoryAllocator* allocator = nullptr; // may be null
+};
+
+struct DevicelessConfig {
+ // The device description of the target device.
+ se::DeviceDescription device_description;
+};
+
+class AutotuneCacheKey {
+ public:
+ AutotuneCacheKey(const se::DeviceDescription& device_description,
+ const HloInstruction& instruction)
+ : AutotuneCacheKey(DeviceDescriptionToCacheKey(device_description),
+ instruction.ToString()) {}
+
+ AutotuneCacheKey(absl::string_view model_str,
+ const HloInstruction& instruction);
+
+ explicit AutotuneCacheKey(absl::string_view model_str,
+ absl::string_view hlo_canonical)
+ : model_str_(model_str), hlo_canonical_(hlo_canonical) {}
+
+ absl::string_view GetModelStr() const { return model_str_; }
+
+ absl::string_view GetHlo() const { return hlo_canonical_; }
+
+ template <typename H>
+ friend H AbslHashValue(H h, const AutotuneCacheKey& w) {
+ return H::combine(std::move(h), w.model_str_, w.hlo_canonical_);
+ }
+
+ bool operator==(const AutotuneCacheKey& w) const {
+ return model_str_ == w.model_str_ && hlo_canonical_ == w.hlo_canonical_;
+ }
+
+ std::string ToString() const {
+ return absl::StrFormat("<key model='%s', hlo='%s'>", model_str_,
+ hlo_canonical_);
+ }
+
+ static std::string DeviceDescriptionToCacheKey(
+ const se::DeviceDescription& device_description);
+
+ private:
+ std::string model_str_;
+ std::string hlo_canonical_;
+};
+
+class AutotuneConfig {
+ public:
+ bool should_init_buffers() const { return autotune_level_ >= 2; }
+ bool should_reinit_output_buffer() const { return autotune_level_ >= 3; }
+ bool should_check_correctness() const { return autotune_level_ >= 4; }
+ bool should_skip_wrong_results() const { return autotune_level_ >= 5; }
+ bool should_crash_on_check_failure() const {
+ return should_crash_on_check_failure_;
+ }
+ bool should_require_complete_aot_autotune_results() const {
+ return require_complete_aot_autotune_results_;
+ }
+ // Empty string means no cache is used.
+ const std::string& autotune_cache_dir() const { return autotune_cache_dir_; }
+
+ AutotuneConfig(const AutotuneConfig& right)
+ : config_(right.config_),
+ autotune_level_(right.autotune_level_),
+ should_crash_on_check_failure_(right.should_crash_on_check_failure_),
+ exhaustive_tiling_search_(right.exhaustive_tiling_search_),
+ require_complete_aot_autotune_results_(
+ right.require_complete_aot_autotune_results_),
+ autotune_cache_dir_(right.autotune_cache_dir_) {}
+
+ AutotuneConfig(const std::variant<DeviceConfig, DevicelessConfig>& config,
+ const DebugOptions& debug_options)
+ : config_(config),
+ autotune_level_(debug_options.xla_gpu_autotune_level()),
+ should_crash_on_check_failure_(
+ debug_options.xla_gpu_crash_on_verification_failures()),
+ exhaustive_tiling_search_(
+ debug_options.xla_gpu_exhaustive_tiling_search()),
+ require_complete_aot_autotune_results_(
+ debug_options.xla_gpu_require_complete_aot_autotune_results()),
+ autotune_cache_dir_(
+ debug_options.xla_gpu_per_fusion_autotune_cache_dir()) {}
+
+ std::string GetModelStr() const {
+ if (auto deviceless_config = std::get_if<DevicelessConfig>(&config_)) {
+ return AutotuneCacheKey::DeviceDescriptionToCacheKey(
+ deviceless_config->device_description);
+ }
+
+ const auto& device_config = std::get<DeviceConfig>(config_);
+ return AutotuneCacheKey::DeviceDescriptionToCacheKey(
+ device_config.stream_exec->GetDeviceDescription());
+ }
+
+ se::StreamExecutor* GetExecutor() const {
+ CHECK(std::holds_alternative<DeviceConfig>(config_));
+ return std::get<DeviceConfig>(config_).stream_exec;
+ }
+
+ se::DeviceMemoryAllocator* GetAllocator() const {
+ CHECK(std::holds_alternative<DeviceConfig>(config_));
+ auto& cf = std::get<DeviceConfig>(config_);
+ if (cf.allocator != nullptr) {
+ return cf.allocator;
+ }
+ if (allocator_ == nullptr) {
+ allocator_ =
+ std::make_unique<se::StreamExecutorMemoryAllocator>(GetExecutor());
+ }
+ return allocator_.get();
+ }
+
+ absl::StatusOr<se::Stream*> GetStream() const {
+ CHECK(std::holds_alternative<DeviceConfig>(config_));
+ return GetAllocator()->GetStream(GetExecutor()->device_ordinal());
+ }
+
+ const se::GpuComputeCapability& GetGpuComputeCapability() const {
+ if (auto c = std::get_if<DeviceConfig>(&config_)) {
+ return c->stream_exec->GetDeviceDescription().gpu_compute_capability();
+ }
+ return std::get<DevicelessConfig>(config_)
+ .device_description.gpu_compute_capability();
+ }
+
+ bool IsDeviceless() const {
+ return std::holds_alternative<DevicelessConfig>(config_);
+ }
+
+ bool ExhaustiveTilingSearch() const { return exhaustive_tiling_search_; }
+
+ private:
+ std::variant<DeviceConfig, DevicelessConfig> config_;
+ int32_t autotune_level_;
+ bool should_crash_on_check_failure_;
+ bool exhaustive_tiling_search_;
+ bool require_complete_aot_autotune_results_;
+ mutable std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
+ std::string autotune_cache_dir_;
+};
+
+using AutotuneNoCacheFn = std::function<absl::StatusOr<AutotuneResult>()>;
+
+struct AutotunerUtil {
+ // Create a buffer for a given operation using redzone checker, initialize
+ // based on a given rng state.
+ static absl::StatusOr<se::DeviceMemoryBase> CreateBuffer(
+ se::RedzoneAllocator& allocator, const Shape& shape,
+ const AutotuneConfig& config, int64_t& rng_state);
+
+ static absl::StatusOr<AutotuneResult> Autotune(
+ const HloInstruction* instr, const AutotuneConfig& config,
+ const AutotuneNoCacheFn& autotune_fn);
+
+ // Returns the same cache key that would be used inside Autotune().
+ //
+ // Normally, we don't have to use this low level method.
+ static AutotuneCacheKey GetKey(const HloInstruction* instr,
+ const AutotuneConfig& config);
+
+ // Checks if the key is in the autotune cache.
+ //
+ // Normally, we don't have to use this low level method.
+ static absl::StatusOr<bool> IsInCache(const AutotuneCacheKey& key,
+ const AutotuneConfig& config);
+
+ // Adds the result to the autotune cache.
+ //
+ // Returns true if the entry is inserted.
+ //
+ // Normally, we don't have to use this low level method.
+ static absl::StatusOr<bool> AddResult(const AutotuneCacheKey& key,
+ AutotuneResult result,
+ const AutotuneConfig& config);
+
+ // Creates a RedzoneAllocator from a given config.
+ static absl::StatusOr<se::RedzoneAllocator> CreateRedzoneAllocator(
+ const AutotuneConfig& config, const DebugOptions& opts);
+
+ // Functions to save/load XLA's autotuning results.
+ //
+ // This is used for ahead-of-time autotuning. Specifically:
+ //
+ // When XLA calls cublas (for matmuls, aka "gemm" or "dot") or cudnn (for
+ // convolutions), it usually has to choose an "algorithm" for the particular
+ // dot/conv. XLA queries cublas/cudnn for a list of candidate algorithms.
+ // Then it runs all of them and picks the fastest one. This is what we call
+ // "autotuning". It happens in GemmAlgorithmPicker and GpuConvAlgorithmPicker.
+ //
+ // Autotuning is necessary to get good performance for dot/conv. But it also
+ // has some disadvantages.
+ //
+ // - Because it relies on timing data, it is fundamentally nondeterministic.
+ // But even if two algorithms have similar runtimes, our choice of
+ // algorithm may be visible to the user: Different algorithms can have
+ // different numerics, and sometimes they can even have different bugs!
+ //
+ // - Trying all the candidate algorithms can be slow, especially if when some
+ // of the candidates are "very bad" and run especially slowly compared to
+ // the optimal candidate. This slows down compilation.
+ //
+ // To address the disadvantages above, we allow users to save/restore the
+ // autotuning choices that XLA has made, using the functions below.
+ //
+ // Loading autotuning results does not erase existing autotuning choices, but
+ // in the event of a disagreement between the existing data and the new data,
+ // the new algorithm is chosen.
+ //
+ // Note that even if you call LoadAutotuneResults(), if XLA encounters a
+ // dot/conv that is *not* covered by the loaded data, it will go ahead and
+ // autotune it like normal. In other words, the behavior of XLA should be
+ // identical with or without ahead-of-time autotuning, modulo nondeterminism.
+ //
+ // This is important if you want to be able to use the same autotuning file
+ // with different versions of XLA, because as XLA changes, exactly which
+ // dots/convs it wants to run can also change. For example, XLA might change
+ // the conv padding heuristics it uses, and we don't want that to mean that
+ // all users of ahead-of-time autotuning are broken.
+ static absl::StatusOr<std::string> SerializeAutotuneResults(
+ bool as_textproto = false);
+
+ // Serializes autotune results into the given proto.
+ static absl::Status SerializeAutotuneResults(AutotuneResults* results);
+
+ // Loads autotune results from the given string of bytes.
+ //
+ // Warning: The results are only loaded to the in-memory cache.
+ static absl::Status LoadAutotuneResults(absl::string_view data,
+ bool as_textproto = false);
+
+ // Loads autotune results from the given proto.
+ //
+ // Warning: The results are only loaded to the in-memory cache.
+ static absl::Status LoadAutotuneResults(const AutotuneResults& results);
+
+ // Serializes autotune results into a file.
+ //
+ // If `file_path` ends with ".txt" or ".textproto", then the textproto format
+ // is used, otherwise the binary protobuf format.
+ static absl::Status SerializeAutotuneResultsToFile(
+ absl::string_view file_path);
+
+ // As above, but if you already called SerializeAutotuneResults to get a
+ // proto.
+ static absl::Status SerializeAutotuneResultsToFile(
+ const AutotuneResults& results, absl::string_view file_path);
+
+ // Loads autotune results from a file.
+ //
+ // If `file_path` ends with ".txt" or ".textproto", then the file is
+ // considered to be in the textproto format, otherwise the binary protobuf
+ // format.
+ //
+ // Warning: The results are only loaded to the in-memory cache.
+ static absl::Status LoadAutotuneResultsFromFile(absl::string_view file_path);
+
+ // Warning: This only clears the in-memory cache. If you use a file based
+ // cache you're responsible for clearing the cache directory when you want to.
+ static void ClearAutotuneResults();
+
+ // Warning: This only checks the in-memory cache. If you use a file based
+ // cache, you're responsible for checking whether the cache directory is
+ // empty.
+ static bool ResultCacheIsEmpty();
+};
+
+absl::StatusOr<std::string> AutotuneResultsToString(
+ const AutotuneResults& results, bool as_textproto);
+
+// Exposed only for testing. Returns the SHA-256 hash of the input string,
+// encoded in base64.
+//
+// SHA-256 was chosen to follow industry best practices and avoid collisions.
+// Git is also transitioning to SHA-256. This is probably better than
+// tsl::Fingerprint128.
+absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc
new file mode 100644
index 0000000..974f4d4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc
@@ -0,0 +1,456 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/dump.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/platform_manager.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h" // IWYU pragma: keep
+#include "tsl/platform/path.h"
+#include "tsl/platform/protobuf.h" // IWYU pragma: keep
+#include "tsl/platform/status.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::testing::TempDir;
+using ::tsl::testing::StatusIs;
+
+class AutotunerUtilTest : public HloTestBase {
+ protected:
+ static constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+ p0 = f16[1,16,17,3] parameter(0)
+ p1 = s8[16,17,3] parameter(1)
+ cp1 = f16[16,17,3] convert(p1)
+ ROOT _ = f16[1,16,16] dot(p0, cp1),
+ lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+})";
+
+ static constexpr absl::string_view kResultText = R"(
+version: 3
+results {
+ device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
+ hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
+ result {
+ run_time {
+ nanos: 31744
+ }
+ triton {
+ block_m: 32
+ block_n: 32
+ block_k: 32
+ split_k: 1
+ num_stages: 1
+ num_warps: 4
+ num_ctas: 1
+ }
+ }
+})";
+
+ void SetUp() override { AutotunerUtil::ClearAutotuneResults(); }
+
+ std::string GetUniqueTempFilePath(absl::string_view suffix) {
+ std::string filename = TempDir();
+ CHECK(tsl::Env::Default()->CreateUniqueFileName(&filename,
+ std::string(suffix)));
+ return filename;
+ }
+
+ std::string ExpectToReadNonEmptyFile(absl::string_view file_path) {
+ std::string str;
+ tsl::Env* env = tsl::Env::Default();
+ TF_EXPECT_OK(tsl::ReadFileToString(env, std::string(file_path), &str));
+ EXPECT_THAT(str, Not(IsEmpty()));
+ return str;
+ }
+
+ static stream_executor::StreamExecutor* NewStreamExecutor() {
+ stream_executor::Platform* platform =
+ stream_executor::PlatformManager::PlatformWithName("Host").value();
+ return platform->ExecutorForDevice(/*ordinal=*/0).value();
+ }
+
+ absl::Status PopulateResultCache() {
+ EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+ TF_RETURN_IF_ERROR(AutotunerUtil::LoadAutotuneResults(kResultText, true));
+ EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+ return absl::OkStatus();
+ }
+};
+
+TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto1) {
+ TF_EXPECT_OK(PopulateResultCache());
+ std::string kFilePath = GetUniqueTempFilePath(".txt");
+ TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+
+ std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
+ AutotuneResults results;
+ EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
+ &results));
+ EXPECT_GT(results.results_size(), 0);
+}
+
+TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto2) {
+ TF_EXPECT_OK(PopulateResultCache());
+ std::string kFilePath = GetUniqueTempFilePath(".textproto");
+ TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+
+ std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
+ AutotuneResults results;
+ EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
+ &results));
+}
+
+TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_Protobuf) {
+ TF_EXPECT_OK(PopulateResultCache());
+ std::string kFilePath = GetUniqueTempFilePath(".pb");
+ TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+
+ std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
+ AutotuneResults results;
+ EXPECT_TRUE(results.ParseFromString(autotune_results_str));
+}
+
+TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto1) {
+ TF_EXPECT_OK(PopulateResultCache());
+ std::string kFilePath = GetUniqueTempFilePath(".txt");
+ TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+ AutotunerUtil::ClearAutotuneResults();
+ EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+
+ TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
+ EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto2) {
+ TF_EXPECT_OK(PopulateResultCache());
+ std::string kFilePath = GetUniqueTempFilePath(".textproto");
+ TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+ AutotunerUtil::ClearAutotuneResults();
+ EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+
+ TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
+ EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_Protobuf) {
+ TF_EXPECT_OK(PopulateResultCache());
+ std::string kFilePath = GetUniqueTempFilePath(".pb");
+ TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+ AutotunerUtil::ClearAutotuneResults();
+ EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+
+ TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
+ EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(AutotunerUtilTest, ResultConflictsAreDetected) {
+ TF_EXPECT_OK(PopulateResultCache());
+ std::string kFilePath = GetUniqueTempFilePath(".pb");
+ TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+ EXPECT_THAT(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath),
+ StatusIs(absl::StatusCode::kInternal,
+ HasSubstr("Duplicate autotuning result")));
+}
+
+// Test that when complete AOT autotuning is required, and there is cache miss,
+// a `NotFound` error will be raised.
+TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) {
+ std::string kFilePath = GetUniqueTempFilePath(".txt");
+ auto hlo_module = GetOptimizedModule(kHloText);
+ TF_EXPECT_OK(hlo_module.status());
+ std::vector<HloComputation*> computations =
+ (*hlo_module)
+ ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
+ EXPECT_THAT(computations, Not(IsEmpty()));
+ const HloInstruction* instruction = *computations[0]->instructions().begin();
+ stream_executor::StreamExecutor* executor = NewStreamExecutor();
+ auto options = DebugOptions();
+ options.set_xla_gpu_require_complete_aot_autotune_results(true);
+ AutotuneConfig config(DeviceConfig{executor}, options);
+ EXPECT_THAT(
+ AutotunerUtil::Autotune(instruction, config,
+ [&] { return AutotuneResult(); }),
+ StatusIs(
+ absl::StatusCode::kNotFound,
+ HasSubstr("Complete XLA AOT autotuning results are required, but "
+ "no AOT result was found for key: <key model")));
+}
+
+// Test that when JIT autotuning is disabled, but no cache miss due to AOT
+// autotuning, `Autotune` still returns Ok status.
+TEST_F(AutotunerUtilTest, OkIfJitAutotuningDisabledButAlreadyLoadedAOT) {
+ auto hlo_module = GetOptimizedModule(kHloText);
+ std::vector<HloComputation*> computations =
+ (*hlo_module)
+ ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
+ EXPECT_THAT(computations, Not(IsEmpty()));
+ const HloInstruction* instruction = *computations[0]->instructions().begin();
+ stream_executor::StreamExecutor* executor = NewStreamExecutor();
+
+ {
+ // By default, JIT autotuning is OK.
+ AutotuneConfig config(DeviceConfig{executor}, DebugOptions());
+ TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
+ return AutotuneResult();
+ }).status());
+ }
+
+ // Now require complete AOT autotuning results.
+ auto options = DebugOptions();
+ options.set_xla_gpu_require_complete_aot_autotune_results(true);
+
+ AutotuneConfig config(DeviceConfig{executor}, options);
+ // Even though JIT autotuning is disabled, there is no cache miss when running
+ // autotuning for the same entry, so no error should be raised either.
+ TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
+ return AutotuneResult();
+ }).status());
+}
+
+class FileBasedCacheTest : public AutotunerUtilTest {
+ public:
+ static std::string ToString(const AutotuneResult& message) {
+ std::string textproto;
+ CHECK(tsl::protobuf::TextFormat::PrintToString(message, &textproto));
+ return textproto;
+ }
+
+ static std::vector<std::string> GetFilesInDir(
+ const absl::string_view cache_dir) {
+ std::vector<std::string> files_in_cache;
+ TF_CHECK_OK(tsl::Env::Default()->GetChildren(std::string(cache_dir),
+ &files_in_cache));
+ return files_in_cache;
+ }
+
+ static std::string Read(const absl::string_view filepath) {
+ std::string file_content;
+ TF_CHECK_OK(tsl::ReadFileToString(tsl::Env::Default(),
+ std::string(filepath), &file_content));
+ return file_content;
+ }
+
+ void Write(const absl::string_view filepath,
+ const absl::string_view content) {
+ TF_CHECK_OK(CreateDirIfNeeded(cache_dir_, tsl::Env::Default()));
+ TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(),
+ std::string(filepath), content));
+ }
+
+ stream_executor::StreamExecutor* executor_ = NewStreamExecutor();
+ std::unique_ptr<HloModule> module_ =
+ ParseAndReturnVerifiedModule(kHloText).value();
+ const HloInstruction* dot_ = hlo_query::GetFirstInstructionWithOpcode(
+ *module_->entry_computation(), HloOpcode::kDot);
+ std::string cache_dir_ = [] {
+ tsl::Env* default_env = tsl::Env::Default();
+ std::string cache_dir;
+ CHECK(default_env->LocalTempFilename(&cache_dir));
+ return cache_dir;
+ }();
+ AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_}, [&] {
+ DebugOptions options;
+ options.set_xla_gpu_per_fusion_autotune_cache_dir(cache_dir_);
+ return options;
+ }());
+ AutotuneCacheKey cache_key_ = AutotunerUtil::GetKey(dot_, config_);
+ std::string cache_filename_ = [&] {
+ absl::StatusOr<std::string> key_hash =
+ GetBase64EncodedSha256Hash(cache_key_.ToString());
+ CHECK_OK(key_hash.status());
+ return absl::StrCat(key_hash.value(), ".textproto");
+ }();
+ std::string cache_file_path_ = tsl::io::JoinPath(cache_dir_, cache_filename_);
+ const AutotuneResult result1_ = [] {
+ AutotuneResult result;
+ result.set_scratch_bytes(1);
+ return result;
+ }();
+ const AutotuneResult result2_ = [] {
+ AutotuneResult result;
+ result.set_scratch_bytes(2);
+ return result;
+ }();
+};
+
+TEST_F(FileBasedCacheTest, AutotuneWritesResultToTheCacheDir) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ AutotuneResult result,
+ AutotunerUtil::Autotune(dot_, config_, [&] { return result1_; }));
+ EXPECT_EQ(ToString(result), ToString(result1_));
+
+ ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
+ EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
+}
+
+TEST_F(FileBasedCacheTest, AutotuneReadsResultFromTheCacheDir) {
+ Write(cache_file_path_, ToString(result1_));
+
+ bool cache_hit = true;
+ TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
+ AutotunerUtil::Autotune(dot_, config_, [&] {
+ cache_hit = false;
+ return result2_;
+ }));
+
+ EXPECT_TRUE(cache_hit);
+ EXPECT_EQ(ToString(result), ToString(result1_));
+}
+
+TEST_F(FileBasedCacheTest,
+ RepeatedAutotuneCallsDontReadOrWriteTheCacheFileAgain) {
+ auto check_autotune_cache_hit = [](const HloInstruction* instr,
+ const AutotuneConfig& config,
+ const AutotuneResult& expected_result) {
+ bool cache_hit = true;
+ TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
+ AutotunerUtil::Autotune(instr, config, [&] {
+ cache_hit = false;
+ AutotuneResult new_result;
+ new_result.set_scratch_bytes(2);
+ return new_result;
+ }));
+ EXPECT_TRUE(cache_hit);
+ EXPECT_EQ(ToString(result), ToString(expected_result));
+ };
+
+ Write(cache_file_path_, ToString(result1_));
+ check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
+
+ constexpr absl::string_view kPlaceholderContent = "placeholder content";
+ Write(cache_file_path_, kPlaceholderContent);
+ // File was not read again:
+ check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
+ // File was not written again:
+ EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
+}
+
+TEST_F(FileBasedCacheTest,
+ IsInCacheReturnsTrueIfTheResultIsInTheFileBasedCache) {
+ Write(cache_file_path_, ToString(result1_));
+
+ TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
+ AutotunerUtil::IsInCache(cache_key_, config_));
+
+ EXPECT_TRUE(is_in_cache);
+}
+
+TEST_F(FileBasedCacheTest, IsInCacheReturnsFalseIfTheResultIsNotInEitherCache) {
+ TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
+ AutotunerUtil::IsInCache(cache_key_, config_));
+
+ EXPECT_FALSE(is_in_cache);
+}
+
+TEST_F(FileBasedCacheTest, AddResultAddsTheResultToTheFileBasedCache) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
+ EXPECT_TRUE(added);
+
+ ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
+ EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
+}
+
+TEST_F(FileBasedCacheTest, RepeatedAddResultDoesNotWriteTheFileAgain) {
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
+ EXPECT_TRUE(added);
+ }
+ ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
+ EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
+ constexpr absl::string_view kPlaceholderContent = "placeholder content";
+ Write(cache_file_path_, kPlaceholderContent);
+
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
+ EXPECT_FALSE(added);
+ }
+
+ // File was not written again:
+ EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
+}
+
+TEST(AutotuneCacheKeyTest, DeviceDescriptionToCacheKey) {
+ auto device_description =
+ [](absl::string_view spec_file_name) -> se::DeviceDescription {
+ se::GpuTargetConfigProto proto;
+ std::string spec_string;
+ CHECK_OK(tsl::ReadFileToString(
+ tsl::Env::Default(),
+ tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "hlo_opt",
+ "gpu_specs", spec_file_name),
+ &spec_string));
+ EXPECT_TRUE(
+ tsl::protobuf::TextFormat::ParseFromString(spec_string, &proto));
+ return se::DeviceDescription(proto.gpu_device_info());
+ };
+
+ EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
+ device_description("a100_sxm_40.txtpb")),
+ "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
+ "1555 GB/s, L2 cache: 40 MB");
+
+ EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
+ device_description("a100_sxm_80.txtpb")),
+ "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
+ "2039 GB/s, L2 cache: 40 MB");
+
+ EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
+ device_description("mi200.txtpb")),
+ "ROCM: gfx90a, Cores: 110, GPU clock: 1.7 GHz, Memory bandwidth: "
+ "1638 GB/s, L2 cache: 8 MB");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc
new file mode 100644
index 0000000..6173915
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc
@@ -0,0 +1,1194 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "xla/autotuning.pb.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/gpu_autotuning.pb.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/gpu_conv_runner.h"
+#include "xla/service/gpu/hlo_algorithm_denylist.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/slow_operation_alarm.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/cuda/cuda_platform_id.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/lazy_op_runner.h"
+#include "xla/stream_executor/numeric_options.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/rocm/rocm_platform_id.h"
+#include "xla/stream_executor/scratch_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/tsl/util/env_var.h"
+#include "xla/tsl/util/proto/proto_utils.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/numbers.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#if CUDNN_VERSION >= 90000
+#include "third_party/gpus/cudnn/cudnn_ops.h"
+#else
+#include "third_party/gpus/cudnn/cudnn_ops_infer.h"
+#endif // CUDNN_VERSION >= 90000
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#endif
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using se::DeviceMemoryBase;
+using se::dnn::AlgorithmDesc;
+using std::optional;
+
+class ScratchAllocator : public se::ScratchAllocator {
+ public:
+ ScratchAllocator(int device_ordinal,
+ se::DeviceMemoryAllocator* memory_allocator)
+ : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
+
+ int64_t GetMemoryLimitInBytes() override {
+ return ScratchAllocator::GetDefaultMemoryLimitInBytes();
+ }
+ int64_t TotalAllocatedBytes() { return total_allocated_bytes_; }
+
+ static int64_t GetDefaultMemoryLimitInBytes() {
+ int64_t value;
+ TF_CHECK_OK(tsl::ReadInt64FromEnvVar("TF_CUDNN_WORKSPACE_LIMIT_IN_MB",
+ 1LL << 12, &value));
+ return value * (1LL << 20);
+ }
+
+ absl::StatusOr<se::DeviceMemory<uint8_t>> AllocateBytes(
+ int64_t byte_size) override;
+
+ template <typename T>
+ absl::StatusOr<se::DeviceMemory<T>> Allocate(int64_t num_elements) {
+ TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8_t> bytes,
+ AllocateBytes(num_elements * sizeof(T)));
+ return se::DeviceMemory<T>(bytes);
+ }
+
+ private:
+ const int device_ordinal_;
+ se::DeviceMemoryAllocator* memory_allocator_;
+ std::vector<se::OwningDeviceMemory> allocated_buffers_;
+ int64_t total_allocated_bytes_ = 0;
+};
+
+absl::StatusOr<se::DeviceMemory<uint8_t>> ScratchAllocator::AllocateBytes(
+ int64_t byte_size) {
+ CHECK_GE(byte_size, 0) << "byte_size must be positive.";
+ if (byte_size > GetMemoryLimitInBytes()) {
+ return absl::ResourceExhaustedError(absl::StrFormat(
+ "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size,
+ GetMemoryLimitInBytes()));
+ }
+
+ TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
+ memory_allocator_->Allocate(device_ordinal_, byte_size,
+ /*retry_on_failure=*/false));
+ total_allocated_bytes_ += byte_size;
+
+ se::DeviceMemoryBase buffer_addr = *allocated_buffer;
+ allocated_buffers_.push_back(std::move(allocated_buffer));
+ return se::DeviceMemory<uint8_t>(buffer_addr);
+}
+
+absl::StatusOr<std::vector<GenericConvRunner>> GetAlgorithms(
+ const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend,
+ bool use_fallback, const se::NumericOptions& numeric_options) {
+ TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
+ GetDNNConvKindFromCudnnConvKind(config.kind));
+
+ TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type,
+ GetDNNDataTypeFromPrimitiveType(config.input_type));
+
+ TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type,
+ GetDNNDataTypeFromPrimitiveType(config.output_type));
+
+ se::StreamExecutor* stream_exec = stream->parent();
+ std::vector<GenericConvRunner> result;
+
+ auto dnn = stream_exec->AsDnn();
+ if (dnn == nullptr) {
+ return absl::InvalidArgumentError("No DNN in stream executor.");
+ }
+ switch (kind) {
+ default:
+ return Internal("Unknown ConvolutionKind %d", kind);
+ case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: {
+ if (!config.fusion) {
+ return Internal(
+ "GpuConvConfig had fusion ConvolutionKind but no FusionConfig.");
+ }
+ std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>> runners;
+ TF_RETURN_IF_ERROR(dnn->GetFusedConvolveRunners(
+ use_cudnn_frontend,
+ // This refers to the kind of convolution op inside the fusion, not
+ // the whole fused graph.
+ se::dnn::ConvolutionKind::FORWARD, input_type,
+ BiasTypeForInputType(input_type), output_type,
+ /* conv_input_scale = */ config.conv_result_scale,
+ /* side_input_scale = */ config.fusion->side_input_scale,
+ /* leakyrelu_alpha = */ config.fusion->leakyrelu_alpha, stream,
+ config.input_descriptor, config.filter_descriptor,
+ config.bias_descriptor, config.output_descriptor, config.conv_desc,
+ use_fallback, config.fusion->mode, numeric_options, &runners));
+ for (auto& runner : runners) {
+ TF_ASSIGN_OR_RETURN(
+ auto runner_cache,
+ se::dnn::LazyOpRunner<se::dnn::FusedConvOp>::FromOpRunner(
+ std::move(runner)));
+ result.emplace_back(std::move(runner_cache));
+ }
+ break;
+ }
+
+ case se::dnn::ConvolutionKind::FORWARD_GRAPH: {
+ std::vector<std::unique_ptr<const se::dnn::GraphConvRunner>> runners;
+ // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
+ // allocator are unused; so, they're all provided as nullptr.
+ TF_RETURN_IF_ERROR(dnn->GetGraphConvolveRunners(
+ kind, input_type, output_type, stream, config.input_descriptor,
+ config.filter_descriptor, config.output_descriptor, config.conv_desc,
+ use_fallback, numeric_options, &runners, config.serialized_graph));
+ for (auto& runner : runners) {
+ TF_ASSIGN_OR_RETURN(
+ auto runner_cache,
+ se::dnn::LazyOpRunner<se::dnn::GraphConvOp>::FromOpRunner(
+ std::move(runner)));
+ result.emplace_back(std::move(runner_cache));
+ }
+ break;
+ }
+
+ case se::dnn::ConvolutionKind::FORWARD:
+ case se::dnn::ConvolutionKind::BACKWARD_DATA:
+ case se::dnn::ConvolutionKind::BACKWARD_FILTER: {
+ std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
+ // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
+ // allocator are unused; so, they're all provided as nullptr.
+ TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
+ use_cudnn_frontend, kind, input_type, output_type, stream,
+ config.input_descriptor,
+ /* input_data = */ DeviceMemoryBase(nullptr),
+ config.filter_descriptor,
+ /* filter_data = */ DeviceMemoryBase(nullptr),
+ config.output_descriptor,
+ /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
+ use_fallback, nullptr, numeric_options, &runners));
+
+ for (auto& runner : runners) {
+ TF_ASSIGN_OR_RETURN(
+ auto runner_cache,
+ se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
+ std::move(runner)));
+ result.emplace_back(std::move(runner_cache));
+ }
+ break;
+ }
+ }
+
+ return result;
+}
+
+absl::StatusOr<std::vector<std::unique_ptr<const se::dnn::ConvRunner>>>
+GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ absl::Span<se::DeviceMemoryBase> result_buffers,
+ se::StreamExecutor* stream_exec,
+ ScratchAllocator* scratch_allocator, se::Stream* stream,
+ const se::NumericOptions& numeric_options) {
+ TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
+
+ TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
+ GetDNNConvKindFromCudnnConvKind(config.kind));
+
+ TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
+ GetDNNDataTypeFromPrimitiveType(config.output_type));
+
+ TF_ASSIGN_OR_RETURN(
+ GpuConvParams params,
+ GetGpuConvParams(config, operand_buffers, result_buffers));
+
+ std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
+ auto dnn = stream_exec->AsDnn();
+ if (dnn == nullptr) {
+ return absl::InvalidArgumentError("No DNN in stream executor.");
+ }
+ TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
+ /* use_cudnn_frontend = */ false, kind, dtype, dtype, stream,
+ params.config->input_descriptor, params.input_buf,
+ params.config->filter_descriptor, params.filter_buf,
+ params.config->output_descriptor, params.output_buf,
+ params.config->conv_desc,
+ /* use_fallback = */ false, scratch_allocator, numeric_options,
+ &runners));
+
+ return runners;
+}
+
+std::string NumBytesToString(int64_t bytes) {
+ return absl::StrCat(tsl::strings::HumanReadableNumBytes(bytes), " (", bytes,
+ "B)");
+}
+
+CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
+ se::dnn::VersionInfo version = GetDnnVersionInfoOrDefault(stream_executor);
+ CudnnVersion cudnn_version;
+ cudnn_version.set_major(version.major_version());
+ cudnn_version.set_minor(version.minor_version());
+ cudnn_version.set_patch(version.patch());
+
+ return cudnn_version;
+}
+
+ComputeCapability GetComputeCapability(se::StreamExecutor* stream_executor) {
+ ComputeCapability cc;
+ se::CudaComputeCapability se_cc =
+ stream_executor->GetDeviceDescription().cuda_compute_capability();
+ cc.set_major(se_cc.major);
+ cc.set_minor(se_cc.minor);
+ return cc;
+}
+
+void PrintPlatformInfo(const se::Stream* stream) {
+ auto* se = stream->parent();
+ const auto& desc = se->GetDeviceDescription();
+ LOG(ERROR) << "Device: " << desc.name();
+ LOG(ERROR) << "Platform: " << desc.platform_version();
+ LOG(ERROR) << "Driver: " << desc.driver_version();
+ LOG(ERROR) << "Runtime: " << desc.runtime_version();
+
+ auto dnn_version = GetDnnVersionInfo(se);
+ if (dnn_version.ok()) {
+ auto v = dnn_version.value();
+ LOG(ERROR) << "cudnn version: " << v.major_version() << "."
+ << v.minor_version() << "." << v.patch();
+ }
+}
+
+// Returns true if the redzones in `allocator`'s allocations are unmodified.
+//
+// If the redzones are modified, logs an error, sets the appropriate failure
+// bits on `result`, and returns false.
+//
+// Returns a absl::Status if an unexpected error has occurred, and the stream
+// has been poisoned.
+//
+// `name` is a user-friendly name for the set of redzones being checked, e.g.
+// "input/output" or "scratch".
+absl::StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
+ se::Stream* stream, absl::string_view name,
+ std::string_view instr_str,
+ AutotuneResult* result) {
+ XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
+ 2);
+ using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
+ TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
+ allocator.CheckRedzones());
+ if (redzone_check.ok()) {
+ return true;
+ }
+
+ auto* fail = result->mutable_failure();
+ fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
+ *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
+ fail->set_buffer_address(
+ reinterpret_cast<uint64_t>(redzone_check.user_buffer_address));
+
+ LOG(ERROR) << absl::StreamFormat(
+ "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
+ "cudnn bug. We will skip this algorithm in the future, but your GPU "
+ "state may already be corrupted, leading to incorrect results. Within "
+ "Google, no action is needed on your part. Outside of Google, please "
+ "ensure you're running the latest version of cudnn. If that doesn't fix "
+ "the problem, please file a bug with this full error message and we'll "
+ "contact nvidia.",
+ name);
+ LOG(ERROR) << redzone_check.RedzoneFailureMsg();
+ LOG(ERROR) << "HloInstruction " << instr_str;
+ PrintPlatformInfo(stream);
+ return false;
+}
+
+} // anonymous namespace
+
+bool ShouldInitConvData(const HloModuleConfig& hlo_module_config) {
+ const int32_t conv_autotune_level =
+ hlo_module_config.debug_options().xla_gpu_autotune_level();
+ return conv_autotune_level >= 2;
+}
+
+bool ShouldCheckConv(const HloModuleConfig& hlo_module_config) {
+ const int32_t conv_autotune_level =
+ hlo_module_config.debug_options().xla_gpu_autotune_level();
+ return conv_autotune_level >= 4;
+}
+
+absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
+ const HloCustomCallInstruction* instr) {
+ return AutotunerUtil::Autotune(
+ instr, config_, [&] { return PickBestAlgorithmNoCache(instr); });
+}
+
+absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithmNoCache(
+ const HloCustomCallInstruction* instr) {
+ if (config_.IsDeviceless()) {
+ // Return an autotune result with algo id -1, which means that we autotune
+ // at runtime.
+ AutotuneResult result;
+ result.mutable_algorithm()->set_algo_id(-1);
+ return result;
+ }
+
+ se::StreamExecutor* stream_exec = config_.GetExecutor();
+ // Don't run this function concurrently on the same GPU.
+ //
+ // This is a bit of a hack and doesn't protect us against arbitrary concurrent
+ // use of a GPU, but it's sufficient to let us compile two HLO modules
+ // concurrently and then run them sequentially.
+ //
+ // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
+ // avoid ever doing duplicate work. If we have a cache miss, only one thread
+ // will run PickBestAlgorithmImpl for a particular device.
+ absl::MutexLock lock(&GetGpuMutex(stream_exec));
+
+ // Make sure any previous activity on this executor is done. We don't want
+ // other work still running on the GPU to interfere with autotuning.
+ if (!stream_exec->SynchronizeAllActivity()) {
+ return Internal(
+ "Failed to synchronize GPU for autotuning conv instruction");
+ }
+
+ absl::StatusOr<AutotuneResult> result_or(Internal("Unknown platform."));
+ // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
+ // have diverged. Specifically, we need to make sure redzone allocator related
+ // utilities are not used in ROCm routine
+ se::Platform::Id platform_id = stream_exec->GetPlatform()->id();
+ if (platform_id == se::rocm::kROCmPlatformId) {
+ result_or = PickBestAlgorithmNoCacheRocm(instr);
+ } else if (platform_id == se::cuda::kCudaPlatformId) {
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+ result_or = PickBestAlgorithmNoCacheCuda(instr);
+#endif
+ }
+
+ return result_or;
+}
+
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+
+absl::StatusOr<GpuConvAlgorithmPicker::AutotuneRuntimeArguments>
+GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction(
+ const HloCustomCallInstruction* instr, const AutotuneConfig& config,
+ const DebugOptions& debug_options) {
+ TF_ASSIGN_OR_RETURN(auto rz_buffers,
+ RedzoneBuffers::FromInstruction(
+ *instr, config, debug_options,
+ RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+ // Get canonical HLO.
+ std::string canonical_hlo(
+ AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr)
+ .GetHlo());
+
+ TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr));
+
+ GpuConvAlgorithmPicker::AutotuneRuntimeArguments runtime_arguments = {
+ instr->GetModule()->config(),
+ std::move(rz_buffers),
+ std::move(gpu_conv_config),
+ {canonical_hlo}};
+
+ return runtime_arguments;
+}
+
+struct CudnnVersionRange {
+ using TupleVersion = std::tuple<int, int, int>;
+ TupleVersion begin;
+ TupleVersion end;
+
+ bool IsInRange(const CudnnVersion& other) const {
+ TupleVersion other_version{other.major(), other.minor(), other.patch()};
+ return begin <= other_version && other_version < end;
+ }
+
+ CudnnVersionRange(const CudnnVersion& begin, const CudnnVersion& end)
+ : begin(begin.major(), begin.minor(), begin.patch()),
+ end(end.major(), end.minor(), end.patch()) {}
+
+ CudnnVersionRange(const TupleVersion& begin, const TupleVersion& end)
+ : begin(begin), end(end) {}
+};
+
+struct ComputeCapabilityRange {
+ using TupleComputeCapability = std::tuple<int, int>;
+ TupleComputeCapability begin;
+ TupleComputeCapability end;
+
+ bool IsInRange(const ComputeCapability& other) const {
+ TupleComputeCapability other_cc{other.major(), other.minor()};
+ return begin <= other_cc && other_cc < end;
+ }
+};
+
+struct DisabledAlgorithm {
+ CudnnVersionRange cudnn_version_range;
+ ComputeCapabilityRange compute_capability_range;
+ int algo_id;
+};
+
+// TODO(b/343101418): Remove this once the bug is fixed in upstream cudnn and
+// once we updated to that cudnn version.
+static const DisabledAlgorithm kDisabledAlgorithms[] = {
+ {/*.cudnn_version_range=*/{/*.begin=*/{9, 0, 0}, /*.end=*/{10, 0, 0}},
+ /*.compute_capability_range=*/{/*.begin=*/{6, 0}, /*.end=*/{8, 0}},
+ /*.algo_id=*/14}};
+
+// There are three tiers of errors possible here: returning a failed
+// absl::StatusOr means autotuning fails immediately; returning an
+// AutotuneResult with a failure code other than DISQUALIFIED means autotuning
+// fails if crash_on_checking_failure is set; and returning a DISQUALIFIED
+// AutotuneResult simply skips the engine/algorithm while recording a reason for
+// skipping it.
+absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
+ GenericConvRunner* const runner,
+ std::optional<ReferenceResult>* reference_result,
+ absl::Span<const AlgorithmDesc> disabled_algos,
+ std::optional<AutotuneCacheKey> instruction_info,
+ const AutotuneRuntimeArguments& runtime_arguments) {
+ auto alg = runner->ToAlgorithmDesc();
+
+ se::StreamExecutor* stream_exec = config_.GetExecutor();
+ XLA_SCOPED_LOGGING_TIMER_LEVEL(
+ absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
+ alg.ToString()),
+ 2);
+
+ auto make_failure = [&alg](AutotuneResult::FailureKind kind,
+ absl::string_view msg) {
+ AutotuneResult result;
+ *result.mutable_algorithm() = alg.ToProto();
+ result.mutable_failure()->set_kind(kind);
+ result.mutable_failure()->set_msg(/* *sigh* */ msg.data(), msg.size());
+ return result;
+ };
+
+ AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);
+
+ std::string instr_str = instruction_info.has_value()
+ ? std::string(instruction_info->GetHlo())
+ : "<unknown>";
+
+ for (const auto& disabled_algo : kDisabledAlgorithms) {
+ if (disabled_algo.cudnn_version_range.IsInRange(
+ GetCudnnVersion(stream_exec)) &&
+ disabled_algo.compute_capability_range.IsInRange(
+ GetComputeCapability(stream_exec)) &&
+ disabled_algo.algo_id == alg.algo_id()) {
+ LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
+ << " for conv " << instr_str;
+ return make_failure(AutotuneResult::DISQUALIFIED,
+ "Disqualified for being known-buggy.");
+ }
+ }
+
+ if (absl::c_linear_search(disabled_algos, alg_key)) {
+ LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
+ << " for conv " << instr_str;
+ return make_failure(AutotuneResult::DISQUALIFIED,
+ "Disqualified for being known-buggy.");
+ }
+
+ GpuConvConfig config = runtime_arguments.gpu_conv_config;
+ auto activation_mode =
+ config.fusion ? config.fusion->mode : se::dnn::ActivationMode::kNone;
+
+ // For fused convolutions with the identity function as the activation, only
+ // ALGO_IMPLICIT_PRECOMP_GEMM does the right thing. Other algorithms
+ // silently do Relu. See
+ // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
+ //
+ // For cuDNN Frontend, there is no way to check whether we're using a broken
+ // algorithm, so on versions where some algorithms are broken, we don't use
+ // the cuDNN Frontend for these convs at all. As such, if we get a
+ // frontend-based runner, we can be sure it's not one of the broken
+ // algorithms we're checking for.
+ if (!alg.is_cudnn_frontend() &&
+ config.kind == CudnnConvKind::kForwardActivation &&
+ activation_mode == se::dnn::ActivationMode::kNone &&
+ alg.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
+ return make_failure(AutotuneResult::DISQUALIFIED,
+ "Disqualified for implicit RELU.");
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ se::RedzoneAllocator scratch_allocator,
+ AutotunerUtil::CreateRedzoneAllocator(
+ config_, runtime_arguments.hlo_module_config.debug_options()));
+
+ se::dnn::ProfileResult profile_result;
+ VLOG(4) << "Trying algorithm " << alg.ToString() << " for " << instr_str;
+
+ SlowOperationAlarm alarm(absl::Seconds(1), [&] {
+ return absl::StrFormat(
+ "Trying algorithm %s for conv %s is taking a while...", alg.ToString(),
+ instr_str);
+ });
+
+ std::optional<size_t> workspace_size =
+ runner->ToAlgorithmDesc().workspace_size();
+ if (!workspace_size) {
+ return make_failure(AutotuneResult::UNKNOWN,
+ "Internal error: missing workspace size from "
+ "OpRunner::ToAlgorithmDesc()");
+ }
+
+ auto scratch_or = scratch_allocator.AllocateBytes(*workspace_size);
+ if (!scratch_or.ok()) {
+ return make_failure(AutotuneResult::DISQUALIFIED,
+ absl::StrCat("Scratch allocation failed: ",
+ scratch_or.status().ToString()));
+ }
+ se::DeviceMemoryBase scratch_memory = scratch_or.value();
+
+ // Use assignment instead of brace-list to make GCC 4.9 happy.
+ RunConvOptions options;
+ options.runner_cache = runner;
+ // The following plan timing code is based on
+ // https://github.com/NVIDIA/cudnn-frontend/blob/60496f42fdc7a4ccc059f5934e306e728a756755/include/cudnn_frontend_find_plan.h
+ float max_time = 0;
+ float min_time = std::numeric_limits<float>::max();
+ absl::Status launch_status;
+ std::vector<se::DeviceMemoryBase> operand_buffers =
+ runtime_arguments.rz_buffers.input_buffers();
+ std::vector<se::DeviceMemoryBase> result_buffers =
+ runtime_arguments.rz_buffers.output_buffers();
+
+ TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+
+ // Dry-run to warmup the plan.
+ launch_status = RunGpuConv(config, operand_buffers, result_buffers,
+ scratch_memory, stream, options);
+ // Flag that a warm-up run has been executed; this allows the GpuTimer for
+ // the main measurement to safely use the delay kernel pattern, even if lazy
+ // module loading is enabled.
+ options.profile_result = &profile_result;
+ profile_result.set_warmup_run_executed(true);
+ constexpr int kMaxIter = 10;
+ // Iterate until the new measurement is within kThreshold of the current
+ // minimum.
+ int num_iters = 0;
+ for (; num_iters < kMaxIter && launch_status.ok(); ++num_iters) {
+ launch_status = RunGpuConv(config, operand_buffers, result_buffers,
+ scratch_memory, stream, options);
+ if (!profile_result.is_valid()) {
+ break;
+ }
+ float old_min_time = min_time;
+ min_time = std::min(min_time, profile_result.elapsed_time_in_ms());
+ max_time = std::max(max_time, profile_result.elapsed_time_in_ms());
+
+ constexpr float kThreshold = 0.05f;
+ if (std::abs(profile_result.elapsed_time_in_ms() - old_min_time) /
+ old_min_time <
+ kThreshold) {
+ break;
+ }
+ }
+ if (!launch_status.ok()) {
+ VLOG(5) << "Launch failed: " << launch_status;
+ return make_failure(
+ AutotuneResult::DISQUALIFIED,
+ absl::StrCat("Profiling failure on cuDNN engine ", alg.ToString(), ": ",
+ launch_status.ToString()));
+ }
+ if (!profile_result.is_valid()) {
+ VLOG(5) << "Launch succeeded but profile result is invalid.";
+ // Not DISQUALIFIED: this means something went wrong internally.
+ return make_failure(
+ AutotuneResult::UNKNOWN,
+ absl::StrCat("Launch succeeded but profile result is invalid, "
+ "with cuDNN engine ",
+ alg.ToString(), ": ", launch_status.ToString()));
+ }
+ VLOG(4) << "Best time: " << min_time << " ms. Worst time: " << max_time
+ << " ms. Total iterations: " << num_iters;
+ int64_t scratch_bytes_used =
+ scratch_allocator.TotalAllocatedBytesExcludingRedzones();
+
+ AutotuneResult result;
+ *result.mutable_algorithm() = alg.ToProto();
+ result.set_scratch_bytes(scratch_bytes_used);
+ *result.mutable_run_time() =
+ tsl::proto_utils::ToDurationProto(absl::Milliseconds(min_time));
+
+ if (!ShouldCheckConv(runtime_arguments.hlo_module_config)) {
+ if (!reference_result->has_value()) {
+ (*reference_result) = {
+ alg, std::vector<DeviceMemoryBase>(result_buffers.size())};
+ }
+ return result;
+ }
+
+ // Check for writes to redzones.
+ TF_ASSIGN_OR_RETURN(
+ bool input_output_allocator_redzone_clear,
+ CheckRedzones(runtime_arguments.rz_buffers.RedzoneAllocator(), stream,
+ "input/output", instr_str, &result));
+
+ TF_ASSIGN_OR_RETURN(
+ bool scratch_allocator_redzone_clear,
+ CheckRedzones(scratch_allocator, stream, "scratch", instr_str, &result));
+
+ if (!input_output_allocator_redzone_clear ||
+ !scratch_allocator_redzone_clear) {
+ if (runtime_arguments.canonical_hlo.has_value()) {
+ std::string canonical_hlo = runtime_arguments.canonical_hlo.value();
+ std::string blas_version;
+ if (auto* blas = stream_exec->AsBlas()) {
+ (void)blas->GetVersion(&blas_version);
+ }
+
+ AlgorithmDenylist proto;
+ auto entry = proto.add_entries();
+ entry->set_hlo(canonical_hlo);
+ *entry->mutable_cc() = GetComputeCapability(stream_exec);
+ *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec);
+ entry->set_blas_version(blas_version);
+ auto algo = entry->add_algos();
+ algo->set_id(alg.algo_id());
+ algo->set_tensor_ops(alg.tensor_ops_enabled());
+
+ LOG(ERROR) << "To denylist this algorithm for this convolution, "
+ "copy-paste the following "
+ "proto to the denylist file pointed by XLA_FLAGS "
+ "--xla_gpu_algorithm_denylist_path="
+ << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
+ << " : " << proto.ShortDebugString();
+ }
+
+ // CheckRedzones has modified the result in-place to include a failure.
+ return result;
+ }
+
+ if (reference_result->has_value()) {
+ XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
+
+ const DebugOptions& debug_options =
+ runtime_arguments.hlo_module_config.debug_options();
+ BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(),
+ debug_options.xla_gpu_autotune_gemm_rtol());
+ for (int i = 0; i < result_buffers.size(); ++i) {
+ absl::StatusOr<bool> compare_result = comparator.CompareEqual(
+ stream, (*reference_result)->buffers[i], result_buffers[i]);
+ if (!compare_result.ok()) {
+ LOG(ERROR) << "Unable to compare "
+ << (*reference_result)->algorithm.ToString() << " against "
+ << alg.ToString() << " for " << instr_str << ": "
+ << compare_result.status();
+ if (compare_result.status().code() ==
+ absl::StatusCode::kResourceExhausted) {
+ // Possibly OOM. Propagate the error.
+ return compare_result.status();
+ }
+ CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
+ } else if (!compare_result.value()) {
+ LOG(ERROR)
+ << "Results mismatch between different convolution algorithms. "
+ "This is likely a bug/unexpected loss of precision in cudnn.\n"
+ << instr_str << " for " << (*reference_result)->algorithm.ToString()
+ << " vs " << alg.ToString();
+ PrintPlatformInfo(stream);
+ if (instruction_info.has_value()) {
+ VLOG(2) << "Full module on failure: \n"
+ << instruction_info->GetModelStr();
+ }
+ auto* fail = result.mutable_failure();
+ fail->set_kind(AutotuneResult::WRONG_RESULT);
+ fail->set_buffer_address(
+ reinterpret_cast<uint64_t>(result_buffers[i].opaque()));
+ *fail->mutable_reference_algorithm() =
+ (*reference_result)->algorithm.ToProto();
+ }
+ }
+ } else {
+ XLA_SCOPED_LOGGING_TIMER_LEVEL("Memcpy Reference Result", 2);
+ std::vector<DeviceMemoryBase> reference_result_buffers(
+ result_buffers.size());
+ for (int i = 0; i < result_buffers.size(); ++i) {
+ TF_ASSIGN_OR_RETURN(
+ reference_result_buffers[i],
+ runtime_arguments.rz_buffers.RedzoneAllocator().AllocateBytes(
+ result_buffers[i].size()));
+ TF_RETURN_IF_ERROR(stream->Memcpy(&reference_result_buffers[i],
+ result_buffers[i],
+ result_buffers[i].size()));
+ }
+ (*reference_result) = {alg, reference_result_buffers};
+ }
+
+ return result;
+}
+
+absl::StatusOr<AutotuneResult>
+GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
+ const HloCustomCallInstruction* instr) {
+ AutotuneCacheKey instruction_info{config_.GetModelStr(), *instr};
+ std::string instr_str(instruction_info.GetHlo());
+ XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
+ "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr_str));
+
+ const DebugOptions& debug_options =
+ instr->GetModule()->config().debug_options();
+ const bool crash_on_checking_failure =
+ debug_options.xla_gpu_crash_on_verification_failures();
+
+ std::string blas_version;
+ se::StreamExecutor* stream_exec = config_.GetExecutor();
+ if (auto* blas = stream_exec->AsBlas()) {
+ (void)blas->GetVersion(&blas_version);
+ }
+
+ std::vector<AlgorithmDesc> disabled_algos;
+ TF_ASSIGN_OR_RETURN(
+ AutotuneRuntimeArguments runtime_arguments,
+ AutotuneRuntimeArguments::FromInstruction(instr, config_, debug_options));
+ if (runtime_arguments.canonical_hlo.has_value()) {
+ disabled_algos = GetDisabledConvAlgorithms(
+ GetComputeCapability(stream_exec), GetCudnnVersion(stream_exec),
+ blas_version, runtime_arguments.canonical_hlo.value());
+ }
+
+ const bool cudnn_frontend_enabled =
+ debug_options.xla_gpu_enable_cudnn_frontend();
+ bool allow_tf32 = true;
+ // TODO(b/284371623): Properly set allow_tf32 even if instr==nullptr, which is
+ // the case when running an AOT compiled executable with runtime autotuning.
+ if (instr) {
+ allow_tf32 = absl::c_all_of(
+ instr->precision_config().operand_precision(),
+ [](int precision) { return precision <= PrecisionConfig::HIGH; });
+ }
+ const se::NumericOptions numeric_options{
+ RequireDeterminism(instr->GetModule()->config()), allow_tf32};
+
+ // Use the first algorithm that's supported as reference. There isn't a
+ // particular reason to use it, as any algorithm suffices. It doesn't make
+ // this algorithm considered correct, though.
+ std::optional<ReferenceResult> reference_result;
+
+ TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+ TF_ASSIGN_OR_RETURN(
+ std::vector<GenericConvRunner> runners,
+ GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
+ cudnn_frontend_enabled,
+ /* use_fallback = */ false, numeric_options));
+
+ std::vector<AutotuneResult> profile_results;
+ for (auto& runner_cache : runners) {
+ TF_ASSIGN_OR_RETURN(
+ auto result,
+ AutotuneOneConvRunner(&runner_cache, &reference_result, disabled_algos,
+ instruction_info, runtime_arguments));
+ profile_results.emplace_back(std::move(result));
+ }
+
+ // If any algorithm has worked, we'll skip the fallback algorithms, since
+ // they include some very slow algorithms.
+ if (!reference_result) {
+ LOG(WARNING) << "None of the algorithms provided by cuDNN heuristics "
+ "worked; trying fallback algorithms.";
+ if (runtime_arguments.canonical_hlo.has_value()) {
+ LOG(WARNING) << "Conv: " << runtime_arguments.canonical_hlo.value();
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<GenericConvRunner> fallback_runners,
+ GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
+ cudnn_frontend_enabled,
+ /* use_fallback = */ true, numeric_options));
+
+ for (auto& runner_cache : fallback_runners) {
+ TF_ASSIGN_OR_RETURN(
+ auto result, AutotuneOneConvRunner(&runner_cache, &reference_result,
+ disabled_algos, instruction_info,
+ runtime_arguments));
+ profile_results.emplace_back(std::move(result));
+ }
+ }
+
+ // Log the autotuning result.
+ if (instr) {
+ AutotuningLog log;
+ {
+ ConvInstructionLog instr_log;
+ *instr_log.mutable_instruction() = instr->ToProto();
+ for (int i = 0; i < instr->operand_count(); i++) {
+ *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
+ instr_log.add_operand_addresses(reinterpret_cast<uint64_t>(
+ runtime_arguments.rz_buffers.input_buffers()[i].opaque()));
+ }
+ for (se::DeviceMemoryBase result_buffer :
+ runtime_arguments.rz_buffers.output_buffers()) {
+ instr_log.add_result_addresses(
+ reinterpret_cast<uint64_t>(result_buffer.opaque()));
+ }
+ log.mutable_instr()->PackFrom(instr_log);
+ }
+ for (const auto& profile : profile_results) {
+ *log.add_results() = profile;
+ }
+ *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
+ *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
+ log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
+ log.set_blas_version(blas_version);
+ VLOG(2) << "Autotuning result: " << log.ShortDebugString();
+ // If we crash on checking failure, we are in a testing/benchmark mode.
+ if (crash_on_checking_failure) {
+ // Crash on miscompares and redzone violations if desired.
+ for (const auto& profile : profile_results) {
+ if (profile.has_failure() &&
+ profile.failure().kind() != AutotuneResult::DISQUALIFIED) {
+ LOG(FATAL) << "crash_on_checking_failure encountered errors:\n\n"
+ << log.DebugString(); // NOLINT
+ }
+ }
+ }
+ }
+
+ TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
+ PickBestResult(profile_results, instr_str,
+ runtime_arguments.hlo_module_config));
+ return selected_algorithm;
+}
+#endif
+
+absl::StatusOr<AutotuneResult>
+GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
+ const HloCustomCallInstruction* instr) {
+ XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
+ "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
+
+ const bool allow_tf32 = absl::c_all_of(
+ instr->precision_config().operand_precision(),
+ [](int precision) { return precision <= PrecisionConfig::HIGH; });
+ const se::NumericOptions numeric_options{
+ RequireDeterminism(instr->GetModule()->config()), allow_tf32};
+
+ se::StreamExecutor* stream_exec = config_.GetExecutor();
+ const auto device_ordinal = stream_exec->device_ordinal();
+ std::vector<se::DeviceMemoryBase> operand_buffers;
+
+ // allocator either points to this->allocator_ or, if that's null, to a
+ // se::StreamExecutorMemoryAllocator for stream_exec.
+ se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
+ ScratchAllocator input_output_allocator(device_ordinal, allocator);
+ TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+ const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
+ // Although we don't have evidence this matters, zero out the buffers
+ // before autotuning. It's conceivable that using uninitialized memory as
+ // the inputs might affect performance if e.g. the inputs contain
+ // denormals, and this is easy enough.
+ return stream->MemZero(&buffer, buffer.size());
+ };
+
+ // Allocate space for the input, filter, and output of the convolution. We
+ // use a ScratchAllocator for this instead of calling allocator_ directly so
+ // that our allocations don't leak.
+ for (const auto* operand : instr->operands()) {
+ TF_ASSIGN_OR_RETURN(auto buffer,
+ input_output_allocator.AllocateBytes(
+ ShapeUtil::ByteSizeOf(operand->shape())));
+ TF_RETURN_IF_ERROR(initialize_buffer(buffer));
+ operand_buffers.push_back(buffer);
+ }
+
+ std::vector<se::DeviceMemoryBase> result_buffers(
+ instr->shape().tuple_shapes_size());
+ if (instr->shape().IsTuple()) {
+ for (int i = 0; i < instr->shape().tuple_shapes_size(); ++i) {
+ TF_ASSIGN_OR_RETURN(
+ result_buffers[i],
+ input_output_allocator.AllocateBytes(
+ ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(i))));
+ TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[i]));
+ }
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ result_buffers[0],
+ input_output_allocator.AllocateBytes(
+ ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
+ TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[0]));
+ }
+
+ ScratchAllocator scratch_allocator(device_ordinal, allocator);
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners,
+ GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers),
+ absl::MakeSpan(result_buffers), stream_exec,
+ &scratch_allocator, stream, numeric_options));
+
+ std::vector<AutotuneResult> profile_results;
+
+ if (runners.size() == 1) {
+ TF_ASSIGN_OR_RETURN(auto alg, runners[0]->ToAlgorithmDesc());
+ auto algorithm_proto = alg.ToProto();
+ profile_results.emplace_back();
+ auto& result = profile_results.back();
+ *result.mutable_algorithm() = algorithm_proto;
+
+ result.set_scratch_bytes(runners[0]->GetWorkspaceSize());
+
+ // TODO(awpr): if the profile result time for a singleton algorithm is
+ // needed, plumb it via OpRunner; we'll need to do this to let TF ops avoid
+ // re-profiling ROCm algorithms anyway.
+ *result.mutable_run_time() =
+ tsl::proto_utils::ToDurationProto(absl::Milliseconds(-1));
+ } else {
+ TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
+ for (auto& runner : runners) {
+ TF_ASSIGN_OR_RETURN(auto alg, runner->ToAlgorithmDesc());
+ XLA_SCOPED_LOGGING_TIMER_LEVEL(
+ absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
+ alg.ToString()),
+ 2);
+
+ se::dnn::ProfileResult profile_result;
+ VLOG(4) << "Trying algorithm " << alg.ToString() << " for "
+ << instr->ToString();
+
+ TF_ASSIGN_OR_RETURN(
+ DeviceMemoryBase scratch_memory,
+ scratch_allocator.AllocateBytes(runner->GetWorkspaceSize()));
+
+ TF_ASSIGN_OR_RETURN(auto lazy_runner,
+ se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
+ std::move(runner)));
+
+ GenericConvRunner runner_cache(std::move(lazy_runner));
+
+ // Use assignment instead of brace-list to make GCC 4.9 happy.
+ RunConvOptions options;
+ options.profile_result = &profile_result;
+ options.runner_cache = &runner_cache;
+ absl::Status launch_status =
+ RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffers,
+ scratch_memory, stream, options);
+
+ if (!launch_status.ok()) {
+ continue;
+ }
+
+ if (!profile_result.is_valid()) {
+ continue;
+ }
+
+ profile_results.emplace_back();
+ AutotuneResult& result = profile_results.back();
+ *result.mutable_algorithm() = alg.ToProto();
+
+ int64_t scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
+ result.set_scratch_bytes(scratch_bytes_used);
+ *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
+ absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+ }
+ }
+
+ TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
+ PickBestResult(profile_results, instr->ToString(),
+ instr->GetModule()->config()));
+ return selected_algorithm;
+}
+
+absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(
+ HloInstruction* instr) {
+ CHECK(IsCustomCallToDnnConvolution(*instr));
+
+ const bool strict = instr->parent()
+ ->parent()
+ ->config()
+ .debug_options()
+ .xla_gpu_strict_conv_algorithm_picker();
+
+ absl::StatusOr<AutotuneResult> best_algo_or =
+ PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
+ if (!best_algo_or.ok()) {
+ auto msg = absl::StrFormat(
+ "Failed to determine best cudnn convolution algorithm for:\n%s\n\n"
+ "Original error: %s",
+ instr->ToString(), best_algo_or.status().ToString());
+
+ if (strict) {
+ return Unknown(
+ "%s\n\nTo ignore this failure and try to use a fallback algorithm "
+ "(which may have suboptimal performance), use "
+ "XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please "
+ "also file a bug for the root cause of failing autotuning.",
+ msg);
+ }
+ LOG(WARNING)
+ << msg << "\n\nAs a result, convolution performance may be suboptimal.";
+ return false;
+ }
+
+ auto best_algo = std::move(best_algo_or).value();
+ VLOG(3) << "Setting cudnn conv to use algorithm "
+ << best_algo.conv().algorithm() << " and "
+ << NumBytesToString(best_algo.scratch_bytes())
+ << " of scratch memory: " << instr->ToString()
+ << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
+
+ // Replace instr with a new CustomCall which has the correct algorithm, and
+ // whose output shape has the appropriate amount of scratch memory.
+ HloComputation* computation = instr->parent();
+ std::vector<Shape> new_call_element_shapes;
+ // Add the shapes of the outputs of the convolution.
+ new_call_element_shapes.reserve(instr->shape().tuple_shapes_size() - 1);
+ for (int i = 0; i < instr->shape().tuple_shapes_size() - 1; ++i) {
+ new_call_element_shapes.emplace_back(instr->shape().tuple_shapes(i));
+ }
+ // The final element is the size of the workspace.
+ new_call_element_shapes.emplace_back(
+ ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()}));
+ Shape new_call_shape = ShapeUtil::MakeTupleShape(new_call_element_shapes);
+
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
+ instr->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& backend_config =
+ *gpu_backend_config.mutable_cudnn_conv_backend_config();
+ *backend_config.mutable_algorithm() = best_algo.algorithm();
+ backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
+ best_algo.scratch_bytes());
+
+ HloInstruction* new_call = computation->AddInstruction(
+ instr->CloneWithNewOperands(new_call_shape, instr->operands()));
+
+ // Preserve the name of the old instruction. This is safe because we're going
+ // to remove the old one anyway, and it makes it easier to trace how our conv
+ // is transformed through all our passes.
+ new_call->SetAndSanitizeName(instr->name());
+
+ VLOG(3) << "Replacing convolution " << instr->ToString() << " with "
+ << new_call->ToString();
+
+ TF_RETURN_IF_ERROR(new_call->set_backend_config(gpu_backend_config));
+
+ std::vector<HloInstruction*> new_tuple_elements;
+ new_tuple_elements.reserve(new_call->shape().tuple_shapes_size() - 1);
+ for (int i = 0; i < new_call->shape().tuple_shapes_size() - 1; ++i) {
+ new_tuple_elements.emplace_back(
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_call->shape().tuple_shapes(i), new_call, i)));
+ }
+ new_tuple_elements.emplace_back(computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({}))));
+
+ // Repackage new_call so it has the same shape as the original call, namely
+ // (conv_result, u8[0]).
+ HloInstruction* new_tuple = computation->AddInstruction(
+ HloInstruction::CreateTuple(new_tuple_elements));
+
+ TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
+ return true;
+}
+
+absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
+ HloComputation* computation) {
+ std::vector<HloInstruction*> convs;
+ for (HloInstruction* instr : computation->instructions()) {
+ if (IsCandidate(instr)) {
+ convs.push_back(instr);
+ }
+ }
+
+ bool changed = false;
+ for (HloInstruction* instr : convs) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
+ changed |= result;
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> GpuConvAlgorithmPicker::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_SCOPED_LOGGING_TIMER(
+ absl::StrCat("GpuConvAlgorithmPicker for ", module->name()));
+
+ if (!IsEnabled(module)) {
+ VLOG(3) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
+ "returning early.";
+ return false;
+ }
+
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h
new file mode 100644
index 0000000..173a0c6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h
@@ -0,0 +1,153 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_
+
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/gpu_conv_runner.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// Choose the fastest algorithm for each conv.
+// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
+// each and adding explicit scratch space to the CustomCalls.
+//
+// We pick the algorithm before fusion so that we can generate better HLO. After
+// GpuConvRewriter, our convolutions are CustomCalls which return a
+// tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
+// scratch:
+//
+// customcall = (f32[...], f32[0])
+// return gte(customcall, 0)
+//
+// The algorithm picker then chooses the best algorithm, and potentially
+// increases the scratch space. It replaces customcall with new_tuple,
+// giving us the following:
+//
+// new_customcall = (f32[...], f32[N])
+// new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
+// return gte(new_tuple, 0)
+//
+// The new tuple and gte instructions can be simplified away, because
+// nobody is expected to use the scratch value.
+//
+// However, if we were to run GpuConvAlgorithmPicker after fusion
+// the gte(customcall, 0) would probably already be into a fusion node. We
+// can't simplify across HloComputation boundaries, so in this case we
+// wouldn't be able to simplify away the new_tuple bits.
+//
+// It supports two modes: device and deviceless.
+// In device mode, we run autotuning on the device and store autotune results.
+//
+// In deviceless mode, we pass in some information related to the device and
+// use stored autotune results to rewrite convolutions. If the required autotune
+// result is not stored, then the performance of convolution will be suboptimal.
+class GpuConvAlgorithmPicker : public HloModulePass {
+ public:
+ explicit GpuConvAlgorithmPicker(AutotuneConfig config) : config_(config) {}
+
+ absl::string_view name() const override {
+ return "gpu-conv-algorithm-picker";
+ }
+
+ static bool IsEnabled(const HloModule* module) {
+ return module->config().debug_options().xla_gpu_autotune_level() != 0;
+ }
+
+ static bool IsCandidate(const HloInstruction* instr) {
+ return IsCustomCallToDnnConvolution(*instr);
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+ absl::StatusOr<bool> RunOnInstruction(HloInstruction* instr);
+
+ absl::StatusOr<AutotuneResult> PickBestAlgorithm(
+ const HloCustomCallInstruction* instr);
+ absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCache(
+ const HloCustomCallInstruction* instr);
+
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+ // Simple bundle of an algorithm and its output, for comparing results across
+ // autotuned algorithms.
+ struct ReferenceResult {
+ stream_executor::dnn::AlgorithmDesc algorithm;
+ std::vector<stream_executor::DeviceMemoryBase> buffers;
+ };
+
+ // Execution environment for autotuning. Runtime autotuning requires runtime
+ // information such as input/output buffers in order to run. It can be
+ // constructed from the autotuned instruction by FromInstruction.
+ struct AutotuneRuntimeArguments {
+ const HloModuleConfig hlo_module_config;
+ RedzoneBuffers rz_buffers;
+ const GpuConvConfig gpu_conv_config;
+ std::optional<std::string> canonical_hlo;
+
+ static absl::StatusOr<AutotuneRuntimeArguments> FromInstruction(
+ const HloCustomCallInstruction* instr, const AutotuneConfig& config,
+ const DebugOptions& debug_options);
+ };
+
+ absl::StatusOr<AutotuneResult> AutotuneOneConvRunner(
+ GenericConvRunner* runner,
+ std::optional<ReferenceResult>* reference_result,
+ absl::Span<const stream_executor::dnn::AlgorithmDesc> disabled_algos,
+ std::optional<AutotuneCacheKey> instruction_info,
+ const AutotuneRuntimeArguments& runtime_arguments);
+
+ // Pick the best algorithm for CUDA platform.
+ absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheCuda(
+ const HloCustomCallInstruction* instr);
+#endif
+
+ absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheRocm(
+ const HloCustomCallInstruction* instr);
+
+ private:
+ AutotuneConfig config_;
+};
+
+} // namespace gpu
+} // namespace xla
+#endif // XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc
new file mode 100644
index 0000000..9652014
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc
@@ -0,0 +1,129 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+
+#include <cstdint>
+#include <variant>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/tuple_simplifier.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class GpuConvAlgorithmPickerTest : public HloTestBase {
+ public:
+ GpuConvAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
+};
+
+TEST_F(GpuConvAlgorithmPickerTest, SetAlgorithm) {
+ constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+ %arg0 = f32[3,56,56,16]{2,1,0,3} parameter(0)
+ %arg1 = f32[3,3,3,64]{2,1,0,3} parameter(1)
+ ROOT %conv = f32[54,54,16,64]{1,0,3,2} convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo));
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+ PlatformUtil::GetStreamExecutors(platform));
+ ASSERT_GT(executors.size(), 0);
+ se::StreamExecutor* stream_exec = executors[0];
+
+ const se::GpuComputeCapability& cc = backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .gpu_compute_capability();
+ bool changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get()));
+ changed = false;
+ DebugOptions opts = DefaultDebugOptionsIgnoringFlags();
+
+ AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts};
+ TF_ASSERT_OK_AND_ASSIGN(changed,
+ RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
+ ASSERT_TRUE(changed);
+
+ AutotuneResults results;
+ TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
+ ASSERT_EQ(results.results_size(), 1);
+ auto& result = *results.mutable_results(0)->mutable_result();
+ int64_t old_scratch_bytes = result.scratch_bytes();
+ int64_t new_scratch_bytes = old_scratch_bytes + 1;
+ result.set_scratch_bytes(new_scratch_bytes);
+
+ AutotunerUtil::ClearAutotuneResults();
+ TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
+
+ // Now send the same module through GpuConvAlgorithmPicker again. The conv
+ // should have the new scratch bytes.
+ TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo));
+ changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get()));
+ changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(changed,
+ RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
+ ASSERT_TRUE(changed);
+
+ // TupleSimplifier cleans this up a bit before we pattern-match
+ TF_ASSERT_OK(RunHloPass(TupleSimplifier(), m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ HloInstruction* conv;
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall(&conv))));
+ EXPECT_THAT(
+ conv->shape(),
+ GmockMatch(m::Shape().WithSubshape(
+ {1}, m::Shape().WithElementType(U8).WithDims({new_scratch_bytes}))));
+
+ // Algorithm 14 is disabled for cuDNN 9 on V100
+ TF_ASSERT_OK_AND_ASSIGN(auto dnn_version, GetDnnVersionInfo(stream_exec));
+ if (dnn_version.major_version() >= 9 && dnn_version.major_version() < 10 &&
+ std::holds_alternative<stream_executor::CudaComputeCapability>(cc) &&
+ std::get<stream_executor::CudaComputeCapability>(cc).major == 7 &&
+ std::get<stream_executor::CudaComputeCapability>(cc).minor == 0) {
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->has_cudnn_conv_backend_config() &&
+ conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .algorithm()
+ .algo_id() != 14);
+ }
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc
new file mode 100644
index 0000000..a920b99
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc
@@ -0,0 +1,220 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <tuple>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/kernels/custom_kernel.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+absl::StatusOr<std::unique_ptr<HloModule>> ExtractFusionModule(
+ HloInstruction* fusion_instruction, int64_t kernel_index) {
+ std::unique_ptr<HloModule> hlo_module =
+ ExtractInstructionIntoNewModule(*fusion_instruction);
+
+ HloInstruction* instruction =
+ hlo_module->entry_computation()->root_instruction();
+ GpuBackendConfig gpu_config =
+ instruction->backend_config<GpuBackendConfig>().value();
+ gpu_config.mutable_fusion_backend_config()
+ ->mutable_custom_fusion_config()
+ ->set_kernel_index(kernel_index);
+ TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
+
+ return hlo_module;
+}
+
+absl::StatusOr<std::vector<std::tuple<int, absl::Duration>>> ProfileKernels(
+ std::vector<CustomKernel>& kernels, HloInstruction* fusion_instruction,
+ AutotunerCompileUtil& compile_util, const AutotuneConfig& autotune_config,
+ const DebugOptions& debug_options) {
+ se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
+ std::vector<std::tuple<int, absl::Duration>> results;
+ for (int i = 0; i < kernels.size(); ++i) {
+ TF_ASSIGN_OR_RETURN(absl::StatusOr<std::unique_ptr<Executable>> executable,
+ compile_util.Compile([&](const DebugOptions& opt) {
+ return ExtractFusionModule(fusion_instruction, i);
+ }));
+
+ se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator();
+ std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
+ if (allocator == nullptr) {
+ owned_allocator =
+ std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
+ allocator = owned_allocator.get();
+ }
+ TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream());
+
+ TF_ASSIGN_OR_RETURN(auto rz_buffers,
+ RedzoneBuffers::FromInstruction(
+ *fusion_instruction, autotune_config, debug_options,
+ RedzoneBuffers::kAllInputs));
+
+ std::optional<ScopedShapedBuffer> reference_buffer;
+ std::optional<AutotunerCompileUtil::ProfilingOutput> profiling_output;
+ TF_ASSIGN_OR_RETURN(profiling_output, compile_util.ProfileExecutable(
+ executable->get(), stream,
+ rz_buffers.input_buffers(),
+ rz_buffers.input_shapes()));
+ results.push_back({i, profiling_output->duration});
+ }
+ return results;
+}
+
+absl::StatusOr<int> FindFastestKernel(
+ const std::vector<std::tuple<int, absl::Duration>>& results) {
+ auto iter = absl::c_min_element(
+ results, [](const std::tuple<int, absl::Duration>& lhs,
+ const std::tuple<int, absl::Duration>& rhs) {
+ return std::get<1>(lhs) < std::get<1>(rhs);
+ });
+ if (iter == results.end()) {
+ return absl::InternalError("Failed to find fastest kernel.");
+ }
+ return std::get<0>(*iter);
+}
+
+absl::Status UpdateFusionInstructionKernelIndex(
+ HloInstruction* fusion_instruction, int kernel_index) {
+ GpuBackendConfig gpu_config =
+ fusion_instruction->backend_config<GpuBackendConfig>().value();
+ gpu_config.mutable_fusion_backend_config()
+ ->mutable_custom_fusion_config()
+ ->set_kernel_index(kernel_index);
+ TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config));
+
+ return absl::OkStatus();
+}
+
+absl::StatusOr<std::vector<CustomKernel>> LoadKernels(
+ const HloInstruction* fusion_instruction,
+ const AutotuneConfig& autotune_config) {
+ auto config = fusion_instruction->backend_config<GpuBackendConfig>()
+ ->fusion_backend_config()
+ .custom_fusion_config();
+ auto* registry = CustomKernelFusionRegistry::Default();
+ auto* custom_kernel_fusion = registry->Lookup(config.name());
+
+ // If custom fusion is not found it means that some of the build targets might
+ // not be statically linked into the binary.
+ if (custom_kernel_fusion == nullptr) {
+ return absl::InternalError(
+ absl::StrCat("Custom kernel fusion ", config.name(),
+ " not found in a default registry."));
+ }
+
+ se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
+ if (!stream_exec->SynchronizeAllActivity()) {
+ return Internal("Failed to synchronize GPU for autotuning.");
+ }
+ se::DeviceDescription device_description =
+ stream_exec->GetDeviceDescription();
+
+ // Load custom kernels that can implement a fusion computation.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<CustomKernel> kernels,
+ custom_kernel_fusion->LoadKernels(
+ device_description,
+ fusion_instruction->fused_instructions_computation()));
+
+ return kernels;
+}
+
+absl::StatusOr<bool> AutotuneCustomKernelFusion(
+ HloInstruction* fusion_instruction, const AutotuneConfig& autotune_config,
+ AutotunerCompileUtil& compile_util, const DebugOptions& debug_options) {
+ int previous_kernel_index =
+ fusion_instruction->backend_config<GpuBackendConfig>()
+ ->fusion_backend_config()
+ .custom_fusion_config()
+ .kernel_index();
+
+ TF_ASSIGN_OR_RETURN(std::vector<CustomKernel> kernels,
+ LoadKernels(fusion_instruction, autotune_config));
+
+ std::vector<std::tuple<int, absl::Duration>> results;
+ TF_ASSIGN_OR_RETURN(results,
+ ProfileKernels(kernels, fusion_instruction, compile_util,
+ autotune_config, debug_options));
+
+ TF_ASSIGN_OR_RETURN(int fastest_kernel_index, FindFastestKernel(results));
+
+ TF_RETURN_IF_ERROR(UpdateFusionInstructionKernelIndex(fusion_instruction,
+ fastest_kernel_index));
+
+ return previous_kernel_index != fastest_kernel_index;
+}
+} // namespace
+
+absl::StatusOr<bool> CustomKernelFusionAutotuner::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ const DebugOptions& debug_options = module->config().debug_options();
+ TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> compile_util,
+ AutotunerCompileUtil::Create(config_, debug_options));
+ TF_RET_CHECK(compile_util.has_value());
+
+ bool hlo_changed = false;
+ for (const HloComputation* computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ TF_ASSIGN_OR_RETURN(
+ bool instruction_changed,
+ AutotuneCustomKernelFusion(computation->FusionInstruction(), config_,
+ compile_util.value(), debug_options));
+ if (instruction_changed) {
+ hlo_changed = true;
+ }
+ }
+ }
+
+ return hlo_changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h
new file mode 100644
index 0000000..07aad07
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h
@@ -0,0 +1,53 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/xla.pb.h"
+
+namespace xla {
+namespace gpu {
+
+// Find best custom kernel for custom kernel fusions.
+class CustomKernelFusionAutotuner : public HloModulePass {
+ public:
+ explicit CustomKernelFusionAutotuner(const AutotuneConfig& config)
+ : config_(config) {}
+
+ absl::string_view name() const override {
+ return "custom_kernel-fusion-autotuner";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const AutotuneConfig config_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc
new file mode 100644
index 0000000..8defca9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CustomKernelFusionAutotunerTest : public HloTestBase {
+ public:
+ CustomKernelFusionAutotunerTest()
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/true) {}
+
+ void SetUp() override { HloTestBase::SetUp(); }
+
+ void TearDown() override { HloTestBase::TearDown(); }
+};
+
+TEST_F(CustomKernelFusionAutotunerTest,
+ CustomKernelFusionAutotunerPassSucceeds) {
+ const std::string hlo_string = R"(
+ HloModule extracted
+
+ cutlass_gemm {
+ p0 = f32[15,19]{1,0} parameter(0)
+ p1 = f32[19,17]{1,0} parameter(1)
+ ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+ ENTRY region_198.14436 {
+ p.0 = f32[15,19]{1,0} parameter(0)
+ p.1 = f32[19,17]{1,0} parameter(1)
+ ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom, calls=cutlass_gemm, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":0}},"force_earliest_schedule":false}
+ }
+ )";
+ std::unique_ptr<HloModule> hlo_module =
+ ParseAndReturnVerifiedModule(hlo_string).value();
+
+ HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
+ DebugOptions debug_options;
+ AutotuneConfig autotune_config =
+ AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
+ backend().memory_allocator()},
+ debug_options};
+ pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
+ ASSERT_TRUE(pipeline.Run(hlo_module.get()).ok());
+}
+
+TEST_F(CustomKernelFusionAutotunerTest,
+ CustomKernelFusionAutotunerPassUpdatesUpdatesKernelIndex) {
+ const std::string hlo_string = R"(
+ HloModule extracted
+
+ cutlass_gemm {
+ p0 = f32[15,19]{1,0} parameter(0)
+ p1 = f32[19,17]{1,0} parameter(1)
+ ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+ rhs_contracting_dims={0}
+ }
+
+ ENTRY region_198.14436 {
+ p.0 = f32[15,19]{1,0} parameter(0)
+ p.1 = f32[19,17]{1,0} parameter(1)
+ ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom,
+ calls=cutlass_gemm,
+ backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":-1}},"force_earliest_schedule":false}
+ }
+ )";
+
+ HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
+ DebugOptions debug_options;
+ AutotuneConfig autotune_config =
+ AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
+ backend().memory_allocator()},
+ debug_options};
+ pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
+
+ std::string expected = R"(
+ CHECK: "kernel_index":0
+ )";
+ RunAndFilecheckHloRewrite(hlo_string, std::move(pipeline), expected);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc
new file mode 100644
index 0000000..8a870f8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc
@@ -0,0 +1,500 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/variant_visitor.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/tsl/util/proto/proto_utils.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/profiler/lib/scoped_annotation.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using se::gpu::BlasLt;
+
+absl::StatusOr<BlasLt::Epilogue> AsBlasLtEpilogue(
+ GemmBackendConfig_Epilogue epilogue) {
+ switch (epilogue) {
+ case GemmBackendConfig::DEFAULT:
+ return BlasLt::Epilogue::kDefault;
+ case GemmBackendConfig::RELU:
+ return BlasLt::Epilogue::kReLU;
+ case GemmBackendConfig::GELU:
+ return BlasLt::Epilogue::kGELU;
+ case GemmBackendConfig::GELU_AUX:
+ return BlasLt::Epilogue::kGELUWithAux;
+ case GemmBackendConfig::BIAS:
+ return BlasLt::Epilogue::kBias;
+ case GemmBackendConfig::BIAS_RELU:
+ return BlasLt::Epilogue::kBiasThenReLU;
+ case GemmBackendConfig::BIAS_GELU:
+ return BlasLt::Epilogue::kBiasThenGELU;
+ case GemmBackendConfig::BIAS_GELU_AUX:
+ return BlasLt::Epilogue::kBiasThenGELUWithAux;
+ default:
+ return Internal("Unsupported Epilogue.");
+ }
+}
+
+class GemmAutotuner {
+ const AutotuneConfig& autotune_config_;
+ RedzoneBuffers rz_buffers_;
+ se::Stream* stream_ = nullptr;
+ bool deterministic_ops_ = false;
+ size_t solutions_limit_ = 0;
+ size_t num_algorithms_left_ = 0;
+
+ public:
+ explicit GemmAutotuner(const AutotuneConfig& autotune_config)
+ : autotune_config_(autotune_config) {}
+
+ size_t num_algorithms_left() const { return num_algorithms_left_; }
+
+ absl::StatusOr<AutotuneResult> operator()(const HloInstruction* gemm,
+ const AutotuneCacheKey& key) {
+ num_algorithms_left_ = 0;
+ if (autotune_config_.IsDeviceless()) {
+ // Return empty result, will tune at runtime.
+ return AutotuneResult{};
+ }
+ VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
+
+ TF_ASSIGN_OR_RETURN(stream_, autotune_config_.GetStream());
+ const DebugOptions& debug_options =
+ gemm->GetModule()->config().debug_options();
+ deterministic_ops_ = RequireDeterminism(gemm->GetModule()->config());
+ solutions_limit_ = debug_options.xla_gpu_autotune_max_solutions();
+
+ TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm));
+
+ // Don't run autotuning concurrently on the same GPU.
+ absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent()));
+
+ TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromInstruction(
+ *gemm, autotune_config_, debug_options,
+ RedzoneBuffers::kAllInputsAllOutputs));
+
+ return IsCublasLtMatmul(*gemm) || IsCublasLtMatmulF8(*gemm)
+ ? TuneGpuBlasLt(gemm, gemm_config)
+ : TuneGpuBlas(gemm, gemm_config);
+ }
+
+ private:
+ se::DeviceMemoryBase LhsBuffer() { return rz_buffers_.input_buffers().at(0); }
+ se::DeviceMemoryBase RhsBuffer() { return rz_buffers_.input_buffers().at(1); }
+ se::DeviceMemoryBase OutputBuffer() {
+ return rz_buffers_.output_buffers().at(0);
+ }
+
+ const Shape& GetOutputShape(const HloInstruction* gemm) {
+ return gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0)
+ : gemm->shape();
+ }
+
+ absl::StatusOr<AutotuneResult> TuneGpuBlasLt(const HloInstruction* gemm,
+ const GemmConfig& gemm_config) {
+ auto workspace_buffer =
+ rz_buffers_.output_buffers().at(gemm->shape().tuple_shapes_size() - 1);
+
+ GpuBackendConfig gpu_config =
+ gemm->backend_config<GpuBackendConfig>().value();
+ const GemmBackendConfig& backend_config = gpu_config.gemm_backend_config();
+
+ bool has_matrix_bias = gemm_config.beta != 0.;
+
+ TF_ASSIGN_OR_RETURN(
+ bool has_vector_bias,
+ gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue()));
+
+ TF_ASSIGN_OR_RETURN(
+ bool has_aux_output,
+ gpublas_lt::EpilogueHasAuxiliaryOutput(backend_config.epilogue()));
+
+ TF_ASSIGN_OR_RETURN(auto epilogue,
+ AsBlasLtEpilogue(backend_config.epilogue()));
+
+ se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer,
+ d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer;
+
+ if (has_vector_bias) {
+ bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2);
+ }
+ if (has_aux_output) {
+ aux_buffer = rz_buffers_.output_buffers().at(1);
+ }
+
+ TF_ASSIGN_OR_RETURN(auto plan,
+ BlasLt::GetMatmulPlan(stream_, gemm_config, epilogue));
+
+ TF_ASSIGN_OR_RETURN(
+ auto algorithms,
+ plan->GetAlgorithms(/*max_algorithm_count*/ 128,
+ /*max_workspace_size*/ workspace_buffer.size()));
+
+ auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm)
+ -> absl::StatusOr<se::blas::ProfileResult> {
+ // Run a warmup iteration without the profiler active.
+ TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
+ stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
+ bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
+ c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
+ workspace_buffer));
+ se::blas::ProfileResult profile_result;
+ profile_result.set_warmup_run_executed(true);
+ TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
+ stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
+ bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
+ c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
+ workspace_buffer, &profile_result));
+ return std::move(profile_result);
+ };
+
+ return GetBestAlgorithm<BlasLt::MatmulAlgorithm>(
+ gemm, algorithms, gemm_config.beta, /*return_algo_index*/ true,
+ tuned_func);
+ }
+
+ absl::StatusOr<AutotuneResult> TuneGpuBlas(const HloInstruction* gemm,
+ const GemmConfig& gemm_config) {
+ auto workspace_buffer = rz_buffers_.output_buffers().at(1);
+
+ std::vector<se::blas::AlgorithmType> algorithms;
+ TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc,
+ gemm_config.GetMatrixDescriptors(
+ LhsBuffer(), RhsBuffer(), OutputBuffer()));
+
+ auto blas = stream_->parent()->AsBlas();
+ if (blas == nullptr) {
+ return absl::InternalError("No BLAS support for stream");
+ }
+ blas->GetBlasGemmAlgorithms(stream_, desc.lhs, desc.rhs, &desc.output,
+ &gemm_config.alpha, &gemm_config.beta,
+ &algorithms);
+
+ auto tuned_func = [&](const se::blas::AlgorithmType& algorithm)
+ -> absl::StatusOr<se::blas::ProfileResult> {
+ // Do a warm-up run first, without a profile result. RunGemm swallows
+ // error codes when profile_result is passed, as it is in the measurement
+ // below, but not otherwise. It is, therefore, consistent to ignore the
+ // error code here.
+ static_cast<void>(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
+ OutputBuffer(), workspace_buffer,
+ deterministic_ops_, stream_, algorithm));
+ se::blas::ProfileResult profile_result;
+ // Allow GpuTimer to use its delay kernel implementation to improve
+ // accuracy.
+ profile_result.set_warmup_run_executed(true);
+ // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
+ // for all algorithms if we're targeting < sm_50. But because we pass a
+ // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
+ // and the actual success-ness is returned in ProfileResult::is_valid.
+ TF_RETURN_IF_ERROR(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
+ OutputBuffer(), workspace_buffer,
+ deterministic_ops_, stream_, algorithm,
+ &profile_result));
+ return std::move(profile_result);
+ };
+
+ return GetBestAlgorithm<se::blas::AlgorithmType>(
+ gemm, algorithms, gemm_config.beta, /*return_algo_index*/ false,
+ tuned_func);
+ }
+
+ // Returns the index (into `algorithms`) of the fastest algorithm.
+ template <typename AlgoT, typename TunedFunc>
+ absl::StatusOr<AutotuneResult> GetBestAlgorithm(
+ const HloInstruction* gemm, absl::Span<const AlgoT> algorithms,
+ double beta, bool return_algo_index, TunedFunc&& run_benchmark) {
+ static_assert(std::is_invocable_r_v<absl::StatusOr<se::blas::ProfileResult>,
+ TunedFunc, const AlgoT&>,
+ "Tuned function has incorrect prototype!");
+
+ if (!stream_->parent()->SynchronizeAllActivity()) {
+ return Internal("Failed to synchronize GPU for autotuning.");
+ }
+ tsl::profiler::ScopedAnnotation annotation([&] {
+ return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
+ gemm->name());
+ });
+
+ auto& hlo_module_config = gemm->GetModule()->mutable_config();
+ const auto& output_shape = GetOutputShape(gemm);
+
+ se::DeviceMemoryBase reference_buffer;
+ if (autotune_config_.should_check_correctness()) {
+ TF_ASSIGN_OR_RETURN(reference_buffer,
+ rz_buffers_.RedzoneAllocator().AllocateBytes(
+ ShapeUtil::ByteSizeOf(output_shape)));
+ }
+
+ // Do not print error messages if should_skip_wrong_results() is ON.
+ BufferComparator comparator(
+ output_shape,
+ hlo_module_config.debug_options().xla_gpu_autotune_gemm_rtol(),
+ /* verbose */ !autotune_config_.should_skip_wrong_results());
+ std::vector<AutotuneResult> results;
+ results.reserve(algorithms.size());
+ std::optional<int64_t> reference_algorithm;
+
+ auto num = algorithms.size();
+ if (solutions_limit_ > 0) num = std::min(num, solutions_limit_);
+ for (size_t i = 0; i < num; i++) {
+ const AlgoT& algorithm = algorithms[i];
+ // Make sure the output buffer always has the same value if we use
+ // the bias parameter.
+ if (autotune_config_.should_reinit_output_buffer() && beta != 0) {
+ int64_t rng_state = 0;
+ InitializeBuffer(stream_, output_shape.element_type(), &rng_state,
+ OutputBuffer());
+ }
+ TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm));
+
+ AutotuneResult& result = results.emplace_back();
+ result.mutable_gemm()->set_algorithm(profile_result.algorithm());
+
+ if (!profile_result.is_valid()) { // Unsupported algorithm.
+ result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
+ continue;
+ }
+
+ VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took "
+ << profile_result.elapsed_time_in_ms() << "ms";
+
+ *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
+ absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+
+ if (!autotune_config_.should_check_correctness()) {
+ num_algorithms_left_++;
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(
+ se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
+ rz_buffers_.RedzoneAllocator().CheckRedzones());
+
+ if (!rz_check_status.ok()) {
+ result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
+ *result.mutable_failure()->mutable_msg() =
+ rz_check_status.RedzoneFailureMsg();
+ LOG(ERROR) << "Detected out-of-bounds write in gemm buffer";
+ CHECK(!autotune_config_.should_crash_on_check_failure());
+ continue;
+ }
+
+ num_algorithms_left_++;
+ if (!reference_algorithm) {
+ TF_RETURN_IF_ERROR(stream_->Memcpy(&reference_buffer, OutputBuffer(),
+ OutputBuffer().size()));
+ reference_algorithm = profile_result.algorithm();
+ continue;
+ }
+ // Perform the comparison versus the reference algorithm.
+ TF_ASSIGN_OR_RETURN(
+ bool outputs_match,
+ comparator.CompareEqual(stream_, /*current=*/OutputBuffer(),
+ /*expected=*/reference_buffer));
+ if (!outputs_match) {
+ LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
+ << "This is likely a bug/unexpected loss of precision.";
+ CHECK(!autotune_config_.should_crash_on_check_failure());
+
+ // By default, autotuner does NOT really skip wrong results, but
+ // merely prints out the above error message: this may lead to a
+ // great confusion. When should_skip_wrong_results() is set to true,
+ // solutions with accuracy problems will be disqualified.
+ auto kind = AutotuneResult::WRONG_RESULT;
+ if (autotune_config_.should_skip_wrong_results()) {
+ kind = AutotuneResult::DISQUALIFIED;
+ num_algorithms_left_--; // Decrement again since we disqualified it.
+ }
+ result.mutable_failure()->set_kind(kind);
+ result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
+ *reference_algorithm);
+ }
+ } // for algorithms
+
+ absl::StatusOr<AutotuneResult> best =
+ PickBestResult(results, gemm->ToString(), hlo_module_config);
+ if (best.ok()) {
+ // Note that, cublas-lt returns an opaque object as an algorithm ID,
+ // therefore we need to convert it to the index from the algorithms list
+ // (otherwise, we cannot store this ID inside a gemm_backend_config).
+ // In contrast, legacy cublas returns a 32-bit integer algorithm ID which
+ // can be readily stored inside an HLO (hence return_algo_index is false
+ // for cublas case).
+ if (!return_algo_index) return best;
+ // Otherwise, map a real algorithm ID to its index among the results.
+ for (size_t i = 0; i < results.size(); ++i) {
+ if (best->gemm().algorithm() == results[i].gemm().algorithm()) {
+ best->mutable_gemm()->set_algorithm(i);
+ return best;
+ }
+ }
+ return Internal("unknown best algorithm");
+ }
+ LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance "
+ "might be suboptimal: "
+ << best.status();
+ return AutotuneResult{};
+ } // GetBestAlgorithm
+}; // GemmAutotuner
+
+// Do Gemm Autotune without stream executor. Use results from autotune cache
+// only.
+absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,
+ const AutotuneConfig& config,
+ size_t* num_algorithms_left) {
+ VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString();
+
+ GpuBackendConfig gpu_config =
+ gemm->backend_config<GpuBackendConfig>().value();
+ GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config();
+
+ *num_algorithms_left = 0;
+ // Degenerate gemms replaced with memzero operation, no need to auto tune it.
+ if (backend_config.alpha_real() == 0.0 &&
+ backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) {
+ VLOG(3) << "Skip degenerate gemm instruction auto tuning";
+ return false;
+ }
+
+ AutotuneCacheKey key(config.GetModelStr(), *gemm);
+ GemmAutotuner autotuner(config);
+ TF_ASSIGN_OR_RETURN(AutotuneResult algorithm,
+ AutotunerUtil::Autotune(
+ gemm, config, [&] { return autotuner(gemm, key); }));
+
+ *num_algorithms_left = autotuner.num_algorithms_left();
+ auto old_algorithm = backend_config.selected_algorithm();
+ bool update_algorithm =
+ IsCublasLtMatmulF8(*gemm) ||
+ std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) {
+ // We only set the 'algorithm' field on
+ // non-Ampere architectures, as for Ampere
+ // it's ignored in any case.
+ return !cc.IsAtLeast(
+ se::CudaComputeCapability::AMPERE);
+ },
+ [](const se::RocmComputeCapability&) {
+ return true; // TODO: not decided yet
+ }},
+ config.GetGpuComputeCapability());
+
+ if (update_algorithm) {
+ int64_t new_algorithm{};
+ if (algorithm.has_gemm()) {
+ new_algorithm = algorithm.gemm().algorithm();
+ } else {
+ // NOTE: runtime autotuning is no longer available => set to default
+ new_algorithm = se::blas::kDefaultAlgorithm;
+ }
+
+ if (new_algorithm == old_algorithm &&
+ backend_config.has_selected_algorithm()) {
+ // We don't need to update the backend config if
+ // the algorithm hasn't changed unless previously
+ // the algorithm wasn't set explicitly.
+ return false;
+ }
+
+ backend_config.set_selected_algorithm(new_algorithm);
+ TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config));
+ return true; // We changed `gemm`
+ }
+
+ return false; // No change to `gemm`
+}
+
+absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
+ AutotuneConfig config,
+ size_t* num_algorithms_left) {
+ bool changed = false;
+
+ for (HloInstruction* instr : computation->instructions()) {
+ if (IsCublasGemm(*instr)) {
+ size_t num_left;
+ TF_ASSIGN_OR_RETURN(bool result,
+ RunOnInstruction(instr, config, &num_left));
+ // Gathering statistics on the algorithms left after tuning (for testing)
+ *num_algorithms_left = std::max(*num_algorithms_left, num_left);
+ changed |= result;
+ }
+ }
+ return changed;
+}
+
+} // namespace
+
+absl::StatusOr<bool> GemmAlgorithmPicker::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_SCOPED_LOGGING_TIMER(
+ absl::StrCat("GemmAlgorithmPicker for ", module->name()));
+
+ num_algorithms_left_ = 0;
+ if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
+ VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
+ return false;
+ }
+
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_,
+ &num_algorithms_left_));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h
new file mode 100644
index 0000000..2373583
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h
@@ -0,0 +1,70 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_
+
+#include <functional>
+#include <optional>
+#include <string_view>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// GemmAlgorithmPicker supports two modes: device and deviceless.
+// In device mode, we run autotuning on the device and store autotune results.
+// In deviceless mode, we pass in some information related to the device and
+// use stored autotune results to rewrite Gemm instructions. If the required
+// autotune result is not stored, then algorithm is set to kRuntimeAutotuning.
+class GemmAlgorithmPicker : public HloModulePass {
+ public:
+ explicit GemmAlgorithmPicker(AutotuneConfig config) : config_(config) {}
+
+ absl::string_view name() const override { return "gemm-algorithm-picker"; }
+
+ size_t num_algorithms_left() const { return num_algorithms_left_; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ AutotuneConfig config_;
+ // The number of valid algorithms used for autotuning (from the last call),
+ // to be used for testing purposes.
+ size_t num_algorithms_left_ = 0;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc
new file mode 100644
index 0000000..f1bd187
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc
@@ -0,0 +1,301 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
+
+#include <cstdint>
+#include <variant>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/gpu/variant_visitor.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/platform_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+#include "tsl/protobuf/dnn.pb.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class GemmAlgorithmPickerTest : public HloTestBase,
+ public ::testing::WithParamInterface<bool> {
+ public:
+ GemmAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_cublaslt(GetParam());
+ debug_options.set_xla_gpu_enable_triton_gemm(false);
+ return debug_options;
+ }
+
+ const se::DeviceDescription& device_desc() {
+ return backend().default_stream_executor()->GetDeviceDescription();
+ }
+
+ se::StreamExecutor* stream_exec() {
+ return backend().default_stream_executor();
+ }
+ const se::DeviceDescription& gpu_device_desc() {
+ return stream_exec()->GetDeviceDescription();
+ }
+ const se::GpuComputeCapability& gpu_comp() {
+ return gpu_device_desc().gpu_compute_capability();
+ }
+
+ void SetUp() override {
+ std::string_view name =
+ ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ // We need special handling for BlasGetVersion test.
+ bool blas_get_version = name.rfind("BlasGetVersion") == 0;
+
+ std::visit(
+ VariantVisitor{
+ [&](const se::CudaComputeCapability& cc) {
+ if (!blas_get_version && cc.IsAtLeastAmpere()) {
+ GTEST_SKIP()
+ << "Skipping this test for Ampere+ as it is supported "
+ "and recommended with the Nvidia Volta+ GPUs.";
+ }
+ },
+ [&](const se::RocmComputeCapability& cc) {
+ if (blas_get_version) {
+ auto version = std::stol(device_desc().runtime_version());
+ if (version < 60200) {
+ GTEST_SKIP()
+ << "This API is not available on ROCM 6.1 and below.";
+ }
+ } else if (GetDebugOptionsForTest().xla_gpu_enable_cublaslt() &&
+ !cc.has_hipblaslt()) {
+ GTEST_SKIP() << "No gpublas-lt support on this architecture!";
+ }
+ }},
+ gpu_comp());
+ }
+};
+
+TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) {
+ auto* blas = backend().default_stream_executor()->AsBlas();
+ ASSERT_TRUE(blas != nullptr);
+ std::string version;
+ ASSERT_TRUE(blas->GetVersion(&version).ok());
+ VLOG(0) << "Blas version: " << version;
+ ASSERT_TRUE(!version.empty());
+}
+
+TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) {
+ constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+ %arg0 = f32[100,100]{1,0} parameter(0)
+ %arg1 = f32[100,100]{1,0} parameter(1)
+ ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+ auto module_cfg = GetModuleConfigForTest();
+ auto debug_opts = module_cfg.debug_options();
+ size_t num_left1 = 0, num_left2 = 0;
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(kHlo, module_cfg));
+
+ {
+ // Run first with default settings (autotune level = 4), keep the number of
+ // algorithms left after autotuning
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool changed,
+ RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
+ module.get()));
+
+ AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
+ GemmAlgorithmPicker gpicker(cfg);
+ // Note that, we do not care if the algorithm index has been changed:
+ // the thing matters is the # of algorithms left after sorting out.
+ TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
+ num_left1 = gpicker.num_algorithms_left();
+ if (num_left1 < 2) {
+ GTEST_SKIP() << "Too few algorithms left after the first step";
+ }
+ }
+
+ // Clear cache before the second run!
+ AutotunerUtil::ClearAutotuneResults();
+ {
+ // Run once again but now with autotune level 5 and embarassingly tight
+ // rtol which shall disqualify most of the algorithms.
+
+ // Note that, we have "two sources of truth" for GemmAlgorithmPicker: i.e.,
+ // debug_options are used to initialize both 'HloModuleConfig' and also
+ // 'AutotuneConfig'.
+ debug_opts.set_xla_gpu_autotune_gemm_rtol(1e-12);
+ debug_opts.set_xla_gpu_autotune_level(5);
+ module->mutable_config().set_debug_options(debug_opts);
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool changed,
+ RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
+ module.get()));
+
+ AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
+ GemmAlgorithmPicker gpicker(cfg);
+ TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
+ num_left2 = gpicker.num_algorithms_left();
+ }
+ // Assert that we have fewer algorithms left after the second run.
+ ASSERT_TRUE(num_left1 > num_left2);
+}
+
+TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) {
+ constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+ %arg0 = f32[100,100]{1,0} parameter(0)
+ %arg1 = f32[100,100]{1,0} parameter(1)
+ ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+ auto module_cfg = GetModuleConfigForTest();
+ TF_ASSERT_OK_AND_ASSIGN(auto m,
+ ParseAndReturnVerifiedModule(kHlo, module_cfg));
+
+ bool changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(
+ changed,
+ RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
+ changed = false;
+ DebugOptions opts;
+ AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
+ TF_ASSERT_OK_AND_ASSIGN(changed,
+ RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
+ ASSERT_TRUE(changed);
+
+ AutotuneResults results;
+ TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
+ ASSERT_EQ(results.results_size(), 1);
+ auto& result = *results.mutable_results(0)->mutable_result();
+ int64_t old_algo_id = result.algorithm().algo_id();
+ int64_t new_algo_id = old_algo_id + 1;
+ result.mutable_gemm()->set_algorithm(new_algo_id);
+
+ AutotunerUtil::ClearAutotuneResults();
+ TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
+
+ // Now send the same module through GemmAlgorithmPicker again. The dot should
+ // have the new algorithm.
+ TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
+ changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(
+ changed,
+ RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
+ changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(changed,
+ RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
+ ASSERT_TRUE(changed);
+
+ SCOPED_TRACE(m->ToString());
+ HloInstruction* dot;
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
+
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ dot->backend_config<GpuBackendConfig>());
+ const GemmBackendConfig& config = gpu_config.gemm_backend_config();
+ EXPECT_EQ(config.selected_algorithm(), new_algo_id);
+}
+
+TEST_P(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) {
+ constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+ %arg0 = f32[100,100]{1,0} parameter(0)
+ %arg1 = f32[100,100]{1,0} parameter(1)
+ ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto m, ParseAndReturnVerifiedModule(kHlo, GetModuleConfigForTest()));
+
+ bool changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(
+ changed,
+ RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
+ changed = false;
+
+ DebugOptions opts;
+ AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
+
+ TF_ASSERT_OK_AND_ASSIGN(changed,
+ RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
+ ASSERT_TRUE(changed);
+
+ AutotuneResults results;
+ TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
+ ASSERT_EQ(results.results_size(), 1);
+ auto& result = *results.mutable_results(0)->mutable_result();
+ int64_t old_algo_id = result.algorithm().algo_id();
+ int64_t new_algo_id = old_algo_id + 1;
+ result.mutable_gemm()->set_algorithm(new_algo_id);
+
+ AutotunerUtil::ClearAutotuneResults();
+ TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
+
+ auto module_cfg = GetModuleConfigForTest();
+ // Now send the same module through GemmAlgorithmPicker again. The dot should
+ // have the new algorithm.
+ TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
+ changed = false;
+
+ DevicelessConfig deviceless_config{gpu_device_desc()};
+ AutotuneConfig deviceless_cfg{deviceless_config, opts};
+ TF_ASSERT_OK_AND_ASSIGN(changed,
+ RunHloPass(GemmRewriter(gpu_comp(),
+ /*toolkit_version=*/12040),
+ m.get()));
+ changed = false;
+ TF_ASSERT_OK_AND_ASSIGN(
+ changed, RunHloPass(GemmAlgorithmPicker(deviceless_cfg), m.get()))
+ ASSERT_TRUE(changed);
+
+ SCOPED_TRACE(m->ToString());
+ HloInstruction* dot;
+
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
+
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ dot->backend_config<GpuBackendConfig>());
+ const GemmBackendConfig& config = gpu_config.gemm_backend_config();
+
+ EXPECT_EQ(config.selected_algorithm(), new_algo_id);
+}
+
+INSTANTIATE_TEST_SUITE_P(GemmAlgorithmPickerTestSuite, GemmAlgorithmPickerTest,
+ ::testing::Bool());
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
new file mode 100644
index 0000000..e480f8b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
@@ -0,0 +1,1294 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h"
+
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "third_party/gpus/cuda/include/cublas_v2.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/pjrt/distributed/key_value_store_interface.h"
+#include "xla/primitive_util.h"
+#include "xla/service/algorithm_util.h"
+#include "xla/service/dump.h"
+#include "xla/service/executable.h"
+#include "xla/service/float_normalization.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/service/gpu/gpu_float_support.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/split_k_gemm_rewriter.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/gpu/transforms/priority_fusion.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/tsl/util/proto/proto_utils.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/lib/core/bits.h"
+#include "tsl/platform/blocking_counter.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/protobuf.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/threadpool.h"
+#include "tsl/profiler/lib/scoped_annotation.h"
+
+// Log levels used in this file:
+// VLOG(1): Overview
+// VLOG(2): Autotuning progress
+// VLOG(3): Autotuning progress - more frequent
+// VLOG(4): Print all fusions
+// VLOG(5): Profiling information for every tiling
+// VLOG(10): Print fusion computations and each configuration
+
+// TODO(b/317016172): Update usages of TritonGemmConfig to use newly exposed
+// parameters.
+
+namespace xla {
+namespace gpu {
+
+using Config = GemmFusionAutotunerImpl::Config;
+using TilingConfigs = GemmFusionAutotunerImpl::TilingConfigs;
+using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
+
+namespace {
+
+// Minimum tile size.
+constexpr int kMinTileSize = 16;
+
+// Default tiling when autotuning is disabled.
+constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4};
+
+// Split-K is enabled when the estimate number of waves is lower than the limit.
+constexpr int kMaxWavesForSplitK = 5;
+
+// Search space for exhaustive matmul autotuning.
+constexpr std::array<int, 6> kBlockSizes = {16, 32, 64, 128, 256, 512};
+constexpr std::array<int, 4> kNumStages = {1, 2, 3, 4};
+constexpr std::array<int, 4> kNumWarps = {2, 4, 8, 16};
+constexpr std::array<int, 5> kSplitK = {1, 2, 4, 8, 16};
+constexpr std::array<int, 5> kNumCtas = {1, 2, 4, 8, 16};
+
+using AutoTuneCacheKeyCount = absl::flat_hash_map<AutotuneCacheKey, uint64_t>;
+
+class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config)
+ : config_(config) {}
+
+ absl::Status HandleFusion(HloInstruction* hlo) override {
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ hlo->backend_config<GpuBackendConfig>());
+ FusionBackendConfig& backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+ if (backend_config.kind() != kTritonGemmFusionKind &&
+ backend_config.kind() != kCuDnnFusionKind) {
+ return absl::OkStatus();
+ }
+
+ VLOG(4) << "Processing " << hlo->ToString();
+ if (!backend_config.has_triton_gemm_config() &&
+ !backend_config.has_cudnn_fusion_config()) {
+ TF_ASSIGN_OR_RETURN(
+ AutotuneResult autotune_result,
+ AutotunerUtil::Autotune(
+ hlo, config_, [&]() -> absl::StatusOr<AutotuneResult> {
+ if (config_.IsDeviceless()) {
+ return absl::InternalError(absl::StrCat(
+ "Expect autotune result cache hit for deviceless "
+ "compilation (HLO: ",
+ hlo->ToString(), ")"));
+ }
+ return absl::InternalError("Expect autotune result cache hit.");
+ }));
+ VLOG(4) << "Result: " << autotune_result.ShortDebugString();
+
+ if (autotune_result.has_triton()) {
+ *backend_config.mutable_triton_gemm_config() = autotune_result.triton();
+ TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
+ } else if (autotune_result.has_gemm()) {
+ // Falling back to cuBLAS: Converting the fusion to a Call, so that it
+ // can be inlined back again.
+ HloComputation* const computation = hlo->parent();
+ HloInstruction* const call = computation->AddInstruction(
+ HloInstruction::CreateCall(hlo->shape(), hlo->operands(),
+ hlo->fused_instructions_computation()));
+ TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call));
+ hlo = call;
+ } else {
+ CHECK(autotune_result.has_algorithm());
+ backend_config.set_kind(std::string(kCuDnnFusionKind));
+ backend_config.mutable_cudnn_fusion_config()->set_plan_id(
+ autotune_result.algorithm().algo_id());
+ TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
+ }
+ }
+
+ if (backend_config.has_triton_gemm_config()) {
+ TF_ASSIGN_OR_RETURN(
+ const TritonGemmConfig config,
+ TritonGemmConfig::FromProto(backend_config.triton_gemm_config()));
+ if (config.split_k > 1) {
+ TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config));
+ }
+ }
+
+ MarkAsChanged();
+ return absl::OkStatus();
+ }
+
+ private:
+ AutotuneConfig config_;
+};
+
+class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {
+ public:
+ explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl)
+ : impl_(impl) {}
+
+ // Find configurations to tune.
+ absl::StatusOr<TilingConfigs> CollectGemmConfigSets(
+ const HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
+ error_out_on_cache_miss_ =
+ module->config()
+ .debug_options()
+ .xla_gpu_require_complete_aot_autotune_results();
+ gemm_config_sets_.clear();
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_RETURN_IF_ERROR(computation->Accept(this));
+ }
+ return std::move(gemm_config_sets_);
+ }
+
+ AutoTuneCacheKeyCount GetFusionsCount() {
+ return std::move(fusion_count_map_);
+ }
+
+ absl::Status HandleFusion(const HloInstruction* hlo) override {
+ const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ hlo->backend_config<GpuBackendConfig>());
+ const FusionBackendConfig& backend_config =
+ gpu_config.fusion_backend_config();
+
+ AutotuneCacheKey key = AutotunerUtil::GetKey(hlo, impl_->GetConfig());
+
+ auto [iterator, inserted] = fusion_count_map_.insert({key, 1});
+ if (!inserted) {
+ ++(iterator->second);
+ }
+
+ TF_ASSIGN_OR_RETURN(bool is_in_cache,
+ AutotunerUtil::IsInCache(key, impl_->GetConfig()));
+ if (is_in_cache || handled_fusions_.contains(key)) {
+ return absl::OkStatus();
+ }
+
+ bool missing_config = (backend_config.kind() == kTritonGemmFusionKind &&
+ !backend_config.has_triton_gemm_config()) ||
+ (backend_config.kind() == kCuDnnFusionKind &&
+ !backend_config.has_cudnn_fusion_config());
+ if (missing_config) {
+ if (error_out_on_cache_miss_) {
+ return absl::NotFoundError(absl::StrCat(
+ "Complete autotuning results are required, but no cache result "
+ "found for key: ",
+ key.ToString()));
+ }
+
+ TF_ASSIGN_OR_RETURN(std::vector<Config> configs,
+ impl_->GenerateConfigs(*fusion));
+ gemm_config_sets_.push_back({fusion, std::move(configs)});
+ }
+
+ handled_fusions_.insert(key);
+ return absl::OkStatus();
+ }
+
+ absl::Status DefaultAction(const HloInstruction* hlo) override {
+ return absl::OkStatus();
+ }
+
+ private:
+ bool error_out_on_cache_miss_;
+ GemmFusionAutotunerImpl* impl_;
+ TilingConfigs gemm_config_sets_;
+ AutoTuneCacheKeyCount fusion_count_map_;
+ absl::flat_hash_set<AutotuneCacheKey> handled_fusions_;
+};
+
+struct TileSizeLimit {
+ int block_m = 0;
+ int block_n = 0;
+ int block_k = 0;
+};
+
+absl::StatusOr<TileSizeLimit> GetLimits(const HloDotInstruction& dot) {
+ TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_lhs,
+ NonContractingDimensionIndex(dot, /*operand_number=*/0));
+ TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_rhs,
+ NonContractingDimensionIndex(dot, /*operand_number=*/1));
+ TF_ASSIGN_OR_RETURN(int64_t contracting_index,
+ ContractingDimensionIndex(dot, /*operand_number=*/1));
+ // This is not a sharp upper limit, the actual m value can be much smaller
+ // based on how much of the m dimension is physically contiguous.
+ // TODO(tdanyluk): Get the exact m value by running a TritonFusionAnalysis.
+ const int max_m = tsl::NextPowerOfTwoS64(
+ dot.operand(0)->shape().dimensions(non_contracting_index_lhs));
+ // Theoretically the same is true as for m, but that is not possible in
+ // practice with the current implementation.
+ const int max_n = tsl::NextPowerOfTwoS64(
+ dot.operand(1)->shape().dimensions(non_contracting_index_rhs));
+ // This is before doing the split-k transform.
+ const int max_k = tsl::NextPowerOfTwoS64(
+ dot.operand(1)->shape().dimensions(contracting_index));
+
+ return TileSizeLimit{
+ /*block_m=*/std::max(max_m, kMinTileSize),
+ /*block_n=*/std::max(max_n, kMinTileSize),
+ /*block_k=*/std::max(max_k, kMinTileSize),
+ };
+}
+
+int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; }
+
+absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
+ const TritonGemmConfig& config,
+ const se::DeviceDescription& gpu_device_info,
+ const HloFusionInstruction* fusion, DebugOptions debug_opts,
+ bool allow_filtering_kernels_spilling_registers) {
+ std::unique_ptr<HloModule> new_module =
+ ExtractInstructionIntoNewModule(*fusion);
+ if (!allow_filtering_kernels_spilling_registers) {
+ debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
+ false);
+ }
+ new_module->mutable_config().set_debug_options(debug_opts);
+
+ HloComputation* entry_computation = new_module->entry_computation();
+ HloInstruction* cloned_dot_fusion = entry_computation->root_instruction();
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ cloned_dot_fusion->backend_config<GpuBackendConfig>());
+ FusionBackendConfig& backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+
+ *backend_config.mutable_triton_gemm_config() = config.ToProto();
+ TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(gpu_config));
+
+ if (config.split_k > 1) {
+ TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config));
+ for (PrimitiveType type :
+ {BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
+ GpuFloatSupport float_support(gpu_device_info.cuda_compute_capability(),
+ type);
+ FloatNormalization float_normalization(&float_support);
+ TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status());
+ }
+
+ auto shape_size_function = [&](const Shape& shape) {
+ // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the
+ // pointer size is used only to determine the size of tuple types. We
+ // shouldn't have any tuples in the autotuned module, so it's safe to use
+ // a constant here, instead of piping the real value.
+ constexpr int64_t kPointerSize = 8;
+ return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+ };
+ PriorityFusion priority_fusion(
+ /*thread_pool=*/nullptr, gpu_device_info,
+ GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function,
+ /*per_second_rates=*/{},
+ /*count_multiple_input_accesses=*/true});
+ TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status());
+
+ // If the priority fusion pass above skipped some instructions, turn them
+ // into fusions.
+ FusionWrapper fusion_wrapper;
+ TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status());
+ }
+ return new_module;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> CublasGemmAutotuneExtractor(
+ const AutotuneConfig& config, const int32_t toolkit_version,
+ const HloFusionInstruction* fusion, const DebugOptions& debug_opts) {
+ const HloComputation* fusion_computation =
+ fusion->called_computations().at(0);
+ std::unique_ptr<HloModule> new_module =
+ ExtractComputationIntoNewModule(*fusion_computation);
+ new_module->mutable_config().set_debug_options(debug_opts);
+
+ auto* dot = hlo_query::GetFirstInstructionWithOpcode(
+ *new_module->entry_computation(), HloOpcode::kDot);
+ // Substitute algorithms, which are not supported by cuBLAS for the check, but
+ // don't use cuBlas in the end. This assumes that the substituting algorithm
+ // has result which are close enough for the check in this file.
+ if (dot->precision_config().algorithm() ==
+ PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
+ dot->precision_config().algorithm() ==
+ PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) {
+ dot->mutable_precision_config()->set_algorithm(
+ PrecisionConfig::ALG_DOT_F32_F32_F32);
+ }
+
+ for (GemmRewriterOptions::DType dtype :
+ {GemmRewriterOptions::DType::kFp8Only,
+ GemmRewriterOptions::DType::kNonFp8Only}) {
+ GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version,
+ GemmRewriterOptions{dtype});
+ GpuInstructionFusion fusion_pass(
+ /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription());
+ TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status());
+ TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status());
+ }
+ // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS
+ // performance. It is probably not needed on Ampere and later because cuBLAS
+ // ignores the algorithm parameter for those targets. If we run
+ // GemmAlgorithmPicker, we probably should not run this in parallel with other
+ // compilations.
+ return new_module;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> FusionExtractor(
+ const HloFusionInstruction& fusion, const DebugOptions& debug_opts) {
+ std::unique_ptr<HloModule> module = ExtractInstructionIntoNewModule(fusion);
+ module->mutable_config().set_debug_options(debug_opts);
+ return module;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> CuDnnFusionExtractor(
+ const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
+ const int plan_id) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ FusionExtractor(fusion, debug_opts));
+
+ GpuBackendConfig gpu_config;
+ FusionBackendConfig& backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+ backend_config.set_kind(std::string(kCuDnnFusionKind));
+ // Provided a plan ID the autotuner just compiles one plan.
+ backend_config.mutable_cudnn_fusion_config()->set_plan_id(plan_id);
+ TF_RETURN_IF_ERROR(
+ module->entry_computation()->root_instruction()->set_backend_config(
+ gpu_config));
+ return module;
+}
+
+bool IsFusionKind(const HloInstruction& hlo, absl::string_view kind) {
+ auto gpu_config = hlo.backend_config<GpuBackendConfig>();
+ if (!gpu_config.ok()) {
+ return false;
+ }
+ return gpu_config->fusion_backend_config().kind() == kind;
+}
+
+int GetCuDnnPlanCount(const HloInstruction& hlo,
+ const AutotuneConfig& autotune_config) {
+ if (auto gpu_config = hlo.backend_config<GpuBackendConfig>();
+ !gpu_config.ok() ||
+ gpu_config->fusion_backend_config().has_cudnn_fusion_config()) {
+ return {};
+ }
+ return CuDnnFusionCompiler::GetAvailablePlanCount(
+ *autotune_config.GetExecutor(), *DynCast<HloFusionInstruction>(&hlo));
+}
+
+AutotuneResult FromConfig(const Config& config) {
+ AutotuneResult res;
+ if (std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(config)) {
+ res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT);
+ } else if (std::holds_alternative<GemmFusionAutotunerImpl::CuDnnConfig>(
+ config)) {
+ res.mutable_algorithm()->set_algo_id(
+ std::get<GemmFusionAutotunerImpl::CuDnnConfig>(config).plan_id);
+ } else if (std::holds_alternative<TritonGemmConfig>(config)) {
+ *res.mutable_triton() = std::get<TritonGemmConfig>(config).ToProto();
+ } else {
+ LOG(FATAL) << "Unsupported config type: " << config.index();
+ }
+ return res;
+}
+
+absl::Status DumpOriginalFusion(AutotunerCompileUtil& util,
+ const HloFusionInstruction& fusion,
+ int fusion_id) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ util.ExtractModule([&](const DebugOptions& debug_opts) {
+ return FusionExtractor(fusion, debug_opts);
+ }));
+ module->set_name(std::string(fusion.name()));
+ // Using the original module for its debug info and name in the first
+ // parameter. It's better to include the name of both the original module
+ // and the extracted module, to avoid name clashes.
+ std::string rendered_graph_name =
+ absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".dot");
+ std::string rendered_graph = RenderGraph(rendered_graph_name, *module,
+ RenderedGraphFormat::kDot, true);
+ DumpToFileInDir(
+ /*module=*/*fusion.GetModule(),
+ /*file_prefix=*/"",
+ /*file_suffix=*/rendered_graph_name,
+ /*contents=*/rendered_graph);
+ DumpToFileInDirOrStdout(
+ /*module=*/*fusion.GetModule(),
+ /*file_prefix=*/"",
+ /*file_suffix=*/
+ absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".txt"),
+ /*contents=*/module->ToString());
+ return absl::OkStatus();
+}
+
+absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config,
+ const int32_t toolkit_version,
+ AutotunerCompileUtil& util,
+ const AutotuneResult result,
+ const HloFusionInstruction* fusion,
+ int fusion_id) {
+ TritonGemmConfig triton_gemm_config;
+ if (result.has_triton()) {
+ TF_ASSIGN_OR_RETURN(triton_gemm_config,
+ TritonGemmConfig::FromProto(result.triton()));
+ }
+ const se::DeviceDescription& device_desc =
+ autotune_config.GetExecutor()->GetDeviceDescription();
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> module,
+ util.ExtractModule([&](const DebugOptions& debug_opts) {
+ if (result.has_algorithm()) {
+ return CuDnnFusionExtractor(*fusion, debug_opts,
+ result.algorithm().algo_id());
+ } else if (result.has_triton()) {
+ return TritonGemmAutotuneExtractor(
+ triton_gemm_config, device_desc, fusion, debug_opts,
+ /*allow_filtering_kernels_spilling_registers=*/true);
+ } else if (result.has_gemm()) {
+ return CublasGemmAutotuneExtractor(autotune_config, toolkit_version,
+ fusion, debug_opts);
+ } else {
+ LOG(FATAL) << "Unknown result type: " << result.DebugString();
+ }
+ }));
+ module->set_name(std::string(fusion->name()));
+ // Using the original module for its debug info and name in the first
+ // parameter. It's better to include the name of both the original module
+ // and the extracted module, to avoid name clashes.
+ DumpToFileInDirOrStdout(
+ /*module=*/*fusion->GetModule(),
+ /*file_prefix=*/"",
+ /*file_suffix=*/
+ absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(),
+ ".optimized.txt"),
+ /*contents=*/module->ToString());
+ return absl::OkStatus();
+}
+
+std::string Serialize(const Config& config) {
+ if (auto triton_config = std::get_if<TritonGemmConfig>(&config)) {
+ tsl::protobuf::TextFormat::Printer printer;
+ printer.SetSingleLineMode(true);
+ std::string result;
+ printer.PrintToString(triton_config->ToProto(), &result);
+ return result;
+ }
+ return GemmFusionAutotunerImpl::ToString(config);
+}
+
+} // anonymous namespace
+
+// Methods required for sorting the configs.
+bool GemmFusionAutotunerImpl::CuBlasConfig::operator<(
+ const CuBlasConfig& other) const {
+ return false;
+}
+bool GemmFusionAutotunerImpl::CuDnnConfig::operator<(
+ const CuDnnConfig& other) const {
+ return plan_id < other.plan_id;
+}
+
+bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const {
+ return debug_options_.xla_gpu_autotune_level() > 0 &&
+ !debug_options_.xla_gpu_deterministic_ops();
+}
+
+/*static*/ std::string GemmFusionAutotunerImpl::ToString(const Config& config) {
+ if (std::holds_alternative<TritonGemmConfig>(config)) {
+ return std::get<TritonGemmConfig>(config).ToString();
+ } else if (std::holds_alternative<CuDnnConfig>(config)) {
+ return absl::StrFormat("cuDNN plan %d",
+ std::get<CuDnnConfig>(config).plan_id);
+ } else if (std::holds_alternative<CuBlasConfig>(config)) {
+ return "reference (cublas)";
+ } else {
+ LOG(FATAL) << "Unsupported config type: " << config.index();
+ }
+}
+
+absl::StatusOr<std::vector<Config>> GemmFusionAutotunerImpl::GenerateConfigs(
+ const HloFusionInstruction& fusion) {
+ const HloDotInstruction* dot =
+ Cast<HloDotInstruction>(hlo_query::GetFirstInstructionWithOpcode(
+ *fusion.called_computations().at(0), HloOpcode::kDot));
+
+ // Add cuBLAS reference config, if available.
+ std::vector<Config> configs;
+ if (algorithm_util::IsSupportedByCublasOrCublasLt(
+ dot->precision_config().algorithm()) &&
+ !dot->sparse_operands() && IsAutotuningEnabled()) {
+ configs.push_back(CuBlasConfig{});
+ }
+
+ // Add cuDNN plans, if available.
+ bool is_hopper =
+ !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
+ bool is_cudnn_enabled =
+ debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
+ GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
+ if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
+ (IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
+ algorithm_util::IsSupportedByCudnn(
+ dot->precision_config().algorithm()) &&
+ !dot->sparse_operands() && IsAutotuningEnabled())) {
+ const int plan_count = GetCuDnnPlanCount(fusion, config_);
+ for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
+ configs.push_back(CuDnnConfig{plan_id});
+ }
+ }
+ if (IsFusionKind(fusion, kCuDnnFusionKind)) {
+ if (!IsAutotuningEnabled()) {
+ configs.push_back(CuDnnConfig{-1});
+ }
+ return configs;
+ }
+
+ // Add triton configs.
+ TF_ASSIGN_OR_RETURN(std::vector<TritonGemmConfig> triton_configs,
+ GenerateTritonConfigs(*dot));
+ for (TritonGemmConfig& config : triton_configs) {
+ configs.push_back(std::move(config));
+ }
+ return configs;
+}
+
+absl::StatusOr<std::vector<TritonGemmConfig>>
+GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
+ // Retrieve the minimum bit-width participating in the dot. This is needed
+ // to avoid autotuning configurations that are not supported by Triton. This
+ // is used to restrict the values for tile_k.
+ std::vector<const HloInstruction*> converts =
+ HloBfsFindAll({&dot}, [&](const HloInstruction* node) {
+ return node->opcode() == HloOpcode::kConvert;
+ });
+ int minBitWidth = primitive_util::BitWidth(dot.shape().element_type());
+ for (auto convert : converts) {
+ auto in_type = convert->operand(0)->shape().element_type();
+ auto out_type = convert->shape().element_type();
+ minBitWidth = std::min({minBitWidth, primitive_util::BitWidth(in_type),
+ primitive_util::BitWidth(out_type)});
+ }
+
+ std::vector<TritonGemmConfig> result_configs;
+ TF_ASSIGN_OR_RETURN(TileSizeLimit limits, GetLimits(dot));
+
+ // Generate the list of configurations (once).
+ if (triton_configs_.empty()) {
+ triton_configs_ = !IsAutotuningEnabled()
+ ? std::vector(1, kDefaultGemmTiling)
+ : debug_options_.xla_gpu_exhaustive_tiling_search()
+ ? GetExhaustiveTritonConfigs()
+ : GetDefaultTritonConfigs();
+ }
+
+ // Avoid autotuning tiny fusions.
+ constexpr int kMinGemmElements = 32 * 32;
+ bool small_dot =
+ ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements &&
+ ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements;
+ std::vector<TritonGemmConfig> triton_configs =
+ small_dot ? std::vector(1, kDefaultGemmTiling) : triton_configs_;
+
+ // Split-K optimization enables more even utilization of a GPU in cases
+ // where tiling just the non-contracting dimensions of a GEMM does not create
+ // a sufficient number of thread block programs to occupy all available cores.
+ // Around 5 full waves completely avoid the need for split-K.
+ // n_tiles = split_k * (M * N) / (block_m * block_n)
+ const int kCoreCount =
+ !config_.IsDeviceless()
+ ? config_.GetExecutor()->GetDeviceDescription().core_count()
+ : 100; // some sensible default
+ const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount;
+ const int64_t result_size = ShapeUtil::ElementsIn(dot.shape());
+
+ // Triton configurations are adjusted and deduplicated.
+ absl::flat_hash_set<TritonGemmConfig> added;
+ bool is_hopper =
+ !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
+ for (TritonGemmConfig& config : triton_configs) {
+ config.block_m = std::min(config.block_m, limits.block_m);
+ config.block_n = std::min(config.block_n, limits.block_n);
+ config.block_k = std::min(config.block_k, limits.block_k);
+ int max_split_k = 1;
+ if (debug_options_.xla_gpu_enable_split_k_autotuning()) {
+ int64_t ratio = kSufficientNumberOfTiles * config.block_m *
+ config.block_n / result_size;
+ max_split_k = 1 << std::max<int>(tsl::Log2Floor64(ratio), 0);
+ }
+ config.split_k = std::min(config.split_k, max_split_k);
+
+ // TODO(b/337839570): Triton currently has a limitation where it crashes
+ // on small block_k values depending on the bit-width of the inputs to the
+ // dot. The logic below accounts for this limitation.
+ constexpr int kLdmatrixGranularity = 256;
+ config.block_k =
+ std::max(config.block_k, kLdmatrixGranularity / minBitWidth);
+
+ // Sparse meta should have at least one element per thread.
+ // Note: only 2:4 structured sparsity is currently supported.
+ if (dot.sparse_operands()) {
+ if (is_hopper) {
+ config.block_m = std::max(config.block_m, 64);
+ config.num_warps = std::max(config.num_warps, 4);
+ }
+ config.block_k = std::max(
+ config.block_k,
+ 2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
+ int meta_elements = config.block_m * config.block_k / 16;
+ config.num_warps =
+ std::min<int>(config.num_warps, meta_elements / WarpSize());
+ }
+
+ if (added.insert(config).second) {
+ result_configs.push_back(config);
+ }
+ }
+ return result_configs;
+}
+
+absl::StatusOr<absl::flat_hash_map<
+ const HloFusionInstruction*,
+ std::vector<GemmFusionAutotunerImpl::ExecutableCandidate>>>
+GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
+ const TilingConfigs& task) {
+ tsl::profiler::ScopedAnnotation annotation("XlaAutotunerCompilation");
+ absl::Mutex results_mu;
+ absl::flat_hash_map<const HloFusionInstruction*,
+ std::vector<ExecutableCandidate>>
+ results;
+ if (task.empty()) {
+ return results;
+ }
+
+ const int log_every_n = GetLogEveryN();
+ int64_t config_count = 0;
+ for (const auto& [unused, configs] : task) {
+ config_count += configs.size();
+ }
+
+ std::atomic<int> done_count = 0;
+ std::atomic<int> good_count = 0;
+ auto log = [&](bool success) {
+ const int done_so_far = done_count.fetch_add(1) + 1;
+ const int good_so_far =
+ success ? good_count.fetch_add(1) + 1 : good_count.load();
+ if (done_so_far % log_every_n == 0) {
+ VLOG(2) << "Compiled " << done_so_far << " of " << config_count
+ << " configs (successful: " << good_so_far << ")";
+ }
+ };
+
+ auto compile = [&](const HloFusionInstruction* fusion, const Config& config,
+ bool allow_filtering_kernels_spilling_registers)
+ -> absl::StatusOr<bool> {
+ std::unique_ptr<Executable> executable;
+ if (std::holds_alternative<TritonGemmConfig>(config)) {
+ TF_ASSIGN_OR_RETURN(
+ executable, compile_util.Compile([&](const DebugOptions& opts) {
+ return TritonGemmAutotuneExtractor(
+ std::get<TritonGemmConfig>(config),
+ config_.GetExecutor()->GetDeviceDescription(), fusion, opts,
+ allow_filtering_kernels_spilling_registers);
+ }));
+ } else if (std::holds_alternative<CuDnnConfig>(config)) {
+ executable =
+ compile_util
+ .Compile([&](const DebugOptions& opts) {
+ return CuDnnFusionExtractor(
+ *fusion, opts, std::get<CuDnnConfig>(config).plan_id);
+ })
+ .value_or(nullptr);
+ } else if (std::holds_alternative<CuBlasConfig>(config)) {
+ TF_ASSIGN_OR_RETURN(executable,
+ compile_util.Compile([&](const DebugOptions& opts) {
+ return CublasGemmAutotuneExtractor(
+ config_, toolkit_version_, fusion, opts);
+ }));
+ } else {
+ LOG(FATAL) << "Unsupported config type: " << config.index();
+ }
+ if (executable != nullptr) {
+ absl::MutexLock lock(&results_mu);
+ results[fusion].push_back({config, std::move(executable)});
+ return true;
+ }
+ return false;
+ };
+
+ // If the thread pool has only one thread, then it is actually slower to
+ // offload the tasks there.
+ if (thread_pool_ && thread_pool_->NumThreads() > 1 &&
+ debug_options_.xla_gpu_force_compilation_parallelism() != 1) {
+ if (task.size() == 1) {
+ absl::string_view fusion_name = task.begin()->first->name();
+ VLOG(1) << "Compiling " << config_count << " configs for " << fusion_name
+ << " on " << thread_pool_->NumThreads() << " threads.";
+ } else {
+ VLOG(1) << "Compiling " << config_count << " configs for " << task.size()
+ << " fusions on " << thread_pool_->NumThreads() << " threads.";
+ }
+
+ tsl::BlockingCounter counter(config_count);
+ for (const auto& key_value : task) {
+ const HloFusionInstruction* fusion = key_value.first;
+ const std::vector<Config>& gemm_config_set = key_value.second;
+
+ VLOG(10) << "Compiling fusion: " << fusion->name();
+ VLOG(10) << "Dumping fusion computation: "
+ << fusion->called_computation()->ToString();
+ for (const Config& config : gemm_config_set) {
+ thread_pool_->Schedule([&, fusion] {
+ VLOG(10) << "Trying configuration forceable through: "
+ "--xla_gpu_override_gemm_autotuner='"
+ << Serialize(config) << "'";
+ VLOG(10) << "WARNING: you are running in multithreaded-mode, the "
+ "last configuration printed out might not be the one "
+ "causing issues! Use "
+ "--xla_gpu_force_compilation_parallelism=1 to fix.";
+ absl::StatusOr<bool> has_executable =
+ compile(fusion, config, gemm_config_set.size() > 1);
+ TF_CHECK_OK(has_executable.status())
+ << "Failure occured when compiling fusion " << fusion->name()
+ << " with config '" << ToString(config)
+ << "'\nFused HLO computation:\n"
+ << fusion->fused_instructions_computation()->ToString();
+ log(has_executable.value());
+ counter.DecrementCount();
+ });
+ }
+ }
+ counter.Wait();
+ } else {
+ if (task.size() == 1) {
+ absl::string_view fusion_name = task.begin()->first->name();
+ LOG(WARNING) << "Compiling " << config_count << " configs for "
+ << fusion_name << " on a single thread.";
+ } else {
+ LOG(WARNING) << "Compiling " << config_count << " configs for "
+ << task.size() << " fusions on a single thread.";
+ }
+
+ for (const auto& [fusion, gemm_config_set] : task) {
+ VLOG(10) << "Compiling fusion: " << fusion->name();
+ VLOG(10) << "Dumping fusion computation: "
+ << fusion->called_computation()->ToString();
+ for (const Config& config : gemm_config_set) {
+ VLOG(10) << "Trying configuration forceable through: "
+ "--xla_gpu_override_gemm_autotuner='"
+ << Serialize(config) << "'";
+ TF_ASSIGN_OR_RETURN(
+ bool has_executable,
+ compile(fusion, config, gemm_config_set.size() > 1));
+ log(has_executable);
+ }
+ }
+ }
+
+ VLOG(1) << "Done compiling (successful: " << good_count.load() << ").";
+ return results;
+}
+
+absl::StatusOr<std::vector<AutotuneResult>> GemmFusionAutotunerImpl::Profile(
+ AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
+ absl::Span<const ExecutableCandidate> candidates) {
+ const HloComputation* fusion_computation = fusion.called_computations().at(0);
+
+ se::StreamExecutor* stream_exec = config_.GetExecutor();
+ if (!stream_exec->SynchronizeAllActivity()) {
+ return Internal("Failed to synchronize GPU for autotuning.");
+ }
+ tsl::profiler::ScopedAnnotation annotation([&] {
+ return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
+ fusion.name());
+ });
+ se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
+ std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
+ if (allocator == nullptr) {
+ owned_allocator =
+ std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
+ allocator = owned_allocator.get();
+ }
+ TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+
+ const HloInstruction& root = *fusion_computation->root_instruction();
+ BufferComparator comparator(root.shape(),
+ debug_options_.xla_gpu_autotune_gemm_rtol());
+
+ TF_ASSIGN_OR_RETURN(auto rz_buffers,
+ RedzoneBuffers::FromInstruction(
+ *fusion_computation->FusionInstruction(), config_,
+ debug_options_, RedzoneBuffers::kAllInputs));
+
+ const int log_every_n = GetLogEveryN();
+ std::vector<AutotuneResult> results;
+ std::optional<ScopedShapedBuffer> reference_buffer;
+ for (const ExecutableCandidate& candidate : candidates) {
+ VLOG(5) << "Trying : " << ToString(candidate.config);
+ AutotuneResult res = FromConfig(candidate.config);
+
+ std::optional<ProfilingOutput> profiling_output;
+ if (IsAutotuningEnabled()) {
+ TF_ASSIGN_OR_RETURN(
+ profiling_output,
+ compile_util.ProfileExecutable(candidate.executable.get(), stream,
+ rz_buffers.input_buffers(),
+ rz_buffers.input_shapes()));
+ if (std::holds_alternative<CuBlasConfig>(candidate.config) &&
+ config_.should_check_correctness()) {
+ reference_buffer = std::move(profiling_output->output);
+ }
+
+ int ran_so_far = results.size() + 1;
+ if (ran_so_far % log_every_n == 0) {
+ VLOG(2) << "Ran " << ran_so_far << " configs of " << candidates.size()
+ << ".";
+ }
+ if (!profiling_output) {
+ VLOG(5) << "Skipping this tiling.";
+ continue;
+ }
+
+ VLOG(5) << "Running the kernel took: " << profiling_output->duration;
+ if (profiling_output->duration >= absl::Seconds(1)) {
+ LOG(WARNING) << "Slow kernel for "
+ << fusion.called_computations()[0]->ToString()
+ << " took: " << profiling_output->duration << ". "
+ << ToString(candidate.config);
+ }
+ *res.mutable_run_time() =
+ tsl::proto_utils::ToDurationProto(profiling_output->duration);
+ }
+
+ // Reference buffer is available when `config.should_check_correctness()`
+ // is set and reference executable was compiled.
+ if (reference_buffer.has_value() &&
+ !std::holds_alternative<CuBlasConfig>(candidate.config)) {
+ TF_ASSIGN_OR_RETURN(
+ se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
+ rz_buffers.RedzoneAllocator().CheckRedzones());
+ if (!rz_check_status.ok()) {
+ LOG(ERROR) << "Red zone modified";
+ res.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
+ res.mutable_failure()->set_msg(rz_check_status.RedzoneFailureMsg());
+ CHECK(!config_.should_crash_on_check_failure());
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ bool outputs_match,
+ comparator.CompareEqual(
+ stream, /*current=*/profiling_output->output.root_buffer(),
+ /*expected=*/reference_buffer->root_buffer()));
+ if (!outputs_match) {
+ const char kMessage[] =
+ "Results do not match the reference. This is likely a "
+ "bug/unexpected loss of precision.";
+ LOG(ERROR) << kMessage;
+ CHECK(!config_.should_crash_on_check_failure());
+ // WRONG_RESULT is not taken seriously by PickBestResult(), so
+ // use DISQUALIFIED.
+ res.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
+ res.mutable_failure()->set_msg(kMessage);
+ }
+ }
+ results.push_back(std::move(res));
+ }
+ VLOG(2) << "Done running.";
+ return results;
+}
+
+std::vector<TritonGemmConfig>
+GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
+ std::vector<TritonGemmConfig> configs;
+ se::CudaComputeCapability cc = GetComputeCapability();
+ bool tune_ctas =
+ debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
+
+ for (int num_stages : kNumStages) {
+ // Volta doesn't support num_stages > 2.
+ if (!cc.IsAtLeastAmpere() && num_stages > 2) {
+ break;
+ }
+ for (int tile_m : kBlockSizes) {
+ for (int tile_n : kBlockSizes) {
+ for (int tile_k : kBlockSizes) {
+ const int tile_lhs = tile_m * tile_k;
+ const int tile_rhs = tile_k * tile_n;
+ for (int num_warps : kNumWarps) {
+ // Each thread should read at least one input element.
+ if (num_warps * WarpSize() > std::min(tile_lhs, tile_rhs)) {
+ break;
+ }
+ for (int split_k : kSplitK) {
+ // Split-K autotuning may be disabled by a flag.
+ if (!debug_options_.xla_gpu_enable_split_k_autotuning() &&
+ split_k > 1) {
+ break;
+ }
+ for (int num_ctas : kNumCtas) {
+ // Clusters are only supported on Hopper.
+ // Autotuning this parameter is enabled by a flag.
+ if (!tune_ctas && num_ctas > 1) {
+ break;
+ }
+ if (num_ctas > num_warps) {
+ break;
+ }
+ configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k,
+ split_k, num_stages,
+ num_warps, num_ctas));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return configs;
+}
+
+std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
+ const {
+ using Config = TritonGemmConfig;
+ std::vector<Config> configs = {
+ Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
+ Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4),
+ Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
+ Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
+ Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
+ Config(64, 32, 64, 1, 2, 8)};
+ if (GetComputeCapability().IsAtLeastAmpere()) {
+ absl::c_copy(
+ std::vector<Config>{
+ Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8),
+ Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4),
+ Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4),
+ Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4),
+ Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4),
+ Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4),
+ Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4),
+ Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4),
+ Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8),
+ Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4),
+ Config(256, 256, 128, 1, 3, 8)},
+ std::back_inserter(configs));
+ }
+ if (GetComputeCapability().IsAtLeastHopper()) {
+ absl::c_copy(
+ std::vector<Config>{
+ Config(16, 32, 32, 8, 1, 2),
+ Config(16, 64, 128, 8, 1, 4),
+ Config(16, 64, 128, 16, 3, 4),
+ },
+ std::back_inserter(configs));
+ }
+ return configs;
+}
+
+absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts,
+ const AutotuningLogs& autotuning_logs) {
+ if (absl::string_view file_path = debug_opts.xla_gpu_dump_autotune_logs_to();
+ !file_path.empty()) {
+ std::string resolved_path;
+ if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
+ return FailedPrecondition("File path can not be resolved: %s", file_path);
+ }
+
+ std::string textproto;
+ tsl::protobuf::TextFormat::PrintToString(autotuning_logs, &textproto);
+
+ TF_RETURN_IF_ERROR(
+ tsl::WriteStringToFile(tsl::Env::Default(), resolved_path, textproto));
+ LOG(INFO) << "Autotune logs serialized to file: " << resolved_path;
+ }
+ return absl::OkStatus();
+}
+
+absl::Status GemmFusionAutotunerImpl::Autotune(
+ AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
+ AutoTuneCacheKeyCount fusion_count_map) {
+ TF_ASSIGN_OR_RETURN(auto executable_sets,
+ CompileAll(compile_util, gemm_config_sets));
+
+ // Sort the candidates to make their execution order well-defined for each
+ // fusion.
+ for (auto& [unused, candidates] : executable_sets) {
+ absl::c_sort(candidates, [](const auto& a, const auto& b) {
+ return a.config < b.config;
+ });
+ }
+
+ AutotuningLogs autotuning_logs;
+ int fusion_id = 0;
+ for (const auto& [fusion, candidates] : executable_sets) {
+ TF_ASSIGN_OR_RETURN(std::vector<AutotuneResult> results,
+ Profile(compile_util, *fusion, candidates));
+
+ // The reference config (if it exists) will be the first in the results,
+ // due to how sorting the variants work.
+ if (!debug_options_.xla_gpu_cublas_fallback() &&
+ results.front().has_gemm()) {
+ results.erase(results.begin());
+ }
+
+ const HloInstruction* root =
+ fusion->called_computations().at(0)->root_instruction();
+ TF_ASSIGN_OR_RETURN(
+ AutotuneResult best,
+ PickBestResult(results, root->ToString(), root->GetModule()->config()));
+ VLOG(2) << "Best time: "
+ << tsl::proto_utils::FromDurationProto(best.run_time());
+
+ if (debug_options_.xla_gpu_dump_autotuned_gemm_fusions()) {
+ TF_RETURN_IF_ERROR(DumpOriginalFusion(compile_util, *fusion, fusion_id));
+ TF_RETURN_IF_ERROR(DumpAutotunedFusion(
+ config_, toolkit_version_, compile_util, best, fusion, fusion_id++));
+ }
+
+ const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
+ TF_ASSIGN_OR_RETURN(
+ bool added, AutotunerUtil::AddResult(key, std::move(best), config_));
+ if (!added) {
+ // In the context of model server, concurrent autotuning is expected and
+ // insertion of identical autotuning keys is accepted.
+ LOG(WARNING) << "AutotunerUtil::AddResult already existed: "
+ << key.ToString();
+ }
+
+ if (!debug_options_.xla_gpu_dump_autotune_logs_to().empty()) {
+ auto autotuning_log = autotuning_logs.add_logs();
+ autotuning_log->set_fusion_name(std::string(fusion->name()));
+
+ for (const auto& autotune_result : results) {
+ auto log_result = autotuning_log->add_results();
+ log_result->CopyFrom(autotune_result);
+ }
+
+ if (auto fusion_key_count = fusion_count_map.find(key);
+ fusion_key_count != fusion_count_map.end()) {
+ auto fusion_key = fusion_key_count->first;
+ auto fusion_count = fusion_key_count->second;
+ autotuning_log->set_fusion_count(fusion_count);
+ }
+ }
+ }
+
+ TF_RETURN_IF_ERROR(DumpAutotuningLogs(debug_options_, autotuning_logs));
+
+ return absl::OkStatus();
+}
+
+// Trim the set of configs to what one rank has to run.
+static TilingConfigs TrimConfigs(const TilingConfigs& gemm_config_sets,
+ const int shard_index, const int shard_count) {
+ const uint64_t bucket_size =
+ (gemm_config_sets.size() + shard_count - 1) / shard_count;
+ const uint64_t start = bucket_size * shard_index;
+ const uint64_t end = std::min(start + bucket_size, gemm_config_sets.size());
+ if (start >= end) {
+ return {};
+ }
+ return TilingConfigs(gemm_config_sets.cbegin() + start,
+ gemm_config_sets.cbegin() + end);
+}
+
+// Exchange the results with the other ranks.
+absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store,
+ const int module_id, const int shard_index,
+ const int shard_count) {
+ AutotuneResults results;
+ TF_RETURN_IF_ERROR(AutotunerUtil::SerializeAutotuneResults(&results));
+ TF_ASSIGN_OR_RETURN(std::string results_str,
+ AutotuneResultsToString(results, true));
+ constexpr absl::string_view kKeyPrefix = "gemm_fusion_autotuning_results";
+ TF_RETURN_IF_ERROR(key_value_store.Set(
+ absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, shard_index),
+ results_str));
+ VLOG(2) << "Rank " << shard_index << ": published results";
+ for (int i = 0; i < shard_count; ++i) {
+ if (i == shard_index) {
+ continue;
+ }
+ VLOG(2) << "Rank " << shard_index << ": waiting for results from rank " << i
+ << " / " << shard_count;
+ TF_ASSIGN_OR_RETURN(
+ std::string autotune_results_str,
+ key_value_store.Get(
+ absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, i),
+ absl::InfiniteDuration()));
+ TF_RETURN_IF_ERROR(
+ AutotunerUtil::LoadAutotuneResults(autotune_results_str, true));
+ }
+ return absl::OkStatus();
+}
+
+absl::StatusOr<bool> GemmFusionAutotuner::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_SCOPED_LOGGING_TIMER("GEMM fusion autotuner");
+
+ const DebugOptions& debug_options = module->config().debug_options();
+ GemmFusionAutotunerImpl autotuner(config_, toolkit_version_, debug_options,
+ thread_pool_);
+ GemmConfigSetCollector gemm_config_set_collector(&autotuner);
+ TF_ASSIGN_OR_RETURN(TilingConfigs gemm_config_sets,
+ gemm_config_set_collector.CollectGemmConfigSets(
+ module, execution_threads));
+ const int total_fusion_count = gemm_config_sets.size();
+
+ AutoTuneCacheKeyCount fusion_count_map =
+ gemm_config_set_collector.GetFusionsCount();
+
+ if (!autotuner.IsAutotuningEnabled()) {
+ // Pick the first option for each gemm instead of autotuning.
+ for (const auto& [fusion, tilings] : gemm_config_sets) {
+ const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
+ AutotuneResult res = FromConfig(tilings[0]);
+ *res.mutable_run_time() =
+ tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
+ TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
+ }
+ } else if (!debug_options.xla_gpu_override_gemm_autotuner().empty()) {
+ // TODO(gflegar): support overriding with non-Triton configs (cuBLAS, cuDNN)
+ AutotuneResult::TritonGemmKey gemm_key;
+ CHECK(tsl::protobuf::TextFormat::ParseFromString(
+ debug_options.xla_gpu_override_gemm_autotuner(), &gemm_key));
+ VLOG(1) << "Overriding GEMM autotuner with the following config: "
+ << gemm_key.DebugString();
+ for (const auto& [fusion, unused] : gemm_config_sets) {
+ const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
+ AutotuneResult res;
+ *res.mutable_triton() = gemm_key;
+ *res.mutable_run_time() =
+ tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
+ TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
+ }
+ } else if (!config_.IsDeviceless()) {
+ TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
+ AutotunerCompileUtil::Create(config_, debug_options));
+ TF_RET_CHECK(opt_compile_util.has_value());
+ std::string correctness_check_str = config_.should_check_correctness()
+ ? "(with correctness check)"
+ : "(without correctness check)";
+
+ const bool shard_autotuning = debug_options.xla_gpu_shard_autotuning() &&
+ key_value_store_.process_count > 1 &&
+ total_fusion_count > 0;
+ if (shard_autotuning) {
+ if (key_value_store_.key_value_store == nullptr) {
+ return absl::FailedPreconditionError(
+ "Sharded autotuning requested but key-value store is missing.");
+ }
+ gemm_config_sets =
+ TrimConfigs(gemm_config_sets, key_value_store_.process_index,
+ key_value_store_.process_count);
+ }
+
+ VLOG(1) << absl::StrFormat(
+ "Shard %d / %d: autotuning %d / %d fusions for %s %s.",
+ key_value_store_.process_index + 1, key_value_store_.process_count,
+ gemm_config_sets.size(), total_fusion_count, module->name(),
+ correctness_check_str);
+ TF_RETURN_IF_ERROR(autotuner.Autotune(*opt_compile_util, gemm_config_sets,
+ std::move(fusion_count_map)));
+ VLOG(1) << "Done autotuning.";
+
+ if (shard_autotuning) {
+ TF_RETURN_IF_ERROR(ExchangeResults(
+ *key_value_store_.key_value_store, module->unique_id(),
+ key_value_store_.process_index, key_value_store_.process_count));
+ }
+ }
+
+ return GemmFusionAutotunerVisitor(config_).RunOnModule(module,
+ execution_threads);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h
new file mode 100644
index 0000000..b12fe00
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h
@@ -0,0 +1,147 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/pjrt/distributed/key_value_store_interface.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+
+// Find best tiling configuration for each triton fusion outlined.
+class GemmFusionAutotuner : public HloModulePass {
+ public:
+ explicit GemmFusionAutotuner(const AutotuneConfig& config,
+ const int32_t toolkit_version,
+ tsl::thread::ThreadPool* thread_pool,
+ const MultiProcessKeyValueStore& key_value_store)
+ : config_(config),
+ toolkit_version_(toolkit_version),
+ thread_pool_(thread_pool),
+ key_value_store_(key_value_store) {}
+
+ absl::string_view name() const override { return "triton-autotuner"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const AutotuneConfig config_;
+ const int32_t toolkit_version_;
+ tsl::thread::ThreadPool* thread_pool_;
+ MultiProcessKeyValueStore key_value_store_;
+};
+
+// Autotuner implementation.
+class GemmFusionAutotunerImpl {
+ public:
+ GemmFusionAutotunerImpl(const AutotuneConfig config,
+ const int32_t toolkit_version,
+ const DebugOptions debug_options,
+ tsl::thread::ThreadPool* thread_pool)
+ : config_(std::move(config)),
+ toolkit_version_(toolkit_version),
+ debug_options_(std::move(debug_options)),
+ thread_pool_(thread_pool) {}
+
+ struct CuBlasConfig {
+ bool operator<(const CuBlasConfig& other) const;
+ };
+ struct CuDnnConfig {
+ int64_t plan_id;
+ bool operator<(const CuDnnConfig& other) const;
+ };
+ using Config = std::variant<CuBlasConfig, CuDnnConfig, TritonGemmConfig>;
+ using TilingConfigs =
+ std::vector<std::pair<const HloFusionInstruction*, std::vector<Config>>>;
+
+ struct ExecutableCandidate {
+ Config config;
+ std::unique_ptr<Executable> executable;
+ };
+
+ // Generate all possible configs for a dot operation.
+ absl::StatusOr<std::vector<Config>> GenerateConfigs(
+ const HloFusionInstruction& fusion);
+ absl::StatusOr<std::vector<TritonGemmConfig>> GenerateTritonConfigs(
+ const HloDotInstruction& dot);
+
+ // Compile all executables for all fusions.
+ absl::StatusOr<absl::flat_hash_map<const HloFusionInstruction*,
+ std::vector<ExecutableCandidate>>>
+ CompileAll(AutotunerCompileUtil& compile_util, const TilingConfigs& task);
+
+ // Profile all executables for a fusion.
+ absl::StatusOr<std::vector<AutotuneResult>> Profile(
+ AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
+ absl::Span<const ExecutableCandidate> candidates);
+
+ // Autotune and save the results to the autotuning cache.
+ absl::Status Autotune(
+ AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
+ absl::flat_hash_map<AutotuneCacheKey, uint64_t> fusion_count_map);
+
+ // Helper methods.
+ const AutotuneConfig& GetConfig() const { return config_; }
+ bool IsAutotuningEnabled() const;
+ static std::string ToString(const Config& config);
+
+ private:
+ se::CudaComputeCapability GetComputeCapability() const {
+ return std::get<se::CudaComputeCapability>(
+ config_.GetGpuComputeCapability());
+ }
+
+ std::vector<TritonGemmConfig> GetDefaultTritonConfigs() const;
+ std::vector<TritonGemmConfig> GetExhaustiveTritonConfigs() const;
+
+ const AutotuneConfig config_;
+ const int32_t toolkit_version_;
+ const DebugOptions debug_options_;
+ tsl::thread::ThreadPool* thread_pool_;
+ std::vector<TritonGemmConfig> triton_configs_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
new file mode 100644
index 0000000..03239a1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
@@ -0,0 +1,999 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "xla/autotuning.pb.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/pjrt/distributed/key_value_store_interface.h"
+#include "xla/service/call_inliner.h"
+#include "xla/service/dump.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_description.pb.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/test_utils.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/cpu_info.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using HloExtractionTest = HloTestBase;
+
+TEST_F(HloExtractionTest, InstructionExtractionIsCorrect) {
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+triton_gemm_dot {
+ p0 = s8[10,10] parameter(0)
+ p1 = f32[10,10] parameter(1)
+ c0 = f32[10,10] convert(p0)
+ ROOT dot.0 = f32[10,10] dot(c0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY entry {
+ p0 = s8[10,10] parameter(0)
+ p1 = f32[10,10] parameter(1)
+ s = f32[10,10] sqrt(p1)
+ d = f32[10,10] fusion(p0, p1),
+ kind=kCustom, calls=triton_gemm_dot
+ ROOT r = f32[10,10] add(d, s)
+})")
+ .value();
+
+ std::unique_ptr<HloModule> extracted_module = ExtractInstructionIntoNewModule(
+ *module->entry_computation()->root_instruction()->operand(0));
+
+ // Destroy the original module to be sure that the extracted one has no
+ // dependency on it.
+ module.release();
+
+ EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+ EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 3);
+ TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false));
+}
+
+TEST_F(HloExtractionTest, ComputationExtractionIsCorrect) {
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+triton_gemm_dot {
+ p0 = s8[10,10] parameter(0)
+ p1 = f32[10,10] parameter(1)
+ c0 = f32[10,10] convert(p0)
+ ROOT dot.0 = f32[10,10] dot(c0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY entry {
+ p0 = s8[10,10] parameter(0)
+ p1 = f32[10,10] parameter(1)
+ s = f32[10,10] sqrt(p1)
+ d = f32[10,10] fusion(p0, p1),
+ kind=kCustom, calls=triton_gemm_dot
+ ROOT r = f32[10,10] add(d, s)
+})")
+ .value();
+
+ std::unique_ptr<HloModule> extracted_module =
+ ExtractComputationIntoNewModule(*module->entry_computation()
+ ->root_instruction()
+ ->operand(0)
+ ->fused_instructions_computation());
+
+ // Destroy the original module to be sure that the extracted one has no
+ // dependency on it.
+ module.release();
+
+ EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter())));
+ EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 4);
+ TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false));
+}
+
+class StatelessAutotunerTest : public HloTestBase {
+ public:
+ StatelessAutotunerTest()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {}
+
+ int32_t GetToolkitVersion() const { return CUDA_VERSION; }
+
+ void SetUp() override {
+ AutotunerUtil::ClearAutotuneResults();
+ HloTestBase::SetUp();
+ }
+
+ void TearDown() override {
+ AutotunerUtil::ClearAutotuneResults();
+ HloTestBase::TearDown();
+ }
+};
+
+class GemmFusionAutotunerTest : public StatelessAutotunerTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options =
+ StatelessAutotunerTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_triton_gemm(true);
+ debug_options.set_xla_gpu_cublas_fallback(false);
+ debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0);
+ return debug_options;
+ }
+
+ se::CudaComputeCapability GetCudaComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability();
+ }
+
+ void CheckTritonAutotuning(absl::string_view hlo,
+ absl::string_view expected) {
+ HloPassPipeline pipeline("gemm_rewrite");
+ pipeline.AddPass<GemmFusion>(backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability());
+ tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
+ tsl::port::MaxParallelism());
+ DebugOptions opts;
+ MultiProcessKeyValueStore key_value_store;
+ pipeline.AddPass<GemmFusionAutotuner>(
+ AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
+ backend().memory_allocator()},
+ opts},
+ GetToolkitVersion(), &thread_pool, key_value_store);
+
+ RunAndFilecheckHloRewrite(
+ hlo, std::move(pipeline), expected, [](const HloModule* m) {
+ VLOG(5) << m->ToString();
+ const HloInstruction* dot_fusion =
+ m->entry_computation()->root_instruction();
+ if (dot_fusion->opcode() == HloOpcode::kReduce) {
+ dot_fusion = dot_fusion->operand(0);
+ }
+ CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion);
+ if (!dot_fusion->backend_config<GpuBackendConfig>()
+ ->fusion_backend_config()
+ .has_cudnn_fusion_config()) {
+ CHECK_GT(dot_fusion->backend_config<GpuBackendConfig>()
+ .value()
+ .fusion_backend_config()
+ .triton_gemm_config()
+ .block_m(),
+ 0);
+ }
+ });
+ }
+};
+
+class GemmFusionAutotunerTestWithMorePreciseReduction
+ : public GemmFusionAutotunerTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options =
+ GemmFusionAutotunerTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(
+ true);
+ return debug_options;
+ }
+};
+
+absl::StatusOr<std::vector<TritonGemmConfig>> GetPossibleMatmulAutotuneConfigs(
+ const HloDotInstruction& dot,
+ const se::CudaComputeCapability& compute_capability,
+ const int32_t toolkit_version, const DebugOptions& debug_options) {
+ se::GpuDeviceInfoProto deviceless_proto;
+ auto ccc = deviceless_proto.mutable_cuda_compute_capability();
+ ccc->set_major(compute_capability.major);
+ ccc->set_minor(compute_capability.minor);
+ DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}};
+ AutotuneConfig autotune_config{test_config, debug_options};
+ GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version,
+ debug_options, nullptr);
+ return autotuner.GenerateTritonConfigs(dot);
+}
+
+TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) {
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[1024,1024] parameter(0)
+ p1 = f32[1024,1024] parameter(1)
+ ROOT r = f32[1024,1024] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+ .value();
+ const se::CudaComputeCapability compute_capability{
+ se::CudaComputeCapability::AMPERE, /*minor=*/0};
+ TF_ASSERT_OK_AND_ASSIGN(
+ const std::vector<TritonGemmConfig> configs,
+ GetPossibleMatmulAutotuneConfigs(
+ *Cast<HloDotInstruction>(
+ module->entry_computation()->root_instruction()),
+ compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+ EXPECT_TRUE(std::any_of(
+ configs.begin(), configs.end(),
+ [](const TritonGemmConfig& config) { return config.num_stages > 2; }));
+}
+
+TEST_F(GemmFusionAutotunerTest, SmallOutputCanUseLargeSplitK) {
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[1024,1024] parameter(0)
+ p1 = f32[1024,1024] parameter(1)
+ ROOT r = f32[1024,1024] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+ .value();
+ const se::CudaComputeCapability compute_capability{
+ se::CudaComputeCapability::AMPERE, /*minor=*/0};
+ TF_ASSERT_OK_AND_ASSIGN(
+ const std::vector<TritonGemmConfig> configs,
+ GetPossibleMatmulAutotuneConfigs(
+ *Cast<HloDotInstruction>(
+ module->entry_computation()->root_instruction()),
+ compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+ EXPECT_TRUE(std::any_of(
+ configs.begin(), configs.end(),
+ [](const TritonGemmConfig& config) { return config.split_k >= 4; }));
+}
+
+TEST_F(GemmFusionAutotunerTest, LargeOutputDoesNotUseLargeSplitK) {
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[20480,20480] parameter(0)
+ p1 = f32[20480,20480] parameter(1)
+ ROOT r = f32[20480,20480] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+ .value();
+ const se::CudaComputeCapability compute_capability{
+ se::CudaComputeCapability::AMPERE, /*minor=*/0};
+ TF_ASSERT_OK_AND_ASSIGN(
+ const std::vector<TritonGemmConfig> configs,
+ GetPossibleMatmulAutotuneConfigs(
+ *Cast<HloDotInstruction>(
+ module->entry_computation()->root_instruction()),
+ compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+ EXPECT_FALSE(std::any_of(
+ configs.begin(), configs.end(),
+ [](const TritonGemmConfig& config) { return config.split_k > 1; }));
+}
+
+TEST_F(GemmFusionAutotunerTest, Int8FusedGemm) {
+ const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+ x = s8[128,64] parameter(0)
+ c = f16[128,64] convert(x)
+
+ y = f16[64,6144] parameter(1)
+
+ ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+ CheckTritonAutotuning(hlo, R"(
+// CHECK: ENTRY
+// CHECK: ROOT
+// CHECK-SAME: kCustom
+// CHECK-SAME: block_m
+)");
+
+ EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3}));
+}
+
+TEST_F(GemmFusionAutotunerTest, Int8FusedGemm256) {
+ const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+ x = s8[128,256] parameter(0)
+ c = f16[128,256] convert(x)
+
+ y = f16[256,6144] parameter(1)
+
+ ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ CheckTritonAutotuning(hlo, R"(
+// CHECK: ENTRY
+// CHECK: ROOT
+// CHECK-SAME: kCustom
+// CHECK-SAME: block_m
+)");
+
+ EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}));
+}
+
+TEST_F(GemmFusionAutotunerTest, SelectsSplitK) {
+ // Shapes with K >> M, N have to force split-K configurations.
+ const std::string kHloText = R"(
+HloModule t
+
+ENTRY e {
+ p0 = s8[7,8192] parameter(0)
+ p0c = f16[7,8192] convert(p0)
+ p1 = f16[8192,18] parameter(1)
+ ROOT dot.0 = f16[7,18] dot(p0c, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+ MatchOptimizedHlo(kHloText, R"(
+; CHECK: reduce
+; CHECK: ENTRY
+; CHECK-NEXT: parameter
+; CHECK-NEXT: parameter
+; CHECK-NEXT: kCustom
+; CHECK-NEXT: kLoop
+)");
+
+ EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1, /*arel=*/0.5}));
+}
+
+TEST_F(GemmFusionAutotunerTestWithMorePreciseReduction, SelectsSplitK) {
+ // Shapes with K >> M, N have to force split-K configurations.
+ constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+ p0 = s8[7,8192] parameter(0)
+ p0c = f16[7,8192] convert(p0)
+ p1 = f16[8192,18] parameter(1)
+ ROOT dot.0 = f16[7,18] dot(p0c, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+ MatchOptimizedHlo(kHloText, R"(
+; CHECK: reduce
+; CHECK: ENTRY
+; CHECK-NEXT: parameter
+; CHECK-NEXT: parameter
+; CHECK-NEXT: kCustom
+; CHECK-NEXT: kLoop
+)");
+
+ EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3}));
+}
+
+TEST_F(GemmFusionAutotunerTest, ApplySplitKWithoutAlteringTiling) {
+ const std::string kHloText = R"(
+triton_dot {
+ p0 = f16[55,120] parameter(0)
+ p1 = f16[120,20] parameter(1)
+ ROOT dot = f16[55,20] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY e {
+ p0 = f16[55,120]{1,0} parameter(0)
+ p1 = f16[120,20]{1,0} parameter(1)
+ ROOT _ = f16[55,20] fusion(p0, p1), kind=kCustom, calls=triton_dot,
+ backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}}}
+})";
+
+ MatchOptimizedHlo(kHloText, R"(
+; CHECK: f16[3,55,20]
+; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}
+; CHECK: f16[55,20]{1,0} {{(reduce|fusion)}}
+)");
+
+ EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+// Modify block_k back to 16 once b/337839570 is fixed.
+// TODO(b/344770374): Make this test not fragile.
+TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
+ const std::string kHloText = R"(
+HloModule m
+
+%triton_gemm_dot {
+ %p1 = s8[4,12288]{1,0} parameter(1)
+ %p0 = s8[12288,1536]{1,0} parameter(0)
+ %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
+ %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
+ %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
+}
+
+ENTRY %e {
+ %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
+ %convert = s8[4,12288]{1,0} parameter(1)
+ ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
+ backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
+})";
+
+ auto module = ParseAndReturnVerifiedModule(kHloText).value();
+ EXPECT_THAT(
+ backend().compiler()->RunBackend(std::move(module),
+ backend().default_stream_executor(),
+ {/*device_allocator=*/nullptr,
+ /*thread_pool=*/nullptr,
+ /*layout_canonicalization_callback=*/{},
+ /*is_autotuning_compilation=*/true}),
+ ::testing::AnyOf(
+ tsl::testing::StatusIs(
+ tsl::error::CANCELLED,
+ absl::StrFormat(
+ "Compilation result discarded due to register spilling")),
+ // Hopper can't spill registers since wgmma instructions are
+ // asynchronous, instead it just runs out of them.
+ tsl::testing::StatusIs(
+ tsl::error::RESOURCE_EXHAUSTED,
+ absl::StrFormat("Register allocation failed"))));
+}
+
+// Modify block_k back to 16 once b/337839570 is fixed.
+// TODO(b/344770374): Make this test not fragile.
+TEST_F(GemmFusionAutotunerTest,
+ DoNotFilterOutAutotuningKernelSpillingRegisters) {
+ if (GetCudaComputeCapability().IsAtLeastHopper()) {
+ GTEST_SKIP() << "Hopper and newer runs out of registers for such HLOs";
+ }
+ const std::string kHloText = R"(
+HloModule m
+
+%triton_gemm_dot {
+ %p1 = s8[4,12288]{1,0} parameter(1)
+ %p0 = s8[12288,1536]{1,0} parameter(0)
+ %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
+ %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
+ %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
+}
+
+ENTRY %e {
+ %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
+ %convert = s8[4,12288]{1,0} parameter(1)
+ ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
+ backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
+})";
+
+ auto module = ParseAndReturnVerifiedModule(kHloText).value();
+ HloModuleConfig config = module->config();
+ DebugOptions debug_options = config.debug_options();
+ debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
+ false);
+ config.set_debug_options(debug_options);
+ module->set_config(config);
+
+ std::unique_ptr<Executable> executable =
+ backend()
+ .compiler()
+ ->RunBackend(std::move(module), backend().default_stream_executor(),
+ {/*device_allocator=*/nullptr,
+ /*thread_pool=*/nullptr,
+ /*layout_canonicalization_callback=*/{},
+ /*is_autotuning_compilation=*/true})
+ .value();
+ EXPECT_NE(executable, nullptr);
+}
+
+// Modify block_k back to 16 once b/337839570 is fixed.
+TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) {
+ const std::string kHloText = R"(
+HloModule m
+
+%triton_gemm_dot {
+ %p1 = f16[4,12288]{1,0} parameter(1)
+ %p0 = s8[12288,1536]{1,0} parameter(0)
+ %convert.10406 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
+ ROOT %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %p1, f16[12288,1536]{1,0} %convert.10406), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY %e {
+ %p0 = s8[12288,1536]{1,0} parameter(0)
+ %p1 = f16[4,12288]{1,0} parameter(1)
+ ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot,
+ backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}}
+})";
+
+ auto module = ParseAndReturnVerifiedModule(kHloText).value();
+ std::unique_ptr<Executable> executable =
+ backend()
+ .compiler()
+ ->RunBackend(std::move(module), backend().default_stream_executor(),
+ {/*device_allocator=*/nullptr,
+ /*thread_pool=*/nullptr,
+ /*layout_canonicalization_callback=*/{},
+ /*is_autotuning_compilation=*/true})
+ .value();
+ EXPECT_NE(executable, nullptr);
+}
+
+using GemmFusionAutotunerDumpTest = GemmFusionAutotunerTest;
+
+TEST_F(GemmFusionAutotunerDumpTest, Fp8CublasltFallbackSupport) {
+ const std::string kHloText = R"(
+HloModule o
+
+gemm_fusion {
+ p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
+ p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
+ ROOT %dot.0 = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+
+ENTRY main {
+ p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
+ p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
+ ROOT %dot.0 = f32[64,64]{1,0} fusion(p0, p1), kind=kCustom, calls=gemm_fusion, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(kHloText));
+
+ DebugOptions opts;
+ AutotuneConfig autotune_config{
+ DeviceConfig{backend().default_stream_executor(),
+ backend().memory_allocator()},
+ opts};
+ AutotuneCacheKey cache_key(autotune_config.GetModelStr(),
+ *module->entry_computation()->root_instruction());
+
+ TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override,
+ ParseTextProto<AutotuneResults>(R"pb(
+ version: 3
+ results {
+ device: "..."
+ hlo: "..."
+ result {
+ gemm { algorithm: -1 }
+ run_time { nanos: 14 }
+ }
+ })pb"));
+ autotune_results_override.mutable_results(0)->set_device(
+ std::string(cache_key.GetModelStr()));
+ autotune_results_override.mutable_results(0)->set_hlo(
+ std::string(cache_key.GetHlo()));
+ CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override));
+
+ HloPassPipeline pipeline("gemm_autotune");
+ tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
+ tsl::port::MaxParallelism());
+ MultiProcessKeyValueStore key_value_store;
+ pipeline.AddPass<GemmFusionAutotuner>(autotune_config, GetToolkitVersion(),
+ &thread_pool, key_value_store);
+ pipeline.AddPass<CallInliner>();
+ for (GemmRewriterOptions::DType dtype :
+ {GemmRewriterOptions::DType::kFp8Only,
+ GemmRewriterOptions::DType::kNonFp8Only}) {
+ pipeline.AddPass<GemmRewriter>(autotune_config.GetGpuComputeCapability(),
+ GetToolkitVersion(),
+ GemmRewriterOptions{dtype});
+ }
+
+ TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get()));
+ const bool is_at_least_hopper =
+ std::holds_alternative<se::CudaComputeCapability>(
+ autotune_config.GetGpuComputeCapability()) &&
+ std::get<se::CudaComputeCapability>(
+ autotune_config.GetGpuComputeCapability())
+ .IsAtLeastHopper();
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(module->ToString(), is_at_least_hopper
+ ? "// CHECK: __cublas$lt"
+ : "// CHECK: __cublas$gemm"));
+ EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) {
+ HloModuleConfig config;
+ DebugOptions options = GetDebugOptionsForTest();
+ options.set_xla_gpu_cublas_fallback(true);
+ options.set_xla_gpu_dump_autotuned_gemm_fusions(true);
+ std::string output_directory;
+ if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) {
+ output_directory = tsl::testing::TmpDir();
+ }
+ options.set_xla_dump_to(output_directory);
+ config.set_debug_options(options);
+ // Computation is chosen such that relatively heavy math operations before the
+ // GEMM are not worth fusing because they would get duplicated many times and
+ // slow down execution. Therefore autotuning picks cuBLAS here.
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+fusion1 {
+ p0 = f32[333,333] parameter(0)
+ s = f32[333,333] sine(p0)
+ p1 = f32[333,333] parameter(1)
+ c = f32[333,333] cosine(p1)
+ ROOT dot = f32[333,333] dot(s, c),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY e {
+ p0 = f32[333,333] parameter(0)
+ p1 = f32[333,333] parameter(1)
+ ROOT rr = f32[333,333] fusion(p0, p1), kind=kCustom, calls=fusion1,
+ backend_config={"fusion_backend_config": {kind: "__triton_gemm"}}
+})",
+ config));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(std::move(module)));
+
+ std::string dump;
+ TF_EXPECT_OK(tsl::ReadFileToString(
+ tsl::Env::Default(),
+ tsl::io::JoinPath(output_directory,
+ FilenameFor(*optimized_module, /*prefix=*/"",
+ /*suffix=*/"gemm_fusion_0.rr.txt")),
+ &dump));
+ EXPECT_TRUE(*RunFileCheck(dump, R"(
+CHECK: HloModule rr
+CHECK-NOT: cublas
+CHECK: __triton_gemm
+CHECK-NOT: block_m
+)"));
+
+ dump.clear();
+
+ TF_EXPECT_OK(tsl::ReadFileToString(
+ tsl::Env::Default(),
+ tsl::io::JoinPath(
+ output_directory,
+ FilenameFor(*optimized_module, /*prefix=*/"",
+ /*suffix=*/"gemm_fusion_0.rr.optimized.txt")),
+ &dump));
+ EXPECT_TRUE(*RunFileCheck(dump, R"(
+CHECK: HloModule rr
+CHECK-NOT: triton
+CHECK: cublas
+)"));
+}
+
+TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) {
+ const std::string kHlo = R"(
+fusion1 {
+ p0 = f32[3,28,32] parameter(0)
+ p1 = f32[3,28,32] parameter(1)
+ ROOT d = f32[3,32,32] dot(p0, p1),
+ lhs_batch_dims={0}, rhs_batch_dims={0},
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+
+ENTRY e {
+ p0 = f32[3,28,32] parameter(0)
+ p1 = f32[3,28,32] parameter(1)
+ ROOT _ = f32[3,32,32] fusion(p0, p1), kind=kCustom, calls=fusion1,
+ backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
+})";
+
+ CheckTritonAutotuning(kHlo, R"(
+// CHECK: "plan_id":
+)");
+}
+
+// TODO(b/281489442): Write a testcase called
+// `SkipConfigsProducingDeviantResults` or similar.
+
+class GemmFusionAutotunerLevelTest : public StatelessAutotunerTest,
+ public ::testing::WithParamInterface<int> {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options =
+ StatelessAutotunerTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_autotune_level(GetParam());
+ debug_options.set_xla_gpu_cublas_fallback(false);
+ return debug_options;
+ }
+};
+
+TEST_P(GemmFusionAutotunerLevelTest, AllAutotuningLevelsWorkCorrectly) {
+ const std::string kHloText = R"(
+HloModule m
+
+ENTRY e {
+ p0 = pred[64,10] parameter(0)
+ p0c = f32[64,10] convert(p0)
+ p1 = f32[10,128] parameter(1)
+ ROOT r = f32[64,128] dot(p0c, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+ MatchOptimizedHlo(kHloText, R"(
+; CHECK: kind=kCustom
+; CHECK-SAME: block_m
+ )");
+
+ EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+TEST_P(GemmFusionAutotunerLevelTest, Deviceless) {
+ const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+ x = s8[16,16] parameter(0)
+ c = f16[16,16] convert(x)
+ y = f16[16,16] parameter(1)
+ ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ HloPassPipeline pipeline("gemm_rewrite_deviceless");
+ pipeline.AddPass<GemmFusion>(backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability());
+ tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
+ tsl::port::MaxParallelism());
+ DebugOptions opts;
+ MultiProcessKeyValueStore key_value_store;
+ pipeline.AddPass<GemmFusionAutotuner>(
+ AutotuneConfig{
+ DevicelessConfig{
+ backend().default_stream_executor()->GetDeviceDescription()},
+ opts},
+ GetToolkitVersion(), &thread_pool, key_value_store);
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) {
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ HloTestBase::RunHloPass(&pipeline, module.get()));
+ EXPECT_TRUE(changed);
+
+ // Check default configuration.
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(
+ module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+ R"(
+// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"16","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}},"force_earliest_schedule":false}
+ )"));
+ EXPECT_TRUE(filecheck_matches);
+ } else {
+ EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()),
+ tsl::testing::StatusIs(
+ tsl::error::INTERNAL,
+ ::testing::HasSubstr(
+ "Expect autotune result cache hit for deviceless")));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerLevelSweep,
+ GemmFusionAutotunerLevelTest, ::testing::Range(0, 5));
+
+class GemmFusionAutotunerExhaustiveTest : public GemmFusionAutotunerTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options =
+ GemmFusionAutotunerTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_exhaustive_tiling_search(true);
+ return debug_options;
+ }
+};
+
+TEST_F(GemmFusionAutotunerExhaustiveTest, DISABLED_CompileOnly) {
+ const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+ x = s8[16,16] parameter(0)
+ c = f16[16,16] convert(x)
+ y = f16[16,16] parameter(1)
+ ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ CheckTritonAutotuning(hlo, R"(
+// CHECK: %triton_gemm_out_computation (
+// CHECK: ROOT %out.1 = f16[16,16]{1,0} dot(%c.1, %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+// CHECK: ROOT %triton_gemm_out = f16[16,16]{1,0} fusion(%x, %y), kind=kCustom, calls=%triton_gemm_out_computation
+// CHECK-SAME: "block_m":
+)");
+}
+
+// TODO(b/337839570): Triton currently has a limitation where it crashes
+// on small block_k values depending on the bit-width of the inputs to the
+// dot. For this test case, it should skip any block_k values that are <= 16
+// since the smallest type has a bit-width of 8.
+TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingTileKConfig) {
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+ENTRY e {
+ x = s8[33,33]{1,0} parameter(0)
+ c = f16[33,33]{1,0} convert(x)
+ y = f16[33,33]{1,0} parameter(1)
+ ROOT out = f16[33,33]{1,0} dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)")
+ .value();
+ const se::CudaComputeCapability compute_capability{
+ se::CudaComputeCapability::AMPERE, /*minor=*/0};
+ TF_ASSERT_OK_AND_ASSIGN(
+ const std::vector<TritonGemmConfig> configs,
+ GetPossibleMatmulAutotuneConfigs(
+ *Cast<HloDotInstruction>(
+ module->entry_computation()->root_instruction()),
+ compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+ EXPECT_TRUE(std::all_of(
+ configs.begin(), configs.end(),
+ [](const TritonGemmConfig& config) { return config.block_k > 16; }));
+}
+
+class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options =
+ GemmFusionAutotunerTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_split_k_autotuning(false);
+ return debug_options;
+ }
+};
+
+TEST_F(GemmFusionAutotunerDisableSplitK, SplitKIsDisabled) {
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[1024,1024] parameter(0)
+ p1 = f32[1024,1024] parameter(1)
+ ROOT r = f32[1024,1024] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+ .value();
+ const se::CudaComputeCapability compute_capability{
+ se::CudaComputeCapability::AMPERE, /*minor=*/0};
+ TF_ASSERT_OK_AND_ASSIGN(
+ const std::vector<TritonGemmConfig> configs,
+ GetPossibleMatmulAutotuneConfigs(
+ *Cast<HloDotInstruction>(
+ module->entry_computation()->root_instruction()),
+ compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+ EXPECT_TRUE(std::all_of(
+ configs.begin(), configs.end(),
+ [](const TritonGemmConfig& config) { return config.split_k == 1; }));
+}
+
+class GemmFusionAutotunerConfigTest
+ : public StatelessAutotunerTest,
+ public ::testing::WithParamInterface<bool> {};
+
+TEST_P(GemmFusionAutotunerConfigTest, SparseDotDiscardsUnsupportedTiles) {
+ const std::string kHloText = R"(
+HloModule test
+ENTRY wais {
+ lhs = f16[5,1600] parameter(0)
+ rhs = f16[3200,10] parameter(1)
+ meta = u16[5,200] parameter(2)
+ ROOT dot = f32[5,10] dot(lhs, rhs, meta),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
+ const se::CudaComputeCapability compute_capability{
+ se::CudaComputeCapability::AMPERE, /*minor=*/0};
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_exhaustive_tiling_search(GetParam());
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ const std::vector<TritonGemmConfig> configs,
+ GetPossibleMatmulAutotuneConfigs(
+ *Cast<HloDotInstruction>(
+ module->entry_computation()->root_instruction()),
+ compute_capability, GetToolkitVersion(), debug_options));
+ for (const auto& config : configs) {
+ int metadata_size = config.block_m * config.block_k / 16;
+ EXPECT_LE(config.num_warps * WarpSize(), metadata_size);
+ EXPECT_GT(config.block_k, 16); // kMinTileSize
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep,
+ GemmFusionAutotunerConfigTest, ::testing::Bool());
+
+TEST_F(GemmFusionAutotunerTest, SplitKFLoatNormalization) {
+ if (!GetCudaComputeCapability().IsAtLeastHopper()) {
+ GTEST_SKIP() << "f8 types are only supported from Hopper onwards.";
+ }
+ const se::CudaComputeCapability compute_capability =
+ GetCudaComputeCapability();
+ se::GpuDeviceInfoProto deviceless_proto;
+ auto ccc = deviceless_proto.mutable_cuda_compute_capability();
+ ccc->set_major(compute_capability.major);
+ ccc->set_minor(compute_capability.minor);
+ DeviceConfig test_config{backend().default_stream_executor(),
+ backend().memory_allocator()};
+ AutotuneConfig autotune_config{test_config, GetDebugOptionsForTest()};
+ GemmFusionAutotunerImpl autotuner(autotune_config, GetToolkitVersion(),
+ GetDebugOptionsForTest(), nullptr);
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto compile_util,
+ AutotunerCompileUtil::Create(autotune_config, GetDebugOptionsForTest()))
+
+ std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+%gemm_fusion_dot_computation (parameter_0: f8e5m2[256,256], parameter_1: f8e4m3fn[128,256]) -> f8e5m2[256,128] {
+ %parameter_0 = f8e5m2[256,256]{1,0} parameter(0)
+ %parameter_1 = f8e4m3fn[128,256]{1,0} parameter(1)
+ %dot.1 = f32[256,128]{1,0} dot(f8e5m2[256,256]{1,0} %parameter_0, f8e4m3fn[128,256]{1,0} %parameter_1), lhs_contracting_dims={0}, rhs_contracting_dims={1}
+ ROOT %convert.2 = f8e5m2[256,128]{1,0} convert(f32[256,128]{1,0} %dot.1)
+}
+ENTRY entry {
+ %p0 = f8e5m2[256,256]{1,0} parameter(0)
+ %p1 = f8e4m3fn[128,256]{1,0} parameter(1)
+ ROOT r = f8e5m2[256,128]{1,0} fusion(f8e5m2[256,256]{1,0} %p0, f8e4m3fn[128,256]{1,0} %p1), kind=kCustom, calls=%gemm_fusion_dot_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
+})")
+ .value();
+ GemmFusionAutotunerImpl::TilingConfigs configs;
+ configs.emplace_back(DynCast<HloFusionInstruction>(
+ module->entry_computation()->root_instruction()),
+ std::vector<GemmFusionAutotunerImpl::Config>{
+ GemmFusionAutotunerImpl::Config(TritonGemmConfig(
+ /*block_m=*/32,
+ /*block_n=*/64,
+ /*block_k=*/64,
+ /*split_k=*/4,
+ /*num_stages=*/1,
+ /*num_warps=*/4,
+ /*num_ctas=*/1))});
+ CHECK_OK(autotuner.CompileAll(*compile_util, configs));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_autotuning.proto b/third_party/xla/xla/service/gpu/autotuning/gpu_autotuning.proto
similarity index 100%
rename from third_party/xla/xla/service/gpu/gpu_autotuning.proto
rename to third_party/xla/xla/service/gpu/autotuning/gpu_autotuning.proto
diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc
index 624d324..0ffb8e3 100644
--- a/third_party/xla/xla/service/gpu/buffer_sharing.cc
+++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc
@@ -79,7 +79,7 @@
// first, i.e. before processing other outputs (that may overwrite the input).
stream_executor::GpuDeviceInfoProto device_info;
stream_executor::DeviceDescription device_description(device_info);
- auto analysis = HloFusionAnalysis::Create(fusion, &device_description);
+ auto analysis = HloFusionAnalysis::Create(*user, device_description);
bool is_reduction_emitter = analysis.GetEmitterFusionKind() ==
HloFusionAnalysis::EmitterFusionKind::kReduction;
const HloInstruction* reduction_hero =
diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc b/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc
deleted file mode 100644
index 9102d75..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc
+++ /dev/null
@@ -1,231 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_cycle_decomposer.h"
-
-#include <cstdint>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/literal_util.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-
-namespace xla {
-
-namespace {
-using SourceTargetPair = std::pair<int64_t, int64_t>;
-using SourceTargetPairs = std::vector<SourceTargetPair>;
-enum class CycleType { kUnknown, kForward, kBackward };
-
-// Returns true if the CollectivePermute instruction has a cycle in its
-// source-target pairs and should be decomposed.
-CycleType ShouldDecomposeWithCycleType(
- const HloCollectivePermuteInstruction& collective_permute,
- int64_t threshold_in_bytes) {
- if (!collective_permute.channel_id().has_value()) {
- return CycleType::kUnknown;
- }
-
- if (collective_permute.operand_count() != 1) {
- return CycleType::kUnknown;
- }
-
- const Shape& result_shape = collective_permute.shape();
- // Skip the transformation if there is any context data.
- if (result_shape.IsTuple()) {
- return CycleType::kUnknown;
- }
-
- CHECK(result_shape.IsArray());
- if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) {
- return CycleType::kUnknown;
- }
-
- const SourceTargetPairs& pairs = collective_permute.source_target_pairs();
- if (pairs.size() == 1) {
- return CycleType::kUnknown;
- }
-
- return IsForwardCycle(pairs) ? CycleType::kForward
- : IsBackwardCycle(pairs) ? CycleType::kBackward
- : CycleType::kUnknown;
-}
-
-// Constructs the frontend attributes for the two decomposed CollectivePermute
-// instructions.
-absl::Status GetFrontendAttributes(HloCollectivePermuteInstruction* cp,
- CycleType cycle_type,
- xla::FrontendAttributes& cp1_attr,
- xla::FrontendAttributes& cp2_attr) {
- cp1_attr = cp->frontend_attributes();
- cp2_attr = cp->frontend_attributes();
- auto validation_it =
- cp->frontend_attributes().map().find(kSendRecvValidationAttr);
- if (validation_it == cp->frontend_attributes().map().end() ||
- validation_it->second == "invalid") {
- return absl::OkStatus();
- }
-
- auto statusor_bounds = ParseReplicaGroupsOnly(validation_it->second);
- if (!statusor_bounds.ok()) {
- return statusor_bounds.status();
- }
- const std::vector<ReplicaGroup>& bounds = statusor_bounds.value();
- if (bounds.size() < 2) {
- return Internal("Invalid number of replica groups");
- }
-
- int64_t num_pairs = bounds.size();
- // A forward cycle has its backedge at the end while a backward cycle has its
- // backedge at the beginning.
- auto backedge_start = cycle_type == CycleType::kBackward
- ? bounds.begin()
- : bounds.begin() + num_pairs - 1;
- auto other_edges_start =
- cycle_type == CycleType::kBackward ? bounds.begin() + 1 : bounds.begin();
- std::vector<ReplicaGroup> cp1_bounds(backedge_start, backedge_start + 1);
- std::vector<ReplicaGroup> cp2_bounds(other_edges_start,
- other_edges_start + num_pairs - 1);
- auto bounds_to_string = [](const std::vector<ReplicaGroup> groups) {
- return "{" +
- absl::StrJoin(groups, ",",
- [](std::string* out, const ReplicaGroup& value) {
- absl::StrAppend(out, "{", value.replica_ids(0), ",",
- value.replica_ids(1), "}");
- }) +
- "}";
- };
- std::string cp1_validation_str = bounds_to_string(cp1_bounds);
- std::string cp2_validation_str = bounds_to_string(cp2_bounds);
- (*cp1_attr.mutable_map())[kSendRecvValidationAttr] = cp1_validation_str;
- (*cp2_attr.mutable_map())[kSendRecvValidationAttr] = cp2_validation_str;
- return absl::OkStatus();
-}
-
-// Decomposes a CollectivePermute instruction with a cycle in its source-target
-// pairs into two CollectivePermute instructions.
-absl::Status DecomposeCollectivePermuteCycle(
- HloCollectivePermuteInstruction* cp, HloComputation* computation,
- HloModule* module, int64_t next_channel_id, CycleType cycle_type) {
- const SourceTargetPairs& pairs = cp->source_target_pairs();
- int64_t num_pairs = pairs.size();
- // A forward cycle has its backedge at the end as in
- // {{0,1},{1,2},{2,3},{3,0}} while a backward cycle has its backedge at the
- // beginning as in {{0,3},{1,0},{2,1},{3,2}}.
- auto backedge_start = cycle_type == CycleType::kBackward
- ? pairs.begin()
- : pairs.begin() + num_pairs - 1;
- auto other_edges_start =
- cycle_type == CycleType::kBackward ? pairs.begin() + 1 : pairs.begin();
- SourceTargetPairs backedge(backedge_start, backedge_start + 1);
- SourceTargetPairs other_edges(other_edges_start,
- other_edges_start + num_pairs - 1);
- const OpMetadata& metadata = cp->metadata();
- xla::FrontendAttributes cp1_attr, cp2_attr;
- TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr));
-
- // Create the CollectivePermute instruction for the communication represented
- // by the backedge.
- HloInstruction* cp1 =
- computation->AddInstruction(HloInstruction::CreateCollectivePermute(
- cp->shape(), cp->mutable_operand(0), backedge,
- cp->channel_id().value()));
- cp1->set_metadata(metadata);
- cp1->set_frontend_attributes(cp1_attr);
- int64_t cp1_receiver = backedge.back().second;
-
- // Create the CollectivePermute instruction for the communication represented
- // byt other edges.
- HloInstruction* cp2 =
- computation->AddInstruction(HloInstruction::CreateCollectivePermute(
- cp->shape(), cp->mutable_operand(0), other_edges, next_channel_id));
- cp2->set_metadata(metadata);
- cp2->set_frontend_attributes(cp2_attr);
-
- // Calculate the received data as follows:
- // partition = u32[] partition-id()
- // constant = u32[] constant(cp1_receiver)
- // compare0 = pred[] compare(partition, cp1_received), direction=EQ
- // compare = pred[?] broadcast(compare0), dimensions={}
- // recv-data = type[?] select(compare, cp1_done, cp2_done)
- HloInstruction* partition =
- computation->AddInstruction(HloInstruction::CreatePartitionId());
- HloInstruction* constant = computation->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0(U32, cp1_receiver)));
- HloInstruction* compare0 = computation->AddInstruction(
- HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), partition,
- constant, Comparison::Direction::kEq));
- HloInstruction* compare =
- computation->AddInstruction(HloInstruction::CreateBroadcast(
- ShapeUtil::MakeShape(PRED, cp1->shape().dimensions()), compare0, {}));
- HloInstruction* recv_data =
- computation->AddInstruction(HloInstruction::CreateTernary(
- cp1->shape(), HloOpcode::kSelect, compare, cp1, cp2));
-
- TF_RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data));
- TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp));
-
- return absl::OkStatus();
-}
-} // namespace
-
-absl::StatusOr<bool> CollectivePermuteCycleDecomposer::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- int64_t next_channel_id;
- for (auto comp : module->computations(execution_threads)) {
- for (auto hlo : comp->MakeInstructionPostOrder()) {
- if (hlo->opcode() != HloOpcode::kCollectivePermute) {
- continue;
- }
- auto collective_permute = Cast<HloCollectivePermuteInstruction>(hlo);
- CycleType cycle_type = ShouldDecomposeWithCycleType(*collective_permute,
- threshold_in_bytes_);
- if (cycle_type != CycleType::kUnknown) {
- if (changed == false) {
- next_channel_id = hlo_query::NextChannelId(*module);
- changed = true;
- }
- TF_RETURN_IF_ERROR(DecomposeCollectivePermuteCycle(
- collective_permute, comp, module, next_channel_id++, cycle_type));
- }
- }
- }
- return changed;
-}
-
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h b/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h
deleted file mode 100644
index 508a859..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h
+++ /dev/null
@@ -1,73 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
-#define XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// CollectivePermuteCycleDecomposer is a pass that converts CollectivePermute
-// instructions with all participants forming either a forward cycle (such as
-// {{0,1},{1,2},{2,3},{3,0}) or a backward cycle (such as {{3,2},{2,1},{1,0},
-// {0,3}}) into two CollectivePermute instructions. We currently restrict
-// this transformation to CollectivePermute using partition mode, with one
-// input, without any context data. Here is an example.
-//
-// before transformation:
-// start = (<rt>, <rt>) collective-permute(data),
-// source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
-//
-// after transformation:
-// partition-id = u32[] partition-id()
-// constant = u32[] constant(0)
-// compare = pred[] compare(u32[] partition-id, u32[] constant),
-// direction=EQ
-// pred = pred[] broadcast(pred[] compare), dimensions={}
-// cp1 = (<rt>, <rt>) collective-permute(data), source_target_pairs={{3,0}}
-// cp2 = (<rt>, <rt>) collective-permute(data),
-// source_target_pairs={{0,1},{1,2},{2,3}}
-// data = <rt> select(pred, cp1, cp2)
-//
-class CollectivePermuteCycleDecomposer : public HloModulePass {
- public:
- explicit CollectivePermuteCycleDecomposer(int64_t threshold_in_bytes)
- : threshold_in_bytes_(threshold_in_bytes) {}
- absl::string_view name() const override {
- return "collective-permute-cycle-decomposer";
- }
-
- using HloPassInterface::Run;
- // Runs CollectivePermuteCycleDecomposer pass on computations in 'module'.
- // Returns whether the 'module' was changed.
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- // Transform only if the size of the CollectivePermute data >= threshold.
- int64_t threshold_in_bytes_;
-};
-
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc b/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc
deleted file mode 100644
index 19436ee..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc
+++ /dev/null
@@ -1,237 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_cycle_decomposer.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-using ::testing::HasSubstr;
-using CollectivePermuteCycleDecomposerTest = HloTestBase;
-
-using ::testing::HasSubstr;
-using CollectivePermuteDecomposerTest = HloTestBase;
-
-TEST_F(CollectivePermuteDecomposerTest, DefaultChannelNotTransformed) {
- const absl::string_view kModuleStr = R"(
- HloModule test
- ENTRY test_computation {
- p = u32[] replica-id()
- ROOT start = u32[] collective-permute(p),
- source_target_pairs={{0,1},{1,0}}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kModuleStr)));
- CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
- const absl::string_view kModuleStr = R"(
- HloModule test
- ENTRY test_computation {
- p = u32[] partition-id()
- ROOT start = u32[] collective-permute(p), channel_id=1,
- source_target_pairs={{0,0}}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kModuleStr)));
- CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
- const absl::string_view kModuleStr = R"(
- HloModule test
- ENTRY test_computation {
- p = u32[] partition-id()
- ROOT start = u32[] collective-permute(p), channel_id=1,
- source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kModuleStr)));
- CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
- const absl::string_view kModuleStr = R"(
- HloModule test
- ENTRY test_computation {
- p = u32[] partition-id()
- ROOT start = u32[3,2] collective-permute(p), channel_id=1,
- source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
- frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
- metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kModuleStr)));
- CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
- EXPECT_TRUE(changed);
-
- auto check_metadata = [](const HloInstruction* inst) {
- EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
- EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
- EXPECT_EQ(inst->metadata().source_line(), 35);
- };
-
- HloCollectivePermuteInstruction* cp1 =
- DynCast<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), "collective-permute"));
- HloCollectivePermuteInstruction* cp2 =
- DynCast<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), "collective-permute.1"));
- EXPECT_NE(cp1, nullptr);
- EXPECT_NE(cp2, nullptr);
- EXPECT_EQ(cp1->operand(0), cp2->operand(0));
- EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
- EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
- EXPECT_THAT(cp1->ToString(),
- HasSubstr("_xla_send_recv_validation=\"{{3,10}}\""));
- EXPECT_THAT(cp2->ToString(),
- HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
- EXPECT_THAT(cp2->ToString(),
- HasSubstr("_xla_send_recv_validation=\"{{0,7},{1,8},{2,9}}\""));
- check_metadata(cp1);
- check_metadata(cp2);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
- const absl::string_view kModuleStr = R"(
- HloModule test
-
- while_cond {
- param = (u32[], f32[2,2], f32[2,2]) parameter(0)
- iter = u32[] get-tuple-element(param), index=0
- max_iter = u32[] constant(3)
- ROOT cmp = pred[] compare(iter, max_iter), direction=LT
- }
-
- while_body {
- param = (u32[], f32[2,2], f32[2,2]) parameter(0)
- iter = u32[] get-tuple-element(param), index=0
- data = f32[2,2] get-tuple-element(param), index=1
- weights = f32[2,2] get-tuple-element(param), index=2
- cp = f32[2,2] collective-permute(data),
- channel_id=1,
- source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}},
- frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}
- matmul = f32[2,2] dot(weights, cp), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- iter_increment = u32[] constant(1)
- next_iter = u32[] add(iter, iter_increment)
- ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights)
- }
-
- ENTRY test_computation {
- iter = u32[] constant(0)
- data = f32[2,2] parameter(0)
- weights = f32[2,2] parameter(1)
- input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights)
- while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
- ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kModuleStr)));
- CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
- EXPECT_TRUE(changed);
- HloCollectivePermuteInstruction* cp1 =
- DynCast<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), "collective-permute"));
- HloCollectivePermuteInstruction* cp2 =
- DynCast<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), "collective-permute.1"));
- EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
- EXPECT_THAT(cp1->ToString(),
- HasSubstr("_xla_send_recv_validation=\"{{3,10}}\""));
- EXPECT_THAT(cp2->ToString(),
- HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
- EXPECT_THAT(cp2->ToString(),
- HasSubstr("_xla_send_recv_validation=\"{{0,7},{1,8},{2,9}}\""));
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
- const absl::string_view kModuleStr = R"(
- HloModule test
- ENTRY test_computation {
- p = u32[] partition-id()
- ROOT start = u32[] collective-permute(p), channel_id=1,
- source_target_pairs={{0,3},{1,0},{2,1},{3,2}},
- frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
- metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kModuleStr)));
- CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
- EXPECT_TRUE(changed);
- auto check_metadata = [](const HloInstruction* inst) {
- EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
- EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
- EXPECT_EQ(inst->metadata().source_line(), 35);
- };
-
- HloCollectivePermuteInstruction* cp1 =
- DynCast<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), "collective-permute"));
- HloCollectivePermuteInstruction* cp2 =
- DynCast<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), "collective-permute.1"));
- EXPECT_NE(cp1, nullptr);
- EXPECT_NE(cp2, nullptr);
- EXPECT_EQ(cp1->operand(0), cp2->operand(0));
- EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
- EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{0,3}}"));
- EXPECT_THAT(cp1->ToString(),
- HasSubstr("_xla_send_recv_validation=\"{{0,7}}\""));
- EXPECT_THAT(cp2->ToString(),
- HasSubstr("source_target_pairs={{1,0},{2,1},{3,2}}"));
- EXPECT_THAT(cp2->ToString(),
- HasSubstr("_xla_send_recv_validation=\"{{1,8},{2,9},{3,10}}\""));
- check_metadata(cp1);
- check_metadata(cp2);
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc b/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc
deleted file mode 100644
index b1e8812..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc
+++ /dev/null
@@ -1,163 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h"
-
-#include "xla/literal_util.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/while_loop_analysis.h"
-
-namespace xla {
-
-// Finds and returns the non-constant operand in instr.
-//
-// CHECK-fails if instr doesn't have exactly one unique non-constant operand.
-static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
- const HloInstruction* result = nullptr;
- for (const HloInstruction* operand : instr->operands()) {
- if (!operand->IsConstant()) {
- if (result != nullptr) {
- CHECK_EQ(result, operand);
- }
- result = operand;
- }
- }
- CHECK_NE(result, nullptr);
- return result;
-}
-
-// Finds the step (k) for while instruction, if the loop is of the form:
-//
-// while(cond) {
-// ind_var = ind_var + k
-// }
-//
-// If this pattern is not found, it returns std::nullopt.
-std::optional<int64_t> GetStep(HloInstruction* while_inst) {
- // Get the update operation
- std::optional<int64_t> indvar_tuple_idx =
- GetLoopInductionVarTupleIdx(while_inst);
- if (!indvar_tuple_idx) {
- return std::nullopt;
- };
- auto* while_body_indvar_update =
- while_inst->while_body()->root_instruction()->mutable_operand(
- *indvar_tuple_idx);
- auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
-
- HloInstruction* trip_count_increase_step_instr = nullptr;
- if (!Match(while_body_indvar_update,
- match::AddAnyOrder(match::Op().Is(while_body_indvar),
- match::Op(&trip_count_increase_step_instr)))) {
- return std::nullopt;
- }
- return LiteralUtil::LiteralAsScalarInt64(
- trip_count_increase_step_instr->literal());
-}
-
-absl::StatusOr<bool> CollectivePermuteValidIterationAnnotator::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* comp : module->computations(execution_threads)) {
- for (HloInstruction* inst : comp->instructions()) {
- if (inst->opcode() != HloOpcode::kCollectivePermute) {
- continue;
- }
-
- if (inst->frontend_attributes().map().find(kSendRecvValidationAttr) !=
- inst->frontend_attributes().map().end()) {
- continue;
- }
- auto sourceTargetPairs = inst->source_target_pairs();
- if (!IsForwardCycle(sourceTargetPairs) &&
- !IsBackwardCycle(sourceTargetPairs)) {
- continue;
- }
-
- VLOG(2) << "Collective permute with cycle: " << inst->ToString();
-
- int64_t max_device_num = -1;
- for (auto [source, target] : sourceTargetPairs) {
- max_device_num = std::max(std::max(source, target), max_device_num);
- }
- int64_t num_devices = max_device_num + 1;
-
- HloInstruction* whileOp = inst->parent()->WhileCallInstruction();
- if (whileOp == nullptr) {
- VLOG(2) << "No surrounding while op found. Ignoring " << inst->name();
- continue;
- }
- if (!whileOp->frontend_attributes().map().contains(
- "is_pipelined_while_loop"))
- continue;
- TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
- whileOp->backend_config<WhileLoopBackendConfig>());
- if (!config.has_known_trip_count()) {
- VLOG(2) << "Trip count for while loop (" << whileOp->name()
- << "): unknown";
- continue;
- }
-
- int64_t trip_count = config.known_trip_count().n();
- std::optional<int64_t> step = GetStep(whileOp);
- VLOG(2) << "Trip count for while loop (" << whileOp->name()
- << "): " << trip_count;
- if (!step) {
- VLOG(2) << "Could not find step for while operation";
- continue;
- }
- VLOG(2) << "Step for while loop (" << whileOp->name() << "): " << *step;
- if (*step != 1) {
- VLOG(2) << "Step is not 1. Skipping...";
- continue;
- }
-
- // For each source i, the send/recv iteration instances are {i, i+offset}
- // where offset is `number of microbatches * CR - 1`. We know that
- // `trip_count = number_of_microbatches * CR + num_devices - 1` So, offset
- // = number_of_microbatches * CR - 1 = trip_count - num_devices.
- int64_t offset = trip_count - num_devices;
-
- std::vector<std::pair<int64_t, int64_t>> sendRecvValidation(
- sourceTargetPairs.size());
-
- for (size_t currIdx = 0; currIdx < sourceTargetPairs.size(); currIdx++) {
- sendRecvValidation[currIdx] = {currIdx, currIdx + offset};
- }
-
- if (IsBackwardCycle(sourceTargetPairs)) {
- std::reverse(sendRecvValidation.begin(), sendRecvValidation.end());
- }
-
- xla::FrontendAttributes attributes;
- std::string iteration_instances =
- "{" +
- absl::StrJoin(sendRecvValidation, ",",
- [](std::string* out, std::pair<int64_t, int64_t> item) {
- absl::StrAppend(out, "{", item.first, ",",
- item.second, "}");
- }) +
- "}";
- (*attributes.mutable_map())[kSendRecvValidationAttr] =
- iteration_instances;
-
- inst->add_frontend_attributes(attributes);
- VLOG(1) << "Adding " << kSendRecvValidationAttr << " to " << inst->name()
- << ": " << iteration_instances;
- changed = true;
- }
- }
- return changed;
-}
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h b/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h
deleted file mode 100644
index e6b04c9..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
-
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// This is an unsafe transformation that is triggered only if the attribute
-// `is_pipelined_while_loop` is present on a while loop.
-//
-// If a while loop is known to be a pipelined while loop, has a known trip count
-// and increments with step=1, then this pass annotates the `collective-permute`
-// operations within the while loop with valid iterations for each GPU. This is
-// only done when the source-target pairs of the `collective-permute` operation
-// form a forward or backward cycle.
-//
-// For example, if the trip count is 10 (iteration 0 to 9), with step=1, and the
-// source-target pairs of a `collective-permute` operation are
-// `{{0,1},{1,2},{2,3},{3,0}}`, then this pass would annotate such operation
-// with `_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"`. This annotation
-// means that
-// - for GPU index 0, the valid iterations are 0,1,2,3,4,5,6.
-// - for GPU index 1, the valid iterations are 1,2,3,4,5,6,7.
-// - for GPU index 2, the valid iterations are 2,3,4,5,6,7,8.
-// - for GPU index 3, the valid iterations are 3,4,5,6,7,8,9.
-//
-// The index in the list denotes the device index and the bounds {start,end} are
-// inclusive. For more examples, look at
-// `xla/service/spmd/collective_permute_valid_iteration_annotator_tests.cc`.
-class CollectivePermuteValidIterationAnnotator : public HloModulePass {
- public:
- CollectivePermuteValidIterationAnnotator() = default;
- absl::string_view name() const override {
- return "collective-permute-valid-iteration-annotator";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc b/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc
deleted file mode 100644
index 3d1d0b4..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc
+++ /dev/null
@@ -1,174 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h"
-
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/while_loop_trip_count_annotator.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace {
-
-using CollectivePermuteValidIterationAnnotatorTest = HloTestBase;
-
-TEST_F(CollectivePermuteValidIterationAnnotatorTest, NoChange) {
- // We expect no changes here because the while loop is not labelled as
- // `is_pipelined_while_loop`.
- absl::string_view hlo_string = R"(
- HloModule test, entry_computation_layout={()->(s32[], s32[])}
- %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
- %param = (s32[], s32[]) parameter(0)
- %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
- %one = s32[] constant(1)
- %i_plus_one = s32[] add(s32[] %i, s32[] %one)
- %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
- ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %permute)
- }
- %Cond (param.1: (s32[], s32[])) -> pred[] {
- %param.1 = (s32[], s32[]) parameter(0)
- %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
- %trip_count = s32[] constant(10)
- ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
- }
- ENTRY %test () -> (s32[], s32[]) {
- %i_start = s32[] constant(0)
- %p_start = s32[] constant(0)
- %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
- ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string, 1, 4));
-
- HloPassPipeline pipeline("my-pass-pipeline");
-
- pipeline.AddPass<WhileLoopTripCountAnnotator>();
- pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
- EXPECT_FALSE(changed);
-
- HloCollectivePermuteInstruction* cp =
- DynCastOrNull<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), HloOpcode::kCollectivePermute));
-
- ASSERT_NE(cp, nullptr);
-
- auto sendRecvValidationIt =
- cp->frontend_attributes().map().find(kSendRecvValidationAttr);
- ASSERT_EQ(sendRecvValidationIt, cp->frontend_attributes().map().end());
-}
-
-TEST_F(CollectivePermuteValidIterationAnnotatorTest, ForwardCycle) {
- absl::string_view hlo_string = R"(
- HloModule test, entry_computation_layout={()->(s32[], s32[])}
- %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
- %param = (s32[], s32[]) parameter(0)
- %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
- %one = s32[] constant(1)
- %i_plus_one = s32[] add(s32[] %i, s32[] %one)
- %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
- ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
- }
- %Cond (param.1: (s32[], s32[])) -> pred[] {
- %param.1 = (s32[], s32[]) parameter(0)
- %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
- %trip_count = s32[] constant(10)
- ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
- }
- ENTRY %test () -> (s32[], s32[]) {
- %i_start = s32[] constant(0)
- %p_start = s32[] constant(0)
- %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
- ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string, 1, 4));
-
- HloPassPipeline pipeline("my-pass-pipeline");
-
- pipeline.AddPass<WhileLoopTripCountAnnotator>();
- pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloCollectivePermuteInstruction* cp =
- DynCastOrNull<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), HloOpcode::kCollectivePermute));
-
- ASSERT_NE(cp, nullptr);
-
- auto sendRecvValidationIt =
- cp->frontend_attributes().map().find(kSendRecvValidationAttr);
- ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
- std::string sendRecvValidationAttr = sendRecvValidationIt->second;
- EXPECT_EQ(sendRecvValidationAttr, "{{0,6},{1,7},{2,8},{3,9}}");
-}
-
-TEST_F(CollectivePermuteValidIterationAnnotatorTest, BackwardCycle) {
- absl::string_view hlo_string = R"(
- HloModule test, entry_computation_layout={()->(s32[], s32[])}
- %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
- %param = (s32[], s32[]) parameter(0)
- %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
- %one = s32[] constant(1)
- %i_plus_one = s32[] add(s32[] %i, s32[] %one)
- %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
- ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
- }
- %Cond (param.1: (s32[], s32[])) -> pred[] {
- %param.1 = (s32[], s32[]) parameter(0)
- %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
- %trip_count = s32[] constant(10)
- ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
- }
- ENTRY %test () -> (s32[], s32[]) {
- %i_start = s32[] constant(0)
- %p_start = s32[] constant(0)
- %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
- ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string, 1, 4));
-
- HloPassPipeline pipeline("my-pass-pipeline");
-
- pipeline.AddPass<WhileLoopTripCountAnnotator>();
- pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloCollectivePermuteInstruction* cp =
- DynCastOrNull<HloCollectivePermuteInstruction>(
- FindInstruction(module.get(), HloOpcode::kCollectivePermute));
-
- ASSERT_NE(cp, nullptr);
-
- auto sendRecvValidationIt =
- cp->frontend_attributes().map().find(kSendRecvValidationAttr);
- ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
- std::string sendRecvValidationAttr = sendRecvValidationIt->second;
- EXPECT_EQ(sendRecvValidationAttr, "{{3,9},{2,8},{1,7},{0,6}}");
-}
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
deleted file mode 100644
index d113a6b..0000000
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
+++ /dev/null
@@ -1,811 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/command_buffer_scheduling.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/strings/match.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/ffi/ffi_api.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/variant_visitor.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-using CommandBuffer = CommandBufferScheduling::CommandBuffer;
-using CommandBufferConfig = CommandBufferScheduling::CommandBufferConfig;
-
-// Returns true if HLO computation can be executed as a command buffer.
-static bool IsCommand(const HloComputation* computation,
- const CommandBufferConfig& config);
-
-//===----------------------------------------------------------------------===//
-// No-op HLO operations.
-//===----------------------------------------------------------------------===//
-
-// Some of the HLO operations do not have corresponding operations at run time
-// and they can be safely wrapped into command buffers together with load
-// bearing commands.
-
-static bool IsConstant(const HloInstruction* hlo) {
- return hlo->opcode() == HloOpcode::kConstant;
-}
-
-static bool IsParameter(const HloInstruction* hlo) {
- return hlo->opcode() == HloOpcode::kParameter;
-}
-
-// Returns true if instruction is no-op at run time and doesn't have a
-// corresponding Thunk or Command (metadata only operation).
-static bool IsNoOp(const HloInstruction* hlo) {
- return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
- HloOpcode::kGetTupleElement>(hlo);
-};
-
-//===----------------------------------------------------------------------===//
-// Synchronous HLO operations mapped to commands.
-//===----------------------------------------------------------------------===//
-
-// Synchronous HLO operations can be wrapped into command buffers when they have
-// a corresponding commands.
-
-// This is a template to define pattern matching functions for HLO instructions
-// that do not have a corresponding class for them.
-template <HloOpcode op>
-static bool IsCommand(const HloInstruction*, const CommandBufferConfig&);
-
-// While loops can be executed inside command buffers only if condition and body
-// regions can be executed as command buffers.
-template <>
-bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
- const CommandBufferConfig& config) {
- return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
- IsCommand(hlo->while_body(), config) &&
- IsCommand(hlo->while_condition(), config);
-}
-
-// Conditional can be executed inside command buffers only if all regions of its
-// branches can be executed as command buffers.
-template <>
-bool IsCommand<HloOpcode::kConditional>(const HloInstruction* hlo,
- const CommandBufferConfig& config) {
- return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
- absl::c_all_of(hlo->branch_computations(),
- [&](const HloComputation* comp) {
- return IsCommand(comp, config);
- });
-}
-
-static bool IsCommand(const HloCustomCallInstruction* hlo,
- const CommandBufferConfig& config) {
- // cuBLAS gemms represented in the HLO as custom call instructions.
- if (config.enabled_commands.contains(DebugOptions::CUBLAS) &&
- IsLegacyCublasMatmul(*hlo)) {
- return true;
- }
-
- if (config.enabled_commands.contains(DebugOptions::CUBLASLT) &&
- (IsCublasLtMatmul(*hlo) || IsCublasLtMatmulF8(*hlo))) {
- return true;
- }
-
- if (config.enabled_commands.contains(DebugOptions::CUDNN) &&
- IsCustomCallTofMHA(*hlo)) {
- VLOG(3) << "Recording FusedMHA, target " << hlo->custom_call_target()
- << " into command buffer.";
- return true;
- }
-
- if (!config.enabled_commands.contains(DebugOptions::CUSTOM_CALL)) {
- return false;
- }
-
- if (config.enabled_legacy_custom_call_targets.contains(
- hlo->custom_call_target())) {
- VLOG(3) << "Recording legacy custom call target "
- << hlo->custom_call_target() << " into command buffer.";
- return true;
- }
-
- // A special case for jax-triton kernel while it is not ported to FFI.
- if (hlo->custom_call_target() == "triton_kernel_call" &&
- // TODO(b/327718087): This is an ugly hack to prevent capturing triton
- // custom calls that might do autotuning at run time.
- !absl::StrContains(hlo->metadata().op_name(), "Autotuner")) {
- return true;
- }
-
- // Check if FFI handler is compatible with command buffers.
- auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
- return registration.ok()
- ? ffi::IsCommandBufferCompatible(registration->traits)
- : false;
-}
-
-static bool IsCommand(const HloInstruction* hlo,
- const CommandBufferConfig& config) {
- if (auto* fusion = DynCast<HloFusionInstruction>(hlo)) {
- auto gpu_config = fusion->backend_config<GpuBackendConfig>();
- const FusionBackendConfig& backend_config =
- gpu_config->fusion_backend_config();
- if (backend_config.kind() == kCuDnnFusionKind) {
- return config.enabled_commands.contains(DebugOptions::CUDNN);
- }
- const auto& custom_config = backend_config.custom_fusion_config();
- if (custom_config.name() == "address_computation") {
- auto fusion_analysis =
- HloFusionAnalysis::Create(fusion, &config.device_description);
- const HloFusionAdaptor& adaptor = fusion_analysis.fusion();
- auto custom_call_adaptor = HloBfsFindIf(
- adaptor.GetRoots(), adaptor,
- [](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
- const auto* custom_call = static_cast<const HloCustomCallInstruction*>(
- &custom_call_adaptor->instruction());
- return IsCommand(custom_call, config);
- }
- if (custom_config.name() == "dynamic_address_computation") {
- return false;
- }
- return config.enabled_commands.contains(DebugOptions::FUSION);
- }
-
- if (auto* sort = DynCast<HloSortInstruction>(hlo))
- return config.enabled_commands.contains(DebugOptions::FUSION);
-
- if (hlo->opcode() == HloOpcode::kPartitionId ||
- hlo->opcode() == HloOpcode::kReplicaId) {
- return config.enabled_commands.contains(DebugOptions::FUSION);
- }
-
- if (auto* custom_call = DynCast<HloCustomCallInstruction>(hlo))
- return IsCommand(custom_call, config);
-
- if (hlo->opcode() == HloOpcode::kWhile)
- return IsCommand<HloOpcode::kWhile>(hlo, config);
-
- if (hlo->opcode() == HloOpcode::kConditional)
- return IsCommand<HloOpcode::kConditional>(hlo, config);
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// Asynchronous HLO operations mapped to commands.
-//===----------------------------------------------------------------------===//
-
-// Asynchronous HLO operations can be wrapped into command buffers only when
-// both start and done operations can be put into the same command buffer.
-// Command buffer semantics implies that when command buffer execution
-// completes, all recorded commands are also completed, which means that if
-// done operation is not part of the same command buffer, we would change the
-// execution semantics and create additional synchronization point.
-
-static bool IsAsyncStartCommand(const HloInstruction* hlo,
- const CommandBufferConfig& config) {
- if (hlo->opcode() == HloOpcode::kAllReduceStart ||
- hlo->opcode() == HloOpcode::kAllGatherStart) {
- return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
- }
-
- if (hlo->opcode() == HloOpcode::kAsyncStart) {
- if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
- return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
- }
- }
-
- return false;
-}
-
-static bool IsAsyncDoneCommand(const HloInstruction* hlo,
- const CommandBufferConfig& config) {
- if (hlo->opcode() == HloOpcode::kAllReduceDone ||
- hlo->opcode() == HloOpcode::kAllGatherDone) {
- return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
- }
-
- if (hlo->opcode() == HloOpcode::kAsyncDone) {
- if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
- return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
- }
- }
-
- return false;
-}
-
-// Finds an async-done HLO operation corresponding on an async-start one.
-static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) {
- if (start->opcode() == HloOpcode::kAllReduceStart ||
- start->opcode() == HloOpcode::kAllGatherStart) {
- CHECK(start->users().size() == 1); // NOLINT, checked by HLO verifier
- return start->users().front();
- } else if (start->opcode() == HloOpcode::kAsyncStart) {
- return start->async_chain_done();
- }
-
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// HLO computations mapped to command buffers.
-//===----------------------------------------------------------------------===//
-
-// Returns true if HLO computation can be executed as a command buffer.
-static bool IsCommand(const HloComputation* computation,
- const CommandBufferConfig& config) {
- return absl::c_all_of(
- computation->instructions(), [&](const HloInstruction* inst) {
- return IsNoOp(inst) || IsConstant(inst) || IsParameter(inst) ||
- IsCommand(inst, config) || IsAsyncStartCommand(inst, config) ||
- IsAsyncDoneCommand(inst, config);
- });
-}
-
-//===----------------------------------------------------------------------===//
-
-static void RemoveTrailingNoOps(HloInstructionSequence& seq) {
- std::vector<HloInstruction*> instructions = seq.instructions();
- for (int i = instructions.size() - 1; i >= 0; i--) {
- if (HloInstruction* inst = instructions[i]; IsNoOp(inst)) {
- seq.remove_instruction(inst);
- } else {
- break;
- }
- }
-}
-
-//===----------------------------------------------------------------------===//
-// Discovering sequences of compatible Hlo instructions
-//===----------------------------------------------------------------------===//
-
-// The input is a scheduled sequence of instructions. This function collects
-// subsequences that will be extracted as command buffers.
-std::vector<HloInstructionSequence>
-CommandBufferScheduling::CollectCommandBufferSequences(
- const HloInstructionSequence schedule, const CommandBufferConfig& config,
- int32_t min_num_commands) {
- std::vector<HloInstructionSequence> sequences;
-
- HloInstructionSequence current_seq;
- int64_t num_commands_in_current_seq = 0;
-
- // Adds `current_seq` to `sequences` if it has enough commands in it.
- auto collect_current_seq = [&]() {
- if (num_commands_in_current_seq >= std::max(1, min_num_commands)) {
- RemoveTrailingNoOps(current_seq);
- sequences.push_back(std::move(current_seq));
- }
- current_seq = HloInstructionSequence();
- num_commands_in_current_seq = 0;
- };
-
- auto& instructions = schedule.instructions();
-
- // Collect the sequence of instructions that contains the async start and its
- // corresponding done instruction. If there is another start instruction
- // between the original start and done, we may potentially extend the sequence
- // to include its corresponding done instruction. For example, if we call this
- // function on async-start_a in the following sequence:
- //
- // async_start_a
- // async_start_b
- // async_done_a
- // async_done_b
- //
- // The returned sequence will contain async_done_b. So that all async pairs
- // are captured by the same command buffer.
- auto collect_async_region = [&](const HloInstruction* start) {
- auto get_index = [&](const HloInstruction* inst) -> size_t {
- auto it = std::find(instructions.begin(), instructions.end(), inst);
- return std::distance(instructions.begin(), it);
- };
-
- HloInstructionSequence seq;
- size_t done_index = get_index(FindAsyncDoneCommand(start));
- for (size_t i = get_index(start); i <= done_index; i++) {
- HloInstruction* inst = instructions.at(i);
- if (IsAsyncStartCommand(inst, config)) {
- const HloInstruction* done = FindAsyncDoneCommand(inst);
- done_index = std::max(done_index, get_index(done));
- }
- seq.push_back(inst);
- }
- return seq;
- };
-
- // Check that instructions are safe to be captured by command buffer, and that
- // we do not capture unmatched async done instruction.
- auto check_async_region = [&](const HloInstructionSequence& seq) {
- if (!absl::c_all_of(seq.instructions(), [&](HloInstruction* inst) {
- return IsNoOp(inst) || IsCommand(inst, config) ||
- IsAsyncStartCommand(inst, config) ||
- IsAsyncDoneCommand(inst, config);
- })) {
- return false;
- }
-
- absl::flat_hash_set<HloInstruction*> done_instructions;
- for (const HloInstruction* inst : seq.instructions()) {
- if (IsAsyncStartCommand(inst, config)) {
- done_instructions.insert(FindAsyncDoneCommand(inst));
- }
- if (IsAsyncDoneCommand(inst, config)) {
- if (!done_instructions.contains(inst)) {
- return false;
- }
- }
- }
- return true;
- };
-
- for (size_t i = 0; i < instructions.size(); i++) {
- HloInstruction* inst = instructions.at(i);
-
- // We add no-op instructions to current sequence only if they act as a glue
- // between commands. We do not create command sequences consisting only from
- // no-op instruction. First and last instruction in the command buffer is
- // always a load-bearing command.
- if (IsNoOp(inst) && num_commands_in_current_seq) {
- current_seq.push_back(inst);
- continue;
- }
-
- // Synchronous commands always can be added to instruction sequence.
- if (IsCommand(inst, config)) {
- num_commands_in_current_seq++;
- current_seq.push_back(inst);
- continue;
- }
-
- // We capture async commands if all instruction between start and done can
- // be outlined into a command buffer.
- if (IsAsyncStartCommand(inst, config)) {
- HloInstructionSequence seq = collect_async_region(inst);
- if (check_async_region(seq)) {
- num_commands_in_current_seq += seq.instructions().size();
- for (HloInstruction* inst : seq.instructions()) {
- current_seq.push_back(inst);
- }
- i += seq.instructions().size() - 1;
- continue;
- }
- }
-
- // If we didn't find the next command, collect the current sequence and
- // start a new one.
- collect_current_seq();
- }
-
- // Don't forget to collect the final command sequence.
- collect_current_seq();
- return sequences;
-}
-
-// This function moves kParameter and kConstant instructions in a computation to
-// the beginning of the computation. This simplifies the construction of command
-// buffer computations because we don't need to deal with parameters and
-// constants that have users outside of a command buffer.
-absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
- HloComputation* computation) {
- HloInstructionSequence new_sequence;
- HloSchedule& schedule = computation->parent()->schedule();
- HloInstructionSequence& sequence = schedule.GetOrCreateSequence(computation);
-
- for (HloInstruction* inst : sequence.instructions()) {
- if (IsParameter(inst) || IsConstant(inst)) {
- new_sequence.push_back(inst);
-
- // Because we move instruction to the front of the computation we can't
- // have any control predecessors, however silently dropping them is unsafe
- // as we can have transitive dependencies that define schedule order, so
- // we forward control predecessors to all users.
- for (HloInstruction* control_predecessor : inst->control_predecessors()) {
- for (HloInstruction* user : inst->users()) {
- TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user));
- }
- }
- TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
- }
- }
-
- for (HloInstruction* inst : sequence.instructions()) {
- if (!IsParameter(inst) && !IsConstant(inst)) {
- new_sequence.push_back(inst);
- }
- }
-
- schedule.set_sequence(computation, new_sequence);
- return absl::OkStatus();
-}
-
-//===----------------------------------------------------------------------===//
-// Prepares command buffer from sequence of instructions
-//===----------------------------------------------------------------------===//
-
-absl::StatusOr<CommandBuffer> CommandBufferScheduling::PrepareCommandBuffer(
- const HloInstructionSequence& seq, HloModule* module) {
- auto builder = HloComputation::Builder("command_buffer");
-
- absl::Span<HloInstruction* const> instructions =
- absl::MakeSpan(seq.instructions());
-
- // A set of instructions that will be moved into command buffer computation.
- absl::flat_hash_set<HloInstruction*> in_command_buffer(instructions.begin(),
- instructions.end());
-
- // The sequence might use results of instructions that are not captured by the
- // sequence. We pass those results as parameters and map the producers of the
- // results to their corresponding parameter instructions.
- absl::flat_hash_map<HloInstruction*, HloParameterInstruction*> parameters;
-
- // Mapping from command buffer instructions to their clones in the command
- // buffer computation body.
- absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
-
- // Maps HLO instructions in the original computation to instructions in the
- // command buffer: (a) a parameter corresponding to captured value (b) cloned
- // instruction corresponding to a command.
- auto mapped_operands = [&](HloInstruction* instr) {
- absl::InlinedVector<HloInstruction*, 4> operands;
- for (HloInstruction* operand : instr->operands()) {
- if (auto it = inst_mapping.find(operand); it != inst_mapping.end())
- operands.push_back(it->second);
- }
- return operands;
- };
-
- // Create parameters in the command buffer computation for captured values.
- for (HloInstruction* inst : instructions) {
- for (HloInstruction* operand : inst->operands()) {
- // We already mapped instruction to a parameter.
- if (parameters.contains(operand)) continue;
-
- // Operand instruction is a part of the command buffer.
- if (in_command_buffer.contains(operand)) continue;
-
- // Create a new parameter for value defined outside of a command buffer.
- int64_t parameter_id = parameters.size();
- auto* parameter = Cast<HloParameterInstruction>(
- builder.AddInstruction(HloInstruction::CreateParameter(
- parameter_id, operand->shape(), "p")));
-
- parameter->UniquifyName(module);
- parameter->UniquifyId(module);
- inst_mapping[operand] = parameters[operand] = parameter;
- }
- }
-
- // Clone commands into the command buffer body with mapped operands.
- for (HloInstruction* inst : seq.instructions()) {
- HloCloneContext ctx(inst->GetModule());
-
- // Cloned instructions should call the same computations as original
- // instructions will be dead code eliminated.
- for (HloComputation* called_computation : inst->called_computations()) {
- // Async computations can only be referenced by a single async chain at
- // a time. Detach the current chain to let its copy bind to the
- // computation.
- if (called_computation->IsAsyncComputation()) {
- called_computation->RemoveAsyncStart();
- }
- ctx.MapComputation(called_computation, called_computation);
- }
-
- inst_mapping[inst] = builder.AddInstruction(
- inst->CloneWithNewOperands(inst->shape(), mapped_operands(inst), &ctx));
- inst_mapping[inst]->UniquifyId(module);
- }
-
- // Convert parameters to command buffer arguments.
- std::vector<HloInstruction*> arguments(parameters.size());
- for (auto& [argument, parameter] : parameters) {
- arguments[parameter->parameter_number()] = argument;
- }
-
- // Collect command buffer `results` (instructions replaced in the original
- // computation) and `results` (instructions in the command buffer).
- std::vector<HloInstruction*> results;
- std::vector<HloInstruction*> returned;
-
- auto has_external_users = [&](HloInstruction* inst) {
- return inst->IsRoot() || absl::c_any_of(inst->users(), [&](auto* user) {
- return !in_command_buffer.contains(user);
- });
- };
-
- for (HloInstruction* inst : instructions) {
- if (has_external_users(inst)) {
- results.push_back(inst);
- returned.push_back(inst_mapping[inst]);
- }
- }
-
- // If we return multiple results wrap them into tuple.
- if (returned.size() > 1) {
- HloInstruction* inst =
- builder.AddInstruction(HloInstruction::CreateTuple(returned));
- inst->UniquifyName(module);
- inst->UniquifyId(module);
- }
-
- std::unique_ptr<HloComputation> comp = builder.Build();
- comp->UniquifyName(module);
- comp->SetUniqueId(comp->root_instruction()->unique_id());
-
- return CommandBuffer{std::move(arguments), std::move(results),
- std::move(comp), std::move(inst_mapping)};
-}
-
-//===----------------------------------------------------------------------===//
-// Rewrites original computation into command buffer call
-//===----------------------------------------------------------------------===//
-
-absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
- HloComputation* parent, const HloInstructionSequence& seq,
- CommandBuffer command_buffer) {
- if (command_buffer.results.empty())
- return absl::InternalError("command buffer results must not be empty");
-
- // If we have more than one result we return them as tuple, and get individual
- // values using `get-tuple-element` instructions. Otherwise we simply return
- // a result from a command buffer computation.
- Shape cmd_buffer_result_shape;
- bool has_single_result = command_buffer.results.size() == 1;
-
- if (has_single_result) {
- cmd_buffer_result_shape = command_buffer.results[0]->shape();
- } else {
- absl::InlinedVector<Shape, 4> shapes;
- shapes.reserve(command_buffer.results.size());
- for (auto* res : command_buffer.results) shapes.push_back(res->shape());
- cmd_buffer_result_shape = ShapeUtil::MakeTupleShape(shapes);
- }
-
- HloComputation* computation =
- parent->parent()->AddComputation(std::move(command_buffer.computation),
- /*is_entry=*/false);
-
- HloInstruction* call = parent->AddInstruction(HloInstruction::CreateCall(
- cmd_buffer_result_shape, command_buffer.arguments, computation));
-
- // Replace all users or original results with a command buffer results.
- if (has_single_result) {
- TF_RETURN_IF_ERROR(command_buffer.results[0]->ReplaceAllUsesWith(call));
- } else {
- for (int i = 0; i < command_buffer.results.size(); i++) {
- TF_RETURN_IF_ERROR(
- command_buffer.results[i]->ReplaceAllUsesWith(parent->AddInstruction(
- HloInstruction::CreateGetTupleElement(call, i))));
- }
- }
-
- // As we are running after scheduling we have to keep it valid.
- HloSchedule& schedule = parent->parent()->schedule();
-
- // Update schedule to replace the last instruction with a command buffer call.
- // Removal of the rest of the instructions in the sequence is handled by
- // schedule update below.
- HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent);
- sequence.replace_instruction(seq.instructions().back(), call);
-
- // Rebuild original instruction sequence schedule in a newly created
- // command buffer computation to guarantee that we'll get exactly the same
- // buffer assignment result as if we were running without command buffers.
- HloInstructionSequence cmd_buffer_schedule;
- for (auto* argument : command_buffer.arguments) {
- cmd_buffer_schedule.push_back(command_buffer.inst_mapping[argument]);
- }
- for (auto* inst : seq.instructions()) {
- cmd_buffer_schedule.push_back(command_buffer.inst_mapping[inst]);
- }
- if (!has_single_result) {
- cmd_buffer_schedule.push_back(computation->root_instruction());
- }
- schedule.set_sequence(computation, cmd_buffer_schedule);
-
- // Forward control dependencies between original instructions to instruction
- // in the command buffer computation.
- auto& inst_mapping = command_buffer.inst_mapping;
- for (HloInstruction* inst : seq.instructions()) {
- HloInstruction* cmd_inst = inst_mapping[inst];
-
- // Forward control dependencies to the new instruction inside command
- // buffer. If the dependent instruction is not captured by the command
- // buffer, forward the dependency to the command buffer call instead.
- for (HloInstruction* predecessor : inst->control_predecessors()) {
- if (auto it = inst_mapping.find(predecessor); it != inst_mapping.end()) {
- // If predecessor mapped to a parameter instruction it means that we
- // need to forward control dependency to a call operation, otherwise
- // we add control dependency between commands in the command buffer.
- HloInstruction* cmd_predecessor = it->second;
- if (IsParameter(cmd_predecessor)) {
- TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
- } else {
- TF_RETURN_IF_ERROR(cmd_predecessor->AddControlDependencyTo(cmd_inst));
- }
- } else {
- TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
- }
- }
-
- for (HloInstruction* successor : inst->control_successors()) {
- if (auto it = inst_mapping.find(successor); it != inst_mapping.end()) {
- HloInstruction* cmd_successor = it->second;
- TF_RETURN_IF_ERROR(cmd_inst->AddControlDependencyTo(cmd_successor));
- } else {
- TF_RETURN_IF_ERROR(call->AddControlDependencyTo(successor));
- }
- }
-
- TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
- }
-
- // Traverse in reverse order as original sequence was topologically sorted and
- // we can't remove instructions with users.
- for (int32_t i = seq.instructions().size() - 1; i >= 0; i--) {
- TF_RETURN_IF_ERROR(parent->RemoveInstruction(seq.instructions()[i]));
- }
-
- return computation;
-}
-
-//===----------------------------------------------------------------------===//
-
-CommandBufferScheduling::CommandBufferScheduling(
- const se::DeviceDescription& device_description,
- int32_t gpu_toolkit_version, int32_t gpu_driver_version)
- : device_description_(device_description),
- gpu_toolkit_version_(gpu_toolkit_version),
- gpu_driver_version_(gpu_driver_version) {}
-
-absl::StatusOr<bool> CommandBufferScheduling::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- // We run command buffer scheduling after a regular scheduling to guarantee
- // that command buffers will not change execution order and buffer assignment
- // compared to a regular execution. Some operations (i.e. async collectives)
- // can't be captured into command buffers, and forming too large command
- // buffers too early can impact async operations scheduling.
- if (!module->has_schedule()) return Internal("module is not scheduled");
-
- const DebugOptions& debug_options = module->config().debug_options();
-
- absl::flat_hash_set<DebugOptions::CommandBufferCmdType> commands;
- for (auto cmd_type : debug_options.xla_gpu_enable_command_buffer()) {
- commands.insert(static_cast<DebugOptions::CommandBufferCmdType>(cmd_type));
- }
-
- absl::flat_hash_set<std::string> legacy_custom_call_targets;
- for (const auto& target :
- debug_options.legacy_command_buffer_custom_call_targets()) {
- legacy_custom_call_targets.insert(target);
- }
-
- CommandBufferConfig config{std::move(commands),
- std::move(legacy_custom_call_targets),
- device_description_};
-
- // Erase command buffer cmd types that are not supported by the gpu runtime.
- static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS};
- static constexpr auto kRequireTracing = {
- DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN,
- DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES};
-
- auto erase = [&](absl::Span<const DebugOptions::CommandBufferCmdType> cmds) {
- for (auto cmd : cmds) {
- if (config.enabled_commands.erase(cmd)) {
- VLOG(1) << "Removed command buffer support for "
- << DebugOptions::CommandBufferCmdType_Name(cmd)
- << " as it's not supported with gpu toolkit version "
- << gpu_toolkit_version_ << " and driver version "
- << gpu_driver_version_
- << ". This might negatively impact peformance. To enable "
- << DebugOptions::CommandBufferCmdType_Name(cmd)
- << " support in command buffers use cuda-compat package: "
-#if defined(PLATFORM_GOOGLE)
- << "set CUDA_COMPAT_LOAD=1 env variable.";
-#else
- << "https://docs.nvidia.com/deploy/cuda-compatibility/.";
-#endif
- }
- }
- };
-
- // Check if CUDA/ROCM driver supports required features.
- auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
- if (std::min(gpu_toolkit_version_, gpu_driver_version_) < 12030) {
- erase(kRequireTracing); // cuStreamBeginCaptureToGraph
- erase(kRequireConditionals); // on-device control flow
- }
- };
- auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
- erase(kRequireConditionals); // on-device control flow
- };
-
- std::visit(VariantVisitor{erase_cuda, erase_rocm},
- device_description_.gpu_compute_capability());
-
- auto order = module->MakeComputationPostOrder();
- std::reverse(order.begin(), order.end());
- absl::flat_hash_set<HloComputation*> processed_command_buffers;
-
- for (HloComputation* comp : order) {
- // Skip special computations that do not have lowering to thunks.
- if (comp->IsFusionComputation() || comp->IsAsyncComputation() ||
- comp->IsCustomCallComputation())
- continue;
-
- // Skip computations that already part of command buffers.
- if (processed_command_buffers.contains(comp)) continue;
-
- TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp));
-
- std::vector<HloInstructionSequence> sequences =
- CollectCommandBufferSequences(
- module->schedule().sequence(comp), config,
- debug_options.xla_gpu_graph_min_graph_size());
-
- for (const HloInstructionSequence& seq : sequences) {
- TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer,
- PrepareCommandBuffer(seq, comp->parent()));
- TF_ASSIGN_OR_RETURN(
- HloComputation * command_buffer_computation,
- RewriteCommandBuffer(comp, seq, std::move(command_buffer)));
-
- // All computations reachable from a command buffer computation are nested
- // command buffers (i.e. body computations attached to a while operation).
- for (HloComputation* called :
- command_buffer_computation->MakeEmbeddedComputationsList()) {
- processed_command_buffers.insert(called);
- }
- }
- }
- TF_RETURN_IF_ERROR(module->schedule().Update());
-
- return true;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h
deleted file mode 100644
index 78590a8..0000000
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h
+++ /dev/null
@@ -1,143 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_
-#define XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_
-
-#include <cstdint>
-#include <memory>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla::gpu {
-
-// Lift fusion instructions to command buffers.
-//
-// Before the pass:
-// %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
-// ...
-// }
-//
-// ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-// %a = s32[] parameter(0)
-// %b = s32[] parameter(1)
-// ROOT %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop,
-// calls=%fused_computation
-// }
-//
-// After the pass:
-// %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
-// ...
-// }
-//
-// %command_buffer (param_0: s32[], param_1: s32[]) -> s32[] {
-// %param_0 = s32[] parameter(0)
-// %param_1 = s32[] parameter(1)
-// ROOT %fusion = s32[] fusion(s32[] %param_0, s32[] %param_1), kind=kLoop,
-// calls=%fused_computation
-// }
-//
-// ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-// %a = s32[] parameter(0)
-// %b = s32[] parameter(1)
-// ROOT %call = s32[] call(s32[] %a, s32[] %b), to_apply=%command_buffer
-// }
-//
-// We currently do not have a command_buffer HLO operation, so we'll start with
-// a kCall op code with an attached HLO computation. We'll consider graduating
-// custom call to a first class operation later.
-class CommandBufferScheduling : public HloModulePass {
- public:
- struct CommandBufferConfig {
- // DebugOptions control which commands are enabled. Long term we want to
- // remove that flag and enable all supported commands by default.
- absl::flat_hash_set<DebugOptions::CommandBufferCmdType> enabled_commands;
- absl::flat_hash_set<std::string> enabled_legacy_custom_call_targets;
- const se::DeviceDescription& device_description;
- };
-
- CommandBufferScheduling(const se::DeviceDescription& device_description,
- int32_t gpu_toolkit_version,
- int32_t gpu_driver_version);
-
- absl::string_view name() const override {
- return "command-buffer-scheduling";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- static std::vector<HloInstructionSequence> CollectCommandBufferSequences(
- HloInstructionSequence schedule, const CommandBufferConfig& config,
- int32_t min_num_commands = 1);
-
- // Moves kParameter and kConstant instructions in a computation to
- // the beginning of the computation. This simplifies the construction of
- // command buffer computations because we don't need to deal with parameters
- // and constants that have users outside of a command buffer.
- static absl::Status MoveParametersAndConstantsToFront(
- HloComputation* computation);
-
- struct CommandBuffer {
- // Command buffer arguments (call instruction arguments).
- std::vector<HloInstruction*> arguments;
-
- // Command buffer result (call instruction result tuple).
- std::vector<HloInstruction*> results;
-
- // Hlo computation corresponding to a command buffer body.
- std::unique_ptr<HloComputation> computation;
-
- // Mapping from original instruction to their clones in the command buffer.
- absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
- };
-
- // Prepares a command buffer from the instruction sequence. Used values
- // constructed by instructions outside of the sequence are passed in as
- // parameters. Results of instructions in the sequence are returned in a tuple
- // (if command buffer has a single result we don't wrap it into tuple).
- static absl::StatusOr<CommandBuffer> PrepareCommandBuffer(
- const HloInstructionSequence& seq, HloModule* module);
-
- // Rewrites prepared command buffer computation into Hlo operations in the
- // parent computation (calls command buffer and replaced all users).
- static absl::StatusOr<HloComputation*> RewriteCommandBuffer(
- HloComputation* parent, const HloInstructionSequence& seq,
- CommandBuffer command_buffer);
-
- private:
- se::DeviceDescription device_description_;
- // For NVIDIA gpus XLA can be compiled with a CUDA version that is larger than
- // the version supported by the driver, e.g. we can compile for CUDA 12.3 but
- // have 12.1 driver installed. When deciding what command buffer features we
- // can use we have to consider both versions.
- int32_t gpu_toolkit_version_;
- int32_t gpu_driver_version_;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
deleted file mode 100644
index 3a46193..0000000
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
+++ /dev/null
@@ -1,1018 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/command_buffer_scheduling.h"
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-class CommandBufferSchedulingTest : public HloTestBase {
- public:
- // Use CUDA 12.3 version for testing as it has all the features we rely on.
- static constexpr int32_t kCudaVersion = 12030;
-
- const se::DeviceDescription& device_desc() {
- return backend().default_stream_executor()->GetDeviceDescription();
- }
-
- DebugOptions GetDebugOptionsForTest() override {
- auto debug_options = HloTestBase::GetDebugOptionsForTest();
- debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
- debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS);
- debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
- debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
- debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
- debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
- debug_options.set_xla_gpu_graph_min_graph_size(2);
- return debug_options;
- }
-};
-
-using CommandBuffer = CommandBufferScheduling::CommandBuffer;
-
-TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[], b: s32[]) -> s32[] {
- %a = s32[] parameter(0)
- %b = s32[] parameter(1)
- %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
- %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1
- ROOT %custom-call = s32[] custom-call(s32[] %fusion, s32[] %fusion.1), custom_call_target="some target"
- })";
-
- const char* expected = R"(
-// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
-// CHECK: %[[P0]] = s32[] parameter(0)
-// CHECK: %[[P1]] = s32[] parameter(1)
-// CHECK: %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
-// CHECK: %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
-// CHECK: ROOT %tuple = (s32[], s32[]) tuple(%fusion, %fusion.1)
-// CHECK: }
-//
-// CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-// CHECK: %a = s32[] parameter(0)
-// CHECK: %b = s32[] parameter(1)
-// CHECK: %call = (s32[], s32[]) call(%a, %b), to_apply=%command_buffer
-// CHECK: %get-tuple-element = s32[] get-tuple-element(%call), index=0
-// CHECK: %get-tuple-element.1 = s32[] get-tuple-element(%call), index=1
-// CHECK: ROOT %custom-call = s32[] custom-call(%get-tuple-element, %get-tuple-element.1), custom_call_target="some target"
-// CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
- %a = s32[] parameter(0)
- %b = s32[] parameter(1)
- %c = (s32[], s32[]) parameter(2)
- %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
- %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
- %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
- %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
- %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
- %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
- %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
- ROOT %custom-call.1 = s32[] custom-call(s32[] %fusion.3), custom_call_target="some target"
- })";
-
- const char* expected = R"(
-// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[], [[P2:.+]]: (s32[], s32[])) -> s32[] {
-// CHECK: %[[P0]] = s32[] parameter(0)
-// CHECK: %[[P1]] = s32[] parameter(1)
-// CHECK: %[[P2]] = (s32[], s32[]) parameter(2)
-// CHECK: %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
-// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%[[P2]]), index=0
-// CHECK: ROOT {{.*}} = s32[] fusion(%[[F0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
-// CHECK: }
-
-// CHECK: %command_buffer.2 ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
-// CHECK: %[[P0]] = s32[] parameter(0)
-// CHECK: %[[P1]] = s32[] parameter(1)
-// CHECK: %[[F2:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.2
-// CHECK: ROOT {{.*}} = s32[] fusion(%[[P0]], %[[F2]]), kind=kLoop, calls=%fused_computation.3
-// CHECK: }
-
-// CHECK: ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
-// CHECK: %a = s32[] parameter(0)
-// CHECK: %b = s32[] parameter(1)
-// CHECK: %c = (s32[], s32[]) parameter(2)
-// CHECK: %[[CMD0:.+]] = s32[] call(%a, %b, %c), to_apply=%command_buffer
-// CHECK: %e = s32[] get-tuple-element(%c), index=1
-// CHECK: %[[CALL:.+]] = s32[] custom-call(%[[CMD0]], %e), custom_call_target="some target"
-// CHECK: %[[CMD1:.+]] = s32[] call(%[[CALL]], %a), to_apply=%command_buffer.2
-// CHECK: ROOT {{.*}} = s32[] custom-call(%[[CMD1]]), custom_call_target="some target"
-// CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %add (p0: s32[4], p1: s32[4]) -> s32[4] {
- %p0 = s32[4] parameter(0)
- %p1 = s32[4] parameter(1)
- ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
- }
-
- ENTRY %main (a: s32[4]) -> s32[4] {
- %a = s32[4] parameter(0)
- %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
- replica_groups={{0,1}}, to_apply=%add,
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
- ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
- CHECK: %[[P0]] = s32[4]{0} parameter(0)
- CHECK: %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
- CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
- CHECK: }
-
- CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
- CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
- CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
- CHECK: to_apply=%command_buffer
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- ENTRY %main (a: s32[2]) -> s32[4] {
- %a = s32[2] parameter(0)
-
- %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a),
- channel_id=555, replica_groups={{0,1}}, dimensions={0},
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-
- ROOT %done = s32[4]{0} all-gather-done(%start)
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s32[2]) -> s32[4] {
- CHECK: %[[P0]] = s32[2]{0} parameter(0)
- CHECK: %[[START:.+]] = {{.*}} all-gather-start(%[[P0]])
- CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-gather-done(%[[START]])
- CHECK: }
-
- CHECK: ENTRY %main (a: s32[2]) -> s32[4] {
- CHECK: %[[A:.+]] = s32[2]{0} parameter(0)
- CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
- CHECK: to_apply=%command_buffer
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %add (p0: s32[], p1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[4]) -> s32[2] {
- %a = s32[4] parameter(0)
-
- %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a),
- channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add,
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-
- ROOT %done = s32[2]{0} reduce-scatter-done(%start)
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[2] {
- CHECK: %[[P0]] = s32[4]{0} parameter(0)
- CHECK: %[[START:.+]] = {{.*}} reduce-scatter-start(%[[P0]])
- CHECK: ROOT %[[DONE:.+]] = s32[2]{0} reduce-scatter-done(%[[START]])
- CHECK: }
-
- CHECK: ENTRY %main (a: s32[4]) -> s32[2] {
- CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
- CHECK: ROOT %[[CALL:.+]] = s32[2]{0} call(%[[A]]),
- CHECK: to_apply=%command_buffer
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %add (p0: s32[4], p1: s32[4]) -> s32[4] {
- %p0 = s32[4] parameter(0)
- %p1 = s32[4] parameter(1)
- ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
- }
-
- ENTRY %main (a: s32[4]) -> s32[4] {
- %a = s32[4] parameter(0)
- %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
- replica_groups={{0,1}}, to_apply=%add,
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
- %bitcast = s32[4] bitcast(s32[4]{0} %a)
- ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
- CHECK: %[[P0]] = s32[4]{0} parameter(0)
- CHECK: %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
- CHECK: %[[BITCAST:.+]] = s32[4]{0} bitcast(%[[P0]])
- CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
- CHECK: }
-
- CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
- CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
- CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
- CHECK: to_apply=%command_buffer
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %add (p0: s32[4], p1: s32[4]) -> s32[4] {
- %p0 = s32[4] parameter(0)
- %p1 = s32[4] parameter(1)
- ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
- }
-
- ENTRY %main (a: s32[4]) -> s32[4] {
- %a = s32[4] parameter(0)
- %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
- replica_groups={{0,1}}, to_apply=%add,
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
- %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
- replica_groups={{0,1}}, to_apply=%add,
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
- %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
- ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
- CHECK: %[[P0]] = s32[4]{0} parameter(0)
- CHECK: %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
- CHECK: %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
- CHECK: %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
- CHECK: ROOT %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
- CHECK: }
-
- CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
- CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
- CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
- CHECK: to_apply=%command_buffer
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %add (p0: s32[4], p1: s32[4]) -> s32[4] {
- %p0 = s32[4] parameter(0)
- %p1 = s32[4] parameter(1)
- ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
- }
-
- ENTRY %main (a: s32[4], b:s32[]) -> s32[] {
- %a = s32[4] parameter(0)
- %b = s32[] parameter(1)
- %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
- replica_groups={{0,1}}, to_apply=%add,
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
- %c = s32[] custom-call(), custom_call_target="target"
- %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
- replica_groups={{0,1}}, to_apply=%add,
- backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
- %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
- %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
- %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation
- ROOT %fusion.1 = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation.1
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
- CHECK: %[[P0]] = s32[] parameter(0)
- CHECK: %[[P1]] = s32[] parameter(1)
- CHECK: %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
- CHECK: ROOT %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
- CHECK: }
-
- CHECK: ENTRY %main (a: s32[4], b: s32[]) -> s32[] {
- CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
- CHECK: %[[B:.+]] = s32[] parameter(1)
- CHECK: %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[A]])
- CHECK: %[[C:.+]] = s32[] custom-call()
- CHECK: %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[A]])
- CHECK: %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
- CHECK: %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
- CHECK: %call = s32[] call(%b, %c), to_apply=%command_buffer
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
- %a = s32[] parameter(0)
- %b = s32[] parameter(1)
- %c = (s32[], s32[]) parameter(2)
- %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
- %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
- %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
- %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
- %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
- %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
- ROOT %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(hlo));
-
- HloInstructionSequence seq;
- for (HloInstruction* x : module->entry_computation()->instructions()) {
- seq.push_back(x);
- }
- EXPECT_EQ(seq.size(), 10);
-
- CommandBufferScheduling::CommandBufferConfig config{
- {DebugOptions::FUSION}, {}, device_desc()};
-
- std::vector<HloInstructionSequence> command_buffer_sequences =
- CommandBufferScheduling::CollectCommandBufferSequences(seq, config);
- EXPECT_EQ(command_buffer_sequences.size(), 2);
-
- std::vector<HloInstruction*> seq_0 =
- command_buffer_sequences[0].instructions();
- EXPECT_EQ(seq_0.size(), 3);
- EXPECT_EQ(seq_0[0]->opcode(), HloOpcode::kFusion);
- EXPECT_EQ(seq_0[1]->opcode(), HloOpcode::kGetTupleElement);
- EXPECT_EQ(seq_0[2]->opcode(), HloOpcode::kFusion);
-
- std::vector<HloInstruction*> seq_1 =
- command_buffer_sequences[1].instructions();
- EXPECT_EQ(seq_1.size(), 2);
- EXPECT_EQ(seq_1[0]->opcode(), HloOpcode::kFusion);
- EXPECT_EQ(seq_1[1]->opcode(), HloOpcode::kFusion);
-}
-
-TEST_F(CommandBufferSchedulingTest, MoveParametersToFront) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
- %a = s32[] parameter(0)
- %b = s32[] parameter(1)
- %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
- %c = s32[] parameter(2)
- ROOT %fusion.1 = s32[] fusion(s32[] %a, s32[] %c), kind=kLoop, calls=%fused_computation.1
- })";
-
- const char* expected = R"(
-// CHECK: ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
-// CHECK: %a = s32[] parameter(0)
-// CHECK: %b = s32[] parameter(1)
-// CHECK: %c = s32[] parameter(2)
-// CHECK: %fusion = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation
-// CHECK: ROOT %fusion.1 = s32[] fusion(%a, %c), kind=kLoop, calls=%fused_computation.1
-// CHECK: })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(hlo));
- TF_ASSERT_OK(CommandBufferScheduling::MoveParametersAndConstantsToFront(
- module->entry_computation()));
- TF_ASSERT_OK_AND_ASSIGN(
- bool filecheck_matches,
- RunFileCheck(
- module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
- expected));
- EXPECT_TRUE(filecheck_matches);
-}
-
-TEST_F(CommandBufferSchedulingTest, PrepareCommandBuffer) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation(param_0: s32[], param_1: s32[]) -> (s32[], s32[]) {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %tuple.1 = (s32[], s32[]) tuple(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[], b: s32[]) -> s32[] {
- %a = s32[] parameter(0)
- %b = s32[] custom-call(), custom_call_target="target"
- %fusion = (s32[], s32[]) fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
- %d = s32[] get-tuple-element((s32[], s32[]) %fusion), index=0
- %fusion.1 = s32[] fusion(s32[] %a, s32[] %d), kind=kLoop, calls=%fused_computation.1
- ROOT %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %d), custom_call_target="some target"
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule(hlo));
-
- EXPECT_EQ(module->entry_computation()->instruction_count(), 6);
- std::vector<HloInstruction*> instructions;
- HloInstructionSequence seq;
- for (HloInstruction* inst : module->entry_computation()->instructions()) {
- if (inst->opcode() == HloOpcode::kFusion ||
- inst->opcode() == HloOpcode::kGetTupleElement) {
- seq.push_back(inst);
- }
- instructions.push_back(inst);
- }
-
- TF_ASSERT_OK_AND_ASSIGN(
- CommandBuffer command_buffer,
- CommandBufferScheduling::PrepareCommandBuffer(seq, module.get()));
- HloComputation* computation = module->AddComputation(
- std::move(command_buffer.computation), /*is_entry=*/false);
-
- const char* expected = R"(
-// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
-// CHECK: %[[P0]] = s32[] parameter(0)
-// CHECK: %[[P1]] = s32[] parameter(1)
-// CHECK: %fusion = (s32[], s32[]) fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
-// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%fusion), index=0
-// CHECK: %fusion.1 = s32[] fusion(%[[P0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
-// CHECK: ROOT {{.*}} = (s32[], s32[]) tuple(%[[V0]], %fusion.1)
-// CHECK:})";
-
- TF_ASSERT_OK_AND_ASSIGN(
- bool filecheck_matches,
- RunFileCheck(computation->ToString(
- HloPrintOptions{}.set_print_operand_shape(false)),
- expected));
- EXPECT_TRUE(filecheck_matches);
-
- auto& arguments = command_buffer.arguments;
- ASSERT_EQ(arguments.size(), 2);
- EXPECT_EQ(arguments[0], instructions[0]);
- EXPECT_EQ(arguments[1], instructions[1]);
-
- auto& results = command_buffer.results;
- ASSERT_EQ(results.size(), 2);
- EXPECT_EQ(results[0], instructions[3]);
- EXPECT_EQ(results[1], instructions[4]);
-}
-
-TEST_F(CommandBufferSchedulingTest, ForwardControlDependencies) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.2 (param_0: s32[], param_1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[], b: s32[]) -> s32[] {
- %a = s32[] parameter(0)
- %b = s32[] parameter(1)
- %custom-call = s32[] custom-call(), custom_call_target="some target"
- %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation, control-predecessors={%custom-call}
- %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion}
- %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
- %fusion.2 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%fusion.1}
- ROOT %custom-call.2 = s32[] custom-call(s32[] %fusion.1, s32[] %fusion.2), custom_call_target="some target"
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
- CHECK: %[[P0]] = s32[] parameter(0)
- CHECK: %[[P1]] = s32[] parameter(1)
- CHECK: %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]])
- CHECK: ROOT {{.*}} = s32[] fusion(%[[P0]], %[[P1]]), {{.*}} control-predecessors={%[[F0]]}
- CHECK: }
-
- CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
- CHECK: %a = s32[] parameter(0)
- CHECK: %b = s32[] parameter(1)
- CHECK: %custom-call = s32[] custom-call(), custom_call_target="some target"
- CHECK: %call = s32[] call(%a, %b), to_apply=%command_buffer, control-predecessors={%custom-call}
- CHECK: %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
- CHECK: %[[F3:.+]] = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%call}
- CHECK: ROOT %custom-call.2 = s32[] custom-call(%call, %[[F3]]), custom_call_target="some target"
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, ForwardControlDependenciesToParams) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation.0 (p0: s32[], p1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- %fused_computation.1 (p0: s32[], p1: s32[]) -> s32[] {
- %p0 = s32[] parameter(0)
- %p1 = s32[] parameter(1)
- ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
- }
-
- ENTRY %main (a: s32[], b: s32[]) -> s32[] {
- %a = s32[] parameter(0)
- %b = s32[] parameter(1)
- %custom-call = s32[] custom-call(), custom_call_target="some target"
- %fusion = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.0, control-predecessors={%custom-call}
- ROOT %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %b), kind=kLoop, calls=%fused_computation.1
- })";
-
- const char* expected = R"(
- CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
- CHECK: %a = s32[] parameter(0)
- CHECK: %b = s32[] parameter(1)
- CHECK: %[[CUSTOM_CALL:.+]] = s32[] custom-call(), custom_call_target="some target"
- CHECK: ROOT {{.*}} call(%[[CUSTOM_CALL]], %a, %b), to_apply=%command_buffer, control-predecessors={%[[CUSTOM_CALL]]}
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, WhileNotCommand) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation (param_0: f32[1]) -> f32[1] {
- %param_0 = f32[1]{0} parameter(0)
- ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
- }
-
- %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
- %param_0.1 = f32[1]{0} parameter(0)
- %param_1 = f32[1]{0} parameter(1)
- ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
- }
-
- %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
- %param_0.2 = f32[1]{0} parameter(0)
- %param_1.1 = f32[1]{0} parameter(1)
- ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
- }
-
- %fused_computation.3 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
- %param_0.1 = f32[1]{0} parameter(0)
- %param_1 = f32[1]{0} parameter(1)
- ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
- }
-
- %body (Arg_.3: f32[1]) -> f32[1] {
- %constant_4 = f32[1]{0} constant({1})
- %Arg_.3 = f32[1]{0} parameter(0)
- %custom-call = s32[] custom-call(), custom_call_target="some target"
- %add = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1, control-predecessors={%custom-call}
- ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %add, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.3, control-predecessors={%custom-call}
- }
-
- %cond (Arg_.11: f32[1]) -> pred[] {
- %constant = f32[1]{0} constant({100})
- %Arg_.11 = f32[1]{0} parameter(0)
- %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
- ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
- }
-
- ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
- %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
- %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
- %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
- ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: f32[1], [[P1:.+]]: f32[1]) -> f32[1] {
- CHECK: %[[P0]] = f32[1]{0} parameter(0)
- CHECK: %[[P1]] = f32[1]{0} parameter(1)
- CHECK: %[[ADD:.*]] = f32[1]{0} fusion(%[[P0]], %[[P1]]), kind=kLoop
- CHECK: ROOT {{.*}} = f32[1]{0} fusion(%[[ADD]], %[[P1]]), kind=kLoop
- CHECK: }
-
- CHECK: %[[BODY:[a-z_0-9.]+]] ([[P0:.+]]: f32[1]) -> f32[1] {
- CHECK: %[[C1:.*]] = f32[1]{0} constant({1})
- CHECK: %[[P0]] = f32[1]{0} parameter(0)
- CHECK: %[[CC:.*]] = s32[] custom-call(), custom_call_target="some target"
- CHECK: ROOT %call = f32[1]{0} call(%[[P0]], %[[C1]]), to_apply=%command_buffer, control-predecessors={%[[CC]]}
- CHECK: }
-
- CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
- CHECK: %[[ARG0]] = f32[1]{0} parameter(0)
- CHECK: %[[COPY:.*]] = f32[1]{0} fusion(%[[ARG0]]), kind=kLoop
- CHECK: %[[WHILE:.*]] = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY]]
- CHECK: ROOT %[[BC:.+]] = f32[] bitcast(%[[WHILE]])
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, While) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation (param_0: f32[1]) -> f32[1] {
- %param_0 = f32[1]{0} parameter(0)
- ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
- }
-
- %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
- %param_0.1 = f32[1]{0} parameter(0)
- %param_1 = f32[1]{0} parameter(1)
- ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
- }
-
- %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
- %param_0.2 = f32[1]{0} parameter(0)
- %param_1.1 = f32[1]{0} parameter(1)
- ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
- }
-
- %body (Arg_.3: f32[1]) -> f32[1] {
- %constant_4 = f32[1]{0} constant({1})
- %Arg_.3 = f32[1]{0} parameter(0)
- ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1
- }
-
- %cond (Arg_.11: f32[1]) -> pred[] {
- %constant = f32[1]{0} constant({100})
- %Arg_.11 = f32[1]{0} parameter(0)
- %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
- ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
- }
-
- ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
- %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
- %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
- %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
- ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: f32[1]) -> f32[1] {
- CHECK: %[[P0]] = f32[1]{0} parameter(0)
- CHECK: %[[COPY:.*]] = f32[1]{0} fusion(%[[P0]]), kind=kLoop
- CHECK: ROOT {{.*}} = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY:[a-z_0-9.]+]]
- CHECK: }
-
- CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
- CHECK: %[[ARG0]] = f32[1]{0} parameter(0)
- CHECK: %call = f32[1]{0} call(%[[ARG0]]), to_apply=%command_buffer
- CHECK: ROOT %[[BC:.+]] = f32[] bitcast(%call)
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, Conditional) {
- const char* hlo = R"(
- HloModule TestModule, is_scheduled=true
-
- %fused_computation.1 (param_0.2: s32[5]) -> s32[5] {
- %param_0.2 = s32[5]{0} parameter(0)
- ROOT %negate.2 = s32[5]{0} negate(s32[5]{0} %param_0.2)
- }
-
- %region_0.7 (Arg_.8: s32[5]) -> (s32[5]) {
- %Arg_.8 = s32[5]{0} parameter(0)
- %wrapped_negate.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.8), kind=kLoop, calls=%fused_computation.1
- ROOT %tuple.3 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_negate.1)
- }
-
- %fused_computation.2 (param_0.3: s32[5]) -> s32[5] {
- %param_0.3 = s32[5]{0} parameter(0)
- ROOT %not.2 = s32[5]{0} not(s32[5]{0} %param_0.3)
- }
-
- %region_1.10 (Arg_.11: s32[5]) -> (s32[5]) {
- %Arg_.11 = s32[5]{0} parameter(0)
- %wrapped_not.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.11), kind=kLoop, calls=%fused_computation.2
- ROOT %tuple.4 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_not.1)
- }
-
- %fused_computation.3 (param_0.4: s32[5]) -> s32[5] {
- %param_0.4 = s32[5]{0} parameter(0)
- ROOT %multiply.2 = s32[5]{0} multiply(s32[5]{0} %param_0.4, s32[5]{0} %param_0.4)
- }
-
- %region_2.13 (Arg_.14: s32[5]) -> (s32[5]) {
- %Arg_.14 = s32[5]{0} parameter(0)
- %wrapped_multiply.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.14), kind=kLoop, calls=%fused_computation.3
- ROOT %tuple.5 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_multiply.1)
- }
-
- %fused_computation (param_0.1: s64[]) -> s32[] {
- %constant_1 = s32[] constant(0)
- %param_0.1 = s64[] parameter(0)
- %convert.2 = s32[] convert(s64[] %param_0.1)
- %constant_0 = s32[] constant(2)
- ROOT %clamp.2 = s32[] clamp(s32[] %constant_1, s32[] %convert.2, s32[] %constant_0)
- }
-
- ENTRY %main.17 (Arg_0.1: s64[], Arg_1.2: s32[5]) -> s32[5] {
- %Arg_0.1 = s64[] parameter(0), sharding={replicated}
- %fusion = s32[] fusion(s64[] %Arg_0.1), kind=kLoop, calls=%fused_computation
- %Arg_1.2 = s32[5]{0} parameter(1), sharding={replicated}
- %conditional.16.clone = (s32[5]{0}) conditional(s32[] %fusion, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2), branch_computations={%region_0.7, %region_1.10, %region_2.13}
- ROOT %get-tuple-element = s32[5]{0} get-tuple-element((s32[5]{0}) %conditional.16.clone), index=0
- })";
-
- const char* expected = R"(
- CHECK: %command_buffer ([[P0:.+]]: s64[], [[P1:.+]]: s32[5]) -> (s32[5]) {
- CHECK: %[[P0]] = s64[] parameter(0)
- CHECK: %[[P1]] = s32[5]{0} parameter(1)
- CHECK: %[[FUSION:.*]] = s32[] fusion(%[[P0]]), kind=kLoop
- CHECK: ROOT {{.*}} = (s32[5]{0}) conditional(%[[FUSION]], %[[P1]], %[[P1]], %[[P1]]), branch_computations={%[[B1:[a-z_0-9.]+]], %[[B2:[a-z_0-9.]+]], %[[B3:[a-z_0-9.]+]]}
- CHECK: }
-
- CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: s64[], [[ARG1:.+]]: s32[5]) -> s32[5] {
- CHECK: %[[ARG0]] = s64[] parameter(0)
- CHECK: %[[ARG1]] = s32[5]{0} parameter(1)
- CHECK: %call = (s32[5]{0}) call(%[[ARG0]], %[[ARG1]]), to_apply=%command_buffer
- CHECK: ROOT %[[GEP:.+]] = s32[5]{0} get-tuple-element(%call)
- CHECK: })";
-
- RunAndFilecheckHloRewrite(
- hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- expected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-TEST_F(CommandBufferSchedulingTest, CuDnnFusionGraphCaptureWorks) {
- const std::string kHloText = R"(
-HloModule m, is_scheduled=true
-
-fusion0 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- ROOT d = f32[64,64] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-fusion1 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- ROOT d = f32[64,64] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-}
-
-fusion_a {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- ROOT a = f32[64,64] add(p0, p1)
-}
-
-ENTRY e {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- d0 = f32[64,64] fusion(p0, p1), kind=kCustom,
- calls=fusion0,
- backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
- a = f32[64,64] fusion(d0, d0), kind=kLoop, calls=fusion_a
- ROOT d1 = f32[64,64] fusion(a, p1), kind=kCustom,
- calls=fusion1,
- backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
-})";
-
- const std::string kExpected = R"(
-; CHECK: ENTRY
-; CHECK-NEXT: parameter
-; CHECK-NEXT: parameter
-; CHECK-NEXT: ROOT
-; CHECK-SAME: call(
-; CHECK-SAME: to_apply=%command_buffer
-})";
-
- RunAndFilecheckHloRewrite(
- kHloText,
- CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
- kExpected, [](HloModule* module) {
- EXPECT_TRUE(module->has_schedule());
- TF_CHECK_OK(module->schedule().Verify());
- });
-}
-
-} // namespace
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
deleted file mode 100644
index 40bbb7a..0000000
--- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
+++ /dev/null
@@ -1,1193 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/conv_algorithm_picker.h"
-
-#include <algorithm>
-#include <cmath>
-#include <cstddef>
-#include <cstdint>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <string>
-#include <string_view>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "xla/autotuning.pb.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_autotuning.pb.h"
-#include "xla/service/gpu/gpu_conv_runner.h"
-#include "xla/service/gpu/hlo_algorithm_denylist.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/slow_operation_alarm.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/cuda/cuda_platform_id.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/numeric_options.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/rocm/rocm_platform_id.h"
-#include "xla/stream_executor/scratch_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/tsl/util/env_var.h"
-#include "xla/tsl/util/proto/proto_utils.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/numbers.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#if CUDNN_VERSION >= 90000
-#include "third_party/gpus/cudnn/cudnn_ops.h"
-#else
-#include "third_party/gpus/cudnn/cudnn_ops_infer.h"
-#endif // CUDNN_VERSION >= 90000
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#endif
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using se::DeviceMemoryBase;
-using se::dnn::AlgorithmDesc;
-using std::optional;
-
-class ScratchAllocator : public se::ScratchAllocator {
- public:
- ScratchAllocator(int device_ordinal,
- se::DeviceMemoryAllocator* memory_allocator)
- : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
-
- int64_t GetMemoryLimitInBytes() override {
- return ScratchAllocator::GetDefaultMemoryLimitInBytes();
- }
- int64_t TotalAllocatedBytes() { return total_allocated_bytes_; }
-
- static int64_t GetDefaultMemoryLimitInBytes() {
- int64_t value;
- TF_CHECK_OK(tsl::ReadInt64FromEnvVar("TF_CUDNN_WORKSPACE_LIMIT_IN_MB",
- 1LL << 12, &value));
- return value * (1LL << 20);
- }
-
- absl::StatusOr<se::DeviceMemory<uint8_t>> AllocateBytes(
- int64_t byte_size) override;
-
- template <typename T>
- absl::StatusOr<se::DeviceMemory<T>> Allocate(int64_t num_elements) {
- TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8_t> bytes,
- AllocateBytes(num_elements * sizeof(T)));
- return se::DeviceMemory<T>(bytes);
- }
-
- private:
- const int device_ordinal_;
- se::DeviceMemoryAllocator* memory_allocator_;
- std::vector<se::OwningDeviceMemory> allocated_buffers_;
- int64_t total_allocated_bytes_ = 0;
-};
-
-absl::StatusOr<se::DeviceMemory<uint8_t>> ScratchAllocator::AllocateBytes(
- int64_t byte_size) {
- CHECK_GE(byte_size, 0) << "byte_size must be positive.";
- if (byte_size > GetMemoryLimitInBytes()) {
- return absl::ResourceExhaustedError(absl::StrFormat(
- "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size,
- GetMemoryLimitInBytes()));
- }
-
- TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
- memory_allocator_->Allocate(device_ordinal_, byte_size,
- /*retry_on_failure=*/false));
- total_allocated_bytes_ += byte_size;
-
- se::DeviceMemoryBase buffer_addr = *allocated_buffer;
- allocated_buffers_.push_back(std::move(allocated_buffer));
- return se::DeviceMemory<uint8_t>(buffer_addr);
-}
-
-absl::StatusOr<std::vector<GenericConvRunner>> GetAlgorithms(
- const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend,
- bool use_fallback, const se::NumericOptions& numeric_options) {
- TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
- GetDNNConvKindFromCudnnConvKind(config.kind));
-
- TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type,
- GetDNNDataTypeFromPrimitiveType(config.input_type));
-
- TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type,
- GetDNNDataTypeFromPrimitiveType(config.output_type));
-
- se::StreamExecutor* stream_exec = stream->parent();
- std::vector<GenericConvRunner> result;
-
- auto dnn = stream_exec->AsDnn();
- if (dnn == nullptr) {
- return absl::InvalidArgumentError("No DNN in stream executor.");
- }
- switch (kind) {
- default:
- return Internal("Unknown ConvolutionKind %d", kind);
- case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: {
- if (!config.fusion) {
- return Internal(
- "GpuConvConfig had fusion ConvolutionKind but no FusionConfig.");
- }
- std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>> runners;
- TF_RETURN_IF_ERROR(dnn->GetFusedConvolveRunners(
- use_cudnn_frontend,
- // This refers to the kind of convolution op inside the fusion, not
- // the whole fused graph.
- se::dnn::ConvolutionKind::FORWARD, input_type,
- BiasTypeForInputType(input_type), output_type,
- /* conv_input_scale = */ config.conv_result_scale,
- /* side_input_scale = */ config.fusion->side_input_scale,
- /* leakyrelu_alpha = */ config.fusion->leakyrelu_alpha, stream,
- config.input_descriptor, config.filter_descriptor,
- config.bias_descriptor, config.output_descriptor, config.conv_desc,
- use_fallback, config.fusion->mode, numeric_options, &runners));
- for (auto& runner : runners) {
- TF_ASSIGN_OR_RETURN(
- auto runner_cache,
- se::dnn::LazyOpRunner<se::dnn::FusedConvOp>::FromOpRunner(
- std::move(runner)));
- result.emplace_back(std::move(runner_cache));
- }
- break;
- }
-
- case se::dnn::ConvolutionKind::FORWARD_GRAPH: {
- std::vector<std::unique_ptr<const se::dnn::GraphConvRunner>> runners;
- // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
- // allocator are unused; so, they're all provided as nullptr.
- TF_RETURN_IF_ERROR(dnn->GetGraphConvolveRunners(
- kind, input_type, output_type, stream, config.input_descriptor,
- config.filter_descriptor, config.output_descriptor, config.conv_desc,
- use_fallback, numeric_options, &runners, config.serialized_graph));
- for (auto& runner : runners) {
- TF_ASSIGN_OR_RETURN(
- auto runner_cache,
- se::dnn::LazyOpRunner<se::dnn::GraphConvOp>::FromOpRunner(
- std::move(runner)));
- result.emplace_back(std::move(runner_cache));
- }
- break;
- }
-
- case se::dnn::ConvolutionKind::FORWARD:
- case se::dnn::ConvolutionKind::BACKWARD_DATA:
- case se::dnn::ConvolutionKind::BACKWARD_FILTER: {
- std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
- // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
- // allocator are unused; so, they're all provided as nullptr.
- TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
- use_cudnn_frontend, kind, input_type, output_type, stream,
- config.input_descriptor,
- /* input_data = */ DeviceMemoryBase(nullptr),
- config.filter_descriptor,
- /* filter_data = */ DeviceMemoryBase(nullptr),
- config.output_descriptor,
- /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
- use_fallback, nullptr, numeric_options, &runners));
-
- for (auto& runner : runners) {
- TF_ASSIGN_OR_RETURN(
- auto runner_cache,
- se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
- std::move(runner)));
- result.emplace_back(std::move(runner_cache));
- }
- break;
- }
- }
-
- return result;
-}
-
-absl::StatusOr<std::vector<std::unique_ptr<const se::dnn::ConvRunner>>>
-GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
- absl::Span<se::DeviceMemoryBase> operand_buffers,
- absl::Span<se::DeviceMemoryBase> result_buffers,
- se::StreamExecutor* stream_exec,
- ScratchAllocator* scratch_allocator, se::Stream* stream,
- const se::NumericOptions& numeric_options) {
- TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
-
- TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
- GetDNNConvKindFromCudnnConvKind(config.kind));
-
- TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
- GetDNNDataTypeFromPrimitiveType(config.output_type));
-
- TF_ASSIGN_OR_RETURN(
- GpuConvParams params,
- GetGpuConvParams(config, operand_buffers, result_buffers));
-
- std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
- auto dnn = stream_exec->AsDnn();
- if (dnn == nullptr) {
- return absl::InvalidArgumentError("No DNN in stream executor.");
- }
- TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
- /* use_cudnn_frontend = */ false, kind, dtype, dtype, stream,
- params.config->input_descriptor, params.input_buf,
- params.config->filter_descriptor, params.filter_buf,
- params.config->output_descriptor, params.output_buf,
- params.config->conv_desc,
- /* use_fallback = */ false, scratch_allocator, numeric_options,
- &runners));
-
- return runners;
-}
-
-std::string NumBytesToString(int64_t bytes) {
- return absl::StrCat(tsl::strings::HumanReadableNumBytes(bytes), " (", bytes,
- "B)");
-}
-
-CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
- se::dnn::VersionInfo version = GetDnnVersionInfoOrDefault(stream_executor);
- CudnnVersion cudnn_version;
- cudnn_version.set_major(version.major_version());
- cudnn_version.set_minor(version.minor_version());
- cudnn_version.set_patch(version.patch());
-
- return cudnn_version;
-}
-
-ComputeCapability GetComputeCapability(se::StreamExecutor* stream_executor) {
- ComputeCapability cc;
- se::CudaComputeCapability se_cc =
- stream_executor->GetDeviceDescription().cuda_compute_capability();
- cc.set_major(se_cc.major);
- cc.set_minor(se_cc.minor);
- return cc;
-}
-
-void PrintPlatformInfo(const se::Stream* stream) {
- auto* se = stream->parent();
- const auto& desc = se->GetDeviceDescription();
- LOG(ERROR) << "Device: " << desc.name();
- LOG(ERROR) << "Platform: " << desc.platform_version();
- LOG(ERROR) << "Driver: " << desc.driver_version();
- LOG(ERROR) << "Runtime: " << desc.runtime_version();
-
- auto dnn_version = GetDnnVersionInfo(se);
- if (dnn_version.ok()) {
- auto v = dnn_version.value();
- LOG(ERROR) << "cudnn version: " << v.major_version() << "."
- << v.minor_version() << "." << v.patch();
- }
-}
-
-// Returns true if the redzones in `allocator`'s allocations are unmodified.
-//
-// If the redzones are modified, logs an error, sets the appropriate failure
-// bits on `result`, and returns false.
-//
-// Returns a absl::Status if an unexpected error has occurred, and the stream
-// has been poisoned.
-//
-// `name` is a user-friendly name for the set of redzones being checked, e.g.
-// "input/output" or "scratch".
-absl::StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
- se::Stream* stream, absl::string_view name,
- std::string_view instr_str,
- AutotuneResult* result) {
- XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
- 2);
- using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
- TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
- allocator.CheckRedzones());
- if (redzone_check.ok()) {
- return true;
- }
-
- auto* fail = result->mutable_failure();
- fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
- *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
- fail->set_buffer_address(
- reinterpret_cast<uint64_t>(redzone_check.user_buffer_address));
-
- LOG(ERROR) << absl::StreamFormat(
- "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
- "cudnn bug. We will skip this algorithm in the future, but your GPU "
- "state may already be corrupted, leading to incorrect results. Within "
- "Google, no action is needed on your part. Outside of Google, please "
- "ensure you're running the latest version of cudnn. If that doesn't fix "
- "the problem, please file a bug with this full error message and we'll "
- "contact nvidia.",
- name);
- LOG(ERROR) << redzone_check.RedzoneFailureMsg();
- LOG(ERROR) << "HloInstruction " << instr_str;
- PrintPlatformInfo(stream);
- return false;
-}
-
-} // anonymous namespace
-
-bool ShouldInitConvData(const HloModuleConfig& hlo_module_config) {
- const int32_t conv_autotune_level =
- hlo_module_config.debug_options().xla_gpu_autotune_level();
- return conv_autotune_level >= 2;
-}
-
-bool ShouldCheckConv(const HloModuleConfig& hlo_module_config) {
- const int32_t conv_autotune_level =
- hlo_module_config.debug_options().xla_gpu_autotune_level();
- return conv_autotune_level >= 4;
-}
-
-absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
- const HloCustomCallInstruction* instr) {
- return AutotunerUtil::Autotune(
- instr, config_, [&] { return PickBestAlgorithmNoCache(instr); });
-}
-
-absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithmNoCache(
- const HloCustomCallInstruction* instr) {
- if (config_.IsDeviceless()) {
- // Return an autotune result with algo id -1, which means that we autotune
- // at runtime.
- AutotuneResult result;
- result.mutable_algorithm()->set_algo_id(-1);
- return result;
- }
-
- se::StreamExecutor* stream_exec = config_.GetExecutor();
- // Don't run this function concurrently on the same GPU.
- //
- // This is a bit of a hack and doesn't protect us against arbitrary concurrent
- // use of a GPU, but it's sufficient to let us compile two HLO modules
- // concurrently and then run them sequentially.
- //
- // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
- // avoid ever doing duplicate work. If we have a cache miss, only one thread
- // will run PickBestAlgorithmImpl for a particular device.
- absl::MutexLock lock(&GetGpuMutex(stream_exec));
-
- // Make sure any previous activity on this executor is done. We don't want
- // other work still running on the GPU to interfere with autotuning.
- if (!stream_exec->SynchronizeAllActivity()) {
- return Internal(
- "Failed to synchronize GPU for autotuning conv instruction");
- }
-
- absl::StatusOr<AutotuneResult> result_or(Internal("Unknown platform."));
- // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
- // have diverged. Specifically, we need to make sure redzone allocator related
- // utilities are not used in ROCm routine
- se::Platform::Id platform_id = stream_exec->GetPlatform()->id();
- if (platform_id == se::rocm::kROCmPlatformId) {
- result_or = PickBestAlgorithmNoCacheRocm(instr);
- } else if (platform_id == se::cuda::kCudaPlatformId) {
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
- result_or = PickBestAlgorithmNoCacheCuda(instr);
-#endif
- }
-
- return result_or;
-}
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-
-absl::StatusOr<GpuConvAlgorithmPicker::AutotuneRuntimeArguments>
-GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction(
- const HloCustomCallInstruction* instr, const AutotuneConfig& config,
- const DebugOptions& debug_options) {
- TF_ASSIGN_OR_RETURN(auto rz_buffers,
- RedzoneBuffers::FromInstruction(
- *instr, config, debug_options,
- RedzoneBuffers::kAllInputsOutputsNoScratch));
-
- // Get canonical HLO.
- std::string canonical_hlo(
- AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr)
- .GetHlo());
-
- TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr));
-
- GpuConvAlgorithmPicker::AutotuneRuntimeArguments runtime_arguments = {
- instr->GetModule()->config(),
- std::move(rz_buffers),
- std::move(gpu_conv_config),
- {canonical_hlo}};
-
- return runtime_arguments;
-}
-
-struct CudnnVersionRange {
- using TupleVersion = std::tuple<int, int, int>;
- TupleVersion begin;
- TupleVersion end;
-
- bool IsInRange(const CudnnVersion& other) const {
- TupleVersion other_version{other.major(), other.minor(), other.patch()};
- return begin <= other_version && other_version < end;
- }
-
- CudnnVersionRange(const CudnnVersion& begin, const CudnnVersion& end)
- : begin(begin.major(), begin.minor(), begin.patch()),
- end(end.major(), end.minor(), end.patch()) {}
-
- CudnnVersionRange(const TupleVersion& begin, const TupleVersion& end)
- : begin(begin), end(end) {}
-};
-
-struct ComputeCapabilityRange {
- using TupleComputeCapability = std::tuple<int, int>;
- TupleComputeCapability begin;
- TupleComputeCapability end;
-
- bool IsInRange(const ComputeCapability& other) const {
- TupleComputeCapability other_cc{other.major(), other.minor()};
- return begin <= other_cc && other_cc < end;
- }
-};
-
-struct DisabledAlgorithm {
- CudnnVersionRange cudnn_version_range;
- ComputeCapabilityRange compute_capability_range;
- int algo_id;
-};
-
-// TODO(b/343101418): Remove this once the bug is fixed in upstream cudnn and
-// once we updated to that cudnn version.
-static const DisabledAlgorithm kDisabledAlgorithms[] = {
- {/*.cudnn_version_range=*/{/*.begin=*/{9, 0, 0}, /*.end=*/{10, 0, 0}},
- /*.compute_capability_range=*/{/*.begin=*/{6, 0}, /*.end=*/{8, 0}},
- /*.algo_id=*/14}};
-
-// There are three tiers of errors possible here: returning a failed
-// absl::StatusOr means autotuning fails immediately; returning an
-// AutotuneResult with a failure code other than DISQUALIFIED means autotuning
-// fails if crash_on_checking_failure is set; and returning a DISQUALIFIED
-// AutotuneResult simply skips the engine/algorithm while recording a reason for
-// skipping it.
-absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
- GenericConvRunner* const runner,
- std::optional<ReferenceResult>* reference_result,
- absl::Span<const AlgorithmDesc> disabled_algos,
- std::optional<AutotuneCacheKey> instruction_info,
- const AutotuneRuntimeArguments& runtime_arguments) {
- auto alg = runner->ToAlgorithmDesc();
-
- se::StreamExecutor* stream_exec = config_.GetExecutor();
- XLA_SCOPED_LOGGING_TIMER_LEVEL(
- absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
- alg.ToString()),
- 2);
-
- auto make_failure = [&alg](AutotuneResult::FailureKind kind,
- absl::string_view msg) {
- AutotuneResult result;
- *result.mutable_algorithm() = alg.ToProto();
- result.mutable_failure()->set_kind(kind);
- result.mutable_failure()->set_msg(/* *sigh* */ msg.data(), msg.size());
- return result;
- };
-
- AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);
-
- std::string instr_str = instruction_info.has_value()
- ? std::string(instruction_info->GetHlo())
- : "<unknown>";
-
- for (const auto& disabled_algo : kDisabledAlgorithms) {
- if (disabled_algo.cudnn_version_range.IsInRange(
- GetCudnnVersion(stream_exec)) &&
- disabled_algo.compute_capability_range.IsInRange(
- GetComputeCapability(stream_exec)) &&
- disabled_algo.algo_id == alg.algo_id()) {
- LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
- << " for conv " << instr_str;
- return make_failure(AutotuneResult::DISQUALIFIED,
- "Disqualified for being known-buggy.");
- }
- }
-
- if (absl::c_linear_search(disabled_algos, alg_key)) {
- LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
- << " for conv " << instr_str;
- return make_failure(AutotuneResult::DISQUALIFIED,
- "Disqualified for being known-buggy.");
- }
-
- GpuConvConfig config = runtime_arguments.gpu_conv_config;
- auto activation_mode =
- config.fusion ? config.fusion->mode : se::dnn::ActivationMode::kNone;
-
- // For fused convolutions with the identity function as the activation, only
- // ALGO_IMPLICIT_PRECOMP_GEMM does the right thing. Other algorithms
- // silently do Relu. See
- // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
- //
- // For cuDNN Frontend, there is no way to check whether we're using a broken
- // algorithm, so on versions where some algorithms are broken, we don't use
- // the cuDNN Frontend for these convs at all. As such, if we get a
- // frontend-based runner, we can be sure it's not one of the broken
- // algorithms we're checking for.
- if (!alg.is_cudnn_frontend() &&
- config.kind == CudnnConvKind::kForwardActivation &&
- activation_mode == se::dnn::ActivationMode::kNone &&
- alg.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
- return make_failure(AutotuneResult::DISQUALIFIED,
- "Disqualified for implicit RELU.");
- }
-
- TF_ASSIGN_OR_RETURN(
- se::RedzoneAllocator scratch_allocator,
- AutotunerUtil::CreateRedzoneAllocator(
- config_, runtime_arguments.hlo_module_config.debug_options()));
-
- se::dnn::ProfileResult profile_result;
- VLOG(4) << "Trying algorithm " << alg.ToString() << " for " << instr_str;
-
- SlowOperationAlarm alarm(absl::Seconds(1), [&] {
- return absl::StrFormat(
- "Trying algorithm %s for conv %s is taking a while...", alg.ToString(),
- instr_str);
- });
-
- std::optional<size_t> workspace_size =
- runner->ToAlgorithmDesc().workspace_size();
- if (!workspace_size) {
- return make_failure(AutotuneResult::UNKNOWN,
- "Internal error: missing workspace size from "
- "OpRunner::ToAlgorithmDesc()");
- }
-
- auto scratch_or = scratch_allocator.AllocateBytes(*workspace_size);
- if (!scratch_or.ok()) {
- return make_failure(AutotuneResult::DISQUALIFIED,
- absl::StrCat("Scratch allocation failed: ",
- scratch_or.status().ToString()));
- }
- se::DeviceMemoryBase scratch_memory = scratch_or.value();
-
- // Use assignment instead of brace-list to make GCC 4.9 happy.
- RunConvOptions options;
- options.runner_cache = runner;
- // The following plan timing code is based on
- // https://github.com/NVIDIA/cudnn-frontend/blob/60496f42fdc7a4ccc059f5934e306e728a756755/include/cudnn_frontend_find_plan.h
- float max_time = 0;
- float min_time = std::numeric_limits<float>::max();
- absl::Status launch_status;
- std::vector<se::DeviceMemoryBase> operand_buffers =
- runtime_arguments.rz_buffers.input_buffers();
- std::vector<se::DeviceMemoryBase> result_buffers =
- runtime_arguments.rz_buffers.output_buffers();
-
- TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
-
- // Dry-run to warmup the plan.
- launch_status = RunGpuConv(config, operand_buffers, result_buffers,
- scratch_memory, stream, options);
- // Flag that a warm-up run has been executed; this allows the GpuTimer for
- // the main measurement to safely use the delay kernel pattern, even if lazy
- // module loading is enabled.
- options.profile_result = &profile_result;
- profile_result.set_warmup_run_executed(true);
- constexpr int kMaxIter = 10;
- // Iterate until the new measurement is within kThreshold of the current
- // minimum.
- int num_iters = 0;
- for (; num_iters < kMaxIter && launch_status.ok(); ++num_iters) {
- launch_status = RunGpuConv(config, operand_buffers, result_buffers,
- scratch_memory, stream, options);
- if (!profile_result.is_valid()) {
- break;
- }
- float old_min_time = min_time;
- min_time = std::min(min_time, profile_result.elapsed_time_in_ms());
- max_time = std::max(max_time, profile_result.elapsed_time_in_ms());
-
- constexpr float kThreshold = 0.05f;
- if (std::abs(profile_result.elapsed_time_in_ms() - old_min_time) /
- old_min_time <
- kThreshold) {
- break;
- }
- }
- if (!launch_status.ok()) {
- VLOG(5) << "Launch failed: " << launch_status;
- return make_failure(
- AutotuneResult::DISQUALIFIED,
- absl::StrCat("Profiling failure on cuDNN engine ", alg.ToString(), ": ",
- launch_status.ToString()));
- }
- if (!profile_result.is_valid()) {
- VLOG(5) << "Launch succeeded but profile result is invalid.";
- // Not DISQUALIFIED: this means something went wrong internally.
- return make_failure(
- AutotuneResult::UNKNOWN,
- absl::StrCat("Launch succeeded but profile result is invalid, "
- "with cuDNN engine ",
- alg.ToString(), ": ", launch_status.ToString()));
- }
- VLOG(4) << "Best time: " << min_time << " ms. Worst time: " << max_time
- << " ms. Total iterations: " << num_iters;
- int64_t scratch_bytes_used =
- scratch_allocator.TotalAllocatedBytesExcludingRedzones();
-
- AutotuneResult result;
- *result.mutable_algorithm() = alg.ToProto();
- result.set_scratch_bytes(scratch_bytes_used);
- *result.mutable_run_time() =
- tsl::proto_utils::ToDurationProto(absl::Milliseconds(min_time));
-
- if (!ShouldCheckConv(runtime_arguments.hlo_module_config)) {
- if (!reference_result->has_value()) {
- (*reference_result) = {
- alg, std::vector<DeviceMemoryBase>(result_buffers.size())};
- }
- return result;
- }
-
- // Check for writes to redzones.
- TF_ASSIGN_OR_RETURN(
- bool input_output_allocator_redzone_clear,
- CheckRedzones(runtime_arguments.rz_buffers.RedzoneAllocator(), stream,
- "input/output", instr_str, &result));
-
- TF_ASSIGN_OR_RETURN(
- bool scratch_allocator_redzone_clear,
- CheckRedzones(scratch_allocator, stream, "scratch", instr_str, &result));
-
- if (!input_output_allocator_redzone_clear ||
- !scratch_allocator_redzone_clear) {
- if (runtime_arguments.canonical_hlo.has_value()) {
- std::string canonical_hlo = runtime_arguments.canonical_hlo.value();
- std::string blas_version;
- if (auto* blas = stream_exec->AsBlas()) {
- (void)blas->GetVersion(&blas_version);
- }
-
- AlgorithmDenylist proto;
- auto entry = proto.add_entries();
- entry->set_hlo(canonical_hlo);
- *entry->mutable_cc() = GetComputeCapability(stream_exec);
- *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec);
- entry->set_blas_version(blas_version);
- auto algo = entry->add_algos();
- algo->set_id(alg.algo_id());
- algo->set_tensor_ops(alg.tensor_ops_enabled());
-
- LOG(ERROR) << "To denylist this algorithm for this convolution, "
- "copy-paste the following "
- "proto to the denylist file pointed by XLA_FLAGS "
- "--xla_gpu_algorithm_denylist_path="
- << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
- << " : " << proto.ShortDebugString();
- }
-
- // CheckRedzones has modified the result in-place to include a failure.
- return result;
- }
-
- if (reference_result->has_value()) {
- XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
-
- const DebugOptions& debug_options =
- runtime_arguments.hlo_module_config.debug_options();
- BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(),
- debug_options.xla_gpu_autotune_gemm_rtol());
- for (int i = 0; i < result_buffers.size(); ++i) {
- absl::StatusOr<bool> compare_result = comparator.CompareEqual(
- stream, (*reference_result)->buffers[i], result_buffers[i]);
- if (!compare_result.ok()) {
- LOG(ERROR) << "Unable to compare "
- << (*reference_result)->algorithm.ToString() << " against "
- << alg.ToString() << " for " << instr_str << ": "
- << compare_result.status();
- if (compare_result.status().code() ==
- absl::StatusCode::kResourceExhausted) {
- // Possibly OOM. Propagate the error.
- return compare_result.status();
- }
- CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
- } else if (!compare_result.value()) {
- LOG(ERROR)
- << "Results mismatch between different convolution algorithms. "
- "This is likely a bug/unexpected loss of precision in cudnn.\n"
- << instr_str << " for " << (*reference_result)->algorithm.ToString()
- << " vs " << alg.ToString();
- PrintPlatformInfo(stream);
- if (instruction_info.has_value()) {
- VLOG(2) << "Full module on failure: \n"
- << instruction_info->GetModelStr();
- }
- auto* fail = result.mutable_failure();
- fail->set_kind(AutotuneResult::WRONG_RESULT);
- fail->set_buffer_address(
- reinterpret_cast<uint64_t>(result_buffers[i].opaque()));
- *fail->mutable_reference_algorithm() =
- (*reference_result)->algorithm.ToProto();
- }
- }
- } else {
- XLA_SCOPED_LOGGING_TIMER_LEVEL("Memcpy Reference Result", 2);
- std::vector<DeviceMemoryBase> reference_result_buffers(
- result_buffers.size());
- for (int i = 0; i < result_buffers.size(); ++i) {
- TF_ASSIGN_OR_RETURN(
- reference_result_buffers[i],
- runtime_arguments.rz_buffers.RedzoneAllocator().AllocateBytes(
- result_buffers[i].size()));
- TF_RETURN_IF_ERROR(stream->Memcpy(&reference_result_buffers[i],
- result_buffers[i],
- result_buffers[i].size()));
- }
- (*reference_result) = {alg, reference_result_buffers};
- }
-
- return result;
-}
-
-absl::StatusOr<AutotuneResult>
-GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
- const HloCustomCallInstruction* instr) {
- AutotuneCacheKey instruction_info{config_.GetModelStr(), *instr};
- std::string instr_str(instruction_info.GetHlo());
- XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
- "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr_str));
-
- const DebugOptions& debug_options =
- instr->GetModule()->config().debug_options();
- const bool crash_on_checking_failure =
- debug_options.xla_gpu_crash_on_verification_failures();
-
- std::string blas_version;
- se::StreamExecutor* stream_exec = config_.GetExecutor();
- if (auto* blas = stream_exec->AsBlas()) {
- (void)blas->GetVersion(&blas_version);
- }
-
- std::vector<AlgorithmDesc> disabled_algos;
- TF_ASSIGN_OR_RETURN(
- AutotuneRuntimeArguments runtime_arguments,
- AutotuneRuntimeArguments::FromInstruction(instr, config_, debug_options));
- if (runtime_arguments.canonical_hlo.has_value()) {
- disabled_algos = GetDisabledConvAlgorithms(
- GetComputeCapability(stream_exec), GetCudnnVersion(stream_exec),
- blas_version, runtime_arguments.canonical_hlo.value());
- }
-
- const bool cudnn_frontend_enabled =
- debug_options.xla_gpu_enable_cudnn_frontend();
- bool allow_tf32 = true;
- // TODO(b/284371623): Properly set allow_tf32 even if instr==nullptr, which is
- // the case when running an AOT compiled executable with runtime autotuning.
- if (instr) {
- allow_tf32 = absl::c_all_of(
- instr->precision_config().operand_precision(),
- [](int precision) { return precision <= PrecisionConfig::HIGH; });
- }
- const se::NumericOptions numeric_options{
- RequireDeterminism(instr->GetModule()->config()), allow_tf32};
-
- // Use the first algorithm that's supported as reference. There isn't a
- // particular reason to use it, as any algorithm suffices. It doesn't make
- // this algorithm considered correct, though.
- std::optional<ReferenceResult> reference_result;
-
- TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
- TF_ASSIGN_OR_RETURN(
- std::vector<GenericConvRunner> runners,
- GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
- cudnn_frontend_enabled,
- /* use_fallback = */ false, numeric_options));
-
- std::vector<AutotuneResult> profile_results;
- for (auto& runner_cache : runners) {
- TF_ASSIGN_OR_RETURN(
- auto result,
- AutotuneOneConvRunner(&runner_cache, &reference_result, disabled_algos,
- instruction_info, runtime_arguments));
- profile_results.emplace_back(std::move(result));
- }
-
- // If any algorithm has worked, we'll skip the fallback algorithms, since
- // they include some very slow algorithms.
- if (!reference_result) {
- LOG(WARNING) << "None of the algorithms provided by cuDNN heuristics "
- "worked; trying fallback algorithms.";
- if (runtime_arguments.canonical_hlo.has_value()) {
- LOG(WARNING) << "Conv: " << runtime_arguments.canonical_hlo.value();
- }
-
- TF_ASSIGN_OR_RETURN(
- std::vector<GenericConvRunner> fallback_runners,
- GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
- cudnn_frontend_enabled,
- /* use_fallback = */ true, numeric_options));
-
- for (auto& runner_cache : fallback_runners) {
- TF_ASSIGN_OR_RETURN(
- auto result, AutotuneOneConvRunner(&runner_cache, &reference_result,
- disabled_algos, instruction_info,
- runtime_arguments));
- profile_results.emplace_back(std::move(result));
- }
- }
-
- // Log the autotuning result.
- if (instr) {
- AutotuningLog log;
- {
- ConvInstructionLog instr_log;
- *instr_log.mutable_instruction() = instr->ToProto();
- for (int i = 0; i < instr->operand_count(); i++) {
- *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
- instr_log.add_operand_addresses(reinterpret_cast<uint64_t>(
- runtime_arguments.rz_buffers.input_buffers()[i].opaque()));
- }
- for (se::DeviceMemoryBase result_buffer :
- runtime_arguments.rz_buffers.output_buffers()) {
- instr_log.add_result_addresses(
- reinterpret_cast<uint64_t>(result_buffer.opaque()));
- }
- log.mutable_instr()->PackFrom(instr_log);
- }
- for (const auto& profile : profile_results) {
- *log.add_results() = profile;
- }
- *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
- *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
- log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
- log.set_blas_version(blas_version);
- VLOG(2) << "Autotuning result: " << log.ShortDebugString();
- // If we crash on checking failure, we are in a testing/benchmark mode.
- if (crash_on_checking_failure) {
- // Crash on miscompares and redzone violations if desired.
- for (const auto& profile : profile_results) {
- if (profile.has_failure() &&
- profile.failure().kind() != AutotuneResult::DISQUALIFIED) {
- LOG(FATAL) << "crash_on_checking_failure encountered errors:\n\n"
- << log.DebugString(); // NOLINT
- }
- }
- }
- }
-
- TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
- PickBestResult(profile_results, instr_str,
- runtime_arguments.hlo_module_config));
- return selected_algorithm;
-}
-#endif
-
-absl::StatusOr<AutotuneResult>
-GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
- const HloCustomCallInstruction* instr) {
- XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
- "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
-
- const bool allow_tf32 = absl::c_all_of(
- instr->precision_config().operand_precision(),
- [](int precision) { return precision <= PrecisionConfig::HIGH; });
- const se::NumericOptions numeric_options{
- RequireDeterminism(instr->GetModule()->config()), allow_tf32};
-
- se::StreamExecutor* stream_exec = config_.GetExecutor();
- const auto device_ordinal = stream_exec->device_ordinal();
- std::vector<se::DeviceMemoryBase> operand_buffers;
-
- // allocator either points to this->allocator_ or, if that's null, to a
- // se::StreamExecutorMemoryAllocator for stream_exec.
- se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
- ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
- const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
- // Although we don't have evidence this matters, zero out the buffers
- // before autotuning. It's conceivable that using uninitialized memory as
- // the inputs might affect performance if e.g. the inputs contain
- // denormals, and this is easy enough.
- return stream->MemZero(&buffer, buffer.size());
- };
-
- // Allocate space for the input, filter, and output of the convolution. We
- // use a ScratchAllocator for this instead of calling allocator_ directly so
- // that our allocations don't leak.
- for (const auto* operand : instr->operands()) {
- TF_ASSIGN_OR_RETURN(auto buffer,
- input_output_allocator.AllocateBytes(
- ShapeUtil::ByteSizeOf(operand->shape())));
- TF_RETURN_IF_ERROR(initialize_buffer(buffer));
- operand_buffers.push_back(buffer);
- }
-
- std::vector<se::DeviceMemoryBase> result_buffers(
- instr->shape().tuple_shapes_size());
- if (instr->shape().IsTuple()) {
- for (int i = 0; i < instr->shape().tuple_shapes_size(); ++i) {
- TF_ASSIGN_OR_RETURN(
- result_buffers[i],
- input_output_allocator.AllocateBytes(
- ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(i))));
- TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[i]));
- }
- } else {
- TF_ASSIGN_OR_RETURN(
- result_buffers[0],
- input_output_allocator.AllocateBytes(
- ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
- TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[0]));
- }
-
- ScratchAllocator scratch_allocator(device_ordinal, allocator);
-
- TF_ASSIGN_OR_RETURN(
- std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners,
- GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers),
- absl::MakeSpan(result_buffers), stream_exec,
- &scratch_allocator, stream, numeric_options));
-
- std::vector<AutotuneResult> profile_results;
-
- if (runners.size() == 1) {
- TF_ASSIGN_OR_RETURN(auto alg, runners[0]->ToAlgorithmDesc());
- auto algorithm_proto = alg.ToProto();
- profile_results.emplace_back();
- auto& result = profile_results.back();
- *result.mutable_algorithm() = algorithm_proto;
-
- result.set_scratch_bytes(runners[0]->GetWorkspaceSize());
-
- // TODO(awpr): if the profile result time for a singleton algorithm is
- // needed, plumb it via OpRunner; we'll need to do this to let TF ops avoid
- // re-profiling ROCm algorithms anyway.
- *result.mutable_run_time() =
- tsl::proto_utils::ToDurationProto(absl::Milliseconds(-1));
- } else {
- TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
- for (auto& runner : runners) {
- TF_ASSIGN_OR_RETURN(auto alg, runner->ToAlgorithmDesc());
- XLA_SCOPED_LOGGING_TIMER_LEVEL(
- absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
- alg.ToString()),
- 2);
-
- se::dnn::ProfileResult profile_result;
- VLOG(4) << "Trying algorithm " << alg.ToString() << " for "
- << instr->ToString();
-
- TF_ASSIGN_OR_RETURN(
- DeviceMemoryBase scratch_memory,
- scratch_allocator.AllocateBytes(runner->GetWorkspaceSize()));
-
- TF_ASSIGN_OR_RETURN(auto lazy_runner,
- se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
- std::move(runner)));
-
- GenericConvRunner runner_cache(std::move(lazy_runner));
-
- // Use assignment instead of brace-list to make GCC 4.9 happy.
- RunConvOptions options;
- options.profile_result = &profile_result;
- options.runner_cache = &runner_cache;
- absl::Status launch_status =
- RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffers,
- scratch_memory, stream, options);
-
- if (!launch_status.ok()) {
- continue;
- }
-
- if (!profile_result.is_valid()) {
- continue;
- }
-
- profile_results.emplace_back();
- AutotuneResult& result = profile_results.back();
- *result.mutable_algorithm() = alg.ToProto();
-
- int64_t scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
- result.set_scratch_bytes(scratch_bytes_used);
- *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
- absl::Milliseconds(profile_result.elapsed_time_in_ms()));
- }
- }
-
- TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
- PickBestResult(profile_results, instr->ToString(),
- instr->GetModule()->config()));
- return selected_algorithm;
-}
-
-absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(
- HloInstruction* instr) {
- CHECK(IsCustomCallToDnnConvolution(*instr));
-
- const bool strict = instr->parent()
- ->parent()
- ->config()
- .debug_options()
- .xla_gpu_strict_conv_algorithm_picker();
-
- absl::StatusOr<AutotuneResult> best_algo_or =
- PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
- if (!best_algo_or.ok()) {
- auto msg = absl::StrFormat(
- "Failed to determine best cudnn convolution algorithm for:\n%s\n\n"
- "Original error: %s",
- instr->ToString(), best_algo_or.status().ToString());
-
- if (strict) {
- return Unknown(
- "%s\n\nTo ignore this failure and try to use a fallback algorithm "
- "(which may have suboptimal performance), use "
- "XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please "
- "also file a bug for the root cause of failing autotuning.",
- msg);
- }
- LOG(WARNING)
- << msg << "\n\nAs a result, convolution performance may be suboptimal.";
- return false;
- }
-
- auto best_algo = std::move(best_algo_or).value();
- VLOG(3) << "Setting cudnn conv to use algorithm "
- << best_algo.conv().algorithm() << " and "
- << NumBytesToString(best_algo.scratch_bytes())
- << " of scratch memory: " << instr->ToString()
- << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
-
- // Replace instr with a new CustomCall which has the correct algorithm, and
- // whose output shape has the appropriate amount of scratch memory.
- HloComputation* computation = instr->parent();
- std::vector<Shape> new_call_element_shapes;
- // Add the shapes of the outputs of the convolution.
- new_call_element_shapes.reserve(instr->shape().tuple_shapes_size() - 1);
- for (int i = 0; i < instr->shape().tuple_shapes_size() - 1; ++i) {
- new_call_element_shapes.emplace_back(instr->shape().tuple_shapes(i));
- }
- // The final element is the size of the workspace.
- new_call_element_shapes.emplace_back(
- ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()}));
- Shape new_call_shape = ShapeUtil::MakeTupleShape(new_call_element_shapes);
-
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
- instr->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& backend_config =
- *gpu_backend_config.mutable_cudnn_conv_backend_config();
- *backend_config.mutable_algorithm() = best_algo.algorithm();
- backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
- best_algo.scratch_bytes());
-
- HloInstruction* new_call = computation->AddInstruction(
- instr->CloneWithNewOperands(new_call_shape, instr->operands()));
-
- // Preserve the name of the old instruction. This is safe because we're going
- // to remove the old one anyway, and it makes it easier to trace how our conv
- // is transformed through all our passes.
- new_call->SetAndSanitizeName(instr->name());
-
- VLOG(3) << "Replacing convolution " << instr->ToString() << " with "
- << new_call->ToString();
-
- TF_RETURN_IF_ERROR(new_call->set_backend_config(gpu_backend_config));
-
- std::vector<HloInstruction*> new_tuple_elements;
- new_tuple_elements.reserve(new_call->shape().tuple_shapes_size() - 1);
- for (int i = 0; i < new_call->shape().tuple_shapes_size() - 1; ++i) {
- new_tuple_elements.emplace_back(
- computation->AddInstruction(HloInstruction::CreateGetTupleElement(
- new_call->shape().tuple_shapes(i), new_call, i)));
- }
- new_tuple_elements.emplace_back(computation->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({}))));
-
- // Repackage new_call so it has the same shape as the original call, namely
- // (conv_result, u8[0]).
- HloInstruction* new_tuple = computation->AddInstruction(
- HloInstruction::CreateTuple(new_tuple_elements));
-
- TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
- return true;
-}
-
-absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
- HloComputation* computation) {
- std::vector<HloInstruction*> convs;
- for (HloInstruction* instr : computation->instructions()) {
- if (IsCandidate(instr)) {
- convs.push_back(instr);
- }
- }
-
- bool changed = false;
- for (HloInstruction* instr : convs) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
- changed |= result;
- }
- return changed;
-}
-
-absl::StatusOr<bool> GpuConvAlgorithmPicker::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_SCOPED_LOGGING_TIMER(
- absl::StrCat("GpuConvAlgorithmPicker for ", module->name()));
-
- if (!IsEnabled(module)) {
- VLOG(3) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
- "returning early.";
- return false;
- }
-
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.h b/third_party/xla/xla/service/gpu/conv_algorithm_picker.h
deleted file mode 100644
index e6dea8b..0000000
--- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.h
+++ /dev/null
@@ -1,159 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
-#define XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
-
-#include <optional>
-#include <string>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_conv_runner.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#endif
-
-namespace xla {
-namespace gpu {
-
-// Choose the fastest algorithm for each conv.
-// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
-// each and adding explicit scratch space to the CustomCalls.
-//
-// We pick the algorithm before fusion so that we can generate better HLO. After
-// GpuConvRewriter, our convolutions are CustomCalls which return a
-// tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
-// scratch:
-//
-// customcall = (f32[...], f32[0])
-// return gte(customcall, 0)
-//
-// The algorithm picker then chooses the best algorithm, and potentially
-// increases the scratch space. It replaces customcall with new_tuple,
-// giving us the following:
-//
-// new_customcall = (f32[...], f32[N])
-// new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
-// return gte(new_tuple, 0)
-//
-// The new tuple and gte instructions can be simplified away, because
-// nobody is expected to use the scratch value.
-//
-// However, if we were to run GpuConvAlgorithmPicker after fusion
-// the gte(customcall, 0) would probably already be into a fusion node. We
-// can't simplify across HloComputation boundaries, so in this case we
-// wouldn't be able to simplify away the new_tuple bits.
-//
-// It supports two modes: device and deviceless.
-// In device mode, we run autotuning on the device and store autotune results.
-//
-// In deviceless mode, we pass in some information related to the device and
-// use stored autotune results to rewrite convolutions. If the required autotune
-// result is not stored, then the performance of convolution will be suboptimal.
-class GpuConvAlgorithmPicker : public HloModulePass {
- public:
- explicit GpuConvAlgorithmPicker(AutotuneConfig config) : config_(config) {}
-
- absl::string_view name() const override {
- return "gpu-conv-algorithm-picker";
- }
-
- static bool IsEnabled(const HloModule* module) {
- return module->config().debug_options().xla_gpu_autotune_level() != 0;
- }
-
- static bool IsCandidate(const HloInstruction* instr) {
- return IsCustomCallToDnnConvolution(*instr);
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
- absl::StatusOr<bool> RunOnInstruction(HloInstruction* instr);
-
- absl::StatusOr<AutotuneResult> PickBestAlgorithm(
- const HloCustomCallInstruction* instr);
- absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCache(
- const HloCustomCallInstruction* instr);
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
- // Simple bundle of an algorithm and its output, for comparing results across
- // autotuned algorithms.
- struct ReferenceResult {
- stream_executor::dnn::AlgorithmDesc algorithm;
- std::vector<stream_executor::DeviceMemoryBase> buffers;
- };
-
- // Execution environment for autotuning. Runtime autotuning requires runtime
- // information such as input/output buffers in order to run. It can be
- // constructed from the autotuned instruction by FromInstruction.
- struct AutotuneRuntimeArguments {
- const HloModuleConfig hlo_module_config;
- RedzoneBuffers rz_buffers;
- const GpuConvConfig gpu_conv_config;
- std::optional<std::string> canonical_hlo;
-
- static absl::StatusOr<AutotuneRuntimeArguments> FromInstruction(
- const HloCustomCallInstruction* instr, const AutotuneConfig& config,
- const DebugOptions& debug_options);
- };
-
- absl::StatusOr<AutotuneResult> AutotuneOneConvRunner(
- GenericConvRunner* runner,
- std::optional<ReferenceResult>* reference_result,
- absl::Span<const stream_executor::dnn::AlgorithmDesc> disabled_algos,
- std::optional<AutotuneCacheKey> instruction_info,
- const AutotuneRuntimeArguments& runtime_arguments);
-
- // Pick the best algorithm for CUDA platform.
- absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheCuda(
- const HloCustomCallInstruction* instr);
-#endif
-
- absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheRocm(
- const HloCustomCallInstruction* instr);
-
- private:
- AutotuneConfig config_;
-};
-
-} // namespace gpu
-} // namespace xla
-#endif // XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc
deleted file mode 100644
index aa7c5e2..0000000
--- a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc
+++ /dev/null
@@ -1,129 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/conv_algorithm_picker.h"
-
-#include <cstdint>
-#include <variant>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/platform_util.h"
-#include "xla/service/tuple_simplifier.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuConvAlgorithmPickerTest : public HloTestBase {
- public:
- GpuConvAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
-};
-
-TEST_F(GpuConvAlgorithmPickerTest, SetAlgorithm) {
- constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
- %arg0 = f32[3,56,56,16]{2,1,0,3} parameter(0)
- %arg1 = f32[3,3,3,64]{2,1,0,3} parameter(1)
- ROOT %conv = f32[54,54,16,64]{1,0,3,2} convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo));
-
- se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
- PlatformUtil::GetStreamExecutors(platform));
- ASSERT_GT(executors.size(), 0);
- se::StreamExecutor* stream_exec = executors[0];
-
- const se::GpuComputeCapability& cc = backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .gpu_compute_capability();
- bool changed = false;
- TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get()));
- changed = false;
- DebugOptions opts = DefaultDebugOptionsIgnoringFlags();
-
- AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts};
- TF_ASSERT_OK_AND_ASSIGN(changed,
- RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
- ASSERT_TRUE(changed);
-
- AutotuneResults results;
- TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
- ASSERT_EQ(results.results_size(), 1);
- auto& result = *results.mutable_results(0)->mutable_result();
- int64_t old_scratch_bytes = result.scratch_bytes();
- int64_t new_scratch_bytes = old_scratch_bytes + 1;
- result.set_scratch_bytes(new_scratch_bytes);
-
- AutotunerUtil::ClearAutotuneResults();
- TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
-
- // Now send the same module through GpuConvAlgorithmPicker again. The conv
- // should have the new scratch bytes.
- TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo));
- changed = false;
- TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get()));
- changed = false;
- TF_ASSERT_OK_AND_ASSIGN(changed,
- RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
- ASSERT_TRUE(changed);
-
- // TupleSimplifier cleans this up a bit before we pattern-match
- TF_ASSERT_OK(RunHloPass(TupleSimplifier(), m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- HloInstruction* conv;
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall(&conv))));
- EXPECT_THAT(
- conv->shape(),
- GmockMatch(m::Shape().WithSubshape(
- {1}, m::Shape().WithElementType(U8).WithDims({new_scratch_bytes}))));
-
- // Algorithm 14 is disabled for cuDNN 9 on V100
- TF_ASSERT_OK_AND_ASSIGN(auto dnn_version, GetDnnVersionInfo(stream_exec));
- if (dnn_version.major_version() >= 9 && dnn_version.major_version() < 10 &&
- std::holds_alternative<stream_executor::CudaComputeCapability>(cc) &&
- std::get<stream_executor::CudaComputeCapability>(cc).major == 7 &&
- std::get<stream_executor::CudaComputeCapability>(cc).minor == 0) {
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->has_cudnn_conv_backend_config() &&
- conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .algorithm()
- .algo_id() != 14);
- }
-}
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/copy_fusion.cc b/third_party/xla/xla/service/gpu/copy_fusion.cc
deleted file mode 100644
index a83354c..0000000
--- a/third_party/xla/xla/service/gpu/copy_fusion.cc
+++ /dev/null
@@ -1,197 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/copy_fusion.h"
-
-#include <cstdint>
-#include <queue>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-bool OnlyElementwiseOpsReachableFromParams(HloComputation* fused_computation) {
- std::queue<const HloInstruction*> q;
- absl::flat_hash_set<const HloInstruction*> visited;
- for (auto param : fused_computation->parameter_instructions()) {
- q.push(param);
- visited.insert(param);
- }
- while (!q.empty()) {
- const HloInstruction* hlo = q.front();
- q.pop();
- for (auto user : hlo->users()) {
- if ((!user->IsElementwiseOnOperand(user->operand_index(hlo)) ||
- user->opcode() == HloOpcode::kCopy) &&
- user->opcode() != HloOpcode::kBitcast &&
- user->opcode() != HloOpcode::kTuple) {
- return false;
- }
- if (visited.insert(user).second) {
- q.push(user);
- }
- }
- }
- return true;
-}
-
-absl::StatusOr<bool> CopyFusion::DoCopyFusion(HloComputation* computation) {
- bool changed = false;
- std::vector<HloInstruction*> defs_before_uses =
- computation->MakeInstructionPostOrder();
-
- for (HloInstruction* hlo : defs_before_uses) {
- if (hlo->opcode() != HloOpcode::kFusion) {
- continue;
- }
- std::vector<HloInstruction*> copies;
- std::vector<HloInstruction*> other_users;
- HloComputation* fused_computation = hlo->fused_instructions_computation();
- if (!OnlyElementwiseOpsReachableFromParams(fused_computation)) {
- continue;
- }
- HloInstruction* root = fused_computation->root_instruction();
- if (IsReductionFromOrToContiguousDimensions(*root) ||
- root->opcode() == HloOpcode::kScatter ||
- (hlo->IsMultiOutputFusion() &&
- absl::c_all_of(root->operands(), [](const HloInstruction* slice) {
- return slice->opcode() == HloOpcode::kSlice;
- }))) {
- continue;
- }
- for (auto user : hlo->users()) {
- HloInstruction* copy_user = user;
- // Skip get-tuple-element ops.
- if (copy_user->opcode() == HloOpcode::kGetTupleElement &&
- copy_user->user_count() == 1) {
- if (IsReductionFromOrToContiguousDimensions(
- *(root->operand(copy_user->tuple_index())))) {
- other_users.push_back(user);
- continue;
- }
- copy_user = copy_user->users()[0];
- }
- // Skip bitcast ops.
- if (copy_user->opcode() == HloOpcode::kBitcast &&
- copy_user->user_count() == 1) {
- copy_user = copy_user->users()[0];
- }
- if (copy_user->opcode() == HloOpcode::kCopy &&
- copy_user->shape() == copy_user->operand(0)->shape() &&
- !copy_user->shape().IsTuple() &&
- !copy_user->HasControlDependencies()) {
- copies.push_back(copy_user);
- } else {
- other_users.push_back(user);
- }
- }
- if (copies.empty()) {
- continue;
- }
- auto fusion_adaptor = HloFusionAdaptor::ForComputation(fused_computation);
- auto dynamic_update_slices =
- GetOutputDefiningDynamicUpdateSlices(fusion_adaptor->GetRoots());
- // Skip dynamic update slice fusions which might be emitted in-place.
- if (!dynamic_update_slices.empty() &&
- (root->opcode() != HloOpcode::kTuple ||
- dynamic_update_slices.size() == root->shape().tuple_shapes_size())) {
- continue;
- }
- changed = true;
-
- HloInstruction::InstructionVector tuple_elements;
- int64_t num_outputs =
- hlo->IsMultiOutputFusion() ? root->operand_count() : int64_t{1};
- tuple_elements.reserve(copies.size() + num_outputs);
- if (hlo->IsMultiOutputFusion()) {
- for (HloInstruction* operand : root->operands()) {
- tuple_elements.push_back(operand);
- }
- } else {
- tuple_elements.push_back(root);
- }
-
- for (auto copy : copies) {
- HloInstruction* user = copy;
- std::vector<HloInstruction*> operand_chain;
- operand_chain.push_back(user);
- while (user->operand(0) != hlo) {
- user = user->mutable_operand(0);
- operand_chain.push_back(user);
- }
- HloInstruction* clone_operand = root;
- if (hlo->IsMultiOutputFusion()) {
- clone_operand = root->mutable_operand(user->tuple_index());
- CHECK_EQ(operand_chain.back()->opcode(), HloOpcode::kGetTupleElement);
- operand_chain.pop_back();
- }
- for (int64_t i = operand_chain.size() - 1; i >= 0; --i) {
- HloInstruction* user = operand_chain[i];
- clone_operand = fused_computation->AddInstruction(
- user->CloneWithNewOperands(user->shape(), {clone_operand}));
- }
- tuple_elements.push_back(clone_operand);
- }
-
- HloInstruction* new_root = fused_computation->AddInstruction(
- HloInstruction::CreateTuple(tuple_elements));
- fused_computation->set_root_instruction(new_root,
- /*accept_different_shape=*/true);
- *hlo->mutable_shape() = new_root->shape();
-
- if (root->opcode() == HloOpcode::kTuple) {
- TF_RETURN_IF_ERROR(fused_computation->RemoveInstruction(root));
- } else {
- auto get_tuple_element_root = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(hlo, 0));
- TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(
- other_users, get_tuple_element_root));
- }
- for (int64_t i = 0; i < copies.size(); ++i) {
- auto get_tuple_element = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(hlo, num_outputs + i));
- TF_RETURN_IF_ERROR(
- computation->ReplaceInstruction(copies[i], get_tuple_element));
- }
- }
- return changed;
-}
-
-absl::StatusOr<bool> CopyFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- // Only for the entry computation we can be sure that the copies do not share
- // a buffer with a parameter of the fusion that it will be fused with. For
- // example while loop computations have tuple parameters that need to share
- // the buffers with the output tuples, and copies inserted by the
- // CopyInsertion pass will share a buffer with the tuple output (and thus
- // with the tuple input as well).
- return DoCopyFusion(module->entry_computation());
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/copy_fusion.h b/third_party/xla/xla/service/gpu/copy_fusion.h
deleted file mode 100644
index 973b671..0000000
--- a/third_party/xla/xla/service/gpu/copy_fusion.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_COPY_FUSION_H_
-#define XLA_SERVICE_GPU_COPY_FUSION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// CopyFusion checks if a fusion is followed by multiple copies and if so, adds
-// those copies to the fusion, replacing the copies with get_tuple_elements.
-class CopyFusion : public HloModulePass {
- public:
- CopyFusion() = default;
-
- absl::string_view name() const override { return "copy_fusion"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- absl::StatusOr<bool> DoCopyFusion(HloComputation* computation);
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_COPY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/copy_fusion_test.cc
deleted file mode 100644
index d2116eb..0000000
--- a/third_party/xla/xla/service/gpu/copy_fusion_test.cc
+++ /dev/null
@@ -1,500 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/copy_fusion.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/str_cat.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-
-namespace m = ::xla::match;
-
-class CopyFusionTest : public HloTestBase {
- public:
- CopyFusion cf_;
-};
-
-const char kModulePrefix[] = R"(
- HloModule test_module
-
- scalar_add_computation {
- scalar_lhs.0 = f32[] parameter(0)
- scalar_rhs.0 = f32[] parameter(1)
- ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
- }
- scalar_mul_computation {
- scalar_lhs.1 = f32[] parameter(0)
- scalar_rhs.1 = f32[] parameter(1)
- ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
- })";
-
-TEST_F(CopyFusionTest, CopyFusionTransposeOfBroadcastedConstantTwoCopies) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- two = f32[] constant(2.0)
- broadcast = f32[16,32]{1,0} broadcast(two), dimensions={}
- s.1 = f32[16,32]{1,0} sqrt(broadcast)
- ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
- }
-
- ENTRY main {
- fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
- copy.1 = f32[32,16]{1,0} copy(fusion)
- copy.2 = f32[32,16]{1,0} copy(fusion)
- ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement())));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionTransposeTwoCopies) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- param_0.1 = f32[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
- }
-
- ENTRY main {
- p = f32[16,32]{1,0} parameter(0)
- fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
- copy.1 = f32[32,16]{1,0} copy(fusion)
- copy.2 = f32[32,16]{1,0} copy(fusion)
- ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
- })"))
- .value();
- ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopies) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
- }
-
- ENTRY entry {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
- copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
- copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
- ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement())));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithReduce) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- const.1 = f32[] parameter(0)
- ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
- }
-
- ENTRY entry {
- p0 = f32[] parameter(0)
- p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
- copy.1 = f32[512]{0} copy(fusion)
- copy.2 = f32[512]{0} copy(fusion)
- ROOT root = (f32[512]{0}, f32[512]{0}) tuple(copy.1, copy.2)
- })"))
- .value();
- ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldRunWithUncopiedReduce) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- two = f32[] constant(2.0)
- broadcast = f32[128,512,28,28]{3,2,1,0} broadcast(two)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(broadcast, broadcast)
- const = f32[] constant(0.0)
- reduce = f32[512]{0} reduce(mul, const), dimensions={0,2,3}, to_apply=scalar_add_computation
- ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(mul, reduce)
- }
-
- ENTRY entry {
- fusion = (f32[128,512,28,28]{3,2,1,0}, f32[512]) fusion(), kind=kInput, calls=fused_computation
- gte = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
- gte.2 = f32[512]{0} get-tuple-element(fusion), index=1
- copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte)
- ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(copy.1, gte.2)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement())));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Multiply(), m::Reduce(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotFuseForSliceMultioutputFusion) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
- slice1 = f32[128,100,28,28]{3,2,1,0} slice(mul), slice={[0:128],[0:100],[0:28],[0:28]}
- slice2 = f32[128,200,28,28]{3,2,1,0} slice(mul), slice={[0:128],[50:250],[0:28],[0:28]}
- ROOT tuple = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) tuple(slice1, slice2)
- }
-
- ENTRY entry {
- p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- ROOT fusion = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) fusion(p1), kind=kInput, calls=fused_computation
- })"))
- .value();
- ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithScatter) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
- input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} negate(p0)
- ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(input_tensor, scatter_indices, updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=scalar_add_computation
-}
-
- ENTRY entry {
- param.0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- param.1 = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- param.2 = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
- fusion = f32[50,49,48,47,46]{4,3,2,1,0} fusion(param.0, param.1, param.2), kind=kInput, calls=fused_computation
- ROOT copy = f32[50,49,48,47,46]{4,3,2,1,0} copy(fusion)
- })"))
- .value();
- ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunOutsideEntryComputation) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-fused_computation.549 {
- param_0.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
- bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.8511)
- slice = bf16[15,1,2,48,128,1]{5,4,3,2,1,0} slice(bitcast.52601), slice={[0:15:1], [0:1:1], [0:2:1], [0:48:1], [0:128:1], [0:1:1]}
- bitcast = bf16[15,1,2,48,128]{4,3,2,1,0} bitcast(slice)
- ROOT broadcast = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} broadcast(bitcast), dimensions={0,1,2,3,4}
-}
-
-condition {
- constant_6915 = s32[] constant(15)
- param.218 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
- get-tuple-element.3714 = s32[] get-tuple-element(param.218), index=1
- ROOT compare.1738 = pred[] compare(get-tuple-element.3714, constant_6915), direction=LT
-}
-
-body {
- tuple_param = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
- param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(tuple_param), index=0
- param_1 = s32[] get-tuple-element(tuple_param), index=1
- fusion.549 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation.549
- bitcast = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(fusion.549)
- copy = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(bitcast)
- constant_one = s32[] constant(1)
- add = s32[] add(param_1, constant_one), control-predecessors={fusion.549}
- ROOT tuple = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) tuple(copy, add)
-}
-
-ENTRY main {
- param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
- zero = s32[] constant(0)
- copy.0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(param_0)
- copy.1 = s32[] copy(zero)
- tuple = tuple(copy.0, copy.1)
- ROOT while = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) while(tuple), condition=condition, body=body, backend_config="{\"known_trip_count\":{\"n\":\"15\"}}"
-})"))
- .value();
- ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithDynamicUpdateSliceInplace) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p.0 = f16[50,96,1024]{2,1,0} parameter(0)
- p.1 = f16[1,96,1024]{2,1,0} parameter(1)
- c.0 = s32[3]{0} constant({0, 0, 0})
- ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
- }
-
- ENTRY entry {
- p0 = f16[50,96,1024]{2,1,0} parameter(0)
- p1 = f16[1,96,1024]{2,1,0} parameter(1)
- fusion = f16[50,96,1024]{2,1,0} fusion(p0, p1), kind=kInput, calls=fused_computation
- copy.1 = f16[50,96,1024]{2,1,0} copy(fusion)
- copy.2 = f16[50,96,1024]{2,1,0} copy(fusion)
- ROOT root = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy.1, copy.2)
- })"))
- .value();
- ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionWithDynamicUpdateSliceNotInplace) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- one = f32[] constant(1.0)
- zero = f32[] constant(0.0)
- p.0 = f16[50,96,1024]{2,1,0} broadcast(one), dimensions={}
- p.1 = f16[1,96,1024]{2,1,0} broadcast(zero), dimensions={}
- c.0 = s32[3]{0} constant({0, 0, 0})
- dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
- neg = f16[50,96,1024]{2,1,0} negate(dynamic-update-slice)
- ROOT tuple = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(dynamic-update-slice, neg)
- }
-
- ENTRY entry {
- fusion = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) fusion(), kind=kInput, calls=fused_computation
- gte.0 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=0
- gte.1 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=1
- bitcast = f16[1,50,96,1024]{3,2,1,0} bitcast(gte.0)
- copy = f16[1,50,96,1024]{3,2,1,0} copy(bitcast)
- ROOT root = (f16[1,50,96,1024]{3,2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy, gte.1)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement())));
- EXPECT_THAT(
- fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::DynamicUpdateSlice(), m::Negate(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionTransposeAndThreeCopies) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- two = f32[] constant(2.0)
- param_0.1 = f32[16,32]{1,0} broadcast(two), dimensions={}
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
- }
-
- ENTRY entry {
- fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
- copy.1 = f32[32,16]{1,0} copy(fusion)
- copy.2 = f32[32,16]{1,0} copy(fusion)
- copy.3 = f32[32,16]{1,0} copy(fusion)
- ROOT root = (f32[32,16]{1,0}, f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.1, copy.2, copy.3)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root,
- GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement(), m::GetTupleElement())));
- EXPECT_THAT(
- fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneCopy) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
- }
-
- ENTRY entry {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
- ROOT copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::GetTupleElement(m::Fusion(&fusion))));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Negate(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopiesAndTransposeCopy) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
- }
-
- ENTRY entry {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
- copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
- transpose = f32[128,512,28,28]{2,3,0,1} copy(fusion)
- bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
- copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
- ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::Bitcast(), m::GetTupleElement())));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneNonTransposeCopy) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
- }
-
- ENTRY entry {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
- copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
- transpose.1 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
- bitcast.1 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.1)
- transpose.2 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
- bitcast.2 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.2)
- ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}) tuple(copy.1, bitcast.1, bitcast.2)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::Bitcast(), m::Bitcast())));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Negate(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionSkipTupleCopies) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
- neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
- ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
- }
-
- ENTRY entry {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
- copy.1 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
- copy.2 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
- ROOT root = ((f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}),(f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0})) tuple(copy.1, copy.2)
- })"))
- .value();
- ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionTupleAndGetTuple) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
- neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
- ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
- }
-
- ENTRY entry {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
- gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
- gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
- copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
- copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
- ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement())));
- EXPECT_THAT(
- fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionWithFusionReturningTupleAndOtherUser) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
- neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
- ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
- }
-
- ENTRY entry {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
- gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
- gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
- copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
- copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
- transpose = f32[128,512,28,28]{2,3,0,1} copy(gte.1)
- bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
- ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
- })"))
- .value();
- ASSERT_TRUE(cf_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root,
- GmockMatch(m::Tuple(m::Copy(), m::Bitcast(),
- m::GetTupleElement(m::Fusion(&fusion)))));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy())));
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc
deleted file mode 100644
index f0da0e5..0000000
--- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc
+++ /dev/null
@@ -1,210 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/gemm_fusion.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-static absl::StatusOr<bool> PadForGemm(HloDotInstruction* dot,
- PrimitiveType datatype,
- int pad_to_multiple_of) {
- auto* lhs = dot->mutable_operand(0);
- auto* rhs = dot->mutable_operand(1);
-
- Shape lshape = lhs->shape();
- Shape rshape = rhs->shape();
- Shape result_shape = dot->shape();
-
- if (lshape.element_type() != datatype || rshape.element_type() != datatype) {
- return false;
- }
-
- auto pad_dim = [&](Shape& s, int dim) {
- s.set_dimensions(dim,
- RoundUpTo<int64_t>(s.dimensions(dim), pad_to_multiple_of));
- };
-
- auto pad_matrix_dims = [&pad_dim](Shape s) {
- // Since the dot instruction is canonicalized, the last two dimensions for
- // each operand represent non-batch dimensions, and the others are the same
- // for both operands and correspond to batch dimensions.
- pad_dim(s, s.rank() - 2);
- pad_dim(s, s.rank() - 1);
- return s;
- };
-
- Shape new_lshape = pad_matrix_dims(lshape);
- Shape new_rshape = pad_matrix_dims(rshape);
- Shape new_result_shape = pad_matrix_dims(result_shape);
-
- if (new_lshape == lshape && new_rshape == rshape) {
- return false;
- }
-
- VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
- VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
- << new_result_shape;
-
- auto create_padding_config = [](Shape& shape, Shape& new_shape) {
- PaddingConfig padding_config;
- for (int i = 0; i < shape.rank(); ++i) {
- auto dimension = padding_config.add_dimensions();
- dimension->set_edge_padding_high(new_shape.dimensions()[i] -
- shape.dimensions()[i]);
- dimension->set_edge_padding_low(0);
- dimension->set_interior_padding(0);
- }
- return padding_config;
- };
-
- auto l_padding_config = create_padding_config(lshape, new_lshape);
- auto r_padding_config = create_padding_config(rshape, new_rshape);
-
- HloComputation* parent = dot->parent();
-
- HloInstruction* zero_float = parent->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::Zero(datatype)));
- zero_float->set_metadata(dot->metadata());
-
- HloInstruction* lpad = parent->AddInstruction(
- HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
- lpad->set_metadata(dot->metadata());
-
- HloInstruction* rpad = parent->AddInstruction(
- HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
- rpad->set_metadata(dot->metadata());
-
- HloInstruction* new_dot = parent->AddInstruction(
- dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
-
- std::vector<int64_t> start_indices(result_shape.rank(), 0);
- std::vector<int64_t> strides(result_shape.rank(), 1);
- HloInstruction* slice = parent->AddInstruction(
- HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
- result_shape.dimensions(), strides));
- slice->set_metadata(dot->metadata());
-
- bool is_root = dot->user_count() == 0;
-
- TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
-
- if (is_root) {
- parent->set_root_instruction(slice);
- }
-
- return true;
-}
-
-namespace {
-
-// We need this check because PadForGemm works in the assumption that
-// the dot instruction is canonicalized.
-bool CheckCanonical(HloDotInstruction* dot) {
- const auto& dimension_numbers = dot->dot_dimension_numbers();
-
- if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
- dot->operand(0)->shape().rank() ||
- dimension_numbers.rhs_batch_dimensions_size() + 2 !=
- dot->operand(1)->shape().rank()) {
- VLOG(2)
- << dot->ToString()
- << " is not canonical: Expected all dimensions but 2 to be "
- "batch_dimensions. Hence, this dot is not a candidate for padding.";
- return false;
- }
-
- std::vector<int64_t> canonical_batch_dims(
- dimension_numbers.lhs_batch_dimensions_size());
- absl::c_iota(canonical_batch_dims, 0);
- if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
- canonical_batch_dims) ||
- !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
- canonical_batch_dims)) {
- VLOG(2)
- << dot->ToString()
- << " is not canonical: Expected batch dimensions to be all "
- "dimensions except for the last 2 ones. Hence, this dot is not a "
- "candidate for padding.";
- return false;
- }
-
- return true;
-}
-
-} // namespace
-
-static std::vector<HloDotInstruction*> GetRelevantDots(
- const se::GpuComputeCapability& gpu_compute_capability,
- HloComputation* comp, PrimitiveType datatype) {
- std::vector<HloDotInstruction*> gemms;
-
- for (HloInstruction* instr : comp->instructions()) {
- if (IsMatrixMultiplication(*instr)) {
- HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
- if (instr->operand(0)->shape().element_type() == datatype &&
- CheckCanonical(dot) &&
- !(instr->GetModule()
- ->config()
- .debug_options()
- .xla_gpu_enable_triton_gemm() &&
- legacy_triton::IsTritonSupportedInstruction(
- *dot, gpu_compute_capability) &&
- ShouldTritonHandleGEMM(*dot, gpu_compute_capability))) {
- gemms.push_back(dot);
- }
- }
- }
- return gemms;
-}
-
-absl::StatusOr<bool> CublasPadForGemms::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloDotInstruction* dot :
- GetRelevantDots(gpu_compute_capability_, comp, datatype_)) {
- TF_ASSIGN_OR_RETURN(bool result,
- PadForGemm(dot, datatype_, pad_to_multiple_of_));
- changed |= result;
- }
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h
deleted file mode 100644
index 2a1f9c6..0000000
--- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_
-#define XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Adds padding to dot operations to make them run faster on GPUs.
-//
-//
-// This can be used to pad f16 dots on tensor cores, or s8 dots to multiples of
-// four.
-//
-// This pass depends on xla::DotDecomposer pass,
-// so it should go strictly later.
-class CublasPadForGemms : public HloModulePass {
- public:
- CublasPadForGemms(const se::GpuComputeCapability gpu_compute_capability,
- PrimitiveType datatype, int32_t pad_to_multiple_of)
- : gpu_compute_capability_(gpu_compute_capability),
- datatype_(datatype),
- pad_to_multiple_of_(pad_to_multiple_of) {}
-
- absl::string_view name() const override { return "cublas-pad-for-gemms"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const se::GpuComputeCapability gpu_compute_capability_;
- PrimitiveType datatype_;
- int32_t pad_to_multiple_of_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_
diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc
deleted file mode 100644
index d20dd94..0000000
--- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc
+++ /dev/null
@@ -1,306 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace m = ::xla::match;
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class CublasGemmPadForTensorCoresTest : public HloTestBase {
- protected:
- bool PadForF16Gemms(HloModule* module) {
- return CublasPadForGemms(se::CudaComputeCapability(7, 0),
- PrimitiveType::F16, 8)
- .Run(module)
- .value();
- }
-
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
- // Some pads would not be added if we detect that Triton will handle the
- // given dot operation.
- debug_options.set_xla_gpu_triton_gemm_any(false);
- return debug_options;
- }
-};
-
-TEST_F(CublasGemmPadForTensorCoresTest, OneDotRootComputation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f16[2048,1024] parameter(0)
- %param2 = f16[1024,33708] parameter(1)
- ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
- f16[1024,33708]{0,1} %param2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })")
- .value();
-
- EXPECT_TRUE(PadForF16Gemms(module.get()));
- SCOPED_TRACE(module->ToString());
-
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(
- m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {2048, 1024}),
- m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {1024, 33712}))
- .WithShape(F16, {2048, 33712})
- .WithContractingDims(/*lhs_contracting_dims=*/{1},
- /*rhs_contracting_dims=*/{0}))
- .WithShape(F16, {2048, 33708})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, OneDotS8RootComputation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = s8[2047,1023] parameter(0)
- %param2 = s8[1023,33707] parameter(1)
- ROOT %dot.2309 = s32[2047,33707]{1,0} dot(s8[2047,1023]{1,0} %param1,
- s8[1023,33707]{0,1} %param2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })")
- .value();
-
- EXPECT_TRUE(
- CublasPadForGemms(se::CudaComputeCapability(7, 0), PrimitiveType::S8, 4)
- .Run(module.get())
- .value());
- SCOPED_TRACE(module->ToString());
-
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(
- m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(S8, {2047, 1023}),
- m::Constant().WithShape(S8, {}))
- .WithShape(S8, {2048, 1024}),
- m::Pad(m::Parameter().WithShape(S8, {1023, 33707}),
- m::Constant().WithShape(S8, {}))
- .WithShape(S8, {1024, 33708}))
- .WithShape(S32, {2048, 33708})
- .WithContractingDims(/*lhs_contracting_dims=*/{1},
- /*rhs_contracting_dims=*/{0}))
- .WithShape(S32, {2047, 33707})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, TwoDotsComputation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f16[2048, 1024] parameter(0)
- %param2 = f16[1024, 33708] parameter(1)
- %param3 = f16[33708, 1] parameter(2)
- %dot1 = f16[2048, 33708]{1,0} dot(f16[2048, 1024]{1,0} %param1,
- f16[1024, 33708]{0,1} %param2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT %dot2 = f16[2048, 1]{1,0} dot(f16[2048, 33708]{1,0} %dot1,
- f16[33708, 1]{0,1} %param3),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })")
- .value();
-
- EXPECT_TRUE(PadForF16Gemms(module.get()));
- SCOPED_TRACE(module->ToString());
-
- auto* root = module->entry_computation()->root_instruction();
- const HloInstruction* dot2 = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(
- m::Slice(
- m::Dot(
- m::Pad(m::Slice(m::Dot(&dot2,
- m::Pad().WithShape(F16, {2048, 1024}),
- m::Pad().WithShape(F16, {1024, 33712}))
- .WithContractingDims(
- /*lhs_contracting_dims=*/{1},
- /*rhs_contracting_dims=*/{0})
- .WithShape(F16, {2048, 33712}))
- .WithShape(F16, {2048, 33708}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {2048, 33712}),
-
- m::Pad(m::Parameter().WithShape(F16, {33708, 1}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {33712, 8}))
- .WithShape(F16, {2048, 8})
- .WithContractingDims(/*lhs_contracting_dims=*/{1},
- /*rhs_contracting_dims=*/{0}))
- .WithShape(F16, {2048, 1})));
-
- EXPECT_THAT(
- dot2,
- GmockMatch(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {2048, 1024}),
- m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {1024, 33712}))
- .WithContractingDims(/*lhs_contracting_dims=*/{1},
- /*rhs_contracting_dims=*/{0})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, DotWithBatchDimensions) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f16[3, 5, 2048, 1024] parameter(0)
- %param2 = f16[3, 5, 1024, 33708] parameter(1)
- ROOT %dot.2309 = f16[3, 5, 2048, 33708]{3, 2, 1,0} dot(f16[3, 5, 2048, 1024]{3, 2, 1,0} %param1,
- f16[3, 5, 1024, 33708]{2, 3, 0,1} %param2), lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1}, lhs_contracting_dims={3}, rhs_contracting_dims={2}})")
- .value();
-
- EXPECT_TRUE(PadForF16Gemms(module.get()));
- SCOPED_TRACE(module->ToString());
-
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(
- m::Slice(
- m::Dot(m::Pad(m::Parameter().WithShape(F16, {3, 5, 2048, 1024}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {3, 5, 2048, 1024}),
- m::Pad(m::Parameter().WithShape(F16, {3, 5, 1024, 33708}),
- m::Constant().WithShape(F16, {}))
- .WithShape(F16, {3, 5, 1024, 33712}))
- .WithShape(F16, {3, 5, 2048, 33712})
- .WithContractingDims(/*lhs_contracting_dims=*/{3},
- /*rhs_contracting_dims=*/{2}))
- .WithShape(F16, {3, 5, 2048, 33708})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, NoDotComputation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %x = f32[] parameter(0)
- %y = f32[] parameter(1)
- ROOT %maximum = f32[] maximum(f32[] %x, f32[] %y)
- })")
- .value();
-
- EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, F32DotComputation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f32[2048,1024] parameter(0)
- %param2 = f32[1024,33708] parameter(1)
- ROOT %dot.2309 = f32[2048,33708]{1,0} dot(f32[2048,1024]{1,0} %param1,
- f32[1024,33708]{0,1} %param2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
- .value();
-
- EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, F64DotComputation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f64[2048,1024] parameter(0)
- %param2 = f64[1024,33708] parameter(1)
- ROOT %dot.2309 = f64[2048,33708]{1,0} dot(f64[2048,1024]{1,0} %param1,
- f64[1024,33708]{0,1} %param2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
- .value();
-
- EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, MultiplesOf8DotComputation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f16[2048,1024] parameter(0)
- %param2 = f16[1024,33712] parameter(1)
- ROOT %dot.2309 = f16[2048,33712]{1,0} dot(f16[2048,1024]{1,0} %param1,
- f16[1024,33712]{0,1} %param2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
- .value();
-
- EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, CheckSavingMetadata) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f16[2048,1024] parameter(0)
- %param2 = f16[1024,33708] parameter(1)
- ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
- f16[1024,33708]{0,1} %param2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0},
- metadata={op_type="MatMul" op_name="transformer_v2/Transformer/decode/embedding_shared_weights_1/presoftmax_linear/MatMul"}
- })")
- .value();
-
- SCOPED_TRACE(module->ToString());
-
- EXPECT_TRUE(PadForF16Gemms(module.get()));
- auto metadata = module->entry_computation()->root_instruction()->metadata();
- EXPECT_EQ("MatMul", metadata.op_type());
- EXPECT_EQ(
- "transformer_v2/Transformer/decode/embedding_shared_weights_1/"
- "presoftmax_linear/MatMul",
- metadata.op_name());
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, NotCanonicalizedDot) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- %param1 = f16[3, 5, 2048, 1024] parameter(0)
- %param2 = f16[3, 5, 1024, 33708] parameter(1)
- ROOT %dot.2309 = f16[3,2048, 33708]{2, 1, 0} dot(f16[3, 5, 2048, 1024]{3, 2, 1, 0} %param1, f16[3, 5, 1024, 33708]{3, 2, 1, 0} %param2), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={3, 1}, rhs_contracting_dims={2, 1}})")
- .value();
-
- EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-} // anonymous namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc
deleted file mode 100644
index e9cb21b..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc
+++ /dev/null
@@ -1,1566 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-
-#include <algorithm>
-#include <array>
-#include <cstdint>
-#include <functional>
-#include <limits>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/ml_dtypes.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = match;
-
-bool IsConvCustomCall(const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kCustomCall &&
- (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
- instr->custom_call_target() ==
- kCudnnConvBiasActivationForwardCallTarget);
-}
-
-bool IsConvDepthwise(const HloInstruction* instr) {
- int64_t feature_group_count = instr->feature_group_count();
- if (feature_group_count == 1) {
- return false;
- }
-
- const HloInstruction* input = instr->operand(0);
- int64_t input_feature_dimension =
- instr->convolution_dimension_numbers().input_feature_dimension();
- int64_t input_feature_count =
- input->shape().dimensions(input_feature_dimension);
- return input_feature_count == feature_group_count;
-}
-
-// We don't want to upgrade depthwise convolutions to ConvBiasActivation,
-// because the fused CUDNN functions are slower for some of those.
-bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
- return IsConvCustomCall(instr) && !IsConvDepthwise(instr);
-}
-
-bool IsROCm(se::GpuComputeCapability cc) {
- return std::holds_alternative<se::RocmComputeCapability>(cc);
-}
-
-// elu, relu6, and leaky-relu activations are supported in cudnn via the
-// "runtime fusion" engine, which JIT compiles C++ code. This can be slow to
-// compile, so we guard it with a debug option.
-//
-// nvidia currently recommends that we enable this only on Ampere+, but we've
-// tested on Turing (sm75) and it seems to work fine.
-//
-// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default
-// due to apparent bugs in cudnn 8.9.0. See debug_options_flags.cc for details.
-bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts,
- se::GpuComputeCapability cc) {
- const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&cc);
- if (cuda_cc != nullptr)
- return debug_opts.xla_gpu_use_runtime_fusion() && cuda_cc->IsAtLeast(7, 5);
- else
- return true;
-}
-
-bool IsSuitableForCudnnRuntimeFusion(HloInstruction* conv) {
- // cudnn runtime fusion is pathologically slow on convs with side-inputs.
- // TODO(kaixih@nvidia): remove this check when cuDNN fixes it.
- if (conv->operands().size() > 3) {
- return false;
- }
-
- // cuDNN runtime funsion kernels require 32-bit aligned data access, which
- // means that the number of in/out channels must be divisible by 2 for fp16.
- // (We don't currently do runtime fusion for int8.)
- if (conv->operand(0)->shape().element_type() != F16) {
- return false;
- }
- const Shape& shape = conv->operand(1)->shape();
- int64_t num_input_features = shape.dimensions(
- conv->convolution_dimension_numbers().kernel_input_feature_dimension());
- int64_t num_output_features = shape.dimensions(
- conv->convolution_dimension_numbers().kernel_output_feature_dimension());
- if (num_input_features % 2 != 0 || num_output_features % 2 != 0) {
- return false;
- }
-
- return true;
-}
-
-// Can instr be converted to type `dst_ty` without losing any precision? For
-// our purposes, this is true if:
-//
-// - instr already has type dst_ty, or
-// - instr is convert<wider type>(op_with_dst_ty), or
-// - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
-// get back exactly the original value, or
-// - instr is a broadcast, reshape, or transpose of one of the above.
-bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
- PrimitiveType dst_ty) {
- if (instr->shape().element_type() == dst_ty) {
- return true;
- }
-
- if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
- // Check that the convert from dst_ty to instr->element_type() doesn't lose
- // precision. Otherwise, this convert is not lossless.
- return primitive_util::CastPreservesValues(dst_ty,
- instr->shape().element_type());
- }
-
- if (instr->opcode() == HloOpcode::kConstant) {
- if (!instr->shape().IsArray()) {
- return false;
- }
- // Check if instr's literal roundtrips to ty and back to its original type
- // without modification.
- PrimitiveType orig_ty = instr->shape().element_type();
-
- // The only reason Convert() should fail is if we don't support converting
- // from x to y, which indeed means it's not losslessly-convertible.
- absl::StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
- if (!converted1.ok()) {
- return false;
- }
- absl::StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
- if (!converted2.ok()) {
- return false;
- }
-
- return instr->literal() == *converted2;
- }
-
- if (instr->opcode() == HloOpcode::kBroadcast ||
- instr->opcode() == HloOpcode::kReshape ||
- instr->opcode() == HloOpcode::kTranspose) {
- return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
- }
-
- return false;
-}
-
-// Helpers suitable for use in m::Op().WithPredicate(...).
-bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
- return IsLosslesslyConvertibleTo(instr, S8);
-}
-bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
- return IsLosslesslyConvertibleTo(instr, F16);
-}
-
-// If `conv` is a vanilla forward conv, transforms it into a
-// conv-bias-activation. If it's already a conv-bias-activation, does nothing.
-//
-// If `conv` is anything else, returns an error.
-absl::StatusOr<HloInstruction*> EnsureIsConvBiasActivation(
- HloInstruction* conv) {
- CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
-
- if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
- return conv;
- }
-
- if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
- HloComputation* comp = conv->parent();
-
- const Shape& shape = conv->shape().tuple_shapes(0);
- int64_t num_output_features = shape.dimensions(
- conv->convolution_dimension_numbers().output_feature_dimension());
-
- // bias for integer convs is always f32, see
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- PrimitiveType bias_ty;
- if (primitive_util::IsIntegralType(shape.element_type())) {
- bias_ty = F32;
- } else {
- bias_ty = shape.element_type();
- }
- auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
-
- absl::InlinedVector<HloInstruction*, 3> new_operands(
- conv->operands().begin(), conv->operands().end());
- new_operands.push_back(bias);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), new_operands));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
- new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
- comp->parent()->SetAndUniquifyInstrName(new_conv,
- "cudnn-conv-bias-activation");
- return new_conv;
- }
-
- return FailedPrecondition("Unsupported conv: %s", conv->ToString());
-}
-
-// convert<cvt_type>(gte(custom-call<conv_type>(int8_x, int8_w))) ->
-// gte(custom-call<cvt_type>(int8_x, int8_w))
-absl::StatusOr<bool> FuseConvertTypeIntoConv(HloComputation* comp,
- PrimitiveType conv_type,
- PrimitiveType cvt_type) {
- bool changed = false;
- for (auto instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv = nullptr;
- auto tuple_elem =
- m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
- .WithElementType(conv_type);
- auto pattern =
- m::Convert(tuple_elem.WithOneUser()).WithElementType(cvt_type);
- if (!Match(instr, pattern)) {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvertTypeIntoConv: ", conv->ToString());
- })) {
- continue;
- }
-
- Shape new_shape = conv->shape();
- new_shape.mutable_tuple_shapes(0)->set_element_type(cvt_type);
- HloInstruction* new_conv =
- comp->AddInstruction(conv->CloneWithNewShape(new_shape));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
-
- changed = true;
- }
-
- return changed;
-}
-
-struct ConvConvertTypes {
- PrimitiveType convolution_type;
- PrimitiveType conversion_type;
-};
-
-// Remove convert around convolution by making the convolution-type
-// (custom call) to be the same as the conversion result.
-// For example: convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
-// gte(custom-call<float>(int8_x, int8_w))
-absl::StatusOr<bool> FuseRemoveConvertInConv(HloComputation* comp) {
- bool changed = false;
- // Note: We are eliminating F16->F32 because it fails on internal tests.
- std::array<ConvConvertTypes, 3> types{{
- {S32, F32},
- {S8, F32},
- {F32, S8},
- }};
- for (auto [conv_type, cvt_type] : types) {
- TF_ASSIGN_OR_RETURN(bool curr_change,
- FuseConvertTypeIntoConv(comp, conv_type, cvt_type));
- changed |= curr_change;
- }
- return changed;
-}
-
-// alpha * gte(custom-call(...)) ->
-// gte(custom-call(..., backend_config={alpha})).
-absl::StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
- bool changed = false;
- for (auto instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv = nullptr;
- HloInstruction* gte = nullptr;
- HloInstruction* alpha = nullptr;
-
- auto pattern = m::MultiplyAnyOrder(
- m::GetTupleElement(
- >e, m::Op(&conv).WithPredicate(IsNonDepthwiseConvCustomCall), 0)
- .WithOneUse(),
- m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
- if (!Match(instr, pattern)) {
- continue;
- }
-
- // alpha is f32 except for f64 convs, where it's f64. See
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
- if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
-
- if (config.conv_result_scale() != 1) {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvAlpha: ", conv->ToString());
- })) {
- continue;
- }
-
- // StreamExecutor doesn't support the alpha parameter on non-bias-activation
- // convs, so we have to upgrade `conv`.
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-
- TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
- config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
- TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
-
- changed = true;
- }
- return changed;
-}
-
-// The format of the serialized graph describing a sequence of ops fused
-// into the cuDNN convolution Custom Call is
-// "UID:[output_type]conv();UID[output_type]:op_name(operand
-// UID);UID:[output_type]op_name(operand UID);..." with the convolution assumed
-// to be the first op in the graph. Operand UIDs identifying ops outside the
-// serialized graph are elided. Currently, multiplication and division by a
-// broadcast scalar, addition of a matrix bias, the application of a ReLU
-// activation and the calculation of the maximum of the absolute value are
-// supported.
-class GraphString {
- public:
- GraphString() = default;
-
- bool AppendOp(std::string op_name, HloInstruction* op,
- std::vector<HloInstruction*> operands = {}) {
- std::optional<int64_t> operand_uid;
- int num_operands_in_graph = 0;
- for (HloInstruction* operand : operands) {
- if (OpInGraph(operand->unique_id())) {
- num_operands_in_graph++;
- // Ops with more than one operand in the graph are not supported.
- if (num_operands_in_graph > 1) {
- return false;
- }
- operand_uid = operand->unique_id();
- }
- }
- graph_.emplace_back(OpDescriptor(
- {op->unique_id(), op->shape().element_type(), op_name, operand_uid}));
- return true;
- }
-
- void ChangeDataType(PrimitiveType type) {
- DCHECK(!graph_.empty());
- graph_.back().output_type = type;
- }
-
- std::string Graph() const {
- std::string graph;
- for (OpDescriptor op : graph_) {
- graph.append(std::to_string(op.uid));
- graph.append(":[" +
- primitive_util::LowercasePrimitiveTypeName(op.output_type) +
- "]");
- graph.append(op.name);
- graph.append("(");
- if (op.operand.has_value()) {
- graph.append(std::to_string(*op.operand));
- }
- graph.append(");");
- }
- return graph;
- }
-
- bool OpInGraph(int64_t uid, std::string op_name = "") const {
- auto op_filter = [&](OpDescriptor op) -> bool {
- if (op_name.empty()) {
- return op.uid == uid;
- } else {
- return op.uid == uid && op.name == op_name;
- }
- };
- return std::find_if(graph_.begin(), graph_.end(), op_filter) !=
- graph_.end();
- }
-
- private:
- struct OpDescriptor {
- int64_t uid;
- PrimitiveType output_type;
- std::string name;
- std::optional<int64_t> operand;
- };
-
- std::vector<OpDescriptor> graph_;
-};
-
-bool IsF8Type(const HloInstruction* instr) {
- return primitive_util::IsF8Type(instr->shape().element_type());
-}
-
-bool IsScalar(const HloInstruction* instr) {
- return ShapeUtil::IsScalar(instr->shape());
-}
-
-std::optional<PrimitiveType> IsSaturatingCastToF8(HloInstruction* instr) {
- HloInstruction *op, *clamp_lower, *clamp_upper;
- if (Match(instr,
- m::Convert(
- &op,
- m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), m::Op(),
- m::Broadcast(m::ConstantScalar(&clamp_upper))))) &&
- ((op->shape().element_type() == F8E4M3FN &&
- clamp_lower->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e4m3fn>::lowest())) &&
- clamp_upper->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e4m3fn>::max()))) ||
- (op->shape().element_type() == F8E5M2 &&
- clamp_lower->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e5m2>::lowest())) &&
- clamp_upper->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e5m2>::max()))))) {
- return op->shape().element_type();
- }
- return std::nullopt;
-}
-
-// Returns whether the HLO Computation applied by `op` calculates the largest
-// element.
-bool AppliesMaxReduce(HloInstruction* op) {
- HloComputation* reduce_comp = op->to_apply();
- HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
- return ShapeUtil::IsScalar(op->shape()) &&
- ShapeUtil::IsScalar(op->operand(1)->shape()) &&
- op->operand(1)->IsConstant() &&
- op->operand(1)->literal().GetAsDouble({}) <= 0. &&
- reduce_comp_root->opcode() == HloOpcode::kMaximum &&
- reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
- reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
-}
-
-// Recursively captures and serializes the graph of pointwise operations
-// operating on the convolution.
-void CaptureConvGraphRecursive(HloInstruction* instr,
- std::vector<HloInstruction*>& operands,
- std::vector<HloInstruction*>& aux_outputs,
- GraphString& graph_string,
- absl::flat_hash_set<int>& visited_instrs,
- HloInstruction*& final_instr) {
- // Avoid visiting the same instruction more than once.
- if (!visited_instrs.emplace(instr->unique_id()).second) {
- return;
- }
- final_instr = instr;
-
- // Copy the current state in case fusion will be unsuccessful or unfavorable.
- GraphString init_graph_string = graph_string;
- std::vector<HloInstruction*> init_operands = operands,
- init_aux_outputs = aux_outputs;
- // The loop adds each user of `instr` that supports fusion into the
- // cuDNN convolution Custom Call to GraphString. Most ops following the
- // convolution describe a linear sequence that generates a single return
- // tensor. The identification of one of these linear ops is followed by a
- // recursive call of CaptureConvGraphRecursive to match and potentially fuse
- // its users. The calculation of the scalar maximum of the absolute value
- // (Amax) of a preceding op is considered a nonlinear user as it adds a
- // return value to the convolution. The users of a nonlinear op are
- // not considered for fusion into the Custom Call. The numbers of linear and
- // nonlinear users of `instr` are stored in `num_linear_users` and
- // `num_nonlinear_users`.
- int num_linear_users = 0, num_nonlinear_users = 0;
- for (HloInstruction* user : instr->users()) {
- HloInstruction *op, *operand0, *operand1;
- // Add
- if (Match(user, m::AddAnyOrder(&op, m::Op(&operand0), m::Op(&operand1)))) {
- if (graph_string.AppendOp("add", op, {operand0, operand1})) {
- operands.push_back(operand0 == instr ? operand1 : operand0);
- num_linear_users++;
- CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
- visited_instrs, final_instr);
- }
- continue;
- }
- // Scale
- if (Match(user, m::MultiplyAnyOrder(&op, m::Op(&operand0),
- m::Broadcast(m::Op(&operand1)))) &&
- ShapeUtil::IsScalar(operand1->shape())) {
- if (graph_string.AppendOp("scale", op, {operand0, operand1})) {
- operands.push_back(operand1);
- num_linear_users++;
- CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
- visited_instrs, final_instr);
- }
- continue;
- }
- // Inverse Scale
- if (Match(user, m::Divide(&op, m::Op(&operand0),
- m::Broadcast(m::Op(&operand1)))) &&
- ShapeUtil::IsScalar(operand1->shape())) {
- if (graph_string.AppendOp("invscale", op, {operand0, operand1})) {
- operands.push_back(operand1);
- num_linear_users++;
- CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
- visited_instrs, final_instr);
- }
- continue;
- }
- // ReLU
- if (Match(user, m::MaximumAnyOrder(&op, m::Op(&operand0),
- m::Broadcast(m::ConstantScalar(0))))) {
- if (graph_string.AppendOp("relu", op, {operand0})) {
- num_linear_users++;
- CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
- visited_instrs, final_instr);
- }
- continue;
- }
- // Maximum of the absolute value (Amax) following ReLU (elided Abs) -- not
- // a linear user
- if (Match(user, m::Reduce(&op, m::Op(&operand0), m::Op())) &&
- graph_string.OpInGraph(operand0->unique_id(), "relu") &&
- AppliesMaxReduce(op)) {
- if (graph_string.AppendOp("amax", op, {operand0})) {
- aux_outputs.emplace_back(op);
- num_nonlinear_users++;
- }
- continue;
- }
-
- // The following patterns match the user of `user`.
- if (!user->users().empty()) {
- HloInstruction* users_user = user->users()[0];
- // Convert with Clamp to FP8 types
- std::optional<PrimitiveType> f8_type = IsSaturatingCastToF8(users_user);
- if (f8_type.has_value()) {
- graph_string.ChangeDataType(f8_type.value());
- num_linear_users++;
- CaptureConvGraphRecursive(users_user, operands, aux_outputs,
- graph_string, visited_instrs, final_instr);
- continue;
- }
- // Maximum of the absolute value (Amax) -- not a linear user
- if (Match(users_user,
- m::Reduce(&op, m::Abs(m::Op(&operand0)), m::Op())) &&
- AppliesMaxReduce(op)) {
- if (graph_string.AppendOp("amax", op, {operand0})) {
- aux_outputs.emplace_back(op);
- num_nonlinear_users++;
- }
- continue;
- }
- }
- }
- // Do not fuse into the cuDNN convolution Custom Call when there are more than
- // one linear or nonlinear users, or when the number of users eligible for
- // fusion is less than the total number of users.
- if (num_linear_users > 1 || num_nonlinear_users > 1 ||
- num_linear_users + num_nonlinear_users < instr->user_count()) {
- graph_string = init_graph_string;
- operands = init_operands;
- aux_outputs = init_aux_outputs;
- final_instr = instr;
- }
-}
-
-// Captures in a GraphString the subgraph of pointwise operations operating on
-// the convolution that will be fused into the cuDNN convolution Custom Call.
-absl::StatusOr<
- std::tuple<std::vector<HloInstruction*>, std::vector<HloInstruction*>,
- GraphString, HloInstruction*>>
-CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution,
- HloInstruction* wide_input, HloInstruction* wide_filter,
- HloInstruction* input_scale, HloInstruction* filter_scale,
- bool x_mult_scale, bool w_mult_scale) {
- GraphString graph_string;
- graph_string.AppendOp("conv", instr);
-
- // Shift the scaling of the input and filter to the output of the convolution.
- HloInstruction *input_scaled_conv, *filter_scaled_conv;
- if (input_scale) {
- TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(0, wide_input));
- HloInstruction* bcast_input_scale = instr->AddInstruction(
- HloInstruction::CreateBroadcast(instr->shape(), input_scale, {}));
- input_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
- instr->shape(),
- x_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide, instr,
- bcast_input_scale));
- TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(input_scaled_conv));
- }
- if (filter_scale) {
- TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(1, wide_filter));
- HloInstruction* bcast_filter_scale = instr->AddInstruction(
- HloInstruction::CreateBroadcast(instr->shape(), filter_scale, {}));
- filter_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
- instr->shape(),
- w_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide,
- input_scale ? input_scaled_conv : instr, bcast_filter_scale));
- TF_RETURN_IF_ERROR((input_scale ? input_scaled_conv : instr)
- ->ReplaceAllUsesWith(filter_scaled_conv));
- }
-
- std::vector<HloInstruction*> operands, aux_outputs;
- absl::flat_hash_set<int> visited_instrs;
- HloInstruction* final_instr;
- CaptureConvGraphRecursive(instr, operands, aux_outputs, graph_string,
- visited_instrs, final_instr);
- return std::make_tuple(operands, aux_outputs, graph_string, final_instr);
-}
-
-// Matches convolutions operating on FP8 inputs and filters and rewrites into a
-// ForwardGraph Custom Call. For scaled FP8 convolutions on Hopper systems, the
-// following steps are elided and rewritten into a ForwardGraph Custom Call:
-//
-// 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32.
-// 2. Optionally unscale the filter and input by multiplying or dividing by
-// scalars.
-// 3. Evaluate the convolution based on the scaled filter and input.
-// 4. Apply a series of elementwise transformations, where a transformation can
-// be adding a matrix bias, applying a ReLU activation, or
-// multiplying or dividing by a broadcast scalar.
-// 5. Optionally calculate the maximum of the absolute of the result.
-// 6. Optionally cast the output back to FP8.
-absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
- se::CudaComputeCapability cc,
- se::dnn::VersionInfo dnn_version,
- int32_t toolkit_version) {
- bool changed = false;
-
- if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) {
- return false;
- }
- if (toolkit_version < 12000) {
- return false;
- }
- if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) {
- return false;
- }
- for (auto instr : comp->MakeInstructionPostOrder()) {
- HloInstruction *convolution, *gte, *input, *filter,
- *input_scale = nullptr, *filter_scale = nullptr,
- *input_scale_op = nullptr, *filter_scale_op = nullptr,
- *wide_input = nullptr, *wide_filter = nullptr;
-
- auto conv_operand_maybe_scaled = [](HloInstruction** operand,
- HloInstruction** wide_operand,
- HloInstruction** scale_op,
- HloInstruction** scale) {
- return m::AnyOf<HloInstruction>(
- m::Op(operand).WithPredicate(IsF8Type),
- m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
- m::Divide(
- scale_op,
- m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
- m::Broadcast(m::Op(scale).WithPredicate(IsScalar))),
- m::MultiplyAnyOrder(
- scale_op,
- m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
- m::Broadcast(m::Op(scale).WithPredicate(IsScalar))));
- };
-
- // TODO(philipphack): Consider allowing ops between dequantization and
- // convolution.
- auto pattern = m::GetTupleElement(
- >e,
- m::CustomCall(
- &convolution,
- conv_operand_maybe_scaled(&input, &wide_input, &input_scale_op,
- &input_scale),
- conv_operand_maybe_scaled(&filter, &wide_filter, &filter_scale_op,
- &filter_scale))
- .WithPredicate(IsConvCustomCall),
- 0);
- if (Match(instr, pattern)) {
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("F8GraphConv: ", convolution->ToString());
- })) {
- continue;
- }
-
- std::vector<HloInstruction*> operands, aux_outputs;
- GraphString graph_string;
- HloInstruction* final_instr;
-
- TF_ASSIGN_OR_RETURN(
- std::tie(operands, aux_outputs, graph_string, final_instr),
- CaptureConvGraph(
- instr, convolution, wide_input, wide_filter, input_scale,
- filter_scale,
- input_scale_op ? input_scale_op->opcode() == HloOpcode::kMultiply
- : false,
- filter_scale_op
- ? filter_scale_op->opcode() == HloOpcode::kMultiply
- : false));
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- convolution->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
-
- config.set_serialized_graph(graph_string.Graph());
- operands.insert(operands.begin(), input);
- operands.insert(operands.begin() + 1, filter);
-
- std::vector<Shape> output_shapes;
- output_shapes.emplace_back(ShapeUtil::ChangeElementType(
- ShapeUtil::GetTupleElementShape(convolution->shape(), 0),
- final_instr->shape().element_type()));
- for (HloInstruction* aux_output : aux_outputs) {
- output_shapes.emplace_back(aux_output->shape());
- }
- output_shapes.emplace_back(
- ShapeUtil::GetTupleElementShape(convolution->shape(), 1));
-
- HloInstruction* new_convolution =
- comp->AddInstruction(convolution->CloneWithNewOperands(
- ShapeUtil::MakeTupleShape(output_shapes), operands));
-
- new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget);
- TF_RETURN_IF_ERROR(new_convolution->set_backend_config(gpu_config));
- TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
- MakeGetTupleElementHlo(new_convolution, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte));
-
- for (int i = 0; i < aux_outputs.size(); ++i) {
- TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
- MakeGetTupleElementHlo(new_convolution, i + 1));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(aux_outputs[i], new_gte));
- }
-
- changed = true;
- }
- }
- return changed;
-}
-
-absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
- bool changed = false;
- for (auto instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv = nullptr;
- HloInstruction* gte = nullptr;
- HloInstruction* addend = nullptr;
-
- auto pattern = m::AddAnyOrder(
- m::GetTupleElement(>e,
- m::Op(&conv)
- .WithPredicate(IsNonDepthwiseConvCustomCall)
- .WithOneUse(),
- 0)
- .WithOneUse(),
- m::Op(&addend));
- if (!Match(instr, pattern)) {
- continue;
- }
-
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
- })) {
- continue;
- }
-
- // If it's a vanilla forward conv, upgrade it to a bias-activation conv. We
- // only want to do this if the fusion will succeed, but we're guaranteed
- // that it will, because the only reason we'll bail at this point is if
- // !can_accept_bias && !can_accept_side_input, and our shiny new
- // bias-activation conv will be able to accept both.
- if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
- }
-
- // Can't fuse bias or side-input if the conv already has a relu (or other
- // activation), because bias and side-input are added before the activation
- // is applied.
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
- if (config.activation_mode() != se::dnn::kNone) {
- continue;
- }
-
- // Does `conv` already have a (nonzero) bias? Does it already have a
- // side_input?
- bool can_accept_bias =
- Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
- bool can_accept_side_input = conv->operand_count() < 4;
-
- // The addend can be fused as a bias if
- // - it is 1D broadcasted in the output feature dimension, and
- // - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
- // convs, and conv_ty for floating-point convs)
- PrimitiveType conv_ty = gte->shape().element_type();
- PrimitiveType bias_ty =
- primitive_util::IsFloatingPointType(conv_ty) ? conv_ty : F32;
- bool addend_may_be_rank1_bias =
- addend->opcode() == HloOpcode::kBroadcast &&
- addend->dimensions().size() == 1 &&
- addend->dimensions(0) ==
- conv->convolution_dimension_numbers().output_feature_dimension() &&
- IsLosslesslyConvertibleTo(addend, bias_ty);
-
- bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
- addend->dimensions().empty() &&
- IsLosslesslyConvertibleTo(addend, bias_ty);
-
- absl::InlinedVector<HloInstruction*, 4> new_operands(
- conv->operands().begin(), conv->operands().end());
- if (can_accept_bias && addend_may_be_rank1_bias) {
- new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
- &addend->operand(0)->metadata());
- } else if (can_accept_bias && addend_may_be_rank0_bias) {
- new_operands[2] = MakeBroadcastHlo(
- MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
- &addend->operand(0)->metadata()),
- /*broadcast_dimensions=*/{},
- /*result_shape_bounds=*/
- {gte->shape().dimensions(conv->convolution_dimension_numbers()
- .output_feature_dimension())});
- } else if (can_accept_side_input) {
- CHECK_EQ(new_operands.size(), 3);
- new_operands.push_back(addend);
- config.set_side_input_scale(1);
- } else {
- // Can't fuse; this op already has a bias and a side-input.
- continue;
- }
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
- TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
- changed = true;
- }
- return changed;
-}
-
-// custom-call(..., alpha * side_input) ->
-// custom-call(..., side_input, backend_config={alpha}).
-//
-// We also have to support the more complicated case of
-//
-// custom-call(..., reshape(side_input * alpha)) -->
-// custom-call(..., reshape(side_input), backend_config={alpha}),
-//
-// where `reshape` can be an arbitrary chain of reshapes+transposes. This idiom
-// is created by the ReshapeMover pass.
-absl::StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv;
- HloInstruction* side_input;
- auto pattern = m::Op(&conv)
- .WithPredicate(IsConvCustomCall)
- .WithOperand(3, m::Op(&side_input));
- if (!Match(instr, pattern)) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
- if (config.side_input_scale() != 1) {
- continue;
- }
-
- // Given side_input, pattern match the following (working from bottom up).
- //
- // before_reshape = multiply(base, broadcast(alpha))
- // side_input = chain_of_reshapes_and_transposes(before_reshape)
- //
- // where alpha is a scalar constant.
- //
- // alpha is f32 except for f64 convs, where it's f64. See
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- HloInstruction* before_reshape = side_input;
- while (before_reshape->opcode() == HloOpcode::kReshape ||
- before_reshape->opcode() == HloOpcode::kTranspose) {
- before_reshape = before_reshape->mutable_operand(0);
- }
-
- PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
- PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
- HloInstruction* base;
- HloInstruction* alpha;
- if (!Match(
- before_reshape,
- m::MultiplyAnyOrder(
- m::Op(&base),
- m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
- [&](const HloInstruction* instr) {
- return IsLosslesslyConvertibleTo(instr, alpha_ty);
- }))))) {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
- })) {
- continue;
- }
-
- // Rewrite conv's operand 3 to
- //
- // chain_of_reshapes_and_transposes(before_reshape).
- //
- // and store alpha in the conv's backend config.
- //
- // We're going to do something bad here: We aren't going to check that the
- // chain of reshapes/transposes has one use, so we're potentially
- // duplicating all these instructions (once with alpha and once without).
- //
- // This is justified because
- //
- // - duplicating reshapes/transposes shouldn't be "that bad" -- these
- // instructions can usually be fused, and
- //
- // - *not* fusing alpha can be catastrophic. For s8->s8 convolutions, the
- // side-input must be s8. But the product side_input * alpha is f32, so
- // we can only see that side-input is s8 if we fuse alpha. IOW not fusing
- // alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
- // slower than some extra transposes.
-
- // Recursively clone chain_of_reshapes_and_transposes until we get to
- // `before_reshape`, at which point we skip the multiply(base, alpha) and
- // just return base.
- std::function<HloInstruction*(const HloInstruction*)> clone =
- [&](const HloInstruction* instr) {
- if (instr == before_reshape) {
- return base;
- }
- CHECK(instr->opcode() == HloOpcode::kReshape ||
- instr->opcode() == HloOpcode::kTranspose)
- << "Must be reshape or transpose: " << instr->ToString();
- return comp->AddInstruction(instr->CloneWithNewOperands(
- instr->shape(), {clone(instr->operand(0))}));
- };
- absl::InlinedVector<HloInstruction*, 4> new_operands(
- conv->operands().begin(), conv->operands().end());
- new_operands[3] = clone(side_input);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-
- TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
- config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
- TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
-
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
- changed = true;
- }
- return changed;
-}
-
-absl::StatusOr<bool> FuseElu(HloComputation* comp,
- se::GpuComputeCapability cc) {
- if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
- cc)) {
- return false;
- }
-
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction *gte1, *gte2, *gte3;
- HloInstruction* conv;
- HloInstruction* expm1;
-
- if (!Match(instr,
- m::Select(m::Compare(m::GetTupleElement(>e1, m::Op()),
- m::Broadcast(m::ConstantEffectiveScalar(0)))
- .WithComparisonDirection(ComparisonDirection::kGt)
- .WithOneUse(),
- m::GetTupleElement(
- >e2,
- m::Op(&conv)
- .WithPredicate(IsNonDepthwiseConvCustomCall)
- .WithOneUse(),
- /*tuple_index=*/0)
- // TODO(jlebar): Why only fp16?
- .WithElementType(F16),
- m::Op(&expm1)
- .WithOpcode(HloOpcode::kExpm1)
- .WithOperand(0, m::GetTupleElement(>e3, m::Op()))
- .WithOneUse()))) {
- continue;
- }
-
- // The three GTEs should be the same, and these should be the only uses.
- if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
- continue;
- }
-
- if (!IsSuitableForCudnnRuntimeFusion(conv)) {
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
- if (config.activation_mode() != se::dnn::kNone) {
- continue;
- }
-
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseElu: ", conv->ToString());
- })) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
- config.set_activation_mode(se::dnn::kElu);
- TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
- changed = true;
- }
- return changed;
-}
-
-absl::StatusOr<bool> FuseRelu(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* gte;
- HloInstruction* conv;
- if (!Match(instr,
- m::MaximumAnyOrder(
- m::Broadcast(m::ConstantEffectiveScalar(0)),
- m::GetTupleElement(
- >e, m::Op(&conv)
- .WithPredicate(IsNonDepthwiseConvCustomCall)
- .WithOneUse())
- .WithOneUse()))) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
- if (config.activation_mode() != se::dnn::kNone) {
- continue;
- }
-
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseRelu: ", conv->ToString());
- })) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
- config.set_activation_mode(se::dnn::kRelu);
- TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
- changed = true;
- }
- return changed;
-}
-
-absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
- se::GpuComputeCapability cc) {
- if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
- cc)) {
- return false;
- }
-
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction *gte, *conv;
- if (!Match(
- instr,
- m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
- m::GetTupleElement(
- >e, m::Op(&conv)
- .WithPredicate(IsNonDepthwiseConvCustomCall)
- .WithOneUse())
- // TODO(jlebar): Why only fp16?
- .WithElementType(F16)
- .WithOneUse(),
- m::Broadcast(m::ConstantEffectiveScalar(6))))) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
- if (config.activation_mode() != se::dnn::kNone) {
- continue;
- }
-
- if (!IsSuitableForCudnnRuntimeFusion(conv)) {
- continue;
- }
-
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseRelu6: ", conv->ToString());
- })) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
- config.set_activation_mode(se::dnn::kRelu6);
- TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
- changed = true;
- }
- return changed;
-}
-
-absl::StatusOr<bool> FuseLeakyRelu(HloComputation* comp,
- se::GpuComputeCapability cc) {
- if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
- cc)) {
- return false;
- }
-
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction *gte1, *gte2, *gte3, *conv, *alpha;
- if (!Match(instr,
- m::Select(
- m::Compare(m::GetTupleElement(>e1, m::Op()),
- m::Broadcast(m::ConstantEffectiveScalar(0)))
- .WithComparisonDirection(ComparisonDirection::kGt)
- .WithOneUse(),
- m::GetTupleElement(
- >e2, m::Op(&conv)
- .WithPredicate(IsNonDepthwiseConvCustomCall)
- .WithOneUse())
- // TODO(jlebar): Why only fp16?
- .WithElementType(F16),
- m::Multiply(m::GetTupleElement(>e3, m::Op()),
- m::Broadcast(m::ConstantEffectiveScalar(&alpha)))
- .WithOneUse()))) {
- continue;
- }
-
- // The three GTEs should be the same, and these should be the only uses.
- if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
- if (config.activation_mode() != se::dnn::kNone) {
- continue;
- }
-
- if (!IsSuitableForCudnnRuntimeFusion(conv)) {
- continue;
- }
-
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseLeakyRelu: ", conv->ToString());
- })) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
- config.set_activation_mode(se::dnn::kLeakyRelu);
- TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
- config.set_leakyrelu_alpha(alpha_f64.GetFirstElement<double>());
- TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
- changed = true;
- }
- return changed;
-}
-
-absl::StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* gte = nullptr;
- HloInstruction* conv = nullptr;
-
- auto f32_convertible_to_f16_pat =
- m::Op().WithElementType(F32).WithPredicate(
- IsLosslesslyConvertibleToF16);
- if (!MatchAndLogIfFailed(
- instr, "f16 conv",
- m::Convert(
- m::GetTupleElement(
- >e,
- m::Op(&conv)
- .WithPredicate(IsConvCustomCall)
- .WithOperand(0, f32_convertible_to_f16_pat)
- .WithOperand(1, f32_convertible_to_f16_pat)
- .WithOperandIfPresent(2, f32_convertible_to_f16_pat)
- .WithOperandIfPresent(3, f32_convertible_to_f16_pat),
- 0)
- .WithOneUse())
- .WithElementType(F16),
- VLOG_IS_ON(3),
- m::Op().WithOperand(0, m::GetTupleElement(m::Op().WithPredicate(
- IsConvCustomCall))))) {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvertToF16: ", conv->ToString());
- })) {
- continue;
- }
-
- VLOG(2) << "Matched fp16 conv: " << conv->ToString();
-
- // In fp16 convs, all operands, including `bias`, must be fp16. This is
- // different from int8 convs, where the bias is fp32. See table of
- // supported datatypes at
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- absl::InlinedVector<HloInstruction*, 4> new_operands;
- for (HloInstruction* operand : conv->operands()) {
- new_operands.push_back(
- MakeConvertToHlo(operand, F16, &operand->metadata()));
- }
-
- Shape new_shape = conv->shape();
- new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(new_shape, new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
- changed = true;
- }
- return changed;
-}
-
-absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,
- se::GpuComputeCapability cc) {
- if (IsROCm(cc)) return false;
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* gte = nullptr;
- HloInstruction* conv = nullptr;
-
- auto conv_pattern =
- m::Op(&conv)
- .WithPredicate(IsConvCustomCall)
- .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
- .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
-
- PrimitiveType conv_output_ty;
- if (MatchAndLogIfFailed(
- instr, "s8->s8 conv",
- m::Convert(m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(-128)),
- m::GetTupleElement(
- >e,
- conv_pattern.WithOperandIfPresent(
- 3, m::Op().WithPredicate(
- IsLosslesslyConvertibleToS8)),
- 0)
- .WithOneUse(),
- m::Broadcast(m::ConstantEffectiveScalar(127))))
- .WithElementType(S8),
- VLOG_IS_ON(3),
- m::Convert(m::Clamp(m::Op(),
- m::GetTupleElement(
- m::Op().WithPredicate(IsConvCustomCall)),
- m::Op()))
- .WithElementType(S8))) {
- conv_output_ty = S8;
- } else if (MatchAndLogIfFailed(
- instr, "s8->f32 conv",
- m::GetTupleElement(>e,
- conv_pattern.WithOperandIfPresent(
- 3, m::Op().WithElementType(F32)),
- 0)
- .WithElementType(F32),
- VLOG_IS_ON(3),
- m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
- .WithElementType(F32))) {
- conv_output_ty = F32;
- } else {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvertToS8: ", conv->ToString());
- })) {
- continue;
- }
-
- absl::InlinedVector<HloInstruction*, 4> new_operands(
- conv->operands().begin(), conv->operands().end());
- new_operands[0] =
- MakeConvertToHlo(new_operands[0], S8, &new_operands[0]->metadata());
- new_operands[1] =
- MakeConvertToHlo(new_operands[1], S8, &new_operands[1]->metadata());
- // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn. See
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- if (new_operands.size() >= 4) {
- // side-input always matches conv output type. We checked in the patterns
- // above that it's losslessly-convertible to this type.
- new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty,
- &new_operands[3]->metadata());
- }
-
- Shape new_shape = conv->shape();
- new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(new_shape, new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
- changed = true;
- }
- return changed;
-}
-
-absl::Status CheckNoIllegalIntegerConvs(HloComputation* comp) {
- auto is_integral_not_s8 = [](const Shape& s) {
- return primitive_util::IsIntegralType(s.element_type()) &&
- s.element_type() != S8;
- };
-
- std::vector<HloInstruction*> bad_convs;
- for (HloInstruction* instr : comp->instructions()) {
- if (!IsConvCustomCall(instr)) {
- continue;
- }
- if (is_integral_not_s8(instr->shape().tuple_shapes(0)) ||
- is_integral_not_s8(instr->operand(0)->shape()) ||
- is_integral_not_s8(instr->operand(1)->shape()) ||
- (instr->operand_count() >= 4 &&
- is_integral_not_s8(instr->operand(3)->shape()))) {
- bad_convs.push_back(instr);
- }
- }
-
- if (bad_convs.empty()) {
- return absl::OkStatus();
- }
-
- return Unimplemented(
- R"(
-Can't lower one or more integer convolutions to idioms supported by CuDNN.
-
-CuDNN integer convolutions must have:
-
- - s8 input and filter,
- - f32 bias (if present),
- - s8 or f32 output, and
- - s8 side_input (if present) if output is s8.
-
-For each of the unsupported convs below, we weren't able to lower one of the
-operands or the output to the appropriate type.
-
-See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
-
-https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
-https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
-
-Unsupported convs:
-%s
-
-******* Full HLO module *******
-%s
-)",
- absl::StrJoin(bad_convs, "\n",
- [](std::string* out, HloInstruction* instr) {
- absl::StrAppend(out, " - ", instr->ToString());
- }),
- comp->parent()->ToString());
-}
-
-void VlogStats(HloModule* module) {
- if (!VLOG_IS_ON(1)) {
- return;
- }
-
- VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
- absl::flat_hash_map<std::string, int> stats;
- for (HloComputation* comp : module->MakeNonfusionComputations()) {
- for (HloInstruction* instr : comp->instructions()) {
- if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
- continue;
- }
-
- VLOG(3) << instr->ToString();
-
- if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
- ++stats["01 non-fused forward convs"];
- } else if (instr->custom_call_target() ==
- kCudnnConvBiasActivationForwardCallTarget) {
- ++stats["02 fused forward convs"];
- }
-
- PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
- PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
- if (conv_in_ty == F32) {
- ++stats["10 f32 convs"];
- } else if (conv_in_ty == F16) {
- ++stats["11 f16 convs"];
- } else if (conv_in_ty == S8) {
- if (conv_out_ty == S8) {
- ++stats["12 s8->s8 convs"];
- } else if (conv_out_ty == F32) {
- ++stats["13 s8->f32 convs"];
- } else {
- LOG(ERROR) << "Unexpected conv: " << instr->ToString();
- }
- }
-
- if (instr->operand_count() > 2) {
- ++stats["20 convs with bias"];
- if (Match(instr->operand(2),
- m::Broadcast(m::ConstantEffectiveScalar(0)))) {
- ++stats["21 convs with 0 bias"];
- }
- }
- if (instr->operand_count() > 3) {
- ++stats["22 convs with side-input"];
- }
-
- auto gpu_config = instr->backend_config<GpuBackendConfig>();
- if (!gpu_config.ok()) {
- LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
- continue;
- }
- const CudnnConvBackendConfig& config =
- gpu_config->cudnn_conv_backend_config();
- if (config.conv_result_scale() != 1) {
- ++stats["30 convs with result scale"];
- }
- if (config.side_input_scale() != 0 && config.side_input_scale() != 1) {
- ++stats["31 convs with side-input scale"];
- }
- ++stats[absl::StrCat(
- "32 convs with activation mode ",
- se::dnn::ActivationMode_Name(config.activation_mode()))];
- }
- }
-
- std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
- stats.end());
- absl::c_sort(stats_sorted);
- for (const auto& kv : stats_sorted) {
- VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
- absl::string_view(kv.first).substr(3));
- }
-}
-
-} // namespace
-
-absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool any_changed = false;
-
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- bool changed = false;
- // Rewrite FP8 convolutions and supported adjacent pointwise ops into a
- // ForwardGraph Custom Call.
- if (!IsROCm(compute_capability_)) {
- auto cc = std::get<se::CudaComputeCapability>(compute_capability_);
- TF_ASSIGN_OR_RETURN(
- changed, F8GraphConv(comp, cc, dnn_version_, toolkit_version_));
- if (changed) {
- return changed;
- }
- }
- // Fuse "inside out" starting with the operations closest to the conv.
- TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp));
- any_changed |= changed;
-
- TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
- any_changed |= changed;
-
- // s8 convs' bias and side-input appear before conversion to s8.
- //
- // Run FuseBiasOrSideInput twice, so we get both the bias and the side
- // input, if both are present.
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
- any_changed |= changed;
-
- // Relu might appear before or after convert-to-f16/s8, so we check in both
- // cases.
- TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
- any_changed |= changed;
-
- TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
- any_changed |= changed;
-
- TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_));
- any_changed |= changed;
-
- // f16 convs' bias+side-input can appear before or after conversion to f16.
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
- any_changed |= changed;
-
- TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
- any_changed |= changed;
-
- // Check that we don't have any convs outputting integer types other than
- // s8 - cudnn does not support these. They should have been transformed to
- // int8->int8 or int8->float above.
- TF_RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp));
- }
-
- VlogStats(module);
-
- return any_changed;
-}
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h
deleted file mode 100644
index 906a67a..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h
+++ /dev/null
@@ -1,135 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites custom-calls targeting cudnnConvolutionForward to
-// cudnnConvolutionBiasActivationForward by fusing operations following forward
-// convolution. This transform must run after GpuConvRewriter.
-//
-// Semantics of underlying cudnn ops:
-//
-// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#cudnnconvolutionforward
-// https://docs.nvidia.com/deeplearning/cudnn/latest/developer/misc.html#scaling-parameters
-//
-// ## Floating-point convs
-//
-// A "complete" fused floating-point conv has the form
-//
-// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)),
-//
-// which we fuse to
-//
-// cudnnConvolutionBiasActivationForward(x, w, bias, side_input).
-//
-// You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still
-// get a fused convolution. alpha1/2 must be broadcasts of scalar constants.
-//
-// f16 convs accumulate in f32. We represent this in HLO as an f32 convolution
-// whose inputs can be converted to f16 without loss of precision and whose
-// output is immediately converted to f16. A fused f16 conv must follow one of
-// the following idioms.
-//
-// 1. convert_f16(conv_f32(x_f32, w_f32)) +
-// side_input_f16 + broadcast(bias_f16)
-//
-// 2. convert_f16(conv_f32(x_f32, w_f32) +
-// side_input_f32 + broadcast(bias_f32))
-//
-// (These are not strictly mathematically equivalent, but cudnn doesn't tell us
-// which one it does, and we deem them "close enough".)
-//
-// The foo_f32 HLOs must all be losslessly-convertible to f16. Some valid
-// examples:
-//
-// - foo_f32 = convert_f32(foo_f16)
-// - foo_f32 = an f32 constant whose values all fit within f16
-// - foo_f32 = broadcast/transpose/reshape(one of the above)
-//
-// If you have a relu, it can appear before or after the convert_f16.
-//
-// Note that here `bias` must be losslessly-convertible to f16; this is
-// different than for s8 convolutions, where bias is f32.
-//
-// ## Integer convs
-//
-// In pure HLO, a "complete" integer conv is spelled as one of the following
-// `result`s.
-//
-// base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) +
-// alpha2_f32 * side_input +
-// bias_f32
-//
-// result_f32 = max(base, 0)
-// result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0)
-// result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127))
-//
-// The foo_s32 HLOs must be losslessly-convertible to s8. If the `result_s8`
-// case, side_input should be an f32 HLO that's losslessly-convertible to s8;
-// otherwise, it should be losslessly-convertible to f32.
-//
-// In the `result_s8` case where there's no bias, side-input, or alpha1, you can
-// skip the convert_f32 on conv.
-//
-// If you have an integer convolution that doesn't fit one of these idioms, this
-// pass returns an error -- cudnn will not be able to run it.
-class CudnnFusedConvRewriter : public HloModulePass {
- public:
- CudnnFusedConvRewriter(se::CudaComputeCapability cc,
- se::dnn::VersionInfo dnn_version,
- int32_t toolkit_version)
- : compute_capability_(cc),
- dnn_version_(dnn_version),
- toolkit_version_(toolkit_version) {}
- CudnnFusedConvRewriter(se::RocmComputeCapability cc,
- se::dnn::VersionInfo dnn_version,
- int32_t toolkit_version)
- : compute_capability_(cc),
- dnn_version_(dnn_version),
- toolkit_version_(toolkit_version) {}
-
- absl::string_view name() const override {
- return "cudnn-fused-convolution-rewriter";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const se::GpuComputeCapability compute_capability_;
- const se::dnn::VersionInfo dnn_version_;
- const int32_t toolkit_version_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
deleted file mode 100644
index 0a58ecf..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
+++ /dev/null
@@ -1,3170 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-
-#include <array>
-#include <memory>
-#include <string>
-#include <string_view>
-#include <thread> // NOLINT
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#elif TENSORFLOW_USE_ROCM
-#include "rocm/rocm_config.h"
-#endif // GOOGLE_CUDA
-
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/convert_mover.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/hlo_constant_folding.h"
-#include "xla/service/hlo_pass_fix.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/reshape_mover.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// TODO(b/210165681): The tests in this file are fragile to HLO op names.
-
-namespace m = match;
-
-using ::testing::HasSubstr;
-using ::testing::Not;
-
-// TODO: Use constexpr vector once XLA is compiled with C++20.
-const auto* kf16f32f64 = new std::vector<std::string>({"f16", "f32", "f64"});
-const auto* kf16f32 = new std::vector<std::string>({"f16", "f32"});
-
-class CudnnFusedConvRewriterHloTest : public HloTestBase {
- public:
- bool IsCuda() {
- return std::holds_alternative<se::CudaComputeCapability>(
- backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .gpu_compute_capability());
- }
- se::CudaComputeCapability GetCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
- stream_executor::dnn::VersionInfo GetDnnVersion() {
- return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
- }
-
- int32_t GetToolkitVersion() const {
-#if GOOGLE_CUDA
- return CUDA_VERSION;
-#elif TENSORFLOW_USE_ROCM
- return TF_ROCM_VERSION;
-#endif
- return 0;
- }
-
- CudnnFusedConvRewriterHloTest()
- : HloTestBase(/*verifier_layout_sensitive=*/false,
- /*allow_mixed_precision_in_hlo_verifier=*/false,
- /*instruction_can_change_layout_func=*/{}) {}
-};
-
-class CudnnFusedConvRewriterTest : public GpuCodegenTest {
- public:
- bool IsCuda() {
- return std::holds_alternative<se::CudaComputeCapability>(
- backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .gpu_compute_capability());
- }
- se::CudaComputeCapability GetCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
- stream_executor::dnn::VersionInfo GetDnnVersion() {
- return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
- }
-
- int32_t GetToolkitVersion() const {
-#if GOOGLE_CUDA
- return CUDA_VERSION;
-#elif TENSORFLOW_USE_ROCM
- return TF_ROCM_VERSION;
-#endif
- return 0;
- }
-
- protected:
- std::string GetOptimizedHlo(absl::string_view hlo_string) {
- // cudnn_vectorize_convolutions transforms convolutions, making it hard to
- // match them here in this test. What's worse, the transforms it does
- // depends on the GPU that's available! So just disable them for this
- // function that gets the optimized HLO. When we actually run the module
- // we'll still have this pass enabled.
- HloModuleConfig config = GetModuleConfigForTest();
- DebugOptions debug_opts = config.debug_options();
- debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- config.set_debug_options(debug_opts);
-
- auto result = backend().compiler()->RunHloPasses(
- ParseAndReturnVerifiedModule(hlo_string, config).value(),
- backend().default_stream_executor(), backend().memory_allocator());
- if (!result.status().ok()) {
- TF_EXPECT_OK(result.status())
- << "HLO compilation failed: " << result.status();
- return "";
- }
- HloPrintOptions print_opts;
- print_opts.set_print_operand_shape(false);
- return (*result)->ToString(print_opts);
- }
-
- void TestMatchWithAllTypes(absl::string_view hlo_string) {
- for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
- const std::string hlo_with_new_type =
- absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
- std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
- EXPECT_THAT(optimized_hlo_string,
- Not(HasSubstr(kCudnnConvForwardCallTarget)))
- << optimized_hlo_string;
- EXPECT_THAT(optimized_hlo_string,
- HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
-
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_with_new_type));
- DebugOptions debug_opts = module->config().debug_options();
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- module->mutable_config().set_debug_options(debug_opts);
- EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01}))
- << optimized_hlo_string;
- }
- }
-
- void TestClamp(absl::string_view pre_hlo_string,
- absl::string_view post_hlo_string) {
- std::string alpha_conv_scalar, alpha_side_input_scalar;
- std::string elementwise_type;
-
- std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
- EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
- EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
- EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01}))
- << pre_hlo_string;
-
- absl::StatusOr<bool> filecheck_result =
- RunFileCheck(optimized_hlo_string, post_hlo_string);
- ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
- EXPECT_TRUE(*filecheck_result);
- }
-
- void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
- for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
- const std::string hlo_with_new_type =
- absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
- std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
- SCOPED_TRACE(optimized_hlo_string);
- EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
- EXPECT_THAT(optimized_hlo_string,
- Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
- }
- }
-
- void TestF8(std::string pre_hlo_string, std::string custom_call_string,
- std::string serialized_graph_string) {
- if (!IsCuda()) return;
- if (GetCudaComputeCapability().IsAtLeast(
- se::CudaComputeCapability::HOPPER)) {
- // On Hopper and newer architectures, test numerical correctness and
- // verify the HLO of the Custom Call with operand and return layouts and
- // the serialized graph based on the full compiler pipeline.
- std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
- EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
- EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
- EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.15, 0.15}))
- << pre_hlo_string;
-
- absl::StatusOr<bool> filecheck_result =
- RunFileCheck(optimized_hlo_string, custom_call_string);
- ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
- EXPECT_TRUE(*filecheck_result);
-
- filecheck_result =
- RunFileCheck(optimized_hlo_string, serialized_graph_string);
- ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
- EXPECT_TRUE(*filecheck_result);
- } else {
- // On older architectures, disregard layout information and only verify
- // the basic configuration of the convolution Custom Call using the number
- // of operands and the window_size and serialized graph attributes based
- // on the GpuConvRewriter and CudnnFusedConvRewriter passes.
- std::string::size_type p0 = custom_call_string.find(':');
- std::string::size_type p1 = custom_call_string.find("custom-call");
- custom_call_string.erase(p0 + 1, p1 - p0 - 2);
- p0 = custom_call_string.find(", dim_labels");
- custom_call_string.erase(p0);
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(pre_hlo_string));
- TF_ASSERT_OK_AND_ASSIGN(
- bool changed, RunHloPass(GpuConvRewriter(GetCudaComputeCapability()),
- module.get()));
- EXPECT_TRUE(changed);
- RunAndFilecheckHloRewrite(
- module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
- CudnnFusedConvRewriter(
- se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
- GetDnnVersion(), GetToolkitVersion()),
- custom_call_string);
- RunAndFilecheckHloRewrite(
- module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
- CudnnFusedConvRewriter(
- se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
- GetDnnVersion(), GetToolkitVersion()),
- serialized_graph_string);
- }
- }
-
- void TestF8Parameterized(std::string template_pre_hlo_string,
- std::string template_custom_call_string,
- std::string template_serialized_graph_string) {
- std::array<absl::string_view, 2> types = {"f8e4m3fn", "f8e5m2"};
- std::array<absl::string_view, 2> clamp_lower = {"-448.", "-57344."};
- std::array<absl::string_view, 2> clamp_upper = {"448.", "57344."};
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- for (int i = 0; i < 2; ++i) {
- replacements["<<InputType>>"] = types[i];
- for (int j = 0; j < 2; ++j) {
- replacements["<<FilterType>>"] = types[j];
- for (int k = 0; k < 2; ++k) {
- replacements["<<OutputType>>"] = types[k];
- replacements["<<ClampLower>>"] = clamp_lower[k];
- replacements["<<ClampUpper>>"] = clamp_upper[k];
- TestF8(absl::StrReplaceAll(template_pre_hlo_string, replacements),
- absl::StrReplaceAll(template_custom_call_string, replacements),
- absl::StrReplaceAll(template_serialized_graph_string,
- replacements));
- }
- }
- }
- }
-};
-
-#if GOOGLE_CUDA
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900)
-#define MAYBE_SKIP_TEST(CAUSE) \
- do { \
- if (absl::string_view(CAUSE) == "F8") \
- GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; \
- } while (0)
-#else
-#define MAYBE_SKIP_TEST(CAUSE)
-#endif
-#else
-#define MAYBE_SKIP_TEST(CAUSE) \
- do { \
- GTEST_SKIP() << "ROCm does not support " CAUSE " fusion"; \
- } while (0)
-#endif
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) {
- // max(0, conv(x, w));
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
-
- input = TYPE[1,17,9,9] parameter(0)
- filter = TYPE[3,3,17,32] parameter(1)
-
- conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseReluWithDepthwiseConv) {
- // max(0, conv(x, w));
- TestNotMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
-
- input = TYPE[1,17,9,9] parameter(0)
- filter = TYPE[3,3,1,17] parameter(1)
-
- conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
- ROOT relu = TYPE[1,17,9,9] maximum(zeros, conv)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBias) {
- // max(0, conv(x, w) + bias);
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- bias = TYPE[64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, Test3D) {
- // max(0, conv(x, w) + bias);
- std::string body = R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,5,7,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,5,7,64] parameter(0)
- filter = TYPE[3,3,3,64,64] parameter(1)
- bias = TYPE[64] parameter(2)
-
- conv = TYPE[1,3,5,7,64] convolution(input, filter), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f, feature_group_count=1
- broadcasted_bias = TYPE[1,3,5,7,64] broadcast(bias), dimensions={4}
- add1 = TYPE[1,3,5,7,64] add(conv, broadcasted_bias)
- )";
-
- std::string relu = R"(
- ROOT relu = TYPE[1,3,5,7,64] maximum(zeros, add1)
- })";
-
- std::string elu = R"(
- cmp = pred[1,3,5,7,64] compare(add1, zeros), direction=GT
- expm1 = TYPE[1,3,5,7,64] exponential-minus-one(add1)
- ROOT elu = TYPE[1,3,5,7,64] select(cmp, add1, expm1)
- })";
-
- TestMatchWithAllTypes(body + relu);
- if (!IsCuda()) TestMatchWithAllTypes(body + elu);
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBiasMultiCall) {
- // max(0, conv(x, w) + bias);
- std::string code = R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,<<<format>>>,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,<<<format>>>,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- bias = TYPE[64] parameter(2)
-
- conv = TYPE[1,<<<format>>>,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- broadcasted_bias = TYPE[1,<<<format>>>,64] broadcast(bias), dimensions={3}
- add1 = TYPE[1,<<<format>>>,64] add(conv, broadcasted_bias)
- ROOT relu = TYPE[1,<<<format>>>,64] maximum(zeros, add1)
- })";
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<<format>>>"] = "3,3";
- TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
- replacements["<<<format>>>"] = "5,5";
- TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
- replacements["<<<format>>>"] = "3,3";
- TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBiasNoRelu) {
- // conv(x, w) + bias;
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- bias = TYPE[64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- ROOT add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseBiasWithDepthwiseConv) {
- // conv(x, w) + bias;
- TestNotMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,1,64] parameter(1)
- bias = TYPE[64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestElu) {
- // sum = conv(x, w) + bias
- // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- bias = TYPE[64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
- expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
- ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) {
- // sum = conv(x, w) + bias
- // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
- TestNotMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,1,64] parameter(1)
- bias = TYPE[64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
- expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
- ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestRelu6) {
- if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
- se::CudaComputeCapability::AMPERE)) {
- GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
- "the Nvidia Ampere+ GPUs.";
- }
- // sum = conv(x, w) + bias
- // clamp(0, sum, 6);
- TestMatchWithAllTypes(R"(
- HloModule Test
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
- six = TYPE[] constant(6)
- sixes = TYPE[1,3,3,64] broadcast(six), dimensions={}
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- bias = TYPE[64] parameter(2)
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- ROOT relu6 = TYPE[1,3,3,64] clamp(zeros, sum, sixes)
- })");
-}
-
-// At time of writing, cudnn runtime fusion cannot handle f16 convs with an odd
-// number of input/output channels. Check that we don't try to run this conv
-// with runtime fusion (or, if we do, that it works!).
-TEST_F(CudnnFusedConvRewriterTest, TestRelu6OddChannels) {
- if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
- se::CudaComputeCapability::AMPERE)) {
- GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
- "the Nvidia Ampere+ GPUs.";
- }
- TestMatchWithAllTypes(R"(
- HloModule Test
- ENTRY Test {
- zeros = TYPE[1,384,1024,32] broadcast(TYPE[] constant(0)), dimensions={}
- sixes = TYPE[1,384,1024,32] broadcast(TYPE[] constant(6)), dimensions={}
- input = TYPE[1,769,2049,3] parameter(0)
- filter = TYPE[32,3,3,3] parameter(1)
- bias = TYPE[32] parameter(2)
- conv = TYPE[1,384,1024,32] convolution(input, filter), window={size=3x3 stride=2x2}, dim_labels=b01f_o01i->b01f
- broadcasted_bias = TYPE[1,384,1024,32] broadcast(bias), dimensions={3}
- sum = add(conv, broadcasted_bias)
- ROOT relu6 = clamp(zeros, sum, sixes)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) {
- if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
- se::CudaComputeCapability::AMPERE)) {
- GTEST_SKIP()
- << "Conv-Bias-LeakyRelu fusion is supported and recommended with "
- "the Nvidia Ampere+ GPUs.";
- }
- // sum = conv(x, w) + bias
- // select(compare(sum, 0, GT), sum, multiply(sum, alpha));
- TestMatchWithAllTypes(R"(
- HloModule Test
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
- alpha = TYPE[] constant(0.2)
- alphas = TYPE[1,3,3,64] broadcast(alpha), dimensions={}
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- bias = TYPE[64] parameter(2)
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
- mul = TYPE[1,3,3,64] multiply(sum, alphas)
- ROOT elu = TYPE[1,3,3,64] select(cmp, sum, mul)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) {
- // max(0, conv(x, w) + side_input);
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- side_input = TYPE[1,3,3,64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- add1 = TYPE[1,3,3,64] add(conv, side_input)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseSideInputWithDepthwiseConv) {
- // max(0, conv(x, w) + side_input);
- TestNotMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,1,64] parameter(1)
- side_input = TYPE[1,3,3,64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
- add1 = TYPE[1,3,3,64] add(conv, side_input)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) {
- // max(0, conv(x, w) + side_input + bias);
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- side_input = TYPE[1,3,3,64] parameter(2)
- bias = TYPE[64] parameter(3)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
- add2 = TYPE[1,3,3,64] add(add1, side_input)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) {
- // max(0, 0.999994934 * conv(x, w));
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
- alpha_conv_scalar = TYPE[] constant(0.999994934)
-
- input = TYPE[1,17,9,9] parameter(0)
- filter = TYPE[3,3,17,32] parameter(1)
-
- conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
- scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
- ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseScaledDepthwiseConv) {
- // max(0, 0.999994934 * conv(x, w));
- TestNotMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
- alpha_conv_scalar = TYPE[] constant(0.999994934)
-
- input = TYPE[1,17,9,9] parameter(0)
- filter = TYPE[3,3,1,17] parameter(1)
-
- conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
- alpha_conv = TYPE[1,17,9,9] broadcast(alpha_conv_scalar), dimensions={}
- scaled_conv = TYPE[1,17,9,9] multiply(conv, alpha_conv)
- ROOT relu = TYPE[1,17,9,9] maximum(zeros, scaled_conv)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) {
- EXPECT_TRUE(RunAndCompare(R"(
- HloModule Test
-
- ENTRY Test {
- zero = f32[] constant(inf)
- zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
- alpha_conv_scalar = f32[] constant(0.999994934)
-
- input = f32[1,17,9,9] parameter(0)
- filter = f32[3,3,17,32] parameter(1)
-
- conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
- scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv)
- ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv)
- })",
- ErrorSpec{0.01}));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvAndScaledSideInput) {
- // max(0, conv(x, w) + 0.899994934 * side_input);
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
- alpha_side_input_scalar = TYPE[] constant(0.899994934)
- alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- side_input = TYPE[1,3,3,64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
- add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseDepthwiseConvWithScaledSideInput) {
- // max(0, conv(x, w) + 0.899994934 * side_input);
- TestNotMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
- alpha_side_input_scalar = TYPE[] constant(0.899994934)
- alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,1,64] parameter(1)
- side_input = TYPE[1,3,3,64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
- scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
- add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) {
- // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
- alpha_conv_scalar = TYPE[] constant(0.999994934)
- alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
- alpha_side_input_scalar = TYPE[] constant(0.899994934)
- alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- side_input = TYPE[1,3,3,64] parameter(2)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
- scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
- add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) {
- // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
- TestMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- zero = TYPE[] constant(0)
- zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
- alpha_conv_scalar = TYPE[] constant(0.999994934)
- alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
- alpha_side_input_scalar = TYPE[] constant(0.899994934)
- alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
- input = TYPE[1,3,3,64] parameter(0)
- filter = TYPE[3,3,64,64] parameter(1)
- side_input = TYPE[1,3,3,64] parameter(2)
- bias = TYPE[64] parameter(3)
-
- conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
- scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
- scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
- broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
- add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
- ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) {
- // max(0.1, conv(x, w)) shouldn't match.
- TestNotMatchWithAllTypes(R"(
- HloModule Test
-
- ENTRY Test {
- point_one = TYPE[] constant(0.1)
- point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
-
- input = TYPE[1,17,9,9] parameter(0)
- filter = TYPE[3,3,17,32] parameter(1)
-
- conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
- })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
- const char* kHloString = R"(
- HloModule Test
-
- ENTRY Test {
- zero = f32[] constant(0)
- zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
-
- input = f32[1,17,9,9] parameter(0)
- filter = f32[3,3,17,32] parameter(1)
-
- conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
- ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
- })";
-
- const std::string optimized_hlo_string =
- backend()
- .compiler()
- ->RunHloPasses(
- ParseAndReturnVerifiedModule(kHloString, GetModuleConfigForTest())
- .value(),
- backend().default_stream_executor(), backend().memory_allocator())
- .value()
- ->ToString();
- EXPECT_THAT(optimized_hlo_string,
- ::testing::ContainsRegex(
- R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
- // The convolution below would crash if feature_count is not preserved.
- const char* kHloString = R"(
- HloModule jaxpr_computation__6.19
-
- primitive_computation__1.4 {
- parameter.5 = f32[] parameter(0)
- parameter.6 = f32[] parameter(1)
- ROOT add.7 = f32[] add(parameter.5, parameter.6)
- }
-
- ENTRY jaxpr_computation__7.8 {
- parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1)
- parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0)
- convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53
- constant.13 = f32[] constant(0)
- broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={}
- maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14)
- ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4
- }
- )";
- EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01}));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- ROOT conv_a = f8e4m3fn[1,16,6,6] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f8e4m3fn]conv();"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_f32 = f32[1,128,6,6] convert(input)
- filter_f32 = f32[3,3,128,16] convert(filter)
- z_scale = f32[] parameter(2)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
- ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE_UID:[0-9]+]]:[f8e4m3fn]scale([[CONV_UID]]);"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledOutputF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_f32 = f32[1,128,6,6] convert(input)
- filter_f32 = f32[3,3,128,16] convert(filter)
- z_scale = f32[] parameter(2)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- conv_a_scaled = f32[1,16,6,6] divide(conv_a, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
- ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]invscale([[CONV_UID]]);"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8Parameterized) {
- MAYBE_SKIP_TEST("F8");
- TestF8Parameterized(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = <<InputType>>[1,128,6,6] parameter(0)
- filter = <<FilterType>>[3,3,128,16] parameter(1)
- input_scale = f32[] parameter(2)
- input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
- filter_scale = f32[] parameter(3)
- filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
- input_f32 = f32[1,128,6,6] convert(input)
- input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
- filter_f32 = f32[3,3,128,16] convert(filter)
- filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
- z_scale = f32[] parameter(4)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
- c1 = f32[] constant(<<ClampLower>>)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(<<ClampUpper>>)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
- ROOT conv_f8 = <<OutputType>>[1,16,6,6] convert(conv_a_clamped)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (<<OutputType>>[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[<<OutputType>>]scale([[SCALE1_UID]]);"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_scale = f32[] parameter(2)
- input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
- filter_scale = f32[] parameter(3)
- filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
- input_f32 = f32[1,128,6,6] convert(input)
- input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
- filter_f32 = f32[3,3,128,16] convert(filter)
- filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
- bias = f32[1,16,6,6] parameter(4)
- z_scale = f32[] parameter(5)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- conv_a_bias = f32[1,16,6,6] add(conv_a, bias)
- conv_a_scaled = f32[1,16,6,6] multiply(conv_a_bias, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
- ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]], /*index=5*/[[OPERAND5:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[ADD_UID:[0-9]+]]:[f32]add([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[ADD_UID]]);"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_f32 = f32[1,128,6,6] convert(input)
- filter_f32 = f32[3,3,128,16] convert(filter)
- z_scale = f32[] parameter(2)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- c = f32[] constant(0)
- c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
- relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
- ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[RELU_UID:[0-9]+]]:[f32]relu([[CONV_UID]]);[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvAmaxF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] maximum(a, b)
- }
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_scale = f32[] parameter(2)
- input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
- filter_scale = f32[] parameter(3)
- filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
- input_f32 = f32[1,128,6,6] convert(input)
- input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
- filter_f32 = f32[3,3,128,16] convert(filter)
- filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
- conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- z_scale = f32[] parameter(4)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
- conv_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
- abs_conv_a = f32[1,16,6,6] abs(conv_a)
- c0 = f32[] constant(-inf)
- amax = f32[] reduce(abs_conv_a, c0), dimensions={0,1,2,3}, to_apply=apply
- ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(conv_a_clamped_f8, amax)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[SCALE1_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[SCALE1_UID]]);"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvReluAmaxF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] maximum(a, b)
- }
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_scale = f32[] parameter(2)
- input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
- filter_scale = f32[] parameter(3)
- filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
- input_f32 = f32[1,128,6,6] convert(input)
- input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
- filter_f32 = f32[3,3,128,16] convert(filter)
- filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
- c = f32[] constant(0)
- c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
- z_scale = f32[] parameter(4)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
- relu_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
- abs_relu_a = f32[1,16,6,6] abs(relu_a)
- c0 = f32[] constant(-inf)
- amax = f32[] reduce(abs_relu_a, c0), dimensions={0,1,2,3}, to_apply=apply
- ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(relu_a_clamped_f8, amax)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[RELU_UID:[0-9]+]]:[f32]relu([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[RELU_UID]]);"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputMultipleUsersF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_f32 = f32[1,128,6,6] convert(input)
- filter_f32 = f32[3,3,128,16] convert(filter)
- z_scale0 = f32[] parameter(2)
- z_scale0_bcast = f32[1,16,6,6] broadcast(z_scale0), dimensions={}
- z_scale1 = f32[] parameter(3)
- z_scale1_bcast = f32[1,16,6,6] broadcast(z_scale1), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- conv_a_scaled0 = f32[1,16,6,6] multiply(conv_a, z_scale0_bcast)
- conv_a_scaled1 = f32[1,16,6,6] multiply(conv_a, z_scale1_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- conv_a_clamped0 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled0, c2_bcast)
- conv_a_clamped1 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled1, c2_bcast)
- conv_a_convert0 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped0)
- conv_a_convert1 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped1)
- ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f8e4m3fn[1,16,6,6]) tuple(conv_a_convert0, conv_a_convert1)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputUnsupportedUserF8) {
- MAYBE_SKIP_TEST("F8");
- TestF8(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- input = f8e4m3fn[1,128,6,6] parameter(0)
- filter = f8e4m3fn[3,3,128,16] parameter(1)
- input_f32 = f32[1,128,6,6] convert(input)
- filter_f32 = f32[3,3,128,16] convert(filter)
- z_scale = f32[] parameter(2)
- z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
- conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- conv_a_cos = f32[1,16,6,6] cosine(conv_a)
- conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
- conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
- conv_a_convert = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
- ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[1,16,6,6]) tuple(conv_a_convert, conv_a_cos)
-
- })",
- // custom_call
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )",
- // serialized_graph
- R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) {
- MAYBE_SKIP_TEST("I8");
- // max(0, clamp(conv(x, w)))); for int8_t
- TestClamp(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- zero = s8[] constant(0)
- zeros = s8[1,32,9,9] broadcast(zero), dimensions={}
-
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
-
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
- lower = s32[] constant(-128)
- lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
- upper = s32[] constant(127)
- uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
-
- clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
-
- ROOT convert = s8[1,32,9,9] convert(clamp)
- })",
- // post_hlo
- R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (s8[1,9,9,32]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[fusion_2_1:%[^ ]+]], [[fusion_1_2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
- MAYBE_SKIP_TEST("I8");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
-
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
-
- ROOT convert = f32[1,32,9,9] convert(conv)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvForwardCallTarget}), 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
- MAYBE_SKIP_TEST("I8");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
-
- conv = s32[1,32,9,9] convolution(input, filter),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_f32 = f32[1,32,9,9] convert(conv)
- ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
- add(add(conv_f32, bias), side_input),
- f32[1,32,9,9] broadcast(f32[] constant(127))))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1),
- m::Parameter(2), m::Parameter(3)),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
- MAYBE_SKIP_TEST("I8");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-
- conv = s32[1,32,9,9] convolution(input, filter),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
- conv,
- s32[1,32,9,9] broadcast(s32[] constant(127))))
- zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
- ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), //
- m::Parameter(1), //
- m::Broadcast(
- m::ConstantEffectiveScalar(0).WithElementType(F32))),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
- MAYBE_SKIP_TEST("I8");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- side_input_f32 = f32[1,32,9,9] parameter(3)
-
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_f32 = f32[1,32,9,9] convert(conv)
- sum1 = add(conv_f32, bias_broadcast)
- ROOT sum2 = add(sum1, side_input_f32)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1),
- m::Parameter(2), m::Parameter(3)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-// The ReshapeMover pass changes
-// reshape(side_input) * alpha -->
-// reshape(side_input * alpha).
-// Make sure we can pattern-match this.
-TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) {
- MAYBE_SKIP_TEST("I8");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
- side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
-
- conv = s32[1,32,9,9] convolution(input, filter),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
- add(add(f32[1,32,9,9] convert(conv), bias), side_input),
- f32[1,32,9,9] broadcast(f32[] constant(127))))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- HloPassFix<HloPassPipeline> simplify("simplify");
- simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
- simplify.AddPass<ReshapeMover>();
- simplify.AddPass<ConvertMover>();
- TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), //
- m::Parameter(1), //
- m::Parameter(2), //
- m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.side_input_scale(), 0.25);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
- MAYBE_SKIP_TEST("I8");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
- alpha = f32[] constant(42)
- alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- convert = f32[1,32,9,9] convert(conv)
- ROOT root = multiply(convert, alpha_broadcast)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.conv_result_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- zero = f32[] constant(0)
- zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias_broadcast)
- ROOT relu = maximum(sum, zeros)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias)
- relu = maximum(sum, zeros)
- not_relu = minimum(sum, zeros)
- ROOT root = tuple(relu, not_relu)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::MaximumAnyOrder(
- m::Broadcast(m::ConstantEffectiveScalar(0)),
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})),
- m::Minimum())));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,16,9,9] parameter(0)
- filters = f16[3,3,16,32] parameter(1)
- bias = f16[32] parameter(2)
- bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
- zero = f16[] constant(0)
- zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
- conv = f16[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias_broadcast)
- cmp = compare(sum, zeros), direction=GT
- expm1 = exponential-minus-one(sum)
- ROOT elu = select(cmp, sum, expm1)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- DebugOptions debug_opts = m->config().debug_options();
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- m->mutable_config().set_debug_options(debug_opts);
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- // elu fusion is only active on Ampere+.
- CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kElu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,16,9,9] parameter(0)
- filters = f16[3,3,16,32] parameter(1)
- bias = f16[32] parameter(2)
- bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
- zero = f16[] constant(0)
- zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
- conv = f16[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias_broadcast)
- cmp = compare(sum, zeros), direction=GT
- expm1 = exponential-minus-one(sum)
- elu = select(cmp, sum, expm1)
- not_elu = minimum(sum, zeros)
- ROOT root = tuple(elu, not_elu)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- DebugOptions debug_opts = m->config().debug_options();
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- m->mutable_config().set_debug_options(debug_opts);
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- auto gte_pattern =
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F16, {1, 32, 9, 9});
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Select(m::Compare(gte_pattern,
- m::Broadcast(m::ConstantEffectiveScalar(0)))
- .WithComparisonDirection(ComparisonDirection::kGt),
- gte_pattern,
- m::Op()
- .WithPredicate(HloPredicateIsOp<HloOpcode::kExpm1>)
- .WithOperand(0, gte_pattern)),
- m::Minimum())));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) {
- const std::string module_str = R"(
- HloModule Test
- ENTRY Test {
- inputs = f16[1,18,9,9] parameter(0)
- filters = f16[3,3,18,32] parameter(1)
- bias = f16[32] parameter(2)
- bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
- zero = f16[] constant(0)
- zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
- sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
- conv = f16[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias_broadcast)
- ROOT relu = clamp(zeros, sum, sixes)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- DebugOptions debug_opts = m->config().debug_options();
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- m->mutable_config().set_debug_options(debug_opts);
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- // relu6 fusion is only enabled on Ampere+.
- CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu6);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) {
- const std::string module_str = R"(
- HloModule Test
- ENTRY Test {
- inputs = f16[1,18,9,9] parameter(0)
- filters = f16[3,3,18,32] parameter(1)
- bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
- zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
- sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
- conv = f16[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias)
- relu = clamp(zeros, sum, sixes)
- not_relu = minimum(sum, zeros)
- ROOT root = tuple(relu, not_relu)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- DebugOptions debug_opts = m->config().debug_options();
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- m->mutable_config().set_debug_options(debug_opts);
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F16, {1, 32, 9, 9}),
- m::Broadcast(m::ConstantEffectiveScalar(6))),
- m::Minimum())));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) {
- const std::string module_str = R"(
- HloModule Test
- ENTRY Test {
- inputs = f16[1,16,9,9] parameter(0)
- filters = f16[3,3,16,32] parameter(1)
- bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
- zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
- alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
- conv = f16[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias)
- cmp = compare(sum, zeros), direction=GT
- mul = multiply(sum, alphas)
- ROOT leaky_relu = select(cmp, sum, mul)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- DebugOptions debug_opts = m->config().debug_options();
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- m->mutable_config().set_debug_options(debug_opts);
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- // Leaky-relu fusion is only enabled on Ampere+.
- CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kLeakyRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) {
- const std::string module_str = R"(
- HloModule Test
- ENTRY Test {
- inputs = f16[1,16,9,9] parameter(0)
- filters = f16[3,3,16,32] parameter(1)
- bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
- zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
- alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
- conv = f16[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias)
- cmp = compare(sum, zeros), direction=GT
- mul = multiply(sum, alphas)
- leaky_relu = select(cmp, sum, mul)
- not_leaky_relu = minimum(sum, zeros)
- ROOT root = tuple(leaky_relu, not_leaky_relu)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- DebugOptions debug_opts = m->config().debug_options();
- debug_opts.set_xla_gpu_use_runtime_fusion(true);
- m->mutable_config().set_debug_options(debug_opts);
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- auto gte_pattern =
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F16, {1, 32, 9, 9});
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Select(m::Compare(gte_pattern,
- m::Broadcast(m::ConstantEffectiveScalar(0)))
- .WithComparisonDirection(ComparisonDirection::kGt)
- .WithOneUse(),
- gte_pattern,
- m::Multiply(gte_pattern,
- m::Broadcast(m::ConstantEffectiveScalar()))),
- m::Minimum())));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(multiply(alpha, conv), bias)
- ROOT root = tuple(conv, sum)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
- m::MultiplyAnyOrder(
- m::Broadcast(m::Parameter(3)),
- m::GetTupleElement(m::CustomCall(&conv2), 0))))));
- EXPECT_EQ(conv1, conv2);
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv1->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = tuple(conv, add(conv, bias))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
- m::GetTupleElement(m::CustomCall(&conv2), 0)))));
- EXPECT_EQ(conv1, conv2);
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv1->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- side_input = f32[1,32,9,9] parameter(2)
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
- ROOT root = add(relu, side_input)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::AddAnyOrder(
- m::Parameter(2),
- m::GetTupleElement(
- m::CustomCall(&conv, m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::ConstantEffectiveScalar(0))),
- 0))));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
- ROOT root = add(relu, bias)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::AddAnyOrder(
- m::Broadcast(m::Parameter(2)),
- m::GetTupleElement(m::CustomCall(
- &conv, m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::ConstantEffectiveScalar(0)))))));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- side_input = f32[1,32,9,9] parameter(2)
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = tuple(conv, add(conv, side_input))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::AddAnyOrder(m::Parameter(2),
- m::GetTupleElement(m::CustomCall(&conv2), 0)))));
- EXPECT_EQ(conv1, conv2);
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv1->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
- filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
- EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_s8 = s8[1,32,9,9] convert(clamp(
- f32[1,32,9,9] broadcast(f32[] constant(-128)),
- conv,
- f32[1,32,9,9] broadcast(f32[] constant(127))))
- ROOT root = tuple(conv, conv_s8)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::Convert(m::Clamp(m::Op(), //
- m::GetTupleElement(m::CustomCall(&conv2), 0),
- m::Op())))));
- EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) {
- MAYBE_SKIP_TEST("I8");
- const std::string_view module_str = R"(
- HloModule Test
-
- ENTRY test_entry {
- inputs = s8[1, 17, 9, 9] parameter(0)
- filters = s8[3, 3, 17, 32] parameter(1)
- mult_op = f32[1, 32, 9, 9] parameter(2)
- conv = s32[1, 32, 9, 9] convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
- ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
- SCOPED_TRACE(m->ToString());
- HloInstruction* conv1 = nullptr;
- // Checks that it removed the Convert inside multiply around conv.
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
- m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) {
- MAYBE_SKIP_TEST("I8");
- const std::string_view module_str = R"(
- HloModule Test
-
- ENTRY test_entry {
- inputs = s8[1, 17, 9, 9] parameter(0)
- filters = s8[3, 3, 17, 32] parameter(1)
- mult_op = f32[1, 32, 9, 9] parameter(2)
- conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
- ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
- SCOPED_TRACE(m->ToString());
- HloInstruction* conv1 = nullptr;
- // Checks that it removed the Convert inside multiply around conv.
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
- m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) {
- MAYBE_SKIP_TEST("I8");
- const std::string_view module_str = R"(
- HloModule Test
-
- ENTRY test_entry {
- inputs = f32[1, 17, 9, 9] parameter(0)
- filters = f32[3, 3, 17, 32] parameter(1)
- mult_op = s8[1, 32, 9, 9] parameter(2)
- conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
- ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
- SCOPED_TRACE(m->ToString());
- HloInstruction* conv1 = nullptr;
- // Checks that it removed the Convert inside multiply around conv.
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
- m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) {
- const std::string_view module_str = R"(
- HloModule Test
-
- ENTRY test_entry {
- inputs = f32[1, 17, 9, 9] parameter(0)
- filters = f32[3, 3, 17, 32] parameter(1)
- mult_op = s8[1, 32, 9, 9] parameter(2)
- sub_op = s8[1, 32, 9, 9] parameter(3)
- conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
- another = subtract(s8[1, 32, 9, 9] convert(conv), sub_op)
- ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
- SCOPED_TRACE(m->ToString());
- HloInstruction* conv1 = nullptr;
- // Checks that it removed the Convert inside multiply around conv.
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Multiply(
- m::Convert(m::GetTupleElement(m::CustomCall(&conv1))),
- m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, bias_broadcast)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- side_input = f32[1,32,9,9] parameter(2)
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, side_input)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::ConstantEffectiveScalar(0))
- .WithShape(F32, {32}),
- m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- side_input = f32[1,32,9,9] parameter(2)
- side_input_scale = f32[] constant(42)
- side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
- side_input_product = multiply(side_input, side_input_scale_broadcast)
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, side_input_product)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::ConstantEffectiveScalar(0))
- .WithShape(F32, {32}),
- m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- side_input = f32[1,32,9,9] parameter(3)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, side_input)
- ROOT sum2 = add(sum, bias_broadcast)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2),
- m::Parameter(3)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, bias)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,17,9,9] parameter(0)
- filters = f16[3,3,17,32] parameter(1)
- bias = f16[32] parameter(2)
- side_input = f16[1,32,9,9] parameter(3)
-
- inputs_f32 = f32[1,17,9,9] convert(inputs)
- filters_f32 = f32[3,3,17,32] convert(filters)
- bias_f32 = f32[32] convert(bias)
- bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
- side_input_f32 = f32[1,32,9,9] convert(side_input)
- conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Parameter(1), m::Parameter(2),
- m::Parameter(3)),
- 0)
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-// We should be able to lower this to an f16 convolution even though the
-// f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
-TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
- filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
- bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
- side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
-
- conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_f16 = f16[1,32,9,9] convert(conv_f32)
- ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
- .WithElementType(F16),
- m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
- .WithElementType(F16),
- m::Parameter(2), m::Reshape(m::Parameter(3))),
- 0)
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,17,9,9] parameter(0)
- filters = f16[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- side_input = f16[1,32,9,9] parameter(3)
-
- inputs_f32 = f32[1,17,9,9] convert(inputs)
- filters_f32 = f32[3,3,17,32] convert(filters)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- side_input_f32 = f32[1,32,9,9] convert(side_input)
- conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- // fp16 convs only support fp16 biases. Because bias is fp32, it doesn't get
- // fused in, and we get an fp32 conv.
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Convert(m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Convert(m::Parameter(0)).WithElementType(F32),
- m::Convert(m::Parameter(1)).WithElementType(F32),
- m::Parameter(2),
- m::Convert(m::Parameter(3)).WithElementType(F32)),
- 0))
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,2,2,2] parameter(0)
- filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
- bias = f16[2] parameter(1)
- bias_f32 = f32[2] convert(bias)
- side_input_f32 = f32[1,2,2,2] constant({{
- {{0.5, 0.25}, {0.125, 0.0625}},
- {{0.5, 0.25}, {0.125, 0.0625}}
- }})
-
- inputs_f32 = f32[1,2,2,2] convert(inputs)
- bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
- conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
- window={size=1x1}, dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph, and fold
- // convert back into constants.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
- HloConstantFolding constant_folding;
- TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), m::Constant().WithElementType(F16),
- m::Parameter(1), m::Constant().WithElementType(F16)),
- 0)
- .WithShape(F16, {1, 2, 2, 2})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,2,2,2] parameter(0)
- filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
- bias = f16[2] parameter(1)
- bias_f32 = f32[2] convert(bias)
- side_input_f32 = f32[1,2,2,2] constant({{
- {{0.1, 0.2}, {0.3, 0.4}},
- {{0.5, 0.6}, {0.7, 0.8}}
- }})
-
- inputs_f32 = f32[1,2,2,2] convert(inputs)
- bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
- conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
- window={size=1x1}, dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph, and fold
- // convert back into constants.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
- HloConstantFolding constant_folding;
- TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- // This doesn't get transformed into an f16 conv because the filters param is
- // not losslessly expressible as f16.
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Convert(m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Convert(m::Parameter(0)).WithElementType(F32),
- m::Constant().WithElementType(F32),
- m::Convert(m::Parameter(1)).WithElementType(F32),
- m::Constant().WithElementType(F32)),
- 0)
- .WithShape(F32, {1, 2, 2, 2}))
- .WithElementType(F16)));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) {
- MAYBE_SKIP_TEST("I8");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
- zero = s32[] constant(0)
- zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
- relu = maximum(conv, zeros)
-
- lower = s32[] constant(-128)
- lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
- upper = s32[] constant(127)
- uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
-
- clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
-
- ROOT convert = s8[1,32,9,9] convert(clamp)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), //
- m::Parameter(1), //
- m::Broadcast(m::ConstantEffectiveScalar(0))
- .WithShape(F32, {32})),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- conv->backend_config<GpuBackendConfig>());
- const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) {
- MAYBE_SKIP_TEST("F64");
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = f64[1,17,9,9] parameter(0)
- filter = f64[3,3,17,32] parameter(1)
- bias = f64[1,32,9,9] broadcast(f64[32] convert(f32[32] parameter(2))), dimensions={1}
- conv = f64[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- ROOT root = f64[1,32,9,9] add(conv, bias)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter{GetCudaComputeCapability()};
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
- GetToolkitVersion()};
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
- m::Parameter(0), //
- m::Parameter(1), //
- m::Convert(m::Parameter(2)).WithShape(F64, {32})),
- 0)
- .WithShape(F64, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
- MAYBE_SKIP_TEST("I8");
- // clamp(max(0, conv(x, w)+bias)); for int8_t
- TestClamp(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- zero = f32[] constant(0)
- zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
-
- input = s8[1,3,3,64] parameter(0)
- filter = s8[3,3,64,64] parameter(1)
- bias = f32[64] parameter(2)
-
- inputs32 = s32[1,3,3,64] convert(input)
- filters32 = s32[3,3,64,64] convert(filter)
-
- conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
- convfloat = f32[1,3,3,64] convert(conv)
- broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
- relu = f32[1,3,3,64] maximum(zeros, add1)
-
- lower = f32[] constant(-128)
- lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
- upper = f32[] constant(127)
- uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
- clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
-
- ROOT convert = s8[1,3,3,64] convert(clamp)
- })",
- // post_hlo
- R"(
-// CHECK: [[cudnn_conv_bias_activation_7_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
- )");
-}
-
-// Disabled per b/190854862 or nvbugs/3326122.
-TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) {
- MAYBE_SKIP_TEST("I8");
- // max(0, convert<float>(conv<int32_t>(int8_x),
- // conv<int32_t>(int8_w))+float_bias)); int8_t to float via bias.
- TestClamp(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- zero = f32[] constant(0)
- zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
-
- input = s8[1,3,3,64] parameter(0)
- filter = s8[3,3,64,64] parameter(1)
- bias = f32[64] parameter(2)
-
- inputs32 = s32[1,3,3,64] convert(input)
- filters32 = s32[3,3,64,64] convert(filter)
-
- conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
- convfloat = f32[1,3,3,64] convert(conv)
- broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
- ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
- })",
- // post_hlo
- R"(
- ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> f32[1,3,3,64] {
- ; CHECK: [[custom_call_0:%[^ ]+]]{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call([[input_1:%[^ ]+]], [[copy_2:%[^ ]+]]{{(\.[0-9])?}}, [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
- ; CHECK-NEXT: ROOT [[get_tuple_element_4:%[^ ]+]]{{(\.[0-9])?}} = f32[1,3,3,64]{3,2,1,0} get-tuple-element([[custom_call_0]]{{(\.[0-9])?}}), index=0
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest,
- TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) {
- MAYBE_SKIP_TEST("I8");
- // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
- // convert<int32_t>(int8_side_input) + bias)); for int8_t
- TestClamp(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- zero = f32[] constant(0)
- zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
- alpha_conv_scalar = f32[] constant(0.999994934)
- alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
- alpha_side_input_scalar = f32[] constant(0.899994934)
- alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
- input = s8[1,3,3,64] parameter(0)
- filter = s8[3,3,64,64] parameter(1)
- side_input = s8[1,3,3,64] parameter(2)
- bias = f32[64] parameter(3)
-
- inputs32 = s32[1,3,3,64] convert(input)
- filters32 = s32[3,3,64,64] convert(filter)
- side_input_f32 = f32[1,3,3,64] convert(side_input)
-
- conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
- convfloat = f32[1,3,3,64] convert(conv)
- scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
- scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
- broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
- add2 = f32[1,3,3,64] add(add1, scaled_side_input)
- relu = f32[1,3,3,64] maximum(zeros, add2)
-
- lower = f32[] constant(-128)
- lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
- upper = f32[] constant(127)
- uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
- clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
-
- ROOT convert = s8[1,3,3,64] convert(clamp)
- })",
- // post_hlo
- R"(
-// CHECK: [[cudnn_conv_bias_activation_11_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest,
- TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) {
- MAYBE_SKIP_TEST("I8");
- // From:
- // convert<int8_t>(clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
- // float_side_input + bias))); To: convert<int8_t>(clamp(conv(int8_x, int8_w,
- // float_alpha_side, float_side_input, float_bias)));
- TestClamp(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- zero = f32[] constant(0)
- zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
- alpha_conv_scalar = f32[] constant(0.999994934)
- alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
- alpha_side_input_scalar = f32[] constant(0.899994934)
- alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
- input = s8[1,3,3,64] parameter(0)
- filter = s8[3,3,64,64] parameter(1)
- side_input = f32[1,3,3,64] parameter(2)
- bias = f32[64] parameter(3)
-
- inputs32 = s32[1,3,3,64] convert(input)
- filters32 = s32[3,3,64,64] convert(filter)
-
- conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
- convfloat = f32[1,3,3,64] convert(conv)
- scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
- scaled_side_input = f32[1,3,3,64] multiply(side_input, alpha_side_input)
- broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
- add2 = f32[1,3,3,64] add(add1, scaled_side_input)
- relu = f32[1,3,3,64] maximum(zeros, add2)
-
- lower = f32[] constant(-128)
- lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
- upper = f32[] constant(127)
- uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
- clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
-
- ROOT convert = s8[1,3,3,64] convert(clamp)
- })",
- // post_hlo
- R"(
-// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest,
- TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) {
- MAYBE_SKIP_TEST("I8");
- // From:
- // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
- // convert<float>(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w,
- // float_alpha_side, convert<float>(int8_side_input), float_bias));
- TestClamp(
- // pre_hlo
- R"(
- HloModule Test
-
- ENTRY Test {
- zero = f32[] constant(0)
- zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
- alpha_conv_scalar = f32[] constant(0.999994934)
- alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
- alpha_side_input_scalar = f32[] constant(0.899994934)
- alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
- input = s8[1,3,3,64] parameter(0)
- filter = s8[3,3,64,64] parameter(1)
- side_input = s8[1,3,3,64] parameter(2)
- bias = f32[64] parameter(3)
-
- inputs32 = s32[1,3,3,64] convert(input)
- filters32 = s32[3,3,64,64] convert(filter)
- side_input_f32 = f32[1,3,3,64] convert(side_input)
-
- conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
- convfloat = f32[1,3,3,64] convert(conv)
- scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
- scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
- broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
- add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
- add2 = f32[1,3,3,64] add(add1, scaled_side_input)
- relu = f32[1,3,3,64] maximum(zeros, add2)
-
- lower = f32[] constant(-128)
- lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
- upper = f32[] constant(127)
- uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
- ROOT clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
- })",
- // post_hlo
- R"(
-// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[fusion_1_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
- )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) {
- MAYBE_SKIP_TEST("I8");
- // Check that integer convolution without clamp to int8_t is not allowed.
- // convert<int8_t>(custom_call<int32_t>(int32_x, int32_w,
- // cudnnConvolutionForward))
- const std::string module_str = absl::StrFormat(R"(
- HloModule Test
-
- ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
- zero = s8[] constant(0)
- zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
- input = s8[1,17,9,9]{3,2,1,0} parameter(0)
- filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
- custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
- get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
- convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
- ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
- })");
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
- GetDnnVersion(), GetToolkitVersion())
- .Run(m.get())
- .ok());
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) {
- MAYBE_SKIP_TEST("I8");
- // Although bias and so on are fused with forward convolution,
- // it is still not allowed if the output is not clampped/converted to int8_t
- // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for
- // int8_t
-
- const std::string module_str = absl::StrFormat(R"(
- HloModule Test
-
- ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
- zero = s8[] constant(0)
- zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
- input = s8[1,17,9,9]{3,2,1,0} parameter(0)
- filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
- custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
- get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
- convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
- ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
- })");
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
- GetDnnVersion(), GetToolkitVersion())
- .Run(m.get())
- .ok());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
deleted file mode 100644
index 23b238d..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
+++ /dev/null
@@ -1,1781 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_mha_rewriter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <numeric>
-#include <optional>
-#include <queue>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/types.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#endif
-
-namespace xla {
-namespace gpu {
-namespace {
-namespace m = match;
-
-// A struct that contains all the matched nodes
-// and results from pattern matching forward graph
-struct MatchFwdResult {
- HloInstruction* matched_bmm_1 = nullptr;
- HloInstruction* matched_bmm_2 = nullptr;
- HloInstruction* matched_bias = nullptr;
- HloInstruction* matched_scale = nullptr;
- HloInstruction* matched_softmax_input = nullptr;
- HloInstruction* matched_reduce_sum = nullptr;
-
- double matched_dropout_rate = 0.0;
- bool need_canonicalization = false;
- bool is_training = false;
- // We use this to keep track of whether the bias is being
- // applied to the bmm1 is a causal mask, cuDNN can generate causal mask inside
- // the attention kernel to save I/O.
- bool is_causal_mask = false;
- bool has_match = false;
- std::string matched_custom_call_name;
-};
-
-// A struct that contains all the matched nodes
-// and results from pattern matching backward graph
-struct MatchBwdResult {
- HloInstruction* matched_bmm_1_grad_1 = nullptr;
- HloInstruction* matched_bmm_1_grad_2 = nullptr;
-
- HloInstruction* matched_bmm_2_grad_1 = nullptr;
- HloInstruction* matched_bmm_2_grad_2 = nullptr;
- HloInstruction* matched_dbias = nullptr;
- // We use this to keep track of all gradient bmms that need
- // canonicalization.
- bool bmm_1_grad_1_need_canonicalization = false;
- bool bmm_1_grad_2_need_canonicalization = false;
- bool bmm_2_grad_1_need_canonicalization = false;
- bool bmm_2_grad_2_need_canonicalization = false;
-
- bool has_match = false;
- std::string matched_custom_call_name;
-};
-
-template <typename Pattern>
-auto OptionalReshape(Pattern pattern) {
- auto shared = m::SharedSubpattern(pattern);
- return m::AnyOf<HloInstruction>(m::Reshape(shared), shared);
-}
-
-template <typename Pattern>
-auto OptionalConvert(Pattern pattern) {
- auto shared = m::SharedSubpattern(pattern);
- return m::AnyOf<HloInstruction>(m::Convert(shared), shared);
-}
-
-template <typename Pattern>
-auto OptionalBitcast(Pattern pattern) {
- auto shared = m::SharedSubpattern(pattern);
- return m::AnyOf<HloInstruction>(m::Bitcast(shared), shared);
-}
-
-template <typename Pattern>
-auto OptionalBroadcast(Pattern pattern) {
- auto shared = m::SharedSubpattern(pattern);
- return m::AnyOf<HloInstruction>(m::Broadcast(shared), shared);
-}
-
-bool IsBatchedMatmul(const HloInstruction* instr) {
- if (instr->opcode() != HloOpcode::kDot) return false;
- if (Cast<HloDotInstruction>(instr)->sparse_operands()) return false;
- const DotDimensionNumbers& dot_dims = instr->dot_dimension_numbers();
- bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
- !dot_dims.rhs_batch_dimensions().empty();
- return is_batch_dot;
-}
-
-// We need to check if current gemm is sharing a parent node with a forward
-// fMHA call because when we match backward gemms, the only way that we can be
-// sure this is a backward gemm is to see if it's sharing the same operand with
-// forward mha call(i.e Q,K,V,activation tensors). We can also use this function
-// to infer if a gemm is a forward fmha gemm or not. We check this by doing a
-// BFS of all operands to see if there's any user that is a forward fMHA custom
-// call. We continue the traversal for shape ops like bitcast, reshape and
-// transpose until we see a forward fmha call or there's no shape ops in path
-// which means that current node will never share the same operand with a
-// forward fmha call.
-bool IsSharingOperandWithFwdMha(HloInstruction* gemm) {
- for (int64_t i = 0; i < gemm->operands().size(); i++) {
- std::queue<HloInstruction*> visit_list;
- visit_list.push(gemm->mutable_operand(i));
- while (!visit_list.empty()) {
- HloInstruction* current_instr = visit_list.front();
- for (auto user : current_instr->users()) {
- switch (user->opcode()) {
- case HloOpcode::kBitcast:
- case HloOpcode::kReshape:
- case HloOpcode::kTranspose: {
- visit_list.push(user);
- break;
- }
- case HloOpcode::kCustomCall: {
- if (IsFwdCustomCallTofMHA(*user)) {
- return true;
- }
- } break;
- default:
- break;
- }
- }
- visit_list.pop();
- }
- }
- return false;
-}
-// When we reach a gemm instruction, it could be one of the 3 cases:
-// 1. one of the 2 gemms in forward fmha call
-// 2. one of the 4 gemms in backward fmha call
-// 3. gemms of other un-related layers
-// 3 can be easily ruled out by the pattern matcher.
-// However, 1 and 2 have very similar bmm-bmm structures.
-// We need to determine that we exactly match case 1 for forward gemms
-// which have below properties:
-// - A batched matmul
-// - None of the operands is a forward fmha call, in which case would make it
-// a backward gemm.
-// - It's not directly or indirectly sharing an operand with any other fmha
-// call, in which case would make it a backward gemm
-bool IsFirstFwdMatmul(HloInstruction* gemm) {
- return IsBatchedMatmul(gemm) && !IsFwdCustomCallTofMHA(*gemm->operand(0)) &&
- !IsFwdCustomCallTofMHA(*gemm->operand(1)) &&
- !IsSharingOperandWithFwdMha(gemm);
-}
-
-bool IsScalar(const HloInstruction* instr) {
- return ShapeUtil::IsEffectiveScalar(instr->shape());
-}
-
-bool IsReduceMax(const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kReduce &&
- instr->to_apply()->root_instruction()->opcode() == HloOpcode::kMaximum;
-}
-
-bool IsReduceSum(const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kReduce &&
- instr->to_apply()->root_instruction()->opcode() == HloOpcode::kAdd;
-}
-
-// Set up subpatterns for re-use.
-// Matches softmax sub-pattern ->
-// divide(exp(Subtract(producer, reduce_max(producer))),
-// broadcast(reduce_add(exp(Subtract(...))))). There might be reshape and
-// convert nodes between reduce and Subtract.
-// TODO TJ: Make this more general to any patterns that has this structure when
-// cudnn runner supports generic cudnnOpGraphs. producer
-// | \
-// | reduce
-// | |
-// | broadcast
-// | /
-// root
-auto GetUnfusedReduceMaxSumSoftmaxPattern(
- HloInstruction** softmax_input = nullptr,
- HloInstruction** softmax_reduce_sum = nullptr,
- HloInstruction** softmax_reduce_sum_bcast = nullptr) {
- // The reduce-max part of the softmax
- // reduce_max and subtract will always have exactly 1 user
- // in both training and inference
- // softmax_input should always have exactly 2 users
- auto unfused_softmax_max_subpattern = m::SharedSubpattern(
- m::Subtract(
- m::Op(),
- m::Broadcast(OptionalConvert(
- m::Op()
- .WithPredicate(IsReduceMax)
- .WithOneUse()
- .WithOperand(0, OptionalBitcast(OptionalConvert(
- m::Op(softmax_input).WithNumUser(2)))))))
- .WithOneUse());
- // The reduce-add part of the softmax
- // reduce_sum and reduce_sum_broadcast should have 2 users in training
- // and 1 user in inference
- auto unfused_softmax_sum_subpattern = m::SharedSubpattern(m::Divide(
- OptionalBitcast(m::Exp(unfused_softmax_max_subpattern)),
- m::Broadcast(
- softmax_reduce_sum_bcast,
- OptionalConvert(
- m::Op(softmax_reduce_sum)
- .WithOperand(0, OptionalBitcast(OptionalConvert(
- m::Exp(unfused_softmax_max_subpattern))))
- .WithPredicate(IsReduceSum)
- .WithAtMostNumUser(2)))
- .WithAtMostNumUser(2)));
- return unfused_softmax_sum_subpattern;
-}
-
-std::optional<double> GetConstantValue(const HloInstruction* inst) {
- if (!IsScalar(inst)) {
- return std::nullopt;
- }
- switch (inst->shape().element_type()) {
- case F16:
- return static_cast<float>(inst->literal().GetFirstElement<half>());
- case BF16:
- return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
- case F32:
- return inst->literal().GetFirstElement<float>();
- case F64:
- return inst->literal().GetFirstElement<double>();
- default:
- return std::nullopt;
- }
-}
-
-double GetDropoutRateFromHlo(HloInstruction* dropout) {
- std::optional<double> dropout_rate_inv;
- dropout_rate_inv = GetConstantValue(dropout);
- if (!dropout_rate_inv.has_value()) {
- return 0.0;
- }
- // In dropout, inputs are divided by (1 - rate), we need to divide 1 by
- // the constant in dropout node and substract
- // from 1 here to get the actual dropout rate.
- return (1.0 - (1.0 / *dropout_rate_inv));
-}
-
-bool IsComputeCapabilityAndCudnnSupported(
- stream_executor::CudaComputeCapability cc,
- stream_executor::dnn::VersionInfo cudnn_version,
- stream_executor::dnn::VersionInfo supported_cudnn_version) {
- // Enforce capability minor == 0 because hardware with a non-zero minor
- // number typically has insufficient shared memory for cuDNN FMHA.
- if (cc.IsAtLeastAmpere() && cc.minor == 0 &&
- cudnn_version >= supported_cudnn_version) {
- return true;
- }
- VLOG(2) << absl::StrFormat(
- "CudnnFusedMHARewriter did not run. Unsupported compute "
- "capability(%s; major should be >= 8, minor should be 0) or cudnn version"
- "(%s; should be >= %s)",
- cc.ToString(), cudnn_version.ToString(),
- supported_cudnn_version.ToString());
- return false;
-}
-
-bool IsSupportedPrimitiveType(const HloInstruction* bmm) {
- PrimitiveType dtype = bmm->shape().element_type();
- return dtype == BF16 || dtype == F16;
-}
-
-std::vector<int64_t> GetDimensionVector(absl::Span<const int64_t> dimensions,
- absl::Span<const int64_t> dim_nums) {
- std::vector<int64_t> vec(dim_nums.size());
- for (int i = 0; i < dim_nums.size(); i++) {
- vec[i] = dimensions.at(dim_nums.at(i));
- }
- return vec;
-}
-
-struct QKVLayout {
- int64_t batch;
- int64_t num_heads;
- int64_t seqlen_q;
- int64_t seqlen_kv;
- int64_t hidden_dim;
-};
-
-absl::StatusOr<std::optional<QKVLayout>> GetQKVLayout(
- HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization) {
- // get layout from bmm1
- const DotDimensionNumbers& bmm1_dnums = bmm_1->dot_dimension_numbers();
- TF_ASSIGN_OR_RETURN(
- std::vector<int64_t> bmm1_s_q_dims,
- GetNonContractingDims(bmm_1->operand(0)->shape(),
- bmm1_dnums.lhs_batch_dimensions(),
- bmm1_dnums.lhs_contracting_dimensions()));
-
- TF_ASSIGN_OR_RETURN(
- std::vector<int64_t> bmm1_s_kv_dims,
- GetNonContractingDims(bmm_1->operand(1)->shape(),
- bmm1_dnums.rhs_batch_dimensions(),
- bmm1_dnums.rhs_contracting_dimensions()));
-
- std::vector<int64_t> bmm1_bh =
- GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
- bmm1_dnums.lhs_batch_dimensions());
-
- std::vector<int64_t> bmm1_s_q = GetDimensionVector(
- bmm_1->operand(0)->shape().dimensions(), bmm1_s_q_dims);
-
- std::vector<int64_t> bmm1_s_kv = GetDimensionVector(
- bmm_1->operand(1)->shape().dimensions(), bmm1_s_kv_dims);
-
- std::vector<int64_t> bmm1_d =
- GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
- bmm1_dnums.lhs_contracting_dimensions());
-
- TF_RET_CHECK(bmm1_bh.size() == 2);
- TF_RET_CHECK(bmm1_s_q.size() == 1);
- TF_RET_CHECK(bmm1_s_kv.size() == 1);
- TF_RET_CHECK(bmm1_d.size() == 1);
-
- // get layout from bmm2
- const DotDimensionNumbers& bmm2_dnums = bmm_2->dot_dimension_numbers();
- TF_ASSIGN_OR_RETURN(
- std::vector<int64_t> bmm2_lhs_non_contracting_dims,
- GetNonContractingDims(bmm_2->operand(0)->shape(),
- bmm2_dnums.lhs_batch_dimensions(),
- bmm2_dnums.lhs_contracting_dimensions()));
-
- TF_ASSIGN_OR_RETURN(
- std::vector<int64_t> bmm2_rhs_non_contracting_dims,
- GetNonContractingDims(bmm_2->operand(1)->shape(),
- bmm2_dnums.rhs_batch_dimensions(),
- bmm2_dnums.rhs_contracting_dimensions()));
-
- std::vector<int64_t> bmm2_bh =
- GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
- bmm2_dnums.lhs_batch_dimensions());
-
- std::vector<int64_t> bmm2_s_kv =
- GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
- bmm2_dnums.lhs_contracting_dimensions());
-
- std::vector<int64_t> bmm2_s_q =
- need_canonicalization
- ? GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
- bmm2_rhs_non_contracting_dims)
- : GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
- bmm2_lhs_non_contracting_dims);
-
- std::vector<int64_t> bmm2_d =
- need_canonicalization
- ? GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
- bmm2_lhs_non_contracting_dims)
- : GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
- bmm2_rhs_non_contracting_dims);
-
- TF_RET_CHECK(bmm2_bh.size() == 2);
- TF_RET_CHECK(bmm2_s_q.size() == 1);
- TF_RET_CHECK(bmm2_s_kv.size() == 1);
- TF_RET_CHECK(bmm2_d.size() == 1);
-
- // check if bhsd is correct between bmm1 and bmm2
- if (bmm1_bh[0] != bmm2_bh[0] || bmm1_bh[1] != bmm2_bh[1] ||
- bmm1_s_q[0] != bmm2_s_q[0] || bmm1_s_kv[0] != bmm2_s_kv[0] ||
- bmm1_d[0] != bmm2_d[0]) {
- return std::nullopt;
- }
-
- QKVLayout qkv_layout;
- qkv_layout.batch = bmm1_bh[0];
- qkv_layout.num_heads = bmm1_bh[1];
- qkv_layout.seqlen_q = bmm1_s_q[0];
- qkv_layout.seqlen_kv = bmm1_s_kv[0];
- qkv_layout.hidden_dim = bmm1_d[0];
- return qkv_layout;
-}
-
-absl::StatusOr<bool> IsFlashAttention(
- QKVLayout qkv_layout, bool is_training,
- stream_executor::CudaComputeCapability cc,
- stream_executor::dnn::VersionInfo cudnn_version) {
- int64_t s_q = qkv_layout.seqlen_q;
- int64_t s_kv = qkv_layout.seqlen_kv;
- int64_t hidden_dim = qkv_layout.hidden_dim;
- // start with most relaxed constraint
- bool is_seqlen_supported = (!is_training || (s_q % 2 == 0 && s_kv % 2 == 0));
- bool is_hidden_dim_supported = hidden_dim <= 128 && hidden_dim % 8 == 0;
- bool is_flash_attention = is_seqlen_supported && is_hidden_dim_supported;
- if (!is_flash_attention) return false;
-
- // going backwards to check compatibility
- if ((is_training && (s_q < 64 || s_kv < 64)) &&
- !IsComputeCapabilityAndCudnnSupported(
- cc, cudnn_version, stream_executor::dnn::VersionInfo(9, 0, 0))) {
- VLOG(2) << "Flash attention training with seq < 64 not supported cuDNN < "
- "9.0.0.";
- return false;
- }
-
- if ((hidden_dim != 64 && hidden_dim != 128) &&
- !IsComputeCapabilityAndCudnnSupported(
- cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 6))) {
- VLOG(2) << "Flash attention head dim != 64 or 128 not supported with cuDNN "
- "< 8.9.6.";
- return false;
- }
-
- if ((is_training && s_kv % 64 != 0) &&
- !IsComputeCapabilityAndCudnnSupported(
- cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 5))) {
- VLOG(2) << "Flash attention training with seq kv % 64 != 0 not supported "
- "with cuDNN < 8.9.5.";
- return false;
- }
-
- if (!IsComputeCapabilityAndCudnnSupported(
- cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 4))) {
- VLOG(2) << "Require cuDNN 8.9.4 to run flash attention.";
- return false;
- }
- return is_flash_attention;
-}
-
-bool IsCausalMaskPattern(HloInstruction* mask) {
- auto causal_mask =
- m::Select(m::Compare(m::Iota(), m::Iota()), m::Broadcast(m::Constant()),
- m::Broadcast(m::Constant()));
- auto causal_mask_pattern_fwd_remat =
- m::Broadcast(OptionalBitcast(causal_mask));
- auto causal_mask_pattern_bwd = m::Broadcast(m::Convert(OptionalBitcast(
- m::Minimum(m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))));
- HloInstruction* param = nullptr;
- HloInstruction* gte = nullptr;
- auto causal_mask_pattern_fwd = m::Broadcast(
- OptionalBitcast(m::GetTupleElement(>e, m::Parameter(¶m))));
- auto causal_mask_pattern = m::AnyOf<HloInstruction>(
- causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd,
- causal_mask_pattern_bwd);
- if (Match(mask, causal_mask_pattern)) {
- if (param != nullptr && param->parent()->IsWhileBodyComputation()) {
- // need to track to outside of the while loop body to find the real mask.
- auto while_instr = param->parent()->WhileCallInstruction();
- auto mask_index = gte->tuple_index();
- auto actual_mask =
- while_instr->mutable_operand(0)->mutable_operand(mask_index);
- auto causal_mask_pattern_fwd =
- OptionalBitcast(m::Convert(m::MinimumAnyOrder(
- m::Op(),
- OptionalBitcast(m::MinimumAnyOrder(
- m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))));
- return Match(actual_mask, causal_mask_pattern_fwd);
- }
- return true;
- }
- return false;
-}
-
-MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result,
- int64_t bmm2_operand_position,
- HloInstruction* instr) {
- // Matches the dropout-softmax subpattern.
- // Softmax_output is a divide
- // Dropout can take multiple forms, we capture 2 forms here based on
- // heurustics Form 1 -> softmax - mul - select(dropout) - BMM2
- MatchFwdResult match_result = previous_result;
- HloInstruction* softmax_reduce_sum;
- HloInstruction* softmax_reduce_sum_bcast;
- HloInstruction* bmm_2;
- HloInstruction* softmax_input;
- HloInstruction* dropout = nullptr;
- auto dropout_softmax_pattern_form_1 = m::Select(
- m::Op(),
- OptionalConvert(m::MultiplyAnyOrder(
- OptionalBitcast(OptionalReshape(
- OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
- &softmax_input, &softmax_reduce_sum,
- &softmax_reduce_sum_bcast)))),
- m::Broadcast(
- OptionalConvert(m::Constant(&dropout).WithPredicate(IsScalar))))),
- m::Op());
-
- // Form 2 -> softmax - mul - BMM2
- // /
- // /
- // select(dropout)
- auto dropout_softmax_pattern_form_2 =
- OptionalBitcast(OptionalBitcast(OptionalConvert(m::MultiplyAnyOrder(
- OptionalReshape(OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
- &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast))),
- m::Broadcast(
- OptionalConvert(OptionalBitcast(OptionalReshape(m::Select(
- m::Op(),
- m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)),
- m::Op())))))))));
-
- // Form3 -> softmax - mul(dropout) - mul(scale) - BMM2
- auto dropout_softmax_pattern_form_3 = m::MultiplyAnyOrder(
- m::MultiplyAnyOrder(
- OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
- &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast)),
- m::Op()),
- m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)));
-
- // Try matching BMM1 - (Scale) - (Bias) - Softmax - (Dropout) -
- // BMM2 Dropout with non-zero drop rate has select(divide(softmax_output,
- // broadcast(1-dropout_rate)))
- auto softmax_dropout_bmm2_pattern =
- m::Op(&bmm_2)
- .WithPredicate(IsBatchedMatmul)
- .WithOperand(bmm2_operand_position,
- m::AnyOf<HloInstruction>(
- OptionalBitcast(OptionalConvert(
- GetUnfusedReduceMaxSumSoftmaxPattern(
- &softmax_input, &softmax_reduce_sum,
- &softmax_reduce_sum_bcast))),
- dropout_softmax_pattern_form_1,
- dropout_softmax_pattern_form_2,
- dropout_softmax_pattern_form_3));
-
- if (!Match(instr, softmax_dropout_bmm2_pattern) ||
- !IsSupportedPrimitiveType(bmm_2)) {
- match_result.has_match = false;
- return match_result;
- }
- if (softmax_reduce_sum->users()[0]->opcode() == HloOpcode::kConvert) {
- softmax_reduce_sum = softmax_reduce_sum->users()[0];
- }
- match_result.is_training = softmax_reduce_sum->user_count() == 2 &&
- softmax_reduce_sum_bcast->user_count() == 2;
- match_result.matched_bmm_2 = bmm_2;
- if (dropout) {
- match_result.matched_dropout_rate = GetDropoutRateFromHlo(dropout);
- }
- match_result.matched_softmax_input = softmax_input;
- match_result.matched_reduce_sum = softmax_reduce_sum;
- match_result.has_match = true;
- return match_result;
-}
-
-MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
- HloInstruction* softmax_input,
- bool has_dropout) {
- MatchFwdResult match_result = previous_result;
- HloInstruction* bmm_1;
- HloInstruction* bias = nullptr;
- HloInstruction* scale = nullptr;
- // bmm1/scale/bias add should have 2 users if being connected to softmax
- // otherwise should have exactly 1 user
- auto first_bmm_pattern =
- m::SharedSubpattern(m::Op(&bmm_1).WithPredicate(IsBatchedMatmul));
- auto unfused_scaled_bmm_subpattern = m::MultiplyAnyOrder(
- OptionalConvert(first_bmm_pattern.WithOneUse()),
- OptionalConvert(
- m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar))));
- if (Match(softmax_input,
- OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
- first_bmm_pattern, unfused_scaled_bmm_subpattern))))) {
- // bmm1 - (scale) - softmax
- match_result.matched_bmm_1 = bmm_1;
- match_result.matched_scale = scale;
- match_result.matched_custom_call_name =
- has_dropout ? kCudnnfMHASoftmaxDropoutCallTarget
- : kCudnnfMHASoftmaxCallTarget;
- match_result.has_match = true;
- } else if (Match(softmax_input,
- OptionalBitcast(m::AddAnyOrder(
- OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
- unfused_scaled_bmm_subpattern.WithOneUse(),
- first_bmm_pattern.WithOneUse()))),
- m::Op(&bias))))) {
- // bmm1 - (scale) - bias - softmax
- match_result.matched_bmm_1 = bmm_1;
- match_result.matched_scale = scale;
- match_result.matched_custom_call_name =
- has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget
- : kCudnnfMHAScaleBiasSoftmaxCallTarget;
- match_result.is_causal_mask |= IsCausalMaskPattern(bias);
- if (!match_result.is_causal_mask &&
- bias->opcode() == HloOpcode::kBroadcast) {
- // we can take the bias before broadcast
- auto dims = Cast<HloBroadcastInstruction>(bias)->dimensions();
- if (dims == std::vector<int64_t>{2, 3} ||
- dims == std::vector<int64_t>{0, 2, 3} ||
- dims == std::vector<int64_t>{1, 2, 3}) {
- // shapes [1, 1, s, s], [b, 1, s, s], [1, h, s, s] are supported
- HloInstruction* bias_bc = bias->mutable_operand(0);
- // bitcast bias_before_broadcast to be 4D
- std::vector<int64_t> bitcast_dims(bias->shape().rank(), 1);
- for (int dim : dims) {
- bitcast_dims[dim] = bias->shape().dimensions()[dim];
- }
- bias = bias_bc->AddInstruction(HloInstruction::CreateBitcast(
- ShapeUtil::MakeShape(bias->shape().element_type(), bitcast_dims),
- bias_bc));
- }
- }
- match_result.matched_bias = bias;
- match_result.has_match = true;
- } else {
- match_result.has_match = false;
- }
- return match_result;
-}
-
-// We will try to match all the patterns below:
-// BMM1 - Scale - bias - Softmax - Dropout - BMM2
-// BMM1 - Scale - bias - Softmax - BMM2
-// BMM1 - Softmax - Dropout - BMM2
-// BMM1 - Softmax - BMM2
-MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
- // We need to match 2 general cases:
- // 1. bmm1 --> (intermediate nodes) --> bmm2 <-- V matrix
- // 2. V matrix --> bmm2 <-- (intermediate nodes) <-- bmm1
- // to determine if we need to canonicalize bmm2.
- // So we go through both of bmm2's operands and see which one matches our
- // desired patterns, if operand 1 consumes them, then we need to canonicalize.
- MatchFwdResult match_result;
- for (auto bmm2_operand_pos : {0, 1}) {
- if (bmm2_operand_pos == 1) {
- match_result.need_canonicalization = true;
- }
-
- bool has_dropout = false;
- // We first check if bmm2 is connect to a softmax or dropout.
- // If so, we set softmax input and dropout rate to their corresponding
- // values.
- match_result =
- MatchSoftmaxDropoutBmm(match_result, bmm2_operand_pos, instr);
- if (!match_result.has_match) {
- continue;
- }
- has_dropout = match_result.matched_dropout_rate > 0.0;
- match_result = MatchBmm1UnfusedBiasSoftmaxBmm2(
- match_result, match_result.matched_softmax_input, has_dropout);
- if (match_result.has_match) {
- return match_result;
- }
- }
- // Didn't find any match
- match_result.need_canonicalization = false;
- return match_result;
-}
-
-bool IsBmm2GradGemm2(HloInstruction* instr) {
- // Check to see if input bmm is bmm2 gradient gemm2, it needs to be either:
- // 1. having 1 user in cases of dropout
- // 2. having 2 users in other cases.
- return (instr->user_count() == 1) || (instr->user_count() == 2);
-}
-
-MatchBwdResult MatchBmm1GradGemm1(MatchBwdResult previous_result,
- HloInstruction* bmm_1) {
- MatchBwdResult match_result = previous_result;
- match_result.has_match = false;
- const HloInstruction* q_tensor = bmm_1->operand(0);
- for (int64_t i = 0; i < q_tensor->user_count(); i++) {
- HloInstruction* q_tensor_user_i = q_tensor->users()[i];
- if (IsBatchedMatmul(q_tensor_user_i) && q_tensor_user_i != bmm_1) {
- match_result.matched_bmm_1_grad_1 = q_tensor_user_i;
- // Check for canonicalization.
- if (match_result.matched_bmm_1_grad_1->operand_index(q_tensor) != 1) {
- match_result.bmm_1_grad_1_need_canonicalization = true;
- }
- match_result.has_match = true;
- }
- }
- return match_result;
-}
-
-MatchBwdResult MatchBmm1GradGemm2(MatchBwdResult previous_result,
- HloInstruction* fwd_fmha_call) {
- HloInstruction* bmm_1_grad_2 = nullptr;
- MatchBwdResult match_result = previous_result;
- match_result.has_match = false;
- // bmm1 gradient gemm2 shares the same input d_s as bmm1 gradient gemm1.
- // Check to see if bmm1 grad gemm1 needs canonicalization or not, if not,
- // then the shared input is the first operand.
- int64_t d_s_index = match_result.bmm_1_grad_1_need_canonicalization ? 1 : 0;
- HloInstruction* d_s_user_0 = match_result.matched_bmm_1_grad_1;
-
- HloInstruction* d_s = d_s_user_0->mutable_operand(d_s_index);
- if (d_s->opcode() == HloOpcode::kBitcast && d_s->user_count() == 1) {
- d_s = d_s->mutable_operand(0);
- }
-
- auto bmm_1_grad_2_it = std::find_if(
- d_s->users().begin(), d_s->users().end(), [&](HloInstruction* instr) {
- return instr != match_result.matched_bmm_1_grad_1 &&
- instr->opcode() == HloOpcode::kDot;
- });
- if (bmm_1_grad_2_it != d_s->users().end()) {
- bmm_1_grad_2 = *bmm_1_grad_2_it;
- } else {
- return match_result;
- }
-
- match_result.matched_bmm_1_grad_2 = bmm_1_grad_2;
-
- if (match_result.matched_bmm_1_grad_2->operand_index(d_s) != 0) {
- match_result.bmm_1_grad_2_need_canonicalization = true;
- }
- match_result.has_match = true;
- return match_result;
-}
-
-MatchBwdResult MatchBmm2GradGemm1(HloInstruction* fwd_fmha_call) {
- HloInstruction* bmm_2_grad_1 = nullptr;
- MatchBwdResult matched_result;
- // The second GTE of the forward MHA call is the input of the bmm2's gradient
- // gemm 1, we check to see if the current gemm satisfies above condition.
- int64_t activation_out_gte_index = 1;
- if (fwd_fmha_call->user_count() < 2 ||
- fwd_fmha_call->users()[activation_out_gte_index]->opcode() !=
- HloOpcode::kGetTupleElement ||
- fwd_fmha_call->users()[activation_out_gte_index]->user_count() > 1 ||
- !IsBatchedMatmul(
- fwd_fmha_call->users()[activation_out_gte_index]->users()[0])) {
- matched_result.has_match = false;
- return matched_result;
- }
- // Found fmha->GTE->gemm, assign it to bmm_2_grad_1 and check to see if it
- // needs canonicalization.
- bmm_2_grad_1 = fwd_fmha_call->users()[activation_out_gte_index]->users()[0];
- matched_result.matched_bmm_2_grad_1 = bmm_2_grad_1;
- if (bmm_2_grad_1->operand_index(
- fwd_fmha_call->users()[activation_out_gte_index]) != 0) {
- matched_result.bmm_2_grad_1_need_canonicalization = true;
- }
-
- matched_result.has_match = true;
- return matched_result;
-}
-
-MatchBwdResult MatchBmm2GradGemm2(MatchBwdResult previous_result,
- HloInstruction* fwd_fmha_call,
- bool v_transposed) {
- MatchBwdResult match_result = previous_result;
- match_result.has_match = false;
- // If v tensor is transposed by forward fmha call, then we need to take fmha v
- // input's producer's producer.
- const HloInstruction* v_tensor = v_transposed
- ? fwd_fmha_call->operand(2)->operand(0)
- : fwd_fmha_call->operand(2);
- for (int64_t i = 0; i < v_tensor->user_count(); i++) {
- HloInstruction* v_tensor_user_i = v_tensor->users()[i];
- if (IsBatchedMatmul(v_tensor_user_i) && IsBmm2GradGemm2(v_tensor_user_i)) {
- match_result.matched_bmm_2_grad_2 = v_tensor_user_i;
- // Check for canonicalization.
- if (match_result.matched_bmm_2_grad_2->operand_index(v_tensor) != 1) {
- match_result.bmm_2_grad_2_need_canonicalization = true;
- }
- match_result.has_match = true;
- }
- }
- return match_result;
-}
-
-MatchBwdResult MatchDbias(MatchBwdResult previous_result,
- HloInstruction* d_intermediate,
- const absl::flat_hash_set<HloInstruction*> users) {
- MatchBwdResult match_result = previous_result;
- auto user_count = d_intermediate->user_count();
- HloInstruction* dbias_user = nullptr;
- HloInstruction* dbias = nullptr;
- for (auto user : d_intermediate->users()) {
- if (users.contains(user)) {
- user_count -= 1;
- } else {
- dbias_user = user;
- }
- }
- auto ConsumeExtraConvert = [](HloInstruction* instr) {
- Match(instr->users()[0], m::Convert(&instr, m::Op()).WithOneUse());
- return true;
- };
- // user_count == 1 && (reduce-> {convert} ->bitcast)
- match_result.has_match =
- user_count == 1 &&
- Match(dbias_user, m::Reduce(&dbias, m::Op(), m::Op()).WithOneUse()) &&
- dbias->shape().rank() == 3 && ConsumeExtraConvert(dbias);
- if (match_result.has_match) {
- // cuDNN only supports dbias for [1, h, s, s]
- // make sure reduce dimension is on batch dim
- auto reduce_dim = dbias->dimensions();
- if (reduce_dim.size() == 1 && reduce_dim[0] == 0) {
- match_result.matched_dbias = dbias;
- } else {
- match_result.has_match = false;
- }
- }
- return match_result;
-}
-
-MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result,
- HloInstruction* fwd_fmha_call) {
- MatchBwdResult match_result = previous_result;
- bool is_bmm1_grad1_canonicalized =
- match_result.bmm_1_grad_1_need_canonicalization;
- match_result.has_match = false;
- bool has_scale = false;
- bool has_dropout = false;
- // Backward dropout pattern
- // select(mask, bmm2_grad2, broadcast())
- auto bwd_dropout_pattern_form_1 = m::SharedSubpattern(
- OptionalBitcast(OptionalReshape(OptionalConvert(m::Select(
- m::Op(), m::Op().WithPredicate([&](const HloInstruction* instr) {
- return instr == match_result.matched_bmm_2_grad_2;
- }),
- m::Broadcast(
- OptionalConvert(m::Constant().WithPredicate(IsScalar))))))));
-
- // multiply(bmm2_grad2, broadcast(select(mask, broadcast(), op())))
- auto bwd_dropout_pattern_form_2 =
- m::SharedSubpattern(OptionalBitcast(m::MultiplyAnyOrder(
- OptionalConvert(
- m::Op().WithPredicate([&](const HloInstruction* instr) {
- return instr == match_result.matched_bmm_2_grad_2;
- })),
- m::Broadcast(OptionalConvert(OptionalBitcast(OptionalReshape(
- m::Select(m::Op(),
- m::Broadcast(OptionalConvert(
- m::Constant().WithPredicate(IsScalar))),
- m::Op()))))))));
- auto bwd_dropout_pattern_form_3 = OptionalConvert(m::MultiplyAnyOrder(
- m::MultiplyAnyOrder(
- m::Op().WithPredicate([&](const HloInstruction* instr) {
- return instr == match_result.matched_bmm_2_grad_2;
- }),
- m::Broadcast(m::Constant().WithPredicate(IsScalar))),
- m::Op()));
- auto bwd_dropout_pattern = m::AnyOf<HloInstruction>(
- bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2,
- bwd_dropout_pattern_form_3);
- // Backward softmax pattern
- HloInstruction* bwd_softmax_input = nullptr;
- HloInstruction* exp_1;
- HloInstruction* exp_2;
- HloInstruction* d_softmax;
-
- // d_softmax = exp * (dy / s_b - sum(dy * exp * 1 / s^2))
- // there could be at most 3 users of d_softmax: bmm1grad1 bmm1grad2 and dbias
- auto bwd_softmax_pattern = OptionalBitcast(OptionalConvert(
- m::MultiplyAnyOrder(
- &d_softmax,
- m::AddAnyOrder(
- m::Divide().WithOneUse(),
- m::Broadcast(OptionalBitcast(OptionalConvert(
- m::Negate(
- OptionalBitcast(
- m::Op()
- .WithPredicate(IsReduceSum)
- .WithOneUse()
- .WithOperand(
- 0, OptionalBitcast(
- m::MultiplyAnyOrder(
- m::MultiplyAnyOrder(
- m::Op(&bwd_softmax_input),
- m::Broadcast())
- .WithOneUse(),
- m::Exp(&exp_2, m::Op()))
- .WithOneUse()))))
- .WithOneUse())))),
- m::Exp(&exp_1, m::Op()))
- .WithAtMostNumUser(3)));
-
- // Backward scale input pattern
- HloInstruction* bwd_scale_input = nullptr;
- HloInstruction* bwd_scale = nullptr;
-
- auto bwd_scale_pattern =
- m::MultiplyAnyOrder(&bwd_scale, m::Op(&bwd_scale_input),
- m::Broadcast(m::Constant().WithPredicate(IsScalar)))
- .WithNumUser(2);
- int intermediate_input_pos = is_bmm1_grad1_canonicalized ? 1 : 0;
-
- HloInstruction* intermediate_input =
- match_result.matched_bmm_1_grad_1->mutable_operand(
- intermediate_input_pos);
-
- has_scale = Match(intermediate_input, bwd_scale_pattern);
-
- if (has_scale) {
- intermediate_input = bwd_scale_input;
- }
-
- if (!Match(intermediate_input, bwd_softmax_pattern) || exp_1 != exp_2) {
- return match_result;
- }
- has_dropout = Match(bwd_softmax_input, bwd_dropout_pattern);
- // If no dropout but softmax input is not coming from bmm2 gradient gemm 2,
- // then it's not the pattern that we care about.
- if (!has_dropout &&
- !Match(bwd_softmax_input,
- OptionalConvert((OptionalBitcast(
- m::Op().WithPredicate([&](const HloInstruction* instr) {
- return instr == match_result.matched_bmm_2_grad_2;
- })))))) {
- return match_result;
- }
-
- if (has_dropout) {
- // has bias
- if (fwd_fmha_call->custom_call_target() ==
- kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget)
- match_result.matched_custom_call_name =
- kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
- // no bias
- if (fwd_fmha_call->custom_call_target() ==
- kCudnnfMHASoftmaxDropoutCallTarget)
- match_result.matched_custom_call_name =
- kCudnnfMHASoftmaxDropoutBackwardCallTarget;
- } else {
- // has bias
- if (fwd_fmha_call->custom_call_target() ==
- kCudnnfMHAScaleBiasSoftmaxCallTarget)
- match_result.matched_custom_call_name =
- kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget;
- // no bias
- if (fwd_fmha_call->custom_call_target() == kCudnnfMHASoftmaxCallTarget)
- match_result.matched_custom_call_name =
- kCudnnfMHASoftmaxBackwardCallTarget;
- }
- // try to pattern match dbias
- HloInstruction* dS = d_softmax;
- if (dS->users()[0]->opcode() == HloOpcode::kConvert) {
- dS = dS->users()[0];
- }
- if (has_scale) {
- // bmm1-(scale)-(bias)-softmax pattern users could be dbias or scale bwd
- if (dS->user_count() == 1) {
- // no dbias
- match_result.has_match = true;
- } else if (dS->user_count() == 2) {
- match_result = MatchDbias(match_result, dS, {bwd_scale});
- } else {
- match_result.has_match = false;
- }
- } else {
- // bmm1-(bias)-softmax pattern
- // users could be dbias besides bmm1grad1 bmm1grad2
- if (dS->user_count() == 2) {
- match_result.has_match = true;
- } else if (dS->user_count() == 3) {
- match_result = MatchDbias(match_result, dS,
- {match_result.matched_bmm_1_grad_1,
- match_result.matched_bmm_1_grad_2});
- } else {
- match_result.has_match = false;
- }
- }
- return match_result;
-}
-// First, we look for the bmm2 gradient gemm 1 which takes the activation
-// output from a forward fmha call.
-// Secondly, look for bmm2 gradient gemm 2 that takes the v tensor as an
-// input. We take the v tensor from the third operand of the forward fmha
-// call. If forward is canonicalized, then we skip the additional transpose in
-// between.
-// Then we look for bmm1 gradient gemm1 by searching for gemms that share q
-// tensor with current fmha call.
-MatchBwdResult MatchBackwardBmms(HloInstruction* fwd_fmha_call,
- HloInstruction* bmm_1, bool v_transposed) {
- MatchBwdResult matched_result = MatchBmm2GradGemm1(fwd_fmha_call);
- if (!matched_result.has_match) {
- return matched_result;
- }
-
- matched_result =
- MatchBmm2GradGemm2(matched_result, fwd_fmha_call, v_transposed);
- if (!matched_result.has_match) {
- return matched_result;
- }
-
- matched_result = MatchBmm1GradGemm1(matched_result, bmm_1);
- if (!matched_result.has_match) {
- return matched_result;
- }
-
- matched_result = MatchBmm1GradGemm2(matched_result, fwd_fmha_call);
- if (!matched_result.has_match) {
- return matched_result;
- }
- return matched_result;
-}
-// We will match the backward graphs for all forward patterns defined in
-// MatchFwdMHAPatternsForCanonicalization
-MatchBwdResult MatchBwdMHAPatternsForCanonicalization(
- HloInstruction* fwd_fmha_call, HloInstruction* bmm_1, bool v_transposed) {
- MatchBwdResult match_result =
- MatchBackwardBmms(fwd_fmha_call, bmm_1, v_transposed);
- if (!match_result.has_match) {
- return match_result;
- }
- match_result = MatchBwdBmmSoftmaxDropoutBmm(match_result, fwd_fmha_call);
- return match_result;
-}
-
-absl::StatusOr<bool> IsMHABlockSupported(
- HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization,
- bool is_training, bool is_causal_mask, std::string& custom_call_name,
- const DebugOptions& debug_options,
- stream_executor::CudaComputeCapability cc,
- stream_executor::dnn::VersionInfo cudnn_version) {
- if (MHACallHasDropout(custom_call_name) &&
- !debug_options.xla_gpu_fused_attention_use_cudnn_rng()) {
- VLOG(3) << "Using CUDNN RNG for fused attention dropout is not enabled.\n";
- return false;
- }
-
- // cuDNN 8.8 currently only supports BF16 and F16 data types.
- if (!IsSupportedPrimitiveType(bmm_1) || !IsSupportedPrimitiveType(bmm_2)) {
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Unsupported primitive type for cuDNN MHA fusion:\n"
- << bmm_1->ToString() << "\nOR\n"
- << bmm_2->ToString() << "\n"
- << "BF16 and F16 are the supported Dtypes.";
- }
- return false;
- }
-
- if (bmm_1->shape().rank() != 4 || bmm_2->shape().rank() != 4) {
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Unsupported bmm rank for cuDNN MHA fusion:\n"
- << bmm_1->ToString() << "\nOR\n"
- << bmm_2->ToString() << "\n"
- << "Only bmm with rank 4 is supported.";
- }
- return false;
- }
-
- // get batch/num heads/sequence length/hidden dim from bmm1 and bmm2
- // also make sure they are the same between bmm1 and bmm2
- TF_ASSIGN_OR_RETURN(std::optional<QKVLayout> qkv_layout,
- GetQKVLayout(bmm_1, bmm_2, need_canonicalization));
- if (!qkv_layout.has_value()) {
- VLOG(2) << "bmm1 and bmm2 have different qkv layout.";
- return false;
- }
-
- // check if matched attention block is supported by cuDNN flash attention.
- TF_ASSIGN_OR_RETURN(
- bool is_flash_attention,
- IsFlashAttention(qkv_layout.value(), is_training, cc, cudnn_version));
- if (is_flash_attention) {
- if (is_causal_mask) {
- // if bias is causal mask, needs to remove bias from name
- if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget) {
- custom_call_name = kCudnnfMHASoftmaxDropoutCallTarget;
- } else if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxCallTarget) {
- custom_call_name = kCudnnfMHASoftmaxCallTarget;
- }
- }
- }
- return is_flash_attention;
-}
-
-absl::StatusOr<HloInstruction*> CanonicalizeBatchedGemmForcuDNNFMHA(
- HloInstruction* bmm, HloComputation* comp) {
- if (VLOG_IS_ON(3)) {
- VLOG(3) << "Before FMHA Dot Cannonicalization: \n"
- << comp->parent()->ToString();
- }
- HloInstruction* lhs_bmm = bmm->mutable_operand(0);
- HloInstruction* rhs_bmm = bmm->mutable_operand(1);
- const DotDimensionNumbers& dnums = bmm->dot_dimension_numbers();
-
- int64_t rank = bmm->shape().dimensions_size();
- std::vector<int64_t> perm(rank);
- std::iota(perm.begin(), perm.end(), 0);
- // Swap the non-contracting dims of BMM shape. By contract, the
- // non-contracting dims in the output are the last two dimensions.
- std::swap(perm[rank - 1], perm[rank - 2]);
-
- DotDimensionNumbers new_dnums = dnums;
- std::swap(*new_dnums.mutable_lhs_contracting_dimensions(),
- *new_dnums.mutable_rhs_contracting_dimensions());
- std::swap(*new_dnums.mutable_lhs_batch_dimensions(),
- *new_dnums.mutable_rhs_batch_dimensions());
- auto original_bmm_shape = bmm->shape();
- HloInstruction* new_dot = comp->AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(original_bmm_shape.element_type(),
- Permute(original_bmm_shape.dimensions(), perm)),
- /* lhs */ rhs_bmm, /* rhs */ lhs_bmm, new_dnums,
- bmm->precision_config()));
-
- TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
- bmm, HloInstruction::CreateTranspose(original_bmm_shape, new_dot, perm)));
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "After FMHA Dot Cannonicalization: \n"
- << comp->parent()->ToString();
- }
- return new_dot;
-}
-
-absl::StatusOr<HloInstruction*> ChangeCheckedDimToFastest(
- HloComputation* comp, HloInstruction* bmm, bool is_lhs,
- bool should_contracting_be_fastest) {
- const DotDimensionNumbers& dot_dims_bmm = bmm->dot_dimension_numbers();
- DotDimensionNumbers new_dot_dims_bmm = dot_dims_bmm;
- int64_t bmm_operand = is_lhs ? 0 : 1;
- absl::Span<const int64_t> contracting_dims =
- is_lhs ? dot_dims_bmm.lhs_contracting_dimensions()
- : dot_dims_bmm.rhs_contracting_dimensions();
- absl::Span<const int64_t> batch_dims =
- is_lhs ? dot_dims_bmm.lhs_batch_dimensions()
- : dot_dims_bmm.rhs_batch_dimensions();
- absl::Span<const int64_t> lhs_minor_to_major_bmm =
- bmm->operand(0)->shape().layout().minor_to_major();
- absl::Span<const int64_t> rhs_minor_to_major_bmm =
- bmm->operand(1)->shape().layout().minor_to_major();
-
- absl::Span<const int64_t>& minor_to_major_to_check =
- is_lhs ? lhs_minor_to_major_bmm : rhs_minor_to_major_bmm;
-
- CHECK_EQ(contracting_dims.size(), 1);
- TF_ASSIGN_OR_RETURN(std::vector<int64_t> non_contracting_dims,
- GetNonContractingDims(bmm->operand(bmm_operand)->shape(),
- batch_dims, contracting_dims));
- CHECK_EQ(non_contracting_dims.size(), 1);
- HloInstruction* operand_bmm = bmm->mutable_operand(bmm_operand);
- int64_t hidden_dim = should_contracting_be_fastest ? contracting_dims[0]
- : non_contracting_dims[0];
- int64_t minor_dim = minor_to_major_to_check[0];
- // If the hidden dim of the target operand is not the fastest moving
- // dimension, make it so.
- if (minor_dim != hidden_dim) {
- std::vector<int64_t> perm(bmm->shape().dimensions_size());
- std::iota(perm.begin(), perm.end(), 0);
- std::swap(perm[hidden_dim], perm[minor_dim]);
-
- if (is_lhs) {
- new_dot_dims_bmm.set_lhs_contracting_dimensions(0,
- non_contracting_dims[0]);
- } else {
- new_dot_dims_bmm.set_rhs_contracting_dimensions(0,
- non_contracting_dims[0]);
- }
-
- operand_bmm = comp->AddInstruction(
- HloInstruction::CreateTranspose(
- ShapeUtil::MakeShapeWithDenseLayout(
- bmm->shape().element_type(),
- Permute(operand_bmm->shape().dimensions(), perm),
- minor_to_major_to_check),
- operand_bmm, perm),
- &operand_bmm->metadata());
- *((DynCast<HloDotInstruction>(bmm))->mutable_dot_dimension_numbers()) =
- new_dot_dims_bmm;
- }
- return operand_bmm;
-}
-
-absl::StatusOr<HloInstruction*> FuseFwdMultiHeadedAttentionBlock(
- HloComputation* comp, HloInstruction* bmm_1, HloInstruction* bmm_2,
- HloInstruction* bias, HloInstruction* scale, HloInstruction* reduce_sum,
- HloInstruction* softmax_input, double dropout_rate,
- std::string& custom_call_name, stream_executor::CudaComputeCapability cc,
- bool is_training, bool& changed, bool& v_transposed, bool is_causal_mask) {
- double scale_value = 1.0;
- HloInstruction* lhs_bmm1;
- HloInstruction* rhs_bmm1;
- HloInstruction* rhs_bmm2;
- DotDimensionNumbers orig_bmm1_dot_dim = bmm_1->dot_dimension_numbers();
- DotDimensionNumbers orig_bmm2_dot_dim = bmm_2->dot_dimension_numbers();
-
- TF_ASSIGN_OR_RETURN(rhs_bmm1, ChangeCheckedDimToFastest(
- comp, bmm_1, false /*is_lhs*/,
- true /*should_contracting_be_fastest*/));
- TF_ASSIGN_OR_RETURN(lhs_bmm1, ChangeCheckedDimToFastest(
- comp, bmm_1, true /*is_lhs*/,
- true /*should_contracting_be_fastest*/));
-
- TF_ASSIGN_OR_RETURN(rhs_bmm2, ChangeCheckedDimToFastest(
- comp, bmm_2, false /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
-
- if (rhs_bmm2 != bmm_2->mutable_operand(1)) {
- v_transposed = true;
- }
-
- GpuBackendConfig gpu_config;
- CudnnfMHABackendConfig& fmha_config =
- *gpu_config.mutable_cudnn_fmha_backend_config();
-
- *fmha_config.mutable_bmm1_dot_dimension_numbers() =
- bmm_1->dot_dimension_numbers();
- *fmha_config.mutable_bmm2_dot_dimension_numbers() =
- bmm_2->dot_dimension_numbers();
-
- TF_RET_CHECK((dropout_rate >= 0.0 && dropout_rate <= 1.0));
- // Restore original DotDimensionNumbers.
- *((DynCast<HloDotInstruction>(bmm_1))->mutable_dot_dimension_numbers()) =
- orig_bmm1_dot_dim;
- *((DynCast<HloDotInstruction>(bmm_2))->mutable_dot_dimension_numbers()) =
- orig_bmm2_dot_dim;
-
- // If scale node is assigned, extract value from it.
- if (scale != nullptr) {
- std::optional<double> value;
- value = GetConstantValue(scale);
- TF_RET_CHECK(value.has_value());
- scale_value = (double)*value;
- }
-
- fmha_config.set_fmha_scale(scale_value);
- fmha_config.set_dropout_rate(dropout_rate);
- // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
- // graph.
- // TODO Find a way to compute original seed from dropout keys.
- fmha_config.set_seed(42);
-
- *fmha_config.mutable_intermediate_tensor_shape() = bmm_1->shape().ToProto();
- {
- auto* algorithm = fmha_config.mutable_algorithm();
- algorithm->set_algo_id(0); // engine id
- algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
- std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
- std::vector<int64_t> knob_vals = {1, 0};
- for (int i = 0; i < knob_ids.size(); ++i) {
- (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
- }
- algorithm->set_is_cudnn_frontend(true);
- algorithm->mutable_workspace_size()->set_value(0);
- }
- // set is_causal_mask here
- // choose to generate causal mask inside cuDNN attention or not
- fmha_config.set_mask_type(is_causal_mask ? CudnnfMHABackendConfig::CAUSAL
- : CudnnfMHABackendConfig::NO_MASK);
-
- const Shape& output_shape = bmm_2->shape();
-
- Shape call_shape;
- // Activation output is used by backward gemm.
- HloInstruction* activation_output = nullptr;
-
- // Output Order: {O, Fwd act*, workspace}
- std::vector<Shape> output_shapes = {output_shape};
- if (is_training) {
- activation_output = bmm_2->mutable_operand(0);
- // Sometimes activation output is bitcast, the actual activation is the
- // other user of the producer of bmm_2's first operand.
- if (activation_output->user_count() < 2 &&
- activation_output->opcode() == HloOpcode::kBitcast) {
- HloInstruction* producer = activation_output->mutable_operand(0);
- TF_RET_CHECK(producer->user_count() == 2);
- HloInstruction* bmm2_grad2_user =
- producer->users()[0] == activation_output ? producer->users()[1]
- : producer->users()[0];
- // might be (transpose) - bmm2_grad2
- if (IsBatchedMatmul(bmm2_grad2_user)) {
- activation_output = producer;
- } else if (bmm2_grad2_user->opcode() == HloOpcode::kTranspose) {
- activation_output = bmm2_grad2_user;
- } else {
- return Internal("Unexpected activation patterns");
- }
- }
- // if it is flash attention, should output softmax stats to the bwd
- TF_RET_CHECK(reduce_sum != nullptr);
- output_shapes.push_back(
- ShapeUtil::MakeShape(F32, reduce_sum->shape().dimensions()));
- }
- output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
- call_shape = ShapeUtil::MakeTupleShape(output_shapes);
-
- // Input Order: {Q, K, V, bias*}
- std::vector<HloInstruction*> operands = {lhs_bmm1, rhs_bmm1, rhs_bmm2};
- if (!is_causal_mask && bias != nullptr) {
- HloInstruction* original_bias;
- HloInstruction* original_broadcast;
- // There will be cases where the bias is up-casted to wider float type,
- // we need to take the original bias node and broadcast it without
- // converting.
- if (Match(bias, m::Broadcast(
- &original_broadcast,
- m::Convert(
- m::Op(&original_bias)
- .WithPredicate([](const HloInstruction* instr) {
- return instr->shape().element_type() == F16 ||
- instr->shape().element_type() == BF16;
- }))
- .WithPredicate([](const HloInstruction* instr) {
- return instr->shape().element_type() == F32 ||
- instr->shape().element_type() == F64;
- })))) {
- absl::Span<const int64_t> original_bcast_dims =
- (DynCast<HloBroadcastInstruction>(original_broadcast))->dimensions();
- // This is to deal with cases like paxml where an extra dimension of 1 is
- // added to the left of the tensor.
- // TODO Make this logic more generic
- absl::Span<const int64_t> original_broadcast_shape_dims =
- original_broadcast->shape().dimensions();
- int64_t starting_index = original_broadcast_shape_dims.size() == 5 &&
- original_broadcast_shape_dims[0] == 1
- ? 1
- : 0;
- std::vector<int64_t> bcast_dimensions;
- for (auto& dim : original_bcast_dims) {
- bcast_dimensions.push_back(dim - starting_index);
- }
-
- const Shape& bcast_shape = bmm_1->shape();
- bias = comp->AddInstruction(HloInstruction::CreateBroadcast(
- bcast_shape, original_bias, bcast_dimensions));
- }
- operands.push_back(bias);
- }
-
- HloInstruction* fmha_call =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- call_shape, operands, absl::string_view(custom_call_name)));
- TF_RETURN_IF_ERROR(fmha_call->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(SetFMHAInstructionName(bmm_1->GetModule(), fmha_call));
-
- TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
- bmm_2,
- HloInstruction::CreateGetTupleElement(bmm_2->shape(), fmha_call, 0)));
-
- if (activation_output) {
- HloInstruction* activation_gte =
- comp->AddInstruction(HloInstruction::CreateGetTupleElement(
- activation_output->shape(), fmha_call, 1));
- TF_RETURN_IF_ERROR(comp->ReplaceInstructionWithDifferentShape(
- activation_output, activation_gte,
- /*preserve_sharding=*/false,
- /*relay_control_dependency=*/false,
- /*remove_unused_operands=*/false)
- .status());
- }
-
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "After CudnnFusedMHARewriter: \n" << comp->parent()->ToString();
- }
- changed = true;
- return fmha_call;
-}
-
-absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
- HloComputation* comp, HloInstruction* bmm_1_grad_1,
- HloInstruction* bmm_1_grad_2, HloInstruction* bmm_2_grad_1,
- HloInstruction* bmm_2_grad_2, HloInstruction* fwd_fmha_call,
- HloInstruction* dbias, HloInstruction* bias,
- std::string& bwd_custom_call_name) {
- HloInstruction* rhs_bmm1_grad_gemm1;
- HloInstruction* lhs_bmm1_grad_gemm2;
- HloInstruction* rhs_bmm2_grad_gemm2;
- HloInstruction* d_output_grad;
-
- DotDimensionNumbers orig_bmm1_grad1_config =
- bmm_1_grad_1->dot_dimension_numbers();
- DotDimensionNumbers orig_bmm1_grad2_config =
- bmm_1_grad_2->dot_dimension_numbers();
- DotDimensionNumbers orig_bmm2_grad1_config =
- bmm_2_grad_1->dot_dimension_numbers();
- DotDimensionNumbers orig_bmm2_grad2_config =
- bmm_2_grad_2->dot_dimension_numbers();
-
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- fwd_fmha_call->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& fwd_config =
- gpu_config.cudnn_fmha_backend_config();
- bool is_causal_mask =
- fwd_config.mask_type() == CudnnfMHABackendConfig::CAUSAL;
- CudnnfMHABackendConfig bwd_fmha_config;
- // Q tensor
- TF_ASSIGN_OR_RETURN(
- rhs_bmm1_grad_gemm1,
- ChangeCheckedDimToFastest(comp, bmm_1_grad_1, false /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
- // K tensor
- TF_ASSIGN_OR_RETURN(
- lhs_bmm1_grad_gemm2,
- ChangeCheckedDimToFastest(comp, bmm_1_grad_2, false /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
-
- // Forward activation
- // softmax_stats
- HloInstruction* fwd_act;
- int64_t fwd_act_index = 1;
- fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
- fwd_fmha_call->shape().tuple_shapes(fwd_act_index), fwd_fmha_call,
- fwd_act_index));
-
- // V tensor
- TF_ASSIGN_OR_RETURN(
- rhs_bmm2_grad_gemm2,
- ChangeCheckedDimToFastest(comp, bmm_2_grad_2, false /*is_lhs*/,
- true /*should_contracting_be_fastest*/));
- // d output to bmm2_grad2
- // Since d_o is the input of 2 bmms, we set the dim number using the
- // constraint
- // -> the contracting dimension of the lhs of bmm_2_grad_2 needs to be the
- // fastest moving dimension.
- TF_ASSIGN_OR_RETURN(
- d_output_grad,
- ChangeCheckedDimToFastest(comp, bmm_2_grad_2, true /*is_lhs*/,
- true /*should_contracting_be_fastest*/));
- // d output to bmm2_grad1
- // we don't use this value but we call this to make sure dot number is being
- // set correctly
- TF_ASSIGN_OR_RETURN(
- HloInstruction * bmm_2_grad_1_rhs,
- ChangeCheckedDimToFastest(comp, bmm_2_grad_1, false /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
- (void)bmm_2_grad_1_rhs;
- // Operand order: {Q, K, V, Fwd act, d_o, bias*, O*}
- std::vector<HloInstruction*> operands = {
- rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, fwd_act,
- d_output_grad};
-
- // For flash attention, add fwd output to input list
- if (!is_causal_mask && bias) {
- operands.push_back(bias);
- }
- HloInstruction* fwd_output;
- for (auto user : fwd_fmha_call->users()) {
- if (user->opcode() == HloOpcode::kGetTupleElement &&
- user->tuple_index() == 0) {
- fwd_output = user;
- }
- }
- // should be able to find the instruction
- TF_RET_CHECK(fwd_output != nullptr);
- // check dO and O have the same layout as it is required by cuDNN
- TF_RET_CHECK(fwd_output->shape() == d_output_grad->shape());
- operands.push_back(fwd_output);
-
- *bwd_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
- bmm_1_grad_1->dot_dimension_numbers();
- *bwd_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
- bmm_1_grad_2->dot_dimension_numbers();
- *bwd_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
- bmm_2_grad_1->dot_dimension_numbers();
- *bwd_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
- bmm_2_grad_2->dot_dimension_numbers();
-
- // Restore original DotDimensionNumbers
- *((DynCast<HloDotInstruction>(bmm_1_grad_1))
- ->mutable_dot_dimension_numbers()) = orig_bmm1_grad1_config;
- *((DynCast<HloDotInstruction>(bmm_1_grad_2))
- ->mutable_dot_dimension_numbers()) = orig_bmm1_grad2_config;
- *((DynCast<HloDotInstruction>(bmm_2_grad_1))
- ->mutable_dot_dimension_numbers()) = orig_bmm2_grad1_config;
- *((DynCast<HloDotInstruction>(bmm_2_grad_2))
- ->mutable_dot_dimension_numbers()) = orig_bmm2_grad2_config;
-
- bwd_fmha_config.set_fmha_scale(fwd_config.fmha_scale());
- bwd_fmha_config.set_dropout_rate(fwd_config.dropout_rate());
- // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
- // graph.
- // TODO Find a way to compute original seed from dropout keys.
- bwd_fmha_config.set_seed(fwd_config.seed());
- bwd_fmha_config.set_mask_type(is_causal_mask
- ? CudnnfMHABackendConfig::CAUSAL
- : CudnnfMHABackendConfig::NO_MASK);
-
- *bwd_fmha_config.mutable_intermediate_tensor_shape() =
- fwd_config.intermediate_tensor_shape();
- {
- auto* algorithm = bwd_fmha_config.mutable_algorithm();
- algorithm->set_algo_id(0); // engine id
- algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
- std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
- std::vector<int64_t> knob_vals = {1, 0};
- for (int i = 0; i < knob_ids.size(); ++i) {
- (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
- }
- algorithm->set_is_cudnn_frontend(true);
- algorithm->mutable_workspace_size()->set_value(0);
- }
-
- // Output order:
- // {dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), dbias*, workspace}
- std::vector<Shape> output_shapes = {
- bmm_1_grad_2->shape(), bmm_1_grad_1->shape(), bmm_2_grad_1->shape()};
-
- if (dbias) {
- // Cudnn kernel only outputs dbias in this shape [1, num_heads, seq, seq],
- // so we add a dimension of 1 to existing dbias' shape.
- std::vector<int64_t> dbias_shape_vector =
- SpanToVector(dbias->shape().dimensions());
- dbias_shape_vector.insert(dbias_shape_vector.begin(), 1);
- Shape cudnn_dbias_shape =
- ShapeUtil::MakeShape(dbias->shape().element_type(), dbias_shape_vector);
- output_shapes.push_back(cudnn_dbias_shape);
- }
- // Reserved placeholder for workspace
- output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
- Shape call_shape = ShapeUtil::MakeTupleShape(output_shapes);
- HloInstruction* fmha_bwd_call =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- call_shape, operands, absl::string_view(bwd_custom_call_name)));
- GpuBackendConfig bwd_gpu_config;
- *bwd_gpu_config.mutable_cudnn_fmha_backend_config() = bwd_fmha_config;
- TF_RETURN_IF_ERROR(fmha_bwd_call->set_backend_config(bwd_gpu_config));
- TF_RETURN_IF_ERROR(
- SetFMHAInstructionName(bmm_1_grad_1->GetModule(), fmha_bwd_call));
-
- // Q gradient
- TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
- bmm_1_grad_2, HloInstruction::CreateGetTupleElement(bmm_1_grad_2->shape(),
- fmha_bwd_call, 0)));
- // K gradient
- TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
- bmm_1_grad_1, HloInstruction::CreateGetTupleElement(bmm_1_grad_1->shape(),
- fmha_bwd_call, 1)));
- // V gradient
- TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
- bmm_2_grad_1, HloInstruction::CreateGetTupleElement(bmm_2_grad_1->shape(),
- fmha_bwd_call, 2)));
-
- if (dbias) {
- // Reshape fmha dbias output to original user's input shape.
- // If the reshape doesn't involve physical data movement,
- // algebraic simplifer can change it to a no-op bitcast.
- Shape original_shape = dbias->shape();
- HloInstruction* dbias_user = dbias->users()[0];
- HloInstruction* cudnn_dbias_output =
- comp->AddInstruction(HloInstruction::CreateGetTupleElement(
- output_shapes[3], fmha_bwd_call, 3));
- HloInstruction* reshape_dbias = comp->AddInstruction(
- HloInstruction::CreateReshape(original_shape, cudnn_dbias_output));
- TF_RETURN_IF_ERROR(dbias_user->ReplaceOperandWith(
- dbias_user->operand_index(dbias), reshape_dbias));
-
- TF_RETURN_IF_ERROR(
- comp->ReplaceInstructionWithDifferentShape(dbias, cudnn_dbias_output));
- }
- return true;
-}
-
-absl::Status RestoreFwdGraph(
- HloComputation* comp, HloInstruction* fwd_fmha_call, HloInstruction* bmm2,
- HloInstruction* activation, HloInstruction* original_bmm2_producer0,
- HloInstruction* original_bmm2_producer1,
- std::vector<HloInstruction*>& original_activation_producers,
- bool bmm_2_need_canonicalization) {
- // If backward pattern is not matched, we need to restore the
- // original graph structure.
- // Replacing new GTEs added by forward FMHA call with cloned old
- // activations and bmm2.
- HloInstruction* output_gte = fwd_fmha_call->users()[0];
- HloInstruction* activation_gte = fwd_fmha_call->users()[1];
- std::string suffix = "fmha_no_match_clone";
- HloInstruction* cloned_activation =
- comp->AddInstruction(activation->CloneWithNewOperands(
- activation->shape(), original_activation_producers, suffix));
-
- // Since old activation is detached by forward FMHA rewrite, we need
- // to use the newly cloned activation.
- HloInstruction* lhs = activation == original_bmm2_producer0
- ? cloned_activation
- : original_bmm2_producer0;
- HloInstruction* rhs = activation == original_bmm2_producer0
- ? original_bmm2_producer1
- : cloned_activation;
- HloInstruction* cloned_bmm2 = comp->AddInstruction(
- bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix));
- if (bmm_2_need_canonicalization) {
- TF_RET_CHECK(output_gte->users()[0]->opcode() == HloOpcode::kTranspose);
- TF_RETURN_IF_ERROR(
- comp->ReplaceInstruction(output_gte->users()[0], cloned_bmm2));
- } else {
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
- }
- TF_RETURN_IF_ERROR(
- comp->ReplaceInstruction(activation_gte, cloned_activation));
- return absl::OkStatus();
-}
-} // namespace
-
-absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool any_changed = false;
- // we use this set to keep track of all already matched attention block
- absl::flat_hash_set<HloInstruction*> matched_bmm1;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- const DebugOptions& debug_options =
- comp->parent()->config().debug_options();
- const se::dnn::VersionInfo cudnn_version =
- GetDnnVersionInfoOrDefault(stream_executor_, cudnn_version_);
-#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
- // CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware.
- // Some cuDNN versions work with CUDA 11, but it is impractical for us to
- // test those combinations so just disable them.
- return false;
-#endif
- if (!debug_options.xla_gpu_enable_cudnn_fmha() ||
- !IsComputeCapabilityAndCudnnSupported(
- compute_capability_, cudnn_version,
- stream_executor::dnn::VersionInfo(8, 9, 4))) {
- return false;
- }
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- bool v_transposed = false;
- bool changed = false;
- MatchFwdResult matched_result =
- MatchFwdMHAPatternsForCanonicalization(instr);
- if (!matched_result.has_match) {
- continue;
- }
- // We check the validity of bmms here before canonicalization so we don't
- // modify the graph if mha fusion is not possible
- TF_ASSIGN_OR_RETURN(
- bool is_mha_module_supported,
- IsMHABlockSupported(
- matched_result.matched_bmm_1, matched_result.matched_bmm_2,
- matched_result.need_canonicalization, matched_result.is_training,
- matched_result.is_causal_mask,
- matched_result.matched_custom_call_name, debug_options,
- compute_capability_, cudnn_version));
-
- if (!is_mha_module_supported) continue;
- // If we have an activation with more than 1 users in non-training mode,
- // we cannot rewrite the graph. So skip processing the rest.
- HloInstruction* activation =
- matched_result.need_canonicalization
- ? matched_result.matched_bmm_2->mutable_operand(1)
- : matched_result.matched_bmm_2->mutable_operand(0);
- if (!matched_result.is_training && activation->user_count() > 1) {
- VLOG(2)
- << "Activation: " << activation->ToString()
- << " cannot have more than 1 users in non-training mode. Skipping.";
- continue;
- }
- HloInstruction* original_bmm2_producer0 =
- matched_result.matched_bmm_2->mutable_operand(0);
- HloInstruction* original_bmm2_producer1 =
- matched_result.matched_bmm_2->mutable_operand(1);
-
- HloInstruction* original_bmm2 = matched_result.matched_bmm_2;
- std::vector<HloInstruction*> original_activation_producers;
- for (HloInstruction* operand : activation->mutable_operands()) {
- original_activation_producers.push_back(operand);
- }
- // We make sure no attention block is matched and replaced twice here
- if (!matched_bmm1.insert(matched_result.matched_bmm_1).second) {
- continue;
- }
- // If we need to canonicalize the bmm, we will assign the newly
- // canonicalized bmm to bmm_2.
- if (matched_result.need_canonicalization) {
- TF_ASSIGN_OR_RETURN(matched_result.matched_bmm_2,
- CanonicalizeBatchedGemmForcuDNNFMHA(
- matched_result.matched_bmm_2, comp));
- }
-
- // Fuse the bmms and intermediate nodes into fMHA call, the fused call
- // will replace bmm_2.
- TF_ASSIGN_OR_RETURN(
- HloInstruction * fwd_fmha_call,
- FuseFwdMultiHeadedAttentionBlock(
- comp, matched_result.matched_bmm_1, matched_result.matched_bmm_2,
- matched_result.matched_bias, matched_result.matched_scale,
- matched_result.matched_reduce_sum,
- matched_result.matched_softmax_input,
- matched_result.matched_dropout_rate,
- matched_result.matched_custom_call_name, compute_capability_,
- matched_result.is_training, changed, v_transposed,
- matched_result.is_causal_mask));
- any_changed |= changed;
- if (matched_result.is_training) {
- MatchBwdResult matched_bwd_result =
- MatchBwdMHAPatternsForCanonicalization(
- fwd_fmha_call, matched_result.matched_bmm_1, v_transposed);
- if (!matched_bwd_result.has_match) {
- VLOG(2) << "Backward pattern not matching, skipping.";
- // restore fwd graph if bwd pattern match failed
- TF_RETURN_IF_ERROR(
- RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
- original_bmm2_producer0, original_bmm2_producer1,
- original_activation_producers,
- matched_result.need_canonicalization));
- continue;
- }
- if (matched_bwd_result.matched_dbias &&
- !(compute_capability_.IsAtLeastHopper() &&
- compute_capability_.minor == 0 &&
- cudnn_version >= stream_executor::dnn::VersionInfo(8, 9, 6))) {
- VLOG(2) << "Flash attention dbias requires cudnn 8.9.6 + hopper.";
- // restore fwd graph if bwd pattern match failed
- TF_RETURN_IF_ERROR(
- RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
- original_bmm2_producer0, original_bmm2_producer1,
- original_activation_producers,
- matched_result.need_canonicalization));
- continue;
- }
- // Canonicalize gemms
- if (matched_bwd_result.bmm_1_grad_1_need_canonicalization) {
- TF_ASSIGN_OR_RETURN(
- matched_bwd_result.matched_bmm_1_grad_1,
- CanonicalizeBatchedGemmForcuDNNFMHA(
- matched_bwd_result.matched_bmm_1_grad_1, comp));
- }
- if (matched_bwd_result.bmm_1_grad_2_need_canonicalization) {
- TF_ASSIGN_OR_RETURN(
- matched_bwd_result.matched_bmm_1_grad_2,
- CanonicalizeBatchedGemmForcuDNNFMHA(
- matched_bwd_result.matched_bmm_1_grad_2, comp));
- }
- if (matched_bwd_result.bmm_2_grad_1_need_canonicalization) {
- TF_ASSIGN_OR_RETURN(
- matched_bwd_result.matched_bmm_2_grad_1,
- CanonicalizeBatchedGemmForcuDNNFMHA(
- matched_bwd_result.matched_bmm_2_grad_1, comp));
- }
- if (matched_bwd_result.bmm_2_grad_2_need_canonicalization) {
- TF_ASSIGN_OR_RETURN(
- matched_bwd_result.matched_bmm_2_grad_2,
- CanonicalizeBatchedGemmForcuDNNFMHA(
- matched_bwd_result.matched_bmm_2_grad_2, comp));
- }
-
- // Fuse the corresponding gradient graph to an fMHA fused call.s
- TF_ASSIGN_OR_RETURN(
- changed,
- FuseBwdMultiHeadedAttentionBlock(
- comp, matched_bwd_result.matched_bmm_1_grad_1,
- matched_bwd_result.matched_bmm_1_grad_2,
- matched_bwd_result.matched_bmm_2_grad_1,
- matched_bwd_result.matched_bmm_2_grad_2, fwd_fmha_call,
- matched_bwd_result.matched_dbias, matched_result.matched_bias,
- matched_bwd_result.matched_custom_call_name));
- any_changed |= changed;
- }
- }
- }
-
- return any_changed;
-}
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h
deleted file mode 100644
index f0aa687..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-class CudnnFusedMHARewriter : public HloModulePass {
- public:
- explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
- se::StreamExecutor* stream_executor)
- : compute_capability_(cc), stream_executor_(stream_executor) {}
-
- explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
- se::dnn::VersionInfo cudnn_version)
- : compute_capability_(cc), cudnn_version_(cudnn_version) {}
-
- absl::string_view name() const override {
- return "cudnn-fused-multi-headed-attention-rewriter";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const se::CudaComputeCapability compute_capability_;
- se::StreamExecutor* stream_executor_ = nullptr;
- const se::dnn::VersionInfo cudnn_version_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
deleted file mode 100644
index 897136b..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
+++ /dev/null
@@ -1,3183 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_mha_rewriter.h"
-
-#include <cstddef>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/strings/string_view.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/computation_layout.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h"
-#include "xla/service/hlo_cse.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/hlo_verifier.h"
-#include "xla/service/layout_normalization.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/reshape_decomposer.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/test_helpers.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
-#endif
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = xla::match;
-
-class CudnnFusedMhaRewriterTestHloTest : public HloTestBase {
- public:
- se::CudaComputeCapability GetCudaComputeCapability() {
- // Fake a supported compute capability to run tests,
- // we don't run any kernels in these tests so they should be safe
- // to run anywhere.
- return se::CudaComputeCapability(8, 0);
- }
-
- se::CudaComputeCapability GetRealCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
-
- se::dnn::VersionInfo GetCudnnVersion() {
- // Fake a supported compute capability to run tests,
- // we don't run any kernels in these tests so they should be safe
- // to run anywhere.
- return se::dnn::VersionInfo(8, 9, 4);
- }
-
- CudnnFusedMhaRewriterTestHloTest()
- : HloTestBase(/*verifier_layout_sensitive=*/false,
- /*allow_mixed_precision_in_hlo_verifier=*/false,
- /*instruction_can_change_layout_func=*/{}) {
-#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
- skip_reason_ = "cuDNN fused MHA requires CUDA 12 or later.";
- return;
-#endif
- }
-
- protected:
- size_t CountFusedAttentionCall(HloModule* module, bool is_backward = false) {
- return absl::c_count_if(module->entry_computation()->instructions(),
- [&](const HloInstruction* instr) {
- if (is_backward) {
- return IsBwdCustomCallTofMHA(*instr);
- } else {
- return IsFwdCustomCallTofMHA(*instr);
- }
- });
- }
-
- DebugOptions GetDebugOptionsForTest() override {
- auto debug_options = HloTestBase::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_cudnn_fmha(true);
- debug_options.set_xla_gpu_fused_attention_use_cudnn_rng(true);
- return debug_options;
- }
-
- HloModuleConfig GetModuleConfig() {
- DebugOptions debug_options = GetDebugOptionsForTest();
- HloModuleConfig config_with_fmha;
- config_with_fmha.set_debug_options(debug_options);
- return config_with_fmha;
- }
-
- // Centralize skip checks in the constructor. Unfortunately we cannot call
- // GTEST_SKIP from the constructor. Instead, we set (if needed) `skip_reason`,
- // and then check it from all test fixtures.
- // An alternative would be to use the SetUp() override, but for this to be
- // correct we'd have to ensure that all the parents' SetUp() methods are
- // called, which is error prone.
- std::optional<absl::string_view> skip_reason_;
-};
-
-constexpr absl::string_view
- hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-
-region_0.7 {
- Arg_0.8 = bf16[] parameter(0)
- Arg_1.9 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
- Arg_0.20 = f32[] parameter(0)
- Arg_1.21 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.6 {
- Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
- Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
- Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
- dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
- constant = bf16[] constant(-inf)
- reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
- broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
- subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
- exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
- convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
- constant.1 = f32[] constant(0)
- reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
- convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
- broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
- divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
- ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-})";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1SoftmaxBmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- TF_ASSERT_OK_AND_ASSIGN(
- auto m, ParseAndReturnVerifiedModule(
- hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
- EXPECT_TRUE(result);
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
- .WithShape(BF16, {16, 16, 256, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
- 2);
-}
-
-constexpr absl::string_view
- hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-
-region_0.7 {
- Arg_0.8 = bf16[] parameter(0)
- Arg_1.9 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
- Arg_0.20 = f32[] parameter(0)
- Arg_1.21 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.6 {
- Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
- Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
- Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
- dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
- constant = bf16[] constant(-inf)
- reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
- broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
- subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
- exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
- convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
- constant.1 = f32[] constant(0)
- reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
- convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
- broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
- divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
- ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-})";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1SoftmaxBmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- TF_ASSERT_OK_AND_ASSIGN(
- auto m, ParseAndReturnVerifiedModule(
- hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
- EXPECT_TRUE(result);
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
- .WithShape(BF16, {16, 16, 256, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(config.bmm1_dot_dimension_numbers().lhs_contracting_dimensions()[0],
- 2);
- EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
- 2);
-}
-
-constexpr absl::string_view
- hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-
-region_0.7 {
- Arg_0.8 = bf16[] parameter(0)
- Arg_1.9 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
- Arg_0.20 = f32[] parameter(0)
- Arg_1.21 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.6 {
- Arg_2.3 = bf16[16,16,256,64]{2,3,1,0} parameter(2)
- Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
- Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
- dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
- constant = bf16[] constant(-inf)
- reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
- broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
- subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
- exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
- convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
- constant.1 = f32[] constant(0)
- reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
- convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
- broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
- divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
- ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-})";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1SoftmaxBmm2Pattern_bmm2_non_contracting_dim_not_most_minor) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- TF_ASSERT_OK_AND_ASSIGN(
- auto m, ParseAndReturnVerifiedModule(
- hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
- EXPECT_TRUE(result);
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
- .WithShape(BF16, {16, 16, 256, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(config.bmm2_dot_dimension_numbers().lhs_contracting_dimensions()[0],
- 3);
- EXPECT_EQ(config.bmm2_dot_dimension_numbers().rhs_contracting_dimensions()[0],
- 3);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1CombinedMaskBiasSoftmaxBmm2) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_,
-entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
-
-region_0.32.clone {
- Arg_0.0 = f32[] parameter(0)
- Arg_1.0 = f32[] parameter(1)
- ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
-}
-
-region_1.44 {
- Arg_0.45 = f32[] parameter(0)
- Arg_1.46 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-
-ENTRY main.61 {
- Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
- transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
- Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
- transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
- Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
- transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
- Arg_4.5 = pred[16,1,256,256]{3,2,1,0} parameter(4), sharding={replicated}
- bitcast.35 = pred[16,256,256]{2,1,0} bitcast(Arg_4.5)
- convert.49 = s32[16,256,256]{2,1,0} convert(bitcast.35)
- constant.5 = s32[] constant(0)
- broadcast.10 = s32[16,256,256]{2,1,0} broadcast(constant.5), dimensions={}
- compare = pred[16,256,256]{2,1,0} compare(convert.49, broadcast.10), direction=GT
- constant.7 = bf16[] constant(0)
- broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.7), dimensions={}
- constant.9 = bf16[] constant(-9.999e+09)
- broadcast.13 = bf16[16,256,256]{2,1,0} broadcast(constant.9), dimensions={}
- select = bf16[16,256,256]{2,1,0} select(compare, broadcast.12, broadcast.13)
- convert.51 = f32[16,256,256]{2,1,0} convert(select)
- broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.51), dimensions={0,2,3}
- Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
- bitcast.52 = bf16[16,256,256]{2,1,0} bitcast(Arg_3.4)
- convert.52 = f32[16,256,256]{2,1,0} convert(bitcast.52)
- broadcast.15 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.52), dimensions={1,2,3}
- add.1 = f32[16,16,256,256]{3,2,1,0} add(broadcast.14, broadcast.15)
- dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- convert.55 = f32[16,16,256,256]{3,2,1,0} convert(dot.2)
- add.18 = f32[16,16,256,256]{3,2,1,0} add(convert.55, add.1)
- constant.11 = f32[] constant(-inf)
- reduce.36 = f32[16,16,256]{2,1,0} reduce(add.18, constant.11), dimensions={3}, to_apply=region_0.32.clone
- broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
- subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.18, broadcast.17)
- exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
- constant.14 = f32[] constant(0)
- reduce.48 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.44
- broadcast.18 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.48), dimensions={0,1,2}
- divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.18)
- convert.68 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
- dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.68), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Transpose(
- m::Transpose(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
- 0)))
- .WithShape(BF16, {16, 256, 16, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- EXPECT_EQ(fmha->operands().size(), 4);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1UnfusedSoftmaxBmm2) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->f16[2,6,40,64]{3,2,1,0}}
-
-region_0.7 {
- Arg_0.8 = f16[] parameter(0)
- Arg_1.9 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
- Arg_0.20 = f32[] parameter(0)
- Arg_1.21 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.31 {
- Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
- dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- constant = f16[] constant(-inf)
- reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
- broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
- subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
- exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
- convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
- constant.1 = f32[] constant(0)
- reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
- convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
- broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
- divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
- Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
- ROOT dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
- .WithShape(F16, {2, 6, 40, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_FLOAT_EQ(config.fmha_scale(), 1.0);
- EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0);
- EXPECT_EQ(fmha->operands().size(), 3);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1ConvertedMaskAddedAfterFirstGemmSoftmaxBmm2) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
-
-region_0.27.clone {
- Arg_0.0 = f32[] parameter(0)
- Arg_1.0 = f32[] parameter(1)
- ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
-}
-
-region_1.39 {
- Arg_0.40 = f32[] parameter(0)
- Arg_1.41 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.40, Arg_1.41)
-}
-
-ENTRY main.56 {
- Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
- transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
- Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
- transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
- Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
- transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
- dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- convert.47 = f32[16,16,256,256]{3,2,1,0} convert(dot)
- Arg_3.4 = pred[16,1,256,256]{3,2,1,0} parameter(3), sharding={replicated}
- bitcast.37 = pred[16,256,256]{2,1,0} bitcast(Arg_3.4)
- convert.42 = s32[16,256,256]{2,1,0} convert(bitcast.37)
- constant.6 = s32[] constant(0)
- broadcast.9 = s32[16,256,256]{2,1,0} broadcast(constant.6), dimensions={}
- compare = pred[16,256,256]{2,1,0} compare(convert.42, broadcast.9), direction=GT
- constant.8 = bf16[] constant(0)
- broadcast.11 = bf16[16,256,256]{2,1,0} broadcast(constant.8), dimensions={}
- constant.10 = bf16[] constant(-9.999e+09)
- broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.10), dimensions={}
- select = bf16[16,256,256]{2,1,0} select(compare, broadcast.11, broadcast.12)
- convert.48 = f32[16,256,256]{2,1,0} convert(select)
- broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.48), dimensions={0,2,3}
- add.2 = f32[16,16,256,256]{3,2,1,0} add(convert.47, broadcast.14)
- constant.13 = f32[] constant(-inf)
- reduce.31 = f32[16,16,256]{2,1,0} reduce(add.2, constant.13), dimensions={3}, to_apply=region_0.27.clone
- broadcast.16 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.31), dimensions={0,1,2}
- subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.2, broadcast.16)
- exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
- constant.14 = f32[] constant(0)
- reduce.43 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.39
- broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.43), dimensions={0,1,2}
- divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.17)
- convert.63 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
- dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.63), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Transpose(
- m::Transpose(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
- 0)))
- .WithShape(BF16, {16, 256, 16, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- EXPECT_EQ(fmha->operands().size(), 4);
-}
-
-// negative test
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1Bmm2Pattern_bmm1_contracting_dim_not_equal_64) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-ENTRY main.6 {
- Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
- Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
- Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
- dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
- ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(&fmha, m::Dot(m::Parameter(0), m::Parameter(1)),
- m::Parameter(2))
- .WithShape(BF16, {16, 16, 256, 64})));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1Bmm2Pattern_bmm2_rhs_non_contracting_dim_not_equal_64) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0})->bf16[16,16,256,32]{3,2,1,0}}
-ENTRY main.6 {
- Arg_2.3 = bf16[16,16,256,32]{3,2,1,0} parameter(2)
- Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
- Arg_1.2 = bf16[16,16,256,64]{3,2,1,0} parameter(1)
- dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
- ROOT dot.1 = bf16[16,16,256,32]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(&fmha, m::Op(), m::Parameter(2))
- .WithShape(BF16, {16, 16, 256, 32})));
-}
-
-// check if MHA is unsupported, canonicalization will not kick in
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1Bmm2PatternUncanonicalized_bmm1_contracting_dim_not_equal_64) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,64,256]{3,2,1,0}}
-
-ENTRY main.6 {
- Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
- Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
- Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
- dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
- ROOT dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(Arg_2.3, dot.0), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
-
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(&fmha, m::Parameter(2), m::Op())
- .WithShape(BF16, {16, 16, 64, 256})));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxDropoutBmm2) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
-
-region_0.34 {
- Arg_0.35 = bf16[] parameter(0)
- Arg_1.36 = bf16[] parameter(1)
- ROOT maximum.37 = bf16[] maximum(Arg_0.35, Arg_1.36)
-}
-
-region_1.46 {
- Arg_0.47 = f32[] parameter(0)
- Arg_1.48 = f32[] parameter(1)
- ROOT add.49 = f32[] add(Arg_0.47, Arg_1.48)
-}
-
-ENTRY main.82 {
- Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
- copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
- transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
- Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
- copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
- transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
- Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
- copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
- transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
- dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
- reshape.31 = bf16[16,256,256]{2,1,0} reshape(Arg_3.4)
- broadcast.32 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.31), dimensions={1,2,3}
- add.33 = bf16[16,16,256,256]{3,2,1,0} add(dot, broadcast.32)
- constant.21 = bf16[] constant(-inf)
- reduce.38 = bf16[16,16,256]{2,1,0} reduce(add.33, constant.21), dimensions={3}, to_apply=region_0.34
- broadcast.42 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.38), dimensions={0,1,2}
- subtract.43 = bf16[16,16,256,256]{3,2,1,0} subtract(add.33, broadcast.42)
- exponential.44 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.43)
- convert.45 = f32[16,16,256,256]{3,2,1,0} convert(exponential.44)
- constant.9 = f32[] constant(0)
- reduce.50 = f32[16,16,256]{2,1,0} reduce(convert.45, constant.9), dimensions={3}, to_apply=region_1.46
- convert.1 = bf16[16,16,256]{2,1,0} convert(reduce.50)
- broadcast.55 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.1), dimensions={0,1,2}
- divide.56 = bf16[16,16,256,256]{3,2,1,0} divide(exponential.44, broadcast.55)
- constant.18 = u32[1]{0} constant({255383827})
- constant.17 = u32[1]{0} constant({267815257})
- constant.2 = u32[1]{0} constant({0})
- constant.19 = u32[1]{0} constant({3213575472})
- custom-call.26 = (u32[1]{0}, u32[1]{0}) custom-call(constant.18, constant.17, constant.2, constant.19), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
- get-tuple-element.27 = u32[1]{0} get-tuple-element(custom-call.26), index=0
- reshape.58 = u32[] reshape(get-tuple-element.27)
- broadcast.62 = u32[32768]{0} broadcast(reshape.58), dimensions={}
- get-tuple-element.28 = u32[1]{0} get-tuple-element(custom-call.26), index=1
- reshape.59 = u32[] reshape(get-tuple-element.28)
- broadcast.63 = u32[32768]{0} broadcast(reshape.59), dimensions={}
- iota.57 = u32[65536]{0} iota(), iota_dimension=0
- slice.60 = u32[32768]{0} slice(iota.57), slice={[0:32768]}
- slice.61 = u32[32768]{0} slice(iota.57), slice={[32768:65536]}
- custom-call.64 = (u32[32768]{0}, u32[32768]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[32768]{0}, u32[32768]{0}, u32[32768]{0}, u32[32768]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\000\000\000\000\000\000"
- get-tuple-element.65 = u32[32768]{0} get-tuple-element(custom-call.64), index=0
- get-tuple-element.66 = u32[32768]{0} get-tuple-element(custom-call.64), index=1
- concatenate.67 = u32[65536]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
- constant.15 = u32[] constant(9)
- broadcast.3 = u32[65536]{0} broadcast(constant.15), dimensions={}
- shift-right-logical.0 = u32[65536]{0} shift-right-logical(concatenate.67, broadcast.3)
- constant.13 = u32[] constant(1065353216)
- broadcast.11 = u32[65536]{0} broadcast(constant.13), dimensions={}
- or.0 = u32[65536]{0} or(shift-right-logical.0, broadcast.11)
- bitcast-convert.0 = f32[65536]{0} bitcast-convert(or.0)
- constant.3 = f32[] constant(-1)
- broadcast.17 = f32[65536]{0} broadcast(constant.3), dimensions={}
- add.1 = f32[65536]{0} add(bitcast-convert.0, broadcast.17)
- broadcast.18 = f32[65536]{0} broadcast(constant.9), dimensions={}
- maximum.0 = f32[65536]{0} maximum(add.1, broadcast.18)
- constant.7 = f32[] constant(0.9)
- broadcast.19 = f32[65536]{0} broadcast(constant.7), dimensions={}
- compare.0 = pred[65536]{0} compare(maximum.0, broadcast.19), direction=LT
- constant = bf16[] constant(1.109)
- broadcast.20 = bf16[65536]{0} broadcast(constant), dimensions={}
- constant.4 = bf16[] constant(0)
- broadcast.21 = bf16[65536]{0} broadcast(constant.4), dimensions={}
- select.1 = bf16[65536]{0} select(compare.0, broadcast.20, broadcast.21)
- reshape.19 = bf16[16,16,256]{2,1,0} reshape(select.1)
- broadcast.9 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.19), dimensions={0,1,3}
- multiply.79 = bf16[16,16,256,256]{3,2,1,0} multiply(divide.56, broadcast.9)
- dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, multiply.79), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- transpose.81 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
- ROOT copy.3 = bf16[16,256,16,64]{3,2,1,0} copy(transpose.81)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
- m::CustomCall(
- &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
- 0))))
- .WithShape(BF16, {16, 256, 16, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(fmha->operands().size(), 4);
- EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16Bmm1ScaleBiasSoftmaxDropoutForm2Bmm2) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0})->bf16[32,40,60,64]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
-
-region_0.29 {
- Arg_0.30 = bf16[] parameter(0)
- Arg_1.31 = bf16[] parameter(1)
- ROOT maximum.32 = bf16[] maximum(Arg_0.30, Arg_1.31)
-}
-
-region_1.41 {
- Arg_0.42 = f32[] parameter(0)
- Arg_1.43 = f32[] parameter(1)
- ROOT add.44 = f32[] add(Arg_0.42, Arg_1.43)
-}
-
-ENTRY main.79 {
- Arg_2.3 = bf16[32,40,60,64]{3,2,1,0} parameter(2), sharding={replicated}
- copy = bf16[32,40,60,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
- transpose.2 = bf16[32,60,64,40]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
- constant.19 = u32[1]{0} constant({2718843009})
- constant.18 = u32[1]{0} constant({1272950319})
- constant.2 = u32[1]{0} constant({0})
- constant.20 = u32[1]{0} constant({2711844646})
- custom-call.54 = (u32[1]{0}, u32[1]{0}) custom-call(constant.19, constant.18, constant.2, constant.20), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
- get-tuple-element.55 = u32[1]{0} get-tuple-element(custom-call.54), index=0
- reshape.58 = u32[] reshape(get-tuple-element.55)
- broadcast.62 = u32[1536000]{0} broadcast(reshape.58), dimensions={}
- get-tuple-element.56 = u32[1]{0} get-tuple-element(custom-call.54), index=1
- reshape.59 = u32[] reshape(get-tuple-element.56)
- broadcast.63 = u32[1536000]{0} broadcast(reshape.59), dimensions={}
- iota.57 = u32[3072000]{0} iota(), iota_dimension=0
- slice.60 = u32[1536000]{0} slice(iota.57), slice={[0:1536000]}
- slice.61 = u32[1536000]{0} slice(iota.57), slice={[1536000:3072000]}
- custom-call.64 = (u32[1536000]{0}, u32[1536000]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000p\027\000\000\000\000\000"
- get-tuple-element.65 = u32[1536000]{0} get-tuple-element(custom-call.64), index=0
- get-tuple-element.66 = u32[1536000]{0} get-tuple-element(custom-call.64), index=1
- concatenate.67 = u32[3072000]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
- constant.16 = u32[] constant(9)
- broadcast.2 = u32[3072000]{0} broadcast(constant.16), dimensions={}
- shift-right-logical.0 = u32[3072000]{0} shift-right-logical(concatenate.67, broadcast.2)
- constant.14 = u32[] constant(1065353216)
- broadcast.6 = u32[3072000]{0} broadcast(constant.14), dimensions={}
- or.0 = u32[3072000]{0} or(shift-right-logical.0, broadcast.6)
- bitcast-convert.0 = f32[3072000]{0} bitcast-convert(or.0)
- constant.3 = f32[] constant(-1)
- broadcast.8 = f32[3072000]{0} broadcast(constant.3), dimensions={}
- add.1 = f32[3072000]{0} add(bitcast-convert.0, broadcast.8)
- constant.10 = f32[] constant(0)
- broadcast.10 = f32[3072000]{0} broadcast(constant.10), dimensions={}
- maximum.0 = f32[3072000]{0} maximum(add.1, broadcast.10)
- constant.8 = f32[] constant(0.9)
- broadcast.12 = f32[3072000]{0} broadcast(constant.8), dimensions={}
- compare.0 = pred[3072000]{0} compare(maximum.0, broadcast.12), direction=LT
- reshape.18 = pred[32,60,40,40]{3,2,1,0} reshape(compare.0)
- Arg_0.1 = bf16[32,40,60,64]{3,2,1,0} parameter(0), sharding={replicated}
- copy.1 = bf16[32,40,60,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
- transpose = bf16[32,60,40,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
- Arg_1.2 = bf16[32,40,60,64]{3,2,1,0} parameter(1), sharding={replicated}
- copy.2 = bf16[32,40,60,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
- transpose.1 = bf16[32,60,64,40]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
- dot = bf16[32,60,40,40]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.25 = bf16[] constant(1)
- broadcast.26 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.25), dimensions={}
- add.28 = bf16[32,60,40,40]{3,2,1,0} add(dot, broadcast.26)
- constant.24 = bf16[] constant(-inf)
- reduce.33 = bf16[32,60,40]{2,1,0} reduce(add.28, constant.24), dimensions={3}, to_apply=region_0.29
- broadcast.37 = bf16[32,60,40,40]{3,2,1,0} broadcast(reduce.33), dimensions={0,1,2}
- subtract.38 = bf16[32,60,40,40]{3,2,1,0} subtract(add.28, broadcast.37)
- exponential.39 = bf16[32,60,40,40]{3,2,1,0} exponential(subtract.38)
- convert.40 = f32[32,60,40,40]{3,2,1,0} convert(exponential.39)
- reduce.45 = f32[32,60,40]{2,1,0} reduce(convert.40, constant.10), dimensions={3}, to_apply=region_1.41
- convert.0 = bf16[32,60,40]{2,1,0} convert(reduce.45)
- broadcast.50 = bf16[32,60,40,40]{3,2,1,0} broadcast(convert.0), dimensions={0,1,2}
- divide.51 = bf16[32,60,40,40]{3,2,1,0} divide(exponential.39, broadcast.50)
- constant = bf16[] constant(1.109)
- broadcast.1 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant), dimensions={}
- multiply = bf16[32,60,40,40]{3,2,1,0} multiply(divide.51, broadcast.1)
- constant.4 = bf16[] constant(0)
- broadcast.5 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.4), dimensions={}
- select.76 = bf16[32,60,40,40]{3,2,1,0} select(reshape.18, multiply, broadcast.5)
- dot.1 = bf16[32,60,64,40]{3,2,1,0} dot(transpose.2, select.76), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- transpose.78 = bf16[32,40,60,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
- ROOT copy.3 = bf16[32,40,60,64]{3,2,1,0} copy(transpose.78)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
- m::CustomCall(
- &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
- 0))))
- .WithShape(BF16, {32, 40, 60, 64})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
- EXPECT_EQ(fmha->operands().size(), 4);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1Bmm2) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0})}
-
-ENTRY main.17 {
- Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
- copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
- transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
- Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
- copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
- transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
- Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
- copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
- transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
- dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- transpose.7 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
- Arg_3.4 = bf16[16,256,16,64]{3,2,1,0} parameter(3), sharding={replicated}
- copy.3 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_3.4), sharding={replicated}
- transpose.4 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.3), dimensions={0,2,1,3}
- dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.4, transpose.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- copy.4 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_1.2), sharding={replicated}
- transpose.12 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.4), dimensions={0,2,1,3}
- dot.4 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.15 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.4), dimensions={0,2,1,3}
- dot.3 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.13 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.3), dimensions={0,2,1,3}
- copy.5 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_3.4), sharding={replicated}
- transpose.8 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.5), dimensions={0,2,3,1}
- dot.10 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.8, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.11 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.10), dimensions={0,3,1,2}
- tuple.16 = (bf16[16,256,16,64]{1,3,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{1,3,2,0}) tuple(transpose.7, transpose.15, transpose.13, transpose.11)
- get-tuple-element = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=0
- copy.6 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element)
- get-tuple-element.1 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=1
- copy.7 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.1)
- get-tuple-element.2 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=2
- copy.8 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.2)
- get-tuple-element.3 = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=3
- copy.9 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.3)
- ROOT tuple = (bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}) tuple(copy.6, copy.7, copy.8, copy.9)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- const auto status = RunHloPass(&fusedMhaRewriter, m.get());
- const bool changed = status.value();
- EXPECT_EQ(changed, false);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16MiniT5xTest) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__lambda_, entry_computation_layout={(bf16[12,512,32,64]{3,2,1,0},bf16[12,512,2,32,64]{4,3,2,1,0},f32[12,512]{1,0},f32[12,512]{1,0})->(bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true}
-
-region_0.51 {
- Arg_0.52 = bf16[] parameter(0)
- Arg_1.53 = bf16[] parameter(1)
- ROOT maximum.54 = bf16[] maximum(Arg_0.52, Arg_1.53)
-}
-
-region_1.63 {
- Arg_0.64 = f32[] parameter(0)
- Arg_1.65 = f32[] parameter(1)
- ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
-}
-
-region_3.99 {
- Arg_0.100 = bf16[] parameter(0)
- Arg_1.101 = bf16[] parameter(1)
- ROOT add.102 = bf16[] add(Arg_0.100, Arg_1.101)
-}
-
-ENTRY main.129 {
- Arg_1.2 = bf16[12,512,2,32,64]{4,3,2,1,0} parameter(1), sharding={replicated}
- copy = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
- slice.42 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy), slice={[0:12], [0:512], [1:2], [0:32], [0:64]}
- reshape.44 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.42)
- transpose.5 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.44), dimensions={0,2,3,1}
- Arg_0.1 = bf16[12,512,32,64]{3,2,1,0} parameter(0), sharding={replicated}
- copy.1 = bf16[12,512,32,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
- constant.5 = bf16[] constant(0.125)
- broadcast.6 = bf16[12,512,32,64]{3,1,2,0} broadcast(constant.5), dimensions={}
- multiply.45 = bf16[12,512,32,64]{3,1,2,0} multiply(copy.1, broadcast.6)
- transpose = bf16[12,32,512,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
- copy.2 = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
- slice.41 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy.2), slice={[0:12], [0:512], [0:1], [0:32], [0:64]}
- reshape.43 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.41)
- transpose.1 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.43), dimensions={0,2,3,1}
- dot = bf16[12,32,512,512]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_2.3 = f32[12,512]{1,0} parameter(2), sharding={replicated}
- constant.14 = f32[] constant(0)
- broadcast.19 = f32[12,512]{1,0} broadcast(constant.14), dimensions={}
- compare.24 = pred[12,512]{1,0} compare(Arg_2.3, broadcast.19), direction=GT
- broadcast.30 = pred[12,512,512]{2,1,0} broadcast(compare.24), dimensions={0,1}
- Arg_3.4 = f32[12,512]{1,0} parameter(3), sharding={replicated}
- compare.25 = pred[12,512]{1,0} compare(Arg_3.4, broadcast.19), direction=GT
- broadcast.33 = pred[12,512,512]{2,1,0} broadcast(compare.25), dimensions={0,2}
- and.34 = pred[12,512,512]{2,1,0} and(broadcast.30, broadcast.33)
- convert.4 = s32[12,512,512]{2,1,0} convert(and.34)
- constant.16 = s32[] constant(0)
- broadcast.21 = s32[12,512,512]{2,1,0} broadcast(constant.16), dimensions={}
- compare.0 = pred[12,512,512]{2,1,0} compare(convert.4, broadcast.21), direction=GT
- constant.20 = bf16[] constant(0)
- broadcast.22 = bf16[12,512,512]{2,1,0} broadcast(constant.20), dimensions={}
- constant.11 = bf16[] constant(-9.999e+09)
- broadcast.23 = bf16[12,512,512]{2,1,0} broadcast(constant.11), dimensions={}
- select.0 = bf16[12,512,512]{2,1,0} select(compare.0, broadcast.22, broadcast.23)
- broadcast.49 = bf16[12,32,512,512]{3,2,1,0} broadcast(select.0), dimensions={0,2,3}
- add.50 = bf16[12,32,512,512]{3,2,1,0} add(dot, broadcast.49)
- constant.22 = bf16[] constant(-inf)
- reduce.55 = bf16[12,32,512]{2,1,0} reduce(add.50, constant.22), dimensions={3}, to_apply=region_0.51
- broadcast.59 = bf16[12,32,512,512]{3,2,1,0} broadcast(reduce.55), dimensions={0,1,2}
- subtract.60 = bf16[12,32,512,512]{3,2,1,0} subtract(add.50, broadcast.59)
- exponential.61 = bf16[12,32,512,512]{3,2,1,0} exponential(subtract.60)
- convert.62 = f32[12,32,512,512]{3,2,1,0} convert(exponential.61)
- reduce.67 = f32[12,32,512]{2,1,0} reduce(convert.62, constant.14), dimensions={3}, to_apply=region_1.63
- convert.5 = bf16[12,32,512]{2,1,0} convert(reduce.67)
- broadcast.72 = bf16[12,32,512,512]{3,2,1,0} broadcast(convert.5), dimensions={0,1,2}
- divide.73 = bf16[12,32,512,512]{3,2,1,0} divide(exponential.61, broadcast.72)
- dot.1 = bf16[12,32,64,512]{3,2,1,0} dot(transpose.5, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- convert.6 = f32[12,32,64,512]{3,2,1,0} convert(dot.1)
- reduce.83 = f32[] reduce(convert.6, constant.14), dimensions={0,3,1,2}, to_apply=region_1.63
- convert.84 = bf16[] convert(reduce.83)
- constant.2 = bf16[] constant(0.0007935)
- multiply.86 = bf16[] multiply(convert.84, constant.2)
- broadcast.9 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.2), dimensions={}
- dot.2 = bf16[12,32,512,512]{3,2,1,0} dot(broadcast.9, transpose.5), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- divide.109 = bf16[12,32,512,512]{3,2,1,0} divide(dot.2, broadcast.72)
- constant.10 = bf16[] constant(1)
- broadcast.24 = bf16[12,32,512]{2,1,0} broadcast(constant.10), dimensions={}
- multiply.4 = bf16[12,32,512]{2,1,0} multiply(convert.5, convert.5)
- divide.0 = bf16[12,32,512]{2,1,0} divide(broadcast.24, multiply.4)
- broadcast.96 = bf16[12,32,512,512]{3,2,1,0} broadcast(divide.0), dimensions={0,1,2}
- multiply.97 = bf16[12,32,512,512]{3,2,1,0} multiply(dot.2, broadcast.96)
- multiply.98 = bf16[12,32,512,512]{3,2,1,0} multiply(multiply.97, exponential.61)
- reduce.103 = bf16[12,32,512]{2,1,0} reduce(multiply.98, constant.20), dimensions={3}, to_apply=region_3.99
- negate.0 = bf16[12,32,512]{2,1,0} negate(reduce.103)
- broadcast.10 = bf16[12,32,512,512]{3,2,1,0} broadcast(negate.0), dimensions={0,1,2}
- add.118 = bf16[12,32,512,512]{3,2,1,0} add(divide.109, broadcast.10)
- multiply.119 = bf16[12,32,512,512]{3,2,1,0} multiply(add.118, exponential.61)
- transpose.9 = bf16[12,32,512,64]{2,3,1,0} transpose(reshape.43), dimensions={0,2,1,3}
- copy.3 = bf16[12,32,512,64]{3,2,1,0} copy(transpose.9)
- dot.4 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, copy.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- broadcast.12 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.5), dimensions={}
- multiply.3 = bf16[12,32,512,64]{3,2,1,0} multiply(dot.4, broadcast.12)
- transpose.11 = bf16[12,512,32,64]{3,1,2,0} transpose(multiply.3), dimensions={0,2,1,3}
- broadcast.7 = bf16[12,32,64,512]{3,2,1,0} broadcast(constant.2), dimensions={}
- dot.90 = bf16[12,32,64,512]{3,2,1,0} dot(broadcast.7, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.91 = bf16[12,512,32,64]{1,3,2,0} transpose(dot.90), dimensions={0,3,1,2}
- reshape.92 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.91)
- pad.93 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.92, constant.20), padding=0_0x0_0x1_0x0_0x0_0
- dot.3 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- copy.4 = bf16[12,32,512,64]{2,3,1,0} copy(dot.3)
- transpose.121 = bf16[12,512,32,64]{1,3,2,0} transpose(copy.4), dimensions={0,2,1,3}
- reshape.124 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.121)
- pad.125 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.124, constant.20), padding=0_0x0_0x0_1x0_0x0_0
- add.126 = bf16[12,512,2,32,64]{1,4,3,0,2} add(pad.93, pad.125)
- tuple.128 = (bf16[], bf16[12,512,32,64]{3,1,2,0}, bf16[12,512,2,32,64]{1,4,3,0,2}) tuple(multiply.86, transpose.11, add.126)
- get-tuple-element = bf16[] get-tuple-element(tuple.128), index=0
- get-tuple-element.1 = bf16[12,512,32,64]{3,1,2,0} get-tuple-element(tuple.128), index=1
- copy.5 = bf16[12,512,32,64]{3,2,1,0} copy(get-tuple-element.1)
- get-tuple-element.2 = bf16[12,512,2,32,64]{1,4,3,0,2} get-tuple-element(tuple.128), index=2
- copy.6 = bf16[12,512,2,32,64]{4,3,2,1,0} copy(get-tuple-element.2)
- ROOT tuple = (bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0}) tuple(get-tuple-element, copy.5, copy.6)
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- AlgebraicSimplifierOptions alg_sim_options;
- alg_sim_options.set_supports_non_canonical_dots(false);
- alg_sim_options.set_is_layout_sensitive(true);
- alg_sim_options.set_enable_conv_operand_swap(false);
- AlgebraicSimplifier alge_simp{alg_sim_options};
- ReshapeDecomposer reshape_decomposer;
- LayoutNormalization layout_normalizer;
- HloCSE cse{/*is_layout_sensitive=*/true};
- TF_ASSERT_OK(RunHloPass(&reshape_decomposer, m.get()).status());
- TF_ASSERT_OK(RunHloPass(&layout_normalizer, m.get()).status());
- TF_ASSERT_OK(RunHloPass(&cse, m.get()).status());
- TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
-
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
- CudnnFusedMHATransposeFusion fmha_transpose_fusion;
-
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
- TF_ASSERT_OK(RunHloPass(&fmha_transpose_fusion, m.get()).status());
-
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- EXPECT_EQ(CountFusedAttentionCall(m.get()), 1);
- EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 1);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- ActivationHasMoreThan1UserShouldNotLower) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule test
-
-%region_50.2457 (Arg_0.2458: bf16[], Arg_1.2459: bf16[]) -> bf16[] {
- %Arg_0.2458 = bf16[] parameter(0)
- %Arg_1.2459 = bf16[] parameter(1)
- ROOT %maximum.2 = bf16[] maximum(bf16[] %Arg_0.2458, bf16[] %Arg_1.2459)
-}
-
-%region_36.2316 (Arg_0.2317: f32[], Arg_1.2318: f32[]) -> f32[] {
- %Arg_0.2317 = f32[] parameter(0)
- %Arg_1.2318 = f32[] parameter(1)
- ROOT %add.342 = f32[] add(f32[] %Arg_0.2317, f32[] %Arg_1.2318)
-}
-
-ENTRY main {
- %transpose.482 = bf16[4,5,64]{2,1,0} parameter(0)
- %transpose.484 = bf16[4,64,5]{2,1,0} parameter(1)
- %dot.20 = bf16[4,5,5]{2,1,0} dot(bf16[4,5,64]{2,1,0} %transpose.482, bf16[4,64,5]{2,1,0} %transpose.484), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
- %constant.2515 = bf16[] constant(0.125)
- %broadcast.789 = bf16[4,5,5]{2,1,0} broadcast(bf16[] %constant.2515), dimensions={}
- %multiply.267 = bf16[4,5,5]{2,1,0} multiply(bf16[4,5,5]{2,1,0} %dot.20, bf16[4,5,5]{2,1,0} %broadcast.789)
- %constant.287 = f32[] constant(-1)
- %broadcast.792 = bf16[4,5,5]{2,1,0} parameter(3)
- %add.348 = bf16[4,5,5]{2,1,0} add(bf16[4,5,5]{2,1,0} %multiply.267, bf16[4,5,5]{2,1,0} %broadcast.792)
- %constant.2510 = bf16[] constant(-inf)
- %reduce.2550 = bf16[4,5]{1,0} reduce(bf16[4,5,5]{2,1,0} %add.348, bf16[] %constant.2510), dimensions={2}, to_apply=%region_50.2457
- %broadcast.793 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %reduce.2550), dimensions={0,1}
- %subtract.81 = bf16[4,5,5]{2,1,0} subtract(bf16[4,5,5]{2,1,0} %add.348, bf16[4,5,5]{2,1,0} %broadcast.793)
- %exponential.21 = bf16[4,5,5]{2,1,0} exponential(bf16[4,5,5]{2,1,0} %subtract.81)
- %convert.180 = f32[4,5,5]{2,1,0} convert(bf16[4,5,5]{2,1,0} %exponential.21)
- %constant.2509 = f32[] constant(0)
- %reduce.2558 = f32[4,5]{1,0} reduce(f32[4,5,5]{2,1,0} %convert.180, f32[] %constant.2509), dimensions={2}, to_apply=%region_36.2316
- %convert.182 = bf16[4,5]{1,0} convert(f32[4,5]{1,0} %reduce.2558)
- %broadcast.794 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %convert.182), dimensions={0,1}
- %divide.25 = bf16[4,5,5]{2,1,0} divide(bf16[4,5,5]{2,1,0} %exponential.21, bf16[4,5,5]{2,1,0} %broadcast.794)
- %transpose.481 = bf16[4,64,5]{2,1,0} parameter(2)
- %dot.21 = bf16[4,64,5]{2,1,0} dot(bf16[4,64,5]{2,1,0} %transpose.481, bf16[4,5,5]{2,1,0} %divide.25), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
- ROOT %tuple.2668 = (bf16[4,5,5]{2,1,0}, bf16[4,64,5]{2,1,0}) tuple(bf16[4,5,5]{2,1,0} %divide.25, bf16[4,64,5]{2,1,0} %dot.21)
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- HloVerifier verifier(/*layout_sensitive=*/false,
- /*allow_mixed_precision*/ true);
- ASSERT_IS_OK(verifier.Run(m.get()).status());
-
- EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxBmm2ShouldNotBeLowered) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
- Arg_0.22 = f16[] parameter(0)
- Arg_1.23 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
- Arg_0.34 = f32[] parameter(0)
- Arg_1.35 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
- Arg_0.56 = f16[] parameter(0)
- Arg_1.57 = f16[] parameter(1)
- ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
- constant.18 = pred[2,6,128,128]{3,2,1,0} constant({...})
- Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
- dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.22 = f16[] constant(2)
- broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
- multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
- constant.19 = f16[] constant(1)
- broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
- add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
- constant.21 = f16[] constant(0)
- broadcast.23 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.21), dimensions={}
- select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.18, add.3, broadcast.23)
- constant.15 = f16[] constant(-inf)
- reduce.25 = f16[2,6,128]{2,1,0} reduce(select.1, constant.15), dimensions={3}, to_apply=region_0.21
- broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
- subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.17)
- exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
- constant.17 = f32[] constant(0)
- reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
- convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
- broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
- Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
- dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
- broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
- multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
- broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
- multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
- broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.59), dimensions={0,1,2}
- add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
- multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
- select.3 = f16[2,6,128,128]{3,2,1,0} select(constant.18, multiply.8, broadcast.23)
- multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(select.3, broadcast.24)
- dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- HloVerifier verifier(/*layout_sensitive=*/false,
- /*allow_mixed_precision*/ true);
- ASSERT_IS_OK(verifier.Run(m.get()).status());
-
- // The backward pattern in the graph is not a valid fmha pattern,
- // we expect no rewrite happening.
- EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
- EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2ShouldNotLower) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.38 {
- Arg_0.39 = f16[] parameter(0)
- Arg_1.40 = f16[] parameter(1)
- ROOT maximum.1 = f16[] maximum(Arg_0.39, Arg_1.40)
-}
-
-region_1.50 {
- Arg_0.51 = f32[] parameter(0)
- Arg_1.52 = f32[] parameter(1)
- ROOT add.2 = f32[] add(Arg_0.51, Arg_1.52)
-}
-
-region_2.99 {
- Arg_0.100 = f16[] parameter(0)
- Arg_1.101 = f16[] parameter(1)
- ROOT add.3 = f16[] add(Arg_0.100, Arg_1.101)
-}
-
-ENTRY main.126 {
- constant.6 = u32[1]{0} constant({2718843009})
- constant.8 = u32[1]{0} constant({1272950319})
- constant.10 = u32[1]{0} constant({0})
- constant.12 = u32[1]{0} constant({2711844646})
- custom-call.65 = (u32[1]{0}, u32[1]{0}) custom-call(constant.6, constant.8, constant.10, constant.12), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
- get-tuple-element.66 = u32[1]{0} get-tuple-element(custom-call.65), index=0
- bitcast.343 = u32[] bitcast(get-tuple-element.66)
- broadcast.27 = u32[98304]{0} broadcast(bitcast.343), dimensions={}
- get-tuple-element.67 = u32[1]{0} get-tuple-element(custom-call.65), index=1
- bitcast.344 = u32[] bitcast(get-tuple-element.67)
- broadcast.28 = u32[98304]{0} broadcast(bitcast.344), dimensions={}
- iota.68 = u32[196608]{0} iota(), iota_dimension=0
- slice = u32[98304]{0} slice(iota.68), slice={[0:98304]}
- slice.1 = u32[98304]{0} slice(iota.68), slice={[98304:196608]}
- custom-call.75 = (u32[98304]{0}, u32[98304]{0}) custom-call(broadcast.27, broadcast.28, slice, slice.1), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[98304]{0}, u32[98304]{0}, u32[98304]{0}, u32[98304]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\001\000\000\000\000\000"
- get-tuple-element.76 = u32[98304]{0} get-tuple-element(custom-call.75), index=0
- get-tuple-element.77 = u32[98304]{0} get-tuple-element(custom-call.75), index=1
- concatenate.2 = u32[196608]{0} concatenate(get-tuple-element.76, get-tuple-element.77), dimensions={0}
- constant.56 = u32[] constant(9)
- broadcast.63 = u32[196608]{0} broadcast(constant.56), dimensions={}
- shift-right-logical.3 = u32[196608]{0} shift-right-logical(concatenate.2, broadcast.63)
- constant.57 = u32[] constant(1065353216)
- broadcast.64 = u32[196608]{0} broadcast(constant.57), dimensions={}
- or.3 = u32[196608]{0} or(shift-right-logical.3, broadcast.64)
- bitcast-convert.3 = f32[196608]{0} bitcast-convert(or.3)
- constant.58 = f32[] constant(-1)
- broadcast.65 = f32[196608]{0} broadcast(constant.58), dimensions={}
- add.10 = f32[196608]{0} add(bitcast-convert.3, broadcast.65)
- constant.48 = f32[] constant(0)
- broadcast.66 = f32[196608]{0} broadcast(constant.48), dimensions={}
- maximum.4 = f32[196608]{0} maximum(add.10, broadcast.66)
- constant.59 = f32[] constant(0.9)
- broadcast.67 = f32[196608]{0} broadcast(constant.59), dimensions={}
- compare.3 = pred[196608]{0} compare(maximum.4, broadcast.67), direction=LT
- bitcast.308 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
- constant.44 = pred[2,6,128,128]{3,2,1,0} constant({...})
- Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
- dot.34 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.55 = f16[] constant(2)
- broadcast.61 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.55), dimensions={}
- multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(dot.34, broadcast.61)
- constant.52 = f16[] constant(1)
- broadcast.39 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.52), dimensions={}
- add.6 = f16[2,6,128,128]{3,2,1,0} add(multiply.8, broadcast.39)
- constant.54 = f16[] constant(0)
- broadcast.52 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.54), dimensions={}
- select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.44, add.6, broadcast.52)
- constant.41 = f16[] constant(-inf)
- reduce.42 = f16[2,6,128]{2,1,0} reduce(select.1, constant.41), dimensions={3}, to_apply=region_0.38
- broadcast.42 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.42), dimensions={0,1,2}
- subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.42)
- exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
- reduce.54 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.48), dimensions={3}, to_apply=region_1.50
- convert.9 = f16[2,6,128]{2,1,0} convert(reduce.54)
- broadcast.68 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.68)
- constant.60 = f16[] constant(1.1113)
- broadcast.69 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.60), dimensions={}
- multiply.20 = f16[2,6,128,128]{3,2,1,0} multiply(divide.5, broadcast.69)
- select.8 = f16[2,6,128,128]{3,2,1,0} select(bitcast.308, multiply.20, broadcast.52)
- Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.88 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- bitcast.248 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
- Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
- dot.91 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- select.6 = f16[2,6,128,128]{3,2,1,0} select(bitcast.248, dot.91, broadcast.52)
- multiply.17 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.69)
- divide.4 = f16[2,6,128,128]{3,2,1,0} divide(multiply.17, broadcast.68)
- broadcast.55 = f16[2,6,128]{2,1,0} broadcast(constant.52), dimensions={}
- multiply.11 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.55, multiply.11)
- broadcast.56 = f16[2,6,128]{2,1,0} broadcast(constant.60), dimensions={}
- multiply.12 = f16[2,6,128]{2,1,0} multiply(divide.3, broadcast.56)
- broadcast.58 = f16[2,6,128,128]{3,2,1,0} broadcast(multiply.12), dimensions={0,1,2}
- multiply.13 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.58)
- multiply.14 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.13, exponential.1)
- reduce.103 = f16[2,6,128]{2,1,0} reduce(multiply.14, constant.54), dimensions={3}, to_apply=region_2.99
- broadcast.62 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.103), dimensions={0,1,2}
- add.9 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.62)
- multiply.18 = f16[2,6,128,128]{3,2,1,0} multiply(add.9, exponential.1)
- select.7 = f16[2,6,128,128]{3,2,1,0} select(constant.44, multiply.18, broadcast.52)
- multiply.19 = f16[2,6,128,128]{3,2,1,0} multiply(select.7, broadcast.61)
- dot.124 = f16[2,6,128,64]{3,2,1,0} dot(multiply.19, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.19), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.125 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.88, dot.124, dot, dot.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- HloVerifier verifier(/*layout_sensitive=*/false,
- /*allow_mixed_precision*/ true);
- ASSERT_IS_OK(verifier.Run(m.get()).status());
-
- // The backward pattern in the graph is not a valid fmha pattern,
- // we expect no rewrite happening.
- EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
- EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- F16TrainingBmm1ScaleBiasSoftmaxBmm2QTranspose) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
- Arg_0.22 = f16[] parameter(0)
- Arg_1.23 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
- Arg_0.34 = f32[] parameter(0)
- Arg_1.35 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
- Arg_0.56 = f16[] parameter(0)
- Arg_1.57 = f16[] parameter(1)
- ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
- Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
- dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.22 = f16[] constant(2)
- broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
- multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
- constant.19 = f16[] constant(1)
- broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
- add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
- constant.21 = f16[] constant(0)
- constant.15 = f16[] constant(-inf)
- reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
- broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
- subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
- exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
- constant.17 = f32[] constant(0)
- reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
- convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
- broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
- Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
- dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
- broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
- multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
- broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
- multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
- negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
- broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
- add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
- multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
- multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
- dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
- .WithShape(F16, {2, 6, 128, 64}),
- m::GetTupleElement(
- m::CustomCall(&fmha,
- {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 0)
- .WithShape(F16, {2, 6, 128, 64}),
- m::Transpose(
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 1))
- .WithShape(F16, {2, 6, 64, 128}),
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
- .WithShape(F16, {2, 6, 128, 64}))));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(fmha->operands().size(), 7);
- EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- F16Bmm1UnfusedSoftmaxBmm2IncorrectBmm1NumUsers) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
-
-region_0.7 {
- Arg_0.8 = f16[] parameter(0)
- Arg_1.9 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
- Arg_0.20 = f32[] parameter(0)
- Arg_1.21 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.31 {
- Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
- dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- // extra user of bmm1
- neg.1 = f16[2,6,40,40]{3,2,1,0} negate(dot)
- constant = f16[] constant(-inf)
- reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
- broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
- subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
- exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
- convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
- constant.1 = f32[] constant(0)
- reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
- convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
- broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
- divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
- Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::Dot(), m::Negate())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- F16Bmm1UnfusedSoftmaxBmm2IncorrectSoftmaxNumUsers) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
-
-region_0.7 {
- Arg_0.8 = f16[] parameter(0)
- Arg_1.9 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
- Arg_0.20 = f32[] parameter(0)
- Arg_1.21 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.31 {
- Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
- dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- constant = f16[] constant(-inf)
- reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
- broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
- subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
- // extra user of softmax sub node
- neg.1 = f16[2,6,40,40]{3,2,1,0} negate(subtract.1)
- exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
- convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
- constant.1 = f32[] constant(0)
- reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
- convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
- broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
- divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
- Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
- ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::Dot(), m::Negate())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- F16TrainingBmm1ScaleBiasSoftmaxBmm2IncorrectSoftmaxBwdNumUsers) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
- Arg_0.22 = f16[] parameter(0)
- Arg_1.23 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
- Arg_0.34 = f32[] parameter(0)
- Arg_1.35 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
- Arg_0.56 = f16[] parameter(0)
- Arg_1.57 = f16[] parameter(1)
- ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
- Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
- dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.22 = f16[] constant(2)
- broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
- multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
- constant.19 = f16[] constant(1)
- broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
- add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
- constant.21 = f16[] constant(0)
- constant.15 = f16[] constant(-inf)
- reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
- broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
- subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
- exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
- constant.17 = f32[] constant(0)
- reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
- convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
- broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
- Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
- dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
- // extra user of softmax bwd divide node
- neg.1 = f16[2,6,128,128]{3,2,1,0} negate(divide.4)
- broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
- multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
- broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
- multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
- negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
- broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
- add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
- multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
- multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
- dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
- m::Negate())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1SoftmaxBmm2IncorrectRank) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule reproducer, entry_computation_layout={(f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, /*index=5*/f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0})->f16[8,16,2,5,64]{4,3,2,1,0}}
-
-region_0.36 {
- Arg_0.37 = f16[] parameter(0)
- Arg_1.38 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.37, Arg_1.38)
-}
-
-region_1.48 {
- Arg_0.49 = f32[] parameter(0)
- Arg_1.50 = f32[] parameter(1)
- ROOT add.1 = f32[] add(Arg_0.49, Arg_1.50)
-}
-
-ENTRY main {
- arg2.3 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(2), parameter_replication={false}
- bitcast.31 = f16[640,128]{1,0} bitcast(arg2.3)
- arg5.6 = f32[128,2,64]{2,1,0} parameter(5), parameter_replication={false}
- convert.3 = f16[128,2,64]{2,1,0} convert(arg5.6)
- bitcast.36 = f16[128,128]{1,0} bitcast(convert.3)
- dot = f16[640,128]{1,0} dot(bitcast.31, bitcast.36), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
- bitcast.39 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot)
- transpose.27 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.39), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
- arg6.7 = f32[2,64]{1,0} parameter(6), parameter_replication={false}
- convert.4 = f16[2,64]{1,0} convert(arg6.7)
- broadcast.9 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.4), dimensions={3,5}
- add.2 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.27, broadcast.9)
- bitcast.49 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.2)
- arg0.1 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(0), parameter_replication={false}
- bitcast.53 = f16[640,128]{1,0} bitcast(arg0.1)
- arg3.4 = f32[128,2,64]{2,1,0} parameter(3), parameter_replication={false}
- convert.5 = f16[128,2,64]{2,1,0} convert(arg3.4)
- bitcast.58 = f16[128,128]{1,0} bitcast(convert.5)
- dot.1 = f16[640,128]{1,0} dot(bitcast.53, bitcast.58), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
- bitcast.61 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.1)
- transpose.28 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} transpose(bitcast.61), dimensions={0,1,2,4,5,3}, frontend_attributes={grad_x="false",grad_y="false"}
- arg4.5 = f32[2,64]{1,0} parameter(4), parameter_replication={false}
- convert.6 = f16[2,64]{1,0} convert(arg4.5)
- broadcast.10 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(convert.6), dimensions={3,4}
- add.3 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} add(transpose.28, broadcast.10)
- constant.29 = f16[] constant(0.125)
- broadcast.11 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(constant.29), dimensions={}
- multiply = f16[1,8,16,2,64,5]{5,4,3,2,1,0} multiply(add.3, broadcast.11)
- bitcast.74 = f16[8,16,2,64,5]{4,3,2,1,0} bitcast(multiply)
- dot.6 = f16[8,16,2,5,5]{4,3,2,1,0} dot(bitcast.49, bitcast.74), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
- constant.35 = f16[] constant(-inf)
- reduce.1 = f16[8,16,2,5]{3,2,1,0} reduce(dot.6, constant.35), dimensions={3}, to_apply=region_0.36
- broadcast.12 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(reduce.1), dimensions={0,1,2,4}
- subtract.2 = f16[8,16,2,5,5]{4,3,2,1,0} subtract(dot.6, broadcast.12)
- exponential.2 = f16[8,16,2,5,5]{4,3,2,1,0} exponential(subtract.2)
- convert.7 = f32[8,16,2,5,5]{4,3,2,1,0} convert(exponential.2)
- constant.34 = f32[] constant(0)
- reduce.3 = f32[8,16,2,5]{3,2,1,0} reduce(convert.7, constant.34), dimensions={3}, to_apply=region_1.48
- convert.8 = f16[8,16,2,5]{3,2,1,0} convert(reduce.3)
- broadcast.13 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(convert.8), dimensions={0,1,2,4}
- divide.2 = f16[8,16,2,5,5]{4,3,2,1,0} divide(exponential.2, broadcast.13)
- bitcast.98 = f16[8,16,2,5,5]{3,4,2,1,0} bitcast(divide.2)
- arg1.2 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(1), parameter_replication={false}
- bitcast.102 = f16[640,128]{1,0} bitcast(arg1.2)
- arg7.8 = f32[128,2,64]{2,1,0} parameter(7), parameter_replication={false}
- convert.9 = f16[128,2,64]{2,1,0} convert(arg7.8)
- bitcast.107 = f16[128,128]{1,0} bitcast(convert.9)
- dot.3 = f16[640,128]{1,0} dot(bitcast.102, bitcast.107), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
- bitcast.110 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.3)
- transpose.30 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.110), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
- arg8.9 = f32[2,64]{1,0} parameter(8), parameter_replication={false}
- convert.10 = f16[2,64]{1,0} convert(arg8.9)
- broadcast.14 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.10), dimensions={3,5}
- add.4 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.30, broadcast.14)
- bitcast.120 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.4)
- ROOT dot.7 = f16[8,16,2,5,64]{4,3,2,1,0} dot(bitcast.98, bitcast.120), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
-} // main
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- const auto status_or = RunHloPass(&fusedMhaRewriter, m.get());
- TF_ASSERT_OK(status_or.status());
- EXPECT_FALSE(status_or.value());
-
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot()));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- F16TrainingBmm1ScaleBiasSoftmaxBmm2NonContractingDimNotDivisibleBy64) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,100]{3,2,1,0},f16[2,6,64,100]{3,2,1,0},f16[2,6,100,64]{3,2,1,0},f16[2,6,100,64]{3,2,1,0})->(f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
- Arg_0.22 = f16[] parameter(0)
- Arg_1.23 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
- Arg_0.34 = f32[] parameter(0)
- Arg_1.35 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
- Arg_0.56 = f16[] parameter(0)
- Arg_1.57 = f16[] parameter(1)
- ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
- Arg_0.1 = f16[2,6,64,100]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,100]{3,2,1,0} parameter(1), sharding={replicated}
- dot.17 = f16[2,6,100,100]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.22 = f16[] constant(2)
- broadcast.24 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.22), dimensions={}
- multiply.2 = f16[2,6,100,100]{3,2,1,0} multiply(dot.17, broadcast.24)
- constant.19 = f16[] constant(1)
- broadcast.13 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.19), dimensions={}
- add.3 = f16[2,6,100,100]{3,2,1,0} add(multiply.2, broadcast.13)
- constant.21 = f16[] constant(0)
- constant.15 = f16[] constant(-inf)
- reduce.25 = f16[2,6,100]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
- broadcast.17 = f16[2,6,100,100]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
- subtract.1 = f16[2,6,100,100]{3,2,1,0} subtract(add.3, broadcast.17)
- exponential.1 = f16[2,6,100,100]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,100,100]{3,2,1,0} convert(exponential.1)
- constant.17 = f32[] constant(0)
- reduce.37 = f32[2,6,100]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
- convert.9 = f16[2,6,100]{2,1,0} convert(reduce.37)
- broadcast.26 = f16[2,6,100,100]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = f16[2,6,100,100]{3,2,1,0} divide(exponential.1, broadcast.26)
- Arg_2.3 = f16[2,6,100,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.46 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = f16[2,6,100,64]{3,2,1,0} parameter(3), sharding={replicated}
- dot.49 = f16[2,6,100,100]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = f16[2,6,100,100]{3,2,1,0} divide(dot.49, broadcast.26)
- broadcast.20 = f16[2,6,100]{2,1,0} broadcast(constant.19), dimensions={}
- multiply.3 = f16[2,6,100]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = f16[2,6,100]{2,1,0} divide(broadcast.20, multiply.3)
- broadcast.21 = f16[2,6,100,100]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = f16[2,6,100,100]{3,2,1,0} multiply(dot.49, broadcast.21)
- multiply.5 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.59 = f16[2,6,100]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
- negate.2 = f16[2,6,100]{2,1,0} negate(reduce.59)
- broadcast.25 = f16[2,6,100,100]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
- add.5 = f16[2,6,100,100]{3,2,1,0} add(divide.4, broadcast.25)
- multiply.8 = f16[2,6,100,100]{3,2,1,0} multiply(add.5, exponential.1)
- multiply.9 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.8, broadcast.24)
- dot.80 = f16[2,6,100,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = f16[2,6,64,100]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.81 = (f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- const auto status_or = RunHloPass(&fusedMhaRewriter, m.get());
- TF_ASSERT_OK(status_or.status());
- EXPECT_FALSE(status_or.value());
-
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm2Grad1IncorrectPattern) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
- Arg_0.22 = f16[] parameter(0)
- Arg_1.23 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
- Arg_0.34 = f32[] parameter(0)
- Arg_1.35 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
- Arg_0.56 = f16[] parameter(0)
- Arg_1.57 = f16[] parameter(1)
- ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
- Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
- dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.22 = f16[] constant(2)
- broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
- multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
- constant.19 = f16[] constant(1)
- broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
- add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
- constant.21 = f16[] constant(0)
- constant.15 = f16[] constant(-inf)
- reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
- broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
- subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
- exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
- constant.17 = f32[] constant(0)
- reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
- convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
- broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
- Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
- dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
- dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
- broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
- multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
- broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
- multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
- negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
- broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
- add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
- multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
- multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
- dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- // add another user of ds multiply.9 here, neg.1 should not be pattern matched as bmm2grad1
- neg.1 = f16[2,6,128,128]{3,2,1,0} negate(multiply.9)
- dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
- m::Negate())));
-}
-
-// flash attention
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- FlashAttentionBF16TrainingBmm1CausalMaskSoftmaxBmm2Pattern) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-region_0.32 {
- Arg_0.33 = bf16[] parameter(0)
- Arg_1.34 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
-}
-region_1.44 {
- Arg_0.45 = f32[] parameter(0)
- Arg_1.46 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-region_2.66 {
- Arg_0.67 = bf16[] parameter(0)
- Arg_1.68 = bf16[] parameter(1)
- ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
-}
-ENTRY main.92 {
- Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
- dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.17 = bf16[] constant(2)
- broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
- multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
- iota.2 = s32[2048,2048]{1,0} iota(), iota_dimension=0
- iota.5 = s32[2048,2048]{1,0} iota(), iota_dimension=1
- compare.1 = pred[2048,2048]{1,0} compare(iota.2, iota.5), direction=LT
- constant.6 = bf16[] constant(-2.366e+38)
- broadcast.16 = bf16[2048,2048]{1,0} broadcast(constant.6), dimensions={}
- constant.16 = bf16[] constant(0)
- broadcast.17 = bf16[2048,2048]{1,0} broadcast(constant.16), dimensions={}
- select.2 = bf16[2048,2048]{1,0} select(compare.1, broadcast.16, broadcast.17)
- broadcast.19 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(select.2), dimensions={2,3}
- add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, broadcast.19)
- constant.10 = bf16[] constant(-inf)
- reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
- broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
- subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
- exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
- constant.14 = f32[] constant(0)
- reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
- convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
- broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
- Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
- dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
- dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
- constant.15 = bf16[] constant(1)
- broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
- multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
- broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
- multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
- negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
- broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
- add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
- multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
- multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
- dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- const HloInstruction* fwd_fmha;
- const HloInstruction* bwd_fmha;
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(
- m::CustomCall(&fwd_fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
- .WithShape(BF16, {2, 6, 2048, 128}),
- m::GetTupleElement(
- m::CustomCall(&bwd_fmha, {kCudnnfMHASoftmaxBackwardCallTarget}),
- 0)
- .WithShape(BF16, {2, 6, 2048, 128}),
- m::Transpose(
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
- .WithShape(BF16, {2, 6, 128, 2048}),
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
- .WithShape(BF16, {2, 6, 2048, 128}))));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fwd_fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(fwd_fmha->operands().size(), 3);
- EXPECT_EQ(bwd_fmha->operands().size(), 6);
- EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
- EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- FlashAttentionBF16TrainingBmm1BiasSoftmaxBmm2Pattern) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-region_0.32 {
- Arg_0.33 = bf16[] parameter(0)
- Arg_1.34 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
-}
-region_1.44 {
- Arg_0.45 = f32[] parameter(0)
- Arg_1.46 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-region_2.66 {
- Arg_0.67 = bf16[] parameter(0)
- Arg_1.68 = bf16[] parameter(1)
- ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
-}
-ENTRY main.92 {
- Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
- dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.17 = bf16[] constant(2)
- broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
- multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
- // bias
- Arg_4.5 = bf16[2,6,2048,2048]{3,2,1,0} parameter(4), sharding={replicated}
- add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, Arg_4.5)
- constant.10 = bf16[] constant(-inf)
- constant.16 = bf16[] constant(0)
- reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
- broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
- subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
- exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
- constant.14 = f32[] constant(0)
- reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
- convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
- broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
- Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
- dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
- dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
- constant.15 = bf16[] constant(1)
- broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
- multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
- broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
- multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
- negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
- broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
- add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
- multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
- multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
- dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
- .WithShape(BF16, {2, 6, 2048, 128}),
- m::GetTupleElement(
- m::CustomCall(&fmha,
- {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 0)
- .WithShape(BF16, {2, 6, 2048, 128}),
- m::Transpose(
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 1))
- .WithShape(BF16, {2, 6, 128, 2048}),
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
- .WithShape(BF16, {2, 6, 2048, 128}))));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(fmha->operands().size(), 7);
- EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
- EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- FlashAttentionBF16TrainingBmm1SoftmaxBmm2Pattern) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-region_0.32 {
- Arg_0.33 = bf16[] parameter(0)
- Arg_1.34 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
-}
-region_1.44 {
- Arg_0.45 = f32[] parameter(0)
- Arg_1.46 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-region_2.66 {
- Arg_0.67 = bf16[] parameter(0)
- Arg_1.68 = bf16[] parameter(1)
- ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
-}
-ENTRY main.92 {
- Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
- dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- constant.17 = bf16[] constant(2)
- broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
- multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
- constant.10 = bf16[] constant(-inf)
- constant.16 = bf16[] constant(0)
- reduce.36 = bf16[2,6,2048]{2,1,0} reduce(multiply.2, constant.10), dimensions={3}, to_apply=region_0.32
- broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
- subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(multiply.2, broadcast.21)
- exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
- convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
- constant.14 = f32[] constant(0)
- reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
- convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
- broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
- divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
- Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
- dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
- dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
- constant.15 = bf16[] constant(1)
- broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
- multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
- divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
- broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
- multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
- multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
- reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
- negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
- broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
- add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
- multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
- multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
- dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- HloDCE dce;
- TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
- .WithShape(BF16, {2, 6, 2048, 128}),
- m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), 0)
- .WithShape(BF16, {2, 6, 2048, 128}),
- m::Transpose(
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
- .WithShape(BF16, {2, 6, 128, 2048}),
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
- .WithShape(BF16, {2, 6, 2048, 128}))));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(fmha->operands().size(), 6);
- EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
- EXPECT_FLOAT_EQ(config.fmha_scale(), 2);
- EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
-}
-
-// GPT3 pattern
-TEST_F(CudnnFusedMhaRewriterTestHloTest, FlashAttentionBF16TrainingGPT3_5B) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={((s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}))->(s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0})}
-add {
- x = bf16[] parameter(0)
- y = bf16[] parameter(1)
- ROOT add.580 = bf16[] add(x, y)
-}
-
-region_20.962 {
- Arg_0.963 = f32[] parameter(0)
- Arg_1.964 = f32[] parameter(1)
- ROOT add.579 = f32[] add(Arg_0.963, Arg_1.964)
-}
-
-region_39.1120 {
- Arg_0.1121 = f32[] parameter(0)
- Arg_1.1122 = f32[] parameter(1)
- ROOT maximum.21 = f32[] maximum(Arg_0.1121, Arg_1.1122)
-}
-
-main {
- param.3 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) parameter(0)
- get-tuple-element.31 = s32[] get-tuple-element(param.3), index=0
- constant.1961 = s32[] constant(1)
- add.581 = s32[] add(get-tuple-element.31, constant.1961)
- get-tuple-element.32 = bf16[24,32,2048,2048]{2,1,3,0} get-tuple-element(param.3), index=25
- bitcast.187 = bf16[24,2048,32,2048]{3,2,1,0} bitcast(get-tuple-element.32)
- constant.1977 = s32[] constant(23)
- subtract.221 = s32[] subtract(constant.1977, get-tuple-element.31)
- constant.1980 = s32[] constant(0)
- compare.210 = pred[] compare(subtract.221, constant.1980), direction=LT
- constant.1979 = s32[] constant(47)
- subtract.222 = s32[] subtract(constant.1979, get-tuple-element.31)
- select.372 = s32[] select(compare.210, subtract.222, subtract.221)
- dynamic-slice.324 = bf16[1,2048,32,2048]{3,2,1,0} dynamic-slice(bitcast.187, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,2048,32,2048}
- bitcast.756 = bf16[2048,32,2048]{2,1,0} bitcast(dynamic-slice.324)
- convert.282 = f32[2048,32,2048]{2,1,0} convert(bitcast.756)
- constant.1991 = bf16[] constant(1)
- broadcast.1270 = bf16[32,2048]{1,0} broadcast(constant.1991), dimensions={}
- get-tuple-element.33 = bf16[32,2048]{1,0} get-tuple-element(param.3), index=27
- subtract.229 = bf16[32,2048]{1,0} subtract(broadcast.1270, get-tuple-element.33)
- convert.285 = f32[32,2048]{1,0} convert(subtract.229)
- broadcast.1228 = f32[2048,32,2048]{2,1,0} broadcast(convert.285), dimensions={1,2}
- multiply.656 = f32[2048,32,2048]{2,1,0} multiply(convert.282, broadcast.1228)
- bitcast.367 = f32[32,2048,2048]{1,0,2} bitcast(multiply.656)
- constant.1968 = f32[] constant(0)
- reduce.84 = f32[] reduce(bitcast.367, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
- all-reduce.230 = f32[] all-reduce(reduce.84), channel_id=278, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
- broadcast.1221 = f32[32,2048,4096]{2,1,0} broadcast(convert.285), dimensions={0,1}
- reduce.85 = f32[] reduce(broadcast.1221, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
- all-reduce.14 = f32[] all-reduce(reduce.85), channel_id=49, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=region_20.962
- constant.2005 = f32[] constant(1)
- maximum.24 = f32[] maximum(all-reduce.14, constant.2005)
- divide.96 = f32[] divide(all-reduce.230, maximum.24)
- broadcast.1223 = f32[2048,32,2048]{2,1,0} broadcast(divide.96), dimensions={}
- subtract.219 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1223)
- multiply.644 = f32[2048,32,2048]{2,1,0} multiply(subtract.219, broadcast.1228)
- multiply.645 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, multiply.644)
- bitcast.271 = f32[32,2048,2048]{1,0,2} bitcast(multiply.645)
- reduce.86 = f32[] reduce(bitcast.271, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
- all-reduce.231 = f32[] all-reduce(reduce.86), channel_id=279, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
- divide.99 = f32[] divide(all-reduce.231, maximum.24)
- rsqrt.16 = f32[] rsqrt(divide.99)
- multiply.650 = f32[] multiply(rsqrt.16, constant.1968)
- divide.100 = f32[] divide(multiply.650, maximum.24)
- constant.1974 = f32[] constant(2)
- multiply.652 = f32[] multiply(divide.100, constant.1974)
- broadcast.1227 = f32[2048,32,2048]{2,1,0} broadcast(multiply.652), dimensions={}
- multiply.653 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, broadcast.1227)
- multiply.654 = f32[2048,32,2048]{2,1,0} multiply(multiply.653, broadcast.1228)
- negate.56 = f32[2048,32,2048]{2,1,0} negate(multiply.654)
- bitcast.321 = f32[32,2048,2048]{1,0,2} bitcast(negate.56)
- reduce.87 = f32[] reduce(bitcast.321, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
- all-reduce.232 = f32[] all-reduce(reduce.87), channel_id=280, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
- divide.101 = f32[] divide(all-reduce.232, maximum.24)
- broadcast.1229 = f32[32,2048]{1,0} broadcast(divide.101), dimensions={}
- multiply.655 = f32[32,2048]{1,0} multiply(broadcast.1229, convert.285)
- broadcast.1230 = f32[2048,32,2048]{2,1,0} broadcast(multiply.655), dimensions={1,2}
- add.582 = f32[2048,32,2048]{2,1,0} add(multiply.654, broadcast.1230)
- broadcast.1236 = f32[2048,32,2048]{2,1,0} broadcast(constant.1968), dimensions={}
- compare.208 = pred[2048,32,2048]{2,1,0} compare(multiply.656, broadcast.1236), direction=GE
- abs.22 = f32[2048,32,2048]{2,1,0} abs(multiply.656)
- bitcast.373 = f32[32,2048,2048]{1,0,2} bitcast(abs.22)
- constant.1989 = f32[] constant(-inf)
- reduce.88 = f32[] reduce(bitcast.373, constant.1989), dimensions={0,1,2}, to_apply=region_39.1120
- all-reduce.233 = f32[] all-reduce(reduce.88), channel_id=281, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_39.1120
- broadcast.1233 = f32[2048,32,2048]{2,1,0} broadcast(all-reduce.233), dimensions={}
- compare.207 = pred[2048,32,2048]{2,1,0} compare(abs.22, broadcast.1233), direction=EQ
- convert.286 = f32[2048,32,2048]{2,1,0} convert(compare.207)
- bitcast.393 = f32[32,2048,2048]{1,0,2} bitcast(convert.286)
- reduce.89 = f32[] reduce(bitcast.393, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
- all-reduce.234 = f32[] all-reduce(reduce.89), channel_id=282, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
- divide.103 = f32[] divide(constant.1968, all-reduce.234)
- broadcast.1238 = f32[2048,32,2048]{2,1,0} broadcast(divide.103), dimensions={}
- select.370 = f32[2048,32,2048]{2,1,0} select(compare.207, broadcast.1238, broadcast.1236)
- select.369 = f32[2048,32,2048]{2,1,0} select(compare.208, select.370, broadcast.1236)
- constant.1976 = pred[] constant(false)
- broadcast.1237 = pred[2048,32,2048]{2,1,0} broadcast(constant.1976), dimensions={}
- compare.209 = pred[2048,32,2048]{2,1,0} compare(compare.208, broadcast.1237), direction=EQ
- select.371 = f32[2048,32,2048]{2,1,0} select(compare.209, select.370, broadcast.1236)
- negate.57 = f32[2048,32,2048]{2,1,0} negate(select.371)
- add.583 = f32[2048,32,2048]{2,1,0} add(select.369, negate.57)
- multiply.658 = f32[2048,32,2048]{2,1,0} multiply(add.583, broadcast.1228)
- add.585 = f32[2048,32,2048]{2,1,0} add(add.582, multiply.658)
- convert.287 = bf16[2048,32,2048]{2,1,0} convert(add.585)
- get-tuple-element.34 = bf16[32,2048,2048]{1,0,2} get-tuple-element(param.3), index=1
- bitcast.1652 = bf16[2048,32,2048]{2,1,0} bitcast(get-tuple-element.34)
- get-tuple-element.35 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=22
- bitcast.461 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.35)
- dynamic-slice.325 = bf16[1,1024,3,16,128]{4,3,2,1,0} dynamic-slice(bitcast.461, select.372, constant.1980, constant.1980, constant.1980, /*index=5*/constant.1980), dynamic_slice_sizes={1,1024,3,16,128}
- bitcast.485 = bf16[3,1024,16,128]{3,2,0,1} bitcast(dynamic-slice.325)
- all-gather.7 = bf16[3,4096,16,128]{3,2,0,1} all-gather(bitcast.485), channel_id=60, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
- bitcast.1420 = bf16[6144,4096]{0,1} bitcast(all-gather.7)
- bitcast.500 = f32[32,2048,2048]{1,0,2} bitcast(convert.282)
- reduce.90 = f32[32,2048]{1,0} reduce(bitcast.500, constant.1968), dimensions={2}, to_apply=region_20.962
- all-reduce.23 = f32[32,2048]{1,0} all-reduce(reduce.90), channel_id=58, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- constant.1983 = f32[] constant(0.000244140625)
- broadcast.1243 = f32[32,2048]{1,0} broadcast(constant.1983), dimensions={}
- multiply.660 = f32[32,2048]{1,0} multiply(all-reduce.23, broadcast.1243)
- broadcast.1242 = f32[2048,32,2048]{2,1,0} broadcast(multiply.660), dimensions={1,2}
- subtract.224 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1242)
- multiply.661 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, subtract.224)
- bitcast.527 = f32[32,2048,2048]{1,0,2} bitcast(multiply.661)
- reduce.91 = f32[32,2048]{1,0} reduce(bitcast.527, constant.1968), dimensions={2}, to_apply=region_20.962
- all-reduce.24 = f32[32,2048]{1,0} all-reduce(reduce.91), channel_id=59, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- multiply.662 = f32[32,2048]{1,0} multiply(all-reduce.24, broadcast.1243)
- constant.1990 = f32[] constant(1e-05)
- broadcast.1264 = f32[32,2048]{1,0} broadcast(constant.1990), dimensions={}
- add.587 = f32[32,2048]{1,0} add(multiply.662, broadcast.1264)
- bitcast.1447 = f32[1,32,2048]{2,1,0} bitcast(add.587)
- rsqrt.20 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1447)
- bitcast.1892 = f32[32,2048]{1,0} bitcast(rsqrt.20)
- broadcast.1337 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1892), dimensions={1,2}
- multiply.754 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1337)
- convert.314 = bf16[2048,32,2048]{2,1,0} convert(multiply.754)
- get-tuple-element.36 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=20
- dynamic-slice.326 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.36, select.372, constant.1980), dynamic_slice_sizes={1,2048}
- broadcast.1266 = bf16[1,2048]{1,0} broadcast(constant.1991), dimensions={}
- add.588 = bf16[1,2048]{1,0} add(dynamic-slice.326, broadcast.1266)
- bitcast.1992 = bf16[2048]{0} bitcast(add.588)
- broadcast.1338 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1992), dimensions={0}
- multiply.755 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, broadcast.1338)
- get-tuple-element.37 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=19
- dynamic-slice.327 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.37, select.372, constant.1980), dynamic_slice_sizes={1,2048}
- bitcast.1998 = bf16[2048]{0} bitcast(dynamic-slice.327)
- broadcast.1339 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1998), dimensions={0}
- add.640 = bf16[2048,32,2048]{2,1,0} add(multiply.755, broadcast.1339)
- bitcast.2003 = bf16[32,2048,2048]{1,0,2} bitcast(add.640)
- all-gather.8 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=61, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.597 = bf16[4096,65536]{1,0} bitcast(all-gather.8)
- dot.42 = bf16[6144,65536]{1,0} dot(bitcast.1420, bitcast.597), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitcast.623 = bf16[3,16,128,32,2048]{4,3,2,1,0} bitcast(dot.42)
- transpose.112 = bf16[3,32,16,128,2048]{4,3,2,1,0} transpose(bitcast.623), dimensions={0,3,1,2,4}
- get-tuple-element.38 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=21
- dynamic-slice.328 = bf16[1,3,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.38, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,3,16,128}
- bitcast.626 = bf16[3,16,128]{2,1,0} bitcast(dynamic-slice.328)
- broadcast.1250 = bf16[3,32,16,128,2048]{4,3,2,1,0} broadcast(bitcast.626), dimensions={0,2,3}
- add.591 = bf16[3,32,16,128,2048]{4,3,2,1,0} add(transpose.112, broadcast.1250)
- slice.87 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[2:3], [0:32], [0:16], [0:128], [0:2048]}
- bitcast.1280 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.87)
- slice.88 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[0:1], [0:32], [0:16], [0:128], [0:2048]}
- constant.2007 = bf16[] constant(0.08838)
- broadcast.1251 = bf16[1,32,16,128,2048]{4,3,2,1,0} broadcast(constant.2007), dimensions={}
- multiply.666 = bf16[1,32,16,128,2048]{4,3,2,1,0} multiply(slice.88, broadcast.1251)
- bitcast.1330 = bf16[32,16,128,2048]{3,2,1,0} bitcast(multiply.666)
- transpose.113 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1330), dimensions={0,1,3,2}
- slice.89 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[1:2], [0:32], [0:16], [0:128], [0:2048]}
- bitcast.647 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.89)
- dot.43 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.113, bitcast.647), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- convert.291 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.43)
- get-tuple-element.39 = bf16[32,1,2048,2048]{3,2,0,1} get-tuple-element(param.3), index=26
- bitcast.651 = bf16[1,32,2048,2048]{3,2,1,0} bitcast(get-tuple-element.39)
- iota.38 = s32[2048,2048]{1,0} iota(), iota_dimension=0
- iota.39 = s32[2048,2048]{1,0} iota(), iota_dimension=1
- compare.211 = pred[2048,2048]{1,0} compare(iota.38, iota.39), direction=LT
- constant.1987 = bf16[] constant(-2.366e+38)
- broadcast.1252 = bf16[2048,2048]{1,0} broadcast(constant.1987), dimensions={}
- constant.2006 = bf16[] constant(0)
- broadcast.1253 = bf16[2048,2048]{1,0} broadcast(constant.2006), dimensions={}
- select.373 = bf16[2048,2048]{1,0} select(compare.211, broadcast.1252, broadcast.1253)
- broadcast.1254 = bf16[1,32,2048,2048]{3,2,1,0} broadcast(select.373), dimensions={2,3}
- minimum.5 = bf16[1,32,2048,2048]{3,2,1,0} minimum(bitcast.651, broadcast.1254)
- bitcast.673 = bf16[32,2048,2048]{2,1,0} bitcast(minimum.5)
- convert.292 = f32[32,2048,2048]{2,1,0} convert(bitcast.673)
- broadcast.1256 = f32[32,16,2048,2048]{3,2,1,0} broadcast(convert.292), dimensions={0,2,3}
- add.593 = f32[32,16,2048,2048]{3,2,1,0} add(convert.291, broadcast.1256)
- reduce.92 = f32[32,16,2048]{2,1,0} reduce(add.593, constant.1989), dimensions={3}, to_apply=region_39.1120
- broadcast.1258 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.92), dimensions={0,1,2}
- subtract.226 = f32[32,16,2048,2048]{3,2,1,0} subtract(add.593, broadcast.1258)
- exponential.8 = f32[32,16,2048,2048]{3,2,1,0} exponential(subtract.226)
- reduce.93 = f32[32,16,2048]{2,1,0} reduce(exponential.8, constant.1968), dimensions={3}, to_apply=region_20.962
- broadcast.1309 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.93), dimensions={0,1,2}
- divide.109 = f32[32,16,2048,2048]{3,2,1,0} divide(exponential.8, broadcast.1309)
- convert.306 = bf16[32,16,2048,2048]{3,2,1,0} convert(divide.109)
- dot.44 = bf16[32,16,128,2048]{3,2,1,0} dot(bitcast.1280, convert.306), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- transpose.116 = bf16[32,2048,16,128]{3,2,1,0} transpose(dot.44), dimensions={0,3,1,2}
- bitcast.711 = bf16[65536,2048]{1,0} bitcast(transpose.116)
- get-tuple-element.40 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=24
- dynamic-slice.329 = bf16[1,1024,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.40, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,16,128}
- bitcast.724 = bf16[1024,16,128]{2,1,0} bitcast(dynamic-slice.329)
- all-gather.9 = bf16[4096,16,128]{2,1,0} all-gather(bitcast.724), channel_id=62, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
- bitcast.729 = bf16[2048,4096]{0,1} bitcast(all-gather.9)
- dot.57 = bf16[65536,4096]{0,1} dot(bitcast.711, bitcast.729), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitcast.733 = bf16[32,2048,4096]{1,0,2} bitcast(dot.57)
- reduce-scatter = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.733), channel_id=322, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
- bitcast.763 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter)
- get-tuple-element.41 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=23
- dynamic-slice.330 = bf16[1,1024]{1,0} dynamic-slice(get-tuple-element.41, select.372, constant.1980), dynamic_slice_sizes={1,1024}
- bitcast.748 = bf16[1024]{0} bitcast(dynamic-slice.330)
- collective-permute.1 = bf16[1024]{0} collective-permute(bitcast.748), channel_id=64, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
- all-gather.10 = bf16[2048]{0} all-gather(collective-permute.1), channel_id=65, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={0}, use_global_device_ids=true
- broadcast.1261 = bf16[2048,32,2048]{2,1,0} broadcast(all-gather.10), dimensions={0}
- add.596 = bf16[2048,32,2048]{2,1,0} add(bitcast.763, broadcast.1261)
- add.597 = bf16[2048,32,2048]{2,1,0} add(add.596, bitcast.756)
- convert.295 = f32[2048,32,2048]{2,1,0} convert(add.597)
- bitcast.774 = f32[32,2048,2048]{1,0,2} bitcast(convert.295)
- reduce.94 = f32[32,2048]{1,0} reduce(bitcast.774, constant.1968), dimensions={2}, to_apply=region_20.962
- all-reduce.26 = f32[32,2048]{1,0} all-reduce(reduce.94), channel_id=66, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- multiply.668 = f32[32,2048]{1,0} multiply(all-reduce.26, broadcast.1243)
- broadcast.1263 = f32[2048,32,2048]{2,1,0} broadcast(multiply.668), dimensions={1,2}
- subtract.228 = f32[2048,32,2048]{2,1,0} subtract(convert.295, broadcast.1263)
- multiply.669 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, subtract.228)
- bitcast.809 = f32[32,2048,2048]{1,0,2} bitcast(multiply.669)
- reduce.95 = f32[32,2048]{1,0} reduce(bitcast.809, constant.1968), dimensions={2}, to_apply=region_20.962
- all-reduce.27 = f32[32,2048]{1,0} all-reduce(reduce.95), channel_id=67, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- multiply.670 = f32[32,2048]{1,0} multiply(all-reduce.27, broadcast.1243)
- add.598 = f32[32,2048]{1,0} add(multiply.670, broadcast.1264)
- bitcast.1148 = f32[1,32,2048]{2,1,0} bitcast(add.598)
- rsqrt.19 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1148)
- bitcast.1602 = f32[32,2048]{1,0} bitcast(rsqrt.19)
- broadcast.1329 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1602), dimensions={1,2}
- multiply.750 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1329)
- convert.312 = bf16[2048,32,2048]{2,1,0} convert(multiply.750)
- get-tuple-element.42 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=18
- dynamic-slice.331 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.42, select.372, constant.1980), dynamic_slice_sizes={1,2048}
- add.599 = bf16[1,2048]{1,0} add(dynamic-slice.331, broadcast.1266)
- bitcast.1609 = bf16[2048]{0} bitcast(add.599)
- broadcast.1330 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1609), dimensions={0}
- multiply.745 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, broadcast.1330)
- get-tuple-element.43 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=17
- dynamic-slice.332 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.43, select.372, constant.1980), dynamic_slice_sizes={1,2048}
- bitcast.1615 = bf16[2048]{0} bitcast(dynamic-slice.332)
- broadcast.1331 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1615), dimensions={0}
- add.636 = bf16[2048,32,2048]{2,1,0} add(multiply.745, broadcast.1331)
- bitcast.1620 = bf16[32,2048,2048]{1,0,2} bitcast(add.636)
- all-gather.12 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=69, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.877 = bf16[65536,4096]{0,1} bitcast(all-gather.12)
- get-tuple-element.44 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=15
- dynamic-slice.333 = bf16[1,1024,8192]{2,1,0} dynamic-slice(get-tuple-element.44, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
- bitcast.890 = bf16[1024,8192]{1,0} bitcast(dynamic-slice.333)
- all-gather.11 = bf16[4096,8192]{1,0} all-gather(bitcast.890), channel_id=68, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
- dot.45 = bf16[65536,8192]{1,0} dot(bitcast.877, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- get-tuple-element.45 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=14
- dynamic-slice.334 = bf16[1,8192]{1,0} dynamic-slice(get-tuple-element.45, select.372, constant.1980), dynamic_slice_sizes={1,8192}
- bitcast.906 = bf16[8192]{0} bitcast(dynamic-slice.334)
- broadcast.1269 = bf16[65536,8192]{1,0} broadcast(bitcast.906), dimensions={1}
- add.601 = bf16[65536,8192]{1,0} add(dot.45, broadcast.1269)
- bitcast.997 = bf16[32,2048,8192]{2,1,0} bitcast(add.601)
- broadcast.1333 = bf16[2048,32,2048]{2,1,0} broadcast(subtract.229), dimensions={1,2}
- multiply.746 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1652, broadcast.1333)
- bitcast.1739 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.746)
- all-gather.14 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=71, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.934 = bf16[65536,4096]{0,1} bitcast(all-gather.14)
- get-tuple-element.46 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=16
- bitcast.935 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.46)
- dynamic-slice.335 = bf16[1,1024,8192]{2,1,0} dynamic-slice(bitcast.935, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
- bitcast.947 = bf16[8192,1024]{0,1} bitcast(dynamic-slice.335)
- all-gather.13 = bf16[8192,4096]{0,1} all-gather(bitcast.947), channel_id=70, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
- dot.46 = bf16[65536,8192]{1,0} dot(bitcast.934, all-gather.13), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- bitcast.1092 = bf16[32,2048,8192]{2,1,0} bitcast(dot.46)
- broadcast.1335 = bf16[32,2048,8192]{2,1,0} broadcast(subtract.229), dimensions={0,1}
- multiply.703 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.1092, broadcast.1335)
- multiply.685 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.703)
- constant.2002 = bf16[] constant(0.5)
- broadcast.1288 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2002), dimensions={}
- multiply.686 = bf16[32,2048,8192]{2,1,0} multiply(multiply.685, broadcast.1288)
- broadcast.1287 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1991), dimensions={}
- multiply.700 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, bitcast.997)
- multiply.693 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.700)
- constant.1998 = bf16[] constant(0.04468)
- broadcast.1282 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1998), dimensions={}
- multiply.694 = bf16[32,2048,8192]{2,1,0} multiply(multiply.693, broadcast.1282)
- add.605 = bf16[32,2048,8192]{2,1,0} add(bitcast.997, multiply.694)
- constant.2010 = bf16[] constant(0.7969)
- broadcast.1324 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2010), dimensions={}
- multiply.695 = bf16[32,2048,8192]{2,1,0} multiply(add.605, broadcast.1324)
- tanh.7 = bf16[32,2048,8192]{2,1,0} tanh(multiply.695)
- subtract.231 = bf16[32,2048,8192]{2,1,0} subtract(broadcast.1287, tanh.7)
- multiply.691 = bf16[32,2048,8192]{2,1,0} multiply(multiply.686, subtract.231)
- multiply.737 = bf16[32,2048,8192]{2,1,0} multiply(multiply.691, tanh.7)
- add.630 = bf16[32,2048,8192]{2,1,0} add(multiply.691, multiply.737)
- multiply.738 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1324)
- constant.2011 = bf16[] constant(0.03564)
- broadcast.1326 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2011), dimensions={}
- multiply.739 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1326)
- constant.2012 = bf16[] constant(3)
- broadcast.1327 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2012), dimensions={}
- multiply.740 = bf16[32,2048,8192]{2,1,0} multiply(multiply.700, broadcast.1327)
- multiply.741 = bf16[32,2048,8192]{2,1,0} multiply(multiply.739, multiply.740)
- add.632 = bf16[32,2048,8192]{2,1,0} add(multiply.738, multiply.741)
- add.637 = bf16[32,2048,8192]{2,1,0} add(tanh.7, broadcast.1287)
- multiply.747 = bf16[32,2048,8192]{2,1,0} multiply(add.637, broadcast.1288)
- multiply.743 = bf16[32,2048,8192]{2,1,0} multiply(multiply.703, multiply.747)
- add.635 = bf16[32,2048,8192]{2,1,0} add(add.632, multiply.743)
- bitcast.1629 = bf16[65536,8192]{1,0} bitcast(add.635)
- dot.47 = bf16[65536,4096]{0,1} dot(bitcast.1629, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- bitcast.1130 = bf16[32,2048,4096]{1,0,2} bitcast(dot.47)
- reduce-scatter.1 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1130), channel_id=323, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
- bitcast.1766 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.1)
- multiply.712 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1766, broadcast.1330)
- convert.299 = f32[2048,32,2048]{2,1,0} convert(multiply.712)
- multiply.707 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, convert.299)
- bitcast.1135 = f32[32,2048,2048]{1,0,2} bitcast(multiply.707)
- reduce.96 = f32[32,2048]{1,0} reduce(bitcast.1135, constant.1968), dimensions={2}, to_apply=region_20.962
- all-reduce.29 = f32[32,2048]{1,0} all-reduce(reduce.96), channel_id=73, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- bitcast.1140 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.29)
- divide.105 = f32[1,32,2048]{2,1,0} divide(rsqrt.19, bitcast.1148)
- constant.2008 = f32[] constant(-0.5)
- broadcast.1313 = f32[1,32,2048]{2,1,0} broadcast(constant.2008), dimensions={}
- multiply.708 = f32[1,32,2048]{2,1,0} multiply(divide.105, broadcast.1313)
- multiply.709 = f32[1,32,2048]{2,1,0} multiply(bitcast.1140, multiply.708)
- constant.2009 = f32[] constant(0.00048828125)
- broadcast.1315 = f32[1,32,2048]{2,1,0} broadcast(constant.2009), dimensions={}
- multiply.710 = f32[1,32,2048]{2,1,0} multiply(multiply.709, broadcast.1315)
- bitcast.1235 = f32[32,2048]{1,0} bitcast(multiply.710)
- broadcast.1296 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1235), dimensions={1,2}
- multiply.717 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1296)
- multiply.718 = f32[2048,32,2048]{2,1,0} multiply(convert.299, broadcast.1329)
- add.617 = f32[2048,32,2048]{2,1,0} add(multiply.717, multiply.718)
- negate.58 = f32[2048,32,2048]{2,1,0} negate(multiply.717)
- bitcast.1189 = f32[32,2048,2048]{1,0,2} bitcast(negate.58)
- reduce.97 = f32[32,2048]{1,0} reduce(bitcast.1189, constant.1968), dimensions={2}, to_apply=region_20.962
- negate.59 = f32[2048,32,2048]{2,1,0} negate(multiply.718)
- bitcast.1203 = f32[32,2048,2048]{1,0,2} bitcast(negate.59)
- reduce.98 = f32[32,2048]{1,0} reduce(bitcast.1203, constant.1968), dimensions={2}, to_apply=region_20.962
- add.613 = f32[32,2048]{1,0} add(reduce.97, reduce.98)
- all-reduce.274 = f32[32,2048]{1,0} all-reduce(add.613), channel_id=335, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- multiply.719 = f32[32,2048]{1,0} multiply(all-reduce.274, broadcast.1243)
- broadcast.1297 = f32[2048,32,2048]{2,1,0} broadcast(multiply.719), dimensions={1,2}
- add.618 = f32[2048,32,2048]{2,1,0} add(add.617, broadcast.1297)
- convert.301 = bf16[2048,32,2048]{2,1,0} convert(add.618)
- add.619 = bf16[2048,32,2048]{2,1,0} add(bitcast.1652, convert.301)
- add.616 = bf16[2048,32,2048]{2,1,0} add(convert.287, add.619)
- bitcast.2063 = bf16[32,2048,2048]{1,0,2} bitcast(add.619)
- all-gather.15 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2063), channel_id=76, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.1263 = bf16[65536,4096]{0,1} bitcast(all-gather.15)
- bitcast.1269 = bf16[4096,2048]{1,0} bitcast(all-gather.9)
- dot.48 = bf16[65536,2048]{1,0} dot(bitcast.1263, bitcast.1269), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitcast.1381 = bf16[32,2048,16,128]{3,2,1,0} bitcast(dot.48)
- transpose.122 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,1,3}
- dot.49 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.122, bitcast.1280), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- convert.303 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.49)
- broadcast.1298 = f32[32,16,2048]{2,1,0} broadcast(constant.2005), dimensions={}
- multiply.720 = f32[32,16,2048]{2,1,0} multiply(reduce.93, reduce.93)
- divide.106 = f32[32,16,2048]{2,1,0} divide(broadcast.1298, multiply.720)
- broadcast.1299 = f32[32,16,2048,2048]{3,2,1,0} broadcast(divide.106), dimensions={0,1,2}
- multiply.721 = f32[32,16,2048,2048]{3,2,1,0} multiply(convert.303, broadcast.1299)
- multiply.722 = f32[32,16,2048,2048]{3,2,1,0} multiply(multiply.721, exponential.8)
- reduce.99 = f32[32,16,2048]{2,1,0} reduce(multiply.722, constant.1968), dimensions={3}, to_apply=region_20.962
- negate.61 = f32[32,16,2048]{2,1,0} negate(reduce.99)
- broadcast.1305 = f32[32,16,2048,2048]{3,2,1,0} broadcast(negate.61), dimensions={0,1,2}
- divide.108 = f32[32,16,2048,2048]{3,2,1,0} divide(convert.303, broadcast.1309)
- add.622 = f32[32,16,2048,2048]{3,2,1,0} add(broadcast.1305, divide.108)
- multiply.724 = f32[32,16,2048,2048]{3,2,1,0} multiply(add.622, exponential.8)
- convert.305 = bf16[32,16,2048,2048]{3,2,1,0} convert(multiply.724)
- dot.50 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.113), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- bitcast.1934 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.50)
- pad.6 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1934, constant.2006), padding=1_1x0_0x0_0x0_0x0_0
- transpose.120 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.647), dimensions={0,1,3,2}
- dot.51 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.120), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- broadcast.1307 = bf16[32,16,2048,128]{3,2,1,0} broadcast(constant.2007), dimensions={}
- multiply.725 = bf16[32,16,2048,128]{3,2,1,0} multiply(dot.51, broadcast.1307)
- bitcast.1941 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(multiply.725)
- pad.7 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1941, constant.2006), padding=0_2x0_0x0_0x0_0x0_0
- add.638 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(pad.6, pad.7)
- transpose.123 = bf16[32,16,128,2048]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,3,1}
- dot.89 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.306, transpose.123), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- bitcast.1949 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.89)
- pad.8 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1949, constant.2006), padding=2_0x0_0x0_0x0_0x0_0
- add.639 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(add.638, pad.8)
- transpose.127 = bf16[32,2048,3,16,128]{4,3,2,1,0} transpose(add.639), dimensions={1,3,0,2,4}
- bitcast.1416 = bf16[65536,6144]{1,0} bitcast(transpose.127)
- dot.52 = bf16[65536,4096]{0,1} dot(bitcast.1416, bitcast.1420), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitcast.1424 = bf16[32,2048,4096]{1,0,2} bitcast(dot.52)
- reduce-scatter.2 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1424), channel_id=324, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
- bitcast.1851 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.2)
- multiply.732 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1851, broadcast.1338)
- convert.308 = f32[2048,32,2048]{2,1,0} convert(multiply.732)
- multiply.727 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, convert.308)
- bitcast.1434 = f32[32,2048,2048]{1,0,2} bitcast(multiply.727)
- reduce.100 = f32[32,2048]{1,0} reduce(bitcast.1434, constant.1968), dimensions={2}, to_apply=region_20.962
- all-reduce.33 = f32[32,2048]{1,0} all-reduce(reduce.100), channel_id=78, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- bitcast.1439 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.33)
- divide.110 = f32[1,32,2048]{2,1,0} divide(rsqrt.20, bitcast.1447)
- multiply.728 = f32[1,32,2048]{2,1,0} multiply(divide.110, broadcast.1313)
- multiply.729 = f32[1,32,2048]{2,1,0} multiply(bitcast.1439, multiply.728)
- multiply.730 = f32[1,32,2048]{2,1,0} multiply(multiply.729, broadcast.1315)
- bitcast.1485 = f32[32,2048]{1,0} bitcast(multiply.730)
- broadcast.1321 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1485), dimensions={1,2}
- multiply.734 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1321)
- multiply.735 = f32[2048,32,2048]{2,1,0} multiply(convert.308, broadcast.1337)
- add.625 = f32[2048,32,2048]{2,1,0} add(multiply.734, multiply.735)
- negate.62 = f32[2048,32,2048]{2,1,0} negate(multiply.734)
- bitcast.1491 = f32[32,2048,2048]{1,0,2} bitcast(negate.62)
- reduce.101 = f32[32,2048]{1,0} reduce(bitcast.1491, constant.1968), dimensions={2}, to_apply=region_20.962
- negate.63 = f32[2048,32,2048]{2,1,0} negate(multiply.735)
- bitcast.1505 = f32[32,2048,2048]{1,0,2} bitcast(negate.63)
- reduce.102 = f32[32,2048]{1,0} reduce(bitcast.1505, constant.1968), dimensions={2}, to_apply=region_20.962
- add.626 = f32[32,2048]{1,0} add(reduce.101, reduce.102)
- all-reduce.275 = f32[32,2048]{1,0} all-reduce(add.626), channel_id=336, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
- multiply.736 = f32[32,2048]{1,0} multiply(all-reduce.275, broadcast.1243)
- broadcast.1323 = f32[2048,32,2048]{2,1,0} broadcast(multiply.736), dimensions={1,2}
- add.628 = f32[2048,32,2048]{2,1,0} add(add.625, broadcast.1323)
- convert.309 = bf16[2048,32,2048]{2,1,0} convert(add.628)
- add.629 = bf16[2048,32,2048]{2,1,0} add(add.616, convert.309)
- bitcast.1525 = bf16[32,2048,2048]{1,0,2} bitcast(add.629)
- get-tuple-element.47 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=2
- reduce.103 = bf16[8192]{0} reduce(add.635, constant.2006), dimensions={0,1}, to_apply=add
- all-reduce.36 = bf16[8192]{0} all-reduce(reduce.103), channel_id=81, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- bitcast.1583 = bf16[1,8192]{1,0} bitcast(all-reduce.36)
- dynamic-update-slice.28 = bf16[24,8192]{1,0} dynamic-update-slice(get-tuple-element.47, bitcast.1583, select.372, constant.1980)
- get-tuple-element.48 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=3
- all-gather.16 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=82, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.1625 = bf16[4096,65536]{1,0} bitcast(all-gather.16)
- dot.53 = bf16[4096,8192]{1,0} dot(bitcast.1625, bitcast.1629), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- reduce-scatter.3 = bf16[1024,8192]{1,0} reduce-scatter(dot.53), channel_id=325, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={0}, to_apply=add
- bitcast.1634 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.3)
- dynamic-update-slice.29 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(get-tuple-element.48, bitcast.1634, select.372, constant.1980, constant.1980)
- get-tuple-element.49 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=4
- collective-permute.2 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.49), channel_id=85, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
- all-gather.17 = bf16[24,2048]{0,1} all-gather(collective-permute.2), channel_id=86, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
- bitcast.1649 = bf16[2048,24]{1,0} bitcast(all-gather.17)
- reduce.104 = bf16[2048]{0} reduce(bitcast.1739, constant.2006), dimensions={0,1}, to_apply=add
- all-reduce.38 = bf16[2048]{0} all-reduce(reduce.104), channel_id=84, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- bitcast.1671 = bf16[2048,1]{1,0} bitcast(all-reduce.38)
- dynamic-update-slice.30 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1649, bitcast.1671, constant.1980, select.372)
- constant.2013 = s32[8]{0} constant({0, 2048, 0, 2048, 1024, 3072, 1024, 3072})
- partition-id.3 = u32[] partition-id()
- dynamic-slice.336 = s32[1]{0} dynamic-slice(constant.2013, partition-id.3), dynamic_slice_sizes={1}
- constant.2014 = s32[8]{0} constant({0, 2048, 0, 2048, 0, 2048, 0, 2048})
- dynamic-slice.337 = s32[1]{0} dynamic-slice(constant.2014, partition-id.3), dynamic_slice_sizes={1}
- subtract.232 = s32[1]{0} subtract(dynamic-slice.336, dynamic-slice.337)
- bitcast.2087 = s32[] bitcast(subtract.232)
- dynamic-slice.338 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.30, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
- bitcast.1695 = bf16[24,1024]{0,1} bitcast(dynamic-slice.338)
- collective-permute.9 = bf16[24,1024]{0,1} collective-permute(bitcast.1695), channel_id=109, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
- get-tuple-element.50 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=5
- bitcast.1698 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.50)
- multiply.748 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.747)
- multiply.749 = bf16[32,2048,8192]{2,1,0} multiply(multiply.748, broadcast.1335)
- bitcast.1735 = bf16[8192,65536]{0,1} bitcast(multiply.749)
- all-gather.18 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=87, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.1743 = bf16[65536,4096]{0,1} bitcast(all-gather.18)
- dot.54 = bf16[8192,4096]{0,1} dot(bitcast.1735, bitcast.1743), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- reduce-scatter.4 = bf16[8192,1024]{0,1} reduce-scatter(dot.54), channel_id=326, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
- bitcast.1748 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.4)
- dynamic-update-slice.31 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(bitcast.1698, bitcast.1748, select.372, constant.1980, constant.1980)
- bitcast.1758 = bf16[24,8192,1024]{1,2,0} bitcast(dynamic-update-slice.31)
- get-tuple-element.51 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=6
- collective-permute.3 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.51), channel_id=90, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
- all-gather.19 = bf16[24,2048]{0,1} all-gather(collective-permute.3), channel_id=91, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
- bitcast.1763 = bf16[2048,24]{1,0} bitcast(all-gather.19)
- reduce.105 = bf16[2048]{0} reduce(reduce-scatter.1, constant.2006), dimensions={0,1}, to_apply=add
- all-reduce.40 = bf16[2048]{0} all-reduce(reduce.105), channel_id=89, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- bitcast.1779 = bf16[2048,1]{1,0} bitcast(all-reduce.40)
- dynamic-update-slice.32 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1763, bitcast.1779, constant.1980, select.372)
- dynamic-slice.339 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.32, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
- bitcast.1794 = bf16[24,1024]{0,1} bitcast(dynamic-slice.339)
- collective-permute.10 = bf16[24,1024]{0,1} collective-permute(bitcast.1794), channel_id=110, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
- get-tuple-element.52 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=7
- collective-permute.4 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.52), channel_id=93, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
- all-gather.20 = bf16[24,2048]{0,1} all-gather(collective-permute.4), channel_id=94, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
- bitcast.1801 = bf16[2048,24]{1,0} bitcast(all-gather.20)
- multiply.751 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, bitcast.1766)
- bitcast.1817 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.751)
- reduce.106 = bf16[2048]{0} reduce(bitcast.1817, constant.2006), dimensions={0,1}, to_apply=add
- all-reduce.41 = bf16[2048]{0} all-reduce(reduce.106), channel_id=92, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- bitcast.1826 = bf16[2048,1]{1,0} bitcast(all-reduce.41)
- dynamic-update-slice.33 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1801, bitcast.1826, constant.1980, select.372)
- dynamic-slice.340 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.33, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
- bitcast.1841 = bf16[24,1024]{0,1} bitcast(dynamic-slice.340)
- collective-permute.11 = bf16[24,1024]{0,1} collective-permute(bitcast.1841), channel_id=111, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
- get-tuple-element.53 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=8
- collective-permute.5 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.53), channel_id=96, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
- all-gather.21 = bf16[24,2048]{0,1} all-gather(collective-permute.5), channel_id=97, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
- bitcast.1848 = bf16[2048,24]{1,0} bitcast(all-gather.21)
- reduce.107 = bf16[2048]{0} reduce(reduce-scatter.2, constant.2006), dimensions={0,1}, to_apply=add
- all-reduce.42 = bf16[2048]{0} all-reduce(reduce.107), channel_id=95, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- bitcast.1864 = bf16[2048,1]{1,0} bitcast(all-reduce.42)
- dynamic-update-slice.34 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1848, bitcast.1864, constant.1980, select.372)
- dynamic-slice.341 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.34, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
- bitcast.1879 = bf16[24,1024]{0,1} bitcast(dynamic-slice.341)
- collective-permute.12 = bf16[24,1024]{0,1} collective-permute(bitcast.1879), channel_id=112, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
- get-tuple-element.54 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=9
- collective-permute.6 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.54), channel_id=99, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
- all-gather.22 = bf16[24,2048]{0,1} all-gather(collective-permute.6), channel_id=100, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
- bitcast.1886 = bf16[2048,24]{1,0} bitcast(all-gather.22)
- multiply.753 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, bitcast.1851)
- bitcast.1905 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.753)
- reduce.108 = bf16[2048]{0} reduce(bitcast.1905, constant.2006), dimensions={0,1}, to_apply=add
- all-reduce.43 = bf16[2048]{0} all-reduce(reduce.108), channel_id=98, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- bitcast.1914 = bf16[2048,1]{1,0} bitcast(all-reduce.43)
- dynamic-update-slice.35 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1886, bitcast.1914, constant.1980, select.372)
- dynamic-slice.342 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.35, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
- bitcast.1929 = bf16[24,1024]{0,1} bitcast(dynamic-slice.342)
- collective-permute.13 = bf16[24,1024]{0,1} collective-permute(bitcast.1929), channel_id=113, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
- get-tuple-element.55 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=10
- bitcast.1979 = bf16[3,32,2048,16,128]{4,2,3,1,0} bitcast(add.639)
- reduce.109 = bf16[3,16,128]{2,1,0} reduce(bitcast.1979, constant.2006), dimensions={1,2}, to_apply=add
- all-reduce.44 = bf16[3,16,128]{2,1,0} all-reduce(reduce.109), channel_id=101, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- bitcast.1963 = bf16[1,3,16,128]{3,2,1,0} bitcast(all-reduce.44)
- dynamic-update-slice.36 = bf16[24,3,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.55, bitcast.1963, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
- get-tuple-element.56 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=11
- bitcast.1974 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.56)
- transpose.130 = bf16[3,16,128,32,2048]{4,3,2,1,0} transpose(add.639), dimensions={0,2,4,1,3}
- bitcast.1983 = bf16[6144,65536]{1,0} bitcast(transpose.130)
- all-gather.23 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=102, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.2007 = bf16[65536,4096]{0,1} bitcast(all-gather.23)
- dot.55 = bf16[6144,4096]{0,1} dot(bitcast.1983, bitcast.2007), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitcast.2011 = bf16[3,16,128,4096]{2,1,0,3} bitcast(dot.55)
- reduce-scatter.5 = bf16[3,16,128,1024]{2,1,0,3} reduce-scatter(bitcast.2011), channel_id=327, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={3}, to_apply=add
- bitcast.2015 = bf16[1,1024,3,16,128]{4,3,2,1,0} bitcast(reduce-scatter.5)
- dynamic-update-slice.37 = bf16[24,1024,3,16,128]{4,3,2,1,0} dynamic-update-slice(bitcast.1974, bitcast.2015, select.372, constant.1980, constant.1980, /*index=5*/constant.1980, constant.1980)
- bitcast.2025 = bf16[24,3,1024,16,128]{4,3,1,2,0} bitcast(dynamic-update-slice.37)
- get-tuple-element.57 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=12
- reduce.110 = bf16[2048]{0} reduce(bitcast.2063, constant.2006), dimensions={0,1}, to_apply=add
- all-reduce.46 = bf16[2048]{0} all-reduce(reduce.110), channel_id=104, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- dynamic-slice.343 = bf16[1024]{0} dynamic-slice(all-reduce.46, bitcast.2087), dynamic_slice_sizes={1024}
- bitcast.2046 = bf16[1,1024]{1,0} bitcast(dynamic-slice.343)
- collective-permute.7 = bf16[1,1024]{1,0} collective-permute(bitcast.2046), channel_id=105, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
- dynamic-update-slice.38 = bf16[24,1024]{1,0} dynamic-update-slice(get-tuple-element.57, collective-permute.7, select.372, constant.1980)
- get-tuple-element.58 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=13
- bitcast.2066 = bf16[2048,65536]{1,0} bitcast(add.619)
- transpose.133 = bf16[16,32,2048,128]{3,2,1,0} transpose(dot.44), dimensions={1,0,3,2}
- bitcast.2072 = bf16[32,2048,16,128]{3,1,0,2} bitcast(transpose.133)
- all-gather.24 = bf16[32,2048,32,128]{3,1,0,2} all-gather(bitcast.2072), channel_id=106, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
- bitcast.2073 = bf16[32,32,2048,128]{3,2,1,0} bitcast(all-gather.24)
- transpose.134 = bf16[32,2048,32,128]{3,2,1,0} transpose(bitcast.2073), dimensions={1,2,0,3}
- bitcast.2077 = bf16[65536,4096]{1,0} bitcast(transpose.134)
- dot.56 = bf16[2048,4096]{1,0} dot(bitcast.2066, bitcast.2077), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitcast.2081 = bf16[2048,32,128]{2,1,0} bitcast(dot.56)
- all-reduce.47 = bf16[2048,32,128]{2,1,0} all-reduce(bitcast.2081), channel_id=107, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
- constant.2015 = s32[8]{0} constant({0, 0, 16, 16, 0, 0, 16, 16})
- dynamic-slice.344 = s32[1]{0} dynamic-slice(constant.2015, partition-id.3), dynamic_slice_sizes={1}
- bitcast.2095 = s32[] bitcast(dynamic-slice.344)
- dynamic-slice.345 = bf16[1024,16,128]{2,1,0} dynamic-slice(all-reduce.47, bitcast.2087, bitcast.2095, constant.1980), dynamic_slice_sizes={1024,16,128}
- bitcast.2102 = bf16[1,1024,16,128]{3,2,1,0} bitcast(dynamic-slice.345)
- collective-permute.8 = bf16[1,1024,16,128]{3,2,1,0} collective-permute(bitcast.2102), channel_id=108, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
- dynamic-update-slice.39 = bf16[24,1024,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.58, collective-permute.8, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
- ROOT tuple.2 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) tuple(add.581, bitcast.1525, dynamic-update-slice.28, dynamic-update-slice.29, collective-permute.9, /*index=5*/bitcast.1758, collective-permute.10, collective-permute.11, collective-permute.12, collective-permute.13, /*index=10*/dynamic-update-slice.36, bitcast.2025, dynamic-update-slice.38, dynamic-update-slice.39, get-tuple-element.45, /*index=15*/get-tuple-element.44, get-tuple-element.46, get-tuple-element.43, get-tuple-element.42, get-tuple-element.37, /*index=20*/get-tuple-element.36, get-tuple-element.38, get-tuple-element.35, get-tuple-element.41, get-tuple-element.40, /*index=25*/get-tuple-element.32, get-tuple-element.39, get-tuple-element.33)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- HloInstruction* fwd_instruction = nullptr;
- HloInstruction* bwd_instruction = nullptr;
- SCOPED_TRACE(m->ToString());
- for (HloInstruction* instr :
- m->entry_computation()->MakeInstructionPostOrder()) {
- if (instr->opcode() == HloOpcode::kCustomCall &&
- instr->custom_call_target() == kCudnnfMHASoftmaxCallTarget) {
- fwd_instruction = instr;
- }
- if (instr->opcode() == HloOpcode::kCustomCall &&
- instr->custom_call_target() == kCudnnfMHASoftmaxBackwardCallTarget) {
- bwd_instruction = instr;
- }
- }
- EXPECT_NE(fwd_instruction, nullptr);
- EXPECT_NE(bwd_instruction, nullptr);
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fwd_instruction->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
- BF16TrainingBmm2CanonicalizationRestoreFwdGraph) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- const char* module_str = R"(
-HloModule pjit__unnamed_function_, entry_computation_layout={(bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,4,256,256]{3,2,1,0})->(bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false,false}, num_partitions=4
-
-region_0.6 {
- Arg_0.7 = bf16[] parameter(0)
- Arg_1.8 = bf16[] parameter(1)
- ROOT maximum.5 = bf16[] maximum(Arg_0.7, Arg_1.8)
-}
-
-region_1.10 {
- Arg_0.11 = f32[] parameter(0)
- Arg_1.12 = f32[] parameter(1)
- ROOT add.14 = f32[] add(Arg_0.11, Arg_1.12)
-}
-
-add.clone {
- x.1 = u32[] parameter(0)
- y.1 = u32[] parameter(1)
- ROOT add.15 = u32[] add(x.1, y.1)
-}
-
-region_2.65 {
- Arg_0.66 = bf16[] parameter(0)
- Arg_1.67 = bf16[] parameter(1)
- ROOT add.16 = bf16[] add(Arg_0.66, Arg_1.67)
-}
-
-ENTRY main.164_spmd {
- param = bf16[2,256,4,64]{3,2,1,0} parameter(2), sharding={devices=[2,1,2,1]<=[4]}
- transpose.26 = bf16[2,4,64,256]{3,2,1,0} transpose(param), dimensions={0,2,3,1}
- param.1 = bf16[2,256,4,64]{3,2,1,0} parameter(0), sharding={devices=[2,1,2,1]<=[4]}
- transpose.27 = bf16[2,4,256,64]{3,2,1,0} transpose(param.1), dimensions={0,2,1,3}
- constant.46 = bf16[] constant(0.5)
- broadcast.126 = bf16[2,4,256,64]{3,2,1,0} broadcast(constant.46), dimensions={}
- multiply.34 = bf16[2,4,256,64]{3,2,1,0} multiply(transpose.27, broadcast.126)
- param.2 = bf16[2,256,4,64]{3,2,1,0} parameter(1), sharding={devices=[2,1,2,1]<=[4]}
- transpose.29 = bf16[2,4,64,256]{3,2,1,0} transpose(param.2), dimensions={0,2,3,1}
- dot.12 = bf16[2,4,256,256]{3,2,1,0} dot(multiply.34, transpose.29), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- param.3 = bf16[2,4,256,256]{3,2,1,0} parameter(4), sharding={devices=[2,2,1,1]<=[4]}
- add.17 = bf16[2,4,256,256]{3,2,1,0} add(dot.12, param.3)
- constant.47 = bf16[] constant(-inf)
- reduce.4 = bf16[2,4,256]{2,1,0} reduce(add.17, constant.47), dimensions={3}, to_apply=region_0.6
- broadcast.127 = bf16[2,4,256,256]{3,2,1,0} broadcast(reduce.4), dimensions={0,1,2}
- subtract.14 = bf16[2,4,256,256]{3,2,1,0} subtract(add.17, broadcast.127)
- exponential.2 = bf16[2,4,256,256]{3,2,1,0} exponential(subtract.14)
- convert.46 = f32[2,4,256,256]{3,2,1,0} convert(exponential.2)
- constant.48 = f32[] constant(0)
- reduce.5 = f32[2,4,256]{2,1,0} reduce(convert.46, constant.48), dimensions={3}, to_apply=region_1.10
- convert.47 = bf16[2,4,256]{2,1,0} convert(reduce.5)
- broadcast.128 = bf16[2,4,256,256]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2}
- divide.7 = bf16[2,4,256,256]{3,2,1,0} divide(exponential.2, broadcast.128)
- broadcast.129 = f32[4096]{0} broadcast(constant.48), dimensions={}
- constant.50 = u32[] constant(0)
- broadcast.131 = u32[8192]{0} broadcast(constant.50), dimensions={}
- broadcast.133 = u32[4096]{0} broadcast(constant.50), dimensions={}
- iota.3 = u32[8192]{0} iota(), iota_dimension=0
- slice.14 = u32[4096]{0} slice(iota.3), slice={[0:4096]}
- slice.15 = u32[4096]{0} slice(iota.3), slice={[4096:8192]}
- custom-call.3 = (u32[4096]{0}, u32[4096]{0}) custom-call(broadcast.133, broadcast.133, slice.14, slice.15), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[4096]{0}, u32[4096]{0}, u32[4096]{0}, u32[4096]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\020\000\000\000\000\000\000"
- get-tuple-element.6 = u32[4096]{0} get-tuple-element(custom-call.3), index=0
- constant.115 = u32[1]{0} constant({0})
- constant.52 = u32[4]{0} constant({0, 0, 1, 1})
- partition-id = u32[] partition-id()
- dynamic-slice.21 = u32[1]{0} dynamic-slice(constant.52, partition-id), dynamic_slice_sizes={1}
- constant.116 = u32[1]{0} constant({1})
- clamp.3 = u32[1]{0} clamp(constant.115, dynamic-slice.21, constant.116)
- convert.48 = s32[1]{0} convert(clamp.3)
- constant.117 = s32[1]{0} constant({2048})
- multiply.35 = s32[1]{0} multiply(convert.48, constant.117)
- bitcast.105 = s32[] bitcast(multiply.35)
- dynamic-slice.22 = u32[2048]{0} dynamic-slice(get-tuple-element.6, bitcast.105), dynamic_slice_sizes={2048}
- constant.58 = s32[4]{0} constant({0, 0, 1, 1})
- dynamic-slice.23 = s32[1]{0} dynamic-slice(constant.58, partition-id), dynamic_slice_sizes={1}
- multiply.36 = s32[1]{0} multiply(dynamic-slice.23, constant.117)
- bitcast.108 = s32[] bitcast(multiply.36)
- dynamic-update-slice.2 = u32[8192]{0} dynamic-update-slice(broadcast.131, dynamic-slice.22, bitcast.108)
- get-tuple-element.7 = u32[4096]{0} get-tuple-element(custom-call.3), index=1
- dynamic-slice.24 = u32[2048]{0} dynamic-slice(get-tuple-element.7, bitcast.105), dynamic_slice_sizes={2048}
- constant.65 = s32[] constant(4096)
- add.18 = s32[] add(bitcast.108, constant.65)
- dynamic-update-slice.3 = u32[8192]{0} dynamic-update-slice(dynamic-update-slice.2, dynamic-slice.24, add.18)
- all-reduce = u32[8192]{0} all-reduce(dynamic-update-slice.3), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.clone
- constant.118 = s32[1]{0} constant({4096})
- multiply.37 = s32[1]{0} multiply(dynamic-slice.23, constant.118)
- bitcast.119 = s32[] bitcast(multiply.37)
- dynamic-slice.25 = u32[4096]{0} dynamic-slice(all-reduce, bitcast.119), dynamic_slice_sizes={4096}
- constant.69 = u32[] constant(9)
- broadcast.134 = u32[4096]{0} broadcast(constant.69), dimensions={}
- shift-right-logical.6 = u32[4096]{0} shift-right-logical(dynamic-slice.25, broadcast.134)
- constant.70 = u32[] constant(1065353216)
- broadcast.135 = u32[4096]{0} broadcast(constant.70), dimensions={}
- or.5 = u32[4096]{0} or(shift-right-logical.6, broadcast.135)
- bitcast-convert.5 = f32[4096]{0} bitcast-convert(or.5)
- constant.71 = f32[] constant(-1)
- broadcast.136 = f32[4096]{0} broadcast(constant.71), dimensions={}
- add.19 = f32[4096]{0} add(bitcast-convert.5, broadcast.136)
- maximum.6 = f32[4096]{0} maximum(broadcast.129, add.19)
- constant.72 = f32[] constant(0.5)
- broadcast.137 = f32[4096]{0} broadcast(constant.72), dimensions={}
- compare.4 = pred[4096]{0} compare(maximum.6, broadcast.137), direction=LT
- bitcast.135 = pred[2,8,256]{2,1,0} bitcast(compare.4)
- convert.49 = bf16[2,8,256]{2,1,0} convert(bitcast.135)
- constant.80 = s32[] constant(0)
- constant.78 = s32[4]{0} constant({0, 4, 0, 4})
- dynamic-slice.26 = s32[1]{0} dynamic-slice(constant.78, partition-id), dynamic_slice_sizes={1}
- bitcast.181 = s32[] bitcast(dynamic-slice.26)
- dynamic-slice.27 = bf16[2,4,256]{2,1,0} dynamic-slice(convert.49, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
- broadcast.139 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.27), dimensions={0,1,3}
- multiply.38 = bf16[2,4,256,256]{3,2,1,0} multiply(divide.7, broadcast.139)
- constant.93 = bf16[] constant(2)
- broadcast.141 = bf16[2,4,256,256]{3,2,1,0} broadcast(constant.93), dimensions={}
- multiply.39 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.38, broadcast.141)
- dot.13 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.26, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- transpose.31 = bf16[4,2,64,256]{3,2,1,0} transpose(dot.13), dimensions={1,0,2,3}
- bitcast.154 = bf16[2,256,4,64]{1,3,0,2} bitcast(transpose.31)
- all-gather = bf16[2,256,8,64]{1,3,0,2} all-gather(bitcast.154), channel_id=2, replica_groups={{0,1},{2,3}}, dimensions={2}, use_global_device_ids=true
- bitcast.155 = bf16[8,2,64,256]{3,2,1,0} bitcast(all-gather)
- transpose.32 = bf16[2,8,64,256]{3,2,1,0} transpose(bitcast.155), dimensions={1,0,2,3}
- bitcast.157 = bf16[2,256,8,64]{1,3,2,0} bitcast(transpose.32)
- all-gather.1 = bf16[4,256,8,64]{1,3,2,0} all-gather(bitcast.157), channel_id=3, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true
- bitcast.236 = bf16[4,8,64,256]{3,2,1,0} bitcast(all-gather.1)
- transpose.38 = bf16[4,256,8,64]{3,2,1,0} transpose(bitcast.236), dimensions={0,3,1,2}
- param.4 = bf16[2,256,4,64]{3,2,1,0} parameter(3), sharding={devices=[2,1,2,1]<=[4]}
- transpose.33 = bf16[2,4,256,64]{3,2,1,0} transpose(param.4), dimensions={0,2,1,3}
- dot.14 = bf16[2,4,256,256]{3,2,1,0} dot(transpose.33, transpose.26), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- broadcast.142 = bf16[4096]{0} broadcast(constant.93), dimensions={}
- constant.95 = bf16[] constant(0)
- broadcast.143 = bf16[4096]{0} broadcast(constant.95), dimensions={}
- select.4 = bf16[4096]{0} select(compare.4, broadcast.142, broadcast.143)
- bitcast.176 = bf16[2,8,256]{2,1,0} bitcast(select.4)
- dynamic-slice.28 = bf16[2,4,256]{2,1,0} dynamic-slice(bitcast.176, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
- broadcast.145 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.28), dimensions={0,1,3}
- multiply.40 = bf16[2,4,256,256]{3,2,1,0} multiply(dot.14, broadcast.145)
- divide.8 = bf16[2,4,256,256]{3,2,1,0} divide(multiply.40, broadcast.128)
- constant.106 = bf16[] constant(1)
- broadcast.146 = bf16[2,4,256]{2,1,0} broadcast(constant.106), dimensions={}
- multiply.41 = bf16[2,4,256]{2,1,0} multiply(convert.47, convert.47)
- divide.9 = bf16[2,4,256]{2,1,0} divide(broadcast.146, multiply.41)
- broadcast.147 = bf16[2,4,256,256]{3,2,1,0} broadcast(divide.9), dimensions={0,1,2}
- multiply.42 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.40, broadcast.147)
- multiply.43 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.42, exponential.2)
- reduce.6 = bf16[2,4,256]{2,1,0} reduce(multiply.43, constant.95), dimensions={3}, to_apply=region_2.65
- negate.4 = bf16[2,4,256]{2,1,0} negate(reduce.6)
- broadcast.148 = bf16[2,4,256,256]{3,2,1,0} broadcast(negate.4), dimensions={0,1,2}
- add.20 = bf16[2,4,256,256]{3,2,1,0} add(divide.8, broadcast.148)
- multiply.44 = bf16[2,4,256,256]{3,2,1,0} multiply(add.20, exponential.2)
- transpose.34 = bf16[2,4,256,64]{3,2,1,0} transpose(param.2), dimensions={0,2,1,3}
- dot.15 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, transpose.34), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- multiply.45 = bf16[2,4,256,64]{3,2,1,0} multiply(dot.15, broadcast.126)
- transpose.39 = bf16[2,256,4,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
- dot.16 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, multiply.34), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.40 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.16), dimensions={0,2,1,3}
- transpose.36 = bf16[2,4,64,256]{3,2,1,0} transpose(param.4), dimensions={0,2,3,1}
- dot.11 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.36, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.41 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.11), dimensions={0,3,1,2}
- ROOT tuple.2 = (bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}) tuple(transpose.38, transpose.39, transpose.40, transpose.41)
-} // main.164_spmd
-)";
- // Dropout bwd pattern not supported, should not lower fwd as well
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
- SCOPED_TRACE(m->ToString());
- // check if fwd graph has been restored with cloned activation
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Transpose(), m::Transpose(), m::Transpose(),
- m::Transpose(m::Dot(
- m::Op(), m::Op().WithPredicate([](const HloInstruction* instr) {
- return instr->name() == "multiply.39.fmha_no_match_clone";
- }))))));
-}
-
-constexpr absl::string_view hlo_head_dim_not_multiple_of_64 = R"(
-HloModule jit__reference, entry_computation_layout={(f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0})->f16[4,48,1024,16]{3,2,1,0}}
-
-region_0.26 {
- Arg_0.27 = f32[] parameter(0)
- Arg_1.28 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(Arg_0.27, Arg_1.28)
-}
-
-region_1.37 {
- Arg_0.38 = f32[] parameter(0)
- Arg_1.39 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.38, Arg_1.39)
-}
-
-ENTRY main.49 {
- iota.2 = s32[1024,1024]{1,0} iota(), iota_dimension=0
- iota.3 = s32[1024,1024]{1,0} iota(), iota_dimension=1
- compare = pred[1024,1024]{1,0} compare(iota.2, iota.3), direction=GE
- broadcast.4 = pred[4,48,1024,1024]{3,2,1,0} broadcast(compare), dimensions={2,3}
- Arg_0.1 = f16[4,48,1024,16]{3,2,1,0} parameter(0)
- Arg_1.2 = f16[4,48,1024,16]{3,2,1,0} parameter(1)
- dot.9 = f16[4,48,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- constant.4 = f16[] constant(0.5)
- broadcast.6 = f16[4,48,1024,1024]{3,2,1,0} broadcast(constant.4), dimensions={}
- multiply = f16[4,48,1024,1024]{3,2,1,0} multiply(dot.9, broadcast.6)
- convert.1 = f32[4,48,1024,1024]{3,2,1,0} convert(multiply)
- constant.7 = f32[] constant(-inf)
- reduce.30 = f32[4,48,1024]{2,1,0} reduce(convert.1, constant.7), dimensions={3}, to_apply=region_0.26
- broadcast.8 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.30), dimensions={0,1,2}
- subtract = f32[4,48,1024,1024]{3,2,1,0} subtract(convert.1, broadcast.8)
- exponential = f32[4,48,1024,1024]{3,2,1,0} exponential(subtract)
- constant.6 = f32[] constant(0)
- reduce.41 = f32[4,48,1024]{2,1,0} reduce(exponential, constant.6), dimensions={3}, to_apply=region_1.37
- broadcast.9 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.41), dimensions={0,1,2}
- divide = f32[4,48,1024,1024]{3,2,1,0} divide(exponential, broadcast.9)
- convert.2 = f16[4,48,1024,1024]{3,2,1,0} convert(divide)
- Arg_2.3 = f16[4,48,1024,16]{3,2,1,0} parameter(2)
- ROOT dot.48 = f16[4,48,1024,16]{3,2,1,0} dot(convert.2, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-} // main.49
-)";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, HeadDimNotMultipleOf64) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- TF_ASSERT_OK_AND_ASSIGN(
- auto m, ParseAndReturnVerifiedModule(hlo_head_dim_not_multiple_of_64,
- GetModuleConfig()));
- CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
- GetCudnnVersion()};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
- // head dim not a multiple of 64 should not be lowered with cuDNN < 8.9.6
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot()));
-
- // should be lowered with cuDNN >= 8.9.6
- CudnnFusedMHARewriter fusedMhaRewriterWithcuDNN8907{
- GetCudaComputeCapability(), se::dnn::VersionInfo(8, 9, 7)};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriterWithcuDNN8907, m.get()).status());
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
- .WithShape(F16, {4, 48, 1024, 16})));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(config.fmha_scale(), 0.5);
- EXPECT_EQ(config.dropout_rate(), 0.0);
-}
-
-constexpr absl::string_view hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})->(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true,true}
-
-region_0.14 {
- Arg_0.15 = bf16[] parameter(0)
- Arg_1.16 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(Arg_0.15, Arg_1.16)
-}
-
-region_1.27 {
- Arg_0.28 = f32[] parameter(0)
- Arg_1.29 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0.28, Arg_1.29)
-}
-
-region_2.56 {
- Arg_0.57 = bf16[] parameter(0)
- Arg_1.58 = bf16[] parameter(1)
- ROOT add.1 = bf16[] add(Arg_0.57, Arg_1.58)
-}
-
-ENTRY main.87 {
- Arg_2.3 = bf16[2,1024,4,64]{3,2,1,0} parameter(2)
- transpose.12 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
- Arg_0.1 = bf16[2,1024,4,64]{3,2,1,0} parameter(0)
- transpose.13 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
- Arg_1.2 = bf16[2,1024,4,64]{3,2,1,0} parameter(1)
- transpose.15 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
- dot = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.13, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- Arg_4.5 = bf16[4,1024,1024]{2,1,0} parameter(4)
- broadcast.9 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(Arg_4.5), dimensions={1,2,3}
- add.2 = bf16[2,4,1024,1024]{3,2,1,0} add(dot, broadcast.9)
- constant.10 = bf16[] constant(-inf)
- reduce.18 = bf16[2,4,1024]{2,1,0} reduce(add.2, constant.10), dimensions={3}, to_apply=region_0.14
- broadcast.10 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(reduce.18), dimensions={0,1,2}
- subtract = bf16[2,4,1024,1024]{3,2,1,0} subtract(add.2, broadcast.10)
- exponential = bf16[2,4,1024,1024]{3,2,1,0} exponential(subtract)
- convert.5 = f32[2,4,1024,1024]{3,2,1,0} convert(exponential)
- constant.9 = f32[] constant(0)
- reduce.31 = f32[2,4,1024]{2,1,0} reduce(convert.5, constant.9), dimensions={3}, to_apply=region_1.27
- convert.6 = bf16[2,4,1024]{2,1,0} convert(reduce.31)
- broadcast.11 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(convert.6), dimensions={0,1,2}
- divide.2 = bf16[2,4,1024,1024]{3,2,1,0} divide(exponential, broadcast.11)
- dot.1 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.12, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
- transpose.22 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
- Arg_3.4 = bf16[2,1024,4,64]{3,2,1,0} parameter(3)
- transpose.17 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,1,3}
- dot.2 = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.17, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- divide.3 = bf16[2,4,1024,1024]{3,2,1,0} divide(dot.2, broadcast.11)
- constant.0 = bf16[] constant(1)
- broadcast.13 = bf16[2,4,1024]{2,1,0} broadcast(constant.0), dimensions={}
- multiply.2 = bf16[2,4,1024]{2,1,0} multiply(convert.6, convert.6)
- divide.4 = bf16[2,4,1024]{2,1,0} divide(broadcast.13, multiply.2)
- broadcast.14 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(divide.4), dimensions={0,1,2}
- multiply.3 = bf16[2,4,1024,1024]{3,2,1,0} multiply(dot.2, broadcast.14)
- multiply.4 = bf16[2,4,1024,1024]{3,2,1,0} multiply(multiply.3, exponential)
- constant.8 = bf16[] constant(0)
- reduce.60 = bf16[2,4,1024]{2,1,0} reduce(multiply.4, constant.8), dimensions={3}, to_apply=region_2.56
- negate.1 = bf16[2,4,1024]{2,1,0} negate(reduce.60)
- broadcast.15 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(negate.1), dimensions={0,1,2}
- add.3 = bf16[2,4,1024,1024]{3,2,1,0} add(divide.3, broadcast.15)
- multiply.5 = bf16[2,4,1024,1024]{3,2,1,0} multiply(add.3, exponential)
- transpose.18 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,1,3}
- dot.4 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.18), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.23 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.4), dimensions={0,2,1,3}
- dot.3 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.13), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.24 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.3), dimensions={0,2,1,3}
- transpose.20 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,3,1}
- dot.49 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.20, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
- transpose.25 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.49), dimensions={0,3,1,2}
- reduce.81 = bf16[4,1024,1024]{2,1,0} reduce(multiply.5, constant.8), dimensions={0}, to_apply=region_2.56
- ROOT tuple = (bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0}) tuple(transpose.22, transpose.23, transpose.24, transpose.25, reduce.81)
-} // main.87
-)";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxBmm2PatternDbias) {
- if (skip_reason_) GTEST_SKIP() << *skip_reason_;
- TF_ASSERT_OK_AND_ASSIGN(
- auto m,
- ParseAndReturnVerifiedModule(hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias));
- // require cudnn 8.9.6 + hopper for dbias
- CudnnFusedMHARewriter fusedMhaRewriter{se::CudaComputeCapability(9, 0),
- se::dnn::VersionInfo(8, 9, 6)};
- TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- const HloInstruction* fmha;
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Transpose(
- m::Transpose(m::GetTupleElement(
- m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
- 0)))
- .WithShape(BF16, {2, 1024, 4, 64}),
- m::Transpose(
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 0))
- .WithShape(BF16, {2, 1024, 4, 64}),
- m::Transpose(
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 1))
- .WithShape(BF16, {2, 1024, 4, 64}),
- m::Transpose(
- m::Transpose(m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 2)))
- .WithShape(BF16, {2, 1024, 4, 64}),
- m::Reshape(
- m::GetTupleElement(
- m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
- 3))
- .WithShape(BF16, {4, 1024, 1024}))));
- TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
- EXPECT_EQ(fmha->operands().size(), 4);
- EXPECT_EQ(fmha->operand(3)->shape(),
- ShapeUtil::MakeShape(BF16, {1, 4, 1024, 1024}));
- EXPECT_EQ(config.fmha_scale(), 1.0);
- EXPECT_EQ(config.dropout_rate(), 0.0);
- EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
-}
-} // anonymous namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc
deleted file mode 100644
index 665cc0b..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc
+++ /dev/null
@@ -1,668 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <iterator>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-namespace m = match;
-
-bool IsFMHACustomCall(const HloInstruction* instr) {
- return IsCustomCallTofMHA(*instr);
-}
-
-bool IsFwdFMHACustomCall(const HloInstruction* instr) {
- return IsFwdCustomCallTofMHA(*instr);
-}
-
-bool IsBwdFMHACustomCall(const HloInstruction* instr) {
- return IsBwdCustomCallTofMHA(*instr);
-}
-
-absl::StatusOr<bool> FuseArgPrologueTransposeWithcuDNNFMHA(
- HloInstruction* fmha, int64_t operand_index, bool is_lhs,
- bool should_contracting_be_fastest) {
- HloInstruction* transpose_arg = fmha->mutable_operand(operand_index);
- HloInstruction* transpose_arg_operand = transpose_arg->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- CudnnfMHABackendConfig config = gpu_config.cudnn_fmha_backend_config();
- CudnnfMHABackendConfig& new_fmha_config =
- *gpu_config.mutable_cudnn_fmha_backend_config();
-
- std::vector<int64_t> inverse_perm =
- InversePermutation(transpose_arg->dimensions());
- DotDimensionNumbers new_bmm_dot_dims;
- if (IsFwdCustomCallTofMHA(*fmha)) {
- if (operand_index == 0 || operand_index == 1) {
- new_bmm_dot_dims = config.bmm1_dot_dimension_numbers();
- } else {
- new_bmm_dot_dims = config.bmm2_dot_dimension_numbers();
- }
- } else {
- switch (operand_index) {
- case 0:
- // Q
- new_bmm_dot_dims = config.bmm1_grad_gemm1_dot_dimension_numbers();
- break;
- case 1:
- // K
- new_bmm_dot_dims = config.bmm1_grad_gemm2_dot_dimension_numbers();
- break;
- case 2:
- // V
- new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
- break;
- case 3:
- // Forward activation
- new_bmm_dot_dims = config.bmm2_grad_gemm1_dot_dimension_numbers();
- break;
- case 4:
- // Output gradient
- new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
- break;
- default:
- return Internal("Invalid operand index.");
- }
- }
- absl::Span<const int64_t> checked_dims;
- std::vector<int64_t> checked_dims_vec;
-
- // `should_contracting_be_fastest` means if contracting dim is the head
- // dim. cuDNN requires head dim to be the fastest dim. fwd bmm1 and bwd
- // bmm2grad1 should set this value to true.
- if (should_contracting_be_fastest) {
- checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
- : new_bmm_dot_dims.rhs_contracting_dimensions();
- } else {
- absl::Span<const int64_t> batch_dims =
- is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
- : new_bmm_dot_dims.rhs_batch_dimensions();
- absl::Span<const int64_t> contracting_dims =
- is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
- : new_bmm_dot_dims.rhs_contracting_dimensions();
-
- TF_ASSIGN_OR_RETURN(checked_dims_vec,
- GetNonContractingDims(transpose_arg->shape(),
- batch_dims, contracting_dims));
- checked_dims = checked_dims_vec;
- }
-
- int64_t checked_dims_bmm_size = checked_dims.size();
- std::vector<int64_t> new_bmm_checked_dims(checked_dims_bmm_size);
- for (int i = 0; i < checked_dims_bmm_size; i++) {
- auto itr =
- std::find(inverse_perm.begin(), inverse_perm.end(), checked_dims[i]);
- if (itr == inverse_perm.end()) {
- return Internal("Invalid inverse perm");
- }
- new_bmm_checked_dims[i] = std::distance(inverse_perm.begin(), itr);
- }
- // We want to make sure that making the argument to transpose, an input to
- // fmha, doesn't break cuDNN constraint that the head dim of
- // corresponding operand of BMM is the fastest moving dimension.
- // One exception is the forward activation which doesn't have the constraint
- // since it does not have head dim.
- absl::Span<const int64_t> minor_to_major_bmm =
- transpose_arg_operand->shape().layout().minor_to_major();
- if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) &&
- !(IsBwdCustomCallTofMHA(*fmha) && operand_index == 3)) {
- return false;
- }
- if (should_contracting_be_fastest) {
- if (is_lhs) {
- new_bmm_dot_dims.clear_lhs_contracting_dimensions();
- *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
- new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
- } else {
- new_bmm_dot_dims.clear_rhs_contracting_dimensions();
- *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
- new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
- }
- }
- auto& batch_dims = is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
- : new_bmm_dot_dims.rhs_batch_dimensions();
- int64_t batch_dims_bmm_size = batch_dims.size();
- std::vector<int64_t> new_bmm_batch_dims(batch_dims_bmm_size);
- for (int i = 0; i < batch_dims_bmm_size; i++) {
- auto itr =
- std::find(inverse_perm.begin(), inverse_perm.end(), batch_dims[i]);
- if (itr == inverse_perm.end()) {
- return Internal("Invalid inverse perm");
- }
- new_bmm_batch_dims[i] = std::distance(inverse_perm.begin(), itr);
- }
-
- if (is_lhs) {
- new_bmm_dot_dims.clear_lhs_batch_dimensions();
- *new_bmm_dot_dims.mutable_lhs_batch_dimensions() = {
- new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
-
- } else {
- new_bmm_dot_dims.clear_rhs_batch_dimensions();
- *new_bmm_dot_dims.mutable_rhs_batch_dimensions() = {
- new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
- }
-
- if (!should_contracting_be_fastest) {
- // Given the non-contracting dimensions, we can use the same function,
- // GetNonContractingDims, to find the new contracting dims. Simply pass the
- // non-contracting dimensions as the second argument.
- TF_ASSIGN_OR_RETURN(
- std::vector<int64_t> new_bmm_contracting_dims,
- GetNonContractingDims(transpose_arg_operand->shape(),
- new_bmm_batch_dims, new_bmm_checked_dims));
- if (is_lhs) {
- new_bmm_dot_dims.clear_lhs_contracting_dimensions();
- *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
- new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
-
- } else {
- new_bmm_dot_dims.clear_rhs_contracting_dimensions();
- *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
- new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
- }
- }
- if (IsFwdCustomCallTofMHA(*fmha)) {
- if (operand_index == 0 || operand_index == 1) {
- // Q or K
- *new_fmha_config.mutable_bmm1_dot_dimension_numbers() = new_bmm_dot_dims;
- } else {
- // V
- *new_fmha_config.mutable_bmm2_dot_dimension_numbers() = new_bmm_dot_dims;
- }
- } else {
- switch (operand_index) {
- case 0:
- // Q
- *new_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
- new_bmm_dot_dims;
- break;
- case 1:
- // K
- *new_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
- new_bmm_dot_dims;
- break;
- case 2:
- // V
- *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
- new_bmm_dot_dims;
- break;
- case 3:
- // Forward activation
- *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
- new_bmm_dot_dims;
- break;
- case 4: {
- // Output gradient
- *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
- new_bmm_dot_dims;
- DotDimensionNumbers bmm2_grad_gemm1_dot_dims =
- config.bmm2_grad_gemm1_dot_dimension_numbers();
- absl::Span<const int64_t> bmm2_grad_gemm1_contracting_dims =
- bmm2_grad_gemm1_dot_dims.rhs_contracting_dimensions();
- CHECK_EQ(bmm2_grad_gemm1_contracting_dims.size(), 1);
- absl::Span<const int64_t> transpose_permutation =
- transpose_arg->dimensions();
- auto itr = std::find(transpose_permutation.begin(),
- transpose_permutation.end(),
- bmm2_grad_gemm1_contracting_dims[0]);
- if (itr == transpose_permutation.end()) {
- return Internal(
- "bmm2 gradident gemm1 contracting dimension not found.");
- }
- int64_t index = std::distance(transpose_permutation.begin(), itr);
- std::vector<int64_t> new_bmm2_grad_gemm1_rhs_contracting_dims = {index};
- // Find the new batch dimensions, this is done by passing new
- // contracting dimensions and contracting dimension of lhs of
- // bmm2_grad_gemm2(which is the non-contracting dimension of rhs
- // bmm2_grad_gemm1) to GetNonContractingDims.
- TF_ASSIGN_OR_RETURN(
- std::vector<int64_t> new_bmm2_grad_gemm1_rhs_batch_dims,
- GetNonContractingDims(
- transpose_arg_operand->shape(),
- new_bmm2_grad_gemm1_rhs_contracting_dims,
- new_bmm_dot_dims.lhs_contracting_dimensions()));
- bmm2_grad_gemm1_dot_dims.clear_rhs_contracting_dimensions();
- bmm2_grad_gemm1_dot_dims.clear_rhs_batch_dimensions();
- *bmm2_grad_gemm1_dot_dims.mutable_rhs_contracting_dimensions() = {
- new_bmm2_grad_gemm1_rhs_contracting_dims.begin(),
- new_bmm2_grad_gemm1_rhs_contracting_dims.end()};
- *bmm2_grad_gemm1_dot_dims.mutable_rhs_batch_dimensions() = {
- new_bmm2_grad_gemm1_rhs_batch_dims.begin(),
- new_bmm2_grad_gemm1_rhs_batch_dims.end()};
- *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
- bmm2_grad_gemm1_dot_dims;
- break;
- }
- default:
- return Internal("Invalid operand index.");
- }
- }
-
- TF_RETURN_IF_ERROR(fmha->set_backend_config(gpu_config));
-
- TF_RETURN_IF_ERROR(fmha->ReplaceOperandWithDifferentShape(
- operand_index, transpose_arg_operand));
-
- return true;
-}
-
-/* Let's say A is transposed to B with perm {3, 0, 2, 1} as shown below:
-A[16, 256, 32, 64]
- |
- |
- | Transpose with perm = {3, 0, 2, 1}
- |
- \/
-B[64, 16, 32, 256]
-
-The inverse perm to obtain A from B would be {1, 3, 2, 0}. That is
-B[64, 16, 32, 256]
- |
- |
- | Transpose' with inv_perm = {1, 3, 2, 0}
- |
- \/
-A[16, 256, 32, 64]
-
-Now, let's say B is the lhs of a BatchedMatmul and the lhs_contracting
-dim is 3 (i.e dim 256). In order to now make A the lhs to the
-batchedMatmul (thus consuming the Transpose from A->B), we need to find
-the dimension number in A that corresponds to dimension number 3 in B.
-This can be done by finding the index of dim num 3 in inv_perm. That
-would be 2. Hence, dim num 3 in B is equivalent to dim num 2 in A. Thus
-the new lhs_contracting dim ,if A were to be the new lhs, would be 2.
-
-Similarly, we need to find corresponding batch dimensions as well.
-*/
-absl::StatusOr<bool> FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction *transpose_arg0, *transpose_arg0_operand;
- HloInstruction *transpose_arg1, *transpose_arg1_operand;
- HloInstruction *transpose_arg2, *transpose_arg2_operand;
- HloInstruction *transpose_arg3, *transpose_arg3_operand;
- HloInstruction *transpose_arg4, *transpose_arg4_operand;
-
- HloInstruction* fmha;
-
- // Arg0 is common between forward and backward fmha calls, so we match
- // either of these.
- auto pattern_arg0 =
- m::Op(&fmha)
- .WithPredicate(IsFMHACustomCall)
- .WithOperand(0, m::Transpose(&transpose_arg0,
- m::Op(&transpose_arg0_operand)));
- if (Match(instr, pattern_arg0)) {
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 0: \n"
- << comp->parent()->ToString();
- }
- if (IsFwdFMHACustomCall(fmha)) {
- // Q tensor in forward graph is lhs with constraint on contracting dim.
- TF_ASSIGN_OR_RETURN(changed,
- FuseArgPrologueTransposeWithcuDNNFMHA(
- fmha, 0, true /*is_lhs*/,
- true /*should_contracting_be_fastest*/));
- } else {
- // Q tensor in backward graph is rhs with constraint on non-contracting
- // dim.
- TF_ASSIGN_OR_RETURN(changed,
- FuseArgPrologueTransposeWithcuDNNFMHA(
- fmha, 0, false /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
- }
-
- if (changed && VLOG_IS_ON(2)) {
- VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 0: \n"
- << comp->parent()->ToString();
- }
- }
-
- // Arg1 is common between forward and backward fmha calls, so we match
- // either of these.
- auto pattern_arg1 =
- m::Op(&fmha)
- .WithPredicate(IsFMHACustomCall)
- .WithOperand(1, m::Transpose(&transpose_arg1,
- m::Op(&transpose_arg1_operand)));
- if (Match(instr, pattern_arg1)) {
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 1: \n"
- << comp->parent()->ToString();
- }
- if (IsFwdFMHACustomCall(fmha)) {
- // K tensor in forward graph is rhs with constraint on contracting dim.
- TF_ASSIGN_OR_RETURN(changed,
- FuseArgPrologueTransposeWithcuDNNFMHA(
- fmha, 1, false /*is_lhs*/,
- true /*should_contracting_be_fastest*/));
- } else {
- // K tensor in backward graph is rhs with constraint on non-contracting
- // dim.
- TF_ASSIGN_OR_RETURN(changed,
- FuseArgPrologueTransposeWithcuDNNFMHA(
- fmha, 1, false /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
- }
-
- if (changed && VLOG_IS_ON(2)) {
- VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 1: \n"
- << comp->parent()->ToString();
- }
- }
-
- // Arg2 is common between forward and backward fmha calls, so we match
- // either of these.
- auto pattern_arg2 =
- m::Op(&fmha)
- .WithPredicate(IsFMHACustomCall)
- .WithOperand(2, m::Transpose(&transpose_arg2,
- m::Op(&transpose_arg2_operand)));
- if (Match(instr, pattern_arg2)) {
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 2: \n"
- << comp->parent()->ToString();
- }
- if (IsFwdFMHACustomCall(fmha)) {
- // V tensor in forward graph is rhs with constraint on non-contracting
- // dim.
- TF_ASSIGN_OR_RETURN(changed,
- FuseArgPrologueTransposeWithcuDNNFMHA(
- fmha, 2, false /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
- } else {
- // V tensor in backward graph is rhs with constraint on contracting dim.
- TF_ASSIGN_OR_RETURN(changed,
- FuseArgPrologueTransposeWithcuDNNFMHA(
- fmha, 2, false /*is_lhs*/,
- true /*should_contracting_be_fastest*/));
- }
-
- if (changed && VLOG_IS_ON(2)) {
- VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 2: \n"
- << comp->parent()->ToString();
- }
- }
-
- // We only care about arg3 of backward
- auto pattern_arg3 =
- m::Op(&fmha)
- .WithPredicate(IsBwdFMHACustomCall)
- .WithOperand(3, m::Transpose(&transpose_arg3,
- m::Op(&transpose_arg3_operand)));
- if (Match(instr, pattern_arg3)) {
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 3: \n"
- << comp->parent()->ToString();
- }
- // Forward activation tensor in backward graph is lhs with constraint on
- // non-contracting dim.
- TF_ASSIGN_OR_RETURN(changed,
- FuseArgPrologueTransposeWithcuDNNFMHA(
- fmha, 3, true /*is_lhs*/,
- false /*should_contracting_be_fastest*/));
-
- if (changed && VLOG_IS_ON(2)) {
- VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 3: \n"
- << comp->parent()->ToString();
- }
- }
-
- // We only care about arg4 of backward
- auto pattern_arg4 =
- m::Op(&fmha)
- .WithPredicate(IsBwdFMHACustomCall)
- .WithOperand(4, m::Transpose(&transpose_arg4,
- m::Op(&transpose_arg4_operand)));
- if (Match(instr, pattern_arg4)) {
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 4: \n"
- << comp->parent()->ToString();
- }
- // D_output tensor in backward graph is lhs with constraint on
- // contracting dim.
- // make sure we dont change layout of dO in flash attention case as dO
- // should have the same layout of O
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- fmha->backend_config<GpuBackendConfig>());
- if (changed && VLOG_IS_ON(2)) {
- VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 4: \n"
- << comp->parent()->ToString();
- }
- }
- }
- return changed;
-}
-
-/* Let's say FMHA out is transposed to result with perm {1, 2, 0, 3} as shown
-below: FMHA_out[b0, b1, n, m]{}
- |
- |
- Transpose with perm = {1, 2, 0, 3}
- |
- \/
-result[b1, n, b0, m]{1, 0, 3, 2}
-The goal is to find the minor_to_major of 'FMHA_out' such that it's physical
-layout matches the physical layout of 'result', thus eliminating the need for an
-explicit transpose. cuDNN can perform an implicit transpose by knowing the
-corresponding strides (inferred from the corresponding minor_to_major).
-
-In order to find the required mino_to_major of 'FMHA_out', we first determine
-the inverse perm to obtain 'FMHA_out' from 'result'. The function
-"ShapeUtil::PermuteDimensions" generates a transposed shape such that the
-physical layout of the transposed shape is equivalent to the input shape.
-Calling this function with 'result' shape as the input shape and the inverse
-perm as the permutation will generate an output shape whose dimensions match
-'FMHA_out' dimensions but the physical layout is equivalent to 'result'. This is
-exactly what we want.
-
-FMHA output should have exactly one gte instruction for a tuple index
-so we can safely fuse the transpose following that gte to FMHA
-
-FMHA_out = gte(FMHA, index=0)
-FMHA_out_t = transpose(FMHA_out)
-use(FMHA_out_t)
-
-after fusion:
-
-FMHA_out_t = gte(FMHA, index=0)
-use(FMHA_out_t)
-*/
-
-absl::StatusOr<bool> FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) {
- bool changed = false;
-
- auto only_one_gte_with_spec_index = [](const HloInstruction* instr,
- int64_t index) {
- int count = 0;
- for (auto user : instr->users()) {
- if (user->opcode() == HloOpcode::kGetTupleElement &&
- user->tuple_index() == index) {
- count += 1;
- }
- }
- return count == 1;
- };
-
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* fmha;
- HloInstruction* transpose;
- HloInstruction* gte;
- auto fwd_tuple_elem =
- m::GetTupleElement(>e,
- m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0)
- .WithOneUser();
- // Note that we don't match any specific tuple index in matcher for
- // backward.
- auto bwd_tuple_elem =
- m::GetTupleElement(>e,
- m::Op(&fmha).WithPredicate(IsBwdFMHACustomCall))
- .WithOneUser();
- auto fwd_pattern = m::Transpose(&transpose, fwd_tuple_elem);
- auto bwd_pattern = m::Transpose(&transpose, bwd_tuple_elem);
-
- if (Match(instr, fwd_pattern)) {
- // check if only one gte with such index exist
- int64_t tuple_index = gte->tuple_index();
- if (!only_one_gte_with_spec_index(fmha, tuple_index)) continue;
-
- std::vector<int64_t> inverse_perm =
- InversePermutation(transpose->dimensions());
-
- auto expected_fmha_shape =
- ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
-
- // cuDNN requires the last dimension of the output to be the fastest
- // moving.
- if (expected_fmha_shape.layout().minor_to_major()[0] !=
- expected_fmha_shape.dimensions_size() - 1) {
- VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
- "the fastest moving. The last dimension is dim: "
- << expected_fmha_shape.dimensions_size() - 1
- << " but the upon fusion with transpose, the fmha output shape "
- "would have been "
- << expected_fmha_shape.ToString(true)
- << " and the fastest moving "
- "dimension would be dim: "
- << expected_fmha_shape.layout().minor_to_major()[0];
- continue;
- }
- Shape call_shape = fmha->shape();
- *call_shape.mutable_tuple_shapes(0) = expected_fmha_shape;
- HloInstruction* new_fmha_custom_call =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- call_shape, fmha->operands(),
- absl::string_view(fmha->custom_call_target())));
-
- TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
- fmha->backend_config<GpuBackendConfig>());
- TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
- TF_RETURN_IF_ERROR(
- SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
- new_fmha_custom_call->set_metadata(fmha->metadata());
-
- auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
- new_fmha_custom_call->shape().tuple_shapes(0), new_fmha_custom_call,
- 0));
- TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
- instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
- TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
-
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "After forward FuseEpilogueTransposeWithcuDNNFMHA: \n"
- << comp->parent()->ToString();
- }
- changed |= true;
- } else if (Match(instr, bwd_pattern)) {
- // check if only one gte with such index exist
- int64_t operand_tuple_idx = gte->tuple_index();
- if (!only_one_gte_with_spec_index(fmha, operand_tuple_idx)) continue;
-
- std::vector<int64_t> inverse_perm =
- InversePermutation(transpose->dimensions());
-
- auto expected_fmha_shape =
- ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
-
- // cuDNN requires the last dimension of the output to be the fastest
- // moving.
- if (expected_fmha_shape.layout().minor_to_major()[0] !=
- expected_fmha_shape.dimensions_size() - 1) {
- VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
- "the fastest moving. The last dimension is dim: "
- << expected_fmha_shape.dimensions_size() - 1
- << " but the upon fusion with transpose, the fmha output shape "
- "would have been "
- << expected_fmha_shape.ToString(true)
- << " and the fastest moving "
- "dimension would be dim: "
- << expected_fmha_shape.layout().minor_to_major()[0];
- continue;
- }
- Shape call_shape = fmha->shape();
- *call_shape.mutable_tuple_shapes(operand_tuple_idx) = expected_fmha_shape;
- HloInstruction* new_fmha_custom_call =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- call_shape, fmha->operands(),
- absl::string_view(fmha->custom_call_target())));
-
- TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
- fmha->backend_config<GpuBackendConfig>());
- TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
- TF_RETURN_IF_ERROR(
- SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
- new_fmha_custom_call->set_metadata(fmha->metadata());
- TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
-
- auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
- new_fmha_custom_call->shape().tuple_shapes(operand_tuple_idx),
- new_fmha_custom_call, operand_tuple_idx));
- TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
- instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
-
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "After backward FuseEpilogueTransposeWithcuDNNFMHA: \n"
- << comp->parent()->ToString();
- }
- changed |= true;
- }
- }
- return changed;
-}
-} // namespace
-
-absl::StatusOr<bool> CudnnFusedMHATransposeFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool any_changed = false;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- bool changed = false;
- TF_ASSIGN_OR_RETURN(changed, FusePrologueTransposeWithcuDNNFMHA(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseEpilogueTransposeWithcuDNNFMHA(comp));
- any_changed |= changed;
- }
-
- return any_changed;
-}
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h b/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h
deleted file mode 100644
index 94ec229..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h
+++ /dev/null
@@ -1,45 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-class CudnnFusedMHATransposeFusion : public HloModulePass {
- public:
- CudnnFusedMHATransposeFusion() = default;
-
- absl::string_view name() const override {
- return "cudnn-fused-multi-headed-attention-transpose-fusion";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc
deleted file mode 100644
index 1806796..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc
+++ /dev/null
@@ -1,734 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/primitive_util.h"
-#include "xla/service/dump.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cudnn_support_utils.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/kernel_reuse_cache.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/triton_fusion_analysis.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/cuda/cuda_dnn.h"
-#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-namespace fe = cudnn_frontend;
-namespace graph = fe::graph;
-
-inline std::optional<fe::PointwiseMode_t> GetElementwiseMode(
- const HloInstruction& instruction) {
- const HloOpcode opcode = instruction.opcode();
- using m = fe::PointwiseMode_t;
- switch (opcode) {
- case HloOpcode::kAbs:
- return m::ABS;
- case HloOpcode::kAdd:
- return m::ADD;
- case HloOpcode::kCeil:
- return m::CEIL;
- case HloOpcode::kCompare:
- switch (instruction.comparison_direction()) {
- case Comparison::Direction::kEq:
- return m::CMP_EQ;
- case Comparison::Direction::kNe:
- return m::CMP_NEQ;
- case Comparison::Direction::kGe:
- return m::CMP_GE;
- case Comparison::Direction::kGt:
- return m::CMP_GT;
- case Comparison::Direction::kLe:
- return m::CMP_LE;
- case Comparison::Direction::kLt:
- return m::CMP_LT;
- }
- break;
- case HloOpcode::kConvert:
- return m::IDENTITY;
- case HloOpcode::kCos:
- return m::COS;
- case HloOpcode::kDivide:
- return m::DIV;
- case HloOpcode::kExp:
- return m::EXP;
- case HloOpcode::kFloor:
- return m::FLOOR;
- case HloOpcode::kLog:
- return m::LOG;
- case HloOpcode::kMaximum:
- return m::MAX;
- case HloOpcode::kMinimum:
- return m::MIN;
- case HloOpcode::kMultiply:
- return m::MUL;
- case HloOpcode::kNegate:
- return m::NEG;
- case HloOpcode::kPower:
- return m::POW;
- case HloOpcode::kRsqrt:
- return m::RSQRT;
-#if CUDNN_VERSION >= 90100
- case HloOpcode::kSelect:
- return m::BINARY_SELECT;
-#endif // CUDNN_VERSION
- case HloOpcode::kSin:
- return m::SIN;
- case HloOpcode::kSqrt:
- return m::SQRT;
- case HloOpcode::kSubtract:
- return m::SUB;
- case HloOpcode::kTan:
- return m::TAN;
- case HloOpcode::kTanh:
- return m::TANH_FWD;
- default:
- return std::nullopt;
- }
-}
-
-inline std::optional<fe::DataType_t> ToCudnnDataType(const PrimitiveType type) {
- using t = fe::DataType_t;
- switch (type) {
- case PrimitiveType::F32:
- return t::FLOAT;
- case PrimitiveType::F16:
- return t::HALF;
- case PrimitiveType::BF16:
- return t::BFLOAT16;
- case PrimitiveType::S32:
- return t::INT32;
- case PrimitiveType::S8:
- return t::INT8;
- case PrimitiveType::PRED:
- return t::INT8;
- case PrimitiveType::F8E5M2:
- return t::FP8_E5M2;
- case PrimitiveType::F8E4M3FN:
- return t::FP8_E4M3;
- default:
- return std::nullopt;
- }
-}
-
-inline std::optional<fe::DataType_t> GetComputeDataType(
- const PrimitiveType type) {
- fe::DataType_t compute_dtype = fe::DataType_t::FLOAT;
- if (primitive_util::IsIntegralType(type)) {
-#if CUDNN_VERSION >= 90100
- compute_dtype = fe::DataType_t::INT32;
-#else
- VLOG(3) << "Integer math requires cuDNN 9.1+.";
- return std::nullopt;
-#endif // CUDNN_VERSION
- }
- return compute_dtype;
-}
-
-int FusionLevel(const HloInstruction& hlo) {
- return hlo.GetModule()
- ->config()
- .debug_options()
- .xla_gpu_cudnn_gemm_fusion_level();
-};
-
-// Extracts dimensions and strides from HLO tensors in the format expected by
-// cuDNN.
-class GemmDimensionAdapter {
- explicit GemmDimensionAdapter(const HloDotInstruction& dot,
- TritonFusionAnalysis analysis)
- : analysis_(std::move(analysis)), dot_(dot) {};
-
- public:
- const TritonFusionAnalysis analysis_;
-
- static absl::StatusOr<std::optional<GemmDimensionAdapter>> Create(
- const HloComputation& computation) {
- const HloInstruction* maybe_dot =
- hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot);
- if (maybe_dot == nullptr) {
- VLOG(3) << "Not a GEMM fusion.";
- return std::nullopt;
- }
- const HloDotInstruction* dot = DynCast<HloDotInstruction>(
- hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot));
- if (absl::c_any_of(dot->precision_config().operand_precision(),
- [](int x) { return x != PrecisionConfig::DEFAULT; })) {
- VLOG(3) << "Non-default precision is not supported.";
- return std::nullopt;
- }
- TF_ASSIGN_OR_RETURN(auto analysis,
- TritonFusionAnalysis::Execute(computation));
- return GemmDimensionAdapter{*dot, std::move(analysis)};
- }
-
- bool DimensionsAndStrides(const HloInstruction& hlo,
- const TritonFusionAnalysis::Scope scope,
- std::vector<int64_t>& dimensions,
- std::vector<int64_t>& strides) {
- const DotDimensionNumbers& dims = dot_.dot_dimension_numbers();
- // GEMM fusions require a specific canonical order of dimensions.
- constexpr int kBatchDimensionIndex = 0;
- constexpr int kOutputLHSNonContractingDimensionIndex = 1;
- std::vector<int64_t> dim_indices;
- int lhs_noncontracting_index = -1;
- switch (scope) {
- case TritonFusionAnalysis::Scope::LHS:
- lhs_noncontracting_index =
- GetNonContractingDims(dot_.operand(0)->shape(),
- dims.lhs_batch_dimensions(),
- dims.lhs_contracting_dimensions())
- .value()[0];
- dim_indices = {
- dims.lhs_batch_dimensions().empty() ? -1
- : dims.lhs_batch_dimensions(0),
- lhs_noncontracting_index, dims.lhs_contracting_dimensions(0)};
- break;
- case TritonFusionAnalysis::Scope::RHS:
- dim_indices = {dims.rhs_batch_dimensions().empty()
- ? -1
- : dims.rhs_batch_dimensions(0),
- dims.rhs_contracting_dimensions(0),
- GetNonContractingDims(dot_.operand(1)->shape(),
- dims.rhs_batch_dimensions(),
- dims.rhs_contracting_dimensions())
- .value()[0]};
- break;
- case TritonFusionAnalysis::Scope::OUTPUT:
- lhs_noncontracting_index = dot_.shape().rank() - 2;
- dim_indices = {dims.lhs_batch_dimensions().empty() ? -1 : 0,
- lhs_noncontracting_index, dot_.shape().rank() - 1};
- break;
- case TritonFusionAnalysis::Scope::META:
- LOG(FATAL) << "Unsupported scope.";
- }
- dimensions.reserve(dim_indices.size());
- strides.reserve(dim_indices.size());
- for (const int index : dim_indices) {
- const auto* spec = analysis_.IterSpec(scope, &hlo, index);
- if (spec == nullptr) {
- dimensions.push_back(1);
- strides.push_back(strides.empty() ? 1 : strides.back());
- continue;
- } else {
- if (spec->size() == 1) {
- // The dimension is not split, nothing to do.
- } else if (spec->size() == 2) {
- if (FusionLevel(hlo) < 3) {
- return false;
- }
- if (!dims.lhs_batch_dimensions().empty()) {
- VLOG(8) << "Noncontracting dimension split is not compatible with "
- "batch dimensions.";
- return false;
- }
- if (index != lhs_noncontracting_index) {
- VLOG(8) << "Only LHS noncontracting dimension can be split.";
- return false;
- }
- switch (scope) {
- case TritonFusionAnalysis::Scope::LHS:
- lhs_noncontracting_split_ = spec->back().count;
- break;
- case TritonFusionAnalysis::Scope::OUTPUT:
- if (lhs_noncontracting_split_ != spec->back().count) {
- VLOG(8) << "Output non-contracting dimension has to be split "
- "the same way as the LHS input one if it is split.";
- return false;
- }
- break;
- default:
- VLOG(8) << "Only LHS noncontracting dimension can be split.";
- return false;
- }
- // Assign the major part of the noncontracting dimension to the
- // unused batch one.
- CHECK_EQ(dimensions[kBatchDimensionIndex], 1);
- dimensions[kBatchDimensionIndex] = spec->back().count;
- strides[kBatchDimensionIndex] = spec->back().stride;
- } else {
- VLOG(8) << "The dimension is split multiple times.";
- return false;
- }
- dimensions.push_back(spec->front().count);
- strides.push_back(spec->front().stride);
- }
- }
- if (lhs_noncontracting_split_ > 1 &&
- scope == TritonFusionAnalysis::Scope::OUTPUT &&
- dimensions[kBatchDimensionIndex] == 1) {
- // LHS input noncontracting dimension is split but the corresponding
- // output one is not. Assign part of the output one to the unused batch
- // dimension.
- dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_;
- dimensions[kOutputLHSNonContractingDimensionIndex] /=
- lhs_noncontracting_split_;
- strides[kBatchDimensionIndex] =
- strides[kOutputLHSNonContractingDimensionIndex] *
- dimensions[kOutputLHSNonContractingDimensionIndex];
- }
- return true;
- }
-
- private:
- int64_t lhs_noncontracting_split_ = 1;
- const HloDotInstruction& dot_;
-};
-
-template <PrimitiveType XlaT, typename T>
-std::shared_ptr<graph::Tensor_attributes> LiteralToCudnnTensor(
- const HloInstruction& hlo, graph::Graph& graph) {
- using NativeT = typename primitive_util::PrimitiveTypeToNative<XlaT>::type;
- return graph.tensor(T(hlo.literal().GetFirstElement<NativeT>()));
-}
-
-std::optional<std::shared_ptr<graph::Tensor_attributes>>
-HandleConstantHloToCudnnGraph(const HloInstruction& hlo, graph::Graph& graph) {
- CHECK(hlo.IsConstant()) << "HLO is not a constant: " << hlo.ToShortString();
- if (!ShapeUtil::IsScalar(hlo.shape())) {
- VLOG(3) << "Currently only support fusing scalar in the graph";
- return std::nullopt;
- }
- PrimitiveType constant_type = hlo.shape().element_type();
- switch (constant_type) {
- case BF16:
- return LiteralToCudnnTensor<BF16, __nv_bfloat16>(hlo, graph);
- case F32:
- return LiteralToCudnnTensor<F32, float>(hlo, graph);
- case S32:
- return LiteralToCudnnTensor<S32, int>(hlo, graph);
- default:
- VLOG(3) << "Unsupported constant type: "
- << PrimitiveType_Name(constant_type);
- return std::nullopt;
- }
-}
-
-std::optional<std::shared_ptr<graph::Tensor_attributes>>
-HandleClampToCudnnGraph(
- const HloInstruction& hlo, graph::Graph& graph,
- absl::flat_hash_map<const HloInstruction*,
- std::shared_ptr<graph::Tensor_attributes>>
- hlo_to_cudnn,
- fe::DataType_t data_type, fe::DataType_t compute_dtype) {
- CHECK(hlo.opcode() == HloOpcode::kClamp)
- << "HLO is not a clamp: " << hlo.ToShortString();
- CHECK(hlo.operands().size() == 3)
- << "Clamp requires to have 3 operands: " << hlo.ToShortString();
- // clamp = max(lower, min(value, upper));
- const auto min_attrs = graph::Pointwise_attributes()
- .set_mode(fe::PointwiseMode_t::MIN)
- .set_compute_data_type(compute_dtype);
- std::shared_ptr<graph::Tensor_attributes> min_tensor = graph.pointwise(
- hlo_to_cudnn[hlo.operand(1)], hlo_to_cudnn[hlo.operand(2)], min_attrs);
- min_tensor->set_data_type(data_type).set_name(std::string(hlo.name()));
- const auto max_attrs = graph::Pointwise_attributes()
- .set_mode(fe::PointwiseMode_t::MAX)
- .set_compute_data_type(compute_dtype);
- return graph.pointwise(min_tensor, hlo_to_cudnn[hlo.operand(0)], max_attrs);
-}
-
-// Traverses fusion computations and creates cuDNN graphs out of them.
-absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
- const HloFusionInstruction& fusion) {
- const HloComputation& computation = *fusion.fused_instructions_computation();
- VLOG(5) << fusion.ToString();
- VLOG(5) << computation.ToString();
- graph::Graph graph;
- std::vector<HloInstruction*> instructions =
- computation.MakeInstructionPostOrder();
- absl::flat_hash_map<const HloInstruction*,
- std::shared_ptr<graph::Tensor_attributes>>
- hlo_to_cudnn;
- TF_ASSIGN_OR_RETURN(std::optional<GemmDimensionAdapter> adapter,
- GemmDimensionAdapter::Create(computation));
- if (!adapter.has_value()) {
- return std::nullopt;
- }
- auto add_parameter = [&](const HloInstruction& parameter,
- std::vector<int64_t>& dimensions,
- std::vector<int64_t> strides) {
- const std::optional<fe::DataType_t> data_type =
- ToCudnnDataType(parameter.shape().element_type());
- if (!data_type.has_value()) {
- VLOG(3) << "Unsupported data type.";
- return false;
- }
- hlo_to_cudnn[¶meter] = graph.tensor(
- graph::Tensor_attributes()
- .set_dim(dimensions)
- .set_stride(strides)
- .set_data_type(*data_type)
- .set_name(std::string(parameter.name()))
- .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number())));
- return true;
- };
- for (const TritonFusionAnalysis::Scope scope :
- {TritonFusionAnalysis::Scope::LHS, TritonFusionAnalysis::Scope::RHS,
- TritonFusionAnalysis::Scope::OUTPUT}) {
- for (const HloInstruction* parameter :
- adapter->analysis_.ScopeParameters(scope)) {
- std::vector<int64_t> dimensions;
- std::vector<int64_t> strides;
- if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions,
- strides)) {
- VLOG(3) << "Unsupported dimensions.";
- return std::nullopt;
- }
- if (!add_parameter(*parameter, dimensions, strides)) {
- return std::nullopt;
- }
- }
- }
-
- for (const HloInstruction* hlo : instructions) {
- VLOG(5) << hlo->ToShortString();
- auto operand = [&hlo_to_cudnn, &hlo](int i) {
- return hlo_to_cudnn[hlo->operand(i)];
- };
- const auto data_type = ToCudnnDataType(hlo->shape().element_type());
- if (!data_type.has_value()) {
- VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type();
- return std::nullopt;
- }
- if (hlo->opcode() == HloOpcode::kParameter) {
- CHECK(hlo_to_cudnn.contains(hlo));
- continue;
- } else if (hlo->opcode() == HloOpcode::kCustomCall) {
- if (hlo->user_count() != 1 ||
- !IsWorkspaceAllocationRoot(*hlo->users()[0])) {
- VLOG(3) << "Custom calls are only expected to be used for workspace "
- "allocation.";
- return std::nullopt;
- }
- continue;
- } else if (hlo->opcode() == HloOpcode::kTuple) {
- if (!IsWorkspaceAllocationRoot(*hlo)) {
- VLOG(3) << "Tuples are only expected at outputs for workspace "
- "allocation.";
- return std::nullopt;
- }
- continue;
- } else if (FusionLevel(fusion) >= 2 &&
- hlo->opcode() == HloOpcode::kConstant) {
- if (const auto const_tensor = HandleConstantHloToCudnnGraph(*hlo, graph);
- const_tensor.has_value()) {
- hlo_to_cudnn[hlo] = const_tensor.value();
- } else {
- return std::nullopt;
- }
- } else if (hlo->opcode() == HloOpcode::kReshape ||
- hlo->opcode() == HloOpcode::kBitcast ||
- hlo->opcode() == HloOpcode::kTranspose ||
- hlo->opcode() == HloOpcode::kCopy ||
- (FusionLevel(fusion) >= 2 &&
- hlo->opcode() == HloOpcode::kBroadcast)) {
- // All these are accounted for separately as transformations of strides.
- hlo_to_cudnn[hlo] = operand(0);
- } else if (hlo->IsElementwise()) {
- const auto compute_dtype =
- GetComputeDataType(hlo->shape().element_type());
- if (!compute_dtype.has_value()) {
- return std::nullopt;
- }
- if (hlo->opcode() == HloOpcode::kClamp) {
- const auto clamp =
- HandleClampToCudnnGraph(*hlo, graph, hlo_to_cudnn,
- data_type.value(), compute_dtype.value());
- if (!clamp.has_value()) {
- return std::nullopt;
- }
- hlo_to_cudnn[hlo] = clamp.value();
- } else {
- const auto mode = GetElementwiseMode(*hlo);
- if (!mode.has_value()) {
- VLOG(3) << "Unsupported elementwise operation.";
- return std::nullopt;
- }
- const auto attrs = graph::Pointwise_attributes()
- .set_mode(mode.value())
- .set_compute_data_type(compute_dtype.value());
- if (hlo->operand_count() == 1) {
- hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs);
- // Sets the dimensions for unary ops whose operands are broadcast
- // for cuDNN to infer its inputs' shapes. constant has dimension [1]
- // while cuDNN requires constant to have dimension [1,1,1]. Not
- // setting output of the unary shapes results in the rejection of
- // the cuDNN graph.
- if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) {
- const auto scope = adapter->analysis_.QueryInstructionScope(*hlo);
- std::vector<int64_t> dimensions;
- std::vector<int64_t> strides;
- if (!scope.has_value()) {
- LOG(FATAL) << "No scope for instruction: "
- << hlo->ToShortString();
- }
- if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
- strides)) {
- VLOG(3) << "Unsupported hlo for querying dimensions: "
- << hlo->ToShortString();
- } else {
- hlo_to_cudnn[hlo]->set_dim(dimensions);
- }
- }
- } else if (hlo->operand_count() == 2) {
- hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs);
- } else if (hlo->operand_count() == 3) {
- if (hlo->opcode() != HloOpcode::kSelect) {
- VLOG(3) << "Unexpected ternary operation: " << hlo->ToString();
- return std::nullopt;
- }
- // Operand order for select differs between HLO and cuDNN.
- hlo_to_cudnn[hlo] =
- graph.pointwise(operand(1), operand(2), operand(0), attrs);
- } else {
- VLOG(3) << "Unimplemented elementwise operation.";
- return std::nullopt;
- }
- }
- } else if (hlo->opcode() == HloOpcode::kDot) {
- const auto compute_dtype =
- GetComputeDataType(hlo->shape().element_type());
- if (!compute_dtype.has_value()) {
- return std::nullopt;
- }
- hlo_to_cudnn[hlo] =
- graph.matmul(operand(0), operand(1),
- graph::Matmul_attributes().set_compute_data_type(
- compute_dtype.value()));
- } else {
- VLOG(3) << "Unimplemented operation.";
- return std::nullopt;
- }
- if (hlo_to_cudnn[hlo] == nullptr) {
- VLOG(3) << "Creation of the operation failed.";
- return std::nullopt;
- }
- hlo_to_cudnn[hlo]
- ->set_data_type(data_type.value())
- .set_name(std::string(hlo->name()));
- }
- const HloInstruction* output = instructions.back();
- if (instructions.back()->shape().IsTuple()) {
- output = instructions.back()->operand(0);
- }
- std::vector<int64_t> dimensions;
- std::vector<int64_t> strides;
- if (!adapter->DimensionsAndStrides(
- *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) {
- VLOG(3) << "Unsupported dimensions.";
- return std::nullopt;
- }
- hlo_to_cudnn[output]
- ->set_output(true)
- .set_dim(dimensions)
- .set_stride(strides)
- .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count()));
- if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) {
- json dump;
- graph.serialize(dump);
- DumpToFileInDirOrStdout(
- /*module=*/*fusion.GetModule(),
- /*file_prefix=*/"",
- /*file_suffix=*/
- absl::StrCat("cudnn_fusion_", fusion.name(), ".json"),
- /*contents=*/dump.dump(1));
- }
-
- return se::gpu::CudnnGraph(std::move(graph));
-}
-
-// Creates a cuDNN graph, queries cuDNN whether it is supported.
-absl::StatusOr<se::gpu::CudnnGraph> PrepareGraph(
- se::dnn::DnnSupport& dnn_support, const HloFusionInstruction& hlo) {
- TF_ASSIGN_OR_RETURN(std::optional<se::gpu::CudnnGraph> graph,
- HloFusionToCuDnnGraph(hlo));
- if (!graph.has_value()) {
- return absl::InternalError("Construction of cuDNN graph failed.");
- }
- TF_RETURN_IF_ERROR(graph->Prepare(dnn_support));
- return *graph;
-}
-
-absl::StatusOr<HloInstruction*> AddWorkspace(HloInstruction& fusion,
- const int64_t workspace_size) {
- HloComputation* computation = fusion.fused_instructions_computation();
- HloInstruction* custom_call =
- computation->AddInstruction(HloInstruction::CreateCustomCall(
- ShapeUtil::MakeShape(S8, {workspace_size}), {},
- kWorkspaceAllocationCustomCallTarget));
- HloInstruction* output_tuple =
- computation->AddInstruction(HloInstruction::CreateTuple(
- {computation->root_instruction(), custom_call}));
- computation->set_root_instruction(output_tuple, true);
- HloInstruction* new_fusion = fusion.parent()->AddInstruction(
- fusion.CloneWithNewShape(output_tuple->shape()));
- TF_RETURN_IF_ERROR(fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(new_fusion, 0))));
- TF_RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion));
- return new_fusion;
-}
-
-class CuDnnFusionVisitor : public DfsHloRewriteVisitor {
- public:
- explicit CuDnnFusionVisitor(se::dnn::DnnSupport& dnn_support,
- BinaryMap& compilation_results)
- : dnn_support_(dnn_support), compilation_results_(compilation_results) {}
-
- absl::Status HandleFusion(HloInstruction* hlo) override {
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- hlo->backend_config<GpuBackendConfig>());
- const auto& fusion_backend_config = gpu_config.fusion_backend_config();
- if (fusion_backend_config.kind() != kCuDnnFusionKind) {
- return absl::OkStatus();
- }
- int64_t plan_id = -1;
- if (fusion_backend_config.has_cudnn_fusion_config()) {
- plan_id = fusion_backend_config.cudnn_fusion_config().plan_id();
- }
-
- VLOG(4) << "Processing " << hlo->ToString();
- VLOG(4) << "Plan ID: " << plan_id;
-
- auto add_workspace = [&](const int64_t workspace_size) {
- if (workspace_size > 0) {
- TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size));
- SetVisited(*hlo);
- }
- return absl::OkStatus();
- };
- const std::string fingerprint_without_workspace =
- GetComputationFingerprint(hlo->fused_instructions_computation(), {});
- auto workspace_size_it =
- workspace_sizes_.find(fingerprint_without_workspace);
- if (workspace_size_it == workspace_sizes_.cend()) {
- TF_ASSIGN_OR_RETURN(
- se::gpu::CudnnGraph graph,
- PrepareGraph(dnn_support_, *DynCast<HloFusionInstruction>(hlo)));
-
- if (plan_id >= 0) {
- // Build single plan with given ID.
- if (plan_id >= graph.Graph().get_execution_plan_count()) {
- return absl::InternalError("cuDNN graph plan does not exist.");
- }
- TF_RETURN_IF_ERROR(graph.Build(dnn_support_, plan_id));
- } else {
- // Build plans one by one till first successful when no plan_id was
- // provided.
- for (plan_id = 0; plan_id < graph.Graph().get_execution_plan_count();
- ++plan_id) {
- VLOG(7) << "Trying plan ID " << plan_id;
- if (graph.Build(dnn_support_, plan_id).ok()) {
- VLOG(7) << "Successfully built plan ID " << plan_id;
- break;
- }
- }
- if (plan_id == graph.Graph().get_execution_plan_count()) {
- return absl::InternalError("No cuDNN plans can be built.");
- }
- }
- const int64_t workspace_size = graph.Graph().get_workspace_size();
- workspace_sizes_.insert(workspace_size_it,
- {fingerprint_without_workspace, workspace_size});
- TF_RETURN_IF_ERROR(add_workspace(workspace_size));
-
- std::vector<uint8_t> serialized_graph;
- RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph));
- // Compute a new fingerprint with a potential workspace for the
- // compilation results to match a fingerprint computed by the emitter.
- compilation_results_[GetComputationFingerprint(
- hlo->fused_instructions_computation(), {})] =
- std::string(reinterpret_cast<char*>(serialized_graph.data()),
- serialized_graph.size());
- } else {
- VLOG(4) << "Cache hit.";
- TF_RETURN_IF_ERROR(add_workspace(workspace_size_it->second));
- }
- auto cudnn_config = gpu_config.mutable_fusion_backend_config()
- ->mutable_cudnn_fusion_config();
- cudnn_config->set_plan_id(plan_id);
- TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
-
- MarkAsChanged();
- return absl::OkStatus();
- }
-
- private:
- se::dnn::DnnSupport& dnn_support_;
- // <HLO computation fingerprint, serialized compiled cuDNN graph>.
- BinaryMap& compilation_results_;
- absl::flat_hash_map<std::string, int64_t> workspace_sizes_;
-};
-
-} // namespace
-
-absl::StatusOr<bool> CuDnnFusionCompiler::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_SCOPED_LOGGING_TIMER("cuDNN fusion compiler");
- return CuDnnFusionVisitor(dnn_support_, compilation_results_)
- .RunOnModule(module, execution_threads);
-}
-
-int CuDnnFusionCompiler::GetAvailablePlanCount(
- se::StreamExecutor& stream_exec, const HloFusionInstruction& hlo) {
- auto graph = PrepareGraph(*stream_exec.AsDnn(), hlo);
- if (!graph.ok()) {
- return 0;
- }
- constexpr int64_t kMaxPlans = 10;
- return std::min(graph->Graph().get_execution_plan_count(), kMaxPlans);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h
deleted file mode 100644
index f34bbb1..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-
-namespace xla {
-namespace gpu {
-
-// Converts HLO fusions with cuDNN backend config to cuDNN graphs,
-// compiles them using a cuDNN handle and serializes them.
-class CuDnnFusionCompiler : public HloModulePass {
- public:
- explicit CuDnnFusionCompiler(se::StreamExecutor& stream_exec,
- BinaryMap& compilation_results)
- : dnn_support_(*stream_exec.AsDnn()),
- compilation_results_(compilation_results) {}
-
- absl::string_view name() const override { return "cudnn-fusion-compiler"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- static int GetAvailablePlanCount(se::StreamExecutor& stream_exec,
- const HloFusionInstruction& hlo);
-
- private:
- se::dnn::DnnSupport& dnn_support_;
- BinaryMap& compilation_results_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc
deleted file mode 100644
index 5e78f48..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc
+++ /dev/null
@@ -1,1553 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_norm_rewriter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <cstdlib>
-#include <functional>
-#include <iterator>
-#include <limits>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "google/protobuf/wrappers.pb.h"
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/types.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/protobuf/dnn.pb.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#endif
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-namespace m = match;
-
-// Traverses the graph upward starting at instr and returns the
-// first instruction that is not a convert, bitcast or reshape.
-const HloInstruction* SkipUnaryOps(const HloInstruction* instr) {
- while (instr->opcode() == HloOpcode::kConvert ||
- instr->opcode() == HloOpcode::kBitcast ||
- instr->opcode() == HloOpcode::kReshape) {
- instr = instr->operand(0);
- }
- return instr;
-}
-
-// Recursively traverses the graph downward starting at instr and stores in
-// instrs the users that are not a convert, bitcast or reshape.
-void SkipUnaryOpsTopDownRecursive(HloInstruction* instr,
- std::vector<HloInstruction*>& instrs) {
- if (instr->opcode() == HloOpcode::kConvert ||
- instr->opcode() == HloOpcode::kBitcast ||
- instr->opcode() == HloOpcode::kReshape) {
- for (HloInstruction* user : instr->users()) {
- SkipUnaryOpsTopDownRecursive(user, instrs);
- }
- } else {
- instrs.emplace_back(instr);
- }
-}
-
-// Holds auxiliary information about individual layer norm patterns rewritten
-// into a cuDNN Custom Call.
-struct NormMetadata {
- // Transposes applied to the input and output of the forward layer norm to
- // order the normalization and non-normalization dimensions as required by
- // cuDNN. Nullptr if no transposes were inserted.
- HloInstruction *x_transpose, *y_transpose;
- // The reduction and non-reduction dimensions of the input into the forward
- // layer norm before the potential application of transposes and adjusted for
- // the removal of any degenerate dimensions in the input to the norm.
- std::vector<int64_t> norm_dims_adjusted, non_norm_dims_adjusted;
-};
-
-// Map from the instruction pointer of a layer norm Custom Call to its metadata.
-using NormMetadataMap = absl::flat_hash_map<HloInstruction*, NormMetadata>;
-
-// Captures multiple HloInstruction pointers and verifies that their target
-// is identical.
-//
-// Example:
-// Pattern cos(x) / sin(x) with cos and sin intended to operate on the same
-// HloInstruction:
-// UniqueHloInstruction x;
-// bool m = Match(
-// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)),
-// m::Sin(m::Op().WithPredicate(x.capture_and_verify))));
-// m is true and x.Instr() returns an HloInstruction pointer to the operand of
-// cosine and sine iff HloInstruction *instr points to a division of a cosine by
-// a sine that operate on the same instruction.
-class UniqueHloInstruction {
- public:
- UniqueHloInstruction()
- : is_set_(false), instr_(nullptr), capture_or_verify_() {}
- HloInstruction* Instr() const { return instr_; }
- void SetInstr(HloInstruction* instr) {
- is_set_ = true;
- instr_ = instr;
- }
-
- // Stores instr when invoked the first time. Otherwise, compares instr to the
- // stored value and sets the stored value to nullptr if the comparison fails.
- bool CaptureOrVerify(HloInstruction* instr) {
- if (is_set_ && instr != instr_) {
- instr_ = nullptr;
- }
- if (!is_set_) {
- is_set_ = true;
- instr_ = instr;
- }
- return instr_;
- }
-
- // Returns a std::function for capturing or verifying an instruction using
- // WithPredicate.
- std::function<bool(const HloInstruction*)> GetCaptureOrVerifyFn() {
- if (!capture_or_verify_) {
- capture_or_verify_ = [this](const HloInstruction* instr) -> bool {
- return CaptureOrVerify(const_cast<HloInstruction*>(instr));
- };
- }
- return capture_or_verify_;
- }
-
- private:
- bool is_set_;
- HloInstruction* instr_;
- std::function<bool(const HloInstruction*)> capture_or_verify_;
-};
-
-// Returns an architecture-specific constant for the calculation of an upper
-// bound for the size of the scratch space for layer norm kernels.
-absl::StatusOr<int64_t> CConstant(
- se::CudaComputeCapability cuda_compute_capability) {
- if (cuda_compute_capability.major == se::CudaComputeCapability::AMPERE) {
- return 32 * 128;
- } else if (cuda_compute_capability.major ==
- se::CudaComputeCapability::HOPPER) {
- return 32 * 144;
- }
- return xla::Internal("Norm kernels require Ampere or Hopper architecture.");
-}
-
-// Returns whether the element type of instr is compatible with layer norm
-// kernels.
-bool CompatibleElementType(const HloInstruction* instr) {
- PrimitiveType element_type = instr->shape().element_type();
- return element_type == BF16 || element_type == F16 || element_type == F32;
-}
-
-// Returns the dimensions associated with shape, adjusted for the removal of any
-// degenerate dimensions in shape. Specifically, for each dimension d in
-// dimensions, returns the new index of d if all dimensions of size 1 are
-// removed from shape. If d has size 1, it is not included in the returned
-// vector.
-std::vector<int64_t> AdjustedDimensions(const Shape& shape,
- absl::Span<const int64_t> dimensions) {
- absl::flat_hash_map<int64_t, int64_t> dimension_map;
- for (int64_t dimension = 0, non_degen_dimension = 0; dimension < shape.rank();
- ++dimension) {
- if (shape.dimensions(dimension) > 1) {
- dimension_map.insert({dimension, non_degen_dimension});
- non_degen_dimension++;
- }
- }
- std::vector<int64_t> adjusted_dimensions;
- for (int64_t dimension : dimensions) {
- auto non_degenerate_dimension = dimension_map.find(dimension);
- if (non_degenerate_dimension != dimension_map.end()) {
- adjusted_dimensions.emplace_back(non_degenerate_dimension->second);
- }
- }
- return adjusted_dimensions;
-}
-
-// Returns the dimensions of broadcast or reduction instructions, adjusted for
-// the removal of any degenerate dimensions in the output or input.
-std::vector<int64_t> AdjustedDimensions(const HloInstruction* instr) {
- Shape shape;
- if (instr->opcode() == HloOpcode::kBroadcast) {
- shape = instr->shape();
- } else if (instr->opcode() == HloOpcode::kReduce) {
- shape = instr->operand(0)->shape();
- } else {
- return {};
- }
- return AdjustedDimensions(shape, instr->dimensions());
-}
-
-// Returns whether the HLO Computation applied by instr calculates the sum of
-// the elements. When provided, compares reduce_dims to the dimensions of the
-// reduction.
-bool AppliesAddReduce(const HloInstruction* instr,
- absl::Span<const int64_t> reduce_dims = {}) {
- if (instr->opcode() != HloOpcode::kReduce) {
- return false;
- }
-
- // Verify the dimensions of the reduction.
- if (!reduce_dims.empty() && AdjustedDimensions(instr) != reduce_dims) {
- return false;
- }
-
- HloComputation* reduce_comp = instr->to_apply();
- HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
- return instr->operand_count() == 2 &&
- instr->operand(1)->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsScalar(instr->operand(1)->shape()) &&
- instr->operand(1)->literal().GetAsDouble({}) == 0. &&
- reduce_comp_root->opcode() == HloOpcode::kAdd &&
- reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
- reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
-}
-
-// Returns whether instr multiplies the result of a reduction by one over the
-// number of reduced elements.
-bool CalculatesExpectation(const HloInstruction* instr) {
- instr = SkipUnaryOps(instr);
- if (instr->opcode() != HloOpcode::kMultiply) {
- return false;
- }
- bool bcast_operand = instr->operand(0)->opcode() != HloOpcode::kBroadcast;
- const HloInstruction *broadcast = instr->operand(bcast_operand),
- *reduce = SkipUnaryOps(instr->operand(!bcast_operand));
- if (reduce->opcode() != HloOpcode::kReduce ||
- broadcast->opcode() != HloOpcode::kBroadcast ||
- broadcast->operand(0)->opcode() != HloOpcode::kConstant) {
- return false;
- }
-
- float actual_r_nelems =
- broadcast->operand(0)->literal().GetAsDouble({}).value();
- int64_t nelems = 1;
- for (int64_t norm_dim : reduce->dimensions()) {
- nelems *= reduce->operand(0)->shape().dimensions()[norm_dim];
- }
- // The absolute of the difference between the actual scaling factor and the
- // reference value must not exceed a prescribed threshold.
- float r_nelems = 1. / static_cast<float>(nelems);
- float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
- return abs(actual_r_nelems - r_nelems) <
- ((actual_r_nelems + r_nelems) * numerical_epsilon);
-}
-
-// Returns whether target can be reached from instr by recursively traversing
-// the graph across converts, bitcasts and reshapes.
-bool FindTargetRecursive(
- const HloInstruction* instr, const HloInstruction* target,
- absl::flat_hash_set<const HloInstruction*>& visited_instrs,
- const HloInstruction* transpose) {
- visited_instrs.emplace(instr);
- const absl::flat_hash_set<HloOpcode> supported_ops = {
- HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
- if (instr == target) {
- return true;
- }
- // Look for target among the users of instr.
- for (HloInstruction* user : instr->users()) {
- if ((supported_ops.contains(user->opcode()) || user == transpose) &&
- !visited_instrs.contains(user)) {
- return FindTargetRecursive(user, target, visited_instrs, transpose);
- }
- }
- // Ascend the graph if target is not found and instr is a convert, bitcast
- // or reshape.
- if (supported_ops.contains(instr->opcode())) {
- return FindTargetRecursive(instr->operand(0), target, visited_instrs,
- transpose);
- }
- return false;
-}
-
-bool FindTarget(const HloInstruction* custom_call, const HloInstruction* instr,
- const HloInstruction* target,
- const NormMetadataMap& norm_metadata) {
- absl::flat_hash_set<const HloInstruction*> visited_instrs;
- auto custom_call_metadata = norm_metadata.find(custom_call);
- if (custom_call_metadata == norm_metadata.end()) {
- return false;
- }
- return FindTargetRecursive(instr, target, visited_instrs,
- custom_call_metadata->second.x_transpose);
-}
-
-// Maps the dimension numbers in dimensions from shape original_shape to shape
-// reshaped_shape, assuming that the shapes are related through a strict
-// reshape. Returns an empty vector if a dimension mapping is not found.
-std::vector<int64_t> MapDimensions(const Shape& original_shape,
- const Shape& reshaped_shape,
- const absl::Span<const int64_t> dimensions) {
- auto dimension_product =
- [](const Shape& shape,
- absl::Span<const int64_t> product_dimensions) -> int64_t {
- int64_t product = 1;
- for (int64_t product_dimension : product_dimensions) {
- product *= shape.dimensions(product_dimension);
- }
- return product;
- };
- // Construct the dimension mapping.
- absl::flat_hash_map<int64_t, std::vector<int64_t>> dimensions_map;
- std::vector<int64_t> original_dimensions, reshaped_dimensions;
- for (int64_t original_dimension = 0, reshaped_dimension = 0;
- original_dimension < original_shape.rank(); ++original_dimension) {
- original_dimensions.emplace_back(original_dimension);
- while ((reshaped_dimensions.empty() ||
- dimension_product(reshaped_shape, reshaped_dimensions) <
- dimension_product(original_shape, original_dimensions)) &&
- reshaped_dimension < reshaped_shape.rank()) {
- reshaped_dimensions.emplace_back(reshaped_dimension++);
- }
-
- // Many-to-many dimension mappings are not supported.
- if (original_dimensions.size() > 1 && reshaped_dimensions.size() > 1) {
- return {};
- }
-
- if (dimension_product(original_shape, original_dimensions) ==
- dimension_product(reshaped_shape, reshaped_dimensions)) {
- std::vector<int64_t> original_dimensions_in_dimensions;
- std::set_intersection(
- original_dimensions.begin(), original_dimensions.end(),
- dimensions.begin(), dimensions.end(),
- std::back_inserter(original_dimensions_in_dimensions));
- // The unique mapping of dimensions requires either all or none of the
- // entries of original_dimensions to be an element of dimensions.
- if (!original_dimensions_in_dimensions.empty() &&
- original_dimensions_in_dimensions.size() !=
- original_dimensions.size()) {
- return {};
- }
- for (int64_t dimension : original_dimensions) {
- dimensions_map.insert({dimension, reshaped_dimensions});
- }
- original_dimensions.clear();
- reshaped_dimensions.clear();
- }
- }
-
- // Map the dimensions numbers to the reshaped shape.
- std::vector<int64_t> mapped_dimensions;
- for (int64_t dimension : dimensions) {
- auto mapped_dimension = dimensions_map.find(dimension);
- if (mapped_dimension == dimensions_map.end()) {
- return {};
- }
- mapped_dimensions.insert(mapped_dimensions.end(),
- mapped_dimension->second.begin(),
- mapped_dimension->second.end());
- }
-
- // Eliminate duplicates in the mapped dimension numbers.
- mapped_dimensions.erase(
- std::unique(mapped_dimensions.begin(), mapped_dimensions.end()),
- mapped_dimensions.end());
- return mapped_dimensions;
-}
-
-// Recursively traverses the graph across converts, bitcasts and reshapes,
-// starting from instr, and returns the first addition-reduction identified.
-// Returns nullptr if no addition-reduction is found.
-HloInstruction* FindAddReduceRecursive(
- HloInstruction* instr, const Shape& orig_instr_shape,
- const absl::Span<const int64_t> reduce_dims,
- absl::flat_hash_set<HloInstruction*>& visited_instrs) {
- visited_instrs.emplace(instr);
- const absl::flat_hash_set<HloOpcode> supported_ops = {
- HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
- // Look for a reduction among the users of instr.
- for (HloInstruction* user : instr->users()) {
- if (user->opcode() == HloOpcode::kReduce) {
- std::vector<int64_t> mapped_reduce_dims =
- MapDimensions(orig_instr_shape, instr->shape(), reduce_dims);
- if (!mapped_reduce_dims.empty() &&
- AppliesAddReduce(user, mapped_reduce_dims)) {
- return user;
- }
- }
- if (supported_ops.contains(user->opcode()) &&
- !visited_instrs.contains(user)) {
- return FindAddReduceRecursive(user, orig_instr_shape, reduce_dims,
- visited_instrs);
- }
- }
- // Ascend the graph if the addition-reduction is not found and instr is a
- // convert, bitcast or reshape.
- if (supported_ops.contains(instr->opcode())) {
- return FindAddReduceRecursive(instr->mutable_operand(0), orig_instr_shape,
- reduce_dims, visited_instrs);
- }
- return nullptr;
-}
-
-HloInstruction* FindAddReduce(HloInstruction* instr,
- const absl::Span<const int64_t> reduce_dims) {
- absl::flat_hash_set<HloInstruction*> visited_instrs;
- return FindAddReduceRecursive(instr, instr->shape(), reduce_dims,
- visited_instrs);
-}
-
-// Type conversion from and to any of BF16, FP16 and FP32.
-template <typename Pattern>
-auto SupportedConvert(Pattern pattern) {
- auto supported_convert = [](const HloInstruction* instr) -> bool {
- return CompatibleElementType(instr) &&
- CompatibleElementType(instr->operand(0));
- };
- return m::Convert(pattern).WithPredicate(supported_convert);
-}
-
-// Bitcast or reshape adding or removing degenerate dimensions.
-template <typename Pattern>
-auto SupportedBitcastOrReshape(Pattern pattern) {
- auto supported_bitcast_or_reshape = [](const HloInstruction* instr) -> bool {
- return ShapeUtil::Equal(
- ShapeUtil::DropDegenerateDimensions(instr->shape()),
- ShapeUtil::DropDegenerateDimensions(instr->operand(0)->shape()));
- };
- return m::AnyOf<HloInstruction>(
- m::Bitcast(pattern).WithPredicate(supported_bitcast_or_reshape),
- m::Reshape(pattern).WithPredicate(supported_bitcast_or_reshape));
-}
-
-// Matches pattern, SupportedConvert(pattern),
-// SupportedBitcastOrReshape(pattern),
-// SupportedConvert(SupportedBitcastOrReshape(pattern)) and
-// SupportedBitcastOrReshape(SupportedConvert(pattern)).
-template <typename Pattern>
-auto OptionalSupportedTransform(Pattern pattern) {
- auto shared_subpattern = m::SharedSubpattern(pattern);
- return m::AnyOf<HloInstruction>(
- SupportedConvert(SupportedBitcastOrReshape(shared_subpattern)),
- SupportedBitcastOrReshape(SupportedConvert(shared_subpattern)),
- SupportedConvert(shared_subpattern),
- SupportedBitcastOrReshape(shared_subpattern), shared_subpattern);
-}
-
-// Bitcast or reshape with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern>
-auto BitcastOrReshape(Pattern pattern) {
- return OptionalSupportedTransform(
- m::AnyOf<HloInstruction>(m::Bitcast(pattern), m::Reshape(pattern)));
-}
-
-// Transpose with optional supported type conversion and/or addition or removal
-// of degenerate dimensions.
-template <typename Pattern>
-auto Transpose(Pattern pattern) {
- return OptionalSupportedTransform(m::Transpose(pattern));
-}
-
-// Rsqrt with optional supported type conversion and/or addition or removal of
-// degenerate dimensions.
-template <typename Pattern>
-auto Rsqrt(HloInstruction** rsqrt, Pattern pattern) {
- return OptionalSupportedTransform(m::Rsqrt(rsqrt, pattern));
-}
-
-// AddAnyOrder with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto AddAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
- return OptionalSupportedTransform(m::AddAnyOrder(pattern0, pattern1));
-}
-
-// Subtract with optional supported type conversion and/or addition or removal
-// of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto Subtract(Pattern0 pattern0, Pattern1 pattern1) {
- return OptionalSupportedTransform(m::Subtract(pattern0, pattern1));
-}
-
-// Capturing subtract with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto Subtract(HloInstruction** subtract, Pattern0 pattern0, Pattern1 pattern1) {
- return OptionalSupportedTransform(m::Subtract(subtract, pattern0, pattern1));
-}
-
-// Multiply with optional supported type conversion and/or addition or removal
-// of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto MultiplyAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
- return OptionalSupportedTransform(m::MultiplyAnyOrder(pattern0, pattern1));
-}
-
-// Capturing multiply with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto MultiplyAnyOrder(HloInstruction** multiply, Pattern0 pattern0,
- Pattern1 pattern1) {
- return OptionalSupportedTransform(
- m::MultiplyAnyOrder(multiply, pattern0, pattern1));
-}
-
-// Multiplication of pattern by itself with optional supported type conversion
-// and/or addition or removal of degenerate dimensions.
-template <typename Pattern>
-auto Square(Pattern pattern) {
- return MultiplyAnyOrder(pattern, pattern)
- .WithPredicate([](const HloInstruction* instr) {
- return instr->unique_operands().size() == 1;
- });
-}
-
-// Multiplication of the square of pattern by pattern with optional supported
-// type conversion and/or addition or removal of degenerate dimensions. The root
-// instruction of pattern cannot be a multiplication.
-template <typename Pattern>
-auto Cube(Pattern pattern) {
- auto unique_cube = [](const HloInstruction* instr) -> bool {
- bool square_operand = instr->operand(0)->opcode() != HloOpcode::kMultiply;
- return instr->operand(!square_operand)->opcode() != HloOpcode::kMultiply &&
- instr->operand(square_operand)->operand(0) ==
- instr->operand(!square_operand);
- };
- return MultiplyAnyOrder(Square(pattern), pattern).WithPredicate(unique_cube);
-}
-
-// Addition-reduction of pattern with optional supported type conversion and/or
-// addition or removal of degenerate dimensions.
-template <typename Pattern>
-auto AddReduce(Pattern pattern) {
- return OptionalSupportedTransform(
- m::Reduce(pattern, m::Op())
- .WithPredicate([](const HloInstruction* instr) {
- return AppliesAddReduce(instr);
- }));
-}
-
-// Capturing addition-reduction of pattern with optional supported type
-// conversion and/or addition or removal of degenerate dimensions.
-template <typename Pattern>
-auto AddReduce(HloInstruction** reduction, Pattern pattern) {
- return OptionalSupportedTransform(
- m::Reduce(reduction, pattern, m::Op())
- .WithPredicate([](const HloInstruction* instr) {
- return AppliesAddReduce(instr);
- }));
-}
-
-// Negated addition-reduction.
-template <typename Pattern>
-auto NegateAddReduce(HloInstruction** reduction, Pattern pattern) {
- return m::AnyOf<HloInstruction>(AddReduce(reduction, m::Negate(pattern)),
- m::Negate(AddReduce(reduction, pattern)));
-}
-
-// Expected value, or mean, with optional broadcast.
-template <typename Pattern>
-auto Expectation(Pattern pattern) {
- auto shared_subpattern =
- MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
- .WithPredicate([](const HloInstruction* instr) {
- return CalculatesExpectation(instr);
- });
- return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
- shared_subpattern);
-}
-
-// Expected value, or mean, with optional broadcast.
-template <typename Pattern>
-auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) {
- auto shared_subpattern = OptionalSupportedTransform(
- m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
- .WithPredicate([](const HloInstruction* instr) {
- return CalculatesExpectation(instr);
- })
- .WithPredicate(expectation->GetCaptureOrVerifyFn()));
- return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
- shared_subpattern);
-}
-
-// Expected value, or mean, with optional broadcast.
-template <typename Pattern>
-auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce,
- Pattern pattern) {
- auto shared_subpattern = OptionalSupportedTransform(
- m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
- AddReduce(reduce, pattern))
- .WithPredicate([](const HloInstruction* instr) {
- return CalculatesExpectation(instr);
- })
- .WithPredicate(expectation->GetCaptureOrVerifyFn()));
- return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
- shared_subpattern);
-}
-
-// Variance, expressed as expectation(X^2) - expectation(X)^2 or
-// expectation((X - expectation(X))^2).
-auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation,
- UniqueHloInstruction* x) {
- return m::AnyOf<HloInstruction>(
- Subtract(
- Expectation(Square(OptionalSupportedTransform(
- m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))),
- Square(Expectation(expectation,
- OptionalSupportedTransform(m::Op().WithPredicate(
- x->GetCaptureOrVerifyFn())))))
- .WithPredicate(variance->GetCaptureOrVerifyFn()),
- Expectation(
- Square(Subtract(
- OptionalSupportedTransform(
- m::Op().WithPredicate(x->GetCaptureOrVerifyFn())),
- Expectation(expectation,
- OptionalSupportedTransform(m::Op().WithPredicate(
- x->GetCaptureOrVerifyFn()))))))
- .WithPredicate(variance->GetCaptureOrVerifyFn()));
-}
-
-// Reciprocal of the square root of variance + epsilon with optional broadcast.
-auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x,
- UniqueHloInstruction* variance,
- UniqueHloInstruction* expectation,
- UniqueHloInstruction* epsilon) {
- auto shared_subpattern = m::SharedSubpattern(Rsqrt(
- norm_factor, AddAnyOrder(Variance(variance, expectation, x),
- m::Broadcast(m::ConstantScalar().WithPredicate(
- epsilon->GetCaptureOrVerifyFn())))));
- return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
- shared_subpattern);
-}
-
-// Any order of p0 * p1 * p2.
-template <typename P0, typename P1, typename P2>
-auto MultiplyMultiplyAnyOrder(P0 p0, P1 p1, P2 p2) {
- return m::AnyOf<HloInstruction>(
- MultiplyAnyOrder(p0, MultiplyAnyOrder(p1, p2)),
- MultiplyAnyOrder(p1, MultiplyAnyOrder(p0, p2)),
- MultiplyAnyOrder(p2, MultiplyAnyOrder(p0, p1)));
-}
-
-// Any order of p0 + p1 + p2.
-template <typename P0, typename P1, typename P2>
-auto AddAddAnyOrder(P0 p0, P1 p1, P2 p2) {
- return m::AnyOf<HloInstruction>(AddAnyOrder(p0, AddAnyOrder(p1, p2)),
- AddAnyOrder(p1, AddAnyOrder(p0, p2)),
- AddAnyOrder(p2, AddAnyOrder(p0, p1)));
-}
-
-// Any order of p0 * (p1 + p2).
-template <typename P0, typename P1, typename P2>
-auto MultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2) {
- return m::AnyOf<HloInstruction>(
- MultiplyAnyOrder(p0, AddAnyOrder(p1, p2)),
- AddAnyOrder(MultiplyAnyOrder(p0, p1), MultiplyAnyOrder(p0, p2)));
-}
-
-// Any order of p0 - p1 + p2.
-template <typename P0, typename P1, typename P2>
-auto SubtractAddAnyOrder(P0 p0, P1 p1, P2 p2) {
- return m::AnyOf<HloInstruction>(AddAnyOrder(Subtract(p0, p1), p2),
- AddAnyOrder(Subtract(p2, p1), p0),
- Subtract(AddAnyOrder(p0, p2), p1));
-}
-
-// Any order of (p0 - p1) * p2 * p3 + p4.
-template <typename P0, typename P1, typename P2, typename P3, typename P4>
-auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) {
- return m::AnyOf<HloInstruction>(
- SubtractAddAnyOrder(MultiplyMultiplyAnyOrder(p0, p2, p3),
- MultiplyMultiplyAnyOrder(p1, p2, p3), p4),
- AddAnyOrder(MultiplyMultiplyAnyOrder(Subtract(p0, p1), p2, p3), p4));
-}
-
-// Expectation fused into a layer norm Custom Call.
-auto FusedExpectation(UniqueHloInstruction* custom_call) {
- auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
- m::CustomCall({kCudnnNormCallTarget})
- .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
- 1));
- return m::AnyOf<HloInstruction>(shared_subpattern,
- BitcastOrReshape(shared_subpattern));
-}
-
-// Expectation fused into a layer norm Custom Call.
-auto FusedExpectation(UniqueHloInstruction* fused_expectation,
- UniqueHloInstruction* custom_call) {
- auto shared_subpattern = m::SharedSubpattern(
- m::GetTupleElement(
- m::CustomCall({kCudnnNormCallTarget})
- .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
- 1)
- .WithPredicate(fused_expectation->GetCaptureOrVerifyFn()));
- return m::AnyOf<HloInstruction>(shared_subpattern,
- BitcastOrReshape(shared_subpattern));
-}
-
-// Norm factor fused into a layer norm Custom Call.
-auto FusedNormFactor(UniqueHloInstruction* custom_call) {
- auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
- m::CustomCall({kCudnnNormCallTarget})
- .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
- 2));
- return m::AnyOf<HloInstruction>(shared_subpattern,
- BitcastOrReshape(shared_subpattern));
-}
-
-// Norm factor fused into a layer norm Custom Call.
-auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor,
- UniqueHloInstruction* custom_call) {
- auto shared_subpattern = m::SharedSubpattern(
- m::GetTupleElement(
- m::CustomCall({kCudnnNormCallTarget})
- .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
- 2)
- .WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn()));
- return m::AnyOf<HloInstruction>(shared_subpattern,
- BitcastOrReshape(shared_subpattern));
-}
-
-// Derivative of the norm factor w.r.t. variance + epsilon,
-// d(norm_factor)/d(variance + epsilon)
-// = d((variance + epsilon)^-1/2)/d(variance + epsilon)
-// = -1/2 * norm_factor^3.
-// Forwards custom_call to FusedNormFactor for verification.
-auto DNormFactor(UniqueHloInstruction* custom_call) {
- return MultiplyAnyOrder(m::Broadcast(m::ConstantScalar(-0.5)),
- Cube(FusedNormFactor(custom_call)));
-}
-
-// Zero-centered input of the layer norm, X - expectation(X). Verifies that
-// custom_call is a forward layer norm fusing X. Forwards custom_call to
-// FusedExpectation for verification.
-auto XCenter(UniqueHloInstruction* x, UniqueHloInstruction* custom_call,
- const NormMetadataMap& norm_metadata) {
- auto capture_or_verify_x =
- [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
- return x->CaptureOrVerify(
- FindTarget(custom_call->Instr(), instr->operand(0),
- custom_call->Instr()->operand(0), norm_metadata)
- ? custom_call->Instr()->mutable_operand(0)
- : nullptr);
- };
- return Subtract(m::Op(), m::Broadcast(FusedExpectation(custom_call)))
- .WithPredicate(capture_or_verify_x);
-}
-
-// Zero-centered input of the layer norm, X - expectation(X). Captures X in x if
-// custom_call is a forward layer norm fusing X. Forwards custom_call to
-// FusedExpectation for comparison.
-auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x,
- UniqueHloInstruction* fused_expectation,
- UniqueHloInstruction* custom_call,
- const NormMetadataMap& norm_metadata) {
- auto capture_or_verify_x =
- [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
- return x->CaptureOrVerify(
- FindTarget(custom_call->Instr(), instr->operand(0),
- custom_call->Instr()->operand(0), norm_metadata)
- ? custom_call->Instr()->mutable_operand(0)
- : nullptr);
- };
- return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation,
- custom_call)))
- .WithPredicate(x_center->GetCaptureOrVerifyFn())
- .WithPredicate(capture_or_verify_x);
-}
-
-// Addition-reduction of the product of XCenter, the broadcasted scale and DY,
-// XCenter * scale * DY. Captures the scale in scale if custom_call is a forward
-// layer norm fusing the scale. Forwards custom_call to XCenter for comparison.
-auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
- UniqueHloInstruction* dy, UniqueHloInstruction* x,
- HloInstruction** reduce, const NormMetadataMap& norm_metadata) {
- auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
- const HloInstruction* instr) -> bool {
- return scale->CaptureOrVerify(FindTarget(custom_call->Instr(), instr,
- custom_call->Instr()->operand(1),
- norm_metadata)
- ? custom_call->Instr()->mutable_operand(1)
- : nullptr);
- };
- return AddReduce(
- reduce, MultiplyMultiplyAnyOrder(
- XCenter(x, custom_call, norm_metadata),
- m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)),
- m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
-}
-
-// Product of XCenter and the scaled and broadcasted product of F0 and
-// d(norm_factor)/d(variance + epsilon), XCenter * F0 * DNormFactor * 2 /
-// nelems. Forwards custom_call to XCenter, F0 and DNormFactor for capture or
-// verification.
-auto F1(UniqueHloInstruction* x, UniqueHloInstruction* x_center,
- UniqueHloInstruction* fused_expectation,
- UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
- UniqueHloInstruction* dy, HloInstruction** reduce,
- const NormMetadataMap& norm_metadata) {
- auto broadcasts_two_over_nelems = [](const HloInstruction* instr) -> bool {
- const HloInstruction* multiply = SkipUnaryOps(instr->operand(0));
- bool bcast_operand =
- multiply->operand(0)->opcode() != HloOpcode::kBroadcast;
-
- // The captured scalar must be two over the number of elements in the
- // broadcasted dimensions.
- float actual_two_over_nelems = multiply->operand(bcast_operand)
- ->operand(0)
- ->literal()
- .GetAsDouble({})
- .value();
- int64_t nelems = 1;
- for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
- if (!absl::c_linear_search(instr->dimensions(), i)) {
- nelems *= instr->shape().dimensions()[i];
- }
- }
- // The absolute of the difference between the actual scaling factor and the
- // reference value must not exceed a prescribed threshold.
- float two_over_nelems = 2. / static_cast<float>(nelems);
- float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
- return abs(actual_two_over_nelems - two_over_nelems) <
- ((actual_two_over_nelems + two_over_nelems) * numerical_epsilon);
- };
-
- return MultiplyAnyOrder(
- XCenter(x_center, x, fused_expectation, custom_call, norm_metadata),
- m::Broadcast(
- MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
- MultiplyAnyOrder(DNormFactor(custom_call),
- F0(custom_call, scale, dy, x,
- reduce, norm_metadata))))
- .WithPredicate(broadcasts_two_over_nelems));
-}
-
-// Product of the norm factor, scale and DY, NormFactor * scale * DY. Captures
-// the scale in scale if custom_call is a forward layer norm fusing the scale.
-// Forwards custom_call to FusedNormFactor for comparison.
-auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale,
- UniqueHloInstruction* dy, UniqueHloInstruction* custom_call,
- const NormMetadataMap& norm_metadata) {
- auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
- const HloInstruction* instr) -> bool {
- return scale->CaptureOrVerify(
- FindTarget(custom_call->Instr(), instr->operand(0),
- custom_call->Instr()->operand(1), norm_metadata)
- ? custom_call->Instr()->mutable_operand(1)
- : nullptr);
- };
- return MultiplyAnyOrder(
- m::Broadcast(
- BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))),
- MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale),
- m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
-}
-
-class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
- public:
- explicit CudnnNormRewriterVisitor(
- const se::CudaComputeCapability cuda_compute_capability)
- : cuda_compute_capability_(cuda_compute_capability) {}
-
- absl::Status HandleAdd(HloInstruction* instr) override {
- TF_RETURN_IF_ERROR(MatchLayerNorm(instr));
- TF_RETURN_IF_ERROR(MatchLayerNormGradient(instr));
- return absl::OkStatus();
- }
-
- absl::Status HandleSubtract(HloInstruction* instr) override {
- return MatchLayerNorm(instr);
- }
-
- // Matches and rewrites layer norm patterns,
- // Y = (X - expectation(X))/sqrt(variance(X) + epsilon) * scale + bias,
- // into Custom Calls to cuDNN.
- absl::Status MatchLayerNorm(HloInstruction* instr) {
- UniqueHloInstruction x, expectation, variance, epsilon;
- HloInstruction *scale, *bias, *reduce, *norm_factor, *broadcast_scale,
- *broadcast_bias;
- if (Match(
- instr,
- SubtractMultiplyAddAnyOrder(
- OptionalSupportedTransform(
- m::Op().WithPredicate(x.GetCaptureOrVerifyFn())),
- Expectation(&expectation, &reduce,
- OptionalSupportedTransform(m::Op().WithPredicate(
- x.GetCaptureOrVerifyFn()))),
- NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon),
- m::Broadcast(&broadcast_scale, m::Op(&scale)),
- m::Broadcast(&broadcast_bias, m::Op(&bias))))) {
-#if CUDNN_VERSION < 8905
- // Layer norm kernels are available with cuDNN 8.9.5 and above.
- VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5.";
- return absl::OkStatus();
-#endif // CUDNN_VERSION < 8905
-
- if (!instr->GetModule()
- ->config()
- .debug_options()
- .xla_gpu_enable_cudnn_layer_norm()) {
- VLOG(1) << "Layer norm Custom Calls disabled.";
- return absl::OkStatus();
- }
-
- // Layer norm kernels require Ampere or Hopper architectures.
- if (cuda_compute_capability_.major != se::CudaComputeCapability::AMPERE &&
- cuda_compute_capability_.major != se::CudaComputeCapability::HOPPER) {
- VLOG(1) << "Layer norm Custom Calls require Ampere or Hopper "
- "architectures.";
- return absl::OkStatus();
- }
-
- // Verify the uniqueness of the inputs.
- if (!x.Instr() || !expectation.Instr() || !variance.Instr() ||
- !epsilon.Instr()) {
- VLOG(1) << "Layer norm operands not unique.";
- return absl::OkStatus();
- }
-
- // Verify the input and output layouts.
- // TODO(philipphack): Consider supporting more general cases.
- if (!LayoutUtil::IsMonotonicWithDim0Major(x.Instr()->shape().layout()) ||
- !LayoutUtil::IsMonotonicWithDim0Major(scale->shape().layout()) ||
- !LayoutUtil::IsMonotonicWithDim0Major(bias->shape().layout()) ||
- !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) {
- VLOG(1) << "Layer norm input and/or output layouts nor supported.";
- return absl::OkStatus();
- }
-
- // Verify the element types. The element types of input and output and the
- // shapes of scale and bias must match.
- if (!CompatibleElementType(instr) || !CompatibleElementType(scale) ||
- !CompatibleElementType(bias) ||
- !ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) ||
- !ShapeUtil::Equal(scale->shape(), bias->shape())) {
- VLOG(1) << "Layer norm input types or shapes not supported.";
- return absl::OkStatus();
- }
-
- // Verify that the shapes of scale and bias are compatible with the
- // operation. The adjusted norm dimensions are the dimensions of the
- // reduction after removing any degenerate dimensions from the input of
- // the reduction.
- std::vector<int64_t> norm_dims(reduce->dimensions().begin(),
- reduce->dimensions().end());
- std::vector<int64_t> norm_dims_adjusted = AdjustedDimensions(reduce);
- if (norm_dims_adjusted.size() !=
- ShapeUtil::DropDegenerateDimensions(scale->shape())
- .dimensions_size()) {
- VLOG(1) << "Layer norm input dimensions not supported.";
- return absl::OkStatus();
- }
-
- // Verify the broadcasts of scale and bias.
- if (!ShapeUtil::EqualIgnoringElementType(
- ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
- ShapeUtil::DropDegenerateDimensions(broadcast_scale->shape())) ||
- !ShapeUtil::EqualIgnoringElementType(
- ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
- ShapeUtil::DropDegenerateDimensions(broadcast_bias->shape())) ||
- norm_dims_adjusted != AdjustedDimensions(broadcast_scale) ||
- norm_dims_adjusted != AdjustedDimensions(broadcast_bias)) {
- VLOG(1) << "Layer norm operand broadcast not supported.";
- return absl::OkStatus();
- }
-
- // If necessary, transpose the input so that the dimensions not being
- // normalized are the leading dimensions.
- std::vector<int64_t> non_norm_dims;
- for (int64_t x_dim = 0; x_dim < x.Instr()->shape().rank(); ++x_dim) {
- if (std::find(norm_dims.begin(), norm_dims.end(), x_dim) ==
- norm_dims.end()) {
- non_norm_dims.emplace_back(x_dim);
- }
- }
- std::vector<int64_t> non_norm_dims_adjusted =
- AdjustedDimensions(x.Instr()->shape(), non_norm_dims);
-
- std::vector<int64_t> x_transpose_order = non_norm_dims;
- x_transpose_order.insert(x_transpose_order.end(), norm_dims.begin(),
- norm_dims.end());
-
- bool apply_transpose = false;
- for (int i = 0; i < x_transpose_order.size(); ++i) {
- if (x_transpose_order[i] != i) {
- apply_transpose = true;
- break;
- }
- }
-
- std::optional<HloInstruction*> x_transpose;
- // The transpose applied to the output is the inverse of the transpose
- // applied to the input.
- std::vector<int64_t> y_transpose_order(x_transpose_order.size());
- if (apply_transpose) {
- for (int k = 0; k < x_transpose_order.size(); ++k) {
- y_transpose_order[x_transpose_order[k]] = k;
- }
- TF_ASSIGN_OR_RETURN(x_transpose,
- MakeTransposeHlo(x.Instr(), x_transpose_order));
- }
-
- // Combine the dimensions not normalized into the first dimension of the
- // input as required by cuDNN.
- std::vector<int64_t> reshaped_dims = {1};
- for (auto non_norm_dim : non_norm_dims) {
- reshaped_dims[0] *= x.Instr()->shape().dimensions(non_norm_dim);
- }
- for (auto norm_dim : norm_dims) {
- reshaped_dims.emplace_back(x.Instr()->shape().dimensions(norm_dim));
- }
- // cuDNN requires tensors to have at least four dimensions.
- while (reshaped_dims.size() < 4) {
- reshaped_dims.emplace_back(1);
- }
-
- Shape reshaped_shape = ShapeUtil::MakeShape(
- x.Instr()->shape().element_type(), reshaped_dims);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * x_reshape,
- MakeReshapeHlo(reshaped_shape, x_transpose.value_or(x.Instr())));
-
- // Reshape the scale and bias. The first dimension corresponds to the
- // non-normalization dimension of the norm input and must have size 1.
- std::vector<int64_t> reshaped_scale_dims = reshaped_dims;
- reshaped_scale_dims[0] = 1;
-
- Shape scale_bias_shape = ShapeUtil::MakeShape(
- scale->shape().element_type(), reshaped_scale_dims);
- TF_ASSIGN_OR_RETURN(HloInstruction * scale_reshape,
- MakeReshapeHlo(scale_bias_shape, scale));
- TF_ASSIGN_OR_RETURN(HloInstruction * bias_reshape,
- MakeReshapeHlo(scale_bias_shape, bias));
- GpuBackendConfig gpu_backend_config;
- CudnnNormBackendConfig& backend_config =
- *gpu_backend_config.mutable_cudnn_norm_backend_config();
- backend_config.set_epsilon(
- epsilon.Instr()->literal().GetAsDouble({}).value());
- backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_INFER);
- auto* algorithm = backend_config.mutable_algorithm();
- algorithm->set_algo_id(0);
- algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
- algorithm->set_is_cudnn_frontend(true);
-
- // Set the workspace size to its upper bound.
- // TODO(philipphack): Consider autotuning the norm kernels.
- TF_ASSIGN_OR_RETURN(const int64_t c_constant,
- CConstant(cuda_compute_capability_));
- const int64_t workspace_size =
- (2 * c_constant * (4 + 256)) + (2 * reshaped_dims[0] * 4) + 64;
- algorithm->mutable_workspace_size()->set_value(workspace_size);
-
- // The output of the Custom Call is a tuple, the second element of which
- // describes the scratch space.
- Shape custom_call_shape = ShapeUtil::MakeTupleShape(
- {x_reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})});
-
- HloInstruction* custom_call =
- instr->AddInstruction(HloInstruction::CreateCustomCall(
- custom_call_shape, {x_reshape, scale_reshape, bias_reshape},
- kCudnnNormCallTarget));
- TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
-
- TF_ASSIGN_OR_RETURN(HloInstruction * gte,
- MakeGetTupleElementHlo(custom_call, 0));
- TF_ASSIGN_OR_RETURN(
- HloInstruction * y_reshape,
- MakeReshapeHlo(x_transpose.value_or(instr)->shape(), gte));
-
- std::optional<HloInstruction*> y_transpose;
- if (apply_transpose) {
- TF_ASSIGN_OR_RETURN(y_transpose,
- MakeTransposeHlo(y_reshape, y_transpose_order));
- }
- TF_RETURN_IF_ERROR(
- ReplaceInstruction(instr, y_transpose.value_or(y_reshape)));
-
- // Store metadata for potential use in the backward graph.
- norm_metadata_.insert(
- {custom_call,
- NormMetadata({x_transpose.value_or(nullptr),
- y_transpose.value_or(nullptr), norm_dims_adjusted,
- non_norm_dims_adjusted})});
-
- VLOG(1) << "Layer norm rewritten into Custom Call.";
-
- // The layer norm training graph separately contains the norm factor
- // divided by the sum of variance and epsilon.
- for (HloInstruction* user : norm_factor->users()) {
- if (user->opcode() == HloOpcode::kDivide &&
- user->operand_index(norm_factor) == 0) {
- TF_ASSIGN_OR_RETURN(bool changed,
- MatchNormFactor(user, custom_call, variance,
- expectation, epsilon));
- if (changed) {
- break;
- }
- }
- }
- }
-
- return absl::OkStatus();
- }
-
- // The layer norm training graph separately contains the expectation as well
- // as the norm factor and its cube, (variance + epsilon)^-1/2 and (variance +
- // epsilon)^-3/2. When identified in the graph, these quantities are fused
- // into the layer norm Custom Call.
- absl::StatusOr<bool> MatchNormFactor(HloInstruction* instr,
- HloInstruction* custom_call,
- UniqueHloInstruction& variance,
- UniqueHloInstruction& expectation,
- UniqueHloInstruction& epsilon) {
- HloInstruction* gte = custom_call->users()[0];
- if (Match(instr,
- m::Divide(
- m::Op(),
- AddAnyOrder(
- m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()),
- m::Broadcast(m::ConstantScalar().WithPredicate(
- epsilon.GetCaptureOrVerifyFn())))))) {
- // Verify the uniqueness of the operands.
- if (!variance.Instr() || !epsilon.Instr()) {
- VLOG(1) << "Layer norm operands not unique.";
- return false;
- }
-
- // Verify the element types.
- if (!CompatibleElementType(instr) ||
- !CompatibleElementType(expectation.Instr())) {
- VLOG(1) << "Layer norm input types not compatible.";
- return false;
- }
-
- // Retrieve metadata of the forward layer norm.
- auto norm_metadata = norm_metadata_.extract(custom_call);
- if (!norm_metadata) {
- VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
- return false;
- }
-
- // The shape of the expectation and norm factor return values of the
- // Custom Call is [nelems, 1, 1, 1], where nelems is the
- // number of elements in the expectation and norm factor shapes.
- auto make_compatible_shape = [](Shape shape) -> Shape {
- return ShapeUtil::MakeShape(shape.element_type(),
- {ShapeUtil::ElementsIn(shape), 1, 1, 1});
- };
-
- Shape expectation_shape =
- make_compatible_shape(expectation.Instr()->shape());
- Shape norm_factor_shape = make_compatible_shape(instr->shape());
-
- // The augmented Custom Call additionally returns the expectation and the
- // norm factor.
- std::vector<Shape> tuple_shapes = custom_call->shape().tuple_shapes();
- tuple_shapes.insert(tuple_shapes.begin() + 1,
- {expectation_shape, norm_factor_shape});
-
- Shape custom_call_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
-
- HloInstruction* new_custom_call = instr->AddInstruction(
- custom_call->CloneWithNewShape(custom_call_shape));
-
- TF_ASSIGN_OR_RETURN(
- GpuBackendConfig gpu_backend_config,
- custom_call->backend_config<xla::gpu::GpuBackendConfig>());
- CudnnNormBackendConfig& backend_config =
- *gpu_backend_config.mutable_cudnn_norm_backend_config();
- backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_TRAIN);
-
- // Update the workspace size.
- TF_ASSIGN_OR_RETURN(const int64_t c_constant,
- CConstant(cuda_compute_capability_));
- const int64_t workspace_size = (2 * c_constant * (4 + 256)) + 32;
- backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
- workspace_size);
- TF_RETURN_IF_ERROR(
- new_custom_call->set_backend_config(gpu_backend_config));
-
- auto replace_with_new_cc = [new_custom_call, this](
- HloInstruction* old_instr,
- int tuple_index) -> absl::Status {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_gte,
- MakeGetTupleElementHlo(new_custom_call, tuple_index));
- HloInstruction* new_instr = new_gte;
- if (!ShapeUtil::Equal(new_gte->shape(), old_instr->shape())) {
- TF_ASSIGN_OR_RETURN(new_instr,
- MakeReshapeHlo(old_instr->shape(), new_gte));
- }
- if (old_instr->opcode() != HloOpcode::kDivide) {
- // Replace the result of the layer norm or the expectation.
- TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
- } else {
- // Replace the norm factor, (variance + epsilon)^-1/2.
- TF_RETURN_IF_ERROR(
- ReplaceInstruction(old_instr->mutable_operand(0), new_instr));
- // Also replace the norm factor to the power of 3, (variance +
- // epsilon)^-1/2 / (variance + epsilon) = ((variance +
- // epsilon)^-1/2)^3.
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_multiply0,
- MakeBinaryHlo(HloOpcode::kMultiply, new_instr, new_instr));
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_multiply1,
- MakeBinaryHlo(HloOpcode::kMultiply, new_multiply0, new_instr));
- TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_multiply1));
- }
- return absl::OkStatus();
- };
-
- // Replace the result of the original Custom Call as well as the
- // expectation and the norm factor with the augmented Custom Call.
- TF_RETURN_IF_ERROR(replace_with_new_cc(gte, 0));
- TF_RETURN_IF_ERROR(replace_with_new_cc(expectation.Instr(), 1));
- TF_RETURN_IF_ERROR(replace_with_new_cc(instr, 2));
-
- // Update the Custom Call associated with the metadata of the forward
- // norm.
- norm_metadata.key() = new_custom_call;
- norm_metadata_.insert(std::move(norm_metadata));
-
- VLOG(1)
- << "Expectation and norm factor fused into layer norm Custom Call.";
- }
-
- return true;
- }
-
- // Matches and rewrites the backward graph of layer norm patterns into Custom
- // Calls to cuDNN when the associated forward graph has been rewritten into a
- // cuDNN Custom Call. The gradients are
- // DX = F1 + F2 - AddReduce(F1 + F2) / nelems,
- // Dscale = AddReduce(DY * XCenter * NormFactor),
- // Dbias = AddReduce(DY),
- // with
- // F0 = XCenter * scale * DY,
- // F1 = XCenter * F0 * DNormFactor * 2 / nelems,
- // F2 = NormFactor * scale * DY,
- // XCenter = X - expectation(X),
- // NormFactor = (variance(X) + epsilon)^-1/2 and
- // DNormFactor = -1/2 * NormFactor^3.
- absl::Status MatchLayerNormGradient(HloInstruction* instr) {
- UniqueHloInstruction fwd_custom_call, x, x_center, scale, dy,
- fused_expectation, fused_norm_factor;
- HloInstruction *broadcast, *scalar, *dscale, *dbias, *reduce0, *reduce1,
- *reduce2, *reduce3;
- if (Match(instr,
- AddAddAnyOrder(
- m::Broadcast(
- &broadcast,
- MultiplyAddAnyOrder(
- m::Broadcast(m::ConstantScalar(&scalar)),
- NegateAddReduce(&reduce0,
- F1(&x, &x_center, &fused_expectation,
- &fwd_custom_call, &scale, &dy,
- &reduce2, norm_metadata_)),
- NegateAddReduce(
- &reduce1, F2(&fused_norm_factor, &scale, &dy,
- &fwd_custom_call, norm_metadata_)))),
- F2(&fused_norm_factor, &scale, &dy, &fwd_custom_call,
- norm_metadata_),
- F1(&x, &x_center, &fused_expectation, &fwd_custom_call,
- &scale, &dy, &reduce3, norm_metadata_)))) {
- // Skip initial convert, if present.
- if (instr->user_count() == 1 &&
- instr->users()[0]->opcode() == HloOpcode::kConvert &&
- CompatibleElementType(instr->users()[0])) {
- instr = instr->users()[0];
- }
-
- // Verify the uniqueness of the captured Custom Call and inputs.
- if (!fwd_custom_call.Instr() || !x.Instr() || !dy.Instr() ||
- !x_center.Instr() || !scale.Instr() || !fused_expectation.Instr() ||
- !fused_norm_factor.Instr()) {
- VLOG(1) << "Layer norm gradient inputs not unique.";
- return absl::OkStatus();
- }
-
- // Retrieve metadata of the forward layer norm.
- auto norm_metadata = norm_metadata_.find(fwd_custom_call.Instr());
- if (norm_metadata == norm_metadata_.end()) {
- VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
- return absl::OkStatus();
- }
-
- // Verify the dimensions of reductions in the backward graph.
- if (AdjustedDimensions(reduce0) !=
- norm_metadata->second.norm_dims_adjusted ||
- AdjustedDimensions(reduce1) !=
- norm_metadata->second.norm_dims_adjusted ||
- AdjustedDimensions(reduce2) !=
- norm_metadata->second.norm_dims_adjusted ||
- AdjustedDimensions(reduce3) !=
- norm_metadata->second.norm_dims_adjusted) {
- VLOG(1) << "Unexpected reductions dimensions in layer norm gradient.";
- return absl::OkStatus();
- }
-
- // The captured scalar must be one over the number of elements in the
- // broadcasted dimensions.
- float actual_r_nelems = scalar->literal().GetAsDouble({}).value();
- int64_t nelems = 1;
- for (int i = 0; i < broadcast->shape().dimensions_size(); ++i) {
- if (!absl::c_linear_search(broadcast->dimensions(), i)) {
- nelems *= broadcast->shape().dimensions()[i];
- }
- }
- // The absolute of the difference between the actual scaling factor and
- // the reference value must not exceed a prescribed threshold.
- float r_nelems = 1. / static_cast<float>(nelems);
- float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
- if (!(abs(actual_r_nelems - r_nelems) <
- ((actual_r_nelems + r_nelems) * numerical_epsilon))) {
- VLOG(1)
- << "Layer norm backward broadcast operand outside expected range.";
- return absl::OkStatus();
- }
-
- // Identify Dscale = AddReduce(DY * XCenter * norm factor) with factor0
- // and factor1 intended to be XCenter and DY or DY and XCenter.
- auto find_dscale =
- [&fused_norm_factor, &norm_metadata](
- const UniqueHloInstruction& factor0,
- const UniqueHloInstruction& factor1) -> HloInstruction* {
- for (HloInstruction* factor0_user : factor0.Instr()->users()) {
- std::vector<HloInstruction*> users;
- SkipUnaryOpsTopDownRecursive(factor0_user, users);
- // One of the users of factor0 must be a chained multiplication by the
- // fused norm factor and factor1.
- for (HloInstruction* user : users) {
- if (Match(user,
- MultiplyAnyOrder(
- m::Op(), MultiplyAnyOrder(
- m::Broadcast(BitcastOrReshape(m::Op().Is(
- fused_norm_factor.Instr()))),
- m::Op().Is(factor1.Instr()))))) {
- // Dscale is an addition-reduction of the product.
- for (HloInstruction* multiply_user : user->users()) {
- if (AppliesAddReduce(
- multiply_user,
- norm_metadata->second.non_norm_dims_adjusted)) {
- return multiply_user;
- }
- }
- }
- }
- }
- return nullptr;
- };
- if (!(dscale = find_dscale(x_center, dy)) &&
- !(dscale = find_dscale(dy, x_center))) {
- VLOG(1) << "Unable to identify Dscale in graph.";
- return absl::OkStatus();
- }
-
- // Find Dbias, i.e. an addition-reduction of DY, starting from DY.
- // Rewriting proceeds without fusing Dbias if unsuccessful.
- dbias = FindAddReduce(dy.Instr(),
- norm_metadata->second.non_norm_dims_adjusted);
-
- // Verify the input and output layouts.
- // TODO(philipphack): Consider supporting more general cases.
- if (!LayoutUtil::IsMonotonicWithDim0Major(dy.Instr()->shape().layout()) ||
- !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()) ||
- !LayoutUtil::IsMonotonicWithDim0Major(dscale->shape().layout()) ||
- (dbias &&
- !LayoutUtil::IsMonotonicWithDim0Major(dbias->shape().layout()))) {
- VLOG(1) << "Layer norm input and/or output layouts nor supported.";
- return absl::OkStatus();
- }
-
- // The types of X and DX must match.
- if (x.Instr()->shape().element_type() != instr->shape().element_type()) {
- VLOG(1) << "The types of X and DX must match.";
- return absl::OkStatus();
- }
-
- // The types and shapes of scale, Dscale and Dbias (if present) must
- // match.
- if (!ShapeUtil::Equal(
- ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
- ShapeUtil::DropDegenerateDimensions(dscale->shape())) ||
- (dbias &&
- !ShapeUtil::Equal(
- ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
- ShapeUtil::DropDegenerateDimensions(dbias->shape())))) {
- VLOG(1) << "Backward layer norm types not supported.";
- return absl::OkStatus();
- }
-
- // Verify the element types.
- if (!CompatibleElementType(dy.Instr())) {
- VLOG(1) << "Backward layer norm types not supported.";
- return absl::OkStatus();
- }
-
- // cuDNN requires the byte size of the element type of X to be at least
- // that of DY and scale.
- if (ShapeUtil::ByteSizeOfPrimitiveType(
- x.Instr()->shape().element_type()) <
- ShapeUtil::ByteSizeOfPrimitiveType(
- dy.Instr()->shape().element_type()) ||
- ShapeUtil::ByteSizeOfPrimitiveType(
- x.Instr()->shape().element_type()) <
- ShapeUtil::ByteSizeOfPrimitiveType(
- scale.Instr()->shape().element_type())) {
- VLOG(1) << "Backward layer norm types not supported.";
- return absl::OkStatus();
- }
-
- // Transpose DY applying the stored transpose order of X from the forward
- // graph.
- HloInstruction* transposed_dy = dy.Instr();
- if (norm_metadata->second.x_transpose) {
- TF_ASSIGN_OR_RETURN(
- transposed_dy,
- MakeTransposeHlo(dy.Instr(),
- norm_metadata->second.x_transpose->dimensions()));
- }
- TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_dy,
- MakeReshapeHlo(x.Instr()->shape(), transposed_dy));
-
- Shape dx_shape = ShapeUtil::MakeShape(instr->shape().element_type(),
- x.Instr()->shape().dimensions());
-
- Shape dscale_dbias_shape = ShapeUtil::MakeShape(
- dscale->shape().element_type(), scale.Instr()->shape().dimensions());
-
- GpuBackendConfig gpu_backend_config;
- CudnnNormBackendConfig& backend_config =
- *gpu_backend_config.mutable_cudnn_norm_backend_config();
- backend_config.set_kind(CudnnNormBackendConfig::LAYER_BWD);
- auto* algorithm = backend_config.mutable_algorithm();
- algorithm->set_algo_id(0);
- algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
- algorithm->set_is_cudnn_frontend(true);
-
- // Set the workspace size to its upper bound.
- // TODO(philipphack): Consider autotuning the norm kernels.
- TF_ASSIGN_OR_RETURN(const int64_t c_constant,
- CConstant(cuda_compute_capability_));
- const int64_t workspace_size =
- (2 * c_constant * (4 + 256)) +
- (2 * x.Instr()->shape().dimensions(0) * 4) + 64;
- algorithm->mutable_workspace_size()->set_value(workspace_size);
-
- // The output of the Custom Call is a tuple. The output shape of Dscale
- // and Dbias is that of scale.
- Shape custom_call_shape = ShapeUtil::MakeTupleShape(
- {dx_shape, dscale_dbias_shape, dscale_dbias_shape,
- ShapeUtil::MakeShape(U8, {workspace_size})});
-
- HloInstruction* custom_call =
- instr->AddInstruction(HloInstruction::CreateCustomCall(
- custom_call_shape,
- {x.Instr(), scale.Instr(), reshaped_dy, fused_expectation.Instr(),
- fused_norm_factor.Instr()},
- kCudnnNormCallTarget));
- TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
-
- auto replace_with_cc = [custom_call, norm_metadata, transposed_dy, this](
- HloInstruction* old_instr,
- int tuple_index) -> absl::Status {
- TF_ASSIGN_OR_RETURN(HloInstruction * gte,
- MakeGetTupleElementHlo(custom_call, tuple_index));
- HloInstruction* new_instr;
- // Transpose DX applying the stored transpose order of Y from the
- // forward graph.
- if (tuple_index == 0 && norm_metadata->second.y_transpose) {
- TF_ASSIGN_OR_RETURN(new_instr,
- MakeReshapeHlo(transposed_dy->shape(), gte));
- TF_ASSIGN_OR_RETURN(
- new_instr,
- MakeTransposeHlo(
- new_instr, norm_metadata->second.y_transpose->dimensions()));
- } else {
- TF_ASSIGN_OR_RETURN(new_instr,
- MakeReshapeHlo(old_instr->shape(), gte));
- }
- TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
- return absl::OkStatus();
- };
-
- TF_RETURN_IF_ERROR(replace_with_cc(instr, 0));
- TF_RETURN_IF_ERROR(replace_with_cc(dscale, 1));
- if (dbias) {
- TF_RETURN_IF_ERROR(replace_with_cc(dbias, 2));
- }
- VLOG(1) << "Gradients w.r.t. x"
- << (dbias ? ", scale and bias" : " and scale")
- << " rewritten into layer norm backward Custom Call.";
- }
-
- return absl::OkStatus();
- }
-
- private:
- se::CudaComputeCapability cuda_compute_capability_;
- NormMetadataMap norm_metadata_;
-};
-
-absl::StatusOr<bool> RunOnComputation(
- HloComputation* computation,
- se::CudaComputeCapability cuda_compute_capability) {
- CudnnNormRewriterVisitor visitor(cuda_compute_capability);
- TF_RETURN_IF_ERROR(computation->Accept(&visitor));
- return visitor.changed();
-}
-
-} // anonymous namespace
-
-CudnnNormRewriter::CudnnNormRewriter(
- se::CudaComputeCapability cuda_compute_capability)
- : cuda_compute_capability_(cuda_compute_capability) {}
-
-absl::StatusOr<bool> CudnnNormRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(
- bool result, RunOnComputation(computation, cuda_compute_capability_));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h
deleted file mode 100644
index 7b3ef8d..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites norm patterns into Custom Calls to the cuDNN library. Currently, the
-// forward and backward passes of layer norm patterns are implemented.
-class CudnnNormRewriter : public HloModulePass {
- public:
- explicit CudnnNormRewriter(se::CudaComputeCapability cuda_compute_capability);
- absl::string_view name() const override { return "norm-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::CudaComputeCapability cuda_compute_capability_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc
deleted file mode 100644
index 754563a..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc
+++ /dev/null
@@ -1,1798 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <string>
-
-#include <gtest/gtest.h>
-#include "xla/error_spec.h"
-#include "xla/stream_executor/device_description.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#endif
-
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class CudnnNormRewriterTest : public GpuCodegenTest {
- public:
- se::CudaComputeCapability GetCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
-
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_cudnn_layer_norm(true);
- return debug_options;
- }
-
- protected:
- void TestNorm(std::string hlo_text, std::string optimized_hlo) {
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, optimized_hlo);
- }
-};
-
-// The following tests evaluate LayerNormXDY configurations, with X the rank of
-// the input and Y the dimensions that are normalized.
-TEST_F(CudnnNormRewriterTest, LayerNorm2D1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4] parameter(0)
- input_square = f32[2,4] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2] multiply(input_square_sum, r_nelems_bcast)
- input_sum = f32[2] reduce(input, c0),dimensions={1}, to_apply=apply
- input_mean = f32[2] multiply(input_sum, r_nelems_bcast)
- input_mean_square = f32[2] multiply(input_mean, input_mean)
- variance = f32[2] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
- norm_factor = f32[2] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
- input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
- input_center = f32[2,4] subtract(input, input_mean_bcast)
- norm = f32[2,4] multiply(norm_factor_bcast, input_center)
- scale = f32[4] parameter(1)
- scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
- norm_scale = f32[2,4] multiply(norm, scale_bcast)
- bias = f32[4] parameter(2)
- bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
- ROOT out = f32[2,4] add(norm_scale, bias_broadcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
- r_nelems = f32[] constant(0.125)
- r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
- input_sum = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
- input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
- input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
- variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
- norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[8] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[8] parameter(2)
- bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
- ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[1,4,6,8] parameter(0)
- input_square = f32[1,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[1,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
- r_nelems = f32[] constant(0.125)
- r_nelems_bcast = f32[1,4,6] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[1,4,6] multiply(input_square_sum, r_nelems_bcast)
- input_sum = f32[1,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
- input_mean = f32[1,4,6] multiply(input_sum, r_nelems_bcast)
- input_mean_square = f32[1,4,6] multiply(input_mean, input_mean)
- variance = f32[1,4,6] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[1,4,6] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[1,4,6] add(variance, epsilon_bcast)
- norm_factor = f32[1,4,6] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[1,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
- input_mean_bcast = f32[1,4,6,8] broadcast(input_mean), dimensions={0,1,2}
- input_center = f32[1,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[1,4,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[8] parameter(1)
- scale_bcast = f32[1,4,6,8] broadcast(scale), dimensions={3}
- norm_scale = f32[1,4,6,8] multiply(norm, scale_bcast)
- bias = f32[8] parameter(2)
- bias_bcast = f32[1,4,6,8] broadcast(bias), dimensions={3}
- ROOT out = f32[1,4,6,8] add(norm_scale, bias_bcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[1,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[1,4,6,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[24,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: ROOT [[GTE_BITCAST:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} bitcast([[GTE]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
- r_nelems = f32[] constant(0.166667)
- r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,4,8] multiply(input_square_sum, r_nelems_bcast)
- reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
- input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,4,8] multiply(input_mean, input_mean)
- variance = f32[2,4,8] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[6] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[6] parameter(2)
- bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
- ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,4,6,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,1,6,8] parameter(0)
- input_square = f32[2,1,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,1,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
- r_nelems = f32[] constant(0.166667)
- r_nelems_bcast = f32[2,1,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,1,8] multiply(input_square_sum, r_nelems_bcast)
- reduce = f32[2,1,8] reduce(input, c0), dimensions={2}, to_apply=apply
- input_mean = f32[2,1,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,1,8] multiply(input_mean, input_mean)
- variance = f32[2,1,8] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,1,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,1,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,1,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,1,6,8] broadcast(norm_factor), dimensions={0,1,3}
- input_mean_bcast = f32[2,1,6,8] broadcast(input_mean), dimensions={0,1,3}
- input_center = f32[2,1,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,1,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[6] parameter(1)
- scale_bcast = f32[2,1,6,8] broadcast(scale), dimensions={2}
- norm_scale = f32[2,1,6,8] multiply(norm, scale_bcast)
- bias = f32[6] parameter(2)
- bias_broadcast = f32[2,1,6,8] broadcast(bias), dimensions={2}
- ROOT out = f32[2,1,6,8] add(norm_scale, bias_broadcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,1,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,1,6,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,6]{3,2,1,0} transpose([[P0]]), dimensions={1,0,3,2}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D12) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
- r_nelems = f32[] constant(0.041667)
- r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
- reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
- input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,8] multiply(input_mean, input_mean)
- variance = f32[2,8] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[4,6] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[4,6] parameter(2)
- bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
- ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> f32[2,4,6,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,1,8] parameter(0)
- input_square = f32[2,4,1,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
- reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
- input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,8] multiply(input_mean, input_mean)
- variance = f32[2,8] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
- input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
- input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
- scale = f32[4,1] parameter(1)
- scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
- norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
- bias = f32[4,1] parameter(2)
- bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
- ROOT out = f32[2,4,1,8] add(norm_scale, bias_broadcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> f32[2,4,1,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,2,2,2] parameter(0)
- input_square = f32[2,2,2,2] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,2,2] reduce(input_square, c0), dimensions={3}, to_apply=apply
- r_nelems = f32[] constant(0.5)
- r_nelems_bcast = f32[2,2,2] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,2,2] multiply(input_square_sum, r_nelems_bcast)
- input_sum = f32[2,2,2] reduce(input, c0), dimensions={3}, to_apply=apply
- input_mean = f32[2,2,2] multiply(input_sum, r_nelems_bcast)
- input_mean_square = f32[2,2,2] multiply(input_mean, input_mean)
- variance = f32[2,2,2] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,2,2] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,2,2] add(variance, epsilon_bcast)
- norm_factor = f32[2,2,2] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,2,2,2] broadcast(norm_factor), dimensions={0,1,2}
- input_mean_bcast = f32[2,2,2,2] broadcast(input_mean), dimensions={0,1,2}
- input_center = f32[2,2,2,2] subtract(input, input_mean_bcast)
- norm = f32[2,2,2,2] multiply(norm_factor_bcast, input_center)
- scale = f32[2] parameter(1)
- scale_bcast = f32[2,2,2,2] broadcast(scale), dimensions={2}
- norm_scale = f32[2,2,2,2] multiply(norm, scale_bcast)
- bias = f32[2] parameter(2)
- bias_bcast = f32[2,2,2,2] broadcast(bias), dimensions={3}
- ROOT out = f32[2,2,2,2] add(norm_scale, bias_bcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,2,2,2], {{.*}}: f32[2], {{.*}}: f32[2]) -> f32[2,2,2,2] {
-; CHECK-NOT: custom_call_target="__cudnn$norm"
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f16[2,4,6,8] parameter(0)
- input_f32 = f32[2,4,6,8] convert(input)
- input_square = f32[2,4,6,8] multiply(input_f32, input_f32)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
- r_nelems = f32[] constant(0.125)
- r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
- input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply
- input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
- input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
- variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
- norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
- input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[8] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[8] parameter(2)
- bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
- ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
-; CHECK-NOT: custom_call_target="__cudnn$norm"
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4] parameter(0)
- input_square = f32[2,4] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
- reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
- input_mean = f32[2] multiply(reduce,r_nelems_bcast)
- input_mean_square = f32[2] multiply(input_mean,input_mean)
- variance = f32[2] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
- norm_factor = f32[2] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
- input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
- input_center = f32[2,4] subtract(input,input_mean_bcast)
- norm = f32[2,4] multiply(norm_factor_bcast,input_center)
- scale = f32[4] parameter(1)
- scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
- norm_scale = f32[2,4] multiply(norm,scale_bcast)
- bias = f32[4] parameter(2)
- bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
- norm_scale_bias = f32[2,4] add(norm_scale, bias_broadcast)
- norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
- ROOT out = (f32[2,4], f32[2], f32[2], f32[2]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> (f32[2,4], f32[2], f32[2], f32[2]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
-; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE1]])
-; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE2]])
-; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2]{0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
- r_nelems = f32[] constant(0.125)
- r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
- reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
- input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
- variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
- norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[8] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
- norm_scale = f32[2,4,6,8] multiply(norm,scale_bcast)
- bias = f32[8] parameter(2)
- bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
- norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
- norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
- ROOT out = (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
-; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE1]])
-; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE2]])
-; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,4,6]{2,1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
- r_nelems = f32[] constant(0.041667)
- r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
- reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
- input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,8] multiply(input_mean, input_mean)
- variance = f32[2,8] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
- scale = f32[4,6] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[4,6] parameter(2)
- bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
- norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
- norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
- ROOT out = (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
-; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
-; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,1,8] parameter(0)
- input_square = f32[2,4,1,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
- reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
- input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,8] multiply(input_mean, input_mean)
- variance = f32[2,8] subtract(input_square_mean, input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
- input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
- input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
- scale = f32[4,1] parameter(1)
- scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
- norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
- bias = f32[4,1] parameter(2)
- bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
- norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_broadcast)
- norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
- ROOT out = (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK: }
-; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
-; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
-; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4] parameter(0)
- input_square = f32[2,4] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
- reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
- input_mean = f32[2] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2] multiply(input_mean,input_mean)
- variance = f32[2] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
- norm_factor = f32[2] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
- input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
- input_center = f32[2,4] subtract(input, input_mean_bcast)
- norm = f32[2,4] multiply(input_center, norm_factor_bcast)
- scale = f32[4] parameter(1)
- scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
- norm_scale = f32[2,4] multiply(norm, scale_bcast)
- bias = f32[4] parameter(2)
- bias_bcast = f32[2,4] broadcast(bias), dimensions={1}
- norm_scale_bias = f32[2,4] add(norm_scale, bias_bcast)
- doutput = f32[2,4] parameter(3)
- dbias = f32[4] reduce(doutput, c0), dimensions={0}, to_apply=apply
- norm_doutput = f32[2,4] multiply(norm, doutput)
- dscale = f32[4] reduce(norm_doutput, c0), dimensions={0}, to_apply=apply
- scale_doutput = f32[2,4] multiply(scale_bcast, doutput)
- input_center_scale_doutput = f32[2,4] multiply(input_center, scale_doutput)
- f0 = f32[2] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
- norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
- c1 = f32[] constant(-0.5)
- c1_bcast = f32[2] broadcast(c1), dimensions={}
- dnorm_factor = f32[2] multiply(norm_factor_cube, c1_bcast)
- f0_dnorm_factor = f32[2] multiply(f0, dnorm_factor)
- c2 = f32[] constant(0.5)
- c2_bcast = f32[2] broadcast(c2), dimensions={}
- f0_dnorm_factor_scaled = f32[2] multiply(f0_dnorm_factor, c2_bcast)
- f0_dnorm_factor_scaled_bcast = f32[2,4] broadcast(f0_dnorm_factor_scaled), dimensions={0}
- f1 = f32[2,4] multiply(input_center, f0_dnorm_factor_scaled_bcast)
- minus_f1 = f32[2,4] negate(f1)
- minus_f1_sum = f32[2] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
- f2 = f32[2,4] multiply(norm_factor_bcast, scale_doutput)
- minus_f2 = f32[2,4] negate(f2)
- minus_f2_sum = f32[2] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
- minus_f1_f2_sum = f32[2] add(minus_f1_sum, minus_f2_sum)
- minus_f1_f2_sum_scaled = f32[2] multiply(minus_f1_f2_sum, r_nelems_bcast)
- minus_f1_f2_sum_scaled_bcast = f32[2,4] broadcast(minus_f1_f2_sum_scaled), dimensions={0}
- f1_f2 = f32[2,4] add(f1, f2)
- dinput = f32[2,4] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
- ROOT out = (f32[2,4], f32[2,4], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> (f32[2,4], f32[2,4], f32[4], f32[4]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
-; CHECK: }
-; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P3]])
-; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0
-; CHECK-DAG: "kind":"LAYER_BWD"
-; CHECK: }
-; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE3]])
-; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
-; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
-; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
- reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
- r_nelems = f32[] constant(0.125)
- r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,4,6] multiply(input_square_sum,r_nelems_bcast)
- input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,4,6] multiply(input_mean,input_mean)
- variance = f32[2,4,6] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
- norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
- scale = f32[8] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[8] parameter(2)
- bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
- norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
- doutput = f32[2,4,6,8] parameter(3)
- dbias = f32[8] reduce(doutput, c0), dimensions={0,1,2}, to_apply=apply
- norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
- dscale = f32[8] reduce(norm_doutput, c0), dimensions={0,1,2}, to_apply=apply
- scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
- input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
- f0 = f32[2,4,6] reduce(input_center_scale_doutput, c0), dimensions={3}, to_apply=apply
- norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
- c1 = f32[] constant(-0.5)
- c1_bcast = f32[2,4,6] broadcast(c1), dimensions={}
- dnorm_factor = f32[2,4,6] multiply(norm_factor_cube, c1_bcast)
- f0_dnorm_factor = f32[2,4,6] multiply(f0, dnorm_factor)
- c2 = f32[] constant(0.25)
- c2_bcast = f32[2,4,6] broadcast(c2), dimensions={}
- f0_dnorm_factor_scaled = f32[2,4,6] multiply(f0_dnorm_factor, c2_bcast)
- f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,2}
- f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
- minus_f1 = f32[2,4,6,8] negate(f1)
- minus_f1_sum = f32[2,4,6] reduce(minus_f1, c0), dimensions={3}, to_apply=apply
- f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
- minus_f2 = f32[2,4,6,8] negate(f2)
- minus_f2_sum = f32[2,4,6] reduce(minus_f2, c0), dimensions={3}, to_apply=apply
- minus_f1_f2_sum = f32[2,4,6] add(minus_f1_sum, minus_f2_sum)
- minus_f1_f2_sum_scaled = f32[2,4,6] multiply(minus_f1_f2_sum, r_nelems_bcast)
- minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,2}
- f1_f2 = f32[2,4,6,8] add(f1, f2)
- dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
- ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) tuple(norm_scale_bias, dinput, dscale, dbias)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
-; CHECK: }
-; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
-; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P3]])
-; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0
-; CHECK-DAG: "kind":"LAYER_BWD"
-; CHECK: }
-; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE3]])
-; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE4]])
-; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE5]])
-; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[8]{0}, f32[8]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
- reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
- r_nelems = f32[] constant(0.166667)
- r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,4,8] multiply(input_square_sum,r_nelems_bcast)
- input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,4,8] multiply(input_mean,input_mean)
- variance = f32[2,4,8] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
- scale = f32[6] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[6] parameter(2)
- bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
- norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
- doutput = f32[2,4,6,8] parameter(3)
- dbias = f32[6] reduce(doutput, c0), dimensions={0,1,3}, to_apply=apply
- norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
- dscale = f32[6] reduce(norm_doutput, c0), dimensions={0,1,3}, to_apply=apply
- scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
- input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
- f0 = f32[2,4,8] reduce(input_center_scale_doutput, c0), dimensions={2}, to_apply=apply
- norm_factor_cube = f32[2,4,8] divide(norm_factor, variance_plus_epsilon)
- c1 = f32[] constant(-0.5)
- c1_bcast = f32[2,4,8] broadcast(c1), dimensions={}
- dnorm_factor = f32[2,4,8] multiply(norm_factor_cube, c1_bcast)
- f0_dnorm_factor = f32[2,4,8] multiply(f0, dnorm_factor)
- c2 = f32[] constant(0.333333)
- c2_bcast = f32[2,4,8] broadcast(c2), dimensions={}
- f0_dnorm_factor_scaled = f32[2,4,8] multiply(f0_dnorm_factor, c2_bcast)
- f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,3}
- f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
- minus_f1 = f32[2,4,6,8] negate(f1)
- minus_f1_sum = f32[2,4,8] reduce(minus_f1, c0), dimensions={2}, to_apply=apply
- f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
- minus_f2 = f32[2,4,6,8] negate(f2)
- minus_f2_sum = f32[2,4,8] reduce(minus_f2, c0), dimensions={2}, to_apply=apply
- minus_f1_f2_sum = f32[2,4,8] add(minus_f1_sum, minus_f2_sum)
- minus_f1_f2_sum_scaled = f32[2,4,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
- minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,3}
- f1_f2 = f32[2,4,6,8] add(f1, f2)
- dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
- ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) tuple(norm_scale_bias, dinput, dscale, dbias)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
-; CHECK: }
-; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
-; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2}
-; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
-; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0
-; CHECK-DAG: "kind":"LAYER_BWD"
-; CHECK: }
-; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
-; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
-; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]])
-; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]])
-; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
- reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
- r_nelems = f32[] constant(0.041667)
- r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
- input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,8] multiply(input_mean,input_mean)
- variance = f32[2,8] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
- scale = f32[4,6] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[4,6] parameter(2)
- bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
- norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
- doutput = f32[2,4,6,8] parameter(3)
- dbias = f32[4,6] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
- norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
- dscale = f32[4,6] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
- scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
- input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
- f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
- norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
- c1 = f32[] constant(-0.5)
- c1_bcast = f32[2,8] broadcast(c1), dimensions={}
- dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
- f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
- c2 = f32[] constant(0.083333)
- c2_bcast = f32[2,8] broadcast(c2), dimensions={}
- f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
- f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
- f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
- minus_f1 = f32[2,4,6,8] negate(f1)
- minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
- f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
- minus_f2 = f32[2,4,6,8] negate(f2)
- minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
- minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
- minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
- minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
- f1_f2 = f32[2,4,6,8] add(f1, f2)
- dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
- ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) tuple(norm_scale_bias, dinput, dscale, dbias)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
-; CHECK: }
-; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
-; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2}
-; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
-; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0
-; CHECK-DAG: "kind":"LAYER_BWD"
-; CHECK: }
-; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
-; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
-; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]])
-; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]])
-; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,1,8] parameter(0)
- input_square = f32[2,4,1,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
- reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
- input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,8] multiply(input_mean,input_mean)
- variance = f32[2,8] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
- input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
- input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,1,8] multiply(input_center, norm_factor_bcast)
- scale = f32[4,1] parameter(1)
- scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
- norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
- bias = f32[4,1] parameter(2)
- bias_bcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
- norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_bcast)
- doutput = f32[2,4,1,8] parameter(3)
- dbias = f32[4,1] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
- norm_doutput = f32[2,4,1,8] multiply(norm, doutput)
- dscale = f32[4,1] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
- scale_doutput = f32[2,4,1,8] multiply(scale_bcast, doutput)
- input_center_scale_doutput = f32[2,4,1,8] multiply(input_center, scale_doutput)
- f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
- norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
- c1 = f32[] constant(-0.5)
- c1_bcast = f32[2,8] broadcast(c1), dimensions={}
- dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
- f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
- c2 = f32[] constant(0.5)
- c2_bcast = f32[2,8] broadcast(c2), dimensions={}
- f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
- f0_dnorm_factor_scaled_bcast = f32[2,4,1,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
- f1 = f32[2,4,1,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
- minus_f1 = f32[2,4,1,8] negate(f1)
- minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
- f2 = f32[2,4,1,8] multiply(norm_factor_bcast, scale_doutput)
- minus_f2 = f32[2,4,1,8] negate(f2)
- minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
- minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
- minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
- minus_f1_f2_sum_scaled_bcast = f32[2,4,1,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
- f1_f2 = f32[2,4,1,8] add(f1, f2)
- dinput = f32[2,4,1,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
- ROOT out = (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) tuple(norm_scale_bias, dinput, dscale, dbias)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1], {{.*}}: f32[2,4,1,8]) -> (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
-; CHECK: }
-; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(3)
-; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P3]]), dimensions={2,0,3,1}
-; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
-; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0
-; CHECK-DAG: "kind":"LAYER_BWD"
-; CHECK: }
-; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG: [[FUSION0:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=0
-; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=1
-; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE4]])
-; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE5]])
-; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-// TODO(b/343124533) Reenable when fixed
-TEST_F(CudnnNormRewriterTest,
- DISABLED_LayerNormTrainBackward4D1DoutputReshapeSplit) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
- reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
- input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
- variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
- scale = f32[4] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[4] parameter(2)
- bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
- norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
- doutput = f32[2,4,48] parameter(3)
- dbias = f32[4] reduce(doutput, c0), dimensions={0,2}, to_apply=apply
- doutput_bitcast = f32[2,4,6,8] reshape(doutput)
- norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
- dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
- scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
- input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
- f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
- norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
- c1 = f32[] constant(-0.5)
- c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
- dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
- f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
- c2 = f32[] constant(0.5)
- c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
- f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
- f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
- f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
- minus_f1 = f32[2,4,6,8] negate(f1)
- minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
- f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
- minus_f2 = f32[2,4,6,8] negate(f2)
- minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
- minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
- minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
- minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
- f1_f2 = f32[2,4,6,8] add(f1, f2)
- dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
- ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,48]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
-; CHECK: }
-; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,48]{2,1,0} parameter(3)
-; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
-; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0
-; CHECK-DAG: "kind":"LAYER_BWD"
-; CHECK: }
-; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
-; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
-; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
-; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
-; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-// TODO(b/343124533) Reenable when fixed
-TEST_F(CudnnNormRewriterTest,
- DISABLED_LayerNormTrainBackward4D1DoutputReshapeCombine) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
- GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
- if (!(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::AMPERE) &&
- !(GetCudaComputeCapability().major ==
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP()
- << "Layer norm kernels require Ampere or Hopper architectures.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a,b)
- }
-
- ENTRY test {
- input = f32[2,4,6,8] parameter(0)
- input_square = f32[2,4,6,8] multiply(input, input)
- c0 = f32[] constant(0)
- input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
- reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
- r_nelems = f32[] constant(0.25)
- r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
- input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
- input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
- input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
- variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
- epsilon = f32[] constant(0.001)
- epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
- variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
- norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
- norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
- input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
- input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
- norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
- scale = f32[4] parameter(1)
- scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
- norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
- bias = f32[4] parameter(2)
- bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
- norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
- doutput = f32[2,4,6,2,2,2] parameter(3)
- dbias = f32[4] reduce(doutput, c0), dimensions={0,2,3,4,5}, to_apply=apply
- doutput_bitcast = f32[2,4,6,8] reshape(doutput)
- norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
- dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
- scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
- input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
- f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
- norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
- c1 = f32[] constant(-0.5)
- c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
- dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
- f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
- c2 = f32[] constant(0.5)
- c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
- f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
- f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
- f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
- minus_f1 = f32[2,4,6,8] negate(f1)
- minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
- f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
- minus_f2 = f32[2,4,6,8] negate(f2)
- minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
- minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
- minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
- minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
- f1_f2 = f32[2,4,6,8] add(f1, f2)
- dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
- ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
- })";
-
- const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,6,2,2,2]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0.001
-; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
-; CHECK: }
-; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,2,2,2]{5,4,3,2,1,0} parameter(3)
-; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
-; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK: custom_call_target="__cudnn$norm",
-; CHECK: backend_config={
-; CHECK-DAG: "epsilon":0
-; CHECK-DAG: "kind":"LAYER_BWD"
-; CHECK: }
-; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
-; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
-; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
-; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
-; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
- )";
-
- TestNorm(hlo_text, optimized_hlo);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc
deleted file mode 100644
index ed83a62..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc
+++ /dev/null
@@ -1,528 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <optional>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/functional/bind_front.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/cudnn_support_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-// Creates and returns an HLO that zero-pads one or more dimensions in the given
-// instruction so that its shape is equal to the given shape.
-//
-// Padding is added to the end of each relevant dimension.
-//
-// If the instruction already has the given shape, simply returns it without an
-// intervening pad.
-static HloInstruction* PadInstruction(HloInstruction* instr,
- const Shape& new_shape) {
- HloComputation* comp = instr->parent();
-
- const Shape& shape = instr->shape();
- PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank());
-
- bool added_padding = false;
- for (int64_t dim = 0; dim < shape.rank(); ++dim) {
- if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
- continue;
- }
- CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim));
- pad_config.mutable_dimensions(dim)->set_edge_padding_high(
- new_shape.dimensions(dim) - shape.dimensions(dim));
- added_padding = true;
- }
- if (!added_padding) {
- return instr;
- }
-
- auto* zero = comp->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
- return comp->AddInstruction(
- HloInstruction::CreatePad(new_shape, instr, zero, pad_config),
- &instr->metadata());
-}
-
-// Modifies the given convolution to have the given input and result shapes.
-static absl::Status PadConv(HloCustomCallInstruction* conv,
- absl::Span<const Shape> new_input_shapes,
- const Shape& new_result_shape) {
- CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
- << "conv must use 0 scratch bytes, i.e. this pass must be run "
- "before CudnnConvAlgorithmPicker.";
- std::vector<HloInstruction*> new_operands;
- new_operands.reserve(conv->operand_count());
- for (int i = 0; i < conv->operand_count(); ++i) {
- new_operands.push_back(
- PadInstruction(conv->mutable_operand(i), new_input_shapes[i]));
- }
- const Shape& result_shape = conv->shape().tuple_shapes(0);
-
- bool changed = false;
- for (int i = 0; i < conv->operand_count(); ++i) {
- changed |= (new_operands[i] != conv->mutable_operand(i));
- }
- CHECK(changed) << "We should have had to pad at least one input operand.";
-
- auto add = [&](std::unique_ptr<HloInstruction> new_instr) {
- return conv->parent()->AddInstruction(std::move(new_instr));
- };
-
- Shape new_conv_shape = ShapeUtil::MakeTupleShape(
- {new_result_shape, ShapeUtil::MakeShape(U8, {0})});
- auto* new_conv =
- add(conv->CloneWithNewOperands(new_conv_shape, new_operands));
-
- // Clone conv's name to new_conv. This is safe because we're going to remove
- // conv below.
- new_conv->SetAndSanitizeName(conv->name());
-
- VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with "
- << new_conv->ToString();
-
- // Slice the new conv result if necessary, keeping in mind that new_conv
- // has tuple shape (new_result_shape, u8[0]).
- if (!ShapeUtil::Equal(result_shape, new_result_shape)) {
- std::vector<int64_t> start_indices(result_shape.dimensions_size(), 0);
- std::vector<int64_t> end_indices(result_shape.dimensions().begin(),
- result_shape.dimensions().end());
- std::vector<int64_t> strides(result_shape.dimensions_size(), 1);
-
- auto* new_conv_result = add(
- HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0));
- auto* empty_temp_buffer =
- add(HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({})));
- auto* sliced_result = add(HloInstruction::CreateSlice(
- result_shape, new_conv_result, start_indices, end_indices, strides));
- new_conv =
- add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer}));
- }
-
- return conv->parent()->ReplaceInstruction(conv, new_conv);
-}
-
-static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
- HloComputation* comp) {
- std::vector<HloCustomCallInstruction*> convs;
- for (HloInstruction* instr : comp->instructions()) {
- if (IsCustomCallToDnnConvolution(*instr)) {
- convs.push_back(Cast<HloCustomCallInstruction>(instr));
- }
- }
- return convs;
-}
-
-// This is the main function of the transform. It runs on a given custom call
-// nodes to cuDNN convolution, calls resolve_pad_shapes to resolve
-// the desired input/output feature map shapes, and adds necessary padding and
-// slicing nodes around them.
-//
-// resolve_pad_shapes takes conv, a custom call instruction to cuDNN convolution
-// that may need padding to figure out the desired padded input and output
-// tensor shapes and store the desired shapes in new_input_shapes and
-// new_input_shapes. Notice that new_input_shapes is a vector for multiple
-// input tensors. This function shall return true if padding is necessary or
-// false otherwise in addition to status.
-static absl::StatusOr<bool> ResolveAndPad(
- HloCustomCallInstruction* conv,
- std::function<absl::StatusOr<bool>(HloCustomCallInstruction* conv,
- std::vector<Shape>* new_input_shapes,
- Shape* new_result_shape)>
- resolve_pad_shapes) {
- std::vector<Shape> new_input_shapes;
- Shape new_result_shape;
- TF_ASSIGN_OR_RETURN(bool result, resolve_pad_shapes(conv, &new_input_shapes,
- &new_result_shape));
- if (result) {
- TF_RETURN_IF_ERROR(PadConv(conv, new_input_shapes, new_result_shape));
- return true;
- }
- return false;
-}
-
-// Adds padding to cudnn convolutions to make them run faster on GPUs with
-// tensor cores.
-//
-// - f16 convolutions are padded to have input/output channel dimensions that
-// are multiples of 8, so that we can use tensor cores.
-//
-// - f16 convolutions with 3 input channels and 32 or 64 output channels are
-// padded to 4 input channels. There's a special-cased cudnn algorithm just
-// for this.
-//
-// Don't run this pass on GPUs without tensor cores -- it will make them slower!
-//
-// TODO(jlebar): Also pad dots.
-static absl::StatusOr<bool> TryResolvePaddedShapesForTensorCore(
- HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
- Shape* new_result_shape_ptr) {
- TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
- const auto& dnums = conv->convolution_dimension_numbers();
- auto* lhs = conv->mutable_operand(0);
- auto* rhs = conv->mutable_operand(1);
- const Shape& result_shape = conv->shape().tuple_shapes(0);
-
- // Nothing to do on non-f16 convolutions.
- if (result_shape.element_type() != PrimitiveType::F16) {
- return false;
- }
-
- // When convolution is grouped, the shapes are in agreement with the group
- // size. We cannot pad them independently.
- if (conv->feature_group_count() > 1 || conv->batch_group_count() > 1) {
- VLOG(2) << "Do not pad grouped convolution.";
- return false;
- }
-
- // TODO(timshen): Don't skip forward-activation convs if we find a benchmark
- // where there's a speedup.
- if (kind == CudnnConvKind::kForwardActivation) {
- return false;
- }
-
- Shape new_lhs_shape = lhs->shape();
- Shape new_rhs_shape = rhs->shape();
- Shape& new_result_shape = *new_result_shape_ptr;
- new_result_shape = conv->shape().tuple_shapes(0);
-
- // new_{input,filter_output}_shape points to the appropriate one of
- // new_{lhs,rhs,result}_shape.
- Shape* new_input_shape;
- Shape* new_filter_shape;
- Shape* new_output_shape;
- std::tie(new_input_shape, new_filter_shape, new_output_shape) = [&] {
- switch (kind) {
- case CudnnConvKind::kForward:
- case CudnnConvKind::kForwardActivation:
- case CudnnConvKind::kForwardGraph:
- return std::make_tuple(&new_lhs_shape, &new_rhs_shape,
- &new_result_shape);
- case CudnnConvKind::kBackwardInput:
- return std::make_tuple(&new_result_shape, &new_rhs_shape,
- &new_lhs_shape);
- case CudnnConvKind::kBackwardFilter:
- return std::make_tuple(&new_lhs_shape, &new_result_shape,
- &new_rhs_shape);
- }
- }();
-
- // If there are 3 input features and 32 or 64 output features, pad the input
- // features to 4. Otherwise, try padding to multiples of 8 and check that
- // this doesn't make any of the conv buffers too much larger.
- auto input_features =
- new_input_shape->dimensions(dnums.input_feature_dimension());
- auto output_features =
- new_output_shape->dimensions(dnums.output_feature_dimension());
- if (input_features == 3 && (output_features == 32 || output_features == 64)) {
- new_input_shape->set_dimensions(dnums.input_feature_dimension(), 4);
- new_filter_shape->set_dimensions(dnums.kernel_input_feature_dimension(), 4);
- } else {
- auto pad_dim = [](Shape* s, int64_t dim) {
- s->set_dimensions(dim, RoundUpTo<int64_t>(s->dimensions(dim), 8));
- };
- pad_dim(new_input_shape, dnums.input_feature_dimension());
- pad_dim(new_filter_shape, dnums.kernel_input_feature_dimension());
- pad_dim(new_filter_shape, dnums.kernel_output_feature_dimension());
- pad_dim(new_output_shape, dnums.output_feature_dimension());
-
- // We won't pad a conv if doing so increases the total number of bytes in
- // the lhs, rhs, or result by more than this amount.
- //
- // TODO(jlebar): This number was tuned experimentally. It represents a
- // compromise on our current benchmarks; it speeds some up significantly,
- // and doesn't slow any down. But we can observe by changing this value
- // that there's additional room for speedups. Achieving those speedups
- // without also slowing other things down will likely require a more
- // sophisticated heuristic, possibly some form of auto-tuning.
- static constexpr double kMaxBytesTouchedBound = 1.35;
-
- // Check that padding wouldn't increase the total bytes read/written by this
- // operation too much.
- auto check_size_increase = [&](const Shape& old_shape,
- const Shape& new_shape) {
- int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
- int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
- if (new_bytes <= old_bytes * kMaxBytesTouchedBound) {
- return true;
- }
- VLOG(3)
- << "Not padding convolution; doing so would change input / result "
- "shape from "
- << ShapeUtil::HumanString(old_shape) << " to "
- << ShapeUtil::HumanString(new_shape) << ", a size increase of "
- << new_bytes / static_cast<double>(old_bytes) << "x > "
- << kMaxBytesTouchedBound << "x: " << conv->ToString();
- return false;
- };
-
- if (!check_size_increase(lhs->shape(), new_lhs_shape) ||
- !check_size_increase(rhs->shape(), new_rhs_shape) ||
- !check_size_increase(result_shape, new_result_shape)) {
- return false;
- }
- }
-
- if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
- ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) {
- VLOG(3) << "No need to pad features of " << conv->ToString();
- return false;
- }
-
- new_input_shapes_ptr->push_back(new_lhs_shape);
- new_input_shapes_ptr->push_back(new_rhs_shape);
- return true;
-}
-
-// Adds padding to cudnn integer convolutions to make input and output feature
-// maps multiples of pad_to (usually 4 or 32).
-absl::StatusOr<bool> TryResolvePaddedShapesForIntegerConvolution(
- int pad_to, const se::CudaComputeCapability& compute_capability,
- HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
- Shape* new_result_shape_ptr) {
- TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
- const Shape& input_shape = conv->operand(0)->shape();
- const Shape& kernel_shape = conv->operand(1)->shape();
- const Shape& result_shape = conv->shape().tuple_shapes(0);
-
- // Integer convolution only
- if (!primitive_util::IsIntegralType(input_shape.element_type())) {
- return false;
- }
-
- // kForward and kForwardActivation only
- if (kind != CudnnConvKind::kForward &&
- kind != CudnnConvKind::kForwardActivation) {
- return false;
- }
-
- const auto& dnums = conv->convolution_dimension_numbers();
- std::vector<Shape>& new_input_shapes = *new_input_shapes_ptr;
- for (auto operand : conv->operands()) {
- new_input_shapes.push_back(operand->shape());
- }
- Shape& new_result_shape = *new_result_shape_ptr;
- new_result_shape = conv->shape().tuple_shapes(0);
-
- // The input/kernel/output might already be vectorized (i.e. cudnn layout
- // NCHW_VECT_C). If so, we pad the features dim so that
- // size(features_dim) * size(vect_dim) is a multiple of pad_to.
- std::optional<int64_t> input_vect_dim;
- std::optional<int64_t> kernel_vect_dim;
- std::optional<int64_t> result_vect_dim;
- std::tie(input_vect_dim, kernel_vect_dim, result_vect_dim) =
- FindVectorizedFeatureDims(dnums, input_shape, kernel_shape, result_shape);
-
- int64_t input_vect_size =
- input_vect_dim.has_value() ? input_shape.dimensions(*input_vect_dim) : 1;
- int64_t kernel_vect_size = kernel_vect_dim.has_value()
- ? kernel_shape.dimensions(*kernel_vect_dim)
- : 1;
- int64_t result_vect_size = result_vect_dim.has_value()
- ? result_shape.dimensions(*result_vect_dim)
- : 1;
- if (pad_to % input_vect_size != 0 || pad_to % kernel_vect_size != 0 ||
- pad_to % result_vect_size != 0) {
- // If the conv is already vectorized but pad_to is not a multiple of the
- // vector size, we choose not to pad. This is a weird case, because the
- // only useful vector sizes in cudnn (as of writing) are 4 and 32, and those
- // are also the only pad_to cases.
- return false;
- }
-
- // Check that cudnn support our desired integer padding/vectorization.
- TF_ASSIGN_OR_RETURN(bool cudnn_supports,
- CudnnSupportsOptimizedIntegerConvolution(
- compute_capability, *conv, pad_to));
- if (!cudnn_supports) {
- return false;
- }
-
- // Pad the features to multiples of pad_to.
- {
- auto pad_dim = [&](Shape* s, int64_t dim, int64_t cur_vect_size) {
- CHECK_EQ(pad_to % cur_vect_size, 0);
- s->set_dimensions(
- dim, RoundUpTo<int64_t>(s->dimensions(dim), pad_to / cur_vect_size));
- };
-
- switch (kind) {
- case CudnnConvKind::kForward:
- CHECK_EQ(new_input_shapes.size(), 2);
- // Input feature maps
- pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
- input_vect_size);
- // Kernel for the input feature maps
- pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
- kernel_vect_size);
- // Kernel for the output feature maps. In the NCHW_VECT_C, only the
- // kernel input feature dim is vectorized, so this has cur_vect_size 1.
- pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
- /*cur_vect_size=*/1);
- // Output feature maps
- pad_dim(&new_result_shape, dnums.output_feature_dimension(),
- result_vect_size);
- break;
- case CudnnConvKind::kForwardActivation:
- CHECK(new_input_shapes.size() == 3 || new_input_shapes.size() == 4);
- // Input feature maps
- pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
- input_vect_size);
- // Kernel for the input feature maps
- pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
- kernel_vect_size);
- // Kernel for the output feature maps. In the NCHW_VECT_C, only the
- // kernel input feature dim is vectorized, so this has cur_vect_size 1.
- pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
- /*cur_vect_size=*/1);
-
- // Bias. This ia 1D vector of length output-depth, and it's unclear if
- // we *have* to pad it. But hey, we might as well. cur_vect_size 1
- // because NCHW_VECT_C doesn't apply here (there is no channels
- // dimension!).
- pad_dim(&new_input_shapes[2], /*dim=*/0, /*cur_vect_size=*/1);
-
- if (new_input_shapes.size() == 4) {
- // Optional side input. Same layout as result, so gets padded the
- // same.
- pad_dim(&new_input_shapes[3], dnums.output_feature_dimension(),
- result_vect_size);
- }
- // Output feature maps
- pad_dim(&new_result_shape, dnums.output_feature_dimension(),
- result_vect_size);
- break;
- default:
- CHECK(false);
- }
-
- // We won't pad a conv if doing so increases the total number of bytes in
- // the lhs, rhs, or result by a factor of this much or more.
- //
- // Note: It's important that this bound is exclusive. It's a performance
- // regression to pad and increase input/output size by 2x, so we only pad
- // strictly less than 2x.
- //
- // TODO(jlebar): This number was tuned experimentally, but without much
- // experimental evidence.
- static constexpr double kMaxBytesTouchedBound = 2;
-
- // Check that padding wouldn't increase the total bytes read/written by this
- // operation too much.
- auto check_size_increase = [&](const Shape& old_shape,
- const Shape& new_shape) {
- int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
- int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
- if (new_bytes < old_bytes * kMaxBytesTouchedBound) {
- return true;
- }
- VLOG(3)
- << "Not padding convolution; doing so would change input / result "
- "shape from "
- << ShapeUtil::HumanString(old_shape) << " to "
- << ShapeUtil::HumanString(new_shape) << ", a size increase of "
- << new_bytes / static_cast<double>(old_bytes)
- << "x >= " << kMaxBytesTouchedBound << "x: " << conv->ToString();
- return false;
- };
-
- // Check size increase only on the input and output. No need to check the
- // filter, since that's determined by the input/output. The bias (if
- // present) is tiny (1D array of length output-depth), so padding doesn't
- // matter. And the side-input, if present, is the same shape as the input.
- if (!check_size_increase(conv->operand(0)->shape(), new_input_shapes[0]) ||
- !check_size_increase(result_shape, new_result_shape)) {
- return false;
- }
- }
-
- bool changed = false;
- for (int64_t i = 0; i < conv->operand_count(); ++i) {
- changed |=
- !ShapeUtil::Equal(conv->operand(i)->shape(), new_input_shapes[i]);
- }
- if (!changed) {
- VLOG(3) << "No need to pad features of " << conv->ToString();
- }
-
- return changed;
-}
-
-absl::StatusOr<bool> CudnnPadForConvolutions::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
- // On Turing and later (sm75+), pad to multiples of 32 bytes if possible,
- // because that lets us use the fast int8x32 data type.
- bool local_changed = false;
- if (compute_capability_.IsAtLeast(7, 5)) {
- TF_ASSIGN_OR_RETURN(
- local_changed,
- ResolveAndPad(conv, absl::bind_front(
- TryResolvePaddedShapesForIntegerConvolution,
- 32, compute_capability_)));
- }
- if (!local_changed) {
- TF_ASSIGN_OR_RETURN(
- local_changed,
- ResolveAndPad(conv, absl::bind_front(
- TryResolvePaddedShapesForIntegerConvolution,
- 4, compute_capability_)));
- }
- changed |= local_changed;
- }
- if (compute_capability_.IsAtLeast(se::CudaComputeCapability::VOLTA)) {
- for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
- TF_ASSIGN_OR_RETURN(
- bool local_changed,
- ResolveAndPad(conv, TryResolvePaddedShapesForTensorCore));
- changed |= local_changed;
- }
- }
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h b/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h
deleted file mode 100644
index be7fae2..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_
-#define XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Two zero-paddings for CuDNN thunking are done in this transform: padding for
-// tensor cores and padding for integer convolutions. This transform also
-// add slice instruction to remove unnecessary output features.
-class CudnnPadForConvolutions : public HloModulePass {
- public:
- explicit CudnnPadForConvolutions(se::CudaComputeCapability compute_capability)
- : compute_capability_(compute_capability) {}
-
- absl::string_view name() const override {
- return "cudnn_pad_for_convolutions";
- }
- // Run PadForConvolutions on the given module and return if any change is made
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const se::CudaComputeCapability compute_capability_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc b/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc
deleted file mode 100644
index 2bae239..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc
+++ /dev/null
@@ -1,456 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = xla::match;
-
-class CudnnPadForConvolutionsTest : public HloTestBase {};
-
-TEST_F(CudnnPadForConvolutionsTest, DoNotPadF16ForwardConvWhenGrouped) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = f16[704,48,1,49]{3,2,1,0} parameter(0)
- filter = f16[44,768,1,50]{3,2,1,0} parameter(1)
- ROOT result = (f16[1,128,48,768]{3,2,1,0}, u8[0]{0})
- custom-call(input, filter)
- , window={size=1x50 pad=0_0x64_64}
- , dim_labels=fb01_io01->01bf
- , feature_group_count=16
- , custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvInputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = f16[10,20,30,41] parameter(0)
- filter = f16[2,2,41,40] parameter(1)
- ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
-
- SCOPED_TRACE(module->ToString());
-
- EXPECT_THAT(
- root,
- GmockMatch(m::CustomCall(
- {kCudnnConvForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
- m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 48, 40}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvOutputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- output = f16[10,20,30,41] parameter(0)
- filter = f16[2,2,40,41] parameter(1)
- ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convBackwardInput"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(m::CustomCall(
- {kCudnnConvBackwardInputCallTarget},
- m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
- m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 40, 48}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvOutputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = f16[10,20,30,40] parameter(0)
- filter = f16[2,2,40,41] parameter(1)
- ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvForwardCallTarget}, m::Parameter(0),
- m::Pad(m::Parameter(1), m::Op())))),
- m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvInputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- output = f16[10,20,30,40] parameter(0)
- filter = f16[2,2,41,40] parameter(1)
- result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convBackwardInput"
- ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root,
- GmockMatch(m::GetTupleElement(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvBackwardInputCallTarget}, m::Parameter(0),
- m::Pad(m::Parameter(1), m::Op())))),
- m::Op()))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvInputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = f16[10,20,30,41] parameter(0)
- output = f16[10,20,30,40] parameter(1)
- result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convBackwardFilter"
- ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root,
- GmockMatch(m::GetTupleElement(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvBackwardFilterCallTarget},
- m::Pad(m::Parameter(0), m::Op()), m::Parameter(1)))),
- m::Op()))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvOutputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = f16[10,20,30,40] parameter(0)
- output = f16[10,20,30,41] parameter(1)
- result = (f16[2,2,40,41], u8[0]) custom-call(input, output),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convBackwardFilter"
- ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root,
- GmockMatch(m::GetTupleElement(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvBackwardFilterCallTarget}, m::Parameter(0),
- m::Pad(m::Parameter(1), m::Op())))),
- m::Op()))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInputFeatures3To4) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = f16[10,20,30,3] parameter(0)
- filter = f16[2,2,3,32] parameter(1)
- ROOT result = (f16[10,20,30,32], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
-
- SCOPED_TRACE(module->ToString());
- EXPECT_THAT(
- root,
- GmockMatch(m::CustomCall(
- {kCudnnConvForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 4}),
- m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 4, 32}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvInputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,41] parameter(0)
- filter = s8[2,2,41,40] parameter(1)
- ROOT result = (f32[10,20,30,40], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
-
- SCOPED_TRACE(module->ToString());
- EXPECT_THAT(
- root,
- GmockMatch(m::CustomCall(
- {kCudnnConvForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 44}),
- m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 44, 40}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvOutputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,40] parameter(0)
- filter = s8[2,2,40,41] parameter(1)
- ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvForwardCallTarget}, m::Parameter(0),
- m::Pad(m::Parameter(1), m::Op())))),
- m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInt8To32OnSm75) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,40] parameter(0)
- filter = s8[2,2,40,41] parameter(1)
- ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 64}),
- m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 64, 64})))),
- m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32OnSm70) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,40] parameter(0)
- filter = s8[2,2,40,41] parameter(1)
- ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvForwardCallTarget}, m::Parameter(0),
- m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
- m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32FloatOutputSm75) {
- // This test checks that the padding pass correctly calls
- // CudnnSupportsOptimizedIntegerConvolution() which should reject this
- // convolution because its output type is f32. It should be padded to int8x4
- // because that supports float outputs.
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,38] parameter(0)
- filter = s8[2,2,38,41] parameter(1)
- ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 40}),
- m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
- m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadInt8UnsupportedFilterTypeOutputSm75) {
- // This test checks that the padding pass correctly calls
- // CudnnSupportsOptimizedIntegerConvolution() which should reject this
- // convolution because kernel type is f32.
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,38] parameter(0)
- filter = f32[2,2,38,41] parameter(1)
- ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadToInt8x32ExcessiveBlowup) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[128,4,48,48] parameter(0)
- filter = s8[64,4,3,3] parameter(1)
- ROOT result = (f32[128,64,48,48], u8[0]) custom-call(input, filter),
- window={size=3x3}, dim_labels=bf01_io01->bf01,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,41,4] parameter(0)
- filter = s8[2,2,41,4,168] parameter(1)
- ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Slice(m::GetTupleElement(
- m::CustomCall({kCudnnConvForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op())
- .WithShape(S8, {10, 20, 30, 48, 4}),
- m::Pad(m::Parameter(1), m::Op())
- .WithShape(S8, {2, 2, 48, 4, 192})))
- .WithShape(S8, {10, 20, 30, 48, 4})),
- m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32BiasActivation) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,41,4] parameter(0)
- filter = s8[2,2,41,4,168] parameter(1)
- bias = f32[10] parameter(2)
- side_input = s8[10,20,30,42,4] parameter(3)
- ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter, bias, side_input),
- window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
- custom_call_target="__cudnn$convBiasActivationForward"
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Slice(
- m::GetTupleElement(
- m::CustomCall(
- {kCudnnConvBiasActivationForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op())
- .WithShape(S8, {10, 20, 30, 48, 4}),
- m::Pad(m::Parameter(1), m::Op())
- .WithShape(S8, {2, 2, 48, 4, 192}),
- m::Pad(m::Parameter(2), m::Op()).WithShape(F32, {32}),
- m::Pad(m::Parameter(3), m::Op())
- .WithShape(S8, {10, 20, 30, 48, 4})))
- .WithShape(S8, {10, 20, 30, 48, 4})),
- m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest,
- PadIntFusedForwardConvInputAndOutputChannels) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule Test
-
- ENTRY %Test (input: s8[1,3,3,2], filter: s8[3,3,2,5], side_input: s8[1,3,3,5], bias: s8[5]) -> f32[1,3,3,5] {
- %input = s8[1,3,3,3]{3,2,1,0} parameter(0)
- %filter = s8[3,3,2,5]{3,2,1,0} parameter(1)
- %bias = s8[5]{0} parameter(3)
- %convert = f32[5]{0} convert(s8[5]{0} %bias)
- %side_input = f32[1,3,3,5]{3,2,1,0} parameter(2)
- %custom-call.1 = (f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,3,3,3]{3,2,1,0} %input, s8[3,3,2,5]{3,2,1,0} %filter, f32[5]{0} %convert, f32[1,3,3,5]{3,2,1,0} %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"activationMode\":\"2\",\"convResultScale\":1,\"sideInputScale\":1}"
- ROOT %get-tuple-element.1 = f32[1,3,3,5]{3,2,1,0} get-tuple-element((f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) %custom-call.1), index=0
- })")
- .value();
- EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::GetTupleElement(m::Tuple(
- m::Slice(m::GetTupleElement(m::CustomCall(
- {kCudnnConvBiasActivationForwardCallTarget},
- m::Pad(m::Parameter(0), m::Op()),
- m::Pad(m::Parameter(1), m::Op()),
- m::Pad(m::Convert(m::Parameter(3)), m::Op()),
- m::Pad(m::Parameter(2), m::Op())))),
- m::Op()))));
-}
-
-} // anonymous namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc b/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc
deleted file mode 100644
index c8f87f7..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc
+++ /dev/null
@@ -1,482 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_simplify_padding.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-namespace {
-namespace m = ::xla::match;
-
-// If exactly one index of `dims` is false, returns that index. If 0 or more
-// than one index is false, returns nullopt.
-std::optional<int64_t> FindFalseIndex(absl::Span<const bool> vals) {
- std::optional<int64_t> missing_dim;
- for (int i = 0; i < vals.size(); i++) {
- if (vals[i]) {
- continue;
- }
- if (missing_dim.has_value()) {
- VLOG(2) << "Multiple dimensions are missing from conv dnums; can't "
- "determine which is vect_c dimension";
- return std::nullopt;
- }
- missing_dim = i;
- }
- return missing_dim;
-}
-
-// Finds the vect_c dimension in the convolution's output.
-//
-// The vect_c dimension in dnums is the dimension that's not mentioned in
-// `dnums`. If there's zero or more than one such dimension, returns nullopt.
-std::optional<int64_t> FindOutputVectCDim(HloInstruction* conv) {
- const ConvolutionDimensionNumbers& dnums =
- conv->convolution_dimension_numbers();
- int64_t num_dims = conv->shape().tuple_shapes(0).dimensions_size();
- absl::InlinedVector<bool, 5> seen_dims(num_dims);
- seen_dims[dnums.output_batch_dimension()] = true;
- seen_dims[dnums.output_feature_dimension()] = true;
- for (int64_t d : dnums.output_spatial_dimensions()) {
- seen_dims[d] = true;
- }
- return FindFalseIndex(seen_dims);
-}
-
-// Finds the vect_c dimension in the convolution's kernel.
-std::optional<int64_t> FindKernelVectCDim(HloInstruction* conv) {
- const ConvolutionDimensionNumbers& dnums =
- conv->convolution_dimension_numbers();
- int64_t num_dims = conv->operand(1)->shape().dimensions_size();
- absl::InlinedVector<bool, 5> seen_dims(num_dims);
- seen_dims[dnums.kernel_input_feature_dimension()] = true;
- seen_dims[dnums.kernel_output_feature_dimension()] = true;
- for (int64_t d : dnums.kernel_spatial_dimensions()) {
- seen_dims[d] = true;
- }
- return FindFalseIndex(seen_dims);
-}
-
-// Attempts to count the number of output features at the end of conv that are
-// guaranteed to be 0.
-//
-// This is the same as counting the number of values o at the end of the kernel
-// for which kernel[i,o,h,w] is 0 for all values i,h,w.
-std::optional<int64_t> NumTrailingZeroOutputFeatures(HloInstruction* conv) {
- const ConvolutionDimensionNumbers& dnums =
- conv->convolution_dimension_numbers();
- int64_t feature_dim = dnums.kernel_output_feature_dimension();
- const HloInstruction* weights = conv->operand(1);
-
- // If the filter is reordered for an int8x32 NCHW_VECT_C convolution, find the
- // original, un-reordered filter and check *it* for trailing zero output
- // features.
- auto backend_config = conv->backend_config<GpuBackendConfig>();
- if (backend_config.ok() &&
- backend_config->cudnn_conv_backend_config().reordered_int8_nchw_vect()) {
- VLOG(2) << "Matched int8x32 convolution with filter reordering";
-
- // Try to set weights to the original, un-reordered value.
- const HloInstruction *reshape, *transpose;
- bool matched =
- Match(weights, m::Reshape(m::Transpose(
- &transpose, m::Reshape(&reshape, m::Op(&weights)))));
-
- // Verify some properties of the reshape-transpose-reshape pattern.
- // If these don't hold, it means that some pass (e.g. constant folding)
- // has modified the filter, making making it infeasible to get the original,
- // un-reordered value.
- if (!matched || feature_dim != 0 || transpose->shape().rank() != 8) {
- VLOG(2) << "The filter output feature dimension cannot be determined, as "
- "the reordering sequence is modified";
- return std::nullopt;
- }
-
- // Calculate the actual output feature dimension before the transpose.
- // For example: the input filter [I, O, H, W] will get reshaped to
- // [I/32, 8, 4, O/8, 4, 2, H, W], transposed in a way that is compatible
- // with cuDNN INT8x32_CONFIG convolutions (see 'cudnn_support_utils.h') and
- // reshaped again to [O, I/32, H, W, 32]. While the output features
- // dimension is zero, we need to get the dimension in the original shape
- // (equals to one in this example).
- const auto& transpose_dimensions =
- Cast<HloTransposeInstruction>(transpose)->dimensions();
-
- // Calculate combined dimensions size before the first appearing output
- // component [O/8], which appears in position 3 of the transpose.
- int64_t preceding_size = 1;
- for (int64_t i = transpose_dimensions.at(3) - 1; i >= 0; --i) {
- preceding_size *= reshape->shape().dimensions(i);
- }
-
- // Skip dimensions in front until the output features dimension is found.
- int64_t accumulated_size = 1;
- for (int64_t size : weights->shape().dimensions()) {
- if (accumulated_size < preceding_size) {
- accumulated_size *= size;
- ++feature_dim;
- } else {
- break;
- }
- }
- // Sanity check; if this condition doesn't hold, something is off.
- if (accumulated_size != preceding_size) {
- VLOG(2) << "Something is really wrong here, I give up";
- return std::nullopt;
- }
- VLOG(2) << "Computed output feature dimension: " << feature_dim;
- }
-
- VLOG(2) << "Computing NumTrailingZeroOutputFeatures of " << conv->ToString()
- << "\nwith weights " << weights->ToString();
- if (Match(weights, m::Pad(m::Op(), m::ConstantEffectiveScalar(0)))) {
- const PaddingConfig::PaddingConfigDimension& padding_config =
- weights->padding_config().dimensions(feature_dim);
- // The last N output feature weights are all 0.
- VLOG(2) << "Success: Weights is a pad; padding on output feature dim is "
- << padding_config.edge_padding_high();
- return padding_config.edge_padding_high();
- } else if (const HloInstruction * pad; Match(
- weights, m::Reshape(m::Pad(&pad, m::Op(),
- m::ConstantEffectiveScalar(0))))) {
- // Check that the reshape merely adds a VECT_C to the kernel input features.
- // That is, we reshape from [I,O,H,W] (in some order) to [I/k,k,O,H,W] (in
- // the same order) for some constant k (probably 32). Then check how much
- // the pad adds to the O dimension.
- std::optional<int64_t> vect_c_dim = FindKernelVectCDim(conv);
- if (!vect_c_dim.has_value()) {
- VLOG(2) << "fail: Can't find vect_c dimension in conv.";
- return std::nullopt;
- }
- if (*vect_c_dim != dnums.kernel_input_feature_dimension() + 1) {
- VLOG(2) << "fail: vect_c dim is in the wrong place; should be right "
- "after kernel input feature dims in conv.";
- return std::nullopt;
- }
- absl::InlinedVector<int64_t, 5> expected_pad_dim_sizes(
- weights->shape().dimensions().begin(),
- weights->shape().dimensions().end());
- expected_pad_dim_sizes[dnums.kernel_input_feature_dimension()] *=
- weights->shape().dimensions(*vect_c_dim);
- expected_pad_dim_sizes.erase(expected_pad_dim_sizes.begin() + *vect_c_dim);
- if (pad->shape().dimensions() != expected_pad_dim_sizes) {
- VLOG(2) << "fail: Reshape doesn't simply merge vect_c dimension into "
- "input features dim "
- << weights->ToString() << " but expected dims "
- << absl::StrJoin(expected_pad_dim_sizes, ",");
- return std::nullopt;
- }
-
- // If the filter dnums are e.g. [I,O,H,W] then after reshape they are
- // [I/k,k,O,H,W] and the new index of O is greater less than before the
- // reshape (which we know only adds the I/k and k dims, which we also know
- // are contiguous). OTOH if the O comes before the I in the original, then
- // the index of O doesn't change after the reshape.
- int64_t feature_dim_before_reshape = feature_dim;
- if (dnums.kernel_output_feature_dimension() >
- dnums.kernel_input_feature_dimension()) {
- feature_dim_before_reshape--;
- }
- const PaddingConfig::PaddingConfigDimension& padding_config =
- pad->padding_config().dimensions(feature_dim_before_reshape);
-
- // The last N output feature weights are all 0.
- VLOG(2) << "Success: Weights is a reshape of a pad; padding on output "
- "feature dim is "
- << padding_config.edge_padding_high();
- return padding_config.edge_padding_high();
- } else if (Match(weights, m::Constant())) {
- // Iterate backwards over `weights` to find the index of the first nonzero
- // value.
- //
- // TODO(jlebar): This may be slow, because it iterates over potentially the
- // whole constant and does a multi_index -> linear_index conversion for each
- // element. If necessary we could rewrite this by using linear indices, but
- // we'd have to be careful of the fact that literals can have arbitrary
- // layouts, so you can't just iterate over the literal's bytes.
- const Literal& lit = weights->literal();
- const auto& dims = weights->shape().dimensions();
- absl::InlinedVector<int64_t, 5> multi_index;
- for (int64_t dim : dims) {
- multi_index.push_back(dim - 1);
- }
- // This iterates through the literal with feature_dim as the most
- // major dimension looking for the final non-zero feature.
- auto decrement_multi_index = [&] {
- for (int i = 0; i < multi_index.size(); ++i) {
- if (i != feature_dim) {
- int64_t& idx = multi_index[i];
- --idx;
- if (idx == -1) {
- idx = dims[i] - 1;
- } else {
- return true;
- }
- }
- }
- int64_t& idx = multi_index[feature_dim];
- --idx;
- return idx != -1;
- };
- do {
- if (!lit.IsZero(multi_index)) {
- break;
- }
- } while (decrement_multi_index());
-
- // The iteration stops if a feature has a non-zero value (or -1), but we
- // want the first zero feature which is always the next one (or 0 if -1).
- int64_t first_trailing_zero_feature = multi_index[feature_dim] + 1;
-
- if (first_trailing_zero_feature == 0) {
- VLOG(2) << "Weights constant is entirely zero.";
- } else {
- VLOG(2) << "First nonzero index in weights constant is "
- << absl::StrJoin(multi_index, ",");
- }
- int64_t ret =
- std::max<int64_t>(0, weights->shape().dimensions(feature_dim) -
- first_trailing_zero_feature);
- VLOG(2) << "Success: weights is a constant; num zero trailing output "
- "features is "
- << ret;
- return ret;
- }
- return std::nullopt;
-}
-
-absl::StatusOr<bool> TrySimplifyPadding(HloInstruction* instr) {
- // Match one of the following patterns.
- // conv -> slice -> pad
- // conv -> reshape -> slice-> pad
- // conv -> transpose -> reshape -> slice -> pad
- //
- // where `pad` (the root of the pattern) is `instr`.
- HloInstruction* conv;
- HloInstruction* transpose = nullptr; // optional
- HloInstruction* reshape = nullptr; // optional
- HloInstruction* slice;
- HloInstruction* pad;
- auto conv_matcher = m::GetTupleElement(
- m::CustomCall(&conv).WithPredicate([](const HloInstruction* instr) {
- return instr->custom_call_target() == kCudnnConvForwardCallTarget ||
- instr->custom_call_target() ==
- kCudnnConvBiasActivationForwardCallTarget;
- }),
- 0);
- auto pad_matcher = m::Pad(m::Op(), m::ConstantEffectiveScalar(0));
- if (!MatchAndLogIfFailed(instr, "conv-slice-pad",
- m::Pad(&pad, m::Slice(&slice, conv_matcher),
- m::ConstantEffectiveScalar(0)),
- VLOG_IS_ON(3), pad_matcher) &&
- !MatchAndLogIfFailed(
- instr, "conv-reshape-slice-pad",
- m::Pad(&pad, m::Slice(&slice, m::Reshape(&reshape, conv_matcher)),
- m::ConstantEffectiveScalar(0)),
- VLOG_IS_ON(3), pad_matcher) &&
- !MatchAndLogIfFailed(
- instr, "conv-transpose-reshape-slice-pad",
- m::Pad(&pad,
- m::Slice(&slice,
- m::Reshape(&reshape,
- m::Transpose(&transpose, conv_matcher))),
- m::ConstantEffectiveScalar(0)),
- VLOG_IS_ON(3), pad_matcher)) {
- return false;
- }
-
- VLOG(2) << "Found pattern to attempt to simplify:\n"
- << "conv: " << conv->ToString() //
- << "\ntranspose: "
- << (transpose != nullptr ? transpose->ToString() : "(null)")
- << "\nreshape: "
- << (reshape != nullptr ? reshape->ToString() : "(null)")
- << "\nslice: " << slice->ToString() //
- << "\npad: " << pad->ToString();
-
- // Now check that we can merge the slice into the pad, because the slice is
- // slicing off elements that we know are 0 and the pad is just adding those 0s
- // back.
- //
- // First, we have to check whether any of the output features at the end of
- // the conv are known to be 0.
- std::optional<int64_t> num_known_zero_output_features =
- NumTrailingZeroOutputFeatures(conv);
- if (!num_known_zero_output_features.has_value() ||
- *num_known_zero_output_features == 0) {
- VLOG(2) << "fail: Didn't find any known-zero output features";
- return false;
- }
-
- // We now know that some of the output features of the conv (starting at
- // known_zero_output_features_start_idx) are zero. Check if the
- // optional-reshape + optional-transpose + slice + pad combination is setting
- // all of these features to 0. If so, we can merge the slice into the pad.
- const auto& dnums = conv->convolution_dimension_numbers();
- int64_t output_feature_dim;
- if (reshape == nullptr) {
- CHECK_EQ(transpose, nullptr);
- output_feature_dim = dnums.output_feature_dimension();
- } else {
- std::optional<int64_t> vect_c_dim_before_transpose =
- FindOutputVectCDim(conv);
- if (!vect_c_dim_before_transpose.has_value()) {
- VLOG(2) << "Couldn't find vect_c output dim in conv.";
- return false;
- }
-
- // If there's no transpose, check that the vect_c dim is immediately after
- // the feature dim. OTOH if there is a transpose, check that the transpose
- // moves the vect_c dim immediately after the feature dim.
- int64_t feature_dim_after_transpose;
- int64_t vect_c_dim_after_transpose;
- if (transpose == nullptr) {
- feature_dim_after_transpose = dnums.output_feature_dimension();
- vect_c_dim_after_transpose = *vect_c_dim_before_transpose;
- } else {
- const auto& transpose_dims = transpose->dimensions();
- feature_dim_after_transpose = std::distance(
- transpose->dimensions().begin(),
- absl::c_find(transpose_dims, dnums.output_feature_dimension()));
- vect_c_dim_after_transpose = std::distance(
- transpose->dimensions().begin(),
- absl::c_find(transpose_dims, *vect_c_dim_before_transpose));
- }
- if (vect_c_dim_after_transpose != feature_dim_after_transpose + 1) {
- VLOG(2) << "fail: after transpose (if present), vect_c dim must appear "
- "immediately after output feature dim: Computed "
- "vect_d_dim_after_transpose to be "
- << vect_c_dim_after_transpose;
- return false;
- }
-
- // Now check that the reshape merges the feature + vect_c dims and
- // doesn't do anything else.
- absl::InlinedVector<int64_t, 5> expected_reshape_dim_sizes(
- reshape->operand(0)->shape().dimensions().begin(),
- reshape->operand(0)->shape().dimensions().end());
- expected_reshape_dim_sizes[feature_dim_after_transpose] *=
- expected_reshape_dim_sizes[vect_c_dim_after_transpose];
- expected_reshape_dim_sizes.erase(expected_reshape_dim_sizes.begin() +
- vect_c_dim_after_transpose);
- if (reshape->shape().dimensions() != expected_reshape_dim_sizes) {
- VLOG(2) << "fail: Reshape doesn't merge vect_c with feature dimension.";
- return false;
- }
-
- output_feature_dim = feature_dim_after_transpose;
- }
-
- // Check that `slice` slices only the output feature dimension.
- if (!absl::c_all_of(slice->slice_starts(), [](auto v) { return v == 0; }) ||
- !absl::c_all_of(slice->slice_strides(), [](auto v) { return v == 1; })) {
- VLOG(2) << "fail: Slice doesn't start at the front or has stride != 1.";
- return false;
- }
-
- // We're only allowed to slice the feature dim.
- for (int64_t dim = 0; dim < slice->slice_limits().size(); dim++) {
- if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
- (dim != output_feature_dim &&
- slice->slice_limits(dim) !=
- slice->operand(0)->shape().dimensions(dim))) {
- VLOG(2) << "fail: Slice removes something other than the features dim.";
- return false;
- }
- }
- int64_t num_sliced_from_feature_dim =
- slice->operand(0)->shape().dimensions(output_feature_dim) -
- slice->slice_limits(output_feature_dim);
-
- // If we slice off more than the known-zero output features, then we need to
- // keep the slice -- it's "doing something".
- if (num_sliced_from_feature_dim > *num_known_zero_output_features) {
- VLOG(2) << "fail: Slice removes " << num_sliced_from_feature_dim
- << " features from the conv, but only "
- << *num_known_zero_output_features
- << " features in the conv are known to be zero.";
- return false;
- }
-
- // Check if we can merge the slice into the pad.
- if (pad->padding_config().dimensions(output_feature_dim).interior_padding() !=
- 0) {
- VLOG(2)
- << "fail: Can't merge slice into pad because pad adds interior padding "
- "in feature dimension.";
- return false;
- }
-
- // Okay! If we got here, it's legal to fold the slice into the pad. We pad
- // less, because we know that the sliced-off elements are all 0. Ideally, the
- // pad becomes a nop and gets eliminated by algsimp later.
- VLOG(1) << "Eliminating " << num_sliced_from_feature_dim
- << " elements of padding from conv " << conv->name();
- PaddingConfig new_padding_config = pad->padding_config();
- PaddingConfig::PaddingConfigDimension* new_pad_feature_dim =
- new_padding_config.mutable_dimensions(output_feature_dim);
- // This is safe even if the new edge_padding_high is negative -- negative
- // padding is allowed.
- new_pad_feature_dim->set_edge_padding_high(
- new_pad_feature_dim->edge_padding_high() - num_sliced_from_feature_dim);
- TF_ASSIGN_OR_RETURN(HloInstruction * new_pad,
- MakePadHlo(slice->mutable_operand(0),
- pad->mutable_operand(1), new_padding_config));
- TF_RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad));
- return true;
-}
-
-} // anonymous namespace
-
-absl::StatusOr<bool> CudnnSimplifyPadding::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- TF_ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr));
- changed |= c;
- }
- }
- return changed;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h b/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h
deleted file mode 100644
index 5811d26..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
-#define XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// Simplifies or eliminates padding introduced by CudnnPadForConvolutions and
-// CudnnVectorizeConvolutions.
-//
-// CudnnVectorizeConvolutions will generate code that does the following.
-// - pad input and output features to a multiple of 32 (or 4),
-// - reshape input from [N,C,H,W] to [N,C/32,H,W,32] and reshape kernel from
-// [I,O,H,W] to [I/32,32,O,H,W],
-// - run the conv,
-// - reshape output from [N,C/32,H,W,32] to [N,C,H,W], and finally
-// - slice off the padding on the C channel.
-//
-// But if this is followed by another convolution (very common), then the slice
-// is immediately followed by another pad. This may be redundant; we know that
-// the trailing channels sliced off from the first conv are 0.
-//
-// Ideally we can eliminate the whole reshape+slice+pad+reshape sequence between
-// the two convolutions.
-//
-// Specifically, this pass tries to merge the slice at the end of the sequence
-// above into the pad from the next convolution (when we can prove that the
-// sliced-off elements are all 0). We then rely on algsimp to remove the pad if
-// it's a nop and then to merge and eliminate the remaining reshapes.
-//
-// This pass should run after CudnnVectorizeConvolutions and there should be no
-// simplification passes in between that modify the reshape-transpose-reshape
-// introduced by int8x32 convolution filter reordering.
-class CudnnSimplifyPadding : public HloModulePass {
- public:
- CudnnSimplifyPadding() = default;
-
- absl::string_view name() const override { return "cudnn_simplify_padding"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc
deleted file mode 100644
index a0e527a..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc
+++ /dev/null
@@ -1,771 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_simplify_padding.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/functional/function_ref.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "xla/literal.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/call_inliner.h"
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-#include "xla/service/hlo_pass_fix.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/reshape_mover.h"
-#include "xla/service/tuple_simplifier.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class CudnnSimplifyPaddingTest : public HloTestBase {
- protected:
- // Runs the whole relevant pass pipeline starting at CudnnPadForConvolutions.
- // This lets us test that we're matching the patterns that actually get
- // generated by padding+vectorization.
- absl::StatusOr<bool> RunEndToEnd(std::pair<int, int> compute_capability,
- HloModule* module) {
- se::CudaComputeCapability cc{compute_capability.first,
- compute_capability.second};
-
- TF_RETURN_IF_ERROR(
- RunHloPass(CudnnPadForConvolutions(cc), module).status());
-
- TF_RETURN_IF_ERROR(
- RunHloPass(CudnnVectorizeConvolutions(
- cc, /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0}),
- module)
- .status());
- VLOG(1) << "after vectorizing convs:\n" << module->ToString();
-
- TF_RETURN_IF_ERROR(RunHloPass(CallInliner(), module).status());
- VLOG(1) << "after inliner:\n" << module->ToString();
-
- TF_RETURN_IF_ERROR(RunHloPass(TupleSimplifier(), module).status());
- VLOG(1) << "after tuple simplifier:\n" << module->ToString();
-
- TF_ASSIGN_OR_RETURN(bool changed,
- RunHloPass(CudnnSimplifyPadding(), module));
- VLOG(1) << "after simplify_padding:\n" << module->ToString();
-
- {
- // reshape-mover expects to be run alongside algsimp.
- HloPassFix<HloPassPipeline> pipeline("reshape-mover and algsimp");
- pipeline.AddPass<ReshapeMover>();
- pipeline.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions());
- TF_RETURN_IF_ERROR(RunHloPass(pipeline, module).status());
- }
- VLOG(1) << "after reshape mover + algsimp:\n" << module->ToString();
-
- return changed;
- }
-
- absl::StatusOr<bool> RunJustThisPass(HloModule* module) {
- TF_ASSIGN_OR_RETURN(bool changed,
- RunHloPass(CudnnSimplifyPadding(), module));
- VLOG(1) << "after simplify_padding:\n" << module->ToString();
-
- // I know the name says "just this pass", but you really want algsimp too,
- // otherwise the resulting patterns are ugly/hard to match.
- TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
- AlgebraicSimplifierOptions()),
- module)
- .status());
- return changed;
- }
-};
-
-void ExpectOnlyPadsOneDim(int64_t dim, int64_t padding_high,
- const PaddingConfig& p) {
- SCOPED_TRACE(p.DebugString());
- for (int i = 0; i < p.dimensions_size(); ++i) {
- SCOPED_TRACE(absl::StrCat("dimension ", i));
- EXPECT_EQ(p.dimensions(i).edge_padding_low(), 0);
- if (i == dim) {
- EXPECT_EQ(p.dimensions(i).edge_padding_high(), padding_high);
- } else {
- EXPECT_EQ(p.dimensions(i).edge_padding_high(), 0);
- }
- }
-}
-
-template <typename NativeT>
-void SetConstantValue(
- HloInstruction* instr,
- absl::FunctionRef<NativeT(absl::Span<const int64_t>, NativeT)> value_fn) {
- Literal new_literal = instr->literal().Clone();
- new_literal.MutableEachCell<int8_t>(value_fn);
- TF_EXPECT_OK(instr->parent()->ReplaceWithNewInstruction(
- instr, HloInstruction::CreateConstant(std::move(new_literal))));
-}
-
-TEST_F(CudnnSimplifyPaddingTest, EndToEnd) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- conv1 = (s8[10,20,30,190], u8[0]) custom-call(
- s8[10,20,30,63] parameter(0), s8[3,5,63,190] parameter(1),
- f32[10] parameter(2), s8[10,20,30,190] parameter(3)),
- window={size=3x5}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convBiasActivationForward"
- conv1_result = get-tuple-element(conv1), index=0
- ROOT conv2 = (s8[10,20,30,29], u8[0]) custom-call(
- conv1_result, s8[3,5,190,29] parameter(4),
- f32[10] parameter(5), s8[10,20,30,29] parameter(6)),
- window={size=3x5}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convBiasActivationForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- // conv2 should be fed directly from conv1, without any intervening
- // reshapes/pads.
- EXPECT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Slice(m::Reshape(m::GetTupleElement(m::CustomCall(
- {"__cudnn$convBiasActivationForward"},
- m::GetTupleElement(
- m::CustomCall({"__cudnn$convBiasActivationForward"}), 0),
- m::Op(), m::Op(), m::Op())))),
- m::Op())));
-}
-
-TEST_F(CudnnSimplifyPaddingTest, EndToEndNCHW) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- conv1 = (s8[1,64,480,400], u8[0]) custom-call(
- s8[1,112,480,400] parameter(0), s8[3,3,112,64] parameter(1),
- f32[64] parameter(2)),
- window={size=3x3}, dim_labels=bf01_01io->bf01,
- custom_call_target="__cudnn$convBiasActivationForward"
- conv1_result = get-tuple-element(conv1), index=0
- convert = f32[1,64,480,400] convert(conv1_result)
- constant = f32[] constant(0.349002093)
- broadcast = f32[1,64,480,400] broadcast(constant)
- ROOT multiply = f32[1,64,480,400] multiply(convert, broadcast)
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
- // The SimplifyPadding pass itself does not do anything.
- EXPECT_FALSE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- // The reshape introduced by CudnnVectorizeConvolutions should have been moved
- // to the root.
- EXPECT_THAT(root, GmockMatch(m::Reshape(m::Multiply())));
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedWeights) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
- const HloInstruction* pad = nullptr;
- ASSERT_THAT(root,
- GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
- m::ConstantScalar(0))));
-
- ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
-}
-
-// This is similar to PaddedWeights, except the only 3 elements of the weights
-// are padded to 0 while we slice off 4 elements from the output features. As a
-// result, not all of the sliced elements are 0, and we can't merge the slice
-// into the pad that follows.
-TEST_F(CudnnSimplifyPaddingTest, PaddedWeightsNotPaddedEnough) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_3
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNCHW) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
- weights = s8[2,32,64,3,3] reshape(weights_p)
- conv = (s8[10,2,32,10,10], u8[0]) custom-call(
- s8[10,2,32,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=bf?01_i?o01->bf?01,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_5x0_0x0_0
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
- const HloInstruction* pad = nullptr;
- ASSERT_THAT(
- root, GmockMatch(
- m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
- m::ConstantScalar(0))));
-
- ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/1, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNHWC) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights_p = pad(s8[3,3,64,60] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
- weights = s8[3,3,2,32,64] reshape(weights_p)
- conv = (s8[10,10,10,2,32], u8[0]) custom-call(
- s8[10,10,10,2,32] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f?_01i?o->b01f?,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,60] slice(s8[10,10,10,64] reshape(conv_result)), slice={[0:10], [0:10], [0:10], [0:60]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
- const HloInstruction* pad = nullptr;
- ASSERT_THAT(
- root, GmockMatch(
- m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
- m::ConstantScalar(0))));
-
- ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedTransposedAndReshapedOutput) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
- weights = s8[2,32,64,3,3] reshape(weights_p)
- conv = (s8[10,2,10,10,32], u8[0]) custom-call(
- s8[10,2,10,10,32] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=bf01?_i?o01->bf01?,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
- slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
- const HloInstruction* pad = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Pad(
- &pad,
- m::Reshape(m::Transpose(m::GetTupleElement(m::CustomCall(), 0))),
- m::ConstantScalar(0))));
-
- ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/2, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeight) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(0),
- s8[3,3,10,10] constant({...})
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- // Set the constant's value. (The HLO text above sets it to all 0s.)
- {
- HloInstruction* weights = nullptr;
- ASSERT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
- m::Op(), m::Constant(&weights)))),
- m::Op())));
- SetConstantValue<int8_t>(
- weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
- if (dims[3] < 6) return 1;
- return 0;
- });
- }
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
- const HloInstruction* pad = nullptr;
- ASSERT_THAT(root,
- GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
- m::ConstantScalar(0))));
-
- ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeightIsNotLargeEnough) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(0),
- s8[3,3,10,10] constant({...})
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- // Set the constant's value. (The HLO text above sets it to all 0s.)
- {
- HloInstruction* weights = nullptr;
- ASSERT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
- m::Op(), m::Constant(&weights)))),
- m::Op())));
- SetConstantValue<int8_t>(
- weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
- // The sixth feature dimension (i.e. index 5) is only partially 0.
- if (dims[3] < 5 /*|| (dims[3] == 5 && dims[2] > 1)*/) return 0;
- return 1;
- });
- }
-
- // Some of the value sliced off are not 0, so we can't merge the slice into
- // the pad.
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, ReshapeDoesntMergeVectCDim) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
- weights = s8[2,64,3,3,32] reshape(weights_p)
- conv = (s8[10,2,10,10,32], u8[0]) custom-call(
- s8[10,2,10,10,32] parameter(1),
- weights_p
- ), window={size=3x3}, dim_labels=bf01?_io01?->bf01?,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInOutput) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
- weights = s8[2,64,3,3,32] reshape(weights_p)
- conv = (s8[10,2,10,10,4,8], u8[0]) custom-call(
- s8[10,2,10,10,32] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=bf01?_io01?->bf01??,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- conv_transposed = s8[10,2,4,8,10,10] transpose(conv_result), dimensions={0,1,4,5,2,3}
- slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInKernel) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
- weights = s8[2,64,3,3,4,8] reshape(weights_p)
- conv = (s8[10,2,10,10,32], u8[0]) custom-call(
- s8[10,2,10,10,32] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=bf01?_io01??->bf01?,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
- slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginning) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,9,10,6] slice(conv_result), slice={[0:10], [1:10], [0:10], [0:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginningOfFeatureDim) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,5] slice(conv_result), slice={[0:10], [0:10], [0:10], [1:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceHasStride) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,3] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6:2]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PadAddsInteriorPadding) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5_1
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceMoreElementsThanPad) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
- conv = (s8[10,10,10,10], u8[0]) custom-call(
- s8[10,10,10,10] parameter(1),
- weights
- ), window={size=3x3}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- conv_result = get-tuple-element(conv), index=0
- slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
- ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_2
- }
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
- const HloInstruction* slice = nullptr;
- // The pass creates a pad with negative padding; this is simplified by algsimp
- // into a slice.
- ASSERT_THAT(root, GmockMatch(m::Slice(
- &slice, m::GetTupleElement(m::CustomCall(), 0))));
- for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
- SCOPED_TRACE(i);
- EXPECT_EQ(slice->slice_starts(i), 0);
- EXPECT_EQ(slice->slice_strides(i), 1);
- if (i != 3) {
- EXPECT_EQ(slice->slice_limits(i), 10);
- } else {
- EXPECT_EQ(slice->slice_limits(i), 8);
- }
- }
-}
-
-TEST_F(CudnnSimplifyPaddingTest, NoChangeOnNonTrivialConstants) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule jit_outer
-
-ENTRY main.26 {
- reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
- constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
- { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- }, {
- { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- }, {
- { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
- cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
- get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
- slice.2 = f32[1,5,1,12]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:1], [0:12]}
- constant.0 = f32[] constant(0)
- ROOT pad.1 = f32[1,5,3,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x2_0x0_0
-}
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, NoChangeOnComplexSlices) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule jit_outer
-
-ENTRY main.26 {
- reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
- constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
- { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- }, {
- { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- }, {
- { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
- cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
- get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
- slice.2 = f32[1,5,5,4]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [2:6]}
- constant.0 = f32[] constant(0)
- ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_8
-}
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, ScanOrderFeatureDimLast) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule jit_outer
-
-ENTRY main.26 {
- reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
- constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
- { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- }, {
- { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 } },
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
- }, {
- { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
- { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
- cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
- get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
- slice.2 = f32[1,5,5,6]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [0:6]}
- constant.0 = f32[] constant(0)
- ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_6
-}
- )")
- .value();
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputFirst) {
- // Test feature dimension calculation from reordering transpose (oi01)
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
- s8[1,112,80,80] parameter(0), s8[63,112,3,3] parameter(1)),
- window={size=3x3}, dim_labels=bf01_oi01->bf01,
- custom_call_target="__cudnn$convForward"
- gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
- const.0 = s8[] constant(0)
- ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputLast) {
- // Test feature dimension calculation from reordering transpose (01io)
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
- s8[1,112,80,80] parameter(0), s8[3,3,112,63] parameter(1)),
- window={size=3x3}, dim_labels=bf01_01io->bf01,
- custom_call_target="__cudnn$convForward"
- gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
- const.0 = s8[] constant(0)
- ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-}
-
-} // anonymous namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc
deleted file mode 100644
index 3846f01..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc
+++ /dev/null
@@ -1,650 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-
-#include <cstdint>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/client/xla_builder.h"
-#include "xla/client/xla_computation.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/cudnn_support_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Finds convolutions that this pass may be able to transform, namely int8_t
-// cudnn forward or forward-bias-activation convolutions
-//
-// cudnn as of v8.2 supports the following data type combinations for forward
-// and forward-bias-activation convolutions. We have to make sure we only
-// vectorize to one of these supported configs.
-//
-// in out
-// int8x1 int8x1
-// int8x1 float
-// int8x1 int32_t
-//
-// int8x4 int8x4
-// int8x4 float
-//
-// int8x32 int8x32
-//
-// https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward
-//
-// For now we restrict ourselves to only the int8xN -> int8xN cases. We could
-// allow the int8x4 -> float case in the future if desirable.
-static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
- HloComputation* comp) {
- std::vector<HloCustomCallInstruction*> convs;
- for (HloInstruction* instr : comp->instructions()) {
- if (instr->opcode() != HloOpcode::kCustomCall ||
- (instr->custom_call_target() != kCudnnConvForwardCallTarget &&
- instr->custom_call_target() !=
- kCudnnConvBiasActivationForwardCallTarget) ||
- instr->operand_count() < 2) {
- continue;
- }
-
- PrimitiveType input_ty = instr->operand(0)->shape().element_type();
- PrimitiveType output_ty = instr->shape().tuple_shapes(0).element_type();
- if (input_ty == output_ty && (input_ty == S8 || input_ty == U8)) {
- convs.push_back(Cast<HloCustomCallInstruction>(instr));
- }
- }
- return convs;
-}
-
-// Converts an XlaBuilder into an HloComputation in the same module as
-// `sibling_computation`.
-//
-// Yes, we serialize/deserialize as a proto. :)
-static absl::StatusOr<HloComputation*> BuilderToHloComputation(
- XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
- TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
- TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
- HloModuleConfig config(program_shape);
- TF_ASSIGN_OR_RETURN(auto new_module,
- HloModule::CreateFromProto(comp.proto(), config));
-
- HloModule* dest_module = sibling_computation->parent();
- HloCloneContext context(dest_module);
- return dest_module->DeepCloneComputation(new_module->entry_computation(),
- &context);
-}
-
-// Reshapes `instr` so that it has an extra dimension of size `vect_size` right
-// after `dim`.
-static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
- XlaBuilder& b = *instr.builder();
- Shape shape = b.GetShape(instr).value();
- DimensionVector new_dims(shape.dimensions().begin(),
- shape.dimensions().end());
- CHECK_EQ(new_dims[dim] % vect_size, 0);
- new_dims[dim] /= vect_size;
- new_dims.insert(new_dims.begin() + dim + 1, vect_size);
- return Reshape(instr, new_dims);
-}
-
-// Reshapes `shape` so that there's an extra dimension of size `vect_size` right
-// after `dim`.
-//
-// For example given shape=s8[10, 32, 20], dim=1, vect_size=4, returns
-// s8[10, 8, 4, 20].
-static Shape SplitShapeAtDim(Shape shape, int64_t dim, int64_t vect_size) {
- DimensionVector new_dims(shape.dimensions().begin(),
- shape.dimensions().end());
- CHECK_EQ(new_dims[dim] % vect_size, 0);
- new_dims[dim] /= vect_size;
- new_dims.insert(new_dims.begin() + dim + 1, vect_size);
- return ShapeUtil::MakeShape(shape.element_type(), new_dims);
-}
-
-// Transposes dimension `src` to right before `dst`.
-static XlaOp MoveDim(XlaOp instr, int64_t src, int64_t dst) {
- XlaBuilder& b = *instr.builder();
- int64_t rank = b.GetShape(instr)->dimensions_size();
-
- DimensionVector idxs(rank);
- absl::c_iota(idxs, 0);
- if (src < dst) {
- idxs.insert(idxs.begin() + dst, src);
- idxs.erase(idxs.begin() + src);
- } else {
- idxs.erase(idxs.begin() + src);
- idxs.insert(idxs.begin() + dst, src);
- }
- return Transpose(instr, idxs);
-}
-
-// Reshapes instr so that dimension `vect_dim` has size `vect_size`, by stealing
-// elements from `dim`.
-//
-// Requires that this is possible without merging and re-splitting the two
-// dimensions. I.e. there should be some amount of dim that we can "split off"
-// and add to vect_dim to get it to have size vect_size.
-static XlaOp RevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
- int64_t vect_size) {
- XlaBuilder& b = *instr.builder();
- Shape shape = b.GetShape(instr).value();
- auto size = [&](int64_t d) { return shape.dimensions(d); };
-
- CHECK_LE(size(vect_dim), vect_size);
- CHECK_EQ(vect_size % size(vect_dim), 0);
-
- int64_t split_factor = vect_size / size(vect_dim);
- CHECK_EQ(size(dim) % split_factor, 0);
-
- // Split dim into [C, split_factor].
- instr = SplitAtDim(instr, dim, split_factor);
-
- // SplitAtDim may have added a dimension before vect_dim.
- if (vect_dim > dim) {
- vect_dim++;
- }
-
- // Move the split_factor dimension to right before vect_dim.
- instr = MoveDim(instr, dim + 1, vect_dim);
-
- // Moving the split_factor dimension may have *removed* a dimension before
- // vect_dim.
- if (vect_dim > dim) {
- vect_dim--;
- }
-
- // Collapse the split_factor dimension into vect_dim.
- return Collapse(instr, {vect_dim, vect_dim + 1});
-}
-
-// Inverse of RevectorizeInstr. Reshapes instr so that dimension `vect_dim` has
-// size `vect_size`, moving excess elements into `dim`.
-static XlaOp UnrevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
- int64_t orig_vect_size) {
- XlaBuilder& b = *instr.builder();
- Shape shape = b.GetShape(instr).value();
- auto size = [&](int64_t d) { return shape.dimensions(d); };
-
- CHECK_GE(size(vect_dim), orig_vect_size);
- CHECK_EQ(size(vect_dim) % orig_vect_size, 0);
-
- // Split vect_dim into [C, orig_vect_size].
- instr = SplitAtDim(instr, vect_dim, orig_vect_size);
-
- // SplitAtDim may have added a dimension before dim.
- if (dim > vect_dim) {
- dim++;
- }
-
- // Move the `C` dimension to right after `dim`. Take into account that
- // SplitAtDim may have added a dimension before dim.
- instr = MoveDim(instr, vect_dim, dim + 1);
-
- // MoveDim may have *removed* a dimension before dim.
- if (dim > vect_dim) {
- dim--;
- }
-
- // Collapse the `C` and `dim` dimensions.
- return Collapse(instr, {dim, dim + 1});
-}
-
-// Adds a vectorized-feature dimension to dnums right after the current feature
-// dimension.
-//
-// ConvolutionDimensionNumbers doesn't represent the vectorized-feature
-// dimension explicitly, because the whole concept of a vectorized-feature
-// dimension is specific to cudnn. Rather, the vectorized-feature dimension is
-// implicit; it's the first dimension that *doesn't* appear in the dnums.
-//
-// This function "makes room" in dnums for the new vectorized dimension by
-// incrementing any dimensions which appear after the feature dim. The implicit
-// vector dim is then in this "empty" spot.
-static ConvolutionDimensionNumbers VectorizeDnums(
- ConvolutionDimensionNumbers dnums, bool reordered_filter) {
- int64_t input_vect_dim = dnums.input_feature_dimension();
- if (dnums.input_batch_dimension() > input_vect_dim) {
- dnums.set_input_batch_dimension(dnums.input_batch_dimension() + 1);
- }
- for (int64_t& d : *dnums.mutable_input_spatial_dimensions()) {
- if (d > input_vect_dim) {
- ++d;
- }
- }
-
- if (!reordered_filter) {
- int64_t kernel_vect_dim = dnums.kernel_input_feature_dimension();
- if (dnums.kernel_output_feature_dimension() > kernel_vect_dim) {
- dnums.set_kernel_output_feature_dimension(
- dnums.kernel_output_feature_dimension() + 1);
- }
- for (int64_t& d : *dnums.mutable_kernel_spatial_dimensions()) {
- if (d > kernel_vect_dim) {
- ++d;
- }
- }
- }
-
- int64_t output_vect_dim = dnums.output_feature_dimension();
- if (dnums.output_batch_dimension() > output_vect_dim) {
- dnums.set_output_batch_dimension(dnums.output_batch_dimension() + 1);
- }
- for (int64_t& d : *dnums.mutable_output_spatial_dimensions()) {
- if (d > output_vect_dim) {
- ++d;
- }
- }
-
- return dnums;
-}
-
-// Reorders the convolution's filter and bias (if present) according to
-// cudnnReorderFilterAndBias. Also marks that the filter + bias are reordered
-// in the conv's backend-config.
-absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv,
- XlaOp* operands) {
- bool has_bias = conv->operand_count() > 2;
- VLOG(1) << "Reordering filter" << (has_bias ? " and bias" : "")
- << " (replacement for cudnnReorderFilterAndBias)";
-
- auto builder = operands->builder();
- ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
-
- // Update convolution backend config.
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- conv->backend_config<GpuBackendConfig>());
- CudnnConvBackendConfig& config =
- *gpu_config.mutable_cudnn_conv_backend_config();
- config.set_reordered_int8_nchw_vect(true);
- TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
-
- // Reorder the filter.
- TF_ASSIGN_OR_RETURN(Shape filter_shape, builder->GetShape(operands[1]));
- TF_ASSIGN_OR_RETURN(auto reorder, CudnnInferTransposeForFilterReordering(
- filter_shape, dnums));
- XlaOp reshape = Reshape(reorder.transpose_shape, operands[1]);
- XlaOp transpose = Transpose(reshape, reorder.permutation);
- operands[1] = Reshape(reorder.result_shape, transpose);
-
- // The reshape-transpose-reshape we did above makes sure the resulting filter
- // has dimension numbers corresponding to "oihw?", so update them.
- dnums.set_kernel_output_feature_dimension(0);
- dnums.set_kernel_input_feature_dimension(1);
- dnums.set_kernel_spatial_dimensions(0, 2);
- dnums.set_kernel_spatial_dimensions(1, 3);
- conv->set_convolution_dimension_numbers(dnums);
-
- if (has_bias) {
- // Reorder the bias.
- TF_ASSIGN_OR_RETURN(Shape bias_shape, builder->GetShape(operands[2]));
- TF_ASSIGN_OR_RETURN(reorder,
- CudnnInferTransposeForBiasReordering(bias_shape));
- reshape = Reshape(reorder.transpose_shape, operands[2]);
- transpose = Transpose(reshape, reorder.permutation);
- operands[2] = Reshape(reorder.result_shape, transpose);
- }
- return absl::OkStatus();
-}
-
-// Tries to vectorize an already-vectorized convolution.
-//
-// That is, given a convolution of shape [N, C/k, H, W, k], changes it to have
-// shape [N, C/vect_size, H, W, vect_size]. Similarly changes the filter from
-// [H, W, I/k, O] to [H, W, I/vect_size, vect_size, O].
-//
-// (The dimensions can appear in any order; which is N/C/etc is determined by
-// the convolutions' dnums.)
-static absl::StatusOr<bool> TryRevectorizeConv(
- const se::CudaComputeCapability& compute_capability,
- const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
- int vect_size) {
- const Shape& input_shape = conv->operand(0)->shape();
- const Shape& kernel_shape = conv->operand(1)->shape();
- const Shape& output_shape = conv->shape().tuple_shapes(0);
- const ConvolutionDimensionNumbers* dnums =
- &conv->convolution_dimension_numbers();
-
- // Find the vectorized-features dim in the input/kernel/output.
- std::optional<int64_t> input_vect_dim;
- std::optional<int64_t> kernel_vect_dim;
- std::optional<int64_t> output_vect_dim;
- std::tie(input_vect_dim, kernel_vect_dim, output_vect_dim) =
- FindVectorizedFeatureDims(*dnums, input_shape, kernel_shape,
- output_shape);
-
- if (!input_vect_dim.has_value() || !kernel_vect_dim.has_value() ||
- !output_vect_dim.has_value()) {
- return false;
- }
-
- int64_t input_feat_size =
- input_shape.dimensions(dnums->input_feature_dimension());
- int64_t output_feat_size =
- output_shape.dimensions(dnums->output_feature_dimension());
- int64_t input_vect_size = input_shape.dimensions(*input_vect_dim);
- int64_t output_vect_size = output_shape.dimensions(*output_vect_dim);
- if (vect_size % input_vect_size != 0 || vect_size % output_vect_size != 0 ||
- input_feat_size % (vect_size / input_vect_size) != 0 ||
- output_feat_size % (vect_size / output_vect_size) != 0) {
- return false;
- }
-
- // If this is an integer convolution check that we only vectorize when cuDNN
- // supports the vectorized implementation.
- if (primitive_util::IsIntegralType(input_shape.element_type())) {
- TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
- CudnnSupportsOptimizedIntegerConvolution(
- compute_capability, *conv, vect_size));
- if (!supported_target_vectorization) {
- VLOG(3) << "Skipping re-vectorization of conv to vector size: "
- << vect_size << ": " << conv->ToString();
- return false;
- }
- }
-
- VLOG(1) << "Re-vectorizing conv channels from "
- << input_shape.dimensions(*input_vect_dim) << " to " << vect_size
- << ": " << conv->ToString();
-
- // We use XlaBuilder because it's a lot easier to get these tricky
- // reshape/transposes correct using that API.
- XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
- b.SetOpMetadata(conv->metadata());
-
- XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
- absl::InlinedVector<XlaOp, 4> new_operands = {
- RevectorizeInstr(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
- dnums->input_feature_dimension(), *input_vect_dim,
- vect_size),
- RevectorizeInstr(filter, dnums->kernel_input_feature_dimension(),
- *kernel_vect_dim, vect_size),
- };
- if (conv->operand_count() > 2) {
- // Bias, if present. This is passed through unmodified.
- new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
- }
- if (conv->operand_count() > 3) {
- new_operands.push_back(RevectorizeInstr(
- Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
- dnums->input_feature_dimension(), *input_vect_dim, vect_size));
- }
-
- if (conv->operand_count() > 4) {
- return InvalidArgument(
- "Don't understand a conv with more than 4 arguments: %s",
- conv->ToString());
- }
-
- // Reorder filter and bias for the int8x32 convolutions. This requires cudnn
- // >= 8.3.0.
- //
- // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
- const auto& debug_options = conv->GetModule()->config().debug_options();
- bool use_reordering =
- input_shape.element_type() == xla::S8 && vect_size == 32 &&
- debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
- cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
- if (use_reordering) {
- // Reordering helper supports vector sizes of 4 and 32, so an additional
- // reshape-transpose-reshape is not necessary in these cases.
- int64_t kernel_vect_size = kernel_shape.dimensions(*kernel_vect_dim);
- if (kernel_vect_size == 4 || kernel_vect_size == 32) {
- new_operands[1] = filter;
- }
- TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
- dnums = &conv->convolution_dimension_numbers();
- }
-
- // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
- // value in the tuple represents the convolution's scratch memory.
- DimensionVector new_output_dims(output_shape.dimensions().begin(),
- output_shape.dimensions().end());
- new_output_dims[dnums->output_feature_dimension()] /=
- (vect_size / output_vect_size);
- new_output_dims[*output_vect_dim] = vect_size;
- XlaOp new_conv = CustomCallWithConvDnums(
- &b, conv->custom_call_target(), new_operands,
- ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(output_shape.element_type(), new_output_dims),
- ShapeUtil::MakeShape(U8, {0})}),
- /*operand_shapes_with_layout=*/{},
- /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
- /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
- /*window=*/conv->window(),
- /*dnums=*/*dnums);
-
- XlaOp new_conv_result = GetTupleElement(new_conv, 0);
- XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
-
- XlaOp new_conv_result_unrevectorized = UnrevectorizeInstr(
- new_conv_result, dnums->output_feature_dimension(), *output_vect_dim,
- /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));
-
- TF_ASSIGN_OR_RETURN(
- HloComputation * new_conv_comp,
- BuilderToHloComputation(
- b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
- conv->parent()));
-
- // Set the name on the new conv. This is purely cosmetic, but we attempt to
- // preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
- auto new_conv_comp_instrs = new_conv_comp->instructions();
- auto new_conv_it =
- absl::c_find_if(new_conv_comp_instrs, [](HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kCustomCall;
- });
- if (new_conv_it != new_conv_comp_instrs.end()) {
- new_conv_comp->parent()->SetAndUniquifyInstrName(*new_conv_it,
- conv->name());
- }
-
- // Replace the old conv with a call to the computation we just created.
- VLOG(1) << "Re-vectorized conv to " << new_conv_comp->ToString();
- TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
- conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
- new_conv_comp)));
-
- return true;
-}
-
-// Tries to vectorize a convolution.
-//
-// Given a convolution of dimensions [N, C, H, W], tries to convert it to have
-// shape [N, C/vect_size, H, W, vect_size]. Similarly, given a kernel of shape
-// [H, W, I, O], tries to conver it to [H, W, I/vect_size, vect_size, O].
-//
-// This requires that C be a multiple of vect_size. CudnnPadForConvolutions can
-// add padding to make this true.
-static absl::StatusOr<bool> TryVectorizeConv(
- const se::CudaComputeCapability& compute_capability,
- const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
- int64_t vect_size) {
- const Shape& input_shape = conv->operand(0)->shape();
- const Shape& output_shape = conv->shape().tuple_shapes(0);
- const ConvolutionDimensionNumbers* dnums =
- &conv->convolution_dimension_numbers();
- int64_t in_channels =
- input_shape.dimensions(dnums->input_feature_dimension());
- int64_t out_channels =
- output_shape.dimensions(dnums->output_feature_dimension());
-
- if (in_channels % vect_size != 0 || out_channels % vect_size != 0) {
- return false;
- }
-
- if (input_shape.dimensions_size() >
- 2 + dnums->input_spatial_dimensions_size()) {
- // Conv already has an extra dimension, which we assume is the vectorized
- // features dim.
- return false;
- }
-
- // If this is an integer convolution check that we only vectorize when cuDNN
- // supports the vectorized implementation.
- if (primitive_util::IsIntegralType(input_shape.element_type())) {
- TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
- CudnnSupportsOptimizedIntegerConvolution(
- compute_capability, *conv, vect_size));
- if (!supported_target_vectorization) {
- VLOG(3) << "Skipping vectorization of conv to vector size: " << vect_size
- << ": " << conv->ToString();
- return false;
- }
- }
-
- VLOG(1) << "Vectorizing conv channels by " << vect_size << ": "
- << conv->ToString();
-
- // We use XlaBuilder because it's a lot easier to get these tricky
- // reshape/transposes correct using that API.
- XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
- b.SetOpMetadata(conv->metadata());
-
- XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
- absl::InlinedVector<XlaOp, 4> new_operands = {
- SplitAtDim(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
- dnums->input_feature_dimension(), vect_size),
- SplitAtDim(filter, dnums->kernel_input_feature_dimension(), vect_size),
- };
- if (conv->operand_count() > 2) {
- // Bias, if present. This is passed through unmodified.
- new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
- }
- if (conv->operand_count() > 3) {
- // Handle side input, which has same shape as the output.
- new_operands.push_back(
- SplitAtDim(Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
- dnums->output_feature_dimension(), vect_size));
- }
- if (conv->operand_count() > 4) {
- return InvalidArgument(
- "Don't understand a conv with more than 4 arguments: %s",
- conv->ToString());
- }
-
- // Reorder filter and bias for the int8x32 convolutions. This requires cudnn
- // >= 8.3.0.
- //
- // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
- const auto& debug_options = conv->GetModule()->config().debug_options();
- bool use_reordering =
- input_shape.element_type() == xla::S8 && vect_size == 32 &&
- debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
- cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
- if (use_reordering) {
- new_operands[1] = filter;
- TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
- dnums = &conv->convolution_dimension_numbers();
- }
-
- // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
- // value in the tuple represents the convolution's scratch memory.
- Shape new_output_shape = SplitShapeAtDim(
- output_shape, dnums->output_feature_dimension(), vect_size);
- XlaOp new_conv = CustomCallWithConvDnums(
- &b, conv->custom_call_target(), new_operands,
- ShapeUtil::MakeTupleShape(
- {new_output_shape, ShapeUtil::MakeShape(U8, {0})}),
- /*operand_shapes_with_layout=*/{},
- /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
- /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
- /*window=*/conv->window(),
- /*dnums=*/VectorizeDnums(*dnums, use_reordering));
-
- XlaOp new_conv_result = GetTupleElement(new_conv, 0);
- XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
-
- // Reshape back to the original shape.
- XlaOp conv_result_collapsed =
- Collapse(new_conv_result, {dnums->output_feature_dimension(),
- dnums->output_feature_dimension() + 1});
-
- TF_ASSIGN_OR_RETURN(
- HloComputation * new_conv_comp,
- BuilderToHloComputation(
- b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
- conv->parent()));
-
- // Create a tuple and replace the old conv with it!
- VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
- TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
- conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
- new_conv_comp)));
- return true;
-}
-
-} // namespace
-
-absl::StatusOr<bool> CudnnVectorizeConvolutions::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
- // Try to (re)vectorize to int8x32 if this is an sm75+ GPU. If we can't,
- // fall back to int8x4.
- bool local_changed = false;
- if (compute_capability_.IsAtLeast(7, 5)) {
- TF_ASSIGN_OR_RETURN(
- local_changed,
- TryRevectorizeConv(compute_capability_, cudnn_version_, conv, 32));
- if (!local_changed) {
- TF_ASSIGN_OR_RETURN(
- local_changed,
- TryVectorizeConv(compute_capability_, cudnn_version_, conv, 32));
- }
- }
- if (!local_changed) {
- TF_ASSIGN_OR_RETURN(
- local_changed,
- TryVectorizeConv(compute_capability_, cudnn_version_, conv, 4));
- }
- changed |= local_changed;
- }
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h
deleted file mode 100644
index 43165f2..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h
+++ /dev/null
@@ -1,73 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_
-#define XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-// Changes the shape of cudnn convolutions to allow faster "vectorized"
-// algorithms.
-//
-// On sm61+ will convert int8_t convolutions from
-//
-// - [N, C, H, W] to [N, C/4, H, W, 4],
-//
-// assuming C is divisible by 4.
-//
-// On sm75+ will convert int8_t convolutions from
-//
-// - [N, C, H, W] to [N, C/32, H, W, 32],
-// - [N, C/4, H, W, 4] to [N, C/32, H, W, 32], and
-// - [N, C, H, W] to [N, C/4, H, W, 4] (same as sm61+),
-//
-// assuming C is divisible by 4 or 32.
-//
-// This pass will not pad the channel dim to a multiple of 4 or 32, so you
-// should run CudnnPadForConvolutions before this.
-class CudnnVectorizeConvolutions : public HloModulePass {
- public:
- explicit CudnnVectorizeConvolutions(
- se::CudaComputeCapability compute_capability,
- se::dnn::VersionInfo cudnn_version)
- : compute_capability_(compute_capability),
- cudnn_version_(cudnn_version) {}
-
- absl::string_view name() const override {
- return "cudnn_vectorize_convolutions";
- }
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const se::CudaComputeCapability compute_capability_;
- const se::dnn::VersionInfo cudnn_version_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc
deleted file mode 100644
index aa15fc7..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc
+++ /dev/null
@@ -1,758 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-
-#include <cstdint>
-#include <utility>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/status/statusor.h"
-#include "xla/service/call_inliner.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class CudnnVectorizeConvolutionsTest : public HloTestBase {
- protected:
- // Runs this pass and some cleanup to make pattern-matching easier.
- absl::StatusOr<bool> Run(std::pair<int, int> compute_capability,
- HloModule* module) {
- CudnnVectorizeConvolutions pass(
- se::CudaComputeCapability{compute_capability.first,
- compute_capability.second},
- se::dnn::VersionInfo(8, 3, 0));
- TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module));
-
- CallInliner inliner;
- TF_RETURN_IF_ERROR(RunHloPass(&inliner, module).status());
-
- return changed;
- }
-};
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,40] parameter(0)
- filter = s8[2,2,40,44] parameter(1)
- ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward",
- backend_config="{bar: 0}"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 10, 4}),
- m::Reshape(m::Parameter(1))
- .WithShape(S8, {2, 2, 10, 4, 44}))
- .WithConvDnums("b01f?_01i?o->b01f?"))
- .WithShape(S8, {10, 20, 30, 11, 4})),
- m::Op())));
-
- EXPECT_EQ(conv->raw_backend_config_string(), "{bar: 0}");
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4UnsupportedFilterType) {
- // This test checks that the vectorize pass correctly calls
- // CudnnSupportsOptimizedIntegerConvolution() which should reject this
- // convolution because its filter type is f32.
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,40] parameter(0)
- filter = f32[2,2,40,44] parameter(1)
- ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward",
- backend_config="{bar: 0}"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4NCHW) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,48,20,30] parameter(0)
- filter = s8[48,44,2,2] parameter(1)
- ROOT result = (s8[10,44,20,30], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=bf01_io01->bf01,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 12, 4, 20, 30}),
- m::Reshape(m::Parameter(1))
- .WithShape(S8, {12, 4, 44, 2, 2}))
- .WithConvDnums("bf?01_i?o01->bf?01"))
- .WithShape(S8, {10, 11, 4, 20, 30})),
- m::Op())));
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, IncrementAllDnums) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[16,16,16,16] parameter(0)
- filter = s8[16,16,3,3] parameter(1)
- ROOT result = (s8[16,16,16,16], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=fb01_i01o->fb01,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {4, 4, 16, 16, 16}),
- m::Reshape(m::Parameter(1))
- .WithShape(S8, {4, 4, 16, 3, 3}))
- .WithConvDnums("f?b01_i?01o->f?b01"))
- .WithShape(S8, {4, 4, 16, 16, 16})),
- m::Op())));
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, FilterDnums) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[1,20,9,9] parameter(0)
- filter = s8[3,3,20,32] parameter(1)
- ROOT result = (s8[1,32,9,9], u8[0]) custom-call(s8[1,20,9,9] input, s8[3,3,20,32] filter),
- window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {1, 5, 4, 9, 9}),
- m::Reshape(m::Parameter(1))
- .WithShape(S8, {3, 3, 5, 4, 32}))
- .WithConvDnums("bf?01_01i?o->bf?01"))
- .WithShape(S8, {1, 8, 4, 9, 9})),
- m::Op())));
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,41] parameter(0)
- filter = s8[2,2,41,44] parameter(1)
- ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- CudnnVectorizeConvolutions pass(
- /*compute_capability=*/{7, 5},
- /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0});
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-
- SCOPED_TRACE(module->ToString());
- EXPECT_FALSE(changed);
-}
-
-// Don't vectorize int8_t -> int32_t into int8x4 or int8x32; this is not
-// supported in cudnn.
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsS32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,41] parameter(0)
- filter = s8[2,2,41,44] parameter(1)
- ROOT result = (s32[10,20,30,44], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- SCOPED_TRACE(module->ToString());
- EXPECT_FALSE(changed);
-}
-
-// Don't vectorize int8_t -> float into int8x4 or int8x32. Vectorizing to
-// int8x4 *is* allowed by cudnn, but we don't do it at the moment.
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsF32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,41] parameter(0)
- filter = s8[2,2,41,44] parameter(1)
- ROOT result = (f32[10,20,30,44], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- SCOPED_TRACE(module->ToString());
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,64] parameter(0)
- filter = s8[2,2,64,128] parameter(1)
- ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 2, 32}),
- m::Reshape(
- m::Transpose(
- m::Reshape(m::Parameter(1))
- .WithShape(S8, {2, 2, 2, 8, 4, 16, 4, 2}))
- .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
- .WithPredicate([](const HloInstruction* instr) {
- return absl::c_equal(
- instr->dimensions(),
- std::vector<int64_t>{2, 0, 1, 5, 7, 3, 6,
- 4});
- }))
- .WithShape(S8, {128, 2, 2, 2, 32})))
- .WithShape(S8, {10, 20, 30, 4, 32})),
- m::Op())));
-
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,64] parameter(0)
- filter = s8[2,2,64,128] parameter(1)
- bias = f32[128] parameter(2)
- side_input = s8[10,20,30,64] parameter(3)
-
- ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter, bias, side_input),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 2, 32}),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
- .WithShape(S8, {128, 2, 2, 2, 32}),
- m::Reshape(
- m::Transpose(m::Reshape(m::Parameter(2))
- .WithShape(F32, {4, 4, 2, 4}))
- .WithShape(F32, {4, 2, 4, 4})
- .WithPredicate([](const HloInstruction* instr) {
- return absl::c_equal(
- instr->dimensions(),
- std::vector<int64_t>{0, 2, 1, 3});
- }))
- .WithShape(F32, {128}),
- m::Reshape(m::Parameter(3))
- .WithShape(S8, {10, 20, 30, 2, 32})))
- .WithShape(S8, {10, 20, 30, 4, 32})),
- m::Op())));
-
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, InputNHWC_OutputNCHW) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,64] parameter(0)
- filter = s8[2,2,64,128] parameter(1)
- bias = f32[128] parameter(2)
- side_input = s8[10,128,20,30] parameter(3)
-
- ROOT result = (s8[10,128,20,30], u8[0]) custom-call(input, filter, bias, side_input),
- window={size=2x2}, dim_labels=b01f_01io->bf01,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 2, 32}),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
- .WithShape(S8, {128, 2, 2, 2, 32}),
- m::Reshape(
- m::Transpose(m::Reshape(m::Parameter(2))
- .WithShape(F32, {4, 4, 2, 4}))
- .WithShape(F32, {4, 2, 4, 4})
- .WithPredicate([](const HloInstruction* instr) {
- return absl::c_equal(
- instr->dimensions(),
- std::vector<int64_t>{0, 2, 1, 3});
- }))
- .WithShape(F32, {128}),
- m::Reshape(m::Parameter(3))
- .WithShape(S8, {10, 4, 32, 20, 30})))
- .WithShape(S8, {10, 4, 32, 20, 30})),
- m::Op())));
-
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,64] parameter(0)
- filter = s8[2,2,64,128] parameter(1)
- ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- root,
- GmockMatch(m::Tuple(
- m::Reshape(m::GetTupleElement(
- m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 16, 4}),
- m::Reshape(m::Parameter(1))
- .WithShape(S8, {2, 2, 16, 4, 128})))
- .WithShape(S8, {10, 20, 30, 32, 4})),
- m::Op())));
-
- EXPECT_FALSE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,16,4] parameter(0)
- filter = s8[3,5,16,192,4] parameter(1)
- bias = f32[64] parameter(2)
- side_input = s8[10,20,30,16,4] parameter(3)
- ROOT result = (s8[10,20,30,48,4], u8[0]) custom-call(input, filter, bias, side_input),
- window={size=3x5}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- auto conv_pat =
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 2, 8, 4}))
- .WithShape(S8, {10, 20, 30, 2, 8, 4}))
- .WithShape(S8, {10, 20, 30, 2, 32}),
- m::Reshape(
- m::Transpose(m::Reshape(m::Parameter(1))
- .WithShape(S8, {3, 5, 2, 8, 24, 4, 2, 4}))
- .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
- .WithPredicate([](const HloInstruction* instr) {
- return absl::c_equal(
- instr->dimensions(),
- std::vector<int64_t>{2, 0, 1, 4, 6, 3, 5, 7});
- }))
- .WithShape(S8, {192, 2, 3, 5, 32}),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
- .WithShape(S8, {10, 20, 30, 2, 8, 4}))
- .WithShape(S8, {10, 20, 30, 2, 8, 4}))
- .WithShape(S8, {10, 20, 30, 2, 32}))
- .WithConvDnums("b01f?_oi01?->b01f?"))
- .WithShape(S8, {10, 20, 30, 6, 32});
- ASSERT_THAT(root, GmockMatch(m::Tuple(
- m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
- S8, {10, 20, 30, 6, 8, 4}))
- .WithShape(S8, {10, 20, 30, 6, 8, 4}))
- .WithShape(S8, {10, 20, 30, 48, 4}),
- m::Op())));
-
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,16,20,30,4] parameter(0)
- filter = s8[16,128,2,2,4] parameter(1)
- bias = f32[64] parameter(2)
- side_input = s8[10,16,20,30,4] parameter(3)
- ROOT result = (s8[10,32,20,30,4], u8[0]) custom-call(input, filter, bias, side_input),
- window={size=2x2}, dim_labels=bf01_io01->bf01,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- auto conv_pat =
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 2, 8, 20, 30, 4}))
- .WithShape(S8, {10, 2, 20, 30, 8, 4}))
- .WithShape(S8, {10, 2, 20, 30, 32}),
- m::Reshape(
- m::Transpose(m::Reshape(m::Parameter(1))
- .WithShape(S8, {2, 8, 16, 4, 2, 2, 2, 4}))
- .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
- .WithPredicate([](const HloInstruction* instr) {
- return absl::c_equal(
- instr->dimensions(),
- std::vector<int64_t>{0, 5, 6, 2, 4, 1, 3, 7});
- }))
- .WithShape(S8, {128, 2, 2, 2, 32}),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
- .WithShape(S8, {10, 2, 8, 20, 30, 4}))
- .WithShape(S8, {10, 2, 20, 30, 8, 4}))
- .WithShape(S8, {10, 2, 20, 30, 32}))
- .WithConvDnums("bf01_oi01->bf01"))
- .WithShape(S8, {10, 4, 20, 30, 32});
- ASSERT_THAT(root, GmockMatch(m::Tuple(
- m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
- S8, {10, 4, 20, 30, 8, 4}))
- .WithShape(S8, {10, 4, 8, 20, 30, 4}))
- .WithShape(S8, {10, 32, 20, 30, 4}),
- m::Op())));
-
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[4,10,20,30,16] parameter(0)
- filter = s8[4,3,5,16,192] parameter(1)
- bias = f32[64] parameter(2)
- side_input = s8[4,10,20,30,16] parameter(3)
- ROOT result = (s8[4,10,20,30,48], u8[0]) custom-call(input, filter, bias, side_input),
- window={size=3x5}, dim_labels=?b01f_?01io->?b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- auto conv_pat =
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
- .WithShape(S8, {4, 10, 20, 30, 2, 8}))
- .WithShape(S8, {8, 4, 10, 20, 30, 2}))
- .WithShape(S8, {32, 10, 20, 30, 2}),
- m::Reshape(
- m::Transpose(m::Reshape(m::Parameter(1))
- .WithShape(S8, {4, 3, 5, 2, 8, 24, 4, 2}))
- .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
- .WithPredicate([](const HloInstruction* instr) {
- return absl::c_equal(
- instr->dimensions(),
- std::vector<int64_t>{3, 1, 2, 5, 7, 4, 6, 0});
- }))
- .WithShape(S8, {192, 2, 3, 5, 32}),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
- .WithShape(S8, {4, 10, 20, 30, 2, 8}))
- .WithShape(S8, {8, 4, 10, 20, 30, 2}))
- .WithShape(S8, {32, 10, 20, 30, 2}))
- .WithConvDnums("?b01f_oi01->?b01f"))
- .WithShape(S8, {32, 10, 20, 30, 6});
- ASSERT_THAT(root, GmockMatch(m::Tuple(
- m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
- S8, {8, 4, 10, 20, 30, 6}))
- .WithShape(S8, {4, 10, 20, 30, 6, 8}))
- .WithShape(S8, {4, 10, 20, 30, 48}),
- m::Op())));
-
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorize4To32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,16,4] parameter(0)
- filter = s8[2,2,16,128,4] parameter(1)
- bias = f32[10] parameter(2)
- side_input = s8[10,20,30,16,4] parameter(3)
- ROOT result = (s8[10,20,30,32,4], u8[0]) custom-call(input, filter, bias, side_input),
- window={size=2x2}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize16To32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,4,16] parameter(0)
- filter = s8[3,5,4,192,16] parameter(1)
- ROOT result = (s8[10,20,30,12,16], u8[0]) custom-call(input, filter),
- window={size=3x5}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- auto filter_pat =
- m::Reshape(
- m::Transpose(
- m::Reshape(m::Parameter(1)).WithShape(S8, {3, 5, 2, 2, 192, 16}))
- .WithShape(S8, {3, 5, 2, 192, 2, 16}))
- .WithShape(S8, {3, 5, 2, 192, 32});
- auto conv_pat =
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(
- m::Transpose(m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 2, 2, 16}))
- .WithShape(S8, {10, 20, 30, 2, 2, 16}))
- .WithShape(S8, {10, 20, 30, 2, 32}),
- m::Reshape(
- m::Transpose(m::Reshape(filter_pat)
- .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
- .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
- .WithShape(S8, {192, 2, 3, 5, 32}))
- .WithConvDnums("b01f_oi01->b01f"))
- .WithShape(S8, {10, 20, 30, 6, 32});
- ASSERT_THAT(root, GmockMatch(m::Tuple(
- m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
- S8, {10, 20, 30, 6, 2, 16}))
- .WithShape(S8, {10, 20, 30, 6, 2, 16}))
- .WithShape(S8, {10, 20, 30, 12, 16}),
- m::Op())));
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeMixedTo32) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- input = s8[10,20,30,8,8] parameter(0)
- filter = s8[3,5,2,192,32] parameter(1)
- ROOT result = (s8[10,20,30,96,2], u8[0]) custom-call(input, filter),
- window={size=3x5}, dim_labels=b01f_01io->b01f,
- custom_call_target="__cudnn$convForward"
- })")
- .value();
- TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
- EXPECT_TRUE(changed);
-
- SCOPED_TRACE(module->ToString());
- auto* root = module->entry_computation()->root_instruction();
-
- const HloInstruction* conv = nullptr;
- auto conv_pat =
- m::GetTupleElement(
- m::CustomCall(
- &conv, {kCudnnConvForwardCallTarget},
- m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
- .WithShape(S8, {10, 20, 30, 2, 4, 8}))
- .WithShape(S8, {10, 20, 30, 2, 4, 8}))
- .WithShape(S8, {10, 20, 30, 2, 32}),
- m::Reshape(
- m::Transpose(m::Reshape(m::Parameter(1))
- .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
- .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
- .WithShape(S8, {192, 2, 3, 5, 32}))
- .WithConvDnums("b01f_oi01->b01f"))
- .WithShape(S8, {10, 20, 30, 6, 32});
- ASSERT_THAT(root, GmockMatch(m::Tuple(
- m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
- S8, {10, 20, 30, 6, 16, 2}))
- .WithShape(S8, {10, 20, 30, 6, 16, 2}))
- .WithShape(S8, {10, 20, 30, 96, 2}),
- m::Op())));
- EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
- ->cudnn_conv_backend_config()
- .reordered_int8_nchw_vect());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc
deleted file mode 100644
index 387a3f4..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc
+++ /dev/null
@@ -1,272 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_workspace_rewriter.h"
-
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/cuda/cuda_dnn.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// create cuDNN graphs from HloCustomCall
-absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
- se::dnn::DnnSupport& dnn_support, HloCustomCallInstruction* custom_call) {
- if (IsFwdCustomCallTofMHA(*custom_call)) {
- TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind,
- xla::gpu::GetCudnnfMHAKind(custom_call));
- std::optional<Shape> mask_shape, bias_shape;
- {
- bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax ||
- kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout;
-
- if (has_bias) {
- const HloInstruction* bias = custom_call->operand(3);
- bias_shape = bias->shape();
- }
- }
-
- TF_ASSIGN_OR_RETURN(
- const auto gpu_config,
- custom_call->backend_config<xla::gpu::GpuBackendConfig>());
- const xla::gpu::CudnnfMHABackendConfig& config =
- gpu_config.cudnn_fmha_backend_config();
- Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
- absl::InlinedVector<Shape, 2> output_shapes = {
- ShapeUtil::GetSubshape(custom_call->shape(), {0})};
-
- bool has_activation =
- xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3;
- if (has_activation) {
- output_shapes.push_back(
- ShapeUtil::GetSubshape(custom_call->shape(), {1}));
- }
-
- Shape q_shape = custom_call->operand(0)->shape();
- Shape k_shape = custom_call->operand(1)->shape();
- Shape v_shape = custom_call->operand(2)->shape();
- TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
- AsCudnnFmhaMaskKind(config.mask_type()));
- GpufMHADescriptor descriptor = {kind,
- config,
- cudnn_mask_type,
- q_shape,
- k_shape,
- v_shape,
- intermediate_tensor_shape,
- output_shapes,
- config.bmm1_dot_dimension_numbers(),
- config.bmm2_dot_dimension_numbers(),
- mask_shape,
- bias_shape};
-
- TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config,
- GpufMHAConfig::For(descriptor));
- TF_ASSIGN_OR_RETURN(
- se::dnn::FMHAMaskKind dnn_mask_type,
- GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
- TF_ASSIGN_OR_RETURN(
- se::gpu::CudnnGraph graph,
- se::gpu::GetCudnnFlashAttentionOperationGraph(
- dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1,
- fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias,
- fmha_config.activation, static_cast<float>(*fmha_config.fmha_scale),
- fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
- fmha_config.dropout_rate, dnn_mask_type));
- return std::move(graph);
- } else {
- TF_ASSIGN_OR_RETURN(
- auto gpu_config,
- custom_call->backend_config<xla::gpu::GpuBackendConfig>());
- xla::gpu::CudnnfMHABackendConfig& config =
- *gpu_config.mutable_cudnn_fmha_backend_config();
-
- int input_index = 0;
- Shape bmm1_grad_gemm1_rhs_shape =
- custom_call->operand(input_index++)->shape();
- Shape bmm1_grad_gemm2_rhs_shape =
- custom_call->operand(input_index++)->shape();
- Shape bmm2_grad_gemm2_rhs_shape =
- custom_call->operand(input_index++)->shape();
- Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape());
- input_index++;
- Shape d_output_shape = custom_call->operand(input_index++)->shape();
-
- TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind,
- GetCudnnfMHAKind(custom_call));
- std::optional<Shape> mask_shape;
-
- bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax ||
- kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout);
- std::optional<Shape> bias_shape;
- if (has_bias) {
- bias_shape = custom_call->operand(input_index++)->shape();
- }
-
- std::optional<Shape> fwd_output_shape =
- custom_call->operand(input_index++)->shape();
- if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING ||
- config.mask_type() ==
- xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) {
- // skip q_seqlen and kv_seqlen
- input_index += 2;
- }
- TF_RET_CHECK(input_index == custom_call->operand_count());
-
- int output_index = 0;
- Shape d_bmm1_lhs_shape =
- ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
- Shape d_bmm1_rhs_shape =
- ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
- Shape d_bmm2_rhs_shape =
- ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
- std::optional<Shape> d_s_shape;
- std::optional<Shape> d_bias_shape;
- bool has_dbias = custom_call->shape().tuple_shapes().size() == 5;
- if (has_dbias) {
- d_bias_shape =
- ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
- }
- // The last one is the workspace.
- TF_RET_CHECK(output_index ==
- custom_call->shape().tuple_shapes().size() - 1);
- TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
- AsCudnnFmhaMaskKind(config.mask_type()));
-
- const bool force_deterministic =
- RequireDeterminism(custom_call->GetModule()->config());
- // set the correct force_deterministic attribute here
- config.set_force_deterministic(force_deterministic);
- TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));
-
- GpufMHABackwardDescriptor descriptor = {
- kind,
- config,
- cudnn_mask_type,
- bmm1_grad_gemm1_rhs_shape,
- bmm1_grad_gemm2_rhs_shape,
- bmm2_grad_gemm1_lhs_shape,
- bmm2_grad_gemm2_rhs_shape,
- d_output_shape,
- d_bmm1_lhs_shape,
- d_bmm1_rhs_shape,
- d_bmm2_rhs_shape,
- config.bmm1_grad_gemm1_dot_dimension_numbers(),
- config.bmm1_grad_gemm2_dot_dimension_numbers(),
- config.bmm2_grad_gemm1_dot_dimension_numbers(),
- config.bmm2_grad_gemm2_dot_dimension_numbers(),
- d_s_shape,
- fwd_output_shape,
- mask_shape,
- d_bias_shape,
- bias_shape,
- force_deterministic};
-
- TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config,
- GpufMHABackwardConfig::For(descriptor));
- TF_ASSIGN_OR_RETURN(
- se::dnn::FMHAMaskKind dnn_mask_type,
- GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
-
- TF_ASSIGN_OR_RETURN(
- se::gpu::CudnnGraph graph,
- se::gpu::GetCudnnFlashAttentionBackwardOperationGraph(
- dnn_support, fmha_config.bmm1_grad_gemm1_rhs,
- fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs,
- fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output,
- fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs,
- fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate,
- fmha_config.seed, *fmha_config.fmha_scale,
- fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
- fmha_config.bias != std::nullopt, dnn_mask_type,
- force_deterministic));
- return std::move(graph);
- }
-}
-
-class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor {
- public:
- explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport& dnn_support)
- : dnn_support_(dnn_support) {}
-
- absl::Status HandleCustomCall(HloInstruction* hlo) override {
- if (!IsCustomCallTofMHA(*hlo)) {
- // don't do anything about other cuDNN custom calls
- return absl::OkStatus();
- }
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- hlo->backend_config<GpuBackendConfig>());
-
- TF_ASSIGN_OR_RETURN(
- se::gpu::CudnnGraph graph,
- HloCustomCallToCuDnnGraph(dnn_support_,
- DynCast<HloCustomCallInstruction>(hlo)));
- auto workspace = graph.Graph().get_workspace_size();
- if (workspace != 0) {
- // rewrite custom call to have correct workspace size
- VLOG(4) << "Rewriting: " << hlo->ToString();
- Shape* shape = hlo->mutable_shape();
- shape->mutable_tuple_shapes(shape->tuple_shapes_size() - 1)
- ->set_dimensions(0, workspace);
- MarkAsChanged();
- }
- return absl::OkStatus();
- }
-
- private:
- se::dnn::DnnSupport& dnn_support_;
-};
-
-} // namespace
-
-absl::StatusOr<bool> CuDnnWorkspaceRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_SCOPED_LOGGING_TIMER("cuDNN workspace rewriter");
- return CuDnnCustomCallVisitor(dnn_support_)
- .RunOnModule(module, execution_threads);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h
deleted file mode 100644
index de81d6d..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrite cuDNN custom call to have correct workspace size by build graph
-// and serialize so we can use it later
-class CuDnnWorkspaceRewriter : public HloModulePass {
- public:
- explicit CuDnnWorkspaceRewriter(se::StreamExecutor& stream_exec)
- : dnn_support_(*stream_exec.AsDnn()) {}
-
- absl::string_view name() const override { return "cudnn-workspace-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::dnn::DnnSupport& dnn_support_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cusolver_rewriter.cc b/third_party/xla/xla/service/gpu/cusolver_rewriter.cc
deleted file mode 100644
index ddfda66..0000000
--- a/third_party/xla/xla/service/gpu/cusolver_rewriter.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cusolver_rewriter.h"
-
-#include <cstdint>
-#include <functional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/literal.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/cusolver_context.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-void SetFortranLayout(Shape* shape) {
- LayoutUtil::SetToDefaultLayout(shape);
- int n = shape->mutable_layout()->minor_to_major_size();
- CHECK_GE(n, 2);
- std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
- shape->mutable_layout()->mutable_minor_to_major()->at(1));
-}
-
-absl::StatusOr<HloInstruction*> CreateCholesky(GpuSolverContext* context,
- HloInstruction* operand,
- const CholeskyOptions& options,
- const OpMetadata& metadata) {
- HloComputation* computation = operand->parent();
-
- Shape a_shape = operand->shape();
- int ndim = a_shape.dimensions_size();
- CHECK_GE(ndim, 2);
- int64_t n = a_shape.dimensions(ndim - 1);
-
- std::vector<int64_t> batch_dims(a_shape.dimensions().begin(),
- a_shape.dimensions().end() - 2);
- std::vector<int64_t> batch_dim_ids(batch_dims.size());
- absl::c_iota(batch_dim_ids, 0);
- int64_t batch_size = absl::c_accumulate(batch_dims, 1, std::multiplies<>{});
-
- // Find the workspace size.
- se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
- : se::blas::UpperLower::kUpper;
- int64_t workspace_size; // Number of elements of size a_shape.element_type()
- TF_ASSIGN_OR_RETURN(
- workspace_size,
- context->PotrfBufferSize(a_shape.element_type(), uplo, n, n, batch_size));
-
- // TODO(phawkins): Ideally we would relax this constraint. What we actually
- // want is that:
- // a) the batch dimensions are major, in no particular order.
- // b) the two minor dimensions are in fortran (column-major) order,
-
- SetFortranLayout(&a_shape);
-
- // This call returns a tuple of (cholesky_result, workspace, info) where:
- // * cholesky_result is the result of the Cholesky decomposition,
- // * workspace is temporary scratch memory used by cuSolver.
- // * info contains the Potrf success/failure status.
- // Currently we have no meaningful way to report an error, so we simply
- // discard the success/failure information. Obviously this is suboptimal.
- Shape info_shape = ShapeUtil::MakeShape(S32, batch_dims);
- Shape call_shape = ShapeUtil::MakeTupleShape(
- {a_shape,
- ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
- info_shape});
-
- HloInstruction* custom_call =
- computation->AddInstruction(HloInstruction::CreateCustomCall(
- call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
- custom_call->set_metadata(metadata);
- TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
- HloInstruction* out = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(a_shape, custom_call, 0));
- HloInstruction* info = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(info_shape, custom_call, 2));
-
- // If info was non-zero, indicating that the Cholesky decomposition failed,
- // returns an array full of NaNs for the corresponding batch element.
- HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
- HloInstruction* zeros =
- computation->AddInstruction(HloInstruction::CreateBroadcast(
- info_shape, zero, /*broadcast_dimensions=*/{}));
- HloInstruction* ok = computation->AddInstruction(
- HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, batch_dims),
- info, zeros, ComparisonDirection::kEq));
- ok = computation->AddInstruction(HloInstruction::CreateBroadcast(
- ShapeUtil::MakeShape(PRED, a_shape.dimensions()), ok,
- /*broadcast_dimensions=*/batch_dim_ids));
-
- TF_ASSIGN_OR_RETURN(Literal nan_literal,
- LiteralUtil::NanValue(a_shape.element_type()));
- HloInstruction* nan = computation->AddInstruction(
- HloInstruction::CreateConstant(std::move(nan_literal)));
- HloInstruction* nans =
- computation->AddInstruction(HloInstruction::CreateBroadcast(
- a_shape, nan, /*broadcast_dimensions=*/{}));
-
- HloInstruction* select =
- computation->AddInstruction(HloInstruction::CreateTernary(
- a_shape, HloOpcode::kSelect, ok, out, nans));
- return select;
-}
-
-// Tries to rewrite a single convolution into a call to cudnn.
-absl::StatusOr<bool> RunOnInstruction(GpuSolverContext* context,
- HloInstruction* instruction) {
- if (instruction->opcode() != HloOpcode::kCholesky) {
- return false;
- }
-
- TF_ASSIGN_OR_RETURN(
- HloInstruction * custom_call,
- CreateCholesky(context, instruction->mutable_operand(0),
- instruction->cholesky_options(), instruction->metadata()));
-
- VLOG(1) << "Replacing " << instruction->ToString() << " with "
- << custom_call->ToString();
-
- TF_RETURN_IF_ERROR(
- instruction->parent()->ReplaceInstruction(instruction, custom_call));
- return true;
-}
-
-} // namespace
-
-// Rewrites the convolutions in the given computation into calls to cudnn.
-// Returns true if it made any changes.
-absl::StatusOr<bool> GpusolverRewriter::RunOnComputation(
- HloComputation* computation) {
- std::vector<HloInstruction*> cusolver_calls;
- for (auto* hlo : computation->instructions()) {
- if (hlo->opcode() == HloOpcode::kCholesky) {
- cusolver_calls.push_back(hlo);
- }
- }
-
- if (cusolver_calls.empty()) {
- return false;
- }
-
- TF_ASSIGN_OR_RETURN(GpuSolverContext context, GpuSolverContext::Create());
-
- bool changed = false;
- for (HloInstruction* instruction : cusolver_calls) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(&context, instruction));
- changed |= result;
- }
- return changed;
-}
-
-GpusolverRewriter::GpusolverRewriter() = default;
-
-absl::StatusOr<bool> GpusolverRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cusolver_rewriter.h b/third_party/xla/xla/service/gpu/cusolver_rewriter.h
deleted file mode 100644
index fd1d84d..0000000
--- a/third_party/xla/xla/service/gpu/cusolver_rewriter.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
-#define XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver.
-class GpusolverRewriter : public HloModulePass {
- public:
- GpusolverRewriter();
- absl::string_view name() const override { return "gpusolver-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc
deleted file mode 100644
index d5114bc..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc
+++ /dev/null
@@ -1,220 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_autotuner.h"
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <tuple>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/time/time.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/kernels/custom_kernel.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-absl::StatusOr<std::unique_ptr<HloModule>> ExtractFusionModule(
- HloInstruction* fusion_instruction, int64_t kernel_index) {
- std::unique_ptr<HloModule> hlo_module =
- ExtractInstructionIntoNewModule(*fusion_instruction);
-
- HloInstruction* instruction =
- hlo_module->entry_computation()->root_instruction();
- GpuBackendConfig gpu_config =
- instruction->backend_config<GpuBackendConfig>().value();
- gpu_config.mutable_fusion_backend_config()
- ->mutable_custom_fusion_config()
- ->set_kernel_index(kernel_index);
- TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
-
- return hlo_module;
-}
-
-absl::StatusOr<std::vector<std::tuple<int, absl::Duration>>> ProfileKernels(
- std::vector<CustomKernel>& kernels, HloInstruction* fusion_instruction,
- AutotunerCompileUtil& compile_util, const AutotuneConfig& autotune_config,
- const DebugOptions& debug_options) {
- se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
- std::vector<std::tuple<int, absl::Duration>> results;
- for (int i = 0; i < kernels.size(); ++i) {
- TF_ASSIGN_OR_RETURN(absl::StatusOr<std::unique_ptr<Executable>> executable,
- compile_util.Compile([&](const DebugOptions& opt) {
- return ExtractFusionModule(fusion_instruction, i);
- }));
-
- se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator();
- std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
- if (allocator == nullptr) {
- owned_allocator =
- std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
- allocator = owned_allocator.get();
- }
- TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream());
-
- TF_ASSIGN_OR_RETURN(auto rz_buffers,
- RedzoneBuffers::FromInstruction(
- *fusion_instruction, autotune_config, debug_options,
- RedzoneBuffers::kAllInputs));
-
- std::optional<ScopedShapedBuffer> reference_buffer;
- std::optional<AutotunerCompileUtil::ProfilingOutput> profiling_output;
- TF_ASSIGN_OR_RETURN(profiling_output, compile_util.ProfileExecutable(
- executable->get(), stream,
- rz_buffers.input_buffers(),
- rz_buffers.input_shapes()));
- results.push_back({i, profiling_output->duration});
- }
- return results;
-}
-
-absl::StatusOr<int> FindFastestKernel(
- const std::vector<std::tuple<int, absl::Duration>>& results) {
- auto iter = absl::c_min_element(
- results, [](const std::tuple<int, absl::Duration>& lhs,
- const std::tuple<int, absl::Duration>& rhs) {
- return std::get<1>(lhs) < std::get<1>(rhs);
- });
- if (iter == results.end()) {
- return absl::InternalError("Failed to find fastest kernel.");
- }
- return std::get<0>(*iter);
-}
-
-absl::Status UpdateFusionInstructionKernelIndex(
- HloInstruction* fusion_instruction, int kernel_index) {
- GpuBackendConfig gpu_config =
- fusion_instruction->backend_config<GpuBackendConfig>().value();
- gpu_config.mutable_fusion_backend_config()
- ->mutable_custom_fusion_config()
- ->set_kernel_index(kernel_index);
- TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config));
-
- return absl::OkStatus();
-}
-
-absl::StatusOr<std::vector<CustomKernel>> LoadKernels(
- const HloInstruction* fusion_instruction,
- const AutotuneConfig& autotune_config) {
- auto config = fusion_instruction->backend_config<GpuBackendConfig>()
- ->fusion_backend_config()
- .custom_fusion_config();
- auto* registry = CustomKernelFusionRegistry::Default();
- auto* custom_kernel_fusion = registry->Lookup(config.name());
-
- // If custom fusion is not found it means that some of the build targets might
- // not be statically linked into the binary.
- if (custom_kernel_fusion == nullptr) {
- return absl::InternalError(
- absl::StrCat("Custom kernel fusion ", config.name(),
- " not found in a default registry."));
- }
-
- se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
- if (!stream_exec->SynchronizeAllActivity()) {
- return Internal("Failed to synchronize GPU for autotuning.");
- }
- se::DeviceDescription device_description =
- stream_exec->GetDeviceDescription();
-
- // Load custom kernels that can implement a fusion computation.
- TF_ASSIGN_OR_RETURN(
- std::vector<CustomKernel> kernels,
- custom_kernel_fusion->LoadKernels(
- device_description,
- fusion_instruction->fused_instructions_computation()));
-
- return kernels;
-}
-
-absl::StatusOr<bool> AutotuneCustomKernelFusion(
- HloInstruction* fusion_instruction, const AutotuneConfig& autotune_config,
- AutotunerCompileUtil& compile_util, const DebugOptions& debug_options) {
- int previous_kernel_index =
- fusion_instruction->backend_config<GpuBackendConfig>()
- ->fusion_backend_config()
- .custom_fusion_config()
- .kernel_index();
-
- TF_ASSIGN_OR_RETURN(std::vector<CustomKernel> kernels,
- LoadKernels(fusion_instruction, autotune_config));
-
- std::vector<std::tuple<int, absl::Duration>> results;
- TF_ASSIGN_OR_RETURN(results,
- ProfileKernels(kernels, fusion_instruction, compile_util,
- autotune_config, debug_options));
-
- TF_ASSIGN_OR_RETURN(int fastest_kernel_index, FindFastestKernel(results));
-
- TF_RETURN_IF_ERROR(UpdateFusionInstructionKernelIndex(fusion_instruction,
- fastest_kernel_index));
-
- return previous_kernel_index != fastest_kernel_index;
-}
-} // namespace
-
-absl::StatusOr<bool> CustomKernelFusionAutotuner::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- const DebugOptions& debug_options = module->config().debug_options();
- TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> compile_util,
- AutotunerCompileUtil::Create(config_, debug_options));
- TF_RET_CHECK(compile_util.has_value());
-
- bool hlo_changed = false;
- for (const HloComputation* computation : module->computations()) {
- if (computation->IsFusionComputation()) {
- TF_ASSIGN_OR_RETURN(
- bool instruction_changed,
- AutotuneCustomKernelFusion(computation->FusionInstruction(), config_,
- compile_util.value(), debug_options));
- if (instruction_changed) {
- hlo_changed = true;
- }
- }
- }
-
- return hlo_changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h b/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h
deleted file mode 100644
index f6cd0c0..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
-#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/xla.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// Find best custom kernel for custom kernel fusions.
-class CustomKernelFusionAutotuner : public HloModulePass {
- public:
- explicit CustomKernelFusionAutotuner(const AutotuneConfig& config)
- : config_(config) {}
-
- absl::string_view name() const override {
- return "custom_kernel-fusion-autotuner";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const AutotuneConfig config_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc
deleted file mode 100644
index aa6c1d2..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_autotuner.h"
-
-#include <memory>
-#include <string>
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class CustomKernelFusionAutotunerTest : public HloTestBase {
- public:
- CustomKernelFusionAutotunerTest()
- : HloTestBase(/*verifier_layout_sensitive=*/false,
- /*allow_mixed_precision_in_hlo_verifier=*/true) {}
-
- void SetUp() override { HloTestBase::SetUp(); }
-
- void TearDown() override { HloTestBase::TearDown(); }
-};
-
-TEST_F(CustomKernelFusionAutotunerTest,
- CustomKernelFusionAutotunerPassSucceeds) {
- const std::string hlo_string = R"(
- HloModule extracted
-
- cutlass_gemm {
- p0 = f32[15,19]{1,0} parameter(0)
- p1 = f32[19,17]{1,0} parameter(1)
- ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
- ENTRY region_198.14436 {
- p.0 = f32[15,19]{1,0} parameter(0)
- p.1 = f32[19,17]{1,0} parameter(1)
- ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom, calls=cutlass_gemm, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":0}},"force_earliest_schedule":false}
- }
- )";
- std::unique_ptr<HloModule> hlo_module =
- ParseAndReturnVerifiedModule(hlo_string).value();
-
- HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
- DebugOptions debug_options;
- AutotuneConfig autotune_config =
- AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
- backend().memory_allocator()},
- debug_options};
- pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
- ASSERT_TRUE(pipeline.Run(hlo_module.get()).ok());
-}
-
-TEST_F(CustomKernelFusionAutotunerTest,
- CustomKernelFusionAutotunerPassUpdatesUpdatesKernelIndex) {
- const std::string hlo_string = R"(
- HloModule extracted
-
- cutlass_gemm {
- p0 = f32[15,19]{1,0} parameter(0)
- p1 = f32[19,17]{1,0} parameter(1)
- ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1},
- rhs_contracting_dims={0}
- }
-
- ENTRY region_198.14436 {
- p.0 = f32[15,19]{1,0} parameter(0)
- p.1 = f32[19,17]{1,0} parameter(1)
- ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom,
- calls=cutlass_gemm,
- backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":-1}},"force_earliest_schedule":false}
- }
- )";
-
- HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
- DebugOptions debug_options;
- AutotuneConfig autotune_config =
- AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
- backend().memory_allocator()},
- debug_options};
- pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
-
- std::string expected = R"(
- CHECK: "kernel_index":0
- )";
- RunAndFilecheckHloRewrite(hlo_string, std::move(pipeline), expected);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc
deleted file mode 100644
index 814ccf0..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc
+++ /dev/null
@@ -1,240 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-CustomKernelFusionRewriter::CustomKernelFusionRewriter(
- const se::DeviceDescription* device,
- const CustomKernelFusionPatternRegistry* patterns)
- : device_(device), patterns_(patterns) {}
-
-// Returns a set of instruction that have users outside of a matched pattern
-// and have a replacement that must be applied after building a new custom
-// fusion instruction. Only root instruction can have external users and does
-// not require a replacement, as the fusion itself is a replacement. If
-// instruction has external users and does not have a replacement returns empty
-// optional.
-static std::optional<absl::flat_hash_set<HloInstruction*>>
-GetPatternReplacements(const CustomKernelFusionPattern::Match& match) {
- absl::flat_hash_set<HloInstruction*> requires_replacement;
- absl::flat_hash_set<HloInstruction*> instructions_set(
- match.instructions().begin(), match.instructions().end());
-
- for (HloInstruction* instr : match.instructions()) {
- for (HloInstruction* user : instr->users()) {
- if (instr == match.root() || instructions_set.contains(user)) continue;
-
- if (match.HasReplacement(instr)) {
- requires_replacement.insert(instr);
- continue;
- }
-
- VLOG(3) << "Custom kernel fusion intermediate result " << instr->name()
- << " has users outside of a matched pattern: " << user->name();
- return std::nullopt;
- }
- }
-
- return requires_replacement;
-}
-
-// Returns instructions that have to become custom kernel fusion parameters.
-// Returns an error if matched pattern can't be outlined as a fusion.
-static absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
- const CustomKernelFusionPattern::Match& match) {
- absl::InlinedVector<HloInstruction*, 4> captures;
-
- absl::flat_hash_set<HloInstruction*> instructions_set(
- match.instructions().begin(), match.instructions().end());
-
- for (HloInstruction* instr : match.instructions()) {
- for (HloInstruction* operand : instr->operands()) {
- if (!instructions_set.contains(operand) &&
- absl::c_find(captures, operand) == captures.end()) {
- captures.emplace_back(operand);
- }
- }
- }
-
- return captures;
-}
-
-// Creates custom kernel fusion computation and moves all matched instructions
-// into it.
-static absl::StatusOr<HloComputation*> CreateFusionBody(
- HloModule* module, const CustomKernelFusionPattern::Match& match,
- absl::Span<HloInstruction* const> captures) {
- HloComputation::Builder builder(match.config().name());
-
- // A mapping from original instructions to instructions in the fusion body.
- absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
-
- auto mapped_operands = [&](HloInstruction* instr) {
- absl::InlinedVector<HloInstruction*, 4> operands;
- for (HloInstruction* operand : instr->operands()) {
- operands.push_back(instr_mapping.at(operand));
- }
- return operands;
- };
-
- // For every captured value create a parameter instruction in the computation
- // body and set up instruction mapping.
- for (const HloInstruction* capture : captures) {
- int64_t index = instr_mapping.size();
- instr_mapping[capture] =
- builder.AddInstruction(HloInstruction::CreateParameter(
- index, capture->shape(), absl::StrCat("p", index)));
- }
-
- // TODO(ezhulenev): Instructions in the pattern must be topologically sorted,
- // otherwise we'll get a crash! Figure out how to do it!
- for (HloInstruction* instr : match.instructions()) {
- instr_mapping[instr] = builder.AddInstruction(
- instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
- }
-
- HloInstruction* root = builder.last_added_instruction();
-
- // If custom kernel fusion requires a workspace we add a custom call that
- // allocates workspace and return a tuple of "real" result and a workspace.
- if (match.workspace_size_bytes() > 0) {
- auto workspace_shape =
- ShapeUtil::MakeShape(PrimitiveType::U8, {match.workspace_size_bytes()});
- HloInstruction* workspace =
- builder.AddInstruction(HloInstruction::CreateCustomCall(
- workspace_shape, {}, CustomKernelFusionPattern::kWorkspace, "",
- CustomCallApiVersion::API_VERSION_TYPED_FFI));
- builder.AddInstruction(HloInstruction::CreateTuple({root, workspace}));
- }
-
- return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
-}
-
-static absl::StatusOr<HloInstruction*> CreateFusionInstruction(
- HloModule* module, const CustomKernelFusionPattern::Match& match,
- absl::Span<HloInstruction* const> captures, HloComputation* body) {
- // We'll be replacing the root operation of a custom kernel fusion with a
- // fusion instruction calling fusion computation.
- HloInstruction* root = match.root();
- HloComputation* parent = root->parent();
-
- // Add a fusion operation calling outlined fusion computation.
- HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
- body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
- captures, body));
- module->SetAndUniquifyInstrName(fusion, match.config().name());
-
- // Set backends config to a matched custom fusion config.
- GpuBackendConfig gpu_config;
- FusionBackendConfig& backend_config =
- *gpu_config.mutable_fusion_backend_config();
- backend_config.set_kind("__custom_fusion");
- *backend_config.mutable_custom_fusion_config() = match.config();
- backend_config.mutable_custom_fusion_config()->set_kernel_index(0);
- TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
-
- // If we don't have workspace we can return constructed fusion instruction.
- if (match.workspace_size_bytes() == 0) return fusion;
-
- // Otherwise have to get result corresponding to the original value;
- return parent->AddInstruction(
- HloInstruction::CreateGetTupleElement(fusion, 0));
-}
-
-absl::StatusOr<bool> CustomKernelFusionRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- std::vector<CustomKernelFusionPattern::Match> matches;
-
- // Collect all potential custom fusion matches in the module.
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instr : computation->instructions()) {
- auto matched = patterns_->Match(*device_, instr);
- matches.insert(matches.end(), matched.begin(), matched.end());
- }
- }
-
- if (matches.empty()) return false;
-
- for (const CustomKernelFusionPattern::Match& match : matches) {
- VLOG(2) << "Matched custom kernel fusion " << match.config().name()
- << "; root instruction: " << match.instructions().back()->name();
-
- auto replacememts = GetPatternReplacements(match);
- if (!replacememts.has_value()) continue;
-
- auto captures = GetPatternCaptures(match);
-
- TF_ASSIGN_OR_RETURN(HloComputation * fusion_body,
- CreateFusionBody(module, match, captures));
- TF_ASSIGN_OR_RETURN(
- HloInstruction * fusion,
- CreateFusionInstruction(module, match, captures, fusion_body));
-
- VLOG(2) << "Added a fusion instruction: " << fusion->name()
- << " for custom kernel fusion " << match.config().name()
- << " (instruction count = " << match.instructions().size() << ")";
-
- for (HloInstruction* instr : *replacememts) {
- VLOG(2) << "Replace matched instruction: " << instr->name()
- << " with a pattern replacement";
-
- TF_ASSIGN_OR_RETURN(
- HloInstruction * replacement,
- match.BuildReplacement(instr, Cast<HloFusionInstruction>(fusion)));
-
- TF_RETURN_IF_ERROR(
- instr->ReplaceAllUsesWith(replacement, match.config().name()));
-
- VLOG(2) << "Replaced instruction: " << instr->name()
- << " with: " << replacement->name();
- }
-
- VLOG(2) << "Replace custom kernel fusion root instruction "
- << match.root()->name() << "with " << fusion->name();
- HloComputation* parent = match.root()->parent();
- TF_RETURN_IF_ERROR(parent->ReplaceInstruction(match.root(), fusion));
- }
-
- return true;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h b/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h
deleted file mode 100644
index cb19d91..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_
-#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla::gpu {
-
-// Pattern matches HLO instruction to custom kernel fusions (hand written CUDA
-// C++ kernels, e.g. custom GEMMs implemented with CUTLASS) and rewrites them
-// into fusion instructions and fusion computations.
-//
-// Example: pattern matching dot operation into CUTLASS gemm
-//
-// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
-// %p0 = f16[15,19]{1,0} parameter(0)
-// %p1 = f16[19,17]{1,0} parameter(1)
-// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
-// lhs_contracting_dims={1}, rhs_contracting_dims={0}
-// }
-//
-// After the pass:
-//
-// %cutlass_gemm (p0: f16[19,17], p1: f16[15,19]) -> f16[15,17] {
-// %p0 = f16[15,19]{1,0} parameter(0)
-// %p1 = f16[19,17]{1,0} parameter(1)
-// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
-// lhs_contracting_dims={1}, rhs_contracting_dims={0}
-// }
-//
-// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
-// %p0 = f16[15,19]{1,0} parameter(0)
-// %p1 = f16[19,17]{1,0} parameter(1)
-// ROOT %r = f16[15,17]{1,0} fusion(%p0, %p1), kind=kCustom,
-// calls==cutlass_gemm,
-// backend_config={kind: "__custom_fusion",
-// custom_fusion_config: {"name":"cutlass_gemm"}}
-// }
-//
-class CustomKernelFusionRewriter : public HloModulePass {
- public:
- explicit CustomKernelFusionRewriter(
- const se::DeviceDescription* device,
- const CustomKernelFusionPatternRegistry* patterns =
- CustomKernelFusionPatternRegistry::Default());
-
- absl::string_view name() const override {
- return "custom-kernel-fusion-rewriter";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const se::DeviceDescription* device_;
- const CustomKernelFusionPatternRegistry* patterns_;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc
deleted file mode 100644
index f2c824c..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla::gpu {
-
-//===----------------------------------------------------------------------===//
-// Simple pattern matchers for testing custom kernel_fusion rewriter.
-//===----------------------------------------------------------------------===//
-
-struct SimpleGemmPattern : public CustomKernelFusionPattern {
- explicit SimpleGemmPattern(int64_t workspace = 0) : workspace(workspace) {}
-
- std::optional<Match> TryMatch(const se::DeviceDescription& device,
- HloInstruction* instr) const override {
- if (auto* dot = DynCast<HloDotInstruction>(instr)) {
- CustomFusionConfig config;
- config.set_name("simple_gemm");
- return Match{config, {instr}, workspace};
- }
- return std::nullopt;
- }
-
- int64_t workspace;
-};
-
-//===----------------------------------------------------------------------===//
-
-class CustomKernelFusionRewriterTest : public HloTestBase {};
-
-TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
- %p0 = f16[15,19]{1,0} parameter(0)
- %p1 = f16[19,17]{1,0} parameter(1)
- ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %simple_gemm {{.*}} {
- ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
- ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
- ; CHECK: ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
- ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ; CHECK: }
-
- ; CHECK: ENTRY %main {{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[15,17]{1,0} fusion
- ; CHECK: kind=kCustom, calls=%simple_gemm,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- CustomKernelFusionPatternRegistry patterns;
- patterns.Emplace<SimpleGemmPattern>();
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- CustomKernelFusionRewriter pass(&device, &patterns);
- RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
-}
-
-TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
- %p0 = f16[15,19]{1,0} parameter(0)
- %p1 = f16[19,17]{1,0} parameter(1)
- ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %simple_gemm {{.*}} {
- ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
- ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
- ; CHECK: [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
- ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ; CHECK: [[WORKSPACE:%[^ ]+]] = u8[1024]{0} custom-call(),
- ; CHECK: custom_call_target="__custom_kernel_fusion$workspace"
- ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0})
- ; CHECK: tuple([[DOT]], [[WORKSPACE]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main {{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0}) fusion
- ; CHECK: kind=kCustom, calls=%simple_gemm,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT {{.*}} get-tuple-element([[FUSION]]), index=0
- ; CHECK: }
- )";
-
- CustomKernelFusionPatternRegistry patterns;
- patterns.Emplace<SimpleGemmPattern>(1024);
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- CustomKernelFusionRewriter pass(&device, &patterns);
- RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc
index 93c5b15..45b9704 100644
--- a/third_party/xla/xla/service/gpu/determinism_test.cc
+++ b/third_party/xla/xla/service/gpu/determinism_test.cc
@@ -22,7 +22,7 @@
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_timer.h"
diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc b/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc
deleted file mode 100644
index 38920ee..0000000
--- a/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc
+++ /dev/null
@@ -1,136 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_dimension_sorter.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/permutation_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// Sort contracting dimensions of a dot() instruction preserving lhs-rhs pairs.
-absl::Status SortDotDimensions(HloDotInstruction* dot) {
- const DotDimensionNumbers& dims = dot->dot_dimension_numbers();
- DotDimensionNumbers new_dims(dims);
- new_dims.clear_lhs_contracting_dimensions();
- new_dims.clear_rhs_contracting_dimensions();
- const bool sort_by_lhs =
- DistinctNumbersAreConsecutiveIfSorted(dims.lhs_contracting_dimensions());
- // Sort lhs and rhs by sort_key using the fact that
- // sort_key is guaranteed to have only distinct consecutive numbers.
- const absl::Span<const int64_t>& sort_key =
- sort_by_lhs ? dims.lhs_contracting_dimensions()
- : dims.rhs_contracting_dimensions();
- std::vector<int64_t> permutation;
- for (const int64_t a : sort_key) {
- permutation.push_back(a - *absl::c_min_element(sort_key));
- }
- const std::vector<int64_t> sorted_lhs =
- Permute(dims.lhs_contracting_dimensions(), permutation);
- *new_dims.mutable_lhs_contracting_dimensions() = {sorted_lhs.begin(),
- sorted_lhs.end()};
- const std::vector<int64_t> sorted_rhs =
- Permute(dims.rhs_contracting_dimensions(), permutation);
- *new_dims.mutable_rhs_contracting_dimensions() = {sorted_rhs.begin(),
- sorted_rhs.end()};
- std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
- dot->shape(), dot->mutable_operand(0), dot->mutable_operand(1), new_dims,
- dot->precision_config(), {dot->sparsity().begin(), dot->sparsity().end()},
- absl::MakeSpan(dot->operands()).subspan(HloDotInstruction::kOperands));
- dot->SetupDerivedInstruction(new_dot.get());
-
- VLOG(3) << "Sorted dot() dimensions:\n"
- << "\t before: " << dot->ToString() << "\n"
- << "\t after: " << new_dot->ToString();
- return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
-}
-
-} // namespace
-
-absl::StatusOr<bool> DotDimensionSorter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- std::vector<HloInstruction*> dots_to_process;
- for (const HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* instr : computation->instructions()) {
- if (instr->opcode() != HloOpcode::kDot) {
- continue;
- }
- // TODO(b/265688934): should non-default layouts be expected here at all?
- if ((instr->operand(0)->shape().has_layout() &&
- !LayoutUtil::IsMonotonicWithDim0Major(
- instr->operand(0)->shape().layout())) ||
- (instr->operand(1)->shape().has_layout() &&
- !LayoutUtil::IsMonotonicWithDim0Major(
- instr->operand(1)->shape().layout()))) {
- continue;
- }
- const DotDimensionNumbers& dims = instr->dot_dimension_numbers();
- if (dims.lhs_contracting_dimensions_size() == 0) {
- continue;
- }
- const bool cons_lhs = DistinctNumbersAreConsecutiveIfSorted(
- dims.lhs_contracting_dimensions());
- const bool cons_rhs = DistinctNumbersAreConsecutiveIfSorted(
- dims.rhs_contracting_dimensions());
- const bool sorted_lhs =
- absl::c_is_sorted(dims.lhs_contracting_dimensions());
- const bool sorted_rhs =
- absl::c_is_sorted(dims.rhs_contracting_dimensions());
- // The side to be sorted has to be consecutive and not sorted yet;
- // the other side should not get worsened.
- // TODO(b/265688934): we may still want to change which one is sorted
- // if this reduces the amount of transposed data.
- if ((cons_lhs && !sorted_lhs && !cons_rhs) ||
- (cons_rhs && !sorted_rhs && !cons_lhs) ||
- (cons_lhs && !sorted_lhs && cons_rhs && !sorted_rhs)) {
- dots_to_process.push_back(instr);
- }
- }
- }
- if (dots_to_process.empty()) {
- return false;
- }
- for (HloInstruction* dot : dots_to_process) {
- TF_RETURN_IF_ERROR(SortDotDimensions(Cast<HloDotInstruction>(dot)));
- }
- return true;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter.h b/third_party/xla/xla/service/gpu/dot_dimension_sorter.h
deleted file mode 100644
index 5eadeb1..0000000
--- a/third_party/xla/xla/service/gpu/dot_dimension_sorter.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_
-#define XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Sorts contracting dimensions of dot() operands when this reduces the
-// number of transposes. Example:
-// dot(p0, p1), lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1} ->
-// dot(p0, p1), lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-// The first case gets transposes inserted by dot_decomposer, the second one
-// does not and thus is generally more efficient.
-
-// TODO(b/265688934): do the same for batch dimensions?
-
-class DotDimensionSorter : public HloModulePass {
- public:
- absl::string_view name() const override { return "dot_dimension_sorter"; }
-
- // Run the pass on computations in 'module'.
- // Returns whether the 'module' was changed.
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_
diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc b/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc
deleted file mode 100644
index fedd1ea..0000000
--- a/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc
+++ /dev/null
@@ -1,191 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_dimension_sorter.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class WithoutDotDimensionSorterTest : public GpuCodegenTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- // The pass is disabled here to preserve suboptimal dimension order in
- // 1) UnsortedDimsCreateTransposes to reveal the transposes.
- // 2) DimOrderCanBeChanged for the comparison of ordered vs unordered.
- // The pass does not touch SortedDimsDoNotCreateTransposes anyway because
- // the dimensions are already ordered there.
- debug_options.add_xla_disable_hlo_passes("dot_dimension_sorter");
- return debug_options;
- }
-};
-
-TEST_F(WithoutDotDimensionSorterTest, UnsortedDimsCreateTransposes) {
- const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[1,14,9,32] parameter(0)
- p1 = f16[12,9,32] parameter(1)
- ROOT _ = f16[1,14,12] dot(p0, p1),
- lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK: transpose
-)");
-}
-
-TEST_F(WithoutDotDimensionSorterTest, SortedDimsDoNotCreateTransposes) {
- const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[1,14,9,32] parameter(0)
- p1 = f16[12,9,32] parameter(1)
- ROOT _ = f16[1,14,12] dot(p0, p1),
- lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-}
-)";
-
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT: transpose
-)");
-}
-
-TEST_F(WithoutDotDimensionSorterTest, DimOrderCanBeChanged) {
- const char* hlo_text_ref = R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[1,14,9,32] parameter(0)
- p1 = f16[12,9,32] parameter(1)
- ROOT _ = f16[1,14,12] dot(p0, p1),
- lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
- const char* hlo_text_modified = R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[1,14,9,32] parameter(0)
- p1 = f16[12,9,32] parameter(1)
- ROOT _ = f16[1,14,12] dot(p0, p1),
- lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-}
-)";
-
- EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_modified,
- ErrorSpec{1e-5, 1e-3},
- /*run_hlo_passes=*/true));
-}
-
-using DotDimensionSorterTest = GpuCodegenTest;
-
-TEST_F(DotDimensionSorterTest, SortContractingDims) {
- const char* module_string = R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[1,144,96,32] parameter(0)
- p1 = f16[122,96,32] parameter(1)
- ROOT _ = f16[1,144,122] dot(p0, p1),
- lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(module_string));
- const auto& dims =
- module->entry_computation()->root_instruction()->dot_dimension_numbers();
-
- EXPECT_EQ(dims.lhs_contracting_dimensions(0), 3);
- EXPECT_EQ(dims.lhs_contracting_dimensions(1), 2);
-
- EXPECT_EQ(dims.rhs_contracting_dimensions(0), 2);
- EXPECT_EQ(dims.rhs_contracting_dimensions(1), 1);
-
- TF_ASSERT_OK_AND_ASSIGN(bool modified,
- DotDimensionSorter().Run(module.get()));
- EXPECT_TRUE(modified);
- const auto& dims2 =
- module->entry_computation()->root_instruction()->dot_dimension_numbers();
-
- EXPECT_EQ(dims2.lhs_contracting_dimensions(0), 2);
- EXPECT_EQ(dims2.lhs_contracting_dimensions(1), 3);
-
- EXPECT_EQ(dims2.rhs_contracting_dimensions(0), 1);
- EXPECT_EQ(dims2.rhs_contracting_dimensions(1), 2);
-}
-
-TEST_F(DotDimensionSorterTest, NothingToReorder) {
- const char* module_string = R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[1,144,96,32] parameter(0)
- p1 = f16[122,96,32] parameter(1)
- ROOT _ = f16[1,144,122] dot(p0, p1),
- lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(module_string));
-
- TF_ASSERT_OK_AND_ASSIGN(bool modified,
- DotDimensionSorter().Run(module.get()));
- EXPECT_FALSE(modified);
-}
-
-TEST_F(DotDimensionSorterTest, SparseDotSortContractingDims) {
- const char* module_string = R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[1,144,96,16] parameter(0)
- p1 = f16[122,96,32] parameter(1)
- meta = u16[1,144,96,2] parameter(2)
- ROOT _ = f16[1,144,122] dot(p0, p1, meta), sparsity=L.3@2:4,
- lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(module_string));
- TF_ASSERT_OK_AND_ASSIGN(bool modified,
- DotDimensionSorter().Run(module.get()));
- EXPECT_TRUE(modified);
- HloDotInstruction* dot = DynCast<HloDotInstruction>(
- module->entry_computation()->root_instruction());
- EXPECT_TRUE(dot != nullptr && dot->sparse_operands() == 1);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/dot_operand_converter.cc
deleted file mode 100644
index 2a298e6..0000000
--- a/third_party/xla/xla/service/gpu/dot_operand_converter.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_operand_converter.h"
-
-#include "absl/status/statusor.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/errors.h"
-
-namespace xla::gpu {
-
-bool DotOperandConverter::InstructionMatchesPattern(
- HloInstruction* instruction) {
- if (instruction->opcode() != HloOpcode::kDot) {
- return false;
- }
- HloInstruction* lhs = instruction->mutable_operand(0);
- HloInstruction* rhs = instruction->mutable_operand(1);
-
- PrimitiveType lhs_type = lhs->shape().element_type();
- PrimitiveType rhs_type = rhs->shape().element_type();
-
- if (lhs_type == rhs_type) {
- return false;
- }
-
- // Exclude conversions between FP8 types.
- absl::flat_hash_set<PrimitiveType> non_converting = {F8E4M3FN, F8E5M2};
- if (non_converting.contains(lhs_type) && non_converting.contains(rhs_type)) {
- return false;
- }
-
- PrimitiveType desired_type =
- ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
-
- return desired_type == lhs_type || desired_type == rhs_type;
-}
-
-absl::StatusOr<HloInstruction*> DotOperandConverter::ExpandInstruction(
- HloInstruction* instruction) {
- HloInstruction* lhs = instruction->mutable_operand(0);
- HloInstruction* rhs = instruction->mutable_operand(1);
-
- // Find the higher precision type among the two operands, and add a convert
- // instruction to convert the lesser-precise operand to that type.
- PrimitiveType desired_type =
- ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
- int operand_index = desired_type == lhs->shape().element_type() ? 1 : 0;
- HloInstruction* inst_to_replace =
- desired_type == lhs->shape().element_type() ? rhs : lhs;
- auto upcast_shape = inst_to_replace->shape();
- upcast_shape.set_element_type(desired_type);
- auto* convert_inst = instruction->AddInstruction(
- HloInstruction::CreateConvert(upcast_shape, inst_to_replace));
- TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape(
- operand_index, convert_inst));
- return nullptr;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.h b/third_party/xla/xla/service/gpu/dot_operand_converter.h
deleted file mode 100644
index d277a24..0000000
--- a/third_party/xla/xla/service/gpu/dot_operand_converter.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_
-#define XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_
-
-#include <utility>
-
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/op_expander_pass.h"
-#include "xla/util.h"
-
-namespace xla::gpu {
-
-// Converts both operands to the highest precision operand type.
-class DotOperandConverter : public OpExpanderPass {
- public:
- explicit DotOperandConverter(HloPredicate extra_filter = nullptr)
- : OpExpanderPass(std::move(extra_filter)) {}
-
- absl::string_view name() const override { return "operand_converter"; }
-
- protected:
- bool InstructionMatchesPattern(HloInstruction* instruction) override;
-
- absl::StatusOr<HloInstruction*> ExpandInstruction(
- HloInstruction* instruction) override;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_
diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc b/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc
deleted file mode 100644
index 63b0017..0000000
--- a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_operand_converter.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/utils/hlo_matchers.h"
-#include "xla/primitive_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace op = ::xla::testing::opcode_matchers;
-
-class DotOperandConverterTest : public HloTestBase {
- public:
- void TestConvert(bool left_less_precise, PrimitiveType lhs_type,
- PrimitiveType rhs_type, PrimitiveType result_type) {
- absl::string_view module_tmpl = R"(
- HloModule module
-
- ENTRY main {
- p0 = $0[2,3]{1,0} parameter(0)
- p1 = $1[3,2]{1,0} parameter(1)
- ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
- rhs_contracting_dims={0}
- })";
- auto module_string = absl::Substitute(
- module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type),
- primitive_util::LowercasePrimitiveTypeName(rhs_type),
- primitive_util::LowercasePrimitiveTypeName(result_type));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(module_string));
- TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
- DotOperandConverter().Run(module.get()));
- EXPECT_TRUE(upcasted);
- if (left_less_precise) {
- auto original_lhs = op::Parameter(0);
- auto upcasted_lhs =
- AllOf(op::Convert(original_lhs),
- op::Shape(absl::Substitute(
- "$0[2,3]{1,0}",
- primitive_util::LowercasePrimitiveTypeName(rhs_type))));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- AllOf(op::Dot(upcasted_lhs, op::Parameter(1)),
- op::Shape(absl::Substitute(
- "$0[2,2]{1,0}",
- primitive_util::LowercasePrimitiveTypeName(result_type)))));
- } else {
- auto original_rhs = op::Parameter(1);
- auto upcasted_rhs =
- AllOf(op::Convert(original_rhs),
- op::Shape(absl::Substitute(
- "$0[3,2]{1,0}",
- primitive_util::LowercasePrimitiveTypeName(lhs_type))));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- AllOf(op::Dot(op::Parameter(0), upcasted_rhs),
- op::Shape(absl::Substitute(
- "$0[2,2]{1,0}",
- primitive_util::LowercasePrimitiveTypeName(result_type)))));
- }
- }
-};
-
-TEST_F(DotOperandConverterTest, ConvertsLeftAndRight) {
- TestConvert(/*left_less_precise=*/true, S8, BF16, F32);
- TestConvert(/*left_less_precise=*/false, BF16, S8, F32);
-}
-
-TEST_F(DotOperandConverterTest, NoConvertHappensWithSameTypes) {
- absl::string_view module_string = R"(
- HloModule module
-
- ENTRY main {
- p0 = s8[2,3]{1,0} parameter(0)
- p1 = s8[3,2]{1,0} parameter(1)
- ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
- rhs_contracting_dims={0}
- })";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(module_string));
- TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
- DotOperandConverter().Run(module.get()));
- EXPECT_FALSE(upcasted);
-}
-
-TEST_F(DotOperandConverterTest, NoConvertFromF8toF8) {
- absl::string_view module_string = R"(
- HloModule module
-
- ENTRY main {
- p0 = f8e4m3fn[2,3]{1,0} parameter(0)
- p1 = f8e5m2[3,2]{1,0} parameter(1)
- ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
- rhs_contracting_dims={0}
- })";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(module_string));
- TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
- DotOperandConverter().Run(module.get()));
- EXPECT_FALSE(upcasted);
-}
-
-TEST_F(DotOperandConverterTest, CompilerOptimizesUsingDotOperandConverter) {
- absl::string_view module_string = R"(
- HloModule module
-
- ENTRY main {
- p0 = s8[2,3]{1,0} parameter(0)
- p1 = bf16[3,2]{1,0} parameter(1)
- ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
- rhs_contracting_dims={0}
- })";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- GetOptimizedModule(module_string));
-}
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc b/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc
deleted file mode 100644
index 0f41091..0000000
--- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc
+++ /dev/null
@@ -1,110 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_sparsity_rewriter.h"
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class SparseDotRewriterImpl : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleDot(HloInstruction* instr) override {
- // Only handle sparse dots with a single RHS sparse descriptor.
- HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
- if (dot->sparse_operands() != 1 || dot->sparsity().front().index() != 1) {
- return absl::OkStatus();
- }
-
- HloInstruction* lhs = dot->mutable_operand(0);
- HloInstruction* rhs = dot->mutable_operand(1);
- HloInstruction* meta = dot->mutable_operand(2);
-
- // Swap LHS and RHS in the attributes.
- DotDimensionNumbers dnums = dot->dot_dimension_numbers();
- std::swap(*dnums.mutable_lhs_batch_dimensions(),
- *dnums.mutable_rhs_batch_dimensions());
- std::swap(*dnums.mutable_lhs_contracting_dimensions(),
- *dnums.mutable_rhs_contracting_dimensions());
-
- PrecisionConfig precision_config = dot->precision_config();
- std::swap(precision_config.mutable_operand_precision()->at(0),
- precision_config.mutable_operand_precision()->at(1));
-
- SparsityDescriptor sparsity = dot->sparsity().front();
- sparsity.set_index(0);
-
- // Create new dot with LHS and RHS swapped.
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_dot,
- MakeDotHlo(rhs, lhs, dnums, precision_config,
- dot->shape().element_type(), {std::move(sparsity)}, {meta}));
- dot->SetupDerivedInstruction(new_dot);
-
- // Result dimensions: <batch>, <rhs_noncontracting>, <lhs_noncontracting>
- int batch_dims = dnums.lhs_batch_dimensions().size();
- int new_lhs_noncontracting = rhs->shape().rank() - batch_dims -
- dnums.lhs_contracting_dimensions().size();
- int new_rhs_noncontracting = lhs->shape().rank() - batch_dims -
- dnums.rhs_contracting_dimensions().size();
-
- int rank = dot->shape().rank();
- DimensionVector dimensions(rank);
- for (int i = 0; i < batch_dims; ++i) {
- dimensions[i] = i;
- }
- for (int i = 0; i < new_lhs_noncontracting; ++i) {
- dimensions[i + batch_dims] = i + batch_dims + new_rhs_noncontracting;
- }
- for (int i = 0; i < new_rhs_noncontracting; ++i) {
- dimensions[i + batch_dims + new_lhs_noncontracting] = i + batch_dims;
- }
-
- // Transpose the result.
- TF_ASSIGN_OR_RETURN(HloInstruction * transpose,
- MakeTransposeHlo(new_dot, dimensions));
- transpose->set_metadata(dot->metadata());
- *transpose->mutable_shape()->mutable_layout() = dot->shape().layout();
-
- return ReplaceInstruction(dot, transpose);
- }
-};
-
-} // namespace
-
-absl::StatusOr<bool> DotSparsityRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- return SparseDotRewriterImpl().RunOnModule(module, execution_threads);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h b/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h
deleted file mode 100644
index b422197..0000000
--- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_
-#define XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Make sure sparse dot requirements are met (sparse operand is LHS).
-class DotSparsityRewriter : public HloModulePass {
- public:
- absl::string_view name() const override { return "dot_sparsity_rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc b/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc
deleted file mode 100644
index c608f8d..0000000
--- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_sparsity_rewriter.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-
-class DotSparsityRewriterTest : public HloTestBase {
- public:
- DotSparsityRewriterTest() : HloTestBase(/*verifier_layout_sensitive=*/true) {}
-};
-
-TEST_F(DotSparsityRewriterTest, SparseDotRhsToLhs) {
- const char* module_string = R"(
-HloModule m
-
-ENTRY e {
- lhs = f16[4,2,16,8,64] parameter(0)
- rhs = f16[2,4,8,32,128] parameter(1)
- meta = u16[2,4,8,4,128] parameter(2)
- ROOT dot = f16[4,2,16,128] dot(lhs, rhs, meta),
- lhs_contracting_dims={3,4}, rhs_contracting_dims={2,3},
- lhs_batch_dims={0,1}, rhs_batch_dims={1,0}, sparsity=R.3@2:4
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(module_string));
- TF_ASSERT_OK_AND_ASSIGN(bool modified,
- DotSparsityRewriter().Run(module.get()));
- EXPECT_TRUE(modified);
-
- const HloTransposeInstruction* transpose = DynCast<HloTransposeInstruction>(
- module->entry_computation()->root_instruction());
- ASSERT_TRUE(transpose != nullptr);
- EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 1, 3, 2));
-
- const HloDotInstruction* dot =
- DynCast<HloDotInstruction>(transpose->operand(0));
- ASSERT_TRUE(dot != nullptr);
-
- const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
- EXPECT_EQ(dnums.lhs_contracting_dimensions(0), 2);
- EXPECT_EQ(dnums.lhs_contracting_dimensions(1), 3);
- EXPECT_EQ(dnums.rhs_contracting_dimensions(0), 3);
- EXPECT_EQ(dnums.rhs_contracting_dimensions(1), 4);
- EXPECT_EQ(dnums.lhs_batch_dimensions(0), 1);
- EXPECT_EQ(dnums.lhs_batch_dimensions(1), 0);
- EXPECT_EQ(dnums.rhs_batch_dimensions(0), 0);
- EXPECT_EQ(dnums.rhs_batch_dimensions(1), 1);
-
- EXPECT_EQ(dot->sparse_operands(), 1);
- EXPECT_EQ(dot->sparsity().front().index(), 0);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc
deleted file mode 100644
index 9cd9113..0000000
--- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc
+++ /dev/null
@@ -1,569 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/double_buffer_loop_unrolling.h"
-
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instruction_utils.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/flatten_call_graph.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-void SetChannelIdForNewCollective(HloInstruction* new_instr,
- const HloModule* module) {
- // This is to track mappings of old->new channel id for async collectives
- // wrapped in the form of HloAsyncInstruction, the start and done need to
- // have the same unique channel id.
- absl::flat_hash_map<int64_t, int64_t> old_to_new_channel_id_map;
- absl::flat_hash_map<int64_t, HloComputation*> channel_id_comp_map;
- if (new_instr->IsAsynchronous() && hlo_query::IsCollectiveCommunicationOp(
- new_instr->async_wrapped_opcode())) {
- HloInstruction* wrapped_instr =
- DynCast<HloAsyncInstruction>(new_instr)->async_wrapped_instruction();
- int64_t old_channel_id = *wrapped_instr->channel_id();
- int64_t new_channel_id = old_to_new_channel_id_map[old_channel_id];
- if (old_to_new_channel_id_map.find(old_channel_id) ==
- old_to_new_channel_id_map.end()) {
- new_channel_id = hlo_query::NextChannelId(*module);
- VLOG(2) << "Generated new channel id " << new_channel_id;
- old_to_new_channel_id_map[old_channel_id] = new_channel_id;
- }
-
- VLOG(2) << "Setting channel id to " << new_channel_id;
-
- wrapped_instr->set_channel_id(new_channel_id);
- if (channel_id_comp_map.find(new_channel_id) == channel_id_comp_map.end()) {
- channel_id_comp_map[new_channel_id] =
- new_instr->async_wrapped_computation();
- } else {
- channel_id_comp_map[new_channel_id]->AddAsyncStart(new_instr);
- }
- } else if (hlo_query::IsCollectiveCommunicationOp(new_instr->opcode()) ||
- hlo_query::IsAsyncCollectiveStartOp(new_instr)) {
- new_instr->set_channel_id(hlo_query::NextChannelId(*module));
- }
-}
-
-using Interval = std::pair<int64_t, int64_t>;
-
-// Parses a string of the format `{{a,b},{c,d},{e,f}...}` to a vector of pairs.
-absl::StatusOr<std::vector<Interval>> ParseVectorOfPairs(
- absl::string_view str) {
- TF_ASSIGN_OR_RETURN(std::vector<ReplicaGroup> replica_groups,
- ParseReplicaGroupsOnly(str));
- std::vector<Interval> res;
- res.reserve(replica_groups.size());
- for (const ReplicaGroup& replica_group : replica_groups) {
- TF_RET_CHECK(replica_group.replica_ids_size() == 2);
- int64_t a = replica_group.replica_ids(0);
- int64_t b = replica_group.replica_ids(1);
- res.emplace_back(a, b);
- }
- return res;
-}
-
-// This function fixes the `_xla_send_recv_validation` attribute for peeled
-// instructions. When the loop trip count is odd, the peeled instructions are
-// moved before the loop. The collectives in these instructions correspond to
-// the first iteration of the original loop. We have to run this peeled
-// collective for all those devices that had the 0-th iteration as a valid
-// iteration.
-absl::Status SetSendRecvValidationForPeeledInstr(HloInstruction* new_instr,
- HloInstruction* old_instr) {
- TF_RET_CHECK(
- new_instr->opcode() == old_instr->opcode() &&
- "cloned instruction and original instruction have different opcodes");
- if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
- HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
- HloOpcode::kRecv>(old_instr)) {
- return absl::OkStatus();
- }
-
- const auto& attribute_map = new_instr->frontend_attributes().map();
- if (!attribute_map.contains(kSendRecvValidationAttr)) {
- return absl::OkStatus();
- }
-
- VLOG(3) << "Original send-recv iterations: "
- << attribute_map.at(kSendRecvValidationAttr);
-
- TF_ASSIGN_OR_RETURN(
- auto send_recv_validation_attr,
- ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
-
- uint64_t n_pairs = send_recv_validation_attr.size();
- if (n_pairs == 0) {
- return absl::OkStatus();
- }
- std::vector<Interval> send_recv_validation_attr_updated(n_pairs, {1, 0});
- // Check which of the attributes have iteration number zero as valid
- // iteration. For all those, set the peeled instruction to run.
- for (std::uint64_t i = 0; i < send_recv_validation_attr.size(); i++) {
- if (send_recv_validation_attr[i].first <= 0 &&
- send_recv_validation_attr[i].second >= 0) {
- send_recv_validation_attr_updated[i] = {0, 0};
- }
- }
-
- hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
- /*instr=*/new_instr, /*attr_name=*/kSendRecvValidationAttr,
- /*intervals=*/send_recv_validation_attr_updated);
- return absl::OkStatus();
-}
-
-// This function fixes the `_xla_send_recv_validation` attribute for the two new
-// collectives inside the loop. The calculation of the new valid iterations
-// depends on whether the loop was peeled or not.
-//
-// If the loop was not peeled, then
-// - iteration 0 of the new loop coressponds to iteration 0,1 of the old loop.
-// - iteration 1 of the new loop coressponds to iteration 2,3 of the old loop.
-// - and so on...
-// If the loop was peeled, then the first iteration runs before the loop. So,
-// - iteration 0 of the new loop coressponds to iteration 1,2 of the old loop.
-// - iteration 1 of the new loop coressponds to iteration 3,4 of the old loop.
-// - and so on...
-//
-// Consider the case when the loop was peeled, and the original attribute for
-// some device was {4,7}. Consider that the two new collectives are
-// `collective.1` and `collective.2` (they execute in this order inside the new
-// loop). In the old loop, iterations 4,5,6,7 were valid. In the new
-// loop,
-// - collective.2 in iteration 1 of new loop runs 4th iteration of old loop.
-// - collective.1 in iteration 2 of new loop runs 5th iteration of old loop.
-// - collective.2 in iteration 2 of new loop runs 6th iteration of old loop.
-// - collective.1 in iteration 3 of new loop runs 7th iteration of old loop.
-// So, the updated attribute for that device are {1,2} for `collective.2` and
-// {2,3} for `collective.1`.
-//
-// In a similar fashion we can generalize the computation of new values based on
-// the values of the old attribute as done in the logic below.
-absl::Status SetSendRecvValidation(HloInstruction* cp1, HloInstruction* cp2,
- bool is_peeled) {
- TF_RET_CHECK(
- cp2->opcode() == cp1->opcode() &&
- "cloned instruction and original instruction have different opcodes");
- if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
- HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
- HloOpcode::kRecv>(cp1)) {
- return absl::OkStatus();
- }
- const auto& attribute_map = cp2->frontend_attributes().map();
- if (!attribute_map.contains(kSendRecvValidationAttr)) {
- return absl::OkStatus();
- }
- VLOG(3) << "Original send-recv iterations: "
- << attribute_map.at(kSendRecvValidationAttr);
-
- TF_ASSIGN_OR_RETURN(
- auto send_recv_validation_attr,
- ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
-
- if (send_recv_validation_attr.size() == 0) {
- return absl::OkStatus();
- }
-
- std::vector<Interval> send_recv_iterations_new_instr1,
- send_recv_iterations_new_instr2;
- send_recv_iterations_new_instr1.reserve(send_recv_validation_attr.size());
- send_recv_iterations_new_instr2.reserve(send_recv_validation_attr.size());
- for (const Interval& pair : send_recv_validation_attr) {
- int64_t a = pair.first;
- int64_t b = pair.second;
- if (is_peeled) {
- send_recv_iterations_new_instr1.emplace_back(
- std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
- send_recv_iterations_new_instr2.emplace_back(
- std::max(0.0, std::floor((a - 1) / 2.0)),
- std::max(0.0, std::floor((b - 2) / 2.0)));
- } else {
- send_recv_iterations_new_instr1.emplace_back(std::floor((a + 1) / 2.0),
- std::floor(b / 2.0));
- send_recv_iterations_new_instr2.emplace_back(
- std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
- }
- }
-
- hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
- /*instr=*/cp1, /*attr_name=*/kSendRecvValidationAttr,
- /*intervals=*/send_recv_iterations_new_instr1);
- hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
- /*instr=*/cp2, /*attr_name=*/kSendRecvValidationAttr,
- /*intervals=*/send_recv_iterations_new_instr2);
-
- VLOG(3) << "Updated send-recv iterations for " << cp1->name() << " : "
- << cp1->frontend_attributes().map().at(kSendRecvValidationAttr);
- VLOG(3) << "Updated send-recv iterations for " << cp2->name() << " : "
- << cp2->frontend_attributes().map().at(kSendRecvValidationAttr);
- return absl::OkStatus();
-}
-
-// Handle control predecessors/successors for every old-new instruction pair.
-// For every new instruction, we find the relevant predecessor/successor
-// relationships of the old instruction and we reconstruct them by looking up
-// new (already created) predecessors/successors.
-//
-// When rewiring dependencies from output of the original body, to the input of
-// the cloned body we skip collectives, and ops in `skip_control_dep_injection`.
-absl::Status HandleControlDependencies(
- const HloComputation* while_body,
- const absl::flat_hash_map<HloInstruction*, HloInstruction*>& old_to_new_map,
- HloInstruction::InstructionVector* old_loop_roots,
- HloInstruction* input_parameter,
- const absl::flat_hash_set<HloInstruction*>& skip_control_dep_injection) {
- for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
- if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
- HloInstruction* new_instr = old_to_new_map.at(old_instr);
- VLOG(2) << "Processing control predecessors for "
- << new_instr->ToString();
- std::vector<HloInstruction*> new_control_pred;
- new_control_pred.reserve(old_instr->control_predecessors().size());
- for (HloInstruction* pred : old_instr->control_predecessors()) {
- if (!old_to_new_map.contains(pred)) {
- continue;
- }
- new_control_pred.push_back(old_to_new_map.at(pred));
- }
-
- TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
- for (HloInstruction* new_pred : new_control_pred) {
- TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
- VLOG(2) << "Adding " << new_pred->ToString()
- << " to control dependency of " << new_instr->ToString();
- }
- }
- }
- for (HloInstruction* input_consumer : input_parameter->users()) {
- for (HloInstruction* old_input : input_consumer->users()) {
- if (old_to_new_map.find(old_input) != old_to_new_map.end()) {
- HloInstruction* new_input = old_to_new_map.at(old_input);
- if (skip_control_dep_injection.find(old_input) ==
- skip_control_dep_injection.end() &&
- !IsCollective(old_input)) {
- for (HloInstruction* old_root : *old_loop_roots) {
- TF_RETURN_IF_ERROR(old_root->AddControlDependencyTo(new_input));
- }
- }
- }
- }
- }
-
- return absl::OkStatus();
-}
-
-absl::StatusOr<bool> FullyUnroll(HloInstruction* while_instr,
- HloModule* module) {
- HloComputation* while_body = while_instr->while_body();
- bool changed = false;
- VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
-
- auto loop_roots = while_body->root_instruction()->mutable_operands();
- HloInstruction* input_parameter = while_body->parameter_instruction(0);
- VLOG(2) << "Processing input parameter " << input_parameter->ToString();
-
- absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
- absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
- std::string clone_suffix = "full_unroll_clone";
-
- TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
- while_instr->backend_config<WhileLoopBackendConfig>());
- std::vector<HloInstruction*> ops_to_clone;
- ops_to_clone.reserve(while_body->MakeInstructionPostOrder().size());
-
- // Pre-loop prep.
- HloInstruction* old_input_parameter = input_parameter;
- HloInstruction* new_input_parameter = while_body->root_instruction();
- absl::flat_hash_set<HloInstruction*> seen_ops;
- for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
- if (seen_ops.contains(old_instr)) {
- continue;
- }
- ops_to_clone.push_back(old_instr);
- seen_ops.insert(old_instr);
- }
-
- int n = config.known_trip_count().n();
- while (--n) {
- std::vector<HloInstruction*> new_ops_to_clone;
- old_to_new_map[old_input_parameter] = new_input_parameter;
- for (HloInstruction* old_instr : ops_to_clone) {
- if (old_to_new_map.contains(old_instr)) {
- continue;
- }
- VLOG(2) << "Cloning instruction " << old_instr->ToString();
- std::vector<HloInstruction*> new_operands;
- for (HloInstruction* old_operand : old_instr->mutable_operands()) {
- new_operands.push_back(old_to_new_map[old_operand]);
- }
- HloInstruction* new_instr =
- while_body->AddInstruction(old_instr->CloneWithNewOperands(
- old_instr->shape(), new_operands, clone_suffix));
-
- // If an elementwise instruction with constant operand is present, we
- // won't inject control dependency at the end to allow more constant
- // folding opportunities.
- if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
- skip_control_dep_injection.insert(old_instr);
- }
- SetChannelIdForNewCollective(new_instr, module);
- old_to_new_map[old_instr] = new_instr;
- new_ops_to_clone.push_back(new_instr);
- VLOG(2) << "Added instruction " << new_instr->ToString();
- }
-
- while_body->set_root_instruction(
- old_to_new_map[while_body->root_instruction()]);
- VLOG(2) << "Replaced with new root "
- << while_body->root_instruction()->ToString();
-
- TF_RETURN_IF_ERROR(HandleControlDependencies(
- while_body, old_to_new_map, &loop_roots, old_input_parameter,
- skip_control_dep_injection));
-
- // Inductive step update, clean/update necessary buffers to prepare them for
- // the next unrolling iteration.
- old_to_new_map.clear();
- skip_control_dep_injection.clear();
- loop_roots = while_body->root_instruction()->mutable_operands();
- old_input_parameter = new_input_parameter;
- new_input_parameter = while_body->root_instruction();
- ops_to_clone = std::move(new_ops_to_clone);
- changed = true;
- }
-
- WhileLoopBackendConfig new_config;
- new_config.mutable_known_trip_count()->set_n(1);
- TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
-
- return changed;
-}
-
-absl::Status PeelInstructionsForOddTripCount(HloModule* module,
- HloInstruction* while_instr) {
- std::string suffix = "peeled_double_buffer";
- absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
- HloComputation* while_body = while_instr->while_body();
- HloInstruction* input_parameter = while_body->parameter_instruction(0);
- HloInstruction* input_tuple = while_instr->mutable_operand(0);
-
- auto old_loop_roots = while_body->root_instruction()->mutable_operands();
- HloComputation* parent_comp = while_instr->parent();
- old_to_new_map[input_parameter] = input_tuple;
-
- for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
- if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
- continue;
- }
- VLOG(2) << "Peeling instruction " << old_instr->ToString();
- std::vector<HloInstruction*> new_operands(old_instr->operand_count());
- for (int64_t i = 0; i < old_instr->operand_count(); i++) {
- new_operands[i] = old_to_new_map[old_instr->mutable_operand(i)];
- }
- HloInstruction* new_instr =
- parent_comp->AddInstruction(old_instr->CloneWithNewOperands(
- old_instr->shape(), new_operands, suffix));
-
- SetChannelIdForNewCollective(new_instr, module);
- TF_CHECK_OK(SetSendRecvValidationForPeeledInstr(new_instr, old_instr));
- old_to_new_map[old_instr] = new_instr;
- VLOG(2) << "Added instruction " << new_instr->ToString()
- << " to parent computation.";
- }
-
- std::vector<HloInstruction*> new_roots;
- for (HloInstruction* instr : old_loop_roots) {
- new_roots.push_back(old_to_new_map[instr]);
- }
- TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(
- 0, old_to_new_map[while_body->root_instruction()]));
- VLOG(2) << "Replaced with new input tuple "
- << while_instr->operand(0)->ToString();
-
- // Handle existing control dependencies.
- for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
- if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
- HloInstruction* new_instr = old_to_new_map[old_instr];
- VLOG(2) << "Processing control predecessors for peeled instruction "
- << new_instr->ToString();
- std::vector<HloInstruction*> new_control_pred(
- old_instr->control_predecessors().size());
- for (HloInstruction* pred : old_instr->control_predecessors()) {
- new_control_pred.push_back(old_to_new_map[pred]);
- }
-
- TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
- for (HloInstruction* new_pred : new_control_pred) {
- TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
- VLOG(2) << "Adding " << new_pred->ToString()
- << " to control dependency of peeled instruction: "
- << new_instr->ToString();
- }
- }
- }
- return absl::OkStatus();
-}
-
-// TODO(olechwierowicz): Extract common logic of this and `FullyUnroll` to
-// a separate function.
-absl::StatusOr<bool> DoubleBufferingUnroll(HloInstruction* while_instr,
- HloModule* module) {
- TF_ASSIGN_OR_RETURN(auto config,
- while_instr->backend_config<WhileLoopBackendConfig>());
-
- CHECK(config.has_known_trip_count())
- << "Only loops with known trip count are supported.";
- int64_t exact_trip_count = config.known_trip_count().n();
- VLOG(2) << "Processing while loop " << while_instr->ToString()
- << " with trip count: " << exact_trip_count;
-
- HloComputation* while_body = while_instr->while_body();
-
- VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
-
- auto old_loop_roots = while_body->root_instruction()->mutable_operands();
- HloInstruction* input_parameter = while_body->parameter_instruction(0);
- VLOG(2) << "Processing input parameter " << input_parameter->ToString();
- absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
- absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
-
- bool is_peeled = exact_trip_count % 2;
- if (is_peeled) {
- VLOG(2) << "Found loops with odd trip count, 1 iteration will be peeled "
- "outside of the main body.";
- TF_RETURN_IF_ERROR(PeelInstructionsForOddTripCount(module, while_instr));
- exact_trip_count -= 1;
- }
-
- std::string suffix = "double_buffer_clone";
- old_to_new_map[input_parameter] = while_body->root_instruction();
- for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
- if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
- continue;
- }
- VLOG(2) << "Cloning instruction " << old_instr->ToString();
- std::vector<HloInstruction*> new_operands;
- for (HloInstruction* old_operand : old_instr->mutable_operands()) {
- new_operands.push_back(old_to_new_map[old_operand]);
- }
- HloInstruction* new_instr =
- while_body->AddInstruction(old_instr->CloneWithNewOperands(
- old_instr->shape(), new_operands, suffix));
-
- // If an elementwise instruction with constant operand is present, we
- // won't inject control dependency at the end to allow more constant
- // folding opportunities.
- if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
- skip_control_dep_injection.insert(old_instr);
- }
- SetChannelIdForNewCollective(new_instr, module);
- TF_CHECK_OK(SetSendRecvValidation(old_instr, new_instr, is_peeled));
- old_to_new_map[old_instr] = new_instr;
- VLOG(2) << "Added instruction " << new_instr->ToString();
- }
-
- while_body->set_root_instruction(
- old_to_new_map[while_body->root_instruction()]);
- VLOG(2) << "Replaced with new root "
- << while_body->root_instruction()->ToString();
-
- // Handle existing control dependencies.
- TF_RETURN_IF_ERROR(HandleControlDependencies(while_body, old_to_new_map,
- &old_loop_roots, input_parameter,
- skip_control_dep_injection));
-
- WhileLoopBackendConfig new_config;
- new_config.mutable_known_trip_count()->set_n(exact_trip_count / 2);
- TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
- return true; // changed
-}
-
-} // namespace
-
-absl::StatusOr<bool> DoubleBufferLoopUnrolling::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- std::vector<HloInstruction*> while_instrs;
- for (auto comp : module->MakeNonfusionComputations()) {
- absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
- HloPredicateIsOp<HloOpcode::kWhile>);
- }
- VLOG(2) << "Processing " << while_instrs.size() << " while loops.";
-
- for (HloInstruction* while_instr : while_instrs) {
- TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
- while_instr->backend_config<WhileLoopBackendConfig>());
- if (!config.has_known_trip_count() || config.known_trip_count().n() == 1) {
- VLOG(2) << while_instr->ToString()
- << " doesn't have exact trip count, skipping loop unrolling "
- "for now";
- continue;
- }
-
- if (unroll_strategy_ == UnrollStrategy::kFullUnroll) {
- TF_ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module));
- } else if (unroll_strategy_ == UnrollStrategy::kDoubleBuffer) {
- TF_ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module));
- } else {
- LOG(FATAL) << absl::StrCat("Unhandled unrolling strategy: ",
- unroll_strategy_);
- }
- }
-
- VLOG(2) << "LoopDoubleBufferTransformer output: " << module->ToString();
-
- // Run necessary cleanup to ensure LoopDoubleBufferTransformer behaves
- // correctly.
- if (changed) {
- // The call graph will not be flat if one of the loops that was unrolled
- // contains any kind of call to another computation---since the call will
- // be duplicated, thereby adding a second callsite for that computation.
- TF_RETURN_IF_ERROR(
- FlattenCallGraph().Run(module, execution_threads).status());
- }
-
- return changed;
-}
-
-} // end namespace gpu
-} // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h b/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h
deleted file mode 100644
index 120070d..0000000
--- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_
-#define XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// With `kDoubleBuffer` strategy:
-// This pass performs the unrolling-by-2 loop transformation
-// to effectively achieve double buffering between inputs and outputs
-// of previously rolled iterations.
-// This pass only runs on loops with known trip counts.
-// For even number of iterations, unrolling-by-2 will be done directly.
-// For odd number of iterations, the first iteration of the loop will be
-// peeled outside of the while loop to make the trip count an even number,
-// then proceed to unroll by 2.
-// It also updates the trip count property of the loop to the correct one
-// (n/2).
-//
-// With `kFullUnroll` strategy:
-// This pass will perform the full unroll of the loop with the same strategy
-// that is used with `kDoubleBuffer` but while loop trip count times.
-// It updates the trip count of the while loop to 1, and relies on other
-// passes (like `WhileLoopSimplifier`) to simplify/get rid of the while loop
-// eventually.
-//
-// Note that this pass will flatten the call graph if any loop has been
-// unrolled.
-// TODO(olechwierowicz): Rename the loop unroller to something more generic like
-// 'DoubleBufferLoopUnrolling'.
-class DoubleBufferLoopUnrolling : public HloModulePass {
- public:
- enum class UnrollStrategy { kDoubleBuffer, kFullUnroll };
-
- explicit DoubleBufferLoopUnrolling(
- UnrollStrategy unroll_strategy = UnrollStrategy::kDoubleBuffer)
- : unroll_strategy_(unroll_strategy) {};
- ~DoubleBufferLoopUnrolling() override = default;
-
- absl::string_view name() const override {
- return "loop-double-buffer-transformer";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- UnrollStrategy unroll_strategy_;
-};
-
-} // end namespace gpu
-} // end namespace xla
-
-#endif // XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_
diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc b/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc
deleted file mode 100644
index 8fed319..0000000
--- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc
+++ /dev/null
@@ -1,1243 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/double_buffer_loop_unrolling.h"
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-
-#include "absl/container/flat_hash_set.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/tuple_simplifier.h"
-#include "xla/test.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using tsl::testing::IsOkAndHolds;
-
-int64_t CountInstructions(const HloComputation& computation, HloOpcode opcode) {
- int64_t count = 0;
- for (const auto& instruction : computation.instructions()) {
- if (instruction->opcode() == opcode) {
- count++;
- }
- }
- return count;
-}
-
-int64_t CountInstructions(const HloModule& module, HloOpcode opcode) {
- int64_t count = 0;
- for (const auto& computation : module.computations()) {
- count += CountInstructions((*computation), opcode);
- }
- return count;
-}
-
-class GpuLoopDoubleBufferTransformerTest : public HloTestBase {
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_while_loop_double_buffering(true);
- return debug_options;
- }
-};
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollOddTripCountTest) {
- const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=3
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
- TupleSimplifier tuple_simp;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
- EXPECT_TRUE(changed);
- TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
- EXPECT_TRUE(changed);
- HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
- *module->entry_computation(), HloOpcode::kWhile);
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- while_instruction->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- EXPECT_EQ(exact_trip_count, 1);
- EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
- HloOpcode::kAllGatherStart),
- 11);
- EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 11);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollEvenTripCountTest) {
- const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=3
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- // Start all-gather communication
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
- TupleSimplifier tuple_simp;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
- EXPECT_TRUE(changed);
- TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloInstruction* while_instruction;
- for (auto instr : module->entry_computation()->instructions()) {
- if (instr->opcode() == HloOpcode::kWhile) {
- while_instruction = instr;
- }
- }
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- while_instruction->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- EXPECT_EQ(exact_trip_count, 1);
- EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
- HloOpcode::kAllGatherStart),
- 10);
- EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 10);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopEvenTripCount) {
- const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=3
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- // Start all-gather communication
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- // Intertwined with the all-gather communication, an operation happens which
- // depends on param_1, but crucially has a different output shape (which
- // excludes reusing param_1's buffer for its output).
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- // The all-gather communication finishes
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer;
- TupleSimplifier tuple_simp;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
- EXPECT_TRUE(changed);
- TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
- *module->entry_computation(), HloOpcode::kWhile);
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- while_instruction->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- // We expect that after unrolling, the total trip count is half of original
- // count.
- EXPECT_EQ(exact_trip_count, 5);
- // We expect that after unrolling, there should be 2 allgather starts,
- // both in while body.
- EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
- HloOpcode::kAllGatherStart),
- 2);
- EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 2);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopOddTripCount) {
- const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=3
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- // Start all-gather communication
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- // Intertwined with the all-gather communication, an operation happens which
- // depends on param_1, but crucially has a different output shape (which
- // excludes reusing param_1's buffer for its output).
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- // The all-gather communication finishes
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer;
- TupleSimplifier tuple_simp;
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
- // We expect that for the while loop, no further copy needs to be added to the
- // module.
- HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
- *module->entry_computation(), HloOpcode::kWhile);
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- while_instruction->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- // We expect that after unrolling, the total trip count is half of original
- // count.
- EXPECT_EQ(exact_trip_count, 5);
-
- // We expect that after unrolling, there should be 3 allgather starts,
- // 1 in parent computation, 2 in while body.
- EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
- HloOpcode::kAllGatherStart),
- 2);
- EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 3);
-
- // We expect that after unrolling, the third operand of the input tuple should
- // be the peeled allgather done.
- EXPECT_EQ(while_instruction->operand(0)->operand(2)->opcode(),
- HloOpcode::kAllGatherDone);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- UnrolledLoopNoControlDepsForConstantAdd) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- c2 = f32[] constant(2)
- add = f32[] add(c2, param_0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(add, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer;
- TupleSimplifier tuple_simp;
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
- HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
- *module->entry_computation(), HloOpcode::kWhile);
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- while_instruction->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- // We expect that after unrolling, the total trip count is half of original
- // count.
- EXPECT_EQ(exact_trip_count, 5);
-
- // We expect that after unrolling, there should be 4 adds
- EXPECT_EQ(
- CountInstructions((*while_instruction->while_body()), HloOpcode::kAdd),
- 4);
-
- // We expect that after unrolling, the first operand of the output tuple
- // should not have any control dependency since it's a elementwise add with a
- // constant operand.
- EXPECT_EQ(while_instruction->while_body()
- ->root_instruction()
- ->operand(0)
- ->control_predecessors()
- .size(),
- 0);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- UnrolledLoopNoControlDepsForCollective) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
- one = s32[] constant(1)
- all-reduce-done = f32[] all-reduce-done(all-reduce-start)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer;
- TupleSimplifier tuple_simp;
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
- HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
- *module->entry_computation(), HloOpcode::kWhile);
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- while_instruction->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- // We expect that after unrolling, the total trip count is half of original
- // count.
- EXPECT_EQ(exact_trip_count, 5);
-
- // We expect that after unrolling, there should be 2 all-reduce-starts
- EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
- HloOpcode::kAllReduceStart),
- 2);
- absl::flat_hash_set<int64_t> channel_ids;
- hlo_query::ForEachInstructionWithOpcode(
- *while_instruction->while_body(), HloOpcode::kAllReduceStart,
- [&channel_ids](HloInstruction* ar) {
- // We expect that after unrolling, all-reduces should not have any
- // control deps.
- EXPECT_EQ(ar->control_predecessors().size(), 0);
- channel_ids.insert(*(ar->channel_id()));
- });
- // we expect that all 2 all-reduces will have different channel ids.
- EXPECT_EQ(channel_ids.size(), 2);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- FullyUnrolledLoopNoControlDepsForCollective) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
- one = s32[] constant(1)
- all-reduce-done = f32[] all-reduce-done(all-reduce-start)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
- TupleSimplifier tuple_simp;
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
- HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
- *module->entry_computation(), HloOpcode::kWhile);
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- while_instruction->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- EXPECT_EQ(exact_trip_count, 1);
-
- // We expect that after unrolling, there should be 10 all-reduce-starts
- EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
- HloOpcode::kAllReduceStart),
- 10);
- absl::flat_hash_set<int64_t> channel_ids;
- hlo_query::ForEachInstructionWithOpcode(
- *while_instruction->while_body(), HloOpcode::kAllReduceStart,
- [&channel_ids](HloInstruction* ar) {
- // We expect that after unrolling, all-reduces should not have any
- // control deps.
- EXPECT_EQ(ar->control_predecessors().size(), 0);
- channel_ids.insert(*(ar->channel_id()));
- });
- // we expect that all 10 all-reduces will have different channel ids.
- EXPECT_EQ(channel_ids.size(), 10);
-}
-
-// The following 2 tests also address the regression described here:
-// https://github.com/openxla/xla/issues/6353
-TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopRemainsFlattened) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_while_loop_remains_flattened
-
-condition_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-
-condition {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (s32[]) parameter(0)
- ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
-}
-
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer;
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- absl::flat_hash_set<const HloComputation*> while_loops_callees;
-
- hlo_query::ForEachInstructionWithOpcode(
- *module, HloOpcode::kWhile,
- [&while_loops_callees](HloInstruction* instr) {
- EXPECT_TRUE(
- while_loops_callees.insert(instr->while_condition()).second);
- EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
- });
-
- // We expect that the nested while loop has been duplicated, along with its
- // associated computations.
- EXPECT_EQ(while_loops_callees.size(), 6);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- NestedWhileLoopRemainsFlattenedOddTripCount) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_while_loop_remains_flattened
-
-condition_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-
-condition {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (s32[]) parameter(0)
- ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
-}
-
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer;
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- absl::flat_hash_set<const HloComputation*> while_loops_callees;
-
- hlo_query::ForEachInstructionWithOpcode(
- *module, HloOpcode::kWhile,
- [&while_loops_callees](HloInstruction* instr) {
- EXPECT_TRUE(
- while_loops_callees.insert(instr->while_condition()).second);
- EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
- });
-
- // We expect that the nested while loop has been duplicated, along with its
- // associated computations.
- EXPECT_EQ(while_loops_callees.size(), 8);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- NestedWhileLoopRemainsFlattenedWhenFullyUnrolled) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_while_loop_remains_flattened
-
-condition_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-
-condition {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (s32[]) parameter(0)
- ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
-}
-
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- absl::flat_hash_set<const HloComputation*> while_loops_callees;
-
- hlo_query::ForEachInstructionWithOpcode(
- *module, HloOpcode::kWhile,
- [&while_loops_callees](HloInstruction* instr) {
- EXPECT_TRUE(
- while_loops_callees.insert(instr->while_condition()).second);
- EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
- });
-
- hlo_query::ForEachInstructionWithOpcode(
- *module->entry_computation(), HloOpcode::kWhile,
- [](HloInstruction* instr) {
- TF_ASSERT_OK_AND_ASSIGN(
- WhileLoopBackendConfig config,
- instr->backend_config<WhileLoopBackendConfig>());
- int64_t exact_trip_count = config.known_trip_count().n();
- EXPECT_EQ(exact_trip_count, 1);
- });
-
- // We expect that the nested while loop has been fully duplicated 10
- // times. The one outer while loop still remains so that's 11 while
- // instructions. We check whether there are 22 distinct computations for
- // each while loop body and condition.
- EXPECT_EQ(while_loops_callees.size(), 22);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreUnrolled) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_are_unrolled
-condition_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-condition {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body {
- input_tuple = (s32[]) parameter(0)
- ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
-}
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer;
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- int64_t num_whiles = 0;
- hlo_query::ForEachInstructionWithOpcode(
- *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
- EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
- ->known_trip_count()
- .n(),
- 5);
- ++num_whiles;
- });
- // We expect the number of while loops to be 4 in total after unrolling.
- EXPECT_EQ(num_whiles, 4);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreFullyUnrolled) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_are_unrolled
-condition_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-condition {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body {
- input_tuple = (s32[]) parameter(0)
- ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
-}
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- int64_t num_whiles = 0;
- hlo_query::ForEachInstructionWithOpcode(
- *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
- EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
- ->known_trip_count()
- .n(),
- 1);
- ++num_whiles;
- });
- EXPECT_EQ(num_whiles, 12);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithCollectivePermute) {
- const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
- frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
- VLOG(0) << module->ToString();
- EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
- // CHECK: %body {{.+}} {
- // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"}
- // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
- // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
- // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"}
- // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
- // CHECK: }
- // CHECK: ENTRY %main {{.+}} {
- // CHECK-NOT: collective-permute
- // CHECK: }
- )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- WhileLoopWithCollectivePermutePeeled) {
- const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(15)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,0}},
- frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10},{4,11},{5,12},{6,13},{7,14}}"}
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
-}
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
- VLOG(0) << module->ToString();
-
- EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
- // CHECK: %body
- // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}"}
- // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
- // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]])
- // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}"}
- // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
- // CHECK: ENTRY %main {{.+}} {
- // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}"}
- // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
- // CHECK: %[[while:.+]] = {{.+}} while({{.+}} %[[out_peeled]])
- // CHECK: }
- )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- WhileLoopWithCollectivePermuteBackwardCycle) {
- const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(14)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
- frontend_attributes={_xla_send_recv_validation="{{7,13},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}}"}
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"14"}}
-}
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
- // CHECK: %body
- // CHECK: %[[cp1:.+]] = f32[] collective-permute(f32[] %param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}"}
- // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
- // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
- // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}"}
- // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
- // CHECK: ENTRY %main
- // CHECK-NOT: collective-permute
- // CHECK: }
- )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- WhileLoopWithCollectivePermuteBackwardCyclePeeled) {
- const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(15)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
- frontend_attributes={_xla_send_recv_validation="{{7,14},{6,13},{5,12},{4,11},{3,10},{2,9},{1,8},{0,7}}"}
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
-}
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
- // CHECK: %body
- // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3}{{[}]}}"}
- // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
- // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
- // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}"}
- // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
- // CHECK: }
- // CHECK: ENTRY %main
- // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}"}
- // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
- // CHECK: ROOT {{.+}} = {{.+}} while({{.+}} %[[out_peeled]])
- // CHECK: }
- )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- WhileLoopWithCollectivePermuteStartDone) {
- const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- collective-permute-start = (f32[], f32[], u32[], u32[]) collective-permute-start(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
- frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
- collective-permute = f32[] collective-permute-done(collective-permute-start)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
- // CHECK: %body
- // CHECK: %[[cp_start1:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"}
- // CHECK: %[[cp1:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start1]])
- // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
- // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
- // CHECK: %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"}
- // CHECK: %[[cp2:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start2]])
- // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
- // CHECK: }
- // CHECK: ENTRY %main
- // CHECK-NOT: collective-permute
- // CHECK: }
- )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithRecvDone) {
- const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- after-all.0 = token[] after-all()
- recv.0 = (f32[], u32[], token[]) recv(after-all.0), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
- _xla_send_recv_pipeline="0",
- _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
- }
- recv-done.0 = (f32[], token[]) recv-done(recv.0), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- recv-data = f32[] get-tuple-element(recv-done.0), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(recv-data, cond_plus_1)
-}
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
- // CHECK: %body
- // CHECK: %[[recv1:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"
- // CHECK: %[[recv2:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"
- // CHECK: ENTRY %main
- // CHECK-NOT: recv
- // CHECK: }
- )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithSendDone) {
- const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
- input_tuple = (f32[], s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=1
- trip_count = s32[] constant(10)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
- Arg_1 = f32[] parameter(1)
- Arg_0 = f32[] parameter(0)
- ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- after-all.0 = token[] after-all()
- send.0 = (f32[], u32[], token[]) send(param_0, after-all.0), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
- _xla_send_recv_pipeline="0",
- _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
- }
- send-done.0 = token[] send-done(send.0), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(param_0, cond_plus_1)
-}
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
- EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
- // CHECK: %body
- // CHECK: %[[send1:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"
- // CHECK: %[[send2:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"
- // CHECK: ENTRY %main
- // CHECK-NOT: send
- // CHECK: }
- )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
- WhileLoopWithTripCount1ShouldBeSkipped) {
- const char* const kModuleString = R"(
-HloModule loop_unrolling_skipped
-condition_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(0)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-condition {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- trip_count = s32[] constant(0)
- ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body {
- input_tuple = (s32[]) parameter(0)
- ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"1"}}
-}
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"1"}}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
- ParseAndReturnVerifiedModule(kModuleString));
- DoubleBufferLoopUnrolling double_buffer(
- DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
- // The processing of the loop should be completely skipped.
- EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(false));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc
deleted file mode 100644
index 0919241..0000000
--- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc
+++ /dev/null
@@ -1,513 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <iterator>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/ffi/ffi_api.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/custom_call_target_registry.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_constants.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// A dataflow path flowing from a definition to a user.
-using DefUseDataflowPath = absl::InlinedVector<HloInstruction*, 2>;
-
-// All dataflow paths flowing from a definition to all users. Each user will
-// have a separate entry in the vector.
-using DefUseDataflowPaths = absl::InlinedVector<DefUseDataflowPath, 4>;
-
-// A dataflow path flowing from a user to a definition.
-using UseDefDataflowPath = absl::InlinedVector<HloInstruction*, 4>;
-
-// All dataflow paths flowing from a user to all definitions of its operands.
-using UseDefDataflowPaths = absl::InlinedVector<HloInstruction*, 8>;
-
-using DataflowPathView = absl::Span<HloInstruction* const>;
-using DataflowPathsView = absl::Span<DataflowPathView>;
-
-using InstructionSet = absl::flat_hash_set<HloInstruction*>;
-
-bool IsNoOp(const HloInstruction* hlo) {
- return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
- HloOpcode::kGetTupleElement>(hlo);
-}
-
-bool IsCustomCall(const HloInstruction* hlo, absl::string_view platform_name) {
- auto* custom_call = DynCast<HloCustomCallInstruction>(hlo);
- if (custom_call == nullptr) return false;
-
- // TODO(vuson): properly handle token by following
- // `LhloDialectEmitter::EmitCustomCallOp`'s `CreateOperands` logic for
- // `LhloDialectEmitter::EmitFusionOp`'s `RewriteFusionOperand`
- if (custom_call->shape().IsTuple() &&
- absl::c_any_of(
- custom_call->shape().tuple_shapes(),
- [&](const Shape& sub_shape) { return sub_shape.IsToken(); }))
- return false;
-
- const std::string call_target_name = custom_call->custom_call_target();
-
- bool is_ffi_custom_call =
- custom_call->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI;
-
- void* call_target = CustomCallTargetRegistry::Global()->Lookup(
- call_target_name, std::string(platform_name));
-
- absl::StatusOr<ffi::HandlerRegistration> handler_registration =
- ffi::FindHandler(call_target_name, platform_name);
-
- // At least one implementation should be available at run time.
- bool found_custom_call = !is_ffi_custom_call && call_target != nullptr;
- bool found_ffi_handler = is_ffi_custom_call && handler_registration.ok();
-
- return found_custom_call || found_ffi_handler;
-}
-
-// Returns true if the slice is 128-byte-aligned. The slice starting
-// address is determined by the product of all non-sliced dimensions and an
-// offset defined by `slice_starts` of the slice op.
-//
-// For dynamic cases, we don't have info about the start indices, so we have to
-// be conservative by only accepting sliced shapes that have the product of all
-// non-sliced dimensions being a multiple of `kXlaAllocatedBufferAlignBytes`.
-bool IsAlignedSlice(const HloInstruction* slice) {
- DCHECK(slice->opcode() == HloOpcode::kSlice ||
- slice->opcode() == HloOpcode::kDynamicSlice ||
- slice->opcode() == HloOpcode::kDynamicUpdateSlice)
- << "Unknown slice operation: " << slice->ToString();
-
- if (!IsContiguousSlice(*slice)) return false;
-
- auto [full_shape, slice_shape] = [&] {
- if (auto* dus = DynCast<HloDynamicUpdateSliceInstruction>(slice)) {
- return std::make_pair(dus->shape(), dus->update()->shape());
- }
- return std::make_pair(slice->operand(0)->shape(), slice->shape());
- }();
-
- auto strides = ShapeUtil::ByteStrides(slice_shape);
- if (!strides.has_value()) return false;
-
- for (auto dim : slice_shape.layout().minor_to_major()) {
- if ((strides.value()[dim] % kXlaAllocatedBufferAlignBytes) == 0) {
- return true;
- }
- if (slice_shape.dimensions(dim) < full_shape.dimensions(dim)) {
- return (slice->opcode() == HloOpcode::kSlice &&
- (((*strides)[dim] * slice->slice_starts(dim)) %
- kXlaAllocatedBufferAlignBytes ==
- 0));
- }
- }
- return true;
-}
-
-UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) {
- UseDefDataflowPaths sliced_operand_paths;
-
- // This set is used to avoid duplicates in the matched results. It contains
- // the matched instructions that we have seen so far.
- InstructionSet processed_instrs;
-
- const auto& aliasing_pairs =
- Cast<HloCustomCallInstruction>(instr)->output_to_operand_aliasing();
- absl::flat_hash_set<int64_t> aliased_operands;
- for (const auto& pair : aliasing_pairs) {
- aliased_operands.insert(pair.second.first);
- }
-
- for (const auto* operand : instr->operands()) {
- // output_to_operand_aliasing means the operand is to be materialized, which
- // is against the whole idea of address computation fusion. Skip this
- // operand.
- if (aliased_operands.contains(instr->operand_index(operand))) continue;
- UseDefDataflowPath maybe_sliced_operand_path;
- bool slice_found = false;
- // TODO: currently HloFindIf exits upon encountering the first node that
- // matches. This works well if each operand only has 1 data flow (i.e. only
- // flows through unary op). We might want to keep finding until the queue is
- // empty: if the operand is a tuple, it might have different data flows
- // (i.e. 1 for each element).
- auto maybe_slice_instr =
- HloBfsFindIf({operand}, [&](const HloInstruction* cur) {
- // If the node is a match that has been processed, stop the traversal.
- if (processed_instrs.contains(cur)) return true;
-
- maybe_sliced_operand_path.push_back(const_cast<HloInstruction*>(cur));
-
- if (IsOpcodeAnyOf<HloOpcode::kDynamicSlice, HloOpcode::kSlice>(cur)) {
- if (IsAlignedSlice(cur)) {
- slice_found = true;
- return slice_found;
- }
- }
-
- return !IsNoOp(cur);
- });
-
- if (maybe_slice_instr == std::nullopt) continue;
-
- if (slice_found || processed_instrs.contains(maybe_slice_instr.value())) {
- // Even in the case of stopping at a match that has been processed, we
- // still need to add instructions encountered in the sliced operand path
- // during the latest traversal.
- sliced_operand_paths.insert(sliced_operand_paths.end(),
- maybe_sliced_operand_path.rbegin(),
- maybe_sliced_operand_path.rend());
- processed_instrs.insert(maybe_sliced_operand_path.begin(),
- maybe_sliced_operand_path.end());
- }
- }
-
- sliced_operand_paths.push_back(const_cast<HloInstruction*>(instr));
- return sliced_operand_paths;
-}
-
-// Each user of `instr` that goes into a DUS will have an entry in the returned
-// vector.
-// Each entry contains the sliced paths for that user, i.e. the sequence of ops
-// following the dataflow from the user itself to the DUS (included).
-DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) {
- DefUseDataflowPaths sliced_user_paths;
- // This set is used to avoid duplicates in the matched results. It contains
- // the matched instructions that we have seen so far.
- InstructionSet processed_instrs;
-
- auto traverse_hlo_and_collect = [&](HloInstruction* start) {
- DefUseDataflowPath maybe_sliced_user_path;
- bool dus_found = false;
- auto maybe_dus_instr = HloBfsFindIf(
- {start},
- [&](const HloInstruction* cur) {
- // If the node is a match that has been processed, stop the
- // traversal.
- if (processed_instrs.contains(cur)) return true;
- maybe_sliced_user_path.push_back(const_cast<HloInstruction*>(cur));
- if (const auto slice_instr =
- DynCast<HloDynamicUpdateSliceInstruction>(cur)) {
- if (IsAlignedSlice(slice_instr)) {
- dus_found = true;
- return true;
- }
- }
- return cur->user_count() > 1 || !IsNoOp(cur);
- },
- /*visit_operands=*/false);
- if (maybe_dus_instr == std::nullopt) return;
- if (dus_found || processed_instrs.contains(maybe_dus_instr.value())) {
- // Even in the case of stopping at a match that has been processed, we
- // still need to add instructions encountered in the sliced user path
- // during the latest traversal.
- processed_instrs.insert(maybe_sliced_user_path.begin(),
- maybe_sliced_user_path.end());
- sliced_user_paths.push_back(std::move(maybe_sliced_user_path));
- }
- };
-
- if (instr->shape().IsTuple()) {
- for (auto* user : instr->users()) {
- if (DynCast<HloGetTupleElementInstruction>(user)) {
- traverse_hlo_and_collect(user);
- }
- }
- } else {
- if (instr->user_count() == 1) {
- traverse_hlo_and_collect(instr->users().front());
- }
- }
-
- return sliced_user_paths;
-}
-
-absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
- DataflowPathView matches) {
- absl::InlinedVector<HloInstruction*, 4> captures;
-
- InstructionSet matched_instrs(matches.begin(), matches.end());
-
- for (HloInstruction* instr : matches) {
- for (HloInstruction* operand : instr->operands()) {
- if (!matched_instrs.contains(operand) &&
- absl::c_find(captures, operand) == captures.end()) {
- captures.emplace_back(operand);
- }
- }
- }
-
- return captures;
-}
-
-absl::Status CreateRootTuple(
- HloInstruction* hero, HloComputation::Builder& builder,
- DataflowPathsView sliced_user_paths,
- absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
- instr_mapping) {
- unsigned tuple_size = hero->shape().tuple_shapes_size();
-
- std::vector<HloInstruction*> sliced_elems(tuple_size, nullptr);
- for (auto& sliced_user_path : sliced_user_paths) {
- auto gte = Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
- sliced_elems[gte->tuple_index()] = sliced_user_path.back();
- }
-
- std::vector<HloInstruction*> elements;
- for (size_t i = 0; i < tuple_size; ++i) {
- if (sliced_elems[i] != nullptr) {
- elements.push_back(instr_mapping[sliced_elems[i]]);
- continue;
- }
- auto* gte = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(instr_mapping[hero], i));
- if (hero->shape().tuple_shapes(i).IsTuple()) {
- instr_mapping[gte] = gte;
- TF_RETURN_IF_ERROR(CreateRootTuple(gte, builder, {}, instr_mapping));
- elements.push_back(builder.last_added_instruction());
- } else {
- elements.push_back(gte);
- }
- }
- if (elements.size() > 1)
- builder.AddInstruction(HloInstruction::CreateTuple(elements));
-
- return absl::OkStatus();
-}
-
-absl::StatusOr<HloComputation*> CreateFusionBody(
- HloModule* module, DataflowPathView sliced_operand_paths,
- DataflowPathsView sliced_user_paths, DataflowPathView captures) {
- HloComputation::Builder builder("address-computation");
-
- // A mapping from original instructions to instructions in the fusion body.
- absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
-
- auto mapped_operands = [&](HloInstruction* instr) {
- absl::InlinedVector<HloInstruction*, 4> operands;
- for (HloInstruction* operand : instr->operands()) {
- operands.push_back(instr_mapping.at(operand));
- }
- return operands;
- };
-
- // For every captured value create a parameter instruction in the computation
- // body and set up instruction mapping.
- for (const HloInstruction* capture : captures) {
- int64_t index = instr_mapping.size();
- instr_mapping[capture] =
- builder.AddInstruction(HloInstruction::CreateParameter(
- index, capture->shape(), absl::StrCat("p", index)));
- }
-
- // Instructions in the pattern are already topologically sorted, as we visited
- // them following use-def path, then reverse the list.
- HloInstruction* hero;
- for (HloInstruction* instr : sliced_operand_paths) {
- instr_mapping[instr] = builder.AddInstruction(
- instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
- hero = instr;
- }
-
- for (auto& sliced_user_path : sliced_user_paths) {
- for (HloInstruction* instr : sliced_user_path) {
- instr_mapping[instr] = builder.AddInstruction(
- instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
- }
- }
-
- // Create a tuple if the hero is a tuple to make sure there's a buffer
- // assigned for each of the elements. Make sure the tuple is not nil first.
- if (hero->shape().IsTuple() && hero->shape().tuple_shapes_size() > 0) {
- TF_RETURN_IF_ERROR(
- CreateRootTuple(hero, builder, sliced_user_paths, instr_mapping));
- }
-
- return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
-}
-
-absl::StatusOr<HloInstruction*> CreateFusionInstruction(
- HloModule* module, HloInstruction* orig, DataflowPathView captures,
- HloComputation* body, bool dynamic) {
- HloComputation* parent = orig->parent();
-
- // Add a fusion operation calling outlined fusion computation.
- HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
- body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
- captures, body));
- module->SetAndUniquifyInstrName(fusion, "address_computation");
-
- // We don't need to set/update output_to_operand_aliasing for the new fusion
- // instruction because all buffers are already assigned at this point.
-
- // Set backends config to a matched custom fusion config.
- GpuBackendConfig gpu_config;
- FusionBackendConfig& backend_config =
- *gpu_config.mutable_fusion_backend_config();
- backend_config.set_kind("__custom_fusion");
- CustomFusionConfig config;
- config.set_name(dynamic ? "dynamic_address_computation"
- : "address_computation");
- *backend_config.mutable_custom_fusion_config() = config;
- TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
-
- return fusion;
-}
-
-} // namespace
-
-absl::StatusOr<bool> DynamicSliceFusionRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- absl::flat_hash_map<HloInstruction*,
- std::pair<UseDefDataflowPaths, DefUseDataflowPaths>>
- matches;
-
- // Collect all potential custom call matches in the non-fusion computations.
- for (HloComputation* computation : module->computations()) {
- if (computation->IsFusionComputation()) continue;
- for (HloInstruction* instr : computation->instructions()) {
- if (IsLegacyCublasMatmul(*instr) ||
- (IsCustomCall(instr, platform_name_))) {
- UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr);
- bool has_sliced_operand_paths = sliced_operand_paths.size() > 1;
-
- DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr);
- bool has_sliced_user_paths = absl::c_any_of(
- sliced_user_paths,
- [&](auto& sliced_user_path) { return !sliced_user_path.empty(); });
-
- if (absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) {
- return DynCast<HloDynamicUpdateSliceInstruction>(
- sliced_user_path.back()) == nullptr;
- })) {
- return absl::InternalError(
- "Expect sliced user path to end with a DUS.");
- }
-
- if (has_sliced_operand_paths || has_sliced_user_paths) {
- matches[instr] = std::make_pair(std::move(sliced_operand_paths),
- std::move(sliced_user_paths));
- }
- }
- }
- }
-
- if (matches.empty()) return false;
-
- for (auto& [hero, paths] : matches) {
- auto& [sliced_operand_paths, sliced_user_paths] = paths;
- std::vector<HloInstruction*> matched_instrs;
- absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs));
-
- std::vector<DataflowPathView> sliced_user_paths_view;
- for (auto& sliced_user_path : sliced_user_paths) {
- absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs));
- DataflowPathView sliced_user_path_view{&sliced_user_path.front(),
- sliced_user_path.size()};
- sliced_user_paths_view.push_back(std::move(sliced_user_path_view));
- }
-
- auto captures = GetPatternCaptures(matched_instrs);
-
- TF_ASSIGN_OR_RETURN(
- HloComputation * fusion_body,
- CreateFusionBody(module, sliced_operand_paths,
- DataflowPathsView(sliced_user_paths_view), captures));
-
- bool has_dynamic_slices = absl::c_any_of(matched_instrs, [&](auto* instr) {
- return DynCast<HloDynamicIndexInstruction>(instr) != nullptr;
- });
- TF_ASSIGN_OR_RETURN(
- HloInstruction * fusion,
- CreateFusionInstruction(module, hero, captures, fusion_body,
- has_dynamic_slices));
-
- HloComputation* parent = hero->parent();
- if (fusion->shape().IsTuple()) {
- TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape(
- const_cast<HloInstruction*>(hero), fusion));
- for (auto& sliced_user_path : sliced_user_paths) {
- auto old_gte =
- Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
- HloInstruction* gte =
- parent->AddInstruction(HloInstruction::CreateGetTupleElement(
- fusion, old_gte->tuple_index()));
- TF_RETURN_IF_ERROR(
- parent->ReplaceInstruction(sliced_user_path.back(), gte));
- }
- } else {
- auto* instr_to_be_replaced = const_cast<HloInstruction*>(hero);
- if (sliced_user_paths.empty()) {
- // The only case where a tuple-shaped original hero op is fused into a
- // non-tuple-shaped fusion is there's only one element of the original
- // tuple being used. In that case, we need to replace that single
- // get-tuple-element (instead of the hero op) with the fusion
- // instruction.
- if (hero->shape().IsTuple()) {
- if (hero->user_count() != 1 ||
- !DynCast<HloGetTupleElementInstruction>(hero->users().front())) {
- return absl::InternalError(
- "Expect a single get-tuple-element user of the original "
- "tuple-shaped hero op when address computation fusion does "
- "not return a tuple");
- }
- instr_to_be_replaced = hero->users().front();
- }
- } else {
- instr_to_be_replaced = sliced_user_paths.front().back();
- }
- TF_RETURN_IF_ERROR(
- parent->ReplaceInstruction(instr_to_be_replaced, fusion));
- }
- }
-
- return true;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h b/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h
deleted file mode 100644
index 15da284..0000000
--- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_
-#define XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_
-
-#include <string>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Pattern matches (slice(s) + custom call) to custom address computation
-// fusions and rewrites them into fusion instructions and fusion computations.
-//
-// Example:
-//
-// ENTRY %main {
-// %p0 = bf16[2,8,8]{2,1,0} parameter(0)
-// %p1 = bf16[2,8,8]{2,1,0} parameter(1)
-// %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-// %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
-// %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-// %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
-// ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
-// custom_call_target="__cublas$gemm"
-// }
-//
-// After the pass:
-//
-// %address_computation {
-// %p0 = bf16[2,8,8]{2,1,0} parameter(0)
-// %p1 = bf16[2,8,8]{2,1,0} parameter(1)
-// %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-// %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
-// %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-// %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
-// ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
-// custom_call_target="__cublas$gemm"
-// }
-//
-// ENTRY %main {
-// %p0 = bf16[2,8,8]{2,1,0} parameter(0)
-// %p1 = bf16[2,8,8]{2,1,0} parameter(1)
-// ROOT %fusion.2 = bf16[8,8]{1,0} fusion(%p0, %p1),
-// kind=kCustom, calls=%address_computation,
-// backend_config={"fusion_backend_config":{
-// "kind":"__custom_fusion",
-// "custom_fusion_config":{"name":"address_computation"}
-// }}
-// }
-//
-class DynamicSliceFusionRewriter : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "address-computation-fusion-rewriter";
- }
-
- explicit DynamicSliceFusionRewriter(std::string platform_name)
- : platform_name_(std::move(platform_name)) {}
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- std::string platform_name_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc
deleted file mode 100644
index a539fb5..0000000
--- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc
+++ /dev/null
@@ -1,1755 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
-
-#include <cstddef>
-#include <optional>
-
-#include "absl/status/status.h"
-#include "xla/client/lib/constants.h"
-#include "xla/client/xla_builder.h"
-#include "xla/ffi/ffi.h"
-#include "xla/ffi/ffi_api.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/custom_call_target_registry.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/gpu/gpu_types.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla::gpu {
-
-class DynamicSliceFusionRewriterTest : public HloTestBase {};
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemm) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWithWorkspace) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
- ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
- ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
- ; CHECK: tuple([[DOT]], [[WORKSPACE]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWorkspaceIgnored) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- ROOT %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
- ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
- ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
- ; CHECK: tuple([[DOT]], [[WORKSPACE]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotRoot) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandHasMultipleUsers) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[4,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[2:3], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %bitcast.41)
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[2:3], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[P0]], [[P1]])
- ; CHECK-DAG: kind=kCustom, calls=%address-computation,
- ; CHECK-DAG: backend_config={
- ; CHECK-DAG: "kind":"__custom_fusion",
- ; CHECK-DAG: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK-DAG: }
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[B0]])
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsHaveMultipleUsers) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
-
- ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation{{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
- ; CHECK: %address-computation{{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmSlicingNotParameter) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[4,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.12 = f16[2,8,8]{2,1,0} slice(%p0), slice={[0:2], [0:8], [0:8]}
- %slice.13 = f16[1,8,8]{2,1,0} slice(%slice.12), slice={[1:2], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[2,8,8]{2,1,0} slice([[P0]]), slice={[0:2], [0:8], [0:8]}
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[S0]], [[P1]])
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotContiguousSlice) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,4,6]{2,1,0} slice(%p0), slice={[1:2], [0:4], [0:6]}
- %bitcast.41 = f16[4,6]{1,0} bitcast(%slice.13)
- %slice.14 = f16[1,6,4]{2,1,0} slice(%p1), slice={[1:2], [0:6], [0:4]}
- %bitcast.42 = f16[6,4]{1,0} bitcast(%slice.14)
-
- ROOT %custom-call.1 = f16[4,4]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
- std::nullopt);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
- %add.0 = f16[1,8,8]{2,1,0} add(%slice.13, %slice.14)
- %bitcast.41 = f16[8,8]{1,0} bitcast(%add.0)
- %slice.15 = f16[1,8,8]{2,1,0} slice(%p1), slice={[0:1], [0:8], [0:8]}
- %slice.16 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %add.1 = f16[1,8,8]{2,1,0} add(%slice.15, %slice.16)
- %bitcast.42 = f16[8,8]{1,0} bitcast(%add.1)
-
- ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
- std::nullopt);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmDuplicateOperand) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main {
- %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
- %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
- %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
- %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0}
- %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240),
- custom_call_target="__cublas$gemm",
- backend_config={
- "gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"20000",
- "rhs_stride":"10000",
- "grad_x":false,
- "grad_y":false
- }
- }
- %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0
- %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]}
- ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26),
- custom_call_target="__cublas$gemm",
- backend_config={
- "gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"10000",
- "rhs_stride":"10000",
- "grad_x":false,
- "grad_y":false
- }
- }
- })";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
- ; CHECK: [[S0:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[0:100], [0:100]}
- ; CHECK-NOT: slice
- ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) custom-call([[S0]], [[S0]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %p1 = f16[2,8,8]{2,1,0} parameter(0)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder2) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %p1 = f16[2,8,8]{2,1,0} parameter(1)
- %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
- %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
- ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[0:1], [0:8], [0:8]}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
- ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandAliasingOutput) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
- %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
- %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
- %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0}
- %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[16:116], [0:100]}
- %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]}
- ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34),
- custom_call_target="__cublas$gemm",
- output_to_operand_aliasing={{0}: (2, {})},
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":1,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"10000",
- "rhs_stride":"10000",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P2:%[^ ]+]] = f32[100,100]{1,0} parameter(2)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f32[100,100]{1,0} parameter(1)
- ; CHECK-DAG: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
- ; CHECK-DAG: [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[16:116], [0:100]}
- ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P1]], [[S1]], [[P2]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[P:%[^ ]+]] = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
- ; CHECK: [[GTE0:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=0
- ; CHECK: [[GTE1:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=1
- ; CHECK: [[CONCAT:%[^ ]+]] = f32[200,100]{1,0} concatenate([[GTE0]], [[GTE1]]), dimensions={0}
- ; CHECK: [[S:%[^ ]+]] = f32[100,100]{1,0} slice([[CONCAT]]), slice={[99:199], [0:100]}
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[CONCAT]], [[GTE0]], [[S]])
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsFromSameSlice) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[2,8,8]{2,1,0} parameter(0)
- %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
- %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
- %bitcast.42 = f16[8,8]{0,1} bitcast(%slice.13)
-
- ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{0,1} bitcast([[S0]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]])
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-static absl::Status Memcpy(se::Stream* stream, ffi::AnyBuffer src,
- ffi::AnyBuffer dst) {
- se::DeviceMemoryBase dst_mem = dst.device_memory();
- se::DeviceMemoryBase src_mem = src.device_memory();
- return stream->MemcpyD2D(&dst_mem, src_mem, src_mem.size());
-}
-
-XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
- ffi::Ffi::Bind()
- .Ctx<ffi::Stream>()
- .Arg<ffi::AnyBuffer>() // src
- .Arg<ffi::AnyBuffer>() // dst
-);
-XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", "gpu",
- kMemcpy);
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCall) {
- XlaBuilder b(TestName());
- CustomCall(&b, "__xla_test$$memcpy",
- /*operands=*/
- {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
- {128}, {1})},
- ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"",
- /*has_side_effect=*/false,
- /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
- /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
- /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
- TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
- xla::HloModuleConfig hlo_config(
- xla::ProgramShape(computation.proto().host_program_shape()),
- /*ignore_layouts=*/false);
- DebugOptions debug_options = GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
- hlo_config.set_debug_options(debug_options);
- TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
- computation.proto(), hlo_config));
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
- ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
- ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
- ; CHECK: custom_call_target="__xla_test$$memcpy",
- ; CHECK: api_version=API_VERSION_TYPED_FFI
- ; CHECK: }
-
- ; CHECK: ENTRY %{{.*}} {
- ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42)
- ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
- expected);
-}
-
-void Callback_Void(se::gpu::GpuStreamHandle stream, void** buffers,
- const char* /*opaque*/, size_t /*opaque_len*/) {}
-
-XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, "gpu");
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCallLegacy) {
- XlaBuilder b(TestName());
- CustomCall(&b, "Callback_Void",
- /*operands=*/
- {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
- {128}, {1})},
- ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
- TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
- xla::HloModuleConfig hlo_config(
- xla::ProgramShape(computation.proto().host_program_shape()),
- /*ignore_layouts=*/false);
- DebugOptions debug_options = GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
- hlo_config.set_debug_options(debug_options);
- TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
- computation.proto(), hlo_config));
- // TF_ASSERT_OK_AND_ASSIGN(
- // HloSchedule schedule,
- // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
- // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
- // }));
- // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
- ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
- ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
- ; CHECK: custom_call_target="Callback_Void"
- ; CHECK: }
-
- ; CHECK: ENTRY %{{.*}} {
- ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42)
- ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
- expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, TupleSliceCustomCallLegacy) {
- XlaBuilder b(TestName());
- CustomCall(
- &b, "Callback_Void",
- /*operands=*/
- {
- Tuple(&b,
- {
- Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
- {0, 0}, {4, 8}, {1, 1}),
- Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
- }),
- Tuple(&b,
- {
- Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
- Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
- }),
- },
- ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
- TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
- xla::HloModuleConfig hlo_config(
- xla::ProgramShape(computation.proto().host_program_shape()),
- /*ignore_layouts=*/false);
- DebugOptions debug_options = GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
- hlo_config.set_debug_options(debug_options);
- TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
- computation.proto(), hlo_config));
- // TF_ASSERT_OK_AND_ASSIGN(
- // HloSchedule schedule,
- // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
- // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
- // }));
- // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
- ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
- ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
- ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
- ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[T0]], [[P2]]),
- ; CHECK: custom_call_target="Callback_Void"
- ; CHECK: }
-
- ; CHECK: ENTRY %{{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion(
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
- expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, TupledOutputCustomCallLegacy) {
- XlaBuilder b(TestName());
- auto custom_call = CustomCall(
- &b, "Callback_Void",
- /*operands=*/
- {
- Tuple(&b,
- {
- Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
- {0, 0}, {4, 8}, {1, 1}),
- Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
- }),
- Tuple(&b,
- {
- Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
- Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
- }),
- },
- ShapeUtil::MakeTupleShape({
- ShapeUtil::MakeShape(F32, {8}),
- ShapeUtil::MakeTupleShape({
- ShapeUtil::MakeShape(F32, {128}),
- ShapeUtil::MakeShape(F32, {256}),
- }),
- ShapeUtil::MakeShape(F32, {1024}),
- ShapeUtil::MakeShape(F32, {4, 8}),
- }),
- /*opaque=*/"");
- Tuple(&b, {GetTupleElement(GetTupleElement(custom_call, 1), 0),
- GetTupleElement(custom_call, 2)});
- TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
- xla::HloModuleConfig hlo_config(
- xla::ProgramShape(computation.proto().host_program_shape()),
- /*ignore_layouts=*/false);
- DebugOptions debug_options = GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
- hlo_config.set_debug_options(debug_options);
- TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
- computation.proto(), hlo_config));
- // TF_ASSERT_OK_AND_ASSIGN(
- // HloSchedule schedule,
- // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
- // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
- // }));
- // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
- const char* expected = R"(
- ; CHECK: %address-computation {{.*}} {
- ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
- ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
- ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
- ; CHECK: [[CC:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) custom-call([[T0]], [[P2]]),
- ; CHECK: custom_call_target="Callback_Void"
- ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[8]{0} get-tuple-element([[CC]]), index=0
- ; CHECK-DAG: [[GTE1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[CC]]), index=1
- ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE1]]), index=0
- ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[256]{0} get-tuple-element([[GTE1]]), index=1
- ; CHECK-DAG: [[T1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) tuple([[GTE2]], [[GTE3]])
- ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1024]{0} get-tuple-element([[CC]]), index=2
- ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,8]{1,0} get-tuple-element([[CC]]), index=3
- ; CHECK: ROOT {{.*}} = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) tuple([[GTE0]], [[T1]], [[GTE4]], [[GTE5]])
- ; CHECK: }
-
- ; CHECK: ENTRY %{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK-DAG: [[GTE6:%[^ ]+]] = f32[1024]{0} get-tuple-element([[FUSION]]), index=2
- ; CHECK-DAG: [[GTE7:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[FUSION]]), index=1
- ; CHECK-DAG: [[GTE8:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE7]]), index=0
- ; CHECK: ROOT {{.*}} = (f32[128]{0}, f32[1024]{0}) tuple([[GTE8]], [[GTE6]])
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
- expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, UnalignedSlice) {
- XlaBuilder b(TestName());
- CustomCall(
- &b, "Callback_Void",
- /*operands=*/
- {Slice(Broadcast(ConstantR0WithType(&b, S32, 42), {17}), {1}, {17}, {1})},
- ShapeUtil::MakeShape(S32, {16}), /*opaque=*/"");
- TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
- xla::HloModuleConfig hlo_config(
- xla::ProgramShape(computation.proto().host_program_shape()),
- /*ignore_layouts=*/false);
- DebugOptions debug_options = GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
- hlo_config.set_debug_options(debug_options);
- TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
- computation.proto(), hlo_config));
- // TF_ASSERT_OK_AND_ASSIGN(
- // HloSchedule schedule,
- // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
- // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
- // }));
- // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
- std::nullopt);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemm) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY main.9 {
- p0 = f16[2,8,8]{2,1,0} parameter(0)
- p1 = f16[2,8,8]{2,1,0} parameter(1)
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
- slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
- ROOT custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWithWorkspace) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY main.9 {
- p0 = f16[2,8,8]{2,1,0} parameter(0)
- p1 = f16[2,8,8]{2,1,0} parameter(1)
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
- slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
- ROOT custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- }
- )";
-
- const char* expected = R"(
- ; CHECK: address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
- ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
- ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
- ; CHECK: tuple([[DOT]], [[WORKSPACE]])
- ; CHECK: }
-
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWorkspaceIgnored) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY main.9 {
- p0 = f16[2,8,8]{2,1,0} parameter(0)
- p1 = f16[2,8,8]{2,1,0} parameter(1)
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
- slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
- custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- ROOT get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
- }
- )";
-
- const char* expected = R"(
- ; CHECK: address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
- ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
- ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
- ; CHECK: tuple([[DOT]], [[WORKSPACE]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmNotRoot) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY main.9 {
- p0 = f16[2,8,8]{2,1,0} parameter(0)
- p1 = f16[2,8,8]{2,1,0} parameter(1)
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
- slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
- custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- ROOT res = f16[8,8]{1,0} add(custom-call.1, custom-call.1)
- }
- )";
-
- const char* expected = R"(
- ; CHECK: address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemm) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY main.9 {
- p0 = f16[1,8,8]{2,1,0} parameter(0)
- p1 = f16[1,8,8]{2,1,0} parameter(1)
- p2 = f16[4,8,8]{2,1,0} parameter(2)
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- bitcast.41 = f16[8,8]{1,0} bitcast(p0)
- bitcast.42 = f16[8,8]{1,0} bitcast(p1)
-
- custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
- ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
- }
- )";
-
- const char* expected = R"(
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
- ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(3)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(4)
- ; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]),
- ; CHECK-DAG: custom_call_target="__cublas$gemm"
- ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
- ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmNotRoot) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY main.9 {
- p0 = f16[2,8,8]{2,1,0} parameter(0)
- p1 = f16[2,8,8]{2,1,0} parameter(1)
- p2 = f16[4,8,8]{2,1,0} parameter(2)
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
- slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
- custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
- dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
- ROOT res = f16[4,8,8]{2,1,0} log(dus)
- }
- )";
-
- const char* expected = R"(
- ; CHECK: address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
- ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
- ; CHECK-DAG: custom_call_target="__cublas$gemm"
- ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
- ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} log([[FUSION]])
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWithWorkspace) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY main.9 {
- p0 = f16[2,8,8]{2,1,0} parameter(0)
- p1 = f16[2,8,8]{2,1,0} parameter(1)
- p2 = f16[4,8,8]{2,1,0} parameter(2)
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
- slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
- bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
- custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
-
- get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
- bitcast.43 = f16[1,8,8]{2,1,0} bitcast(get-tuple-element.0)
- dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
- get-tuple-element.1 = s8[256]{0} get-tuple-element(custom-call.1), index=1
- ROOT tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(dus, get-tuple-element.1)
- }
- )";
-
- const char* expected = R"(
- ; CHECK: address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
- ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
- ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
- ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
- ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
- ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
- ; CHECK: custom_call_target="__cublas$gemm"
- ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
- ; CHECK: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
- ; CHECK: [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
- ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
- ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
- ; CHECK: tuple([[DUS]], [[WORKSPACE]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: [[DUS_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
- ; CHECK: [[WORKSPACE_MAIN:%[^ ]+]] = s8[256]{0} get-tuple-element([[FUSION]]), index=1
- ; CHECK: ROOT {{.*}} = (f16[4,8,8]{2,1,0}, s8[256]{0})
- ; CHECK: tuple([[DUS_MAIN]], [[WORKSPACE_MAIN]])
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) {
- const char* hlo = R"(
- HloModule test
-
- ENTRY %main.9 {
- %p0 = f16[8,8]{1,0} parameter(0)
- %p1 = f16[8,8]{1,0} parameter(1)
- %p2 = f16[4,8,8]{2,1,0} parameter(2)
- %c1_s32 = s32[] constant(1)
- %c0_s32 = s32[] constant(0)
-
- %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%p0, %p1),
- custom_call_target="__cublas$gemm",
- backend_config={"gemm_backend_config":{
- "alpha_real":1,
- "beta":0,
- "dot_dimension_numbers":{
- "lhs_contracting_dimensions":["1"],
- "rhs_contracting_dimensions":["0"],
- "lhs_batch_dimensions":[],
- "rhs_batch_dimensions":[]
- },
- "alpha_imag":0,
- "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
- "epilogue":"DEFAULT",
- "lhs_stride":"64",
- "rhs_stride":"64",
- "grad_x":false,
- "grad_y":false
- }}
- %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
- %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0)
- ROOT %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32)
- })";
-
- const char* expected = R"(
- ; CHECK: address-computation {{.*}} {
- ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
- ; CHECK-DAG: [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
- ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
- ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(3)
- ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(4)
- ; CHECK-DAG: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[P0]], [[P1]]),
- ; CHECK-DAG: custom_call_target="__cublas$gemm"
- ; CHECK-DAG: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
- ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
- ; CHECK-DAG: [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
- ; CHECK-DAG: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
- ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
- ; CHECK: tuple([[DUS]], [[WORKSPACE]])
- ; CHECK: }
-
- ; CHECK: ENTRY %main{{.*}} {
- ; CHECK: [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
- ; CHECK: kind=kCustom, calls=%address-computation,
- ; CHECK: backend_config={
- ; CHECK: "kind":"__custom_fusion",
- ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
- ; CHECK: }
- ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
- ; CHECK: }
- )";
-
- auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusion_merger.cc b/third_party/xla/xla/service/gpu/fusion_merger.cc
deleted file mode 100644
index 3703c98..0000000
--- a/third_party/xla/xla/service/gpu/fusion_merger.cc
+++ /dev/null
@@ -1,327 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusion_merger.h"
-
-#include <optional>
-#include <string>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/model/gpu_performance_model.h"
-#include "xla/service/gpu/model/gpu_performance_model_base.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_graph_dumper.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
-
-namespace xla {
-namespace gpu {
-
-// For each fusion F, attempts to fuse F into *all* of F's users (does not fuse
-// if can't fuse into at least one).
-class FusionInstructionMerger {
- public:
- explicit FusionInstructionMerger(
- HloComputation* computation, const se::DeviceDescription& gpu_device_info,
- HloCostAnalysis::ShapeSizeFunction shape_size_function)
- : computation_(computation),
- shape_size_function_(shape_size_function),
- gpu_device_info_(gpu_device_info),
- dump_fusion_visualization_(computation->parent()
- ->config()
- .debug_options()
- .xla_dump_fusion_visualization()) {}
-
- absl::Status Run();
-
- bool changed() const { return changed_; }
-
- private:
- FusionDecision ShouldFuse(HloInstruction* producer);
- absl::Status FuseIntoAllUsers(HloInstruction* producer);
-
- HloComputation* computation_;
- HloCostAnalysis::ShapeSizeFunction shape_size_function_;
- // Many cheap checks can prevent fusion merging - postpone execution of full
- // HLO cost analysis of the computation so that it may be not needed at all.
- std::optional<GpuHloCostAnalysis> cost_analysis_;
- FusionInfoCache fusion_info_cache_;
- const se::DeviceDescription& gpu_device_info_;
- bool changed_ = false;
- bool dump_fusion_visualization_ = false;
-
- // Fusion instruction merge stats.
- int total_visited_ = 0;
- int total_merged_ = 0;
- int num_fail_no_users_ = 0;
- int num_fail_not_loop_fusion_ = 0;
- int num_fail_merge_all_users_ = 0;
- int num_fail_inefficient_fusion_emitter_ = 0;
- int num_fail_fusion_too_large_ = 0;
- int num_fail_uncoalesced_read_ = 0;
- int num_fail_slower_if_fused_ = 0;
-
- FusionInstructionMerger(const FusionInstructionMerger&) = delete;
- FusionInstructionMerger& operator=(const FusionInstructionMerger&) = delete;
-};
-
-absl::Status FusionInstructionMerger::FuseIntoAllUsers(
- HloInstruction* producer) {
- // Merge fused instructions from 'fusion' into each user.
- std::vector<HloInstruction*> users = producer->users();
- for (HloInstruction* user : users) {
- if (dump_fusion_visualization_) {
- RegisterFusionState(
- *computation_,
- absl::StrCat("About to fuse |", producer->name(), "| into |",
- user->name(), "| inside FusionMerger"),
- /*consumer=*/*user,
- /*producer=*/producer);
- }
-
- TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(user));
-
- // Wrap consumers which are not fusions first.
- HloInstruction* consumer = user;
- if (consumer->opcode() != HloOpcode::kFusion) {
- consumer = computation_->AddInstruction(HloInstruction::CreateFusion(
- user->shape(), ChooseFusionKind(*producer, *user), user));
- TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer));
- }
-
- consumer->MergeFusionInstruction(producer);
- TF_RETURN_IF_ERROR(cost_analysis_->RevisitInstruction(consumer));
- fusion_info_cache_.Invalidate(consumer);
-
- if (dump_fusion_visualization_) {
- RegisterFusionState(*computation_,
- absl::StrCat("Fused |", producer->name(), "| into |",
- user->name(), "| inside FusionMerger"),
- *consumer);
- }
-
- changed_ = true;
- }
-
- CHECK_EQ(0, producer->user_count()) << producer->ToString();
- TF_RETURN_IF_ERROR(computation_->RemoveInstruction(producer));
- TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(producer));
- fusion_info_cache_.Invalidate(producer);
- VLOG(2) << "Merged fusion instruction: " << producer->name()
- << " into users { "
- << absl::StrJoin(users, ", ",
- [](std::string* out, HloInstruction* user) {
- absl::StrAppend(out, user->name());
- })
- << " }";
- return absl::OkStatus();
-}
-
-absl::Status FusionInstructionMerger::Run() {
- for (HloInstruction* producer : computation_->MakeInstructionPostOrder()) {
- if (producer->opcode() != HloOpcode::kFusion) {
- continue;
- }
- FusionDecision should_fuse = ShouldFuse(producer);
- if (should_fuse) {
- TF_RETURN_IF_ERROR(FuseIntoAllUsers(producer));
- ++total_merged_;
- } else {
- VLOG(3) << "Not fusing fusion |" << producer->name()
- << "| with all of it's users due to: " << should_fuse.Explain();
- if (dump_fusion_visualization_ && !producer->users().empty()) {
- RegisterFusionState(
- *computation_,
- absl::StrCat(
- "Not fusing fusion |", producer->name(),
- "| into all of its users due to: ", should_fuse.Explain()),
- // Just pick any consumer, since we are trying to merge into all.
- /*consumer=*/*producer->users()[0],
- /*producer=*/producer);
- }
- }
- }
-
- VLOG(1) << "FusionInstructionMerger EXIT"
- << " computation: " << computation_->name()
- << " total_visited: " << total_visited_
- << " total_merged: " << total_merged_ << " merge failures { "
- << " no_users: " << num_fail_no_users_
- << " not_loop_fusion: " << num_fail_not_loop_fusion_
- << " merge_all_users: " << num_fail_merge_all_users_
- << " uncoalesced_read: " << num_fail_uncoalesced_read_
- << " inefficient_fusion_emitter: "
- << num_fail_inefficient_fusion_emitter_
- << " slower_if_fused: " << num_fail_slower_if_fused_
- << " fusion_too_large: " << num_fail_fusion_too_large_ << " }";
- return absl::OkStatus();
-}
-
-bool TransposesMostData(const HloInstruction& fusion) {
- float score = 0;
-
- for (const HloInstruction* instr : fusion.fused_instructions()) {
- if (IsPhysicallyTransposing(*instr)) {
- score += 1.0 * ShapeUtil::ElementsInRecursive(instr->shape()) /
- ShapeUtil::ElementsInRecursive(fusion.shape());
- if (score >= 0.5) {
- VLOG(3) << fusion.ToString() << " transpose ratio exceeds " << score;
- return true;
- }
- }
- }
-
- return false;
-}
-
-FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) {
- ++total_visited_;
-
- VLOG(4) << "Considering producer " << producer->name();
-
- // Skip 'producer' instruction if there are no users into which we can
- // merge.
- if (producer->users().empty()) {
- ++num_fail_no_users_;
- return "fusion has no users";
- }
-
- // Skip 'producer' instruction if it is not a loop fusion. Library fusion
- // instructions match specific patterns, so they shouldn't be further fused.
- // Input fusion instructions need to be rooted at a particular HLO (e.g.
- // kReduce), so they shouldn't be further fused either.
- if (!producer->IsLoopFusion()) {
- ++num_fail_not_loop_fusion_;
- return "not a loop fusion";
- }
-
- auto producer_hero = GetRealHeroForMultiOutputFusion(*producer);
-
- bool has_reduction_user = false;
- for (const HloInstruction* user : producer->users()) {
- if (user->opcode() == HloOpcode::kBitcast) {
- ++num_fail_merge_all_users_;
- return "not fusing bitcast ops";
- }
- if (user->IsCustomFusion()) {
- ++num_fail_merge_all_users_;
- return "not fusing custom fusions";
- }
- auto consumer_hero = GetRealHeroForMultiOutputFusion(*user);
- if (auto compatible =
- FusionHeroesAreCompatible(producer_hero, consumer_hero);
- !compatible) {
- return compatible;
- }
- FusionDecision fusible = IsProducerConsumerFusible(*producer, *user);
- if (!fusible) {
- ++num_fail_merge_all_users_;
- VLOG(9) << user->ToString();
- return fusible;
- }
- if (IsInputFusibleReduction(*user)) {
- has_reduction_user = true;
- }
- }
-
- // We do not want to worsen reduction's memory access pattern by connecting
- // it to a producer which transposes most data.
- if (has_reduction_user && TransposesMostData(*producer)) {
- ++num_fail_uncoalesced_read_;
- return "would read mostly uncoalesced";
- }
-
- for (const HloInstruction* user : producer->users()) {
- // Skip 'fusion' instruction if merging it into at least one of the users
- // would make the fusion use too much shared memory or registers.
- FusionDecision fits = FusionFitsInBudget(
- *user, *producer, gpu_device_info_,
- /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
- if (!fits) {
- ++num_fail_fusion_too_large_;
- return fits;
- }
- }
-
- if (!cost_analysis_) {
- VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
- cost_analysis_.emplace(
- GpuHloCostAnalysis::Options{shape_size_function_,
- /*per_second_rates=*/{},
- /*count_multiple_input_accesses=*/true},
- gpu_device_info_);
- TF_CHECK_OK(computation_->Accept(&cost_analysis_.value()));
- }
-
- for (const HloInstruction* user : producer->users()) {
- if (cost_analysis_->ProducerConsumerMergedTooLarge(*producer, *user)) {
- ++num_fail_inefficient_fusion_emitter_;
- return FusionDecision{} << "if merged with " << user->name()
- << " will generate huge IR";
- }
- }
-
- GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
- producer, gpu_device_info_, &*cost_analysis_,
- GpuPerformanceModelOptions::Default(), producer->users());
- if (t.time_fused > t.time_unfused) {
- ++num_fail_slower_if_fused_;
- return "will execute slower if fused";
- }
-
- return {};
-}
-
-absl::StatusOr<bool> FusionMerger::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- VLOG(1) << "FusionMerger for module: " << module->name();
- for (auto* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- VLOG(9) << "Before running FusionInstructionMerger for computation: "
- << computation->name();
- XLA_VLOG_LINES(9, computation->ToString());
-
- FusionInstructionMerger fusion_merger(computation, gpu_device_info_,
- shape_size_function_);
- TF_RETURN_IF_ERROR(fusion_merger.Run());
- changed |= fusion_merger.changed();
-
- VLOG(9) << "After running FusionInstructionMerger for computation: "
- << computation->name() << " changed: " << changed;
- XLA_VLOG_LINES(9, computation->ToString());
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusion_merger.h b/third_party/xla/xla/service/gpu/fusion_merger.h
deleted file mode 100644
index acbc93e..0000000
--- a/third_party/xla/xla/service/gpu/fusion_merger.h
+++ /dev/null
@@ -1,85 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSION_MERGER_H_
-#define XLA_SERVICE_GPU_FUSION_MERGER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// An HLO pass that attempts to merge fusion instructions to reduce memory
-// bandwidth requirements and kernel launch overhead.
-//
-// Consider the example below. On the left-hand side, op A is the producer and
-// ops B and C are its consumers. FusionMerger duplicates producer ops and fuses
-// them into all consumers. The result is depicted on the right-hand side below.
-//
-// p p
-// | / \
-// v / \
-// A +fusion+ +fusion+
-// / \ | A' | | A" |
-// | | | | | | | |
-// v v | v | | v |
-// B C | B | | C |
-// +------+ +------+
-//
-// Op A has been cloned twice and fused with B and C. The kernel launch overhead
-// is reduced from 3 to 2. The memory bandwidth requirements may be reduced.
-// We trade 1 read of input(A) + 1 write and 2 reads of output(A) for 2 reads of
-// input(A). In general the achieveable savings in memory bandwidth depend on
-// the differences in memory read and written and the number of consumers. The
-// FusionMeger pass takes this into account when making fusion decisions.
-//
-// The pass traverses the HLO module in post-order (defs before uses).
-// Fusion instructions are merged into their users if some conditions are met:
-// * The result of merging the fusion instruction into its users would not
-// increase bytes transferred.
-// * Producer ops are fusible with _all_ consumers. If they are not fusible with
-// at least one consumers, they won't be fused at all.
-// * Producers are kLoop fusion ops.
-//
-// None of these restrictions are necessary for correctness. In fact, lifting
-// the latter two could be beneficial.
-
-class FusionMerger : public HloModulePass {
- public:
- explicit FusionMerger(const se::DeviceDescription& d,
- HloCostAnalysis::ShapeSizeFunction f)
- : gpu_device_info_(d), shape_size_function_(f) {}
- absl::string_view name() const override { return "fusion_merger"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::DeviceDescription gpu_device_info_;
- HloCostAnalysis::ShapeSizeFunction shape_size_function_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSION_MERGER_H_
diff --git a/third_party/xla/xla/service/gpu/fusion_merger_test.cc b/third_party/xla/xla/service/gpu/fusion_merger_test.cc
deleted file mode 100644
index de45a4b..0000000
--- a/third_party/xla/xla/service/gpu/fusion_merger_test.cc
+++ /dev/null
@@ -1,1162 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusion_merger.h"
-
-#include <cstdint>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class FusionMergerTest : public HloTestBase {
- HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
- return [&](const Shape& shape) {
- constexpr int64_t kPointerSize = 8;
- return ShapeUtil::ByteSizeOf(shape, kPointerSize);
- };
- }
-
- public:
- FusionMerger fusion_merger_{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
- ShapeSizeBytesFunction()};
- FusionMergerTest() : HloTestBase() {}
-};
-
-// Tests that we can merge a fusion instruction that is below threshold.
-//
-// Computation after fusion merger pass (Fusion2 is merged into Fusion0 and
-// Fusion1):
-// Param
-// / | \
-// Fusion3 Fusion0 Fusion1
-// \ | /
-// Tuple
-//
-TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule MergeSharedFusionInstruction
-
-comp.3 {
- constant.param_0 = f32[4]{0} parameter(0)
- param.param_1.2 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(1)
- get-tuple-element.6 = f32[4]{0} get-tuple-element(param.param_1.2), index=0
- ROOT add.7 = f32[4]{0} add(constant.param_0, get-tuple-element.6)
-}
-
-comp.2 {
- param.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
- get-tuple-element.4 = f32[4]{0} get-tuple-element(param.param_1.1), index=1
- get-tuple-element.5 = f32[4]{0} get-tuple-element(param.param_1.1), index=2
- ROOT add.6 = f32[4]{0} add(get-tuple-element.4, get-tuple-element.5)
-}
-
-comp.1 {
- add.1.param_1.1 = f32[4]{0} parameter(1)
- constant.param_1.3 = f32[4]{0} parameter(0)
- add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
- ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
-}
-
-comp {
- add.1.param_1 = f32[4]{0} parameter(1)
- constant.param_1.1 = f32[4]{0} parameter(0)
- multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
- ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
-}
-
-ENTRY MergeSharedFusionInstruction.Computation0 {
- constant = f32[4]{0} constant({1, 1, 1, 1})
- param = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
- fusion.3 = f32[4]{0} fusion(constant, param), kind=kLoop, calls=comp.3
- fusion.4 = f32[4]{0} fusion(param), kind=kLoop, calls=comp.2
- fusion.5 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp.1
- fusion.6 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp
- ROOT tuple = (f32[4]{0}, f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.5, fusion.6)
-})")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-
- auto* root = module->entry_computation()->root_instruction();
- EXPECT_EQ(HloOpcode::kTuple, root->opcode());
- // Check operand 0 (not merged). Should have 4 instructions.
- auto* operand0 = root->operand(0);
- EXPECT_EQ(HloOpcode::kFusion, operand0->opcode());
- EXPECT_EQ(4, operand0->fused_instruction_count());
- // Check operand 1 (should have merged in its operand fusion instruction).
- auto* operand1 = root->operand(1);
- EXPECT_EQ(HloOpcode::kFusion, operand1->opcode());
- EXPECT_EQ(7, operand1->fused_instruction_count());
- // Check operand 2 (should have merged in its operand fusion instruction).
- auto* operand2 = root->operand(2);
- EXPECT_EQ(HloOpcode::kFusion, operand2->opcode());
- EXPECT_EQ(7, operand2->fused_instruction_count());
-}
-
-TEST_F(FusionMergerTest, MoreMemoryAccessIfFused) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f32add {
- x = f32[] parameter(0)
- y = f32[] parameter(1)
- ROOT _ = f32[] add(x, y)
-}
-
-comp0 {
- p = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0)
- gte0 = f32[100000000] get-tuple-element(p), index=0
- gte1 = f32[100000000] get-tuple-element(p), index=1
- add.9 = f32[100000000] add(gte0, gte1)
- gte2 = f32[100000000] get-tuple-element(p), index=2
- add.10 = f32[100000000] add(add.9, gte2)
- gte3 = f32[100000000] get-tuple-element(p), index=3
- add.11 = f32[100000000] add(add.10, gte3)
- p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1)
- gte4 = f32[100000000] get-tuple-element(p1), index=0
- gte5 = f32[100000000] get-tuple-element(p1), index=1
- add.12 = f32[100000000] add(gte4, gte5)
- gte6 = f32[100000000] get-tuple-element(p1), index=2
- add.13 = f32[100000000] add(add.12, gte6)
- gte7 = f32[100000000] get-tuple-element(p1), index=3
- add.14 = f32[100000000] add(add.13, gte7)
- ROOT r = f32[100000000] add(add.14, add.11)
-}
-
-comp1 {
- p = f32[100000000] parameter(0)
- c0 = f32[] constant(0)
- ROOT r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
-}
-
-comp2 {
- p = f32[100000000] parameter(0)
- c0 = f32[] constant(0)
- r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
- ROOT n = f32[] negate(r)
-}
-
-ENTRY m.Computation2 {
- p0 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0)
- p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1)
- fusion.0 = f32[100000000] fusion(p0, p1), kind=kLoop, calls=comp0
- fusion.1 = f32[] fusion(fusion.0), kind=kLoop, calls=comp1
- fusion.2 = f32[] fusion(fusion.0), kind=kLoop, calls=comp2
- ROOT tuple = (f32[], f32[]) tuple(fusion.1, fusion.2)
-}
-)")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, LessMemoryAccessIfFused) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-comp.2 {
- state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
- get-tuple-element.5 = f32[4]{0} get-tuple-element(state.param_1.1), index=0
- get-tuple-element.6 = f32[4]{0} get-tuple-element(state.param_1.1), index=1
- add.7 = f32[4]{0} add(get-tuple-element.5, get-tuple-element.6)
- get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=2
- ROOT add.8 = f32[4]{0} add(add.7, get-tuple-element.7)
-}
-
-comp.1 {
- add.1.param_1.1 = f32[4]{0} parameter(1)
- constant.param_1.3 = f32[4]{0} parameter(0)
- add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
- ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
-}
-
-comp {
- add.1.param_1 = f32[4]{0} parameter(1)
- constant.param_1.1 = f32[4]{0} parameter(0)
- multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
- ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
-}
-
-ENTRY m.Computation2 {
- constant = f32[4]{0} constant({1, 1, 1, 1})
- state = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
- fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2
- fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1
- fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp
- ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4)
-})")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-// Check that we're willing to merge f1_computation into f2_computation, even
-// though f2 is an input fusion node.
-TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- f1_computation {
- f1_p0 = f32[32]{0} parameter(0)
- ROOT f1_root = f32[32]{0} add(f1_p0, f1_p0)
- }
-
- add_computation {
- add_lhs = f32[] parameter(0)
- add_rhs = f32[] parameter(1)
- ROOT add_root = f32[] add(add_lhs, add_rhs)
- }
-
- f2_computation {
- f2_p0 = f32[32]{0} parameter(0)
- f2_mul = f32[32]{0} multiply(f2_p0, f2_p0)
- f2_zero = f32[] constant(0)
- ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
- to_apply=add_computation
- }
-
- ENTRY entry {
- p0 = f32[32]{0} parameter(0)
- f1 = f32[32]{0} fusion(p0), kind=kLoop, calls=f1_computation
- ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
- })")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter())));
-}
-
-TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule jit_matmul.36
-
- max (parameter.13: f32[], parameter.14: f32[]) -> f32[] {
- parameter.13 = f32[] parameter(0)
- parameter.14 = f32[] parameter(1)
- ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14)
- }
-
- add (parameter.29: f32[], parameter.30: f32[]) -> f32[] {
- parameter.29 = f32[] parameter(0)
- parameter.30 = f32[] parameter(1)
- ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30)
- }
-
- fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] {
- param_1.4 = f32[200,200,200]{2,1,0} parameter(0)
- param_2.1 = f32[200,200]{1,0} parameter(1)
- broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2}
- subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3)
- exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0)
- constant.27 = f32[] constant(0)
- ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add
- }
-
- fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] {
- param_1.9 = f32[200,200]{1,0} parameter(1)
- broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1}
- param_0.7 = f32[200,200]{1,0} parameter(0)
- broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2}
- ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8)
- }
-
- ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] {
- parameter.2 = f32[200,200]{1,0} parameter(1)
- parameter.1 = f32[200,200]{1,0} parameter(0)
- fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3
- constant.11 = f32[] constant(-inf)
- reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max
- ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1
- })")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Fusion(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
- // TODO(b/247762001): the case here does not represent the problem -
- // profiling shows that it works faster if merged (even on larger dimensions).
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- f1_computation {
- f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
- add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
- // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
- ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
- }
-
- add_computation {
- add_lhs = f32[] parameter(0)
- add_rhs = f32[] parameter(1)
- ROOT add_root = f32[] add(add_lhs, add_rhs)
- }
-
- f2_computation {
- f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
- f2_zero = f32[] constant(0)
- ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
- to_apply=add_computation
- }
-
- ENTRY entry {
- p0 = f32[16,16,256]{0,1,2} parameter(0)
- f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
- ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
- })")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeReduceNotTooUnfriendlyLayouts) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- f1_computation {
- f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
- slice1 = f32[5,16,256]{0,1,2} slice(f1_p0), slice={[0:5], [0:16], [0:256]}
- // Here the copy changes the layout only of a part of the data.
- f1_copy = f32[5,16,256]{2,1,0} copy(slice1)
- slice2 = f32[11,16,256]{0,1,2} slice(f1_p0), slice={[0:11], [0:16], [0:256]}
- bitcast = f32[11,16,256]{2,1,0} bitcast(slice2)
- ROOT f1_root = f32[16,16,256]{2,1,0} concatenate(f1_copy, bitcast), dimensions={0}
- }
-
- add_computation {
- add_lhs = f32[] parameter(0)
- add_rhs = f32[] parameter(1)
- ROOT add_root = f32[] add(add_lhs, add_rhs)
- }
-
- f2_computation {
- f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
- f2_zero = f32[] constant(0)
- ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
- to_apply=add_computation
- }
-
- ENTRY entry {
- p0 = f32[16,16,256]{0,1,2} parameter(0)
- f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
- ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
- })")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-// Check that we limit the number of operands to fusions we create.
-TEST_F(FusionMergerTest, AvoidsLargeFusion) {
- constexpr int64_t kNumParams = MaxOperandsAndOutputsPerFusion() + 1;
-
- // Compute
- // p0 + p1 + p2 + ... + pn,
- // Use so many parameters that they do not fit into one fusion.
- auto module = CreateNewVerifiedModule();
- HloComputation::Builder b(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
-
- std::vector<HloInstruction*> entry_params;
-
- for (int64_t i = 0; i < kNumParams; ++i) {
- entry_params.push_back(
- b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
- }
- auto make_fusion = [&](absl::Span<HloInstruction* const> params) {
- // Build a fusion computation for calculating the sum of all parameters.
- HloComputation::Builder sub_builder("subcomp");
- HloInstruction* sum = nullptr;
- for (int64_t i = 0; i < params.size(); ++i) {
- auto p = sub_builder.AddInstruction(
- HloInstruction::CreateParameter(i, shape, "p"));
- if (sum == nullptr) {
- sum = p;
- } else {
- sum = sub_builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, p));
- }
- }
- HloComputation* subcomp =
- module->AddEmbeddedComputation(sub_builder.Build());
- return HloInstruction::CreateFusion(
- shape, HloInstruction::FusionKind::kLoop, params, subcomp);
- };
- auto fusion = b.AddInstruction(
- make_fusion(absl::MakeSpan(entry_params)
- .subspan(0, MaxOperandsAndOutputsPerFusion())));
- b.AddInstruction(make_fusion({entry_params.back(), fusion}));
- module->AddEntryComputation(b.Build());
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-// TODO(b/119692968): Remove this test once fusion emitter is fixed.
-TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f1 {
- Arg_0.5 = f32[200000] parameter(0)
- slice.7 = f32[100000] slice(Arg_0.5), slice={[0:199999:2]}
- slice.8 = f32[100000] slice(Arg_0.5), slice={[1:200000:2]}
- add.9 = f32[100000] add(slice.7, slice.8)
- slice.10 = f32[50000] slice(add.9), slice={[0:99999:2]}
- slice.11 = f32[50000] slice(add.9), slice={[1:100000:2]}
- add.12 = f32[50000] add(slice.10, slice.11)
- slice.13 = f32[25000] slice(add.12), slice={[0:49999:2]}
- slice.14 = f32[25000] slice(add.12), slice={[1:50000:2]}
- add.15 = f32[25000] add(slice.13, slice.14)
- slice.16 = f32[12500] slice(add.15), slice={[0:24999:2]}
- slice.17 = f32[12500] slice(add.15), slice={[1:25000:2]}
- add.18 = f32[12500] add(slice.16, slice.17)
- slice.19 = f32[6250] slice(add.18), slice={[0:12499:2]}
- slice.20 = f32[6250] slice(add.18), slice={[1:12500:2]}
- add.21 = f32[6250] add(slice.19, slice.20)
- slice.22 = f32[3125] slice(add.21), slice={[0:6249:2]}
- slice.23 = f32[3125] slice(add.21), slice={[1:6250:2]}
- ROOT add.24 = f32[3125] add(slice.22, slice.23)
-}
-
-f2 {
- Arg_0 = f32[3125] parameter(0)
- slice.25 = f32[1562] slice(Arg_0), slice={[0:3124:2]}
- slice.26 = f32[1562] slice(Arg_0), slice={[1:3125:2]}
- add.27 = f32[1562] add(slice.25, slice.26)
- slice.28 = f32[781] slice(add.27), slice={[0:1561:2]}
- slice.29 = f32[781] slice(add.27), slice={[1:1562:2]}
- add.30 = f32[781] add(slice.28, slice.29)
- slice.31 = f32[390] slice(add.30), slice={[0:780:2]}
- slice.32 = f32[390] slice(add.30), slice={[1:781:2]}
- add.33 = f32[390] add(slice.31, slice.32)
- slice.34 = f32[195] slice(add.33), slice={[0:389:2]}
- slice.35 = f32[195] slice(add.33), slice={[1:390:2]}
- add.36 = f32[195] add(slice.34, slice.35)
- slice.37 = f32[97] slice(add.36), slice={[0:194:2]}
- slice.38 = f32[97] slice(add.36), slice={[1:195:2]}
- add.39 = f32[97] add(slice.37, slice.38)
- slice.40 = f32[48] slice(add.39), slice={[0:96:2]}
- slice.41 = f32[48] slice(add.39), slice={[1:97:2]}
- ROOT add.42 = f32[48] add(slice.40, slice.41)
-}
-
-ENTRY e {
- p0 = f32[200000] parameter(0)
- f1 = f32[3125] fusion(p0), kind=kLoop, calls=f1
- ROOT r = f32[48] fusion(f1), kind=kLoop, calls=f2
-})")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeSliceIntoReusingConsumer) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f1 {
- p01 = s8[1000000] parameter(0)
- ROOT s0 = s8[10] slice(p01), slice={[0:10]}
-}
-
-f2 {
- p02 = s8[10] parameter(0)
- ROOT b0 = s8[10,1000000] broadcast(p02), dimensions={0}
-}
-
-ENTRY e {
- p0 = s8[1000000] parameter(0)
- f1 = s8[10] fusion(p0), kind=kLoop, calls=f1
- ROOT r = s8[10,1000000] fusion(f1), kind=kLoop, calls=f2
-})")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- %f_a (p: f32[]) -> f32[1024,1024,1024] {
- %p = f32[] parameter(0)
- %b = f32[1024,1024,1024] broadcast(%p), dimensions={}
- ROOT %t = f32[1024,1024,1024] tanh(%b)
- }
-
- %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
- %p = f32[1024,1024,1024] parameter(0)
- ROOT %t = f32[1024,1024,1024] tanh(%p)
- }
-
- %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
- %p = f32[1024,1024,1024] parameter(0)
- ROOT %t = f32[1024,1024,1024] tanh(%p)
- }
-
- ENTRY entry {
- p0 = f32[] parameter(0)
- f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_a
- f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_b
- f3 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
- ROOT f4 = f32[1024,1024,1024] add(f2, f3)
- })")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
- %p = f32[1024,1024,1024] parameter(0)
- ROOT %t = f32[1024,1024,1024] tanh(%p)
- }
-
- %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
- %p = f32[1024,1024,1024] parameter(0)
- ROOT %t = f32[1024,1024,1024] add(%p, %p)
- }
-
- ENTRY entry {
- p0 = f32[1024,1024,1024] parameter(0)
- f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
- ROOT f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
- })")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillNotMergeExpensiveFusionsWithReusingConsumer) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- %f_b {
- %p = f32[1024,1024,1024] parameter(0)
- %t1 = f32[1024,1024,1024] tanh(%p)
- %t2 = f32[1024,1024,1024] tanh(%t1)
- %t3 = f32[1024,1024,1024] tanh(%t2)
- %t4 = f32[1024,1024,1024] tanh(%t3)
- %t5 = f32[1024,1024,1024] tanh(%t4)
- %t6 = f32[1024,1024,1024] tanh(%t5)
- %t7 = f32[1024,1024,1024] tanh(%t6)
- %t8 = f32[1024,1024,1024] tanh(%t7)
- ROOT %t9 = f32[1024,1024,1024] tanh(%t8)
- }
-
- %f_c {
- %p = f32[1024,1024,1024] parameter(0)
- ROOT %t = f32[1024,1024,1024,2048] broadcast(%p), dimensions={0,1,2}
- }
-
- ENTRY entry {
- p0 = f32[1024,1024,1024] parameter(0)
- f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
- ROOT f2 = f32[1024,1024,1024,2048] fusion(f1), kind=kLoop, calls=%f_c
- })")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, NoMergeWithBitcast) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f32add {
- x.634 = f32[] parameter(0)
- y.635 = f32[] parameter(1)
- ROOT add.636 = f32[] add(x.634, y.635)
-}
-
-fused_computation.103 {
- param_0.310 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
- param_1.420 = f32[8,512]{1,0} parameter(1)
- bitcast.1144 = f32[1,8,512]{2,1,0} bitcast(param_1.420)
- convert.252 = f16[1,8,512]{2,1,0} convert(bitcast.1144)
- bitcast.1143 = f16[8,512]{1,0} bitcast(convert.252)
- broadcast.481 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1143), dimensions={1,2}
- divide.15 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.310, broadcast.481)
- ROOT bitcast.1142 = f16[8,512,1536]{1,2,0} bitcast(divide.15)
-}
-
-fused_computation.105 {
- param_1.426 = f16[8,1536,512]{2,1,0} parameter(1)
- bitcast.1896 = f16[1,8,1536,512]{3,2,1,0} bitcast(param_1.426)
- transpose.238 = f16[1,8,512,1536]{2,3,1,0} transpose(bitcast.1896), dimensions={0,1,3,2}
- param_0.315 = f16[8,512]{1,0} parameter(0)
- broadcast.482 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.315), dimensions={1,2}
- subtract.22 = f16[1,8,512,1536]{2,3,1,0} subtract(transpose.238, broadcast.482)
- ROOT exponential.15 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.22)
-}
-
-fused_computation.104 {
- param_0.1000 = f16[8,1536,512]{2,1,0} parameter(0)
- convert.652 = f32[8,1536,512]{2,1,0} convert(param_0.1000)
- constant_752 = f32[] constant(-0)
- ROOT reduce.232 = f32[8,512]{1,0} reduce(convert.652, constant_752),
- dimensions={1}, to_apply=f32add
-}
-
-ENTRY entry {
- p0 = f16[8,1536,512]{2,1,0} parameter(0)
- p1 = f16[8,512]{1,0} parameter(1)
- fusion.105 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.105
- bitcast.1787 = f16[8,1536,512]{2,1,0} bitcast(fusion.105)
- fusion.104 = f32[8,512]{1,0} fusion(bitcast.1787), kind=kInput, calls=fused_computation.104
- ROOT fusion.103 = f16[8,512,1536]{1,2,0} fusion(fusion.105, fusion.104), kind=kLoop, calls=fused_computation.103
-}
- )")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, CostBasedMerge) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-fused_computation.45 {
- param_1.194 = f16[8,1536,512]{2,1,0} parameter(1)
- bitcast.1042 = f16[1,8,512,1536]{2,3,1,0} bitcast(param_1.194)
- param_0.135 = f16[8,512]{1,0} parameter(0)
- broadcast.391 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.135), dimensions={1,2}
- subtract.6 = f16[1,8,512,1536]{2,3,1,0} subtract(bitcast.1042, broadcast.391)
- ROOT exponential.11 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.6)
-}
-
-f32add {
- x.634 = f32[] parameter(0)
- y.635 = f32[] parameter(1)
- ROOT add.636 = f32[] add(x.634, y.635)
-}
-
-fused_computation.44 {
- param_0.869 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
- convert.221 = f32[1,8,512,1536]{2,3,1,0} convert(param_0.869)
- transpose.212 = f32[1,8,1536,512]{3,2,1,0} transpose(convert.221), dimensions={0,1,3,2}
- bitcast.1041 = f32[8,1536,512]{2,1,0} bitcast(transpose.212)
- constant_429 = f32[] constant(0)
- ROOT reduce.149 = f32[8,512]{1,0} reduce(bitcast.1041, constant_429), dimensions={1}, to_apply=f32add
-}
-
-fused_computation.43 {
- param_0.130 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
- param_1.188 = f32[8,512]{1,0} parameter(1)
- bitcast.1040 = f32[1,8,512]{2,1,0} bitcast(param_1.188)
- convert.220 = f16[1,8,512]{2,1,0} convert(bitcast.1040)
- bitcast.1039 = f16[8,512]{1,0} bitcast(convert.220)
- broadcast.390 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1039), dimensions={1,2}
- divide.11 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.130, broadcast.390)
- ROOT bitcast.1038 = f16[8,512,1536]{1,2,0} bitcast(divide.11)
-}
-
-ENTRY entry {
- p0 = f16[8,1536,512]{2,1,0} parameter(0)
- p1 = f16[8,512]{1,0} parameter(1)
- fusion.45 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.45
- fusion.44 = f32[8,512]{1,0} fusion(fusion.45), kind=kInput, calls=fused_computation.44
- ROOT fusion.43 = f16[8,512,1536]{1,2,0} fusion(fusion.45, fusion.44), kind=kLoop, calls=fused_computation.43
-}
- )")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-// Outputs of fusions 66 and 67 here are heavily reused by fusion 59 - so
-// it is better to not merge here.
-TEST_F(FusionMergerTest, CostBasedNoMerge) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-add_float_.56 {
- x.57 = f32[] parameter(0)
- y.58 = f32[] parameter(1)
- ROOT add.59 = f32[] add(x.57, y.58)
-}
-
-fused_computation.66 {
- constant.635 = f32[] constant(0)
- broadcast.257 = f32[459,3]{1,0} broadcast(constant.635), dimensions={}
- constant.641 = f32[] constant(1)
- broadcast.256 = f32[459,3]{1,0} broadcast(constant.641), dimensions={}
- broadcast.255 = f32[459]{0} broadcast(constant.635), dimensions={}
- iota.28 = f32[459]{0} iota(), iota_dimension=0
- constant.629 = f32[] constant(1.49891067)
- broadcast.253 = f32[459]{0} broadcast(constant.629), dimensions={}
- multiply.39 = f32[459]{0} multiply(iota.28, broadcast.253)
- constant.633 = f32[] constant(-1)
- broadcast.252 = f32[459]{0} broadcast(constant.633), dimensions={}
- add.31 = f32[459]{0} add(multiply.39, broadcast.252)
- ceil.11 = f32[459]{0} ceil(add.31)
- constant.630 = f32[] constant(685)
- broadcast.251 = f32[459]{0} broadcast(constant.630), dimensions={}
- clamp.49 = f32[459]{0} clamp(broadcast.255, ceil.11, broadcast.251)
- subtract.11 = f32[459]{0} subtract(clamp.49, multiply.39)
- broadcast.249 = f32[459,3]{1,0} broadcast(subtract.11), dimensions={0}
- iota.26 = f32[459,3]{1,0} iota(), iota_dimension=1
- add.30 = f32[459,3]{1,0} add(broadcast.249, iota.26)
- abs.3 = f32[459,3]{1,0} abs(add.30)
- subtract.10 = f32[459,3]{1,0} subtract(broadcast.256, abs.3)
- maximum.6 = f32[459,3]{1,0} maximum(broadcast.257, subtract.10)
- ROOT reduce.3 = f32[459]{0} reduce(maximum.6, constant.635), dimensions={1}, to_apply=add_float_.56
-}
-
-fused_computation.67 {
- constant.684 = f32[] constant(0)
- broadcast.296 = f32[1130,3]{1,0} broadcast(constant.684), dimensions={}
- constant.685 = f32[] constant(1)
- broadcast.295 = f32[1130,3]{1,0} broadcast(constant.685), dimensions={}
- broadcast.294 = f32[1130]{0} broadcast(constant.684), dimensions={}
- iota.41 = f32[1130]{0} iota(), iota_dimension=0
- constant.675 = f32[] constant(1.34513271)
- broadcast.293 = f32[1130]{0} broadcast(constant.675), dimensions={}
- multiply.47 = f32[1130]{0} multiply(iota.41, broadcast.293)
- constant.677 = f32[] constant(-1)
- broadcast.290 = f32[1130]{0} broadcast(constant.677), dimensions={}
- add.39 = f32[1130]{0} add(multiply.47, broadcast.290)
- ceil.15 = f32[1130]{0} ceil(add.39)
- constant.676 = f32[] constant(1517)
- broadcast.289 = f32[1130]{0} broadcast(constant.676), dimensions={}
- clamp.53 = f32[1130]{0} clamp(broadcast.294, ceil.15, broadcast.289)
- subtract.19 = f32[1130]{0} subtract(clamp.53, multiply.47)
- broadcast.287 = f32[1130,3]{1,0} broadcast(subtract.19), dimensions={0}
- iota.39 = f32[1130,3]{1,0} iota(), iota_dimension=1
- add.38 = f32[1130,3]{1,0} add(broadcast.287, iota.39)
- abs.7 = f32[1130,3]{1,0} abs(add.38)
- subtract.18 = f32[1130,3]{1,0} subtract(broadcast.295, abs.7)
- maximum.10 = f32[1130,3]{1,0} maximum(broadcast.296, subtract.18)
- ROOT reduce.4 = f32[1130]{0} reduce(maximum.10, constant.684), dimensions={1}, to_apply=add_float_.56
-}
-
-fused_computation.59 {
- constant.532 = f32[] constant(0)
- broadcast.316 = f32[1130,3]{1,0} broadcast(constant.532), dimensions={}
- constant.663 = f32[] constant(1)
- broadcast.315 = f32[1130,3]{1,0} broadcast(constant.663), dimensions={}
- broadcast.314 = f32[1130]{0} broadcast(constant.532), dimensions={}
- iota.47 = f32[1130]{0} iota(), iota_dimension=0
- constant.579 = f32[] constant(1.34513271)
- broadcast.311 = f32[1130]{0} broadcast(constant.579), dimensions={}
- multiply.51 = f32[1130]{0} multiply(iota.47, broadcast.311)
- constant.578 = f32[] constant(-1)
- broadcast.310 = f32[1130]{0} broadcast(constant.578), dimensions={}
- add.43 = f32[1130]{0} add(multiply.51, broadcast.310)
- ceil.17 = f32[1130]{0} ceil(add.43)
- constant.576 = f32[] constant(1517)
- broadcast.309 = f32[1130]{0} broadcast(constant.576), dimensions={}
- clamp.55 = f32[1130]{0} clamp(broadcast.314, ceil.17, broadcast.309)
- subtract.24 = f32[1130]{0} subtract(clamp.55, multiply.51)
- broadcast.306 = f32[1130,3]{1,0} broadcast(subtract.24), dimensions={0}
- iota.45 = f32[1130,3]{1,0} iota(), iota_dimension=1
- add.42 = f32[1130,3]{1,0} add(broadcast.306, iota.45)
- abs.9 = f32[1130,3]{1,0} abs(add.42)
- subtract.23 = f32[1130,3]{1,0} subtract(broadcast.315, abs.9)
- maximum.12 = f32[1130,3]{1,0} maximum(broadcast.316, subtract.23)
- param_2.183 = f32[1130]{0} parameter(2)
- broadcast.172 = f32[1130,3]{1,0} broadcast(param_2.183), dimensions={0}
- divide.3 = f32[1130,3]{1,0} divide(maximum.12, broadcast.172)
- bitcast.53 = f32[3390]{0} bitcast(divide.3)
- broadcast.171 = f32[3390,1377]{1,0} broadcast(bitcast.53), dimensions={0}
- broadcast.276 = f32[459,3]{1,0} broadcast(constant.532), dimensions={}
- broadcast.275 = f32[459,3]{1,0} broadcast(constant.663), dimensions={}
- broadcast.274 = f32[459]{0} broadcast(constant.532), dimensions={}
- iota.35 = f32[459]{0} iota(), iota_dimension=0
- constant.614 = f32[] constant(1.49891067)
- broadcast.273 = f32[459]{0} broadcast(constant.614), dimensions={}
- multiply.43 = f32[459]{0} multiply(iota.35, broadcast.273)
- broadcast.272 = f32[459]{0} broadcast(constant.578), dimensions={}
- add.35 = f32[459]{0} add(multiply.43, broadcast.272)
- ceil.13 = f32[459]{0} ceil(add.35)
- constant.611 = f32[] constant(685)
- broadcast.269 = f32[459]{0} broadcast(constant.611), dimensions={}
- clamp.51 = f32[459]{0} clamp(broadcast.274, ceil.13, broadcast.269)
- subtract.15 = f32[459]{0} subtract(clamp.51, multiply.43)
- broadcast.267 = f32[459,3]{1,0} broadcast(subtract.15), dimensions={0}
- iota.33 = f32[459,3]{1,0} iota(), iota_dimension=1
- add.34 = f32[459,3]{1,0} add(broadcast.267, iota.33)
- abs.5 = f32[459,3]{1,0} abs(add.34)
- subtract.14 = f32[459,3]{1,0} subtract(broadcast.275, abs.5)
- maximum.8 = f32[459,3]{1,0} maximum(broadcast.276, subtract.14)
- param_1.177 = f32[459]{0} parameter(1)
- broadcast.170 = f32[459,3]{1,0} broadcast(param_1.177), dimensions={0}
- divide.2 = f32[459,3]{1,0} divide(maximum.8, broadcast.170)
- bitcast.52 = f32[1377]{0} bitcast(divide.2)
- broadcast.169 = f32[3390,1377]{1,0} broadcast(bitcast.52), dimensions={1}
- multiply.15 = f32[3390,1377]{1,0} multiply(broadcast.171, broadcast.169)
- bitcast.61 = f32[1130,3,459,3]{3,2,1,0} bitcast(multiply.15)
- transpose.68 = f32[459,1130,3,3]{2,0,3,1} transpose(bitcast.61), dimensions={2,0,3,1}
- copy.1 = f32[459,1130,3,3]{3,2,1,0} copy(transpose.68)
- bitcast.50 = f32[1130,459,9]{2,1,0} bitcast(copy.1)
- broadcast.168 = f32[1130,459,6,9]{3,2,1,0} broadcast(bitcast.50), dimensions={0,1,3}
- param_0.171 = u8[1,688,1520,6]{3,2,1,0} parameter(0)
- bitcast.49 = u8[688,1520,1,6]{3,1,0,2} bitcast(param_0.171)
- convert.175 = f32[688,1520,1,6]{3,1,0,2} convert(bitcast.49)
- broadcast.167 = f32[459,1130,1]{2,1,0} broadcast(clamp.51), dimensions={0}
- broadcast.166 = f32[459,1130,1]{2,1,0} broadcast(clamp.55), dimensions={1}
- concatenate.3 = f32[459,1130,2]{2,1,0} concatenate(broadcast.167, broadcast.166), dimensions={2}
- convert.174 = s32[459,1130,2]{2,1,0} convert(concatenate.3)
- bitcast.48 = s32[518670,2]{1,0} bitcast(convert.174)
- gather.1 = f32[518670,3,3,1,6]{2,1,4,0,3} gather(convert.175, bitcast.48), offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={3,3,1,6}
- transpose.69 = f32[1,518670,6,3,3]{4,3,2,1,0} transpose(gather.1), dimensions={3,0,4,1,2}
- bitcast.47 = f32[1130,459,6,9]{3,2,1,0} bitcast(transpose.69)
- multiply.14 = f32[1130,459,6,9]{3,2,1,0} multiply(broadcast.168, bitcast.47)
- reduce.2 = f32[1130,459,6]{2,1,0} reduce(multiply.14, constant.532), dimensions={3}, to_apply=add_float_.56
- convert.173 = f16[1130,459,6]{2,1,0} convert(reduce.2)
- bitcast.46 = f16[1,459,1130,6]{3,2,1,0} bitcast(convert.173)
- constant.533 = f16[] constant(0)
- pad.9 = f16[1,480,1130,6]{3,2,1,0} pad(bitcast.46, constant.533), padding=0_0x0_21x0_0x0_0
- pad.8 = f16[1,480,1152,6]{3,2,1,0} pad(pad.9, constant.533), padding=0_0x0_0x0_22x0_0
- constant.532f16 = f16[] constant(0)
- ROOT pad.7 = f16[1,485,1157,6]{3,2,1,0} pad(pad.8, constant.532f16), padding=0_0x2_3x2_3x0_0
-}
-
-ENTRY e {
- arg0.1 = u8[1,688,1520,6]{3,2,1,0} parameter(0), parameter_replication={false}
- fusion.66 = f32[459]{0} fusion(), kind=kLoop, calls=fused_computation.66
- fusion.67 = f32[1130]{0} fusion(), kind=kLoop, calls=fused_computation.67
- ROOT fusion.59 = f16[1,485,1157,6]{2,1,3,0} fusion(arg0.1, fusion.66, fusion.67), kind=kLoop, calls=fused_computation.59
-}
- )")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, NoMergeBecauseTooManyBasicBlockSplits) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-region_6.97 {
- Arg_0.98 = pred[] parameter(0)
- Arg_1.99 = pred[] parameter(1)
- ROOT or.100 = pred[] or(Arg_0.98, Arg_1.99)
-}
-
-region_4.50 {
- Arg_0.51 = f64[] parameter(0)
- Arg_1.52 = f64[] parameter(1)
- ROOT add.53 = f64[] add(Arg_0.51, Arg_1.52)
-}
-
-f2 {
- param_0 = s64[1]{0} parameter(0)
- constant_70 = f64[] constant(0)
- convert.41.clone.1 = f64[1]{0} convert(param_0)
- ROOT pad.99.clone.1 = f64[3]{0} pad(convert.41.clone.1, constant_70), padding=0_2
-}
-
-f1 {
- param_0.361 = pred[5]{0} parameter(0)
- broadcast.107 = pred[10,5]{1,0} broadcast(param_0.361), dimensions={1}
- param_6.244 = pred[5]{0} parameter(6)
- broadcast.111.clone.1 = pred[10,5]{1,0} broadcast(param_6.244), dimensions={1}
- param_1.450 = f64[10,5]{1,0} parameter(1)
- constant_294_clone_1 = f64[] constant(1)
- broadcast.153.clone.1 = f64[10,5]{1,0} broadcast(constant_294_clone_1), dimensions={}
- compare.22.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.153.clone.1), direction=GE
- constant_75_clone_1 = f64[] constant(-1)
- broadcast.109.clone.1 = f64[10,5]{1,0} broadcast(constant_75_clone_1), dimensions={}
- add.34.clone.1 = f64[10,5]{1,0} add(param_1.450, broadcast.109.clone.1)
- param_5.322 = f64[10,5,4]{1,0,2} parameter(5)
- slice.45.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [3:4]}
- bitcast.94.clone.1 = f64[10,5]{1,0} bitcast(slice.45.clone.1)
- divide.7.clone.1 = f64[10,5]{1,0} divide(add.34.clone.1, bitcast.94.clone.1)
- add.33.clone.1 = f64[10,5]{1,0} add(divide.7.clone.1, broadcast.153.clone.1)
- constant_70 = f64[] constant(0)
- broadcast.157.clone.1 = f64[10,5]{1,0} broadcast(constant_70), dimensions={}
- compare.26.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.157.clone.1), direction=LE
- slice.46.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [0:1]}
- bitcast.93.clone.1 = f64[10,5]{1,0} bitcast(slice.46.clone.1)
- divide.6.clone.1 = f64[10,5]{1,0} divide(param_1.450, bitcast.93.clone.1)
- broadcast.295.clone.1 = f64[10,5,3]{1,0,2} broadcast(param_1.450), dimensions={0,1}
- param_4.368 = f64[10,5,2]{1,0,2} parameter(4)
- pad.103.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_70), padding=0_0x0_0x1_0
- compare.121.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.103.clone.1), direction=GE
- pad.102.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_294_clone_1), padding=0_0x0_0x0_1
- compare.120.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.102.clone.1), direction=LT
- and.39.clone.1 = pred[10,5,3]{1,0,2} and(compare.121.clone.1, compare.120.clone.1)
- transpose.9 = pred[3,10,5]{2,1,0} transpose(and.39.clone.1), dimensions={2,0,1}
- constant_296_clone_1 = pred[] constant(false)
- reduce.91.clone.1 = pred[10,5]{1,0} reduce(transpose.9, constant_296_clone_1), dimensions={0}, to_apply=region_6.97
- broadcast.294.clone.1 = pred[10,5,3]{1,0,2} broadcast(reduce.91.clone.1), dimensions={0,1}
- pad.99.clone.1 = f64[3]{0} parameter(3)
- broadcast.292.clone.1 = f64[3]{0} broadcast(constant_70), dimensions={}
- compare.117.clone.1 = pred[3]{0} compare(pad.99.clone.1, broadcast.292.clone.1), direction=NE
- broadcast.290.clone.1 = pred[10,5,3]{1,0,2} broadcast(compare.117.clone.1), dimensions={2}
- select.67.clone.1 = pred[10,5,3]{1,0,2} select(broadcast.294.clone.1, and.39.clone.1, broadcast.290.clone.1)
- convert.40.clone.1 = f64[10,5,3]{1,0,2} convert(select.67.clone.1)
- broadcast.288.clone.1 = f64[10,5,3,3]{1,0,2,3} broadcast(convert.40.clone.1), dimensions={0,1,2}
- param_2.361 = f64[10,5,4,3]{1,0,2,3} parameter(2)
- slice.114.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [1:4], [0:3]}
- multiply.53.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.114.clone.1)
- transpose.10 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.53.clone.1), dimensions={3,2,0,1}
- reduce.90.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.10, constant_70), dimensions={1}, to_apply=region_4.50
- transpose.11 = f64[10,5,3]{1,0,2} transpose(reduce.90.clone.1), dimensions={1,2,0}
- slice.28.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [0:1]}
- bitcast.99.clone.1 = f64[10,5]{1,0} bitcast(slice.28.clone.1)
- slice.108.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [0:3], [0:3]}
- multiply.49.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.108.clone.1)
- transpose.12 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.49.clone.1), dimensions={3,2,0,1}
- reduce.82.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.12, constant_70), dimensions={1}, to_apply=region_4.50
- transpose.13 = f64[10,5,3]{1,0,2} transpose(reduce.82.clone.1), dimensions={1,2,0}
- slice.107.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [0:1]}
- bitcast.240.clone.1 = f64[10,5]{1,0} bitcast(slice.107.clone.1)
- subtract.27.clone.1 = f64[10,5]{1,0} subtract(bitcast.99.clone.1, bitcast.240.clone.1)
- slice.27.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [2:3]}
- bitcast.98.clone.1 = f64[10,5]{1,0} bitcast(slice.27.clone.1)
- slice.26.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [2:3]}
- bitcast.97.clone.1 = f64[10,5]{1,0} bitcast(slice.26.clone.1)
- add.36.clone.1 = f64[10,5]{1,0} add(bitcast.97.clone.1, bitcast.98.clone.1)
- slice.24.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [1:2]}
- bitcast.95.clone.1 = f64[10,5]{1,0} bitcast(slice.24.clone.1)
- slice.121.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [1:2]}
- bitcast.274.clone.1 = f64[10,5]{1,0} bitcast(slice.121.clone.1)
- subtract.26.clone.1 = f64[10,5]{1,0} subtract(bitcast.95.clone.1, bitcast.274.clone.1)
- divide.21 = f64[10,5]{1,0} divide(subtract.26.clone.1, subtract.27.clone.1)
- constant_77_clone_1 = f64[] constant(2)
- broadcast.117.clone.1 = f64[10,5]{1,0} broadcast(constant_77_clone_1), dimensions={}
- multiply.37.clone.1 = f64[10,5]{1,0} multiply(divide.21, broadcast.117.clone.1)
- subtract.25.clone.1 = f64[10,5]{1,0} subtract(add.36.clone.1, multiply.37.clone.1)
- subtract.24.clone.1 = f64[10,5]{1,0} subtract(param_1.450, bitcast.274.clone.1)
- divide.9.clone.1 = f64[10,5]{1,0} divide(subtract.24.clone.1, subtract.26.clone.1)
- clamp.7.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.9.clone.1, broadcast.153.clone.1)
- multiply.36.clone.1 = f64[10,5]{1,0} multiply(subtract.25.clone.1, clamp.7.clone.1)
- subtract.23.clone.1 = f64[10,5]{1,0} subtract(bitcast.98.clone.1, multiply.36.clone.1)
- compare.13.clone.1 = pred[10,5]{1,0} compare(subtract.23.clone.1, broadcast.157.clone.1), direction=GE
- negate.19.clone.1 = f64[10,5]{1,0} negate(divide.21)
- multiply.35.clone.1 = f64[10,5]{1,0} multiply(negate.19.clone.1, clamp.7.clone.1)
- multiply.34.clone.1 = f64[10,5]{1,0} multiply(multiply.35.clone.1, broadcast.117.clone.1)
- negate.18.clone.1 = f64[10,5]{1,0} negate(subtract.23.clone.1)
- multiply.33.clone.1 = f64[10,5]{1,0} multiply(subtract.23.clone.1, subtract.23.clone.1)
- subtract.22.clone.1 = f64[10,5]{1,0} subtract(divide.21, subtract.23.clone.1)
- constant_78_clone_1 = f64[] constant(4)
- broadcast.113.clone.1 = f64[10,5]{1,0} broadcast(constant_78_clone_1), dimensions={}
- multiply.32.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.113.clone.1)
- multiply.31.clone.1 = f64[10,5]{1,0} multiply(multiply.32.clone.1, multiply.35.clone.1)
- subtract.21.clone.1 = f64[10,5]{1,0} subtract(multiply.33.clone.1, multiply.31.clone.1)
- compare.12.clone.1 = pred[10,5]{1,0} compare(subtract.21.clone.1, broadcast.157.clone.1), direction=GT
- constant_79_clone_1 = f64[] constant(2.2250738585072014e-308)
- broadcast.112.clone.1 = f64[10,5]{1,0} broadcast(constant_79_clone_1), dimensions={}
- maximum.18.clone.1 = f64[10,5]{1,0} maximum(broadcast.112.clone.1, subtract.21.clone.1)
- sqrt.1.clone.1 = f64[10,5]{1,0} sqrt(maximum.18.clone.1)
- select.47.clone.1 = f64[10,5]{1,0} select(compare.12.clone.1, sqrt.1.clone.1, broadcast.157.clone.1)
- add.35.clone.1 = f64[10,5]{1,0} add(negate.18.clone.1, select.47.clone.1)
- select.46.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, multiply.34.clone.1, add.35.clone.1)
- subtract.20.clone.1 = f64[10,5]{1,0} subtract(negate.18.clone.1, select.47.clone.1)
- multiply.30.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.117.clone.1)
- select.45.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, subtract.20.clone.1, multiply.30.clone.1)
- divide.8.clone.1 = f64[10,5]{1,0} divide(select.46.clone.1, select.45.clone.1)
- clamp.6.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.8.clone.1, broadcast.153.clone.1)
- multiply.29.clone.1 = f64[10,5]{1,0} multiply(subtract.27.clone.1, clamp.6.clone.1)
- add.32.clone.1 = f64[10,5]{1,0} add(multiply.29.clone.1, bitcast.240.clone.1)
- select.44.clone.1 = f64[10,5]{1,0} select(compare.26.clone.1, divide.6.clone.1, add.32.clone.1)
- select.43.clone.1 = f64[10,5]{1,0} select(compare.22.clone.1, add.33.clone.1, select.44.clone.1)
- select.42.clone.1 = f64[10,5]{1,0} select(broadcast.111.clone.1, param_1.450, select.43.clone.1)
- select.41 = f64[10,5]{1,0} select(broadcast.107, select.42.clone.1, broadcast.157.clone.1)
- ROOT tuple.14 = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) tuple(select.41, select.42.clone.1, clamp.6.clone.1, subtract.25.clone.1, bitcast.97.clone.1, multiply.37.clone.1, bitcast.98.clone.1, divide.21)
-}
-
-ENTRY e {
- p3 = s64[1]{0} parameter(3)
- f2 = f64[3]{0} fusion(p3), kind=kLoop, calls=f2
-
- p0 = pred[5]{0} parameter(0)
- p1 = f64[10,5]{1,0} parameter(1)
- p2 = f64[10,5,4,3]{1,0,2,3} parameter(2)
- p4 = f64[10,5,2]{1,0,2} parameter(4)
- p5 = f64[10,5,4]{1,0,2} parameter(5)
- p6 = pred[5]{0} parameter(6)
- ROOT ret = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) fusion(p0, p1, p2, f2, p4, p5, p6), kind=kLoop, calls=f1
-}
- )")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, CommonElementwiseUsedParameter) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- p {
- p0 = f32[10000000] parameter(0)
- p1 = f32[10000000] parameter(1)
- p2 = f32[10000000] parameter(2)
- p3 = f32[10000000] parameter(3)
- a0 = f32[10000000] add(p1, p2)
- a1 = f32[10000000] add(a0, p3)
- ROOT _ = add(p0, a1)
- }
-
- c1 {
- p0 = f32[10000000] parameter(0)
- p1 = f32[10000000] parameter(1)
- ROOT _ = add(p0, p1)
- }
-
- c2 {
- p0 = f32[10000000] parameter(0)
- p1 = f32[10000000] parameter(1)
- ROOT _ = multiply(p0, p1)
- }
-
- ENTRY entry {
- p0 = f32[10000000] parameter(0)
- p1 = f32[10000000] parameter(1)
- p2 = f32[10000000] parameter(2)
- p3 = f32[10000000] parameter(3)
- f = f32[10000000] fusion(p0, p1, p2, p3), kind=kLoop, calls=p
- f1 = f32[10000000] fusion(p0, f), kind=kLoop, calls=c1
- f2 = f32[10000000] fusion(p1, f), kind=kLoop, calls=c2
- ROOT _ = (f32[10000000], f32[10000000]) tuple(f1, f2)
- }
- )")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, IncompatibleNonTrivialHeroes) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fused_computation {
- param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
- param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
- s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
- t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
- sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
- exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
- ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
- }
-
- fused_computation.2 {
- param_0.2 = f32[32,16,18]{2,1,0} parameter(0)
- s.2 = f32[32,16,18]{2,1,0} sqrt(param_0.2)
- ROOT t.2 = f32[32,18,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
- }
-
- ENTRY main {
- p = f32[18,16,32]{2,1,0} parameter(0)
- p2 = f32[32,16,18]{2,1,0} parameter(1)
- fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
- ROOT fusion2 = f32[32,18,16]{2,1,0} fusion(fusion), kind=kInput, calls=fused_computation.2
- }
- )")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, DoNotMergeDUSFusions) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- %fused_computation (param_0: f32[8], param_1.2: f32[], param_2.3: f32[8]) -> f32[8] {
- %param_0 = f32[8]{0} parameter(0)
- %param_2.3 = f32[8]{0} parameter(2)
- %slice.2 = f32[5]{0} slice(f32[8]{0} %param_2.3), slice={[0:5]}
- %param_1.2 = f32[] parameter(1)
- %broadcast.2 = f32[5]{0} broadcast(f32[] %param_1.2), dimensions={}
- %add.2 = f32[5]{0} add(f32[5]{0} %slice.2, f32[5]{0} %broadcast.2)
- %two.1 = s32[] constant(2)
- ROOT %dynamic-update-slice.2 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0, f32[5]{0} %add.2, s32[] %two.1)
- }
-
- %fused_computation.1 (param_0.1: f32[8], param_1.4: f32[6], param_2.6: f32[]) -> f32[8] {
- %param_0.1 = f32[8]{0} parameter(0)
- %param_1.4 = f32[6]{0} parameter(1)
- %param_2.6 = f32[] parameter(2)
- %broadcast.3 = f32[6]{0} broadcast(f32[] %param_2.6), dimensions={}
- %add.3 = f32[6]{0} add(f32[6]{0} %param_1.4, f32[6]{0} %broadcast.3)
- %three.1 = s32[] constant(3)
- ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[6]{0} %add.3, s32[] %three.1)
- }
-
- ENTRY %Test (parameter: f32[8]) -> f32[8] {
- %parameter = f32[8]{0} parameter(0)
- %slice.1 = f32[6]{0} slice(f32[8]{0} %parameter), slice={[0:6]}
- %one = f32[] constant(1)
- %fusion.1 = f32[8]{0} fusion(f32[8]{0} %parameter, f32[6]{0} %slice.1, f32[] %one), kind=kLoop, calls=%fused_computation.1
- ROOT %fusion = f32[8]{0} fusion(f32[8]{0} %fusion.1, f32[] %one, f32[8]{0} %parameter), kind=kLoop, calls=%fused_computation
- }
- )")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, MergeDUSFusionWithElementwiseFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- %fused_computation {
- %param_0 = f32[1,8]{1,0} parameter(0)
- %bitcast = f32[8]{0} bitcast(%param_0)
- ROOT %neg = f32[8]{0} negate(%bitcast)
- }
-
- %fused_computation.1 {
- %param_0.1 = f32[8]{0} parameter(0)
- %param_1.4 = f32[5]{0} parameter(1)
- %three.1 = s32[] constant(3)
- %exp = f32[5]{0} exponential(%param_1.4)
- ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[5]{0} %exp, s32[] %three.1)
- }
-
- ENTRY %Test {
- %parameter = f32[5]{0} parameter(0)
- %parameter.1 = f32[1,8]{1,0} parameter(1)
- %fusion = f32[8]{0} fusion(f32[1,8]{1,0} %parameter.1), kind=kLoop, calls=%fused_computation
- ROOT %fusion.1 = f32[8]{0} fusion(f32[8]{0} %fusion, f32[5]{0} %parameter), kind=kLoop, calls=%fused_computation.1
- }
- )")
- .value();
- EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, DoNotMergeTwoReduces) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add.13235 = f32[] add(p0, p1)
- }
-
- ENTRY main {
- p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
- c0 = f32[] constant(0)
- r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
- ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
- }
- )")
- .value();
- EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc
index 2a184c0..4fc4af0 100644
--- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc
+++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc
@@ -19,14 +19,14 @@
#include <utility>
#include "xla/service/cpu_gpu_shape_verifier.h"
-#include "xla/service/gpu/fusion_merger.h"
-#include "xla/service/gpu/horizontal_input_fusion.h"
-#include "xla/service/gpu/horizontal_loop_fusion.h"
-#include "xla/service/gpu/instruction_fusion.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/multi_output_fusion.h"
-#include "xla/service/gpu/priority_fusion.h"
-#include "xla/service/gpu/variadic_op_splitter.h"
+#include "xla/service/gpu/transforms/fusion_merger.h"
+#include "xla/service/gpu/transforms/horizontal_input_fusion.h"
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
+#include "xla/service/gpu/transforms/priority_fusion.h"
+#include "xla/service/gpu/transforms/variadic_op_splitter.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_cse.h"
#include "xla/service/hlo_dce.h"
@@ -63,8 +63,8 @@
shape_size_bytes_function,
/*per_second_rates=*/{},
/*count_multiple_input_accesses=*/true};
- fusion.AddPass<GpuPriorityFusion>(thread_pool, gpu_device_info,
- std::move(cost_analysis_options));
+ fusion.AddPass<PriorityFusion>(thread_pool, gpu_device_info,
+ std::move(cost_analysis_options));
} else {
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false,
gpu_device_info);
@@ -77,8 +77,7 @@
fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
fusion.AddPass<HloDCE>();
- fusion.AddPass<GpuMultiOutputFusion>(gpu_device_info,
- shape_size_bytes_function);
+ fusion.AddPass<MultiOutputFusion>(gpu_device_info, shape_size_bytes_function);
fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
fusion.AddPass<HloDCE>();
@@ -88,8 +87,8 @@
HloPassPipeline HorizontalFusionPipeline(
const se::DeviceDescription& gpu_device_info) {
HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
- horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
- horizontal_fusion.AddPass<GpuHorizontalInputFusion>(gpu_device_info);
+ horizontal_fusion.AddPass<HorizontalLoopFusion>();
+ horizontal_fusion.AddPass<HorizontalInputFusion>(gpu_device_info);
horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
horizontal_fusion.AddPass<HloDCE>();
diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/fusion_wrapper.cc
deleted file mode 100644
index 2cb8471..0000000
--- a/third_party/xla/xla/service/gpu/fusion_wrapper.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusion_wrapper.h"
-
-#include <functional>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "tsl/platform/errors.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> FusionWrapper::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- auto instructions = module->entry_computation()->MakeInstructionPostOrder();
- bool changed = false;
-
- std::function<absl::Status(HloInstruction*)> handle_instruction;
- handle_instruction = [&](HloInstruction* instruction) -> absl::Status {
- switch (instruction->opcode()) {
- case HloOpcode::kConditional:
- case HloOpcode::kWhile:
- for (auto* computation : instruction->called_computations()) {
- for (auto* inner_instruction :
- computation->MakeInstructionPostOrder()) {
- TF_RETURN_IF_ERROR(handle_instruction(inner_instruction));
- }
- }
- break;
- case HloOpcode::kAbs:
- case HloOpcode::kAdd:
- case HloOpcode::kAnd:
- case HloOpcode::kAtan2:
- case HloOpcode::kBitcastConvert:
- case HloOpcode::kBroadcast:
- case HloOpcode::kCeil:
- case HloOpcode::kCbrt:
- case HloOpcode::kClamp:
- case HloOpcode::kClz:
- case HloOpcode::kCompare:
- case HloOpcode::kComplex:
- case HloOpcode::kConcatenate:
- case HloOpcode::kConvert:
- case HloOpcode::kCopy:
- case HloOpcode::kCos:
- case HloOpcode::kDivide:
- case HloOpcode::kDot:
- case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- case HloOpcode::kErf:
- case HloOpcode::kExp:
- case HloOpcode::kExpm1:
- case HloOpcode::kFloor:
- case HloOpcode::kGather:
- case HloOpcode::kImag:
- case HloOpcode::kIota:
- case HloOpcode::kIsFinite:
- case HloOpcode::kLog:
- case HloOpcode::kLog1p:
- case HloOpcode::kMap:
- case HloOpcode::kMaximum:
- case HloOpcode::kMinimum:
- case HloOpcode::kMultiply:
- case HloOpcode::kNegate:
- case HloOpcode::kNot:
- case HloOpcode::kOr:
- case HloOpcode::kPad:
- case HloOpcode::kPopulationCount:
- case HloOpcode::kPower:
- case HloOpcode::kReal:
- case HloOpcode::kReshape:
- case HloOpcode::kReduce:
- case HloOpcode::kReducePrecision:
- case HloOpcode::kReduceWindow:
- case HloOpcode::kRemainder:
- case HloOpcode::kReverse:
- case HloOpcode::kRoundNearestAfz:
- case HloOpcode::kRoundNearestEven:
- case HloOpcode::kRsqrt:
- case HloOpcode::kScatter:
- case HloOpcode::kSelect:
- case HloOpcode::kShiftLeft:
- case HloOpcode::kShiftRightLogical:
- case HloOpcode::kShiftRightArithmetic:
- case HloOpcode::kSign:
- case HloOpcode::kSin:
- case HloOpcode::kSlice:
- case HloOpcode::kSqrt:
- case HloOpcode::kSubtract:
- case HloOpcode::kStochasticConvert:
- case HloOpcode::kTan:
- case HloOpcode::kTanh:
- case HloOpcode::kTranspose:
- case HloOpcode::kXor: {
- auto* computation = instruction->parent();
- auto* fusion_instruction =
- computation->AddInstruction(HloInstruction::CreateFusion(
- instruction->shape(),
- ChooseFusionKind(*instruction, *instruction), instruction));
- const absl::string_view wrapped_opcode =
- HloOpcodeString(instruction->opcode());
- module->SetAndUniquifyInstrName(
- fusion_instruction, absl::StrCat("wrapped_", wrapped_opcode));
- module->SetAndUniquifyComputationName(
- fusion_instruction->fused_instructions_computation(),
- absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
- if (module->has_schedule()) {
- module->schedule().replace_instruction(computation, instruction,
- fusion_instruction);
- }
- TF_RETURN_IF_ERROR(
- fusion_instruction->CopyAllControlDepsFrom(instruction));
- TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
- TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
- TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
- changed = true;
- break;
- }
- default:
- break;
- }
- return absl::OkStatus();
- };
-
- for (auto* instruction : instructions) {
- TF_RETURN_IF_ERROR(handle_instruction(instruction));
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper.h b/third_party/xla/xla/service/gpu/fusion_wrapper.h
deleted file mode 100644
index fc46692..0000000
--- a/third_party/xla/xla/service/gpu/fusion_wrapper.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSION_WRAPPER_H_
-#define XLA_SERVICE_GPU_FUSION_WRAPPER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Wraps leftover unfused instruction that are in the entry computation that
-// have no LHLO equivalent in fusions containing just that instruction.
-class FusionWrapper : public HloModulePass {
- public:
- absl::string_view name() const override { return "fusion-wrapper"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSION_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc
deleted file mode 100644
index 397fe75..0000000
--- a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusion_wrapper.h"
-
-#include <optional>
-
-#include <gtest/gtest.h>
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class FusionWrapperTest : public HloTestBase {};
-
-TEST_F(FusionWrapperTest, SimpleOp) {
- RunAndFilecheckHloRewrite(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- p0 = f16[30,41] parameter(0)
- p1 = f16[30,41] parameter(1)
- ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
- })",
- FusionWrapper(), R"(
-// CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] {
-// CHECK: %param_0 = f16[30,41]{1,0} parameter(0)
-// CHECK: %param_1 = f16[30,41]{1,0} parameter(1)
-// CHECK: ROOT %result.1 = f16[60,41]{1,0} concatenate(%param_0, %param_1), dimensions={0}
-// CHECK: }
-
-// CHECK: ENTRY %TestComputation (p0: f16[30,41], p1: f16[30,41]) -> f16[60,41] {
-// CHECK: %p0 = f16[30,41]{1,0} parameter(0)
-// CHECK: %p1 = f16[30,41]{1,0} parameter(1)
-// CHECK: ROOT %wrapped_concatenate = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%wrapped_concatenate_computation
-// CHECK: })");
-}
-
-TEST_F(FusionWrapperTest, Scatter) {
- RunAndFilecheckHloRewrite(R"(
- HloModule ScatterIntoScalar
-
- update_s32 {
- lhs = s32[] parameter(0)
- ROOT rhs = s32[] parameter(1)
- }
-
- ENTRY main {
- parameter.1 = s32[] parameter(0)
- parameter.2 = s32[0]{0} parameter(1)
- parameter.3 = s32[] parameter(2)
- ROOT scatter_ScatterIntoScalar = s32[] scatter(parameter.1, parameter.2, parameter.3),
- update_window_dims={},
- inserted_window_dims={},
- scatter_dims_to_operand_dims={},
- index_vector_dim=0,
- to_apply=update_s32
- })",
- FusionWrapper(), R"(
-// CHECK: wrapped_scatter_computation
-// CHECK: %[[param_0:.*]] = s32[] parameter(0)
-// CHECK: %[[param_1:.*]] = s32[0]{0} parameter(1)
-// CHECK: %[[param_2:.*]] = s32[] parameter(2)
-// CHECK: ROOT %{{.*}} = s32[] scatter(%[[param_0]], %[[param_1]], %[[param_2]])
-
-// CHECK: ENTRY
-// CHECK: %[[p0:.*]] = s32[] parameter(0)
-// CHECK: %[[p1:.*]] = s32[0]{0} parameter(1)
-// CHECK: %[[p2:.*]] = s32[] parameter(2)
-// CHECK: ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%wrapped_scatter_computation
-// CHECK: })");
-}
-
-TEST_F(FusionWrapperTest, ControlDependency) {
- RunAndFilecheckHloRewrite(R"(
- HloModule TestModule
-
- fusion {
- ROOT param = f32[] parameter(0)
- }
-
- ENTRY main {
- param = f32[] parameter(0)
- fusion = f32[] fusion(param), kind=kLoop, calls=fusion
- constant_one = f32[] constant(1)
- ROOT add = f32[] add(param, constant_one), control-predecessors={fusion}
- })",
- FusionWrapper(), R"(
-// CHECK: ROOT %wrapped_add = f32[] fusion(%param.1, %constant_one),
-// CHECK-SAME: control-predecessors={%fusion})");
-}
-
-TEST_F(FusionWrapperTest, While) {
- RunAndFilecheckHloRewrite(R"(
- HloModule While
-
- %body {
- %parameter.5 = (f32[5]{0}) parameter(0)
- %constant_8 = f32[] constant(0)
- %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
- ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
- }
-
- %cond {
- %parameter.12 = (f32[5]{0}) parameter(0)
- ROOT %constant_1 = pred[] constant(false)
- }
-
- ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
- %parameter.1 = f32[5]{0} parameter(0)
- %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
- %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
- ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
- })",
- FusionWrapper(), R"(
-// CHECK: %wrapped_broadcast_computation {{.*}} {
-// CHECK: %param_0.1 = f32[] parameter(0)
-// CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={}
-// CHECK: }
-// CHECK: %body {{.*}} {
-// CHECK: %parameter.5 = (f32[5]{0}) parameter(0)
-// CHECK: %constant_8 = f32[] constant(0)
-// CHECK: %wrapped_broadcast = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%wrapped_broadcast_computation
-// CHECK: ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast)
-// CHECK: }
-// CHECK: %cond {{.*}} {
-// CHECK: %parameter.12 = (f32[5]{0}) parameter(0)
-// CHECK: ROOT %constant_1 = pred[] constant(false)
-// CHECK: }
-// CHECK: %wrapped_copy_computation {{.*}} {
-// CHECK: %param_0 = f32[5]{0} parameter(0)
-// CHECK: ROOT %copy.0 = f32[5]{0} copy(%param_0)
-// CHECK: }
-// CHECK: ENTRY %main {{.*}} {
-// CHECK: %parameter.1 = f32[5]{0} parameter(0)
-// CHECK: %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation
-// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy)
-// CHECK: ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body
-// CHECK: })");
-}
-
-TEST_F(FusionWrapperTest, WhileInFusion) {
- RunAndFilecheckHloRewrite(R"(
- HloModule While
-
- %body {
- %parameter.5 = (f32[5]{0}) parameter(0)
- %constant_8 = f32[] constant(0)
- %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
- ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
- }
-
- %cond {
- %parameter.12 = (f32[5]{0}) parameter(0)
- ROOT %constant_1 = pred[] constant(false)
- }
-
- %fusion {
- %parameter.1 = f32[5]{0} parameter(0)
- %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
- %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
- ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
- }
-
- ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
- %parameter.1 = f32[5]{0} parameter(0)
- ROOT %fusion = (f32[5]{0}) fusion(f32[5]{0} %parameter.1), kind=kLoop, calls=%fusion
- })",
- FusionWrapper(),
- // No change
- std::nullopt);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD
index e936cdf..188eaed 100644
--- a/third_party/xla/xla/service/gpu/fusions/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/BUILD
@@ -9,49 +9,6 @@
)
cc_library(
- name = "in_place_dynamic_update_slice",
- srcs = ["in_place_dynamic_update_slice.cc"],
- hdrs = ["in_place_dynamic_update_slice.h"],
- deps = [
- ":fusion_emitter",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:ir_emitter",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/service/llvm_ir:dynamic_update_slice_util",
- "//xla/service/llvm_ir:fused_ir_emitter",
- "//xla/service/llvm_ir:ir_array",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:ir_headers",
- "@llvm-project//mlir:IR",
- ],
-)
-
-xla_cc_test(
- name = "in_place_dynamic_update_slice_test",
- srcs = ["in_place_dynamic_update_slice_test.cc"],
- deps = [
- ":fusions",
- ":in_place_dynamic_update_slice",
- "//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu/model:affine_map_printer",
- "//xla/service/gpu/model:indexing_test_utils",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "in_place_dynamic_update_slice_mlir",
srcs = ["in_place_dynamic_update_slice_mlir.cc"],
hdrs = ["in_place_dynamic_update_slice_mlir.h"],
@@ -122,6 +79,7 @@
deps = [
":fusion_emitter",
"//xla:shape_util",
+ "//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:attribute_map",
@@ -147,6 +105,10 @@
"//xla/service/gpu/runtime:dynamic_slice_thunk",
"//xla/service/gpu/runtime:gemm_thunk",
"//xla/service/gpu/runtime:kernel_thunk",
+ "//xla/service/gpu/runtime:nccl_all_reduce_thunk",
+ "//xla/service/gpu/runtime:nccl_api",
+ "//xla/service/gpu/runtime:nccl_clique_key",
+ "//xla/service/gpu/runtime:nccl_collective_thunk",
"//xla/service/gpu/runtime:thunk",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
@@ -167,6 +129,12 @@
xla_test(
name = "dynamic_slice_fusion_test",
srcs = if_cuda_is_configured(["dynamic_slice_fusion_test.cc"]),
+ backend_tags = {
+ "gpu": [
+ "multi_gpu",
+ "no_oss",
+ ],
+ },
backends = ["gpu"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
@@ -184,11 +152,14 @@
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
- "//xla/service/gpu:dynamic_slice_fusion_rewriter",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/stream_executor",
"//xla/stream_executor:device_description",
"//xla/stream_executor/gpu:gpu_types_header",
+ "//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
+ "//xla/tests:test_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@local_tsl//tsl/platform:errors",
@@ -244,23 +215,16 @@
hdrs = ["fusions.h"],
visibility = ["//xla/service/gpu:__subpackages__"],
deps = [
- ":concatenate",
":concatenate_mlir",
":copy",
":cudnn",
":custom",
":fusion_emitter",
- ":in_place_dynamic_update_slice",
":in_place_dynamic_update_slice_mlir",
- ":input_slices",
":input_slices_mlir",
- ":loop",
":loop_mlir",
- ":reduction",
":reduction_mlir",
- ":scatter",
":scatter_mlir",
- ":transpose",
":transpose_mlir",
":triton",
"//xla:shape_util",
@@ -270,6 +234,13 @@
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu/fusions/legacy:concatenate",
+ "//xla/service/gpu/fusions/legacy:in_place_dynamic_update_slice",
+ "//xla/service/gpu/fusions/legacy:input_slices",
+ "//xla/service/gpu/fusions/legacy:loop",
+ "//xla/service/gpu/fusions/legacy:reduction",
+ "//xla/service/gpu/fusions/legacy:scatter",
+ "//xla/service/gpu/fusions/legacy:transpose",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
@@ -291,8 +262,8 @@
"//xla/service:gpu_plugin",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
"//xla/service/gpu/model:affine_map_printer",
"//xla/stream_executor:device_description",
"//xla/tests:filecheck",
@@ -320,53 +291,22 @@
)
cc_library(
- name = "loop",
- srcs = ["loop.cc"],
- hdrs = ["loop.h"],
- deps = [
- ":fusion_emitter",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:gpu_fusible",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:hlo_traversal",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:ir_emitter",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu:parallel_loop_emitter",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/service/llvm_ir:fused_ir_emitter",
- "//xla/service/llvm_ir:ir_array",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/numeric:bits",
- "@com_google_absl//absl/status",
- "@llvm-project//llvm:ir_headers",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:macros",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "loop_mlir",
srcs = ["loop_mlir.cc"],
hdrs = ["loop_mlir.h"],
deps = [
- ":loop",
"//xla:shape_util",
"//xla:status_macros",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
+ "//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
@@ -401,17 +341,17 @@
srcs = ["scatter_mlir.cc"],
hdrs = ["scatter_mlir.h"],
deps = [
- ":loop",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:scatter_simplifier",
+ "//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
@@ -456,15 +396,14 @@
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
- "//xla/mlir/utils:type_util",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
"//xla/service/gpu/fusions/mlir:type_util",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
@@ -497,102 +436,6 @@
],
)
-xla_cc_test(
- name = "loop_test",
- srcs = ["loop_test.cc"],
- deps = [
- ":fusion_emitter",
- ":fusions",
- "//xla:status_macros",
- "//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu/model:affine_map_printer",
- "//xla/service/gpu/model:indexing_test_utils",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/status:statusor",
- "@com_google_googletest//:gtest",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "scatter",
- srcs = ["scatter.cc"],
- hdrs = ["scatter.h"],
- deps = [
- ":fusion_emitter",
- ":loop",
- "//xla:shape_util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:ir_emitter",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu:parallel_loop_emitter",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/service/llvm_ir:fused_ir_emitter",
- "//xla/service/llvm_ir:ir_array",
- "//xla/service/llvm_ir:llvm_util",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/types:span",
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:ir_headers",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "scatter_test",
- srcs = ["scatter_test.cc"],
- deps = [
- ":fusions",
- ":scatter",
- "//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu/model:affine_map_printer",
- "//xla/service/gpu/model:indexing_test_utils",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "tiling_util",
- srcs = ["tiling_util.cc"],
- hdrs = ["tiling_util.h"],
- visibility = ["//xla/service/gpu:__subpackages__"],
- deps = [
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:target_util",
- "//xla/service/llvm_ir:ir_array",
- "//xla/service/llvm_ir:kernel_support_library",
- "//xla/service/llvm_ir:llvm_loop",
- "//xla/service/llvm_ir:llvm_util",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:ir_headers",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
cc_library(
name = "triton",
srcs = ["triton.cc"],
@@ -694,14 +537,13 @@
"//xla/hlo/ir:hlo",
"//xla/service:dump",
"//xla/service:executable",
- "//xla/service:hlo_module_config",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
- "//xla/service/gpu:cudnn_fusion_compiler",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:stream_executor_util",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/tests:gpu_codegen_test",
+ "//xla/service/gpu/transforms:cudnn_fusion_compiler",
"//xla/stream_executor:dnn",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:stream_executor_memory_allocator",
@@ -739,87 +581,12 @@
)
cc_library(
- name = "reduction",
- srcs = ["reduction.cc"],
- hdrs = ["reduction.h"],
- deps = [
- ":fusion_emitter",
- ":reduction_base",
- ":thunk_util",
- ":tiling_util",
- "//xla:shape_util",
- "//xla:status_macros",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:buffer_assignment",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:hlo_traversal",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:ir_emitter",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu:kernel_arguments",
- "//xla/service/gpu:kernel_reuse_cache",
- "//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu:parallel_loop_emitter",
- "//xla/service/gpu:reduction_utils",
- "//xla/service/gpu:target_util",
- "//xla/service/gpu/runtime:kernel_thunk",
- "//xla/service/gpu/runtime:thunk",
- "//xla/service/llvm_ir:fused_ir_emitter",
- "//xla/service/llvm_ir:ir_array",
- "//xla/service/llvm_ir:kernel_support_library",
- "//xla/service/llvm_ir:llvm_loop",
- "//xla/service/llvm_ir:llvm_util",
- "//xla/service/llvm_ir:loop_emitter",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/container:node_hash_map",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:ir_headers",
- "@llvm-project//mlir:Support",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:status",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "reduction_test",
- srcs = ["reduction_test.cc"],
- deps = [
- ":fusion_emitter",
- ":reduction",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/service/gpu/model:indexing_test_utils",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
- "@com_google_googletest//:gtest",
- "@llvm-project//mlir:IR",
- ],
-)
-
-cc_library(
name = "reduction_base",
srcs = ["reduction_base.cc"],
hdrs = ["reduction_base.h"],
+ visibility = ["//xla/service/gpu/fusions:__subpackages__"],
deps = [
":fusion_emitter",
- ":tiling_util",
"//xla:shape_util",
"//xla:union_find",
"//xla:util",
@@ -863,11 +630,11 @@
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:reduction_utils",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
"//xla/service/gpu/fusions/mlir:type_util",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
@@ -895,6 +662,7 @@
":mlir_emitter_test_base",
":reduction_mlir",
"//xla:error_spec",
+ "//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu/model:indexing_analysis",
"//xla/service/gpu/model:indexing_test_utils",
"//xla/tests:xla_internal_test_main",
@@ -907,59 +675,12 @@
)
cc_library(
- name = "concatenate",
- srcs = ["concatenate.cc"],
- hdrs = ["concatenate.h"],
- deps = [
- ":fusion_emitter",
- "//xla:shape_util",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:ir_emitter",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu:parallel_loop_emitter",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/service/llvm_ir:fused_ir_emitter",
- "//xla/service/llvm_ir:ir_array",
- "//xla/service/llvm_ir:loop_emitter",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/status",
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:ir_headers",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "concatenate_test",
- srcs = ["concatenate_test.cc"],
- deps = [
- ":concatenate",
- ":fusions",
- "//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu/model:affine_map_printer",
- "//xla/service/gpu/model:indexing_test_utils",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest",
- "@llvm-project//mlir:IR",
- ],
-)
-
-cc_library(
name = "concatenate_mlir",
srcs = ["concatenate_mlir.cc"],
hdrs = ["concatenate_mlir.h"],
deps = [
- ":concatenate",
- ":loop",
"//xla/hlo/ir:hlo",
+ "//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
@@ -995,94 +716,6 @@
)
cc_library(
- name = "transpose",
- srcs = ["transpose.cc"],
- hdrs = ["transpose.h"],
- deps = [
- ":fusion_emitter",
- ":tiling_util",
- "//xla:permutation_util",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:ir_emitter",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu:target_util",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/service/llvm_ir:fused_ir_emitter",
- "//xla/service/llvm_ir:ir_array",
- "//xla/service/llvm_ir:llvm_util",
- "//xla/service/llvm_ir:loop_emitter",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:ir_headers",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "transpose_test",
- srcs = ["transpose_test.cc"],
- deps = [
- ":fusions",
- ":transpose",
- "//xla:status_macros",
- "//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu/model:indexing_test_utils",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/status:statusor",
- "@com_google_googletest//:gtest",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
- name = "input_slices",
- srcs = ["input_slices.cc"],
- hdrs = ["input_slices.h"],
- deps = [
- ":fusion_emitter",
- "//xla:shape_util",
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:elemental_ir_emitter",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu:ir_emitter",
- "//xla/service/gpu:ir_emitter_context",
- "//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu:parallel_loop_emitter",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/service/llvm_ir:fused_ir_emitter",
- "//xla/service/llvm_ir:ir_array",
- "//xla/service/llvm_ir:kernel_support_library",
- "//xla/service/llvm_ir:llvm_loop",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@llvm-project//llvm:ir_headers",
- "@llvm-project//mlir:IR",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "input_slices_mlir",
srcs = ["input_slices_mlir.cc"],
hdrs = ["input_slices_mlir.h"],
@@ -1093,10 +726,10 @@
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
@@ -1124,21 +757,3 @@
"@com_google_googletest//:gtest",
],
)
-
-xla_cc_test(
- name = "input_slices_test",
- srcs = ["input_slices_test.cc"],
- deps = [
- ":fusions",
- ":input_slices",
- "//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:hlo_fusion_analysis",
- "//xla/service/gpu/model:affine_map_printer",
- "//xla/service/gpu/model:indexing_test_utils",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_googletest//:gtest",
- "@llvm-project//mlir:IR",
- ],
-)
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.cc b/third_party/xla/xla/service/gpu/fusions/concatenate.cc
deleted file mode 100644
index 55a45dc..0000000
--- a/third_party/xla/xla/service/gpu/fusions/concatenate.cc
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/concatenate.h"
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/status/status.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Value.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/loop_emitter.h"
-#include "xla/shape.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) {
- const HloInstruction& concat = analysis.fusion_hero(0).instruction();
- int64_t dim = concat.concatenate_dimension();
- auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
- return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim);
- };
- HloInstruction* operand = *absl::c_max_element(concat.operands(), less);
- return operand->shape();
-}
-
-ConcatenateFusion::ConcatenateFusion(const HloFusionAnalysis& analysis)
- : analysis_(analysis) {}
-
-std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const {
- return std::nullopt;
-}
-
-std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const {
- return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1,
- GetLargestConcatOperandShape(analysis_),
- ctx);
-}
-
-absl::Status ConcatenateFusion::EmitKernel(
- IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
- GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
- FusedIrEmitter fused_emitter(elemental_emitter);
- for (int i = 0; i < fusion.fused_parameters().size(); i++) {
- fused_emitter.BindGenerator(
- *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
- return inputs[i].EmitReadArrayElement(index, builder);
- });
- }
-
- llvm::Type* index_type =
- GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
-
- const HloInstruction& concat = analysis_.fusion_hero(0).instruction();
- int64_t concat_dim = concat.concatenate_dimension();
- int64_t operand_offset = 0;
-
- // Emit the slices that correspond to the operands of the concat hero.
- for (const HloInstruction* operand : concat.operands()) {
- llvm_ir::BodyEmitter body_emitter =
- [&](const llvm_ir::IrArray::Index& operand_index) -> absl::Status {
- // Bind concat to generate the current operand.
- TF_ASSIGN_OR_RETURN(auto operand_generator,
- fused_emitter.GetGenerator(*operand));
- fused_emitter.BindGenerator(concat, [&](llvm_ir::IrArray::Index) {
- return operand_generator(operand_index);
- });
-
- // Create the index of the slice corresponding to the current operand.
- llvm_ir::IrArray::Index result_index = operand_index.AddOffsetToDim(
- llvm::ConstantInt::get(index_type, operand_offset), concat_dim,
- builder);
- operand_offset += operand->shape().dimensions(concat_dim);
-
- // Generate and write out the slice for each root.
- for (const auto& [output, root] :
- llvm::zip_equal(outputs, analysis_.fusion_roots())) {
- llvm_ir::IrArray::Index root_index = result_index.SourceIndexOfBitcast(
- concat.shape(), root.shape(), builder);
- TF_ASSIGN_OR_RETURN(auto generator,
- fused_emitter.GetGenerator(root.instruction()));
- TF_ASSIGN_OR_RETURN(llvm::Value * value, generator(root_index));
- output.EmitWriteArrayElement(root_index, value, builder);
- }
- return absl::OkStatus();
- };
-
- ParallelLoopEmitter emitter(body_emitter, operand->shape(), launch_dims,
- builder);
- TF_RETURN_IF_ERROR(emitter.EmitLoop(fusion.name(), index_type));
- }
-
- return absl::OkStatus();
-}
-
-LaunchDimensions ConcatenateFusion::launch_dimensions() const {
- return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_),
- analysis_.device_info());
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.h b/third_party/xla/xla/service/gpu/fusions/concatenate.h
deleted file mode 100644
index e838b29..0000000
--- a/third_party/xla/xla/service/gpu/fusions/concatenate.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_
-#define XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_
-
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-
-namespace xla {
-namespace gpu {
-
-const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis);
-
-// Emits a kernel for the given hlo instruction where each thread produces
-// one element of each concat operand.
-class ConcatenateFusion : public KernelFusionEmitterBase {
- public:
- explicit ConcatenateFusion(const HloFusionAnalysis& analysis);
- LaunchDimensions launch_dimensions() const override;
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const override;
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const override;
-
- protected:
- absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const override;
-
- private:
- const HloFusionAnalysis& analysis_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc
index f2cecc5..d7ebabf 100644
--- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc
@@ -35,10 +35,9 @@
#include "mlir/IR/ValueRange.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/concatenate.h"
-#include "xla/service/gpu/fusions/loop.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
+#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/indexing_analysis.h"
@@ -52,6 +51,16 @@
using mlir::Value;
using mlir::ValueRange;
+const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) {
+ const HloInstruction& concat = analysis.fusion_hero(0).instruction();
+ int64_t dim = concat.concatenate_dimension();
+ auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
+ return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim);
+ };
+ HloInstruction* operand = *absl::c_max_element(concat.operands(), less);
+ return operand->shape();
+}
+
// Computes the unroll factor that divides concat dimension of all operands.
int ComputeUnrollFactor(const HloFusionAnalysis& analysis,
int unroll_factor_for_the_largest_shape) {
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc
index 92aff94..30ca0f6 100644
--- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc
@@ -52,7 +52,7 @@
thread_id_printer_.SetSymbolName(1, "unroll_id");
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirConcatenateFusion fusion(analysis);
constexpr auto kIndexing = R"(
@@ -102,9 +102,9 @@
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0)>
- // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 200)>
- // CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 600)>
+ // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0)
+ // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 200)
+ // CHECK-DAG: #[[MAP_3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 600)
// CHECK-LABEL: fused_computation
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}},
@@ -152,7 +152,7 @@
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 + 64)>
+ // CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 64)
// CHECK-LABEL: fused_computation
// CHECK-DAG: %[[C_63:.*]] = arith.constant 63
@@ -254,9 +254,9 @@
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK-DAG: affine_map<(d0, d1) -> (d1 * 128 + d0)>
- // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)>
- // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)>
+ // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0)
+ // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)
+ // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)
// CHECK-LABEL: fused_computation
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc
deleted file mode 100644
index 8192b33..0000000
--- a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/concatenate.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class ConcatenateTest : public HloTestBase {
- public:
- void SetUp() override {
- HloTestBase::SetUp();
- printer_ =
- AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
- {"chunk_id", "unroll_id"});
- }
-
- protected:
- DebugOptions GetDebugOptionsForTest() override {
- auto opts = HloTestBase::GetDebugOptionsForTest();
- opts.set_xla_gpu_mlir_emitter_level(0);
- return opts;
- }
- AffineMapPrinter printer_;
- mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(ConcatenateTest, ThreadIndexing) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fused_computation {
- param0 = f32[200] parameter(0)
- param1 = f32[400] parameter(1)
- param2 = f32[300] parameter(2)
- ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0}
- }
- ENTRY main {
- param0 = f32[200] parameter(0)
- param1 = f32[400] parameter(1)
- param2 = f32[300] parameter(2)
- ROOT fusion = f32[900] fusion(param0, param1, param2),
- calls=fused_computation, kind=kLoop
- }
- )")
- .value();
-
- stream_executor::DeviceDescription device_info =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis_fused = AnalyzeFusion(*root, device_info);
-
- auto emitter =
- GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
- auto fusion = dynamic_cast<ConcatenateFusion*>(emitter.get());
- ASSERT_NE(fusion, nullptr);
-
- constexpr auto kIndexing = R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
- (bl_x * 128 + th_x)
- domain:
- th_x in [0, 127]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 3]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- bl_x * 128 + th_x in [0, 399]
- )";
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kIndexing));
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kIndexing));
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kIndexing));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
index caa24b1..172c53a 100644
--- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
@@ -31,11 +31,10 @@
#include "xla/primitive_util.h"
#include "xla/service/dump.h"
#include "xla/service/executable.h"
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/hlo_module_config.h"
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/stream_executor/dnn.h"
@@ -88,17 +87,49 @@
}
};
-TEST_F(CuDnnFusionTest, DumpingWorks) {
- HloModuleConfig config;
- DebugOptions options = GetDebugOptionsForTest();
- std::string output_directory;
- if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) {
- output_directory = tsl::testing::TmpDir();
+class CuDnnFusionFileCheckTest : public CuDnnFusionTest {
+ public:
+ CuDnnFusionFileCheckTest() {
+ if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory_)) {
+ output_directory_ = tsl::testing::TmpDir();
+ }
}
- options.set_xla_dump_to(output_directory);
- config.set_debug_options(options);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions options = CuDnnFusionTest::GetDebugOptionsForTest();
+ options.set_xla_dump_to(output_directory_);
+ return options;
+ }
+
+ absl::StatusOr<bool> RunCuDnnFileCheck(absl::string_view hlo,
+ absl::string_view pattern) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ const std::string root_name(
+ module->entry_computation()->root_instruction()->name());
+ BinaryMap dnn_compiled_graphs;
+ CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(),
+ dnn_compiled_graphs);
+ // Run filecheck even if CuDnnFusionCompiler failed.
+ cudnn_compiler.Run(module.get()).IgnoreError();
+ std::string dump;
+ TF_RETURN_IF_ERROR(tsl::ReadFileToString(
+ tsl::Env::Default(),
+ tsl::io::JoinPath(
+ output_directory_,
+ FilenameFor(*module, /*prefix=*/"",
+ /*suffix=*/
+ absl::StrCat("cudnn_fusion_", root_name, ".json"))),
+ &dump));
+ return RunFileCheck(dump, pattern);
+ }
+
+ private:
+ std::string output_directory_;
+};
+
+TEST_F(CuDnnFusionFileCheckTest, F32DotGraphIsConvertedCorrectly) {
+ EXPECT_TRUE(*RunCuDnnFileCheck(R"(
fd0 {
p0 = f32[64,64] parameter(0)
p1 = f32[64,64] parameter(1)
@@ -111,20 +142,7 @@
ROOT d0 = f32[64,64] fusion(p0, p1), kind=kCustom, calls=fd0,
backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}}
})",
- config));
- BinaryMap dnn_compiled_graphs;
- CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(),
- dnn_compiled_graphs);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, cudnn_compiler.Run(module.get()));
- EXPECT_TRUE(changed);
- std::string dump;
- TF_EXPECT_OK(tsl::ReadFileToString(
- tsl::Env::Default(),
- tsl::io::JoinPath(output_directory,
- FilenameFor(*module, /*prefix=*/"",
- /*suffix=*/"cudnn_fusion_d0.json")),
- &dump));
- EXPECT_TRUE(*RunFileCheck(dump, R"(
+ R"(
CHECK: "nodes": [
CHECK: "inputs": {
CHECK: "A": "p0",
diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc
index 3a95abf..31cd030 100644
--- a/third_party/xla/xla/service/gpu/fusions/custom.cc
+++ b/third_party/xla/xla/service/gpu/fusions/custom.cc
@@ -58,12 +58,16 @@
#include "xla/service/gpu/runtime/dynamic_slice_thunk.h"
#include "xla/service/gpu/runtime/gemm_thunk.h"
#include "xla/service/gpu/runtime/kernel_thunk.h"
+#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h"
+#include "xla/service/gpu/runtime/nccl_api.h"
+#include "xla/service/gpu/runtime/nccl_collective_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
+#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
@@ -755,6 +759,119 @@
return result;
}
+absl::StatusOr<FusionEmissionResult> EmitCollective(
+ IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor,
+ const HloFusionInstruction& fusion_instr, const HloInstruction* instr) {
+ if (instr->opcode() != HloOpcode::kReduceScatter) {
+ return absl::UnimplementedError(
+ "Dynamic slice fusion with collectives only works for reduce-scatter "
+ "instruction");
+ }
+
+ const BufferAssignment& buffer_assignment =
+ ir_emitter_context.buffer_assignment();
+
+ std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>>
+ offset_buffer_indices(2, std::nullopt);
+ std::vector<std::optional<Shape>> orig_shapes(2, std::nullopt);
+ std::vector<std::optional<Shape>> sliced_shapes(2, std::nullopt);
+ std::vector<std::optional<uint64_t>> offset_byte_sizes(2, std::nullopt);
+
+ std::vector<HloInstruction*> slice_instrs(2, nullptr);
+ std::vector<std::optional<BufferAllocation::Slice>> arguments;
+
+ // Collect slice information for inputs.
+ unsigned arg_idx = 0;
+ TF_ASSIGN_OR_RETURN(arguments.emplace_back(),
+ GetOperandSlice(buffer_assignment, adaptor, fusion_instr,
+ *instr->operand(arg_idx), slice_instrs,
+ /*shape_idx=*/{}, arg_idx));
+ TF_RETURN_IF_ERROR(CollectSliceInfo(
+ buffer_assignment, fusion_instr,
+ absl::Span<HloInstruction*>(slice_instrs), offset_buffer_indices,
+ orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++));
+
+ // Collect slice information for outputs.
+ TF_ASSIGN_OR_RETURN(
+ arguments.emplace_back(),
+ GetResultSlice(buffer_assignment, adaptor, fusion_instr, *instr,
+ slice_instrs, /*shape_idx=*/{}, arg_idx));
+ TF_RETURN_IF_ERROR(CollectSliceInfo(
+ buffer_assignment, fusion_instr,
+ absl::Span<HloInstruction*>(slice_instrs), offset_buffer_indices,
+ orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx));
+
+ if (absl::c_all_of(slice_instrs, [&](HloInstruction* slice_instr) {
+ return slice_instr &&
+ slice_instr->opcode() != HloOpcode::kDynamicUpdateSlice;
+ })) {
+ return absl::InternalError(
+ "DynamicSliceFusion with reduce-scatter expects a dynamic-update-slice "
+ "operation.");
+ }
+
+ // Provide fake allocations for inputs and outputs.
+ std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(2);
+ unsigned fake_arg_idx = 0;
+ int64_t operand_byte_size =
+ ShapeUtil::ByteSizeOf(instr->operand(fake_arg_idx)->shape());
+ fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
+ /*index=*/fake_arg_idx, operand_byte_size, /*color=*/0);
+ BufferAllocation::Slice slice_operand(fake_allocations[fake_arg_idx].get(), 0,
+ operand_byte_size);
+ fake_arg_idx++;
+ TF_RET_CHECK(instr->shape().IsArray() &&
+ "The output is not expected to be a tuple.");
+ int64_t out_fake_byte_size =
+ ShapeUtil::ByteSizeOf(instr->shape()); // TODO: we don't need this
+ fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
+ /*index=*/fake_arg_idx, out_fake_byte_size, /*color=*/0);
+ BufferAllocation::Slice slice_out_fake(fake_allocations[fake_arg_idx].get(),
+ 0, out_fake_byte_size);
+
+ // Generate the hero thunk and wrap it in a dynamic-slice thunk.
+ ThunkSequence seq;
+ auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr);
+ std::vector<NcclCollectiveThunk::Buffer> buffers;
+ const Shape& src_shape = instr->operand(0)->shape();
+ const Shape& dst_shape = instr->shape();
+ buffers.push_back(NcclCollectiveThunk::Buffer{
+ ShapeUtil::ElementsIn(src_shape), slice_operand, slice_out_fake,
+ src_shape.layout().memory_space(), dst_shape.layout().memory_space(),
+ nullptr, nullptr});
+
+ if (instr->opcode() == HloOpcode::kReduceScatter) {
+ int64_t replica_count = instr->GetModule()->config().replica_count();
+ int64_t partition_count = instr->GetModule()->config().num_partitions();
+ auto rs = static_cast<const HloReduceScatterInstruction*>(instr);
+ TF_RETURN_IF_ERROR(NcclReduceScatterStartThunk::CheckImplementable(
+ rs, replica_count, partition_count));
+
+ // TODO: add special handling for degenerate case - where no communication
+ // is needed. Just copy.
+ auto rs_start_thunk = std::make_unique<NcclReduceScatterStartThunk>(
+ thunk_info, NcclApi::Default(), rs, buffers);
+ auto rs_done = std::make_unique<NcclCollectiveDoneThunk>(
+ /*kind=*/Thunk::kNcclReduceScatterDone,
+ /*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(rs),
+ /*async_events=*/rs_start_thunk->async_events(),
+ /*async_stream_kind=*/AsyncStreamKind::kCollective);
+ seq.emplace_back(std::move(rs_start_thunk));
+ seq.emplace_back(std::move(rs_done));
+ } else {
+ return absl::InternalError("Expected reduce-scatter hero instruction");
+ }
+
+ std::unique_ptr<Thunk> thunk = std::make_unique<DynamicSliceThunk>(
+ thunk_info, std::make_unique<ThunkSequence>(std::move(seq)),
+ std::move(arguments), std::move(fake_allocations),
+ std::move(offset_buffer_indices), std::move(orig_shapes),
+ std::move(sliced_shapes), std::move(offset_byte_sizes));
+ FusionEmissionResult result;
+ result.thunks.push_back(std::move(thunk));
+ return result;
+}
+
} // namespace
absl::StatusOr<FusionEmissionResult> CustomFusion::Emit(
@@ -807,6 +924,16 @@
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
const HloFusionAdaptor& adaptor = analysis_.fusion();
+ // Only reduce-scatter is supported for now.
+ auto maybe_collective =
+ HloBfsFindIf(/*roots=*/adaptor.GetRoots(), /*fusion=*/adaptor,
+ /*visit=*/[](HloInstructionAdaptor node) -> bool {
+ return node.opcode() == HloOpcode::kReduceScatter;
+ });
+ if (maybe_collective != std::nullopt) {
+ return EmitCollective(ir_emitter_context, adaptor, fusion,
+ &maybe_collective->instruction());
+ }
auto maybe_custom_call_adaptor = HloBfsFindIf(
adaptor.GetRoots(), adaptor,
[](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
diff --git a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
index f53dc13..f835514 100644
--- a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
@@ -16,7 +16,9 @@
#include <cstddef>
#include <cstdint>
#include <functional>
+#include <string>
#include <utility>
+#include <vector>
#include "absl/status/status.h"
#include "xla/client/lib/constants.h"
@@ -24,9 +26,11 @@
#include "xla/error_spec.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/custom_call_target_registry.h"
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
@@ -34,6 +38,7 @@
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/stream.h"
+#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
@@ -80,6 +85,36 @@
config.set_debug_options(debug_options);
return config;
}
+
+ HloModuleConfig GetModuleConfigWithDeterministicOps() {
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_deterministic_ops(true);
+ HloModuleConfig config;
+ config.set_debug_options(debug_options);
+ return config;
+ }
+
+ std::vector<HloComputation*> GetAddressComputations(const HloModule& module) {
+ std::vector<HloComputation*> computations;
+ for (auto computation : module.computations()) {
+ if (!computation->IsFusionComputation()) {
+ continue;
+ }
+ auto backend_config = computation->FusionInstruction()
+ ->backend_config<xla::gpu::GpuBackendConfig>();
+ if (backend_config.ok()) {
+ const FusionBackendConfig& fusion_backend_config =
+ backend_config.value().fusion_backend_config();
+ const std::string name =
+ fusion_backend_config.custom_fusion_config().name();
+ if (name == "dynamic_address_computation" ||
+ name == "address_computation") {
+ computations.push_back(computation);
+ }
+ }
+ }
+ return computations;
+ }
};
TEST_F(DynamicSliceFusionTest, CublasGemmSimple) {
@@ -237,8 +272,10 @@
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
})";
- EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
- /*run_hlo_passes=*/false));
+ EXPECT_TRUE(RunAndCompareTwoModules(
+ hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+ GetModuleConfigWithDeterministicOps(), error_spec,
+ /*run_hlo_passes=*/false));
}
TEST_F(DynamicSliceFusionTest, ContiguousSlice) {
@@ -1327,8 +1364,10 @@
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
})";
- EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
- /*run_hlo_passes=*/false));
+ EXPECT_TRUE(RunAndCompareTwoModules(
+ hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+ GetModuleConfigWithDeterministicOps(), error_spec,
+ /*run_hlo_passes=*/false));
}
TEST_F(DynamicSliceFusionTest, DynamicContiguousSlice) {
@@ -2156,8 +2195,10 @@
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
})";
- EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
- /*run_hlo_passes=*/false));
+ EXPECT_TRUE(RunAndCompareTwoModules(
+ hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+ GetModuleConfigWithDeterministicOps(), error_spec,
+ /*run_hlo_passes=*/false));
}
TEST_F(DynamicSliceFusionTest, CublasGemmDUSWorkspaceIgnored) {
@@ -2241,8 +2282,10 @@
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
})";
- EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
- /*run_hlo_passes=*/false));
+ EXPECT_TRUE(RunAndCompareTwoModules(
+ hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+ GetModuleConfigWithDeterministicOps(), error_spec,
+ /*run_hlo_passes=*/false));
}
TEST_F(DynamicSliceFusionTest, CublasGemmDUSOffsetS32NotConstant) {
@@ -2435,8 +2478,10 @@
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
})";
- EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
- /*run_hlo_passes=*/false));
+ EXPECT_TRUE(RunAndCompareTwoModules(
+ hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+ GetModuleConfigWithDeterministicOps(), error_spec,
+ /*run_hlo_passes=*/false));
}
TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) {
@@ -2445,9 +2490,7 @@
&b, "__xla_test$$memcpy",
/*operands=*/
{DynamicSlice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"),
- {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "start0"),
- Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start1")},
- {2, 128})},
+ {ConstantR0(&b, 2), ConstantR0(&b, 0)}, {2, 128})},
ShapeUtil::MakeShape(F32, {2, 128}), /*opaque=*/"",
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
@@ -2464,7 +2507,6 @@
hlo_config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));
-
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));
DynamicSliceFusionRewriter pass(PLATFORM);
@@ -2502,11 +2544,7 @@
DynamicSlice(
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}),
"p0"),
- {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}),
- "start0"),
- Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}),
- "start1")},
- {3, 128}),
+ {ConstantR0(&b, 20), ConstantR0(&b, 0)}, {3, 128}),
}),
},
ShapeUtil::MakeTupleShape({
@@ -2545,6 +2583,15 @@
DynamicSliceFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);
+ EXPECT_TRUE(*RunFileCheck(hlo_opt->ToString(), R"(
+ // CHECK: %address-computation{{.+}} {
+ // CHECK: {{.+}} = {{.+}} slice
+ // CHECK: {{.+}} = {{.+}} dynamic-slice
+ // CHECK: {{.+}} = {{.+}} custom-call
+ // CHECK: ENTRY {{.+}} {
+ // CHECK-NOT: {{.+}} = {{.+}} slice
+ // CHECK-NOT: {{.+}} = {{.+}} dynamic-slice
+ )"));
EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
error_spec, /*run_hlo_passes=*/false));
@@ -2754,6 +2801,199 @@
/*run_hlo_passes=*/false));
}
+TEST_F(DynamicSliceFusionTest, ReduceScatterDUSConstant) {
+ // DUS offset is a constant
+ const char* hlo_ref = R"(
+ HloModule test, replica_count=2
+
+ add.clone {
+ x.1 = f16[] parameter(0)
+ y.1 = f16[] parameter(1)
+ ROOT add.462 = f16[] add(x.1, y.1)
+ }
+
+ ENTRY %main.9 {
+ param_0 = f16[128,128]{1,0} parameter(0)
+ param_1 = f16[128,128]{1,0} parameter(1)
+ constant_20 = u32[] constant(20)
+ constant_0 = u32[] constant(0)
+ reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone
+ ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, constant_20, constant_0)
+ })";
+
+ const char* hlo_opt = R"(
+ HloModule test, replica_count=2
+
+ %add {
+ %param_0 = f16[] parameter(0)
+ %param_1 = f16[] parameter(1)
+ ROOT %add.1 = f16[] add(%param_0, %param_1)
+ }
+
+ %address-computation {
+ %p1 = f16[128,128]{1,0} parameter(1)
+ %p0 = f16[128,128]{1,0} parameter(0)
+ %reduce-scatter.1 = f16[64,128]{1,0} reduce-scatter(%p0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+ %p2 = u32[] parameter(2)
+ %p3 = u32[] parameter(3)
+ ROOT %loop_dynamic_update_slice_fusion.1 = f16[128,128]{1,0} dynamic-update-slice(%p1, %reduce-scatter.1, %p2, %p3)
+ }
+
+ ENTRY %main.9 {
+ %param_0.1 = f16[128,128]{1,0} parameter(0)
+ %param_1.1 = f16[128,128]{1,0} parameter(1)
+ %constant_20 = u32[] constant(20)
+ %constant_0 = u32[] constant(0)
+ ROOT %address_computation = f16[128,128]{1,0} fusion(%param_0.1, %param_1.1, %constant_20, %constant_0), kind=kCustom, calls=%address-computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}},"force_earliest_schedule":false}
+ })";
+
+ ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+ EXPECT_TRUE(RunAndCompareTwoModulesReplicated(hlo_ref, hlo_opt, true, true,
+ error_spec));
+}
+
+TEST_F(DynamicSliceFusionTest, ReduceScatterDUSParameterOffset) {
+ // DUS offset is a parameter. This enforces a d2h copy.
+ const char* hlo_ref = R"(
+ HloModule test, replica_count=2
+
+ add.clone {
+ x.1 = f16[] parameter(0)
+ y.1 = f16[] parameter(1)
+ ROOT add.462 = f16[] add(x.1, y.1)
+ }
+
+ ENTRY %main.9 {
+ param_0 = f16[128,128]{1,0} parameter(0)
+ param_1 = f16[128,128]{1,0} parameter(1)
+ param_2 = u32[] parameter(2)
+ constant_0 = u32[] constant(0)
+ reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone
+ ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, param_2, constant_0)
+ })";
+
+ const char* hlo_opt = R"(
+ HloModule test, replica_count=2
+
+ %add {
+ %param_0 = f16[] parameter(0)
+ %param_1 = f16[] parameter(1)
+ ROOT %add.1 = f16[] add(f16[] %param_0, f16[] %param_1)
+ }
+
+ %address-computation {
+ %p1 = f16[128,128]{1,0} parameter(1)
+ %p0 = f16[128,128]{1,0} parameter(0)
+ %reduce-scatter.1 = f16[64,128]{1,0} reduce-scatter(%p0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+ %p2 = u32[] parameter(2)
+ %p3 = u32[] parameter(3)
+ ROOT %loop_dynamic_update_slice_fusion.1 = f16[128,128]{1,0} dynamic-update-slice(%p1, %reduce-scatter.1, %p2, %p3)
+ }
+
+ ENTRY %main.9 {
+ %param_0 = f16[128,128]{1,0} parameter(0)
+ %param_1 = f16[128,128]{1,0} parameter(1)
+ %param_2 = u32[] parameter(2)
+ %constant_0 = u32[] constant(0)
+ ROOT %address_computation = f16[128,128]{1,0} fusion(%param_0, %param_1, %param_2, %constant_0), kind=kCustom, calls=%address-computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}},"force_earliest_schedule":false}
+ })";
+
+ ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+ EXPECT_TRUE(RunAndCompareTwoModulesReplicated(hlo_ref, hlo_opt, true, true,
+ error_spec));
+}
+
+TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) {
+ const char* hlo_ref = R"(
+ HloModule jit_scan, replica_count=2
+
+ %add {
+ %param_0 = f32[] parameter(0)
+ %param_1 = f32[] parameter(1)
+ ROOT %add.1 = f32[] add(%param_0, %param_1)
+ }
+
+ %region_0.14 {
+ %arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+ %get-tuple-element.16 = s32[] get-tuple-element(%arg_tuple.15), index=0
+ %constant.21 = s32[] constant(1)
+ %add.37 = s32[] add(%get-tuple-element.16, %constant.21)
+ %get-tuple-element.20 = f32[128,128]{1,0} get-tuple-element(%arg_tuple.15), index=4
+ %get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(%arg_tuple.15), index=2
+ %reduce-scatter.1 = f32[64,128]{1,0} reduce-scatter(%get-tuple-element.20), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+ %reshape.32 = f32[1,64,128]{2,1,0} reshape(%reduce-scatter.1)
+ %constant.23 = s32[] constant(0)
+ %compare.33 = pred[] compare(%get-tuple-element.16, %constant.23), direction=LT
+ %constant.22 = s32[] constant(128)
+ %add.34 = s32[] add(%get-tuple-element.16, %constant.22)
+ %select.35 = s32[] select(%compare.33, %add.34, %get-tuple-element.16)
+ %dynamic-update-slice.36 = f32[128,128,128]{2,1,0} dynamic-update-slice(%get-tuple-element.18, %reshape.32, %select.35, %constant.23, %constant.23)
+ %get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(%arg_tuple.15), index=3
+ ROOT %tuple.38 = tuple(%add.37, %get-tuple-element.20, %dynamic-update-slice.36, %get-tuple-element.19, %get-tuple-element.20)
+ }
+
+ %region_1.39 {
+ %arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+ %get-tuple-element.41 = s32[] get-tuple-element(%arg_tuple.40), index=0
+ %constant.46 = s32[] constant(128)
+ ROOT %compare.47 = pred[] compare(%get-tuple-element.41, %constant.46), direction=LT
+ }
+
+ ENTRY %main.55 {
+ %constant.4 = s32[] constant(0)
+ %Arg_1.2 = f32[128,128]{1,0} parameter(1)
+ %constant.5 = f32[] constant(0)
+ %broadcast.6 = f32[128,128,128]{2,1,0} broadcast(%constant.5), dimensions={}
+ %Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2)
+ %Arg_0.1 = f32[128,128]{1,0} parameter(0)
+ %tuple.7 = tuple(%constant.4, %Arg_1.2, %broadcast.6, %Arg_2.3, %Arg_0.1)
+ %while.48 = while(%tuple.7), condition=%region_1.39, body=%region_0.14
+ %get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(%while.48), index=1
+ %get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(%while.48), index=2
+ ROOT %tuple.54 = tuple(%get-tuple-element.50, %get-tuple-element.51)
+ })";
+ DebugOptions debugoptions = GetDebugOptionsForTest();
+
+ HloModuleConfig ref_config;
+ debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(false);
+ ref_config.set_debug_options(debugoptions);
+ TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
+ ParseAndReturnVerifiedModule(hlo_ref, ref_config));
+ TF_ASSERT_OK_AND_ASSIGN(auto ref_module_opt,
+ GetOptimizedModule(std::move(ref_module)));
+
+ HloModuleConfig opt_config;
+ debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(true);
+ opt_config.set_debug_options(debugoptions);
+ TF_ASSERT_OK_AND_ASSIGN(auto module_with_adddress_computation_flag,
+ ParseAndReturnVerifiedModule(hlo_ref, opt_config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module_with_adddress_computation,
+ GetOptimizedModule(std::move(module_with_adddress_computation_flag)));
+
+ std::vector<HloComputation*> address_computations_opt =
+ GetAddressComputations(*module_with_adddress_computation);
+ std::vector<HloComputation*> address_computations_ref =
+ GetAddressComputations(*ref_module_opt);
+ EXPECT_EQ(address_computations_ref.size(), 0);
+ ASSERT_EQ(address_computations_opt.size(), 1);
+
+ // Check that reduce scatter happens in the fusion in optimized module and not
+ // outside the fusion.
+ EXPECT_TRUE(*RunFileCheck(address_computations_opt[0]->ToString(), R"(
+ // CHECK: {{.+}} = {{.*}}reduce-scatter({{.+}})
+ // CHECK: {{.+}} = {{.*}}dynamic-update-slice({{.+}})
+ )"));
+ EXPECT_TRUE(*RunFileCheck(
+ address_computations_opt[0]->FusionInstruction()->parent()->ToString(),
+ "// CHECK-NOT: {{.+}} = {{.*}}reduce-scatter"));
+
+ ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3};
+ EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
+ std::move(ref_module_opt), std::move(module_with_adddress_computation),
+ false, true, error));
+}
+
} // namespace
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc
index 03425b5..d662d25 100644
--- a/third_party/xla/xla/service/gpu/fusions/fusions.cc
+++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc
@@ -27,24 +27,24 @@
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusions/concatenate.h"
#include "xla/service/gpu/fusions/concatenate_mlir.h"
#include "xla/service/gpu/fusions/copy.h"
#include "xla/service/gpu/fusions/cudnn.h"
#include "xla/service/gpu/fusions/custom.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h"
#include "xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h"
-#include "xla/service/gpu/fusions/input_slices.h"
#include "xla/service/gpu/fusions/input_slices_mlir.h"
-#include "xla/service/gpu/fusions/loop.h"
+#include "xla/service/gpu/fusions/legacy/concatenate.h"
+#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h"
+#include "xla/service/gpu/fusions/legacy/input_slices.h"
+#include "xla/service/gpu/fusions/legacy/loop.h"
+#include "xla/service/gpu/fusions/legacy/reduction.h"
+#include "xla/service/gpu/fusions/legacy/scatter.h"
+#include "xla/service/gpu/fusions/legacy/transpose.h"
#include "xla/service/gpu/fusions/loop_mlir.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/reduction.h"
#include "xla/service/gpu/fusions/reduction_mlir.h"
-#include "xla/service/gpu/fusions/scatter.h"
#include "xla/service/gpu/fusions/scatter_mlir.h"
-#include "xla/service/gpu/fusions/transpose.h"
#include "xla/service/gpu/fusions/transpose_mlir.h"
#include "xla/service/gpu/fusions/triton.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc
deleted file mode 100644
index 464de3c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc
+++ /dev/null
@@ -1,105 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h"
-
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/dynamic_update_slice_util.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr int kDUSUpdateIndex = 1;
-
-} // namespace
-
-LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const {
- const auto& update_shape = dus_ops_.front().GetOperand(1).shape();
- return CalculateLaunchDimensions(update_shape, analysis_.device_info());
-}
-
-std::optional<IndexingMap>
-InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* mlir_context) const {
- if (hero_operand_index != kDUSUpdateIndex) {
- return std::nullopt;
- }
- auto launch_dims = launch_dimensions();
- // It is guaranteed that all DUS ops have the same output shape at this point.
- const auto& update_shape =
- dus_ops_.front().GetOperand(kDUSUpdateIndex).shape();
- return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1,
- update_shape, mlir_context);
-}
-
-absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel(
- IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
- // In case a dynamic slice update's output is bitcasted, we need to ensure we
- // write to the output array using the shape and layout of the dynamic slice
- // update. This cast is known to be safe to do iff, in the case the output of
- // the dynamic slice update is bitcasted, that bitcast is either the fusion's
- // output, or has a single user and is part of the fusion's tuple output.
- // This condition should be enforced explicitly in the
- // 'CanEmitFusedDynamicUpdateSliceInPlaceForGpu' matcher.
- for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
- output = output.CastToShape(op.shape(), builder);
- }
-
- auto* fused_computation = fusion.fused_instructions_computation();
- GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
- FusedIrEmitter fused_emitter(elemental_emitter);
- for (auto [index, input] : llvm::enumerate(inputs)) {
- auto fused_operand = fused_computation->parameter_instruction(index);
- fused_emitter.BindGenerator(
- *fused_operand, [input = input, builder,
- fused_operand](const llvm_ir::IrArray::Index& index) {
- return input.EmitReadArrayElement(index, builder,
- fused_operand->name());
- });
- }
-
- std::vector<std::pair<const HloInstruction*, const llvm_ir::IrArray>>
- dus_and_output_array;
- dus_and_output_array.reserve(dus_ops_.size());
-
- for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
- dus_and_output_array.push_back(std::make_pair(&op.instruction(), output));
- }
-
- return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
- fused_computation, dus_and_output_array, &fused_emitter, launch_dims,
- builder);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h
deleted file mode 100644
index cfac87d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h
+++ /dev/null
@@ -1,98 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
-#define XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-
-// Fusion node where the root is either:
-// 1. a dynamic-update-slice op
-// 2. a bitcast of a dynamic-update-slice op
-// 3. a tuple op returning the result of several dynamic-update-slice ops
-// 4. a tuple op returning the result of several bitcast
-// dynamic-update-slice ops
-//
-// Additionally, all the dynamic-update-slice ops have exactly one user. The
-// fusion parameter that they update can have users (in addition to the
-// dynamic-update-slice op) that read in either
-// a. a dynamic-slice corresponding exactly to the slice of the parameter that
-// is updated by the dynamic-update-slice op
-// b. a dynamic-slice reading in a single element anywhere in the parameter.
-// This is only allowed if the dynamic-update-slice op updates a single
-// element
-//
-// In both cases, the additional users must not flow into any other output
-// than the dynamic-slice-update corresponding to that particular slice of the
-// parameter.
-//
-// The assumption is that each op's input (i.e. array to update) shares the
-// same slice as its output. In this case, we have a special algorithm that
-// modifies the output in place without touching the un-updated elements. The
-// update slice is assumed to be the exact same for all the
-// dynamic-update-slice ops.
-class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase {
- public:
- explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis)
- : analysis_(analysis),
- dus_ops_(
- GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
- LaunchDimensions launch_dimensions() const override;
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const override {
- // The mapping cannot be statically computed in general, since the offsets
- // are unknown.
- return std::nullopt;
- }
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* mlir_context) const override;
-
- protected:
- absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const override;
-
- const HloFusionAnalysis& analysis_;
- std::vector<HloInstructionAdaptor> dus_ops_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc
index b0da3ef..f18173b 100644
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc
@@ -56,7 +56,7 @@
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirInPlaceDynamicUpdateSliceFusion fusion(analysis);
auto thread_id_update_indexing = fusion.ComputeThreadIdToInputIndexing(
@@ -100,8 +100,8 @@
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 6)>
- // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 6)>
+ // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 6), domain: d0 in [0, 29]
+ // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 6), domain: d0 in [0, 29]
// CHECK: func.func @fused_computation
// CHECK-SAME: %arg0: tensor<20x30xf32>
// CHECK-SAME: %arg1: tensor<5x6xf32>
@@ -112,8 +112,8 @@
// CHECK-DAG: %[[C_15:.*]] = arith.constant 15
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0
// CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x
- // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 29])
- // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 29])
+ // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]])
+ // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]])
// CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0
// CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1
// CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]]
@@ -151,8 +151,8 @@
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
- // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 3)>
+ // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 3), domain: d0 in [0, 5]
+ // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 3), domain: d0 in [0, 5]
// CHECK: func.func @fused_computation
// CHECK-SAME: %arg0: tensor<7x8xf32>
// CHECK-SAME: %arg1: tensor<2x3xf32>
@@ -162,8 +162,8 @@
// CHECK-DAG: %[[C_5:.*]] = arith.constant 5
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0
// CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x
- // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 5])
- // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 5])
+ // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]])
+ // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]])
// CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0
// CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1
// CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]]
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc
deleted file mode 100644
index e48cee0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase {
- public:
- void SetUp() override {
- HloTestBase::SetUp();
- printer_ =
- AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
- {"chunk_id", "unroll_id"});
- }
-
- protected:
- DebugOptions GetDebugOptionsForTest() override {
- auto opts = HloTestBase::GetDebugOptionsForTest();
- opts.set_xla_gpu_mlir_emitter_level(0);
- return opts;
- }
- AffineMapPrinter printer_;
- mlir::MLIRContext mlir_context_;
- stream_executor::DeviceDescription device_info_ =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
-};
-
-TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fused_computation {
- in = f32[20,30] parameter(0)
- updates = f32[5,6] parameter(1)
- i0 = s32[] parameter(2)
- i1 = s32[] parameter(3)
- ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1)
- }
- ENTRY entry {
- in = f32[20,30] parameter(0)
- updates = f32[5,6] parameter(1)
- i0 = s32[] constant(2)
- i1 = s32[] constant(3)
- ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation
- }
- )"));
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis_fused = AnalyzeFusion(*root, device_info_);
-
- auto emitter =
- GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
- auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
- ASSERT_NE(fusion, nullptr);
-
- auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_);
- EXPECT_THAT(thread_id_update_indexing->ToString(printer_),
- MatchIndexingString(R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
- th_x floordiv 6, th_x mod 6)
- domain:
- th_x in [0, 29]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 0]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- )"));
- auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
- EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt));
-}
-
-TEST_F(InPlaceDynamicUpdateSliceFusionTest, ProduceConsumerFusion) {
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- fused_computation.1 {
- param_0 = bf16[1,2,5,1,2] parameter(0)
- bitcast = bf16[1,5,1,2,2] bitcast(param_0)
- param_1 = bf16[1,1,1,2,2] parameter(1)
- param_2 = s32[] parameter(2)
- param_3 = s32[] parameter(3)
- ROOT dynamic-update-slice = bf16[1,5,1,2,2] dynamic-update-slice(bitcast, param_1, param_2, param_3, param_2, param_2, param_2)
- }
-
- ENTRY entry_computation {
- param_0.2 = bf16[1,2,5,1,2] parameter(3)
- param_1.2 = bf16[1,1,1,2,2] parameter(0)
- param_2.2 = s32[] parameter(1)
- param_3.2 = s32[] parameter(2)
- fusion = bf16[1,5,1,2,2] fusion(param_0.2, param_1.2, param_2.2, param_3.2), kind=kLoop, calls=fused_computation.1
- ROOT bitcast.1 = bf16[1,2,5,1,2] bitcast(fusion)
- }
- )"));
-
- auto* root = module->entry_computation()->root_instruction();
-
- auto analysis_fused =
- AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info_);
-
- auto emitter =
- GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
-
- auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
-
- ASSERT_NE(fusion, nullptr);
- EXPECT_EQ(fusion->launch_dimensions().launch_bound(), 4 /* update size */);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/input_slices.cc
deleted file mode 100644
index 75ffe4b..0000000
--- a/third_party/xla/xla/service/gpu/fusions/input_slices.cc
+++ /dev/null
@@ -1,220 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/input_slices.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Value.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/elemental_ir_emitter.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/kernel_support_library.h"
-#include "xla/service/llvm_ir/llvm_loop.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Emits code for slices based on the below structure. An if statement with
-// a guarding condition is generated for each ROOT slice.
-//
-// Pseudo code:
-//
-// Compute values of slice input operands
-//
-// Compute guarding_cond0
-// if (guarding_cond0) {
-// Write to output of slice0
-// }
-//
-// Compute guarding_cond1
-// if (guarding_cond1) {
-// Write to output of slice1
-// }
-//
-absl::Status EmitElementForInputFusibleSlices(
- ElementalIrEmitter& elemental_emitter,
- const HloComputation* fused_computation,
- const std::vector<llvm_ir::IrArray>& inputs,
- const std::vector<llvm_ir::IrArray>& outputs,
- const llvm_ir::IrArray::Index& index, llvm::IRBuilder<>* builder) {
- VLOG(10) << "Emitting slice input fusion for "
- << fused_computation->ToString();
-
- HloInstruction* slice_or_tuple = fused_computation->root_instruction();
- auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
- if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
- return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
- }
- CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
- return slice_or_tuple->operands();
- }();
-
- // Emit input operand values of slices.
- std::vector<llvm::Value*> input_ir_values;
- FusedIrEmitter fused_emitter(elemental_emitter);
- for (int i = 0; i < fused_computation->num_parameters(); i++) {
- fused_emitter.BindGenerator(
- *fused_computation->parameter_instruction(i),
- [&inputs, i, builder](llvm_ir::IrArray::Index index) {
- return inputs[i].EmitReadArrayElement(index, builder);
- });
- }
- for (const HloInstruction* slice : slice_instructions) {
- auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0));
- input_ir_values.push_back(input_generator(index).value());
- }
-
- // Emit for slice_instructions.
- KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
- for (int64_t i = 0; i < slice_instructions.size(); ++i) {
- HloInstruction* slice = slice_instructions[i];
-
- // guarding_cond := index >= start && index < limit, for each dim.
- std::vector<llvm::Value*> index_within_ranges;
- for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
- CHECK_EQ(slice->slice_strides(dim), 1);
- auto larger_or_equal_than_start = builder->CreateICmpSGE(
- index.multidim()[dim],
- index.GetConstantWithIndexType(slice->slice_starts(dim)));
- llvm::Value* smaller_than_limit = builder->CreateICmpSLT(
- index.multidim()[dim],
- index.GetConstantWithIndexType(slice->slice_limits(dim)));
- llvm::Value* within_range =
- builder->CreateAnd(larger_or_equal_than_start, smaller_than_limit);
- index_within_ranges.push_back(within_range);
- }
- llvm::Value* guarding_cond = builder->CreateAnd(index_within_ranges);
-
- auto emit_slice_elem_func = [&] {
- const std::vector<llvm::Value*>& src_multidim = index.multidim();
- std::vector<llvm::Value*> dst_multidim(src_multidim.size());
- for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
- dst_multidim[dim] = builder->CreateSub(
- src_multidim[dim],
- index.GetConstantWithIndexType(slice->slice_starts(dim)));
- }
- const llvm_ir::IrArray& src_ir_array = outputs[i];
- llvm_ir::IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
- index.GetType());
- src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
- builder);
- };
-
- ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func);
- }
- return absl::OkStatus();
-}
-
-// Gets the input shape of the ROOT slices, which will be used as the kernel
-// launch dims. The slice input fusion requires the input shapes of the ROOT
-// slices to be the same although the (slice) output shapes can be different.
-//
-// Returns the input shape of the ROOT slices if all the input shapes of ROOT
-// slices are the same and the slices are non-strided. Otherwise, returns
-// FailedPrecondition.
-absl::StatusOr<Shape> GetConsistentInputShapeForRootSlices(
- const HloComputation* fused_computation) {
- const HloInstruction& root = *fused_computation->root_instruction();
- if (root.opcode() == HloOpcode::kSlice) {
- return root.operands()[0]->shape();
- }
-
- CHECK_EQ(root.opcode(), HloOpcode::kTuple);
- const Shape& first_slice_operand_shape =
- root.operands()[0]->operands()[0]->shape();
- for (size_t i = 1; i < root.operands().size(); ++i) {
- const HloInstruction* slice = root.operands()[i];
- const Shape& operand_shape = slice->operands()[0]->shape();
- if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
- operand_shape)) {
- return FailedPrecondition(
- "Fused slices do not have the same input shape, fused computation = "
- "%s.",
- root.parent()->name());
- }
- }
-
- return first_slice_operand_shape;
-}
-
-} // namespace
-
-LaunchDimensions InputSlicesFusion::launch_dimensions() const {
- const auto& root = analysis_.fusion_root(0).instruction();
- const auto& shape = root.operand(0)->shape();
- return CalculateLaunchDimensions(shape, analysis_.device_info(),
- {unroll_factor_});
-}
-
-std::optional<IndexingMap> InputSlicesFusion::ComputeThreadIdToOutputIndexing(
- int64_t output_id, mlir::MLIRContext* ctx) const {
- // The mapping here is trivial and the same for all outputs - slice offsets
- // are applied in the indexing from slice outputs to slice inputs.
- auto launch_dims = launch_dimensions();
- // The implementation requires the shapes and layouts to be the same, but we
- // still use the requested output's shape for clarity.
- const auto& shape = analysis_.fusion_root(output_id).shape();
- return GetDefaultThreadIdIndexingMap(launch_dims, unroll_factor_, shape, ctx);
-}
-
-absl::Status InputSlicesFusion::EmitKernel(
- IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
- TF_ASSIGN_OR_RETURN(Shape element_shape,
- GetConsistentInputShapeForRootSlices(
- fusion.fused_instructions_computation()));
- LaunchDimensionsConfig launch_config;
- launch_config.unroll_factor = unroll_factor_;
- GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
- return ParallelLoopEmitter(
- [&](const llvm_ir::IrArray::Index index) -> absl::Status {
- return EmitElementForInputFusibleSlices(
- elemental_emitter, fusion.fused_instructions_computation(),
- inputs, outputs, index, builder);
- },
- element_shape, launch_dims, builder, launch_config)
- .EmitLoop(
- fusion.name(),
- GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder));
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.h b/third_party/xla/xla/service/gpu/fusions/input_slices.h
deleted file mode 100644
index fa80438..0000000
--- a/third_party/xla/xla/service/gpu/fusions/input_slices.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_
-#define XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Generates code for input-fusible slices.
-//
-// Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes
-// of all ROOT slices need to be the same while their output shapes can be
-// different. On the other hand, the input ranges of slices can be
-// overlapping. Further generalization/specialization when the needs are seen
-// in the future.
-class InputSlicesFusion : public KernelFusionEmitterBase {
- public:
- explicit InputSlicesFusion(const HloFusionAnalysis& analysis)
- : analysis_(analysis),
- unroll_factor_(CeilOfRatio(
- 8, analysis.input_output_info().smallest_output_dtype_bits)) {}
- LaunchDimensions launch_dimensions() const override;
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t output_id, mlir::MLIRContext* ctx) const override;
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const override {
- // TODO(b/319081342): Implement this.
- return std::nullopt;
- }
-
- protected:
- absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const override;
-
- private:
- const HloFusionAnalysis& analysis_;
- const int unroll_factor_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc
index 5297bf7..d2739e0 100644
--- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc
@@ -38,9 +38,9 @@
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/indexing_analysis.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc
index abeb57a..4aec42e 100644
--- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc
@@ -45,7 +45,7 @@
.value();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
auto emitter = GetEmitter(analysis);
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc
deleted file mode 100644
index 689727a..0000000
--- a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc
+++ /dev/null
@@ -1,104 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/input_slices.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class InputSlicesTest : public HloTestBase {
- public:
- void SetUp() override {
- HloTestBase::SetUp();
- printer_ =
- AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
- {"chunk_id", "unroll_id"});
- }
-
- protected:
- DebugOptions GetDebugOptionsForTest() override {
- auto opts = HloTestBase::GetDebugOptionsForTest();
- opts.set_xla_gpu_mlir_emitter_level(0);
- return opts;
- }
- AffineMapPrinter printer_;
- mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(InputSlicesTest, ThreadIndexing) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fused_computation {
- %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
- slice0 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[1:3],[0:3],[2:7]}
- slice1 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[0:2],[0:3],[2:7]}
- ROOT tuple = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) tuple(slice0, slice1)
- }
-
- ENTRY entry {
- %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
- ROOT %fusion = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) fusion(%input), kind=kLoop, calls=fused_computation
- })")
- .value();
-
- stream_executor::DeviceDescription device_info =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis_fused = AnalyzeFusion(*root, device_info);
-
- auto emitter =
- GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
- auto fusion = dynamic_cast<InputSlicesFusion*>(emitter.get());
- ASSERT_NE(fusion, nullptr);
-
- auto thread_id_to_output_indexing =
- fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_);
- EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
- MatchIndexingString(R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0,
- ((bl_x * 128 + th_x) floordiv 3) mod 2,
- (bl_x * 128 + th_x) mod 3,
- (bl_x * 128 + th_x) floordiv 6)
- domain:
- th_x in [0, 127]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 1]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- bl_x * 128 + th_x in [0, 29]
- )"));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD
new file mode 100644
index 0000000..8250669
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD
@@ -0,0 +1,157 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+load("//xla/tests:build_defs.bzl", "xla_test")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = [":friends"],
+ licenses = ["notice"],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//xla:friends",
+ ],
+)
+
+td_library(
+ name = "xla_gpu_td_files",
+ srcs = glob(["*.td"]),
+ includes = ["."],
+ deps = [
+ "@llvm-project//mlir:BuiltinDialectTdFiles",
+ "@llvm-project//mlir:CallInterfacesTdFiles",
+ "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
+ "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:SideEffectInterfacesTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "xla_gpu_dialect_inc_gen",
+ strip_include_prefix = ".",
+ tbl_outs = [
+ (
+ ["-gen-dialect-decls"],
+ "xla_gpu_dialect.h.inc",
+ ),
+ (
+ ["-gen-dialect-defs"],
+ "xla_gpu_dialect.cc.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "xla_gpu_dialect.td",
+ deps = [":xla_gpu_td_files"],
+)
+
+gentbl_cc_library(
+ name = "xla_gpu_ops_inc_gen",
+ strip_include_prefix = ".",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "xla_gpu_ops.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "xla_gpu_ops.cc.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "xla_gpu_ops.td",
+ deps = [":xla_gpu_td_files"],
+)
+
+gentbl_cc_library(
+ name = "xla_gpu_attrs_inc_gen",
+ strip_include_prefix = ".",
+ tbl_outs = [
+ (
+ [
+ "-gen-attrdef-decls",
+ ],
+ "xla_gpu_attrs.h.inc",
+ ),
+ (
+ [
+ "-gen-attrdef-defs",
+ ],
+ "xla_gpu_attrs.cc.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "xla_gpu_attrs.td",
+ deps = [":xla_gpu_td_files"],
+)
+
+gentbl_cc_library(
+ name = "xla_gpu_types_inc_gen",
+ strip_include_prefix = ".",
+ tbl_outs = [
+ (
+ [
+ "-gen-typedef-decls",
+ "-typedefs-dialect=xla_gpu",
+ ],
+ "xla_gpu_types.h.inc",
+ ),
+ (
+ [
+ "-gen-typedef-defs",
+ "-typedefs-dialect=xla_gpu",
+ ],
+ "xla_gpu_types.cc.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "xla_gpu_types.td",
+ deps = [":xla_gpu_td_files"],
+)
+
+cc_library(
+ name = "xla_gpu",
+ srcs = [
+ "xla_gpu_attrs.cc",
+ "xla_gpu_dialect.cc",
+ "xla_gpu_ops.cc",
+ "xla_gpu_types.cc",
+ ],
+ hdrs = [
+ "xla_gpu_ops.h",
+ ],
+ deps = [
+ ":xla_gpu_attrs_inc_gen",
+ ":xla_gpu_dialect_inc_gen",
+ ":xla_gpu_ops_inc_gen",
+ ":xla_gpu_types_inc_gen",
+ "//xla/service/gpu/model:indexing_analysis",
+ "@com_google_absl//absl/strings:str_format",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:BytecodeOpInterface",
+ "@llvm-project//mlir:CallOpInterfaces",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:InferTypeOpInterface",
+ "@llvm-project//mlir:InliningUtils",
+ "@llvm-project//mlir:SideEffectInterfaces",
+ "@llvm-project//mlir:Support",
+ ],
+)
+
+xla_test(
+ name = "xla_gpu_ops_test",
+ srcs = ["xla_gpu_ops_test.cc"],
+ backends = ["gpu"],
+ deps = [
+ ":xla_gpu",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD
new file mode 100644
index 0000000..381d5a3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD
@@ -0,0 +1,16 @@
+load("//xla:lit.bzl", "lit_test_suite")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ licenses = ["notice"],
+)
+
+lit_test_suite(
+ name = "tests",
+ srcs = glob(["*.mlir"]),
+ cfg = "//xla:lit.cfg.py",
+ tools = [
+ "//xla/service/gpu/fusions/tools:mlir_fusions_opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir
new file mode 100644
index 0000000..34065f9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir
@@ -0,0 +1,188 @@
+// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s
+
+#map0 = #xla_gpu.indexing_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2),
+ domain: s0 in [-10, 10], s1 in [0, 2]>
+func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) {
+ %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1]
+ func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<()[s0] -> (s0 + 1, s0 mod 2),
+// CHECK-SAME: domain: s0 in [-10, 10]>
+
+// CHECK-LABEL: func.func @simplify_apply_indexing
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2),
+ domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]>
+func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index,
+ %d2: index, %s0: index, %s1: index) -> (index, index, index) {
+ %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1]
+ func.return %0#0, %0#1, %0#2 : index, index, index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1),
+// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3], s0 in [-11, 11]>
+
+// CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME: (%[[ARG_0]], %[[ARG_2]])
+// CHECK-SAME: [%[[ARG_3]]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0),
+ domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]>
+func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index)
+ -> (index, index, index, index, index) {
+ %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+ func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+
+// CHECK-LABEL: func.func @fold_indexing_map_results
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
+
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+
+// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]]
+// CHECK: return %[[NEW_RESULT]], %[[C4]], %[[ARG_1]], %[[C1]], %[[ARG_2]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0),
+ domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]>
+func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) {
+ %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+ func.return %0#2 : index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2),
+// CHECK-SAME: domain: d0 in [0, 2]>
+
+// CHECK-LABEL: func.func @remove_unused_results
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
+
+// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]])
+// CHECK: return %[[NEW_RESULT]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3),
+ domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]>
+func.func @fold_operands(%d0: index) -> index {
+ %d1 = arith.constant 1 : index
+ %s0 = arith.constant 2 : index
+ %s1 = arith.constant 3 : index
+ %0 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0, %s1]
+ func.return %0 : index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 3),
+// CHECK-SAME: domain: d0 in [0, 10]>
+
+// CHECK-LABEL: func.func @fold_operands
+// CHECK-SAME: %[[ARG_0:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]])
+
+// -----
+
+func.func @fold_operands_and_results(%arg0: index, %arg1: index)
+ -> (index, index) {
+ %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (0, d1),
+ domain: d0 in [0, 4], d1 in [0, 5]>(%arg0, %arg1)
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @fold_operands_and_results
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: return %[[C0]], %[[ARG_1]] : index, index
+
+// -----
+
+func.func @fold_sequence(%arg0: index, %arg1: index) -> index {
+ %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+ domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1)
+ %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 100 + 42),
+ domain: d0 in [0, 10000]>(%0)
+ func.return %1 : index
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42),
+// CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4]>
+// CHECK-LABEL: func.func @fold_sequence
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME: (%[[ARG0]], %[[ARG1]])
+
+// -----
+
+func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index {
+ %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+ domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1)
+ %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<()[s0] -> (s0 mod 100 + 42),
+ domain: s0 in [0, 10000]>(%0)
+ func.return %1 : index
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42),
+// CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4]>
+// CHECK-LABEL: func.func @fold_sequence_sym
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME: (%[[ARG0]], %[[ARG1]])
+
+// -----
+
+func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index {
+ %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+ domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1)
+ %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+ domain: d0 in [0, 4], d1 in [0, 10000]>(%arg1, %0)
+ func.return %1 : index
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1),
+// CHECK-SAME: domain: d0 in [0, 4], d1 in [0, 5]>
+// CHECK-LABEL: func.func @fold_sequence_shared_operands
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME: (%[[ARG1]], %[[ARG0]])
+
+// -----
+
+func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index)
+ -> (tensor<2x3xf32>) {
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
+ ^bb0(%current : f32):
+ xla_gpu.yield %current : f32
+ }
+ return %ret : tensor<2x3xf32>
+}
+// CHECK-LABEL: func.func @atomic_rmw_empty
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>
+// CHECK: return %[[ARG0]]
+
+
+// -----
+
+func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index)
+ -> (tensor<2x3xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
+ ^bb0(%current : f32):
+ xla_gpu.yield %cst : f32
+ }
+ return %ret : tensor<2x3xf32>
+}
+// CHECK-LABEL: func.func @atomic_rmw_cst
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant
+// CHECK-NEXT: atomic_rmw
+// CHECK: xla_gpu.yield %[[CST]]
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir
new file mode 100644
index 0000000..cd2e09f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir
@@ -0,0 +1,136 @@
+// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1, d2)[s0] -> (d0),
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [1, 2],
+// CHECK-SAME: d1 in [5, 8],
+// CHECK-SAME: d2 in [10, 12],
+// CHECK-SAME: s0 in [0, 32],
+// CHECK-SAME: d0 mod 2 in [0, 1],
+// CHECK-SAME: d0 + s0 in [1, 10]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ d2 in [10, 12],
+ s0 in [0, 32],
+ d0 mod 2 in [0, 1],
+ d0 + s0 in [1, 10]
+ >
+
+func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>)
+// CHECK-LABEL: @indexing_map_attr
+// CHECK: !xla_gpu.indexed_vector<64x64x32xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [1, 2]
+// CHECK-SAME: d1 in [5, 8]
+// CHECK-SAME: s0 in [0, 10]
+// CHECK-SAME: s1 in [0, 5]
+// CHECK-SAME: s2 in [0, 32]
+// CHECK-SAME: d0 mod 2 in [0, 1]
+// CHECK-SAME: d0 + s0 in [1, 10]
+// CHECK-SAME: d1 + s1 + s2 in [1, 32]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ s0 in [0, 10],
+ s1 in [0, 5],
+ s2 in [0, 32],
+ d0 mod 2 in [0, 1],
+ d0 + s0 in [1, 10],
+ d1 + s1 + s2 in [1, 32]
+ >
+func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
+// CHECK-LABEL: @more_range_vars
+// CHECK: !xla_gpu.indexed_vector<100x32xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0)[s0] -> (d0)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [0, 100]
+// CHECK-SAME: s0 in [-3, -1]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0),
+ domain:
+ d0 in [0, 100],
+ s0 in [-3, -1]
+ >
+func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @indexing_map_small
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1, d2)[s0] -> (d0)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [1, 2]
+// CHECK-SAME: d1 in [5, 8]
+// CHECK-SAME: d2 in [10, 12]
+// CHECK-SAME: s0 in [0, 32]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ d2 in [10, 12],
+ s0 in [0, 32]
+ >
+func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
+// CHECK-LABEL: @no_constraints
+// CHECK: !xla_gpu.indexed_vector<32xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: ()[s0] -> (s0)
+// CHECK-SAME: domain:
+// CHECK-SAME: s0 in [3, 5]
+// CHECK-SAME: s0 mod 2 in [0, 1]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<()[s0] -> (s0),
+ domain:
+ s0 in [3, 5],
+ s0 mod 2 in [0, 1]
+ >
+func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @no_dimensions
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0) -> (d0)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [3, 5]
+// CHECK-SAME: d0 mod 2 in [0, 1]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0) -> (d0),
+ domain:
+ d0 in [3, 5],
+ d0 mod 2 in [0, 1]
+ >
+func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @no_symbols
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: () -> ()
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<() -> ()>
+func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @empty
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir
rename to third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir
new file mode 100644
index 0000000..999a6de
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir
@@ -0,0 +1,236 @@
+// RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics
+
+#map0 = #xla_gpu.indexing_map<
+ (d0, d1)[s0] -> (d0, d1 + s0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ s0 in [0, 32]
+>
+func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
+ // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}}
+ %0:2 = xla_gpu.apply_indexing #map0 (%d0)
+ func.return %0#0, %0#1 : index, index
+}
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+ (d0, d1)[s0] -> (d0, d1 + s0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ s0 in [0, 32],
+ d0 mod 2 in [0, 1],
+ d0 + s0 in [1, 10]
+>
+func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) {
+ // expected-error @+1 {{apply indexing op cannot have any constraints}}
+ %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+ func.return %0#0, %0#1 : index, index
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>,
+ %init: f32) -> (f32) {
+ // expected-error @+1 {{mismatch in number of loop-carried values and results}}
+ %sum:2 = "xla_gpu.loop"(%init) <{
+ indexing_map_attr = #map,
+ operandSegmentSizes = array<i32: 0, 1>
+ }> ({
+ ^bb0(%i: index, %j: index, %iter: f32):
+ %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+ %add = arith.addf %iter, %t : f32
+ xla_gpu.yield %add : f32
+ }) : (f32) -> (f32, f32)
+ func.return %sum#0 : f32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0] -> (s0, s0), domain: s0 in [0, 1024]>
+func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>,
+ %init: f32) -> (f32) {
+ // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}}
+ %sum = "xla_gpu.loop"(%init) <{
+ indexing_map_attr = #map,
+ operandSegmentSizes = array<i32: 0, 1>
+ }> ({
+ ^bb0(%i: index, %j: index, %iter: f32):
+ %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+ %add = arith.addf %iter, %t : f32
+ xla_gpu.yield %add : f32
+ }) : (f32) -> (f32)
+ func.return %sum : f32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) {
+ // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}}
+ %sum = "xla_gpu.loop"(%init) <{
+ indexing_map_attr = #map,
+ operandSegmentSizes = array<i32: 0, 1>
+ }> ({
+ ^bb0(%i: index, %j: index, %iter: f32):
+ %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+ %add = arith.addf %iter, %t : f32
+ xla_gpu.yield %add : f32
+ }) : (f32) -> (i32)
+ func.return %sum : i32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) {
+ // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}}
+ %sum = xla_gpu.loop ()[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+ %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+ %add = arith.addf %sum_, %t : f32
+ xla_gpu.yield %add : f32
+ } {xla.range = [0 : index, 42 : index]}
+ func.return %sum : f32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> {
+ // expected-error @+1 {{number of indices must match number of dimensions of indexing map}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{must have thread_id dimension in both indexing maps}}
+ %0 = xla_gpu.materialize @exp(%input) at #map() : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{must have thread_id dimension in both indexing maps}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]>
+func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{constraints of indexing maps must be equal for the thread_id dimension}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{number of symbols in both indexing_maps must match}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{domain of symbols of indexing_maps must match}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]>
+func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]>
+func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]>
+func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]>
+func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{vector mapping indices must not depend on the block ID}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]>
+func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]>
+func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+ // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}}
+ %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir
new file mode 100644
index 0000000..c4378a3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir_fusions_opt %s --split-input-file | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s
+
+func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) {
+ %shared1 = xla_gpu.allocate_shared : tensor<2xf32>
+ %shared2 = xla_gpu.allocate_shared : tensor<2xf32>
+ %sync:2 = xla_gpu.sync_threads %shared1, %shared2
+ : tensor<2xf32>, tensor<2xf32>
+ return %sync#0, %sync#1 : tensor<2xf32>, tensor<2xf32>
+}
+// CHECK-LABEL: @shared_and_sync
+// CHECK-NEXT: allocate_shared
+// CHECK-NEXT: allocate_shared
+// CHECK-NEXT: sync_threads
+// CHECK-NEXT: return
+
+// -----
+
+func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index)
+ -> (tensor<2x3xf32>) {
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
+ ^bb0(%current : f32):
+ %c42 = arith.constant 42.0 : f32
+ %add = arith.addf %current, %c42 : f32
+ xla_gpu.yield %add : f32
+ }
+ return %ret : tensor<2x3xf32>
+}
+// CHECK-LABEL: @atomic_rmw
+// CHECK: xla_gpu.atomic_rmw
+
+// -----
+
+func.func private @add(%a: f32, %b: f32) -> f32 {
+ %ret = arith.addf %a, %b : f32
+ return %ret : f32
+}
+
+func.func @caller(%a: f32, %b: f32) -> f32 {
+ %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
+ %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
+ %ret = arith.addf %c, %d : f32
+ return %ret : f32
+}
+// CHECK-LABEL: @caller
+// CHECK: %[[C:.*]] = xla_gpu.pure_call @add
+// CHECK: %[[D:.*]] = xla_gpu.pure_call @add
+// CHECK: arith.addf %[[C]], %[[D]]
+
+// CHECK-CSE: @caller
+// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add
+// CHECK-CSE: arith.addf %[[C]], %[[C]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+(d0, d1)[s0] -> (d0, d1 + s0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ s0 in [0, 32]
+>
+func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
+ %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+ func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1)[s0] -> (d0, d1 + s0)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [1, 2]
+// CHECK-SAME: d1 in [5, 8]
+// CHECK-SAME: s0 in [0, 32]
+// CHECK-SAME: >
+
+// CHECK-LABEL: @apply_indexing
+// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
+// CHECK-SAME: (%[[d0]], %[[d1]])[%[[s0]]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+(d0, d1) -> (d0, d1),
+ domain:
+ d0 in [0, 2],
+ d1 in [1, 3]
+>
+func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) {
+ %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)
+ func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1) -> (d0, d1)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [0, 2]
+// CHECK-SAME: d1 in [1, 3]
+// CHECK-SAME: >
+
+// CHECK-LABEL: @apply_indexing_no_symbols
+// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
+// CHECK-SAME: (%[[d0]], %[[d1]])
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+ ()[s0] -> (s0, s0),
+ domain:
+ s0 in [2, 4]
+>
+func.func @apply_indexing_no_dims(%s0: index) -> (index, index) {
+ %0:2 = xla_gpu.apply_indexing #map0 [%s0]
+ func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: ()[s0] -> (s0, s0)
+// CHECK-SAME: domain:
+// CHECK-SAME: s0 in [2, 4]
+// CHECK-SAME: >
+
+// CHECK-LABEL: @apply_indexing_no_dims
+// CHECK: (%[[s0:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]]]
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) {
+ %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+ %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+ %add = arith.addf %sum_, %t : f32
+ xla_gpu.yield %add : f32
+ } {xla.range = [0 : index, 42 : index]}
+ func.return %sum : f32
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map
+// CHECK: %0 = xla_gpu.loop (%{{.*}})[%[[I:.*]], %[[J:.*]]] in #[[$MAP]]
+// CHECK-SAME: iter_args(%[[SUM_ITER:.*]] = %{{.*}}) -> (f32) {
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[%[[I]], %[[J]]]
+// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[EXTRACTED]] : f32
+// CHECK: xla_gpu.yield %[[ADD]] : f32
+// CHECK: } {xla.range = [0 : index, 42 : index]}
+
+// -----
+
+func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> {
+ %0 = xla_gpu.materialize @exp(%input) at #map(%i, %j) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+ %1 = xla_gpu.insert %0 into %output at #map1(%i, %j) : !xla_gpu.indexed_vector<32x64xf32, #map1> -> tensor<32x64xf32> into tensor<32x64xf32>
+ func.return %1 : tensor<32x64xf32>
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)
+// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1)
+// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+// CHECK-LABEL: @materialize_and_insert
+// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at #[[$MAP]](%{{.*}}, %{{.*}})
+// CHECK: xla_gpu.insert %[[MATERIALIZED]] into %{{.*}} at #[[$MAP1]](%{{.*}}, %{{.*}})
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc
new file mode 100644
index 0000000..a3220b0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc
@@ -0,0 +1,231 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/str_format.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+#define GET_ATTRDEF_LIST
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc"
+
+namespace xla {
+namespace gpu {
+
+using llvm::ParseResult;
+using llvm::SmallVector;
+using mlir::AffineExpr;
+using mlir::ArrayRef;
+using mlir::AsmParser;
+using mlir::AsmPrinter;
+using mlir::failure;
+using mlir::success;
+
+ParseResult ParseInterval(AsmParser& parser, Interval& interval) {
+ // ParseResult converts to `true` if parsing failed.
+ return failure(parser.parseLSquare() || parser.parseInteger(interval.lower) ||
+ parser.parseComma() || parser.parseInteger(interval.upper) ||
+ parser.parseRSquare());
+}
+
+void PrintDimVars(AsmPrinter& p, ArrayRef<DimVar> dim_vars) {
+ int index = 0;
+ llvm::interleaveComma(dim_vars, p, [&](const DimVar& dim_var) {
+ p << "d" << index++ << " in " << dim_var.bounds;
+ });
+}
+
+ParseResult ParseDimVars(AsmParser& parser, ArrayRef<std::string> dim_names,
+ SmallVector<DimVar>& dim_vars) {
+ dim_vars.reserve(dim_names.size());
+ for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) {
+ if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") ||
+ ParseInterval(parser, dim_vars.emplace_back().bounds)) {
+ return failure();
+ }
+ if (index < dim_names.size() - 1 && parser.parseComma()) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+void PrintRangeVars(AsmPrinter& p, ArrayRef<RangeVar> range_vars) {
+ int index = 0;
+ llvm::interleaveComma(range_vars, p, [&](const RangeVar& range_var) {
+ p << "s" << index++ << " in " << range_var.range;
+ });
+}
+
+ParseResult ParseRangeVars(AsmParser& parser,
+ ArrayRef<std::string> range_symbol_names,
+ SmallVector<RangeVar>& range_vars) {
+ range_vars.reserve(range_symbol_names.size());
+ for (const auto& [index, range_symbol_name] :
+ llvm::enumerate(range_symbol_names)) {
+ if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") ||
+ ParseInterval(parser, range_vars.emplace_back().range)) {
+ return failure();
+ }
+ if (index < range_symbol_names.size() - 1 && parser.parseComma()) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+void PrintConstraints(AsmPrinter& p,
+ ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
+ llvm::interleaveComma(constraints, p, [&](const auto& constraint) {
+ p << constraint.first << " in " << constraint.second;
+ });
+}
+
+ParseResult ParseConstraints(
+ AsmParser& parser,
+ ArrayRef<std::pair<llvm::StringRef, AffineExpr>> symbolSet,
+ SmallVector<std::pair<AffineExpr, Interval>>& constraints) {
+ // In order for there to be any constraints, there must be at least 1 symbol
+ // or dimension meaning there will be commas for as long as there are
+ // constraints left.
+ while (succeeded(parser.parseOptionalComma())) {
+ auto& constraint = constraints.emplace_back();
+ if (parser.parseAffineExpr(symbolSet, constraint.first) ||
+ parser.parseKeyword("in") || ParseInterval(parser, constraint.second)) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) {
+ mlir::AffineMap map;
+ if (parser.parseLess() || parser.parseAffineMap(map)) {
+ return {};
+ }
+
+ // Store real strings to back up StringRef throughout ParseConstraints.
+ SmallVector<std::string> dim_strings(map.getNumDims());
+ SmallVector<std::string> symbol_strings(map.getNumSymbols());
+ SmallVector<std::pair<llvm::StringRef, AffineExpr>> symbolSet;
+ symbolSet.reserve(map.getNumDims() + map.getNumSymbols());
+ for (int i = 0; i < map.getNumDims(); ++i) {
+ dim_strings[i] = absl::StrFormat("d%d", i);
+ symbolSet.push_back(
+ {dim_strings[i], mlir::getAffineDimExpr(i, parser.getContext())});
+ }
+ for (int i = 0; i < map.getNumSymbols(); ++i) {
+ symbol_strings[i] = absl::StrFormat("s%d", i);
+ symbolSet.push_back(
+ {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())});
+ }
+ if (map.getNumDims() + map.getNumSymbols() > 0) {
+ if (parser.parseComma() || parser.parseKeyword("domain") ||
+ parser.parseColon()) {
+ return {};
+ }
+ }
+
+ SmallVector<DimVar> dim_vars;
+ if (map.getNumDims() > 0) {
+ if (ParseDimVars(parser, dim_strings, dim_vars)) {
+ return {};
+ }
+ }
+
+ SmallVector<RangeVar> range_vars;
+ if (map.getNumSymbols() > 0) {
+ if (!dim_vars.empty() && parser.parseComma()) {
+ return {};
+ }
+ if (ParseRangeVars(parser, symbol_strings, range_vars)) {
+ return {};
+ }
+ }
+
+ SmallVector<std::pair<AffineExpr, Interval>> constraints;
+ if (ParseConstraints(parser, symbolSet, constraints) ||
+ parser.parseGreater()) {
+ return {};
+ }
+ return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars,
+ constraints);
+}
+
+void IndexingMapAttr::print(mlir::AsmPrinter& printer) const {
+ printer << "<";
+ printer.printStrippedAttrOrType(getMap());
+ if (getDimVars().size() + getRangeVars().size() + getConstraints().size() >
+ 0) {
+ printer << ", domain: ";
+ }
+ PrintDimVars(printer, getDimVars());
+ if (!getDimVars().empty() &&
+ getRangeVars().size() + getConstraints().size() > 0) {
+ printer << ", ";
+ }
+ PrintRangeVars(printer, getRangeVars());
+ if (!getRangeVars().empty() && !getConstraints().empty()) {
+ printer << ", ";
+ }
+ PrintConstraints(printer, getConstraints());
+ printer << ">";
+}
+
+IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context,
+ const IndexingMap& indexing_map) {
+ llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints;
+ for (auto& constraint : indexing_map.GetConstraints()) {
+ constraints.push_back({constraint.first, constraint.second});
+ }
+ return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(),
+ indexing_map.GetRangeVars(), constraints);
+}
+
+mlir::LogicalResult IndexingMapAttr::verify(
+ mlir::function_ref<mlir::InFlightDiagnostic()> emitError,
+ mlir::AffineMap map, ArrayRef<DimVar> dim_vars,
+ ArrayRef<RangeVar> range_vars,
+ ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
+ if (map.getNumDims() != dim_vars.size()) {
+ return emitError()
+ << "dim size must match the number of dimensions in the affine map";
+ }
+ if (map.getNumSymbols() != range_vars.size()) {
+ return emitError()
+ << "range size must match the number of symbols in the affine map";
+ }
+ return mlir::success();
+}
+
+IndexingMap IndexingMapAttr::getIndexingMap() {
+ return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{},
+ getConstraints());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td
new file mode 100644
index 0000000..19dd24f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td
@@ -0,0 +1,64 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
+
+include "mlir/IR/AttrTypeBase.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td"
+
+class XLAGPU_Attr<string name, list<Trait> traits = []> :
+ AttrDef<XlaGpuDialect, name, traits> {
+}
+
+def XLAGPU_AffineMapParameter :
+ AttrOrTypeParameter<"::mlir::AffineMap", ""> {
+}
+
+def XLAGPU_DimVarsParameter : ArrayRefParameter<"::xla::gpu::DimVar",
+ "DimVarArray"> {
+}
+
+def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar",
+ "RangeVarArray"> {
+}
+
+def XLAGPU_ConstraintsParameter :
+ ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>",
+ "ContraintsArray"> {
+}
+
+def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> {
+ let summary = "An Attribute representing an indexing map.";
+ let mnemonic = "indexing_map";
+ let description = [{This attribute stores an indexing map. See
+ https://openxla.org/xla/indexing for more details.
+ }];
+ let parameters = (ins XLAGPU_AffineMapParameter:$map,
+ XLAGPU_DimVarsParameter:$dim_vars,
+ XLAGPU_RangeVarsParameter:$range_vars,
+ XLAGPU_ConstraintsParameter:$constraints);
+ let hasCustomAssemblyFormat = 1;
+ let builders = [
+ AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>,
+ ];
+ let genVerifyDecl = 1;
+ let extraClassDeclaration = [{
+ // Returns the indexing map constructed from IndexingMapAttr.
+ xla::gpu::IndexingMap getIndexingMap();
+ }];
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc
new file mode 100644
index 0000000..57d2d70
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc
@@ -0,0 +1,136 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
+#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
+#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
+#include "mlir/Transforms/InliningUtils.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc"
+#undef GET_ATTRDEF_CLASSES
+#define GET_TYPEDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc"
+#undef GET_TYPEDEF_CLASSES
+
+namespace xla {
+namespace gpu {
+namespace {
+
+struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+ // Returns true if the given operation 'callable', that implements the
+ // 'CallableOpInterface', can be inlined into the position given call
+ // operation 'call', that is registered to the current dialect and implements
+ // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the
+ // given 'callable' is set to be cloned during the inlining process, or false
+ // if the region is set to be moved in-place (i.e. no duplicates would be
+ // created).
+ bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable,
+ bool wouldBeCloned) const final {
+ if (!wouldBeCloned) {
+ // If no duplicate would be created, 'call' is likely the only caller of
+ // 'callable'.
+ return true;
+ }
+ // Otherwise, inline only if the called function is small. We could
+ // theoretically also inline if there is no other caller in the function
+ // that contains the callee that has a call path to the callable, but that
+ // is more expensive to check.
+ auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(callable);
+ if (!func_op) {
+ return false;
+ }
+ auto region = func_op.getCallableRegion();
+ if (!region) {
+ return false;
+ }
+
+ // If callee and caller call the same third function, inline. We have no
+ // guarantee that the indices are the same, but there is a good chance they
+ // are (or if the callee gets inlined as well, there will be CSE
+ // opportunities).
+ // This is duct tape to work around the limitations of our partitioner.
+ // Ideally, the partitioner would be aware of the actual indexing and create
+ // the partitions based on it (i.e., the case where the indices are the same
+ // would never happen).
+ llvm::SmallDenseSet<llvm::StringRef> callee_calls;
+ for (auto call : region->getOps<PureCallOp>()) {
+ callee_calls.insert(call.getCallee());
+ }
+ for (auto call : call->getParentRegion()->getOps<PureCallOp>()) {
+ if (callee_calls.contains(call.getCallee())) {
+ return true;
+ }
+ }
+
+ constexpr int kMaxOperationsToInline = 8;
+ int num_ops = 0;
+ region->front().walk([&](mlir::Operation* op) { ++num_ops; });
+
+ // Don't inline functions that are called more than once and contain more
+ // than one call themselves.
+ return num_ops <= kMaxOperationsToInline;
+ }
+ // Returns true if the given operation 'op', that is registered to this
+ // dialect, can be inlined into the given region, false otherwise.
+ // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned
+ // during the inlining process, or false if the operation is set to be moved
+ // in-place(i.e. no duplicates would be created). 'valueMapping' contains any
+ // remapped values from within the 'src' region. This can be used to examine
+ // what values may potentially replace the operands to 'op'.
+ bool isLegalToInline(mlir::Operation* op, mlir::Region* dest,
+ bool wouldBeCloned,
+ mlir::IRMapping& valueMapping) const final {
+ // We allow any op from the xla_gpu dialect to be inlined.
+ return true;
+ }
+};
+
+struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface {
+ using OpAsmDialectInterface::OpAsmDialectInterface;
+ AliasResult getAlias(mlir::Attribute attr,
+ mlir::raw_ostream& os) const final {
+ if (llvm::isa<IndexingMapAttr>(attr)) {
+ os << "indexing_map";
+ return AliasResult::FinalAlias;
+ }
+ return AliasResult::NoAlias;
+ }
+};
+
+} // namespace
+
+void XlaGpuDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc"
+#undef GET_OP_LIST
+ >();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc"
+ >();
+#undef GET_ATTRDEF_LIST
+ addInterfaces<XlaGpuInlinerInterface, XlaGpuOpAsmDialectInterface>();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc"
+#undef GET_TYPEDEF_LIST
+ >();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td
new file mode 100644
index 0000000..9a5c539
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td
@@ -0,0 +1,33 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
+
+include "mlir/IR/DialectBase.td"
+
+def XlaGpuDialect : Dialect {
+ let name = "xla_gpu";
+
+ let description = [{
+ This dialect contains ops required for lowering HLO to LLVM.
+ }];
+
+ let cppNamespace = "::xla::gpu";
+ let useDefaultAttributePrinterParser = 1;
+ let useDefaultTypePrinterParser = 1;
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc
new file mode 100644
index 0000000..bc59b7c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc
@@ -0,0 +1,841 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h" // IWYU pragma: keep
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
+#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h" // IWYU pragma: keep
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using llvm::ArrayRef;
+using mlir::AffineExpr;
+using mlir::AffineMap;
+using mlir::Block;
+using mlir::failure;
+using mlir::getAffineConstantExpr;
+using mlir::getAffineDimExpr;
+using mlir::getAffineSymbolExpr;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpAsmParser;
+using mlir::OpAsmPrinter;
+using mlir::OpBuilder;
+using mlir::OperationState;
+using mlir::PatternRewriter;
+using mlir::RankedTensorType;
+using mlir::Region;
+using mlir::SmallVector;
+using mlir::success;
+using mlir::Type;
+using mlir::TypeRange;
+using mlir::Value;
+using mlir::ValueRange;
+
+namespace arith = mlir::arith;
+
+} // namespace
+
+LogicalResult PureCallOp::verifySymbolUses(
+ mlir::SymbolTableCollection& symbolTable) {
+ auto callee = getCalleeAttr();
+ auto function =
+ symbolTable.lookupNearestSymbolFrom<mlir::func::FuncOp>(*this, callee);
+ if (!function) {
+ return emitError("'f' attribute refers to an undefined function: ")
+ << callee;
+ }
+
+ int func_arg_count = function.getFunctionType().getNumInputs();
+ int arg_count = getOperands().size();
+
+ if (arg_count != func_arg_count) {
+ return emitError() << "argument count mismatch: 'operands' has "
+ << arg_count << " arguments, but '" << callee
+ << "' expects " << func_arg_count;
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AllocateSharedOp
+//===----------------------------------------------------------------------===//
+
+void AllocateSharedOp::getAsmResultNames(
+ llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+ setNameFn(getResult(), "shmem");
+}
+
+//===----------------------------------------------------------------------===//
+// ApplyIndexingOp
+//===----------------------------------------------------------------------===//
+
+void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
+ ValueRange dims, ValueRange symbols,
+ const IndexingMap& indexing_map) {
+ SmallVector<Value, 4> operands;
+ operands.reserve(dims.size() + symbols.size());
+ operands.append(dims.begin(), dims.end());
+ operands.append(symbols.begin(), symbols.end());
+ build(builder, result, operands, indexing_map);
+}
+
+void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
+ ValueRange operands,
+ const IndexingMap& indexing_map) {
+ SmallVector<Type, 2> result_types(indexing_map.GetAffineMap().getNumResults(),
+ builder.getIndexType());
+ IndexingMapAttr indexing_map_attr =
+ IndexingMapAttr::get(builder.getContext(), indexing_map);
+ build(builder, result, result_types, operands, indexing_map_attr);
+}
+
+void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
+ ValueRange operands, AffineMap affine_map,
+ ArrayRef<DimVar> dim_vars,
+ ArrayRef<RangeVar> range_vars) {
+ IndexingMap indexing_map(affine_map, dim_vars, range_vars, {});
+ build(builder, result, operands, indexing_map);
+}
+
+// Parses a comma-separated list of operands, ex: %d1, %d2.
+mlir::ParseResult parseOperands(
+ OpAsmParser& parser,
+ SmallVector<OpAsmParser::UnresolvedOperand, 4>* operands) {
+ OpAsmParser::UnresolvedOperand operand;
+ return parser.parseCommaSeparatedList(
+ [&]() { return parser.parseOperand(operands->emplace_back()); });
+}
+
+mlir::ParseResult ApplyIndexingOp::parse(OpAsmParser& parser,
+ OperationState& result) {
+ mlir::Builder& builder = parser.getBuilder();
+ auto index_type = builder.getIndexType();
+
+ IndexingMapAttr indexing_map_attr;
+ if (parser.parseAttribute(indexing_map_attr, "indexing_map_attr",
+ result.attributes)) {
+ return failure();
+ }
+
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+ SmallVector<int64_t, 4> lower_bounds, upper_bounds;
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (parseOperands(parser, &operands) || parser.parseRParen()) {
+ return failure();
+ }
+ }
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parseOperands(parser, &operands) || parser.parseRSquare()) {
+ return failure();
+ }
+ }
+ if (parser.resolveOperands(operands, index_type, result.operands) ||
+ parser.parseOptionalAttrDict(result.attributes)) {
+ return failure();
+ }
+ auto map = indexing_map_attr.getMap();
+ result.addTypes(SmallVector<Type, 2>(map.getNumResults(), index_type));
+ return success();
+}
+
+void ApplyIndexingOp::print(OpAsmPrinter& p) {
+ AffineMap affine_map = getIndexingMapAttr().getMap();
+ p << " " << getIndexingMapAttr();
+
+ auto operands = getOperands();
+ unsigned num_dimensions = affine_map.getNumDims();
+ if (num_dimensions > 0) {
+ p << '(';
+ auto dimension_operands = operands.slice(0, num_dimensions);
+ llvm::interleaveComma(dimension_operands, p);
+ p << ')';
+ }
+
+ unsigned num_symbols = affine_map.getNumSymbols();
+ if (num_symbols > 0) {
+ p << '[';
+ auto symbol_operands = operands.slice(num_dimensions, num_symbols);
+ llvm::interleaveComma(symbol_operands, p);
+ p << ']';
+ }
+
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/{"indexing_map_attr"});
+}
+
+LogicalResult ApplyIndexingOp::verify() {
+ auto affine_map = getIndexingMapAttr().getMap();
+ unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols();
+ if (getOperands().size() != num_variables) {
+ return emitOpError(
+ "operand count must match the number of dimensions and symbols in the "
+ "affine map");
+ }
+ if (!getIndexingMapAttr().getConstraints().empty()) {
+ return emitOpError("apply indexing op cannot have any constraints");
+ }
+ return success();
+}
+
+IndexingMap ApplyIndexingOp::getIndexingMap() {
+ return getIndexingMapAttr().getIndexingMap();
+}
+
+namespace {
+
+// Simplifies the indexing map, removes unused variables.
+struct SimplifyIndexingMap : public mlir::OpRewritePattern<ApplyIndexingOp> {
+ using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+ PatternRewriter& rewriter) const override {
+ IndexingMap indexing_map = indexing_op.getIndexingMap();
+ bool is_simplified = indexing_map.Simplify();
+
+ // Remove unused symbols.
+ auto unused_symbols_bit_vector = indexing_map.RemoveUnusedVars();
+ bool symbols_removed = unused_symbols_bit_vector.count() != 0;
+
+ if (!is_simplified && !symbols_removed) {
+ return rewriter.notifyMatchFailure(indexing_op,
+ "IndexingMap stayed unchanged");
+ }
+ if (!unused_symbols_bit_vector.empty()) {
+ SmallVector<Value, 4> operands;
+ operands.reserve(unused_symbols_bit_vector.count());
+ for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) {
+ if (!unused_symbols_bit_vector[i]) {
+ operands.push_back(indexing_op.getOperand(i));
+ }
+ }
+ rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, operands,
+ indexing_map);
+ } else {
+ rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
+ indexing_op, indexing_op.getOperands(), indexing_map);
+ }
+ return success();
+ }
+};
+
+struct FoldApplyIndexingSequence
+ : public mlir::OpRewritePattern<ApplyIndexingOp> {
+ using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+ PatternRewriter& rewriter) const override {
+ MLIRContext* ctx = indexing_op.getContext();
+ int num_dims = indexing_op.getAffineMap().getNumDims();
+ int num_syms = indexing_op.getAffineMap().getNumSymbols();
+ mlir::DenseMap<Value, AffineExpr> operand_exprs;
+ for (auto& operand : indexing_op->getOpOperands()) {
+ int operand_number = operand.getOperandNumber();
+ operand_exprs[operand.get()] =
+ operand_number < num_dims
+ ? getAffineDimExpr(operand_number, ctx)
+ : getAffineSymbolExpr(operand_number - num_dims, ctx);
+ }
+
+ auto this_map = indexing_op.getIndexingMap();
+
+ SmallVector<Value> added_dim_args;
+ SmallVector<Value> added_sym_args;
+ auto new_dim_vars = this_map.GetDimVars();
+ auto new_sym_vars = this_map.GetRangeVars();
+
+ mlir::DenseMap<AffineExpr, AffineExpr> replacements;
+ for (auto& operand : indexing_op->getOpOperands()) {
+ if (auto producer = operand.get().getDefiningOp<ApplyIndexingOp>()) {
+ auto producer_map = producer.getIndexingMap();
+ int producer_result_id =
+ mlir::cast<mlir::OpResult>(operand.get()).getResultNumber();
+ int num_producer_dims = producer.getAffineMap().getNumDims();
+ SmallVector<AffineExpr> producer_dim_replacements;
+ SmallVector<AffineExpr> producer_sym_replacements;
+ for (auto& producer_operand : producer->getOpOperands()) {
+ int producer_operand_number = producer_operand.getOperandNumber();
+ bool is_dim = producer_operand_number < num_producer_dims;
+ auto& replacement_expr = operand_exprs[producer_operand.get()];
+ if (!replacement_expr) {
+ if (is_dim) {
+ int dim_num = producer_operand_number;
+ replacement_expr =
+ getAffineDimExpr(num_dims + added_dim_args.size(), ctx);
+ added_dim_args.push_back(producer_operand.get());
+ new_dim_vars.push_back(producer_map.GetDimVars(dim_num));
+ } else {
+ int sym_num = producer_operand_number -
+ producer.getAffineMap().getNumDims();
+ replacement_expr =
+ getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx);
+ added_sym_args.push_back(producer_operand.get());
+ new_sym_vars.push_back(producer_map.GetRangeVar(sym_num));
+ }
+ }
+
+ if (is_dim) {
+ producer_dim_replacements.push_back(replacement_expr);
+ } else {
+ producer_sym_replacements.push_back(replacement_expr);
+ }
+ }
+
+ replacements[operand_exprs[operand.get()]] =
+ producer.getAffineMap()
+ .getResult(producer_result_id)
+ .replaceDimsAndSymbols(producer_dim_replacements,
+ producer_sym_replacements);
+ }
+ }
+
+ if (replacements.empty()) {
+ return rewriter.notifyMatchFailure(indexing_op,
+ "No apply_indexing sequences found");
+ }
+
+ int new_num_operands = indexing_op->getNumOperands() +
+ added_dim_args.size() + added_sym_args.size();
+ auto new_affine_map = indexing_op.getAffineMap().replace(
+ replacements, num_dims + added_dim_args.size(),
+ num_syms + added_sym_args.size());
+ IndexingMap new_indexing_map(new_affine_map, new_dim_vars, new_sym_vars,
+ /*rt_vars=*/{});
+ if (!new_indexing_map.Simplify()) {
+ return rewriter.notifyMatchFailure(
+ indexing_op, "Folded indexing map was not simplified");
+ }
+ SmallVector<Value> new_operands;
+ new_operands.reserve(new_num_operands);
+
+ auto begin = indexing_op.getOperands().begin();
+ new_operands.append(begin, begin + num_dims);
+ new_operands.append(added_dim_args);
+ new_operands.append(begin + num_dims, begin + num_dims + num_syms);
+ new_operands.append(added_sym_args);
+
+ rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, new_operands,
+ new_indexing_map);
+ return success();
+ }
+};
+
+// Folds constants into the indexing map.
+struct FoldApplyIndexingOperands
+ : public mlir::OpRewritePattern<ApplyIndexingOp> {
+ using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+ PatternRewriter& rewriter) const override {
+ IndexingMap indexing_map = indexing_op.getIndexingMap();
+ AffineMap affine_map = indexing_map.GetAffineMap();
+
+ MLIRContext* ctx = affine_map.getContext();
+ unsigned num_operands = indexing_op->getNumOperands();
+ unsigned num_dims = affine_map.getNumDims();
+ unsigned num_symbols = affine_map.getNumSymbols();
+
+ SmallVector<std::optional<int64_t>> constant_values(num_operands,
+ std::nullopt);
+ int num_constants = 0;
+ for (auto& operand : indexing_op->getOpOperands()) {
+ if (auto constant =
+ operand.get().getDefiningOp<arith::ConstantIndexOp>()) {
+ constant_values[operand.getOperandNumber()] = constant.value();
+ ++num_constants;
+ }
+ }
+ if (num_constants == 0) {
+ return rewriter.notifyMatchFailure(indexing_op,
+ "No constant operands found");
+ }
+ SmallVector<AffineExpr, 2> dim_replacements, symbol_replacements;
+ dim_replacements.reserve(num_dims);
+ symbol_replacements.reserve(num_symbols);
+
+ unsigned new_num_operands = indexing_op->getNumOperands() - num_constants;
+ SmallVector<Value, 4> new_operands;
+ new_operands.reserve(new_num_operands);
+ SmallVector<DimVar, 2> new_dim_vars;
+ new_dim_vars.reserve(num_dims);
+ SmallVector<RangeVar, 2> new_range_vars;
+ new_range_vars.reserve(num_symbols);
+
+ unsigned new_num_dims = 0;
+ unsigned new_num_symbols = 0;
+ for (auto [operand, constant_value] :
+ llvm::zip(indexing_op->getOpOperands(), constant_values)) {
+ unsigned operand_id = operand.getOperandNumber();
+ if (constant_value.has_value()) {
+ if (operand_id < num_dims) {
+ dim_replacements.push_back(
+ getAffineConstantExpr(*constant_value, ctx));
+ } else {
+ symbol_replacements.push_back(
+ getAffineConstantExpr(*constant_value, ctx));
+ }
+ } else {
+ new_operands.push_back(operand.get());
+ if (operand_id < num_dims) {
+ dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx));
+ new_dim_vars.push_back(indexing_map.GetDimVars(operand_id));
+ } else {
+ symbol_replacements.push_back(
+ getAffineSymbolExpr(new_num_symbols++, ctx));
+ new_range_vars.push_back(
+ indexing_map.GetRangeVar(operand_id - num_dims));
+ }
+ }
+ }
+ rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
+ indexing_op, new_operands,
+ affine_map.replaceDimsAndSymbols(dim_replacements, symbol_replacements,
+ new_num_dims, new_num_symbols),
+ new_dim_vars, new_range_vars);
+ return success();
+ }
+};
+
+// Folds constant and dim/symbol expression results.
+struct FoldApplyIndexingResults
+ : public mlir::OpRewritePattern<ApplyIndexingOp> {
+ using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+ PatternRewriter& rewriter) const override {
+ mlir::Location loc = indexing_op.getLoc();
+ IndexingMap indexing_map = indexing_op.getIndexingMap();
+ if (indexing_map.IsKnownEmpty()) {
+ return rewriter.notifyMatchFailure(indexing_op,
+ "Domain of the indexing map is empty");
+ }
+ AffineMap* affine_map = &indexing_map.GetMutableAffineMap();
+ unsigned num_results = affine_map->getNumResults();
+ SmallVector<AffineExpr, 4> new_exprs;
+ new_exprs.reserve(num_results);
+ SmallVector<Value, 4> new_values;
+ new_values.reserve(num_results);
+ for (mlir::OpResult opresult : indexing_op->getOpResults()) {
+ if (opresult.use_empty()) {
+ new_values.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ continue;
+ }
+
+ unsigned id = opresult.getResultNumber();
+ AffineExpr result_expr = affine_map->getResult(id);
+ if (auto const_expr =
+ mlir::dyn_cast<mlir::AffineConstantExpr>(result_expr)) {
+ new_values.push_back(rewriter.create<arith::ConstantIndexOp>(
+ loc, const_expr.getValue()));
+ continue;
+ }
+ if (auto dim_expr = mlir::dyn_cast<mlir::AffineDimExpr>(result_expr)) {
+ new_values.push_back(indexing_op.getOperand(dim_expr.getPosition()));
+ continue;
+ }
+ if (auto symbol_expr =
+ mlir::dyn_cast<mlir::AffineSymbolExpr>(result_expr)) {
+ new_values.push_back(indexing_op.getOperand(
+ indexing_map.GetDimVarsCount() + symbol_expr.getPosition()));
+ continue;
+ }
+ new_exprs.push_back(result_expr);
+ new_values.push_back(Value{});
+ }
+ if (new_exprs.size() == num_results) {
+ return rewriter.notifyMatchFailure(
+ indexing_op, "No constant or dim/symbol expression found");
+ }
+ *affine_map =
+ AffineMap::get(affine_map->getNumDims(), affine_map->getNumSymbols(),
+ new_exprs, affine_map->getContext());
+ auto new_indexing_op = rewriter.create<ApplyIndexingOp>(
+ loc, indexing_op.getOperands(), indexing_map);
+ for (int new_result_id = 0, new_indexing_op_result_id = 0;
+ new_result_id < new_values.size(); ++new_result_id) {
+ auto& new_value = new_values[new_result_id];
+ if (new_value) continue;
+ new_value = new_indexing_op.getResult(new_indexing_op_result_id++);
+ }
+ rewriter.replaceOp(indexing_op, new_values);
+ return success();
+ }
+};
+
+} // namespace
+
+void ApplyIndexingOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet& results, MLIRContext* context) {
+ results.add<FoldApplyIndexingOperands, FoldApplyIndexingResults,
+ SimplifyIndexingMap, FoldApplyIndexingSequence>(context);
+}
+
+mlir::LogicalResult ApplyIndexingOp::fold(
+ FoldAdaptor adaptor, llvm::SmallVectorImpl<mlir::OpFoldResult>& results) {
+ auto map = getAffineMap();
+ for (auto expr : map.getResults()) {
+ if (auto dim = mlir::dyn_cast<mlir::AffineDimExpr>(expr)) {
+ results.push_back(getOperand(dim.getPosition()));
+ } else if (auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(expr)) {
+ results.push_back(getOperand(map.getNumDims() + sym.getPosition()));
+ } else {
+ results.clear();
+ return failure();
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+void AtomicRMWOp::getAsmResultNames(
+ llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+ setNameFn(getResult(), "atomic_rmw");
+}
+
+void AtomicRMWOp::build(OpBuilder& builder, OperationState& result,
+ Value tensor, ValueRange ivs) {
+ OpBuilder::InsertionGuard g(builder);
+ result.addOperands(tensor);
+ result.addOperands(ivs);
+ result.addTypes(tensor.getType());
+
+ auto tensor_type = llvm::cast<RankedTensorType>(tensor.getType());
+ Region* body = result.addRegion();
+ builder.createBlock(body);
+ body->addArgument(tensor_type.getElementType(), tensor.getLoc());
+}
+
+mlir::OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
+ auto* body = getBody();
+ if (&body->front() == body->getTerminator() &&
+ body->front().getOperand(0) == body->getArgument(0)) {
+ return getOperand(0);
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// PureCallOp
+//===----------------------------------------------------------------------===//
+
+void PureCallOp::getAsmResultNames(
+ llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+ for (auto result : getResults()) {
+ setNameFn(result, "pure_call");
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// SyncThreadsOp
+//===----------------------------------------------------------------------===//
+
+void SyncThreadsOp::getAsmResultNames(
+ llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+ for (auto result : getResults()) {
+ setNameFn(result, "synced_tensor");
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+void LoopOp::build(OpBuilder& builder, OperationState& result,
+ IndexingMapAttr indexing_map_attr, ValueRange dims,
+ ValueRange inits, BodyBuilderFn bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+
+ int64_t num_ivs = indexing_map_attr.getRangeVars().size();
+ result.addOperands(dims);
+ result.addOperands(inits);
+ result.addTypes(TypeRange(inits));
+ Block* body_block = builder.createBlock(result.addRegion());
+ // Add induction variables block args.
+ for (int i = 0; i < num_ivs; ++i) {
+ body_block->addArgument(builder.getIndexType(), result.location);
+ }
+ // Add iteration arguments block args.
+ for (auto init_type : TypeRange(inits)) {
+ body_block->addArguments(init_type, result.location);
+ }
+
+ mlir::OperationName opname(LoopOp::getOperationName(), builder.getContext());
+ result.addAttribute(LoopOp::getIndexingMapAttrAttrName(opname),
+ indexing_map_attr);
+ result.addAttribute(
+ LoopOp::getOperandSegmentSizesAttrName(opname),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(dims.size()),
+ static_cast<int32_t>(inits.size())}));
+ if (bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(body_block);
+ bodyBuilder(builder, result.location,
+ body_block->getArguments().take_front(num_ivs),
+ body_block->getArguments().drop_front(num_ivs));
+ }
+}
+
+void LoopOp::build(OpBuilder& builder, OperationState& result,
+ const IndexingMap& indexing_map, ValueRange dims,
+ ValueRange inits, BodyBuilderFn bodyBuilder) {
+ build(builder, result,
+ IndexingMapAttr::get(builder.getContext(), indexing_map), dims, inits,
+ bodyBuilder);
+}
+
+mlir::ParseResult LoopOp::parse(OpAsmParser& parser, OperationState& result) {
+ SmallVector<OpAsmParser::Argument, 4> region_args, ivs, iter_args;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> dim_operands;
+
+ // Parse the dimension values.
+ OpBuilder b(parser.getContext());
+ Type index_type = b.getIndexType();
+ if (parser.parseOperandList(dim_operands, OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(dim_operands, index_type, result.operands))
+ return failure();
+ // Parse the induction variables.
+ if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Square))
+ return failure();
+ for (auto iv : ivs) {
+ region_args.push_back(iv);
+ region_args.back().type = index_type;
+ }
+
+ // Parse the indexing map attribute.
+ IndexingMapAttr indexing_map_attr;
+ if (parser.parseKeyword("in") ||
+ parser.parseAttribute(indexing_map_attr, "indexing_map_attr",
+ result.attributes)) {
+ return failure();
+ }
+
+ // Parse the arguments.
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> init_operands;
+ if (parser.parseKeyword("iter_args") ||
+ parser.parseAssignmentList(iter_args, init_operands) ||
+ parser.parseArrowTypeList(result.types) ||
+ parser.resolveOperands(init_operands, result.types, parser.getNameLoc(),
+ result.operands))
+ return failure();
+
+ for (auto [index, iter_arg] : llvm::enumerate(iter_args)) {
+ region_args.push_back(iter_arg);
+ region_args.back().type = result.types[index];
+ }
+
+ if (region_args.size() != result.types.size() + ivs.size()) {
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of induction variables + "
+ "loop-carried values and the number of results");
+ }
+
+ // Parse the body region.
+ Region* body = result.addRegion();
+ if (parser.parseRegion(*body, region_args)) return failure();
+ LoopOp::ensureTerminator(*body, b, result.location);
+
+ // Parse the optional attribute list
+ result.addAttribute(
+ LoopOp::getOperandSegmentSizeAttr(),
+ b.getDenseI32ArrayAttr({static_cast<int32_t>(dim_operands.size()),
+ static_cast<int32_t>(iter_args.size())}));
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+
+ return success();
+}
+
+void LoopOp::print(OpAsmPrinter& p) {
+ p << " (" << getDims() << ")[" << getInductionVars() << "] in "
+ << getIndexingMapAttr() << " iter_args(";
+ llvm::interleaveComma(
+ llvm::zip(getRegionIterArgs(), getInits()), p,
+ [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); });
+ p << ") -> (" << getInits().getTypes() << ") ";
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/{
+ getIndexingMapAttrAttrName(),
+ getOperandSegmentSizesAttrName(),
+ });
+}
+
+LogicalResult LoopOp::verify() {
+ if (getInits().size() != getNumResults()) {
+ return emitOpError("mismatch in number of loop-carried values and results");
+ }
+ IndexingMap indexing_map = getIndexingMap();
+ if (indexing_map.GetRangeVarsCount() != getNumInductionVars()) {
+ return emitOpError() << "mismatch in number of induction variables "
+ << getNumInductionVars()
+ << " and RangeVars in the indexing map "
+ << indexing_map.ToString();
+ }
+ if (indexing_map.GetDimVarsCount() != getDims().size()) {
+ return emitOpError() << "mismatch in number of dims operands "
+ << getDims().size()
+ << " and DimVars in the indexing map "
+ << indexing_map.ToString();
+ }
+ for (auto [bb_arg, result_type, init] :
+ llvm::zip(getRegionIterArgs(), getResultTypes(), getInits())) {
+ if (bb_arg.getType() != result_type || init.getType() != result_type) {
+ return emitOpError() << "block iter arg type = " << bb_arg.getType()
+ << ", result type = " << result_type
+ << " and init operand type = " << init.getType()
+ << " should match";
+ }
+ }
+ return success();
+}
+
+IndexingMap LoopOp::getIndexingMap() {
+ return getIndexingMapAttr().getIndexingMap();
+}
+
+//===----------------------------------------------------------------------===//
+// MaterializeOp
+//===----------------------------------------------------------------------===//
+
+VariableConstraints GetConstraintsForVariables(const IndexingMap& map) {
+ VariableConstraints result;
+ result.constraints_for_dims.resize(map.GetDimensionCount());
+ result.constraints_for_symbols.resize(map.GetSymbolCount());
+ for (const auto& constraint : map.GetConstraints()) {
+ constraint.first.walk([&](mlir::AffineExpr leaf) {
+ if (auto dim = mlir::dyn_cast<mlir::AffineDimExpr>(leaf)) {
+ result.constraints_for_dims[dim.getPosition()].push_back(constraint);
+ } else if (auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(leaf)) {
+ result.constraints_for_symbols[sym.getPosition()].push_back(constraint);
+ }
+ });
+ }
+ return result;
+}
+
+LogicalResult MaterializeOp::verify() {
+ IndexingMap map_in = getMap().getIndexingMap();
+ IndexingMap map_out =
+ getResult().getType().getIndexingMapAttr().getIndexingMap();
+ if (getIndices().size() != map_in.GetDimVarsCount()) {
+ return emitOpError() << "number of indices must match number of dimensions "
+ "of indexing map";
+ }
+
+ // The thread dimension must have the same domain (range and constraints)
+ if (map_in.GetDimVarsCount() == 0 || map_out.GetDimVarsCount() == 0) {
+ return emitOpError()
+ << "must have thread_id dimension in both indexing maps";
+ }
+ if (map_in.GetDimVars(0) != map_out.GetDimVars(0)) {
+ return emitOpError() << "thread_id dimension must have the same bounds in "
+ "both indexing maps";
+ }
+
+ auto variable_constraints_in = GetConstraintsForVariables(map_in);
+ auto variable_constraints_out = GetConstraintsForVariables(map_out);
+ if (variable_constraints_in.constraints_for_dims[0] !=
+ variable_constraints_out.constraints_for_dims[0]) {
+ return emitOpError() << "constraints of indexing maps must be equal for "
+ << "the thread_id dimension";
+ }
+
+ // The two maps must have the same symbols and they must have the same domain
+ if (map_in.GetRangeVarsCount() != map_out.GetRangeVarsCount()) {
+ return emitOpError()
+ << "number of symbols in both indexing_maps must match";
+ }
+ for (auto const& [range_in, range_out] :
+ llvm::zip(map_in.GetRangeVars(), map_out.GetRangeVars())) {
+ if (range_in.range != range_out.range) {
+ return emitOpError() << "domain of symbols of indexing_maps must match";
+ }
+ }
+ if (variable_constraints_in.constraints_for_symbols !=
+ variable_constraints_out.constraints_for_symbols) {
+ return emitOpError()
+ << "constraints of indexing maps must be equal for all symbols";
+ }
+
+ // The vector mapping indices must not depend on the block ID
+ if (map_out.GetDimVarsCount() > 1) {
+ for (auto expr : map_out.GetAffineMap().getResults()) {
+ if (expr.isFunctionOfDim(1)) {
+ return emitOpError() << "vector mapping indices must not depend on the "
+ << "block ID";
+ }
+ }
+ }
+ // If there are constraints on the block ID, they must be the same in both
+ // maps
+ if (map_in.GetDimVarsCount() > 1 && map_out.GetDimVarsCount() > 1) {
+ if (variable_constraints_in.constraints_for_dims[1] !=
+ variable_constraints_out.constraints_for_dims[1]) {
+ return emitOpError() << "constraints of indexing maps must be equal for "
+ << "the block_id dimension";
+ }
+ } else if (map_in.GetDimVarsCount() > 1 &&
+ !variable_constraints_in.constraints_for_dims[1].empty()) {
+ return emitOpError() << "constraints of indexing maps must be equal for "
+ << "the block_id dimension";
+ } else if (map_out.GetDimVarsCount() > 1 &&
+ !variable_constraints_out.constraints_for_dims[1].empty()) {
+ return emitOpError() << "constraints of indexing maps must be equal for "
+ << "the block_id dimension";
+ }
+
+ return success();
+}
+
+} // namespace gpu
+} // namespace xla
+
+#define GET_OP_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc"
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h
new file mode 100644
index 0000000..e3b8fd6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h
@@ -0,0 +1,56 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_
+#define XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_
+
+#include <utility>
+
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep
+#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep
+#include "mlir/IR/Attributes.h" // IWYU pragma: keep
+#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep
+#include "mlir/IR/Dialect.h" // IWYU pragma: keep
+#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep
+#include "mlir/IR/OpDefinition.h" // IWYU pragma: keep
+#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
+#include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep
+#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep
+#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep
+#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc"
+#include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc"
+#undef GET_ATTRDEF_CLASSES
+#define GET_TYPEDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc"
+#undef GET_TYPEDEF_CLASSES
+#define GET_OP_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h.inc"
+#undef GET_OP_CLASSES
+
+namespace xla::gpu {
+
+struct VariableConstraints {
+ llvm::SmallVector<llvm::SmallVector<std::pair<mlir::AffineExpr, Interval>>>
+ constraints_for_dims;
+ llvm::SmallVector<llvm::SmallVector<std::pair<mlir::AffineExpr, Interval>>>
+ constraints_for_symbols;
+};
+VariableConstraints GetConstraintsForVariables(const IndexingMap& map);
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td
new file mode 100644
index 0000000..9eb246f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td
@@ -0,0 +1,367 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_types.td"
+
+class XLAGPU_Op<string mnemonic, list<Trait> traits = []> :
+ Op<XlaGpuDialect, mnemonic, traits> {
+}
+
+def XLAGPU_AllocateSharedOp : XLAGPU_Op<"allocate_shared", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Allocates a shared memory tile.";
+
+ let description = [{
+ Allocates a shared memory tensor. The tensor is shared among all threads in
+ a block.
+
+ ```mlir
+ %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
+ ```
+ }];
+
+ let results = (outs AnyStaticShapeTensor:$result);
+
+ let assemblyFormat = "attr-dict `:` type($result)";
+}
+
+def XLAGPU_SyncThreadsOp : XLAGPU_Op<"sync_threads", [
+ TypesMatchWith<"result type matches type of dest",
+ "operands", "results", "$_self">,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Synchronizes threads.";
+
+ let description = [{
+ Synchronizes threads, taking any number of distributed tensors and returning
+ the synchronized state.
+ }];
+
+ let arguments = (ins Variadic<AnyRankedTensor>:$operands);
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+
+ let assemblyFormat = "operands attr-dict `:` type($operands)";
+}
+
+def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw",
+ [Pure,
+ TypesMatchWith<"result type matches type of dest",
+ "input", "result", "$_self">,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Atomically updates an element of a tensor.";
+
+ let description = [{
+ Reads an element from a tensor, computes the updated value for it, and
+ writes back the result.
+ }];
+
+ let arguments = (ins AnyRankedTensor:$input, Variadic<Index>:$indices);
+ let results = (outs AnyRankedTensor:$result);
+ // The region takes the current value in the tensor as an argument and yields
+ // the updated value.
+ let regions = (region SizedRegion<1>:$computation);
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>];
+
+ let extraClassDeclaration = [{
+ mlir::Block* getBody() { return &getComputation().front(); }
+ mlir::OpBuilder getBodyBuilder() {
+ return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
+ }
+ // The value stored in tensor[ivs].
+ mlir::Value getCurrentValue() {
+ return getRegion().getArgument(0);
+ }
+ }];
+ let hasFolder = 1;
+
+ let assemblyFormat = [{
+ $input `[` $indices `]` `:` type($input) $computation attr-dict
+ }];
+}
+
+def XLAGPU_YieldOp : XLAGPU_Op<"yield", [
+ ParentOneOf<["::xla::gpu::AtomicRMWOp", "::xla::gpu::LoopOp"]>,
+ ReturnLike, Terminator]> {
+ let summary = "Terminator for atomic_rmw ops.";
+ let arguments = (ins AnyType:$result);
+
+ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+ let assemblyFormat = "$result attr-dict `:` type($result)";
+}
+
+def XLAGPU_PureCallOp : XLAGPU_Op<"pure_call",
+ [Pure, CallOpInterface, DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Function call without side effects.";
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+ let results = (outs Variadic<AnyType>);
+ let builders = [
+ OpBuilder<(ins "mlir::func::FuncOp":$callee, CArg<"mlir::ValueRange", "{}">:$operands), [{
+ $_state.addOperands(operands);
+ $_state.addAttribute("callee", mlir::SymbolRefAttr::get(callee));
+ $_state.addTypes(callee.getFunctionType().getResults());
+ }]>];
+ let assemblyFormat = [{
+ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ operand_range getArgOperands() {
+ return getOperands();
+ }
+
+ mlir::MutableOperandRange getArgOperandsMutable() {
+ return getOperandsMutable();
+ }
+
+ mlir::CallInterfaceCallable getCallableForCallee() {
+ return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
+ }
+
+ void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
+ }
+ }];
+}
+
+def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
+ [Pure,
+ TypesMatchWith<"result type matches type of operands",
+ "operands", "results", "$_self">]> {
+ let summary = "Performs a full warp shuffle and reduces the values";
+ let description = [{
+ This op performs a full warp shuffle and reduces the results using the given
+ function. The function is invoked with the operands from the low lanes,
+ followed by the operands from the high lanes. For example:
+
+ ```
+ shuffle_reduce @argmax(%value, %idx) : (f32, index)
+ ```
+
+ Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke
+ @argmax five times. The first invocations will be
+
+ ```
+ @argmax(%value[i], %idx[i], %value[16+i], %idx[16+i])
+ ```
+ }];
+ let builders = [
+ OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
+ $_state.addOperands(operands);
+ $_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer));
+ $_state.addAttribute("max_distance",
+ mlir::IntegerAttr::get(
+ mlir::IntegerType::get(reducer.getContext(), 64),
+ max_distance));
+ $_state.addTypes(reducer.getFunctionType().getResults());
+ }]>];
+ let arguments = (ins FlatSymbolRefAttr:$reducer,
+ Variadic<AnyType>:$operands,
+ I64Attr:$max_distance);
+ let results = (outs Variadic<AnyType>:$results);
+
+ let assemblyFormat = [{
+ $reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands)
+ }];
+}
+
+def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert",
+ [Pure,
+ TypesMatchWith<"result type matches type of operands",
+ "dest", "result", "$_self">,
+ TypesMatchWith<"value type matches element type of dest",
+ "dest", "value",
+ "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
+ let summary = "Inserts a value into a tensor if a condition holds";
+ let arguments = (ins I1:$condition, AnyType:$value,
+ AnyStaticShapeTensor:$dest, Variadic<Index>:$indices);
+ let results = (outs AnyStaticShapeTensor:$result);
+
+ let assemblyFormat = [{
+ $value `into` $dest `[` $indices `]` `if` $condition attr-dict `:` type($dest)
+ }];
+}
+
+def XLAGPU_PredicatedExtractOp : XLAGPU_Op<"predicated_extract",
+ [Pure,
+ TypesMatchWith<"fallback type matches element type of src",
+ "src", "fallback",
+ "::llvm::cast<mlir::TensorType>($_self).getElementType()">,
+ TypesMatchWith<"result type matches element type of src",
+ "src", "result",
+ "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
+ let summary = "Inserts a value into a tensor if a condition holds";
+ let arguments = (ins I1:$condition, AnyType:$fallback,
+ AnyStaticShapeTensor:$src, Variadic<Index>:$indices);
+ let results = (outs AnyType:$result);
+
+ let assemblyFormat = [{
+ $src `[` $indices `]` `if` $condition `else` $fallback attr-dict `:` type($src)
+ }];
+}
+
+def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> {
+ let summary = "Applies indexing map to a list of SSA values";
+ let description = [{
+ The `apply_indexing` operation applies an affine map to a list
+ of SSA values, yielding a single SSA value. The number of dimension and
+ symbol arguments must be equal to the respective number of dimensional and
+ symbolic inputs in the affine map. The affine mapping can be
+ multi-dimensional, and so the `apply_indexing` operation always returns one
+ value. The operands and results must all have ‘index’ type.
+
+ Example:
+
+ ```mlir
+ #map = affine_map<(d0, d1)[s0] -> (d0 floordiv 8 + d1 floordiv 128, s0)>
+ %results:2 = xla_gpu_ops.apply_indexing #map (%0 in [0, 10], %1 in [0, 11])[%2 in [11, 32]]
+ ```
+ }];
+ let arguments = (ins Variadic<Index>:$operands,
+ XLAGPU_IndexingMapAttr:$indexing_map_attr);
+ let results = (outs Variadic<Index>);
+
+ let builders = [
+ OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols,
+ "const IndexingMap&":$indexing_map)>,
+ OpBuilder<(ins "mlir::ValueRange":$operands,
+ "const IndexingMap&":$indexing_map)>,
+ OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map,
+ "llvm::ArrayRef<DimVar>":$dim_vars,
+ "llvm::ArrayRef<RangeVar>":$range_vars)>,
+ ];
+ let extraClassDeclaration = [{
+ // Returns the indexing map constructed from IndexingMapAttr.
+ xla::gpu::IndexingMap getIndexingMap();
+ // Extracts the affine map from the attribute.
+ mlir::AffineMap getAffineMap() { return getIndexingMapAttr().getMap(); }
+ }];
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def LoopOp : XLAGPU_Op<"loop", [
+ AttrSizedOperandSegments, Pure,
+ SingleBlockImplicitTerminator<"xla::gpu::YieldOp">
+ ]> {
+ let summary = "Loop nest that iterates over all feasible values of RangeVars.";
+ let description = [{
+
+ ```mlir
+ #map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1),
+ domain:
+ d0 in [0, 3],
+ s0 in [0, 1024],
+ s1 in [0, 32]
+ >
+ // Initial sum set to 0.
+ %sum_0 = arith.constant 0.0 : f32
+ %dim = arith.constant 1 : index
+ // iter_args binds initial values to the loop's region arguments.
+ %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_iter = %sum_0) -> (f32) {
+ %t = tensor.extract %buffer[%i, %j] : tensor<1024x32xf32>
+ %sum_next = arith.addf %sum_iter, %t : f32
+ // Yield current iteration sum to next iteration %sum_iter or to %sum
+ // if final iteration.
+ scf.yield %sum_next : f32
+ }
+ ```
+ }];
+ let arguments = (ins XLAGPU_IndexingMapAttr:$indexing_map_attr,
+ Variadic<Index>:$dims,
+ Variadic<AnyType>:$inits);
+ let results = (outs Variadic<AnyType>);
+ let regions = (region SizedRegion<1>:$region);
+
+ let builders = [
+ OpBuilder<(ins "IndexingMapAttr":$indexing_map_attr,
+ "mlir::ValueRange":$dims, "mlir::ValueRange":$inits,
+ CArg<"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::ValueRange, mlir::ValueRange)>",
+ "nullptr">)>,
+ OpBuilder<(ins "const IndexingMap&":$indexing_map,
+ "mlir::ValueRange":$dims, "mlir::ValueRange":$inits,
+ CArg<"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::ValueRange, mlir::ValueRange)>",
+ "nullptr">)>
+ ];
+
+ let extraClassDeclaration = [{
+ using BodyBuilderFn =
+ llvm::function_ref<void(mlir::OpBuilder&, mlir::Location,
+ mlir::ValueRange, mlir::ValueRange)>;
+
+ // Returns the indexing map constructed from IndexingMapAttr.
+ xla::gpu::IndexingMap getIndexingMap();
+ int64_t getNumInductionVars() {
+ return getBody()->getNumArguments() - getNumResults();
+ }
+ mlir::BlockArgument getInductionVar(int64_t index) {
+ return getBody()->getArgument(index);
+ }
+ mlir::Block::BlockArgListType getInductionVars() {
+ return getBody()->getArguments().take_front(getNumInductionVars());
+ }
+ mlir::Block::BlockArgListType getRegionIterArgs() {
+ return getBody()->getArguments().drop_front(getNumInductionVars());
+ }
+ }];
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XLAGPU_MaterializeOp : XLAGPU_Op<"materialize", [AttrSizedOperandSegments]> {
+ let summary = "Reads a tensor into registers";
+ let arguments = (ins Variadic<AnyType>:$input,
+ Variadic<Index>:$indices,
+ FlatSymbolRefAttr:$callee,
+ XLAGPU_IndexingMapAttr:$map);
+ let results = (outs XLAGPU_IndexedVectorType:$result);
+ let hasVerifier = 1;
+ let assemblyFormat = [{
+ $callee `(` $input `)` `at` $map `(` $indices `)` attr-dict `:` functional-type($input, results)
+ }];
+}
+
+def XLAGPU_InsertOp : XLAGPU_Op<"insert", []> {
+ let summary = "Inserts an indexed vector into a tensor";
+ let arguments = (ins XLAGPU_IndexedVectorType:$source,
+ Variadic<Index>:$indices,
+ AnyRankedTensor:$dest,
+ XLAGPU_IndexingMapAttr:$map);
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ $source `into` $dest `at` $map `(` $indices `)` attr-dict `:` type($source) `->` type($dest) `into` type($result)
+ }];
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc
new file mode 100644
index 0000000..2d9076d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+#include <gtest/gtest.h>
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Pair;
+using ::testing::UnorderedElementsAre;
+
+class XLAGPUOpsTest : public HloTestBase {
+ public:
+ mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) {
+ auto map = IndexingMap(
+ ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_),
+ /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}},
+ /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{});
+ map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_),
+ Interval{0, 1});
+ map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_),
+ Interval{0, 2});
+ map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3});
+ map.AddConstraint(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4});
+ map.AddConstraint(ParseAffineExpr("d1 mod 32", &mlir_context_),
+ Interval{0, 6});
+
+ auto constraints_for_variables = GetConstraintsForVariables(map);
+ EXPECT_THAT(constraints_for_variables.constraints_for_dims[0],
+ UnorderedElementsAre());
+ EXPECT_THAT(
+ constraints_for_variables.constraints_for_dims[1],
+ UnorderedElementsAre(
+ Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}),
+ Pair(ParseAffineExpr("d1 mod 32", &mlir_context_), Interval{0, 6})));
+ EXPECT_THAT(
+ constraints_for_variables.constraints_for_symbols[0],
+ UnorderedElementsAre(
+ Pair(ParseAffineExpr("s0 mod 4", &mlir_context_), Interval{0, 1}),
+ Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3})));
+ EXPECT_THAT(
+ constraints_for_variables.constraints_for_symbols[1],
+ UnorderedElementsAre(
+ Pair(ParseAffineExpr("s1 mod 4", &mlir_context_), Interval{0, 2}),
+ Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}),
+ Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4})));
+}
+
+TEST_F(XLAGPUOpsTest, GetConstraintsForVariablesEmpty) {
+ auto map = IndexingMap(
+ ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_),
+ /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}},
+ /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{});
+ auto constraints_for_variables = GetConstraintsForVariables(map);
+ EXPECT_THAT(constraints_for_variables.constraints_for_dims,
+ ElementsAre(IsEmpty(), IsEmpty()));
+ EXPECT_THAT(constraints_for_variables.constraints_for_symbols,
+ ElementsAre(IsEmpty(), IsEmpty()));
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc
new file mode 100644
index 0000000..1c1b218
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc
@@ -0,0 +1,57 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+
+#include "mlir/IR/Attributes.h" // IWYU pragma: keep
+#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep
+#include "mlir/IR/Dialect.h" // IWYU pragma: keep
+#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc"
+#include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc"
+#undef GET_ATTRDEF_CLASSES
+#define GET_TYPEDEF_LIST
+#define GET_TYPEDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc"
+
+namespace xla {
+namespace gpu {
+
+mlir::Type IndexedVectorType::parse(mlir::AsmParser& parser) {
+ mlir::SmallVector<int64_t, 4> shape;
+ mlir::Type type;
+ IndexingMapAttr indexing_map_attr;
+ if (parser.parseLess() ||
+ parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
+ parser.parseType(type) || parser.parseComma() ||
+ parser.parseAttribute(indexing_map_attr) || parser.parseGreater()) {
+ return {};
+ }
+ return IndexedVectorType::get(parser.getContext(), shape, type,
+ indexing_map_attr);
+}
+
+void IndexedVectorType::print(mlir::AsmPrinter& printer) const {
+ printer << "<";
+ printer.printDimensionList(getShape());
+ printer << "x" << getElementType() << ", " << getIndexingMapAttr() << ">";
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td
new file mode 100644
index 0000000..5d73344
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td
@@ -0,0 +1,50 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td"
+
+class XLAGPU_Type<string name, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<XlaGpuDialect, name, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+def XLAGPU_IndexedVectorType : XLAGPU_Type<"IndexedVector", "indexed_vector",
+ [ShapedTypeInterface, ValueSemantics]> {
+ let summary = "Vector type with a specified layout";
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ "mlir::Type":$elementType,
+ XLAGPU_IndexingMapAttr:$indexing_map_attr
+ );
+ let hasCustomAssemblyFormat = 1;
+ let extraClassDeclaration = [{
+ IndexedVectorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
+ mlir::Type elementType) const {
+ return IndexedVectorType::get(getContext(), shape.value_or(getShape()),
+ elementType, getIndexingMapAttr());
+ }
+
+ bool hasRank() const { return true; }
+ }];
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD
new file mode 100644
index 0000000..98d8ade
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD
@@ -0,0 +1,406 @@
+load("//xla:xla.bzl", "xla_cc_test")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//xla/service/gpu/fusions:__pkg__"],
+ licenses = ["notice"],
+)
+
+cc_library(
+ name = "in_place_dynamic_update_slice",
+ srcs = ["in_place_dynamic_update_slice.cc"],
+ hdrs = ["in_place_dynamic_update_slice.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:ir_emitter",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/llvm_ir:dynamic_update_slice_util",
+ "//xla/service/llvm_ir:fused_ir_emitter",
+ "//xla/service/llvm_ir:ir_array",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:IR",
+ ],
+)
+
+xla_cc_test(
+ name = "in_place_dynamic_update_slice_test",
+ srcs = ["in_place_dynamic_update_slice_test.cc"],
+ deps = [
+ ":in_place_dynamic_update_slice",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions",
+ "//xla/service/gpu/model:affine_map_printer",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "loop",
+ srcs = ["loop.cc"],
+ hdrs = ["loop.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:ir_emitter",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu:parallel_loop_emitter",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/llvm_ir:fused_ir_emitter",
+ "//xla/service/llvm_ir:ir_array",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/numeric:bits",
+ "@com_google_absl//absl/status",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:macros",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "loop_test",
+ srcs = ["loop_test.cc"],
+ deps = [
+ "//xla:status_macros",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:affine_map_printer",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "scatter",
+ srcs = ["scatter.cc"],
+ hdrs = ["scatter.h"],
+ deps = [
+ ":loop",
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:ir_emitter",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu:parallel_loop_emitter",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/llvm_ir:fused_ir_emitter",
+ "//xla/service/llvm_ir:ir_array",
+ "//xla/service/llvm_ir:llvm_util",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "scatter_test",
+ srcs = ["scatter_test.cc"],
+ deps = [
+ ":scatter",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions",
+ "//xla/service/gpu/model:affine_map_printer",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "tiling_util",
+ srcs = ["tiling_util.cc"],
+ hdrs = ["tiling_util.h"],
+ visibility = ["//xla/service/gpu:__subpackages__"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:target_util",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/llvm_ir:ir_array",
+ "//xla/service/llvm_ir:kernel_support_library",
+ "//xla/service/llvm_ir:llvm_loop",
+ "//xla/service/llvm_ir:llvm_util",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "reduction",
+ srcs = ["reduction.cc"],
+ hdrs = ["reduction.h"],
+ deps = [
+ ":tiling_util",
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:buffer_assignment",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:ir_emitter",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu:kernel_arguments",
+ "//xla/service/gpu:kernel_reuse_cache",
+ "//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu:parallel_loop_emitter",
+ "//xla/service/gpu:reduction_utils",
+ "//xla/service/gpu:target_util",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/fusions:reduction_base",
+ "//xla/service/gpu/fusions:thunk_util",
+ "//xla/service/gpu/runtime:kernel_thunk",
+ "//xla/service/gpu/runtime:thunk",
+ "//xla/service/llvm_ir:fused_ir_emitter",
+ "//xla/service/llvm_ir:ir_array",
+ "//xla/service/llvm_ir:kernel_support_library",
+ "//xla/service/llvm_ir:llvm_loop",
+ "//xla/service/llvm_ir:llvm_util",
+ "//xla/service/llvm_ir:loop_emitter",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:Support",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "reduction_test",
+ srcs = ["reduction_test.cc"],
+ deps = [
+ ":reduction",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//mlir:IR",
+ ],
+)
+
+cc_library(
+ name = "concatenate",
+ srcs = ["concatenate.cc"],
+ hdrs = ["concatenate.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:ir_emitter",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu:parallel_loop_emitter",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/llvm_ir:fused_ir_emitter",
+ "//xla/service/llvm_ir:ir_array",
+ "//xla/service/llvm_ir:loop_emitter",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/status",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "concatenate_test",
+ srcs = ["concatenate_test.cc"],
+ deps = [
+ ":concatenate",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions",
+ "//xla/service/gpu/model:affine_map_printer",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//mlir:IR",
+ ],
+)
+
+cc_library(
+ name = "transpose",
+ srcs = ["transpose.cc"],
+ hdrs = ["transpose.h"],
+ deps = [
+ ":tiling_util",
+ "//xla:permutation_util",
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:ir_emitter",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu:target_util",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/llvm_ir:fused_ir_emitter",
+ "//xla/service/llvm_ir:ir_array",
+ "//xla/service/llvm_ir:llvm_util",
+ "//xla/service/llvm_ir:loop_emitter",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "transpose_test",
+ srcs = ["transpose_test.cc"],
+ deps = [
+ ":transpose",
+ "//xla:status_macros",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "input_slices",
+ srcs = ["input_slices.cc"],
+ hdrs = ["input_slices.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:elemental_ir_emitter",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:ir_emitter",
+ "//xla/service/gpu:ir_emitter_context",
+ "//xla/service/gpu:launch_dimensions",
+ "//xla/service/gpu:parallel_loop_emitter",
+ "//xla/service/gpu/fusions:fusion_emitter",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/llvm_ir:fused_ir_emitter",
+ "//xla/service/llvm_ir:ir_array",
+ "//xla/service/llvm_ir:kernel_support_library",
+ "//xla/service/llvm_ir:llvm_loop",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:ir_headers",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "input_slices_test",
+ srcs = ["input_slices_test.cc"],
+ deps = [
+ ":input_slices",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions",
+ "//xla/service/gpu/model:affine_map_printer",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//mlir:IR",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/README.md b/third_party/xla/xla/service/gpu/fusions/legacy/README.md
new file mode 100644
index 0000000..0fa6bb9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/README.md
@@ -0,0 +1,8 @@
+# Deprecated emitters
+
+The emitters in this directory are deprecated. Please do not add any new
+features. If you believe you need to add a feature, please reach out and
+describe your use case.
+
+These emitters have more modern MLIR-based equivalents in the directory above
+this one.
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc
new file mode 100644
index 0000000..8bb0e04
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc
@@ -0,0 +1,137 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/concatenate.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/status/status.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/loop_emitter.h"
+#include "xla/shape.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) {
+ const HloInstruction& concat = analysis.fusion_hero(0).instruction();
+ int64_t dim = concat.concatenate_dimension();
+ auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
+ return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim);
+ };
+ HloInstruction* operand = *absl::c_max_element(concat.operands(), less);
+ return operand->shape();
+}
+
+ConcatenateFusion::ConcatenateFusion(const HloFusionAnalysis& analysis)
+ : analysis_(analysis) {}
+
+std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const {
+ return std::nullopt;
+}
+
+std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const {
+ return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1,
+ GetLargestConcatOperandShape(analysis_),
+ ctx);
+}
+
+absl::Status ConcatenateFusion::EmitKernel(
+ IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
+ GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+ FusedIrEmitter fused_emitter(elemental_emitter);
+ for (int i = 0; i < fusion.fused_parameters().size(); i++) {
+ fused_emitter.BindGenerator(
+ *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
+ return inputs[i].EmitReadArrayElement(index, builder);
+ });
+ }
+
+ llvm::Type* index_type =
+ GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
+
+ const HloInstruction& concat = analysis_.fusion_hero(0).instruction();
+ int64_t concat_dim = concat.concatenate_dimension();
+ int64_t operand_offset = 0;
+
+ // Emit the slices that correspond to the operands of the concat hero.
+ for (const HloInstruction* operand : concat.operands()) {
+ llvm_ir::BodyEmitter body_emitter =
+ [&](const llvm_ir::IrArray::Index& operand_index) -> absl::Status {
+ // Bind concat to generate the current operand.
+ TF_ASSIGN_OR_RETURN(auto operand_generator,
+ fused_emitter.GetGenerator(*operand));
+ fused_emitter.BindGenerator(concat, [&](llvm_ir::IrArray::Index) {
+ return operand_generator(operand_index);
+ });
+
+ // Create the index of the slice corresponding to the current operand.
+ llvm_ir::IrArray::Index result_index = operand_index.AddOffsetToDim(
+ llvm::ConstantInt::get(index_type, operand_offset), concat_dim,
+ builder);
+ operand_offset += operand->shape().dimensions(concat_dim);
+
+ // Generate and write out the slice for each root.
+ for (const auto& [output, root] :
+ llvm::zip_equal(outputs, analysis_.fusion_roots())) {
+ llvm_ir::IrArray::Index root_index = result_index.SourceIndexOfBitcast(
+ concat.shape(), root.shape(), builder);
+ TF_ASSIGN_OR_RETURN(auto generator,
+ fused_emitter.GetGenerator(root.instruction()));
+ TF_ASSIGN_OR_RETURN(llvm::Value * value, generator(root_index));
+ output.EmitWriteArrayElement(root_index, value, builder);
+ }
+ return absl::OkStatus();
+ };
+
+ ParallelLoopEmitter emitter(body_emitter, operand->shape(), launch_dims,
+ builder);
+ TF_RETURN_IF_ERROR(emitter.EmitLoop(fusion.name(), index_type));
+ }
+
+ return absl::OkStatus();
+}
+
+LaunchDimensions ConcatenateFusion::launch_dimensions() const {
+ return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_),
+ analysis_.device_info());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h
new file mode 100644
index 0000000..be0465b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h
@@ -0,0 +1,67 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_
+
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+
+namespace xla {
+namespace gpu {
+
+const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis);
+
+// Emits a kernel for the given hlo instruction where each thread produces
+// one element of each concat operand.
+class ConcatenateFusion : public KernelFusionEmitterBase {
+ public:
+ explicit ConcatenateFusion(const HloFusionAnalysis& analysis);
+ LaunchDimensions launch_dimensions() const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const override;
+
+ protected:
+ absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const override;
+
+ private:
+ const HloFusionAnalysis& analysis_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc
new file mode 100644
index 0000000..ee63bda
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/concatenate.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class ConcatenateTest : public HloTestBase {
+ public:
+ void SetUp() override {
+ HloTestBase::SetUp();
+ printer_ =
+ AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+ {"chunk_id", "unroll_id"});
+ }
+
+ protected:
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.set_xla_gpu_mlir_emitter_level(0);
+ return opts;
+ }
+ AffineMapPrinter printer_;
+ mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(ConcatenateTest, ThreadIndexing) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation {
+ param0 = f32[200] parameter(0)
+ param1 = f32[400] parameter(1)
+ param2 = f32[300] parameter(2)
+ ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0}
+ }
+ ENTRY main {
+ param0 = f32[200] parameter(0)
+ param1 = f32[400] parameter(1)
+ param2 = f32[300] parameter(2)
+ ROOT fusion = f32[900] fusion(param0, param1, param2),
+ calls=fused_computation, kind=kLoop
+ }
+ )")
+ .value();
+
+ stream_executor::DeviceDescription device_info =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+ auto emitter =
+ GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+ auto fusion = dynamic_cast<ConcatenateFusion*>(emitter.get());
+ ASSERT_NE(fusion, nullptr);
+
+ constexpr auto kIndexing = R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
+ (bl_x * 128 + th_x)
+ domain:
+ th_x in [0, 127]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 3]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ bl_x * 128 + th_x in [0, 399]
+ )";
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kIndexing));
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kIndexing));
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kIndexing));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc
new file mode 100644
index 0000000..38a3e5b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc
@@ -0,0 +1,105 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h"
+
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/dynamic_update_slice_util.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr int kDUSUpdateIndex = 1;
+
+} // namespace
+
+LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const {
+ const auto& update_shape = dus_ops_.front().GetOperand(1).shape();
+ return CalculateLaunchDimensions(update_shape, analysis_.device_info());
+}
+
+std::optional<IndexingMap>
+InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* mlir_context) const {
+ if (hero_operand_index != kDUSUpdateIndex) {
+ return std::nullopt;
+ }
+ auto launch_dims = launch_dimensions();
+ // It is guaranteed that all DUS ops have the same output shape at this point.
+ const auto& update_shape =
+ dus_ops_.front().GetOperand(kDUSUpdateIndex).shape();
+ return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1,
+ update_shape, mlir_context);
+}
+
+absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel(
+ IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
+ // In case a dynamic slice update's output is bitcasted, we need to ensure we
+ // write to the output array using the shape and layout of the dynamic slice
+ // update. This cast is known to be safe to do iff, in the case the output of
+ // the dynamic slice update is bitcasted, that bitcast is either the fusion's
+ // output, or has a single user and is part of the fusion's tuple output.
+ // This condition should be enforced explicitly in the
+ // 'CanEmitFusedDynamicUpdateSliceInPlaceForGpu' matcher.
+ for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
+ output = output.CastToShape(op.shape(), builder);
+ }
+
+ auto* fused_computation = fusion.fused_instructions_computation();
+ GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+ FusedIrEmitter fused_emitter(elemental_emitter);
+ for (auto [index, input] : llvm::enumerate(inputs)) {
+ auto fused_operand = fused_computation->parameter_instruction(index);
+ fused_emitter.BindGenerator(
+ *fused_operand, [input = input, builder,
+ fused_operand](const llvm_ir::IrArray::Index& index) {
+ return input.EmitReadArrayElement(index, builder,
+ fused_operand->name());
+ });
+ }
+
+ std::vector<std::pair<const HloInstruction*, const llvm_ir::IrArray>>
+ dus_and_output_array;
+ dus_and_output_array.reserve(dus_ops_.size());
+
+ for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
+ dus_and_output_array.push_back(std::make_pair(&op.instruction(), output));
+ }
+
+ return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
+ fused_computation, dus_and_output_array, &fused_emitter, launch_dims,
+ builder);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h
new file mode 100644
index 0000000..db12c3c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h
@@ -0,0 +1,98 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// Fusion node where the root is either:
+// 1. a dynamic-update-slice op
+// 2. a bitcast of a dynamic-update-slice op
+// 3. a tuple op returning the result of several dynamic-update-slice ops
+// 4. a tuple op returning the result of several bitcast
+// dynamic-update-slice ops
+//
+// Additionally, all the dynamic-update-slice ops have exactly one user. The
+// fusion parameter that they update can have users (in addition to the
+// dynamic-update-slice op) that read in either
+// a. a dynamic-slice corresponding exactly to the slice of the parameter that
+// is updated by the dynamic-update-slice op
+// b. a dynamic-slice reading in a single element anywhere in the parameter.
+// This is only allowed if the dynamic-update-slice op updates a single
+// element
+//
+// In both cases, the additional users must not flow into any other output
+// than the dynamic-slice-update corresponding to that particular slice of the
+// parameter.
+//
+// The assumption is that each op's input (i.e. array to update) shares the
+// same slice as its output. In this case, we have a special algorithm that
+// modifies the output in place without touching the un-updated elements. The
+// update slice is assumed to be the exact same for all the
+// dynamic-update-slice ops.
+class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase {
+ public:
+ explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis)
+ : analysis_(analysis),
+ dus_ops_(
+ GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
+ LaunchDimensions launch_dimensions() const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const override {
+ // The mapping cannot be statically computed in general, since the offsets
+ // are unknown.
+ return std::nullopt;
+ }
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* mlir_context) const override;
+
+ protected:
+ absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const override;
+
+ const HloFusionAnalysis& analysis_;
+ std::vector<HloInstructionAdaptor> dus_ops_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc
new file mode 100644
index 0000000..c4fd277
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc
@@ -0,0 +1,144 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase {
+ public:
+ void SetUp() override {
+ HloTestBase::SetUp();
+ printer_ =
+ AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+ {"chunk_id", "unroll_id"});
+ }
+
+ protected:
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.set_xla_gpu_mlir_emitter_level(0);
+ return opts;
+ }
+ AffineMapPrinter printer_;
+ mlir::MLIRContext mlir_context_;
+ stream_executor::DeviceDescription device_info_ =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+};
+
+TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation {
+ in = f32[20,30] parameter(0)
+ updates = f32[5,6] parameter(1)
+ i0 = s32[] parameter(2)
+ i1 = s32[] parameter(3)
+ ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1)
+ }
+ ENTRY entry {
+ in = f32[20,30] parameter(0)
+ updates = f32[5,6] parameter(1)
+ i0 = s32[] constant(2)
+ i1 = s32[] constant(3)
+ ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation
+ }
+ )"));
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis_fused = HloFusionAnalysis::Create(*root, device_info_);
+
+ auto emitter =
+ GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+ auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
+ ASSERT_NE(fusion, nullptr);
+
+ auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_);
+ EXPECT_THAT(thread_id_update_indexing->ToString(printer_),
+ MatchIndexingString(R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+ th_x floordiv 6, th_x mod 6)
+ domain:
+ th_x in [0, 29]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 0]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ )"));
+ auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
+ EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt));
+}
+
+TEST_F(InPlaceDynamicUpdateSliceFusionTest, ProduceConsumerFusion) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ fused_computation.1 {
+ param_0 = bf16[1,2,5,1,2] parameter(0)
+ bitcast = bf16[1,5,1,2,2] bitcast(param_0)
+ param_1 = bf16[1,1,1,2,2] parameter(1)
+ param_2 = s32[] parameter(2)
+ param_3 = s32[] parameter(3)
+ ROOT dynamic-update-slice = bf16[1,5,1,2,2] dynamic-update-slice(bitcast, param_1, param_2, param_3, param_2, param_2, param_2)
+ }
+
+ ENTRY entry_computation {
+ param_0.2 = bf16[1,2,5,1,2] parameter(3)
+ param_1.2 = bf16[1,1,1,2,2] parameter(0)
+ param_2.2 = s32[] parameter(1)
+ param_3.2 = s32[] parameter(2)
+ fusion = bf16[1,5,1,2,2] fusion(param_0.2, param_1.2, param_2.2, param_3.2), kind=kLoop, calls=fused_computation.1
+ ROOT bitcast.1 = bf16[1,2,5,1,2] bitcast(fusion)
+ }
+ )"));
+
+ auto* root = module->entry_computation()->root_instruction();
+
+ auto analysis_fused =
+ HloFusionAnalysis::Create(*root->operand(0), *root, device_info_);
+
+ auto emitter =
+ GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+
+ auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
+
+ ASSERT_NE(fusion, nullptr);
+ EXPECT_EQ(fusion->launch_dimensions().launch_bound(), 4 /* update size */);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc
new file mode 100644
index 0000000..d336f92
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc
@@ -0,0 +1,220 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/input_slices.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/elemental_ir_emitter.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/kernel_support_library.h"
+#include "xla/service/llvm_ir/llvm_loop.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Emits code for slices based on the below structure. An if statement with
+// a guarding condition is generated for each ROOT slice.
+//
+// Pseudo code:
+//
+// Compute values of slice input operands
+//
+// Compute guarding_cond0
+// if (guarding_cond0) {
+// Write to output of slice0
+// }
+//
+// Compute guarding_cond1
+// if (guarding_cond1) {
+// Write to output of slice1
+// }
+//
+absl::Status EmitElementForInputFusibleSlices(
+ ElementalIrEmitter& elemental_emitter,
+ const HloComputation* fused_computation,
+ const std::vector<llvm_ir::IrArray>& inputs,
+ const std::vector<llvm_ir::IrArray>& outputs,
+ const llvm_ir::IrArray::Index& index, llvm::IRBuilder<>* builder) {
+ VLOG(10) << "Emitting slice input fusion for "
+ << fused_computation->ToString();
+
+ HloInstruction* slice_or_tuple = fused_computation->root_instruction();
+ auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
+ if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
+ return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
+ }
+ CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
+ return slice_or_tuple->operands();
+ }();
+
+ // Emit input operand values of slices.
+ std::vector<llvm::Value*> input_ir_values;
+ FusedIrEmitter fused_emitter(elemental_emitter);
+ for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ fused_emitter.BindGenerator(
+ *fused_computation->parameter_instruction(i),
+ [&inputs, i, builder](llvm_ir::IrArray::Index index) {
+ return inputs[i].EmitReadArrayElement(index, builder);
+ });
+ }
+ for (const HloInstruction* slice : slice_instructions) {
+ auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0));
+ input_ir_values.push_back(input_generator(index).value());
+ }
+
+ // Emit for slice_instructions.
+ KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
+ for (int64_t i = 0; i < slice_instructions.size(); ++i) {
+ HloInstruction* slice = slice_instructions[i];
+
+ // guarding_cond := index >= start && index < limit, for each dim.
+ std::vector<llvm::Value*> index_within_ranges;
+ for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
+ CHECK_EQ(slice->slice_strides(dim), 1);
+ auto larger_or_equal_than_start = builder->CreateICmpSGE(
+ index.multidim()[dim],
+ index.GetConstantWithIndexType(slice->slice_starts(dim)));
+ llvm::Value* smaller_than_limit = builder->CreateICmpSLT(
+ index.multidim()[dim],
+ index.GetConstantWithIndexType(slice->slice_limits(dim)));
+ llvm::Value* within_range =
+ builder->CreateAnd(larger_or_equal_than_start, smaller_than_limit);
+ index_within_ranges.push_back(within_range);
+ }
+ llvm::Value* guarding_cond = builder->CreateAnd(index_within_ranges);
+
+ auto emit_slice_elem_func = [&] {
+ const std::vector<llvm::Value*>& src_multidim = index.multidim();
+ std::vector<llvm::Value*> dst_multidim(src_multidim.size());
+ for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
+ dst_multidim[dim] = builder->CreateSub(
+ src_multidim[dim],
+ index.GetConstantWithIndexType(slice->slice_starts(dim)));
+ }
+ const llvm_ir::IrArray& src_ir_array = outputs[i];
+ llvm_ir::IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
+ index.GetType());
+ src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
+ builder);
+ };
+
+ ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func);
+ }
+ return absl::OkStatus();
+}
+
+// Gets the input shape of the ROOT slices, which will be used as the kernel
+// launch dims. The slice input fusion requires the input shapes of the ROOT
+// slices to be the same although the (slice) output shapes can be different.
+//
+// Returns the input shape of the ROOT slices if all the input shapes of ROOT
+// slices are the same and the slices are non-strided. Otherwise, returns
+// FailedPrecondition.
+absl::StatusOr<Shape> GetConsistentInputShapeForRootSlices(
+ const HloComputation* fused_computation) {
+ const HloInstruction& root = *fused_computation->root_instruction();
+ if (root.opcode() == HloOpcode::kSlice) {
+ return root.operands()[0]->shape();
+ }
+
+ CHECK_EQ(root.opcode(), HloOpcode::kTuple);
+ const Shape& first_slice_operand_shape =
+ root.operands()[0]->operands()[0]->shape();
+ for (size_t i = 1; i < root.operands().size(); ++i) {
+ const HloInstruction* slice = root.operands()[i];
+ const Shape& operand_shape = slice->operands()[0]->shape();
+ if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
+ operand_shape)) {
+ return FailedPrecondition(
+ "Fused slices do not have the same input shape, fused computation = "
+ "%s.",
+ root.parent()->name());
+ }
+ }
+
+ return first_slice_operand_shape;
+}
+
+} // namespace
+
+LaunchDimensions InputSlicesFusion::launch_dimensions() const {
+ const auto& root = analysis_.fusion_root(0).instruction();
+ const auto& shape = root.operand(0)->shape();
+ return CalculateLaunchDimensions(shape, analysis_.device_info(),
+ {unroll_factor_});
+}
+
+std::optional<IndexingMap> InputSlicesFusion::ComputeThreadIdToOutputIndexing(
+ int64_t output_id, mlir::MLIRContext* ctx) const {
+ // The mapping here is trivial and the same for all outputs - slice offsets
+ // are applied in the indexing from slice outputs to slice inputs.
+ auto launch_dims = launch_dimensions();
+ // The implementation requires the shapes and layouts to be the same, but we
+ // still use the requested output's shape for clarity.
+ const auto& shape = analysis_.fusion_root(output_id).shape();
+ return GetDefaultThreadIdIndexingMap(launch_dims, unroll_factor_, shape, ctx);
+}
+
+absl::Status InputSlicesFusion::EmitKernel(
+ IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
+ TF_ASSIGN_OR_RETURN(Shape element_shape,
+ GetConsistentInputShapeForRootSlices(
+ fusion.fused_instructions_computation()));
+ LaunchDimensionsConfig launch_config;
+ launch_config.unroll_factor = unroll_factor_;
+ GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+ return ParallelLoopEmitter(
+ [&](const llvm_ir::IrArray::Index index) -> absl::Status {
+ return EmitElementForInputFusibleSlices(
+ elemental_emitter, fusion.fused_instructions_computation(),
+ inputs, outputs, index, builder);
+ },
+ element_shape, launch_dims, builder, launch_config)
+ .EmitLoop(
+ fusion.name(),
+ GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h
new file mode 100644
index 0000000..e653224
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h
@@ -0,0 +1,79 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+
+// Generates code for input-fusible slices.
+//
+// Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes
+// of all ROOT slices need to be the same while their output shapes can be
+// different. On the other hand, the input ranges of slices can be
+// overlapping. Further generalization/specialization when the needs are seen
+// in the future.
+class InputSlicesFusion : public KernelFusionEmitterBase {
+ public:
+ explicit InputSlicesFusion(const HloFusionAnalysis& analysis)
+ : analysis_(analysis),
+ unroll_factor_(CeilOfRatio(
+ 8, analysis.input_output_info().smallest_output_dtype_bits)) {}
+ LaunchDimensions launch_dimensions() const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t output_id, mlir::MLIRContext* ctx) const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const override {
+ // TODO(b/319081342): Implement this.
+ return std::nullopt;
+ }
+
+ protected:
+ absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const override;
+
+ private:
+ const HloFusionAnalysis& analysis_;
+ const int unroll_factor_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc
new file mode 100644
index 0000000..bb9f510
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/input_slices.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class InputSlicesTest : public HloTestBase {
+ public:
+ void SetUp() override {
+ HloTestBase::SetUp();
+ printer_ =
+ AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+ {"chunk_id", "unroll_id"});
+ }
+
+ protected:
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.set_xla_gpu_mlir_emitter_level(0);
+ return opts;
+ }
+ AffineMapPrinter printer_;
+ mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(InputSlicesTest, ThreadIndexing) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation {
+ %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
+ slice0 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[1:3],[0:3],[2:7]}
+ slice1 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[0:2],[0:3],[2:7]}
+ ROOT tuple = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) tuple(slice0, slice1)
+ }
+
+ ENTRY entry {
+ %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
+ ROOT %fusion = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) fusion(%input), kind=kLoop, calls=fused_computation
+ })")
+ .value();
+
+ stream_executor::DeviceDescription device_info =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+ auto emitter =
+ GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+ auto fusion = dynamic_cast<InputSlicesFusion*>(emitter.get());
+ ASSERT_NE(fusion, nullptr);
+
+ auto thread_id_to_output_indexing =
+ fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_);
+ EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+ MatchIndexingString(R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0,
+ ((bl_x * 128 + th_x) floordiv 3) mod 2,
+ (bl_x * 128 + th_x) mod 3,
+ (bl_x * 128 + th_x) floordiv 6)
+ domain:
+ th_x in [0, 127]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 1]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ bl_x * 128 + th_x in [0, 29]
+ )"));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc
new file mode 100644
index 0000000..e6ce5f1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc
@@ -0,0 +1,132 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/loop.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <optional>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/numeric/bits.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/macros.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+const Shape& GetElementShape(const HloFusionAnalysis& analysis) {
+ const Shape* shape = &analysis.fusion_root(0).shape();
+ while (shape->IsTuple()) {
+ shape = &shape->tuple_shapes(0);
+ }
+ return *shape;
+}
+
+} // namespace
+
+LoopFusion::LoopFusion(const HloFusionAnalysis& analysis)
+ : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {}
+
+std::optional<IndexingMap> LoopFusion::ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const {
+ auto launch_dims = launch_dimensions();
+ return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor,
+ GetElementShape(analysis_), ctx);
+}
+
+std::optional<IndexingMap> LoopFusion::ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const {
+ std::optional<IndexingMap> thread_id_to_output_indexing =
+ ComputeThreadIdToOutputIndexing(root_index, ctx);
+ if (!thread_id_to_output_indexing.has_value()) {
+ return std::nullopt;
+ }
+ const HloInstruction* fusion_root =
+ &analysis_.fusion_root(root_index).instruction();
+ auto output_to_input_indexing =
+ ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx);
+ IndexingMapSet output_to_input_indexing_set =
+ output_to_input_indexing.indexing_maps[hero_operand_index];
+ // Since we are computing the indexing for a non-fusion op, there is only one
+ // indexing map per operand.
+ CHECK_EQ(output_to_input_indexing_set.size(), 1);
+ IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps(
+ *thread_id_to_output_indexing, *output_to_input_indexing_set.begin());
+ thread_id_to_input_indexing_map.Simplify();
+ return thread_id_to_input_indexing_map;
+}
+
+absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const {
+ GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+ FusedIrEmitter fused_emitter(elemental_emitter);
+ for (int i = 0; i < fusion.fused_parameters().size(); i++) {
+ fused_emitter.BindGenerator(
+ *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
+ return inputs[i].EmitReadArrayElement(index, builder);
+ });
+ }
+ TF_ASSIGN_OR_RETURN(
+ auto element_generator,
+ fused_emitter.GetGenerator(*fusion.fused_expression_root()));
+
+ llvm::Type* index_type =
+ GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
+
+ return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder,
+ config_)
+ .EmitLoop(fusion.name(), index_type);
+}
+
+LaunchDimensions LoopFusion::launch_dimensions() const {
+ return CalculateLaunchDimensions(GetElementShape(analysis_),
+ analysis_.device_info(), config_);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.h b/third_party/xla/xla/service/gpu/fusions/legacy/loop.h
new file mode 100644
index 0000000..30e5007
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop.h
@@ -0,0 +1,65 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// Generic loop fusion.
+class LoopFusion : public KernelFusionEmitterBase {
+ public:
+ explicit LoopFusion(const HloFusionAnalysis& analysis);
+ LaunchDimensions launch_dimensions() const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const override;
+
+ protected:
+ absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const override;
+
+ private:
+ const HloFusionAnalysis& analysis_;
+ LaunchDimensionsConfig config_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc
new file mode 100644
index 0000000..a05508c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc
@@ -0,0 +1,222 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/statusor.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class LoopTest : public HloTestBase {
+ public:
+ void SetUp() override {
+ HloTestBase::SetUp();
+
+ printer_ =
+ AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+ {"chunk_id", "unroll_id"});
+ }
+
+ protected:
+ stream_executor::DeviceDescription device_info_ =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ AffineMapPrinter printer_;
+ mlir::MLIRContext mlir_context_;
+};
+
+absl::StatusOr<std::unique_ptr<KernelFusionInterface>> GetFusion(
+ const HloFusionAnalysis& analysis) {
+ auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
+ auto fusion = dynamic_cast<KernelFusionInterface*>(emitter.get());
+ TF_RET_CHECK(fusion != nullptr);
+
+ emitter.release();
+ return std::unique_ptr<KernelFusionInterface>{fusion};
+}
+
+TEST_F(LoopTest, ThreadIndexingUnrolled) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ neg {
+ %input = f32[100,200,300] parameter(0)
+ ROOT neg = f32[100,200,300] negate(%input)
+ }
+
+ ENTRY entry {
+ %input = f32[100,200,300] parameter(0)
+ ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
+ auto thread_id_to_output_indexing =
+ loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
+ &mlir_context_);
+
+ EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+ MatchIndexingString(R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+ (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000,
+ ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
+ ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
+ )
+ domain:
+ th_x in [0, 127]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 1007]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 11]
+ unroll_id in [0, 3]
+ bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999]
+)"));
+}
+
+TEST_F(LoopTest, ThreadIndexingNotUnrolled) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ neg {
+ %input = f32[20] parameter(0)
+ ROOT neg = f32[20] negate(%input)
+ }
+
+ ENTRY entry {
+ %input = f32[20] parameter(0)
+ ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
+ auto thread_id_to_output_indexing =
+ loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
+ &mlir_context_);
+ EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+ MatchIndexingString(R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
+ domain:
+ th_x in [0, 19]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 0]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ )"));
+ auto thread_id_to_input_indexing =
+ loop_fusion->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
+ EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
+ MatchIndexingString(R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
+ domain:
+ th_x in [0, 19]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 0]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ )"));
+}
+
+TEST_F(LoopTest, Broadcast) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ bcast {
+ %input = f32[20] parameter(0)
+ ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1}
+ }
+
+ ENTRY entry {
+ %input = f32[20] parameter(0)
+ ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
+ auto thread_id_to_output_indexing =
+ loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
+ &mlir_context_);
+ EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+ MatchIndexingString(R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+ (bl_x * 128 + th_x) floordiv 600,
+ ((bl_x * 128 + th_x) floordiv 30) mod 20,
+ (bl_x * 128 + th_x) mod 30)
+ domain:
+ th_x in [0, 127]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 46]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ bl_x * 128 + th_x in [0, 5999]
+ )"));
+ auto thread_id_to_input_indexing =
+ loop_fusion->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
+ EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
+ MatchIndexingString(R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
+ (((bl_x * 128 + th_x) floordiv 30) mod 20)
+ domain:
+ th_x in [0, 127]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 46]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ bl_x * 128 + th_x in [0, 5999]
+ )"));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc
new file mode 100644
index 0000000..e009ea1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc
@@ -0,0 +1,1330 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/reduction.h"
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/buffer_assignment.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/fusions/reduction_base.h"
+#include "xla/service/gpu/fusions/thunk_util.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/ir_emitter_nested.h"
+#include "xla/service/gpu/kernel_arguments.h"
+#include "xla/service/gpu/kernel_reuse_cache.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/service/gpu/runtime/kernel_thunk.h"
+#include "xla/service/gpu/runtime/thunk.h"
+#include "xla/service/gpu/target_util.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/kernel_support_library.h"
+#include "xla/service/llvm_ir/llvm_loop.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/service/llvm_ir/loop_emitter.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using TypedPointer = std::pair<llvm::Value* const, llvm::Type* const>;
+
+// Fusion root -> array of indexes, one per reduction output.
+using ReductionOutputMap =
+ ConstHloInstructionMap<absl::Span<llvm_ir::IrArray const>>;
+
+using ExtraOutputGensMap = ConstHloInstructionMap<llvm_ir::ElementGenerator>;
+
+int GetNumOutputs(const Shape& shape) {
+ if (shape.IsTuple()) {
+ return shape.tuple_shapes_size();
+ }
+ return 1;
+}
+
+const Shape& OutputShape(const Shape& output_shape, int output_index) {
+ CHECK(output_index == 0 || output_shape.IsTuple());
+ return output_shape.IsTuple() ? output_shape.tuple_shapes(output_index)
+ : output_shape;
+}
+
+llvm::Type* GetIndexType(const HloFusionInstruction& fusion,
+ const Tiling& tiling, llvm::IRBuilder<>* builder) {
+ return GetIndexTypeForKernel(
+ &fusion, tiling.GetNumThreadsPerBlock() * tiling.GetNumBlocks(), builder);
+}
+
+llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input,
+ llvm::Type* element_type, llvm::Twine name) {
+ return builder->CreateAddrSpaceCast(
+ input,
+ llvm::PointerType::get(element_type,
+ /*AddressSpace=*/0),
+ name);
+}
+
+class ReductionEmitter {
+ public:
+ ReductionEmitter(const HloFusionAnalysis& analysis,
+ const ReductionInfo& reduction_codegen_info,
+ IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ llvm::IRBuilder<>* builder)
+ : builder_(builder),
+ elemental_emitter_(ir_emitter_context, builder_),
+ analysis_(analysis),
+ reduction_codegen_info_(reduction_codegen_info),
+ ir_emitter_context_(ir_emitter_context),
+ fusion_(fusion),
+ index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(),
+ elemental_emitter_.builder())) {
+ for (auto hero : analysis.fusion_heroes()) {
+ if (hero.opcode() == HloOpcode::kReduce) {
+ for (int i = 0; i < hero.instruction().operand_count() / 2; ++i) {
+ CHECK(LayoutUtil::IsMonotonicWithDim0Major(
+ hero.instruction().operand(i)->shape().layout()))
+ << "reduction-layout-normalizer must run before code generation";
+ }
+ }
+ }
+ }
+
+ absl::StatusOr<FusionEmissionResult> EmitInitializers();
+ absl::Status EmitKernel(const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs);
+
+ private:
+ friend class ReductionGroupEmitter;
+
+ absl::StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkForFusion(
+ const LaunchDimensions& launch_dimensions,
+ absl::string_view discriminator,
+ std::function<absl::Status(std::vector<llvm_ir::IrArray>,
+ std::vector<llvm_ir::IrArray>)>
+ kernel_builder_fn);
+
+ absl::StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
+ const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice,
+ int output_index);
+
+ absl::Status EmitIRForReduction(
+ absl::Span<const HloInstruction* const> instr_index_group,
+ FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
+ const Shape& input_shape);
+
+ void MaybeEmitFenceForAMDGPU();
+ void EmitSyncThreads();
+
+ int ReducedDimensionSize() const {
+ return reduction_codegen_info_.GetTiling().GetShape()[2];
+ }
+
+ llvm::IRBuilder<>* builder_;
+ GpuElementalIrEmitter elemental_emitter_;
+ const HloFusionAnalysis& analysis_;
+ const ReductionInfo& reduction_codegen_info_;
+ IrEmitterContext& ir_emitter_context_;
+ const HloFusionInstruction& fusion_;
+ llvm::Type* index_ty_;
+};
+
+class ReductionEmitter;
+
+class ReductionGroupEmitter {
+ public:
+ struct ReductionCalculationState {
+ std::optional<llvm_ir::SharedMemoryTile> shared_cache;
+ llvm::Value* initial_value;
+ llvm::AllocaInst* partial_result_address;
+ llvm::AllocaInst* input_address;
+ llvm_ir::ElementGenerator input_gen;
+ };
+
+ ReductionGroupEmitter(
+ ReductionEmitter& reduction_emitter,
+ absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
+ const ReductionOutputMap& result_ir_arrays,
+ FusedIrEmitter& fused_emitter);
+
+ const ReductionCalculationState& GetCalculationStateFor(
+ const HloInstruction* instruction, int operand_idx) const {
+ const ReductionOpState& op_state = state_.at(instruction);
+ CHECK_LT(operand_idx, op_state.size());
+ return op_state[operand_idx];
+ }
+
+ void SetCalculationStateFor(
+ const ReductionCalculationState& calculation_state,
+ const HloInstruction* instruction, int operand_idx) {
+ ReductionOpState& op_state = state_[instruction];
+ CHECK_EQ(operand_idx, op_state.size());
+ op_state.push_back(calculation_state);
+ }
+
+ void EmitReductionOutputForRowReduction(
+ const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction,
+ const std::vector<const HloInstruction*>& roots) const;
+
+ void EmitReductionOutputForColumnReduction(
+ const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction,
+ const std::vector<const HloInstruction*>& roots) const;
+
+ void EmitFullWarpShuffleDownLoopForReduce(
+ const HloComputation* reducer,
+ absl::Span<TypedPointer const> partial_result_addresses,
+ int threads_per_block, int num_results_per_warp) const;
+
+ void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction,
+ const std::vector<const HloInstruction*>& roots,
+ absl::Span<TypedPointer const> values) const;
+
+ llvm_ir::IrArray::Index GetOutputIndexForReduction(
+ const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction, const HloInstruction* root,
+ int output_idx) const;
+
+ void GenerateElementForReducer(const HloReduceInstruction* reduction,
+ const llvm_ir::IrArray::Index& index) const;
+
+ absl::Status EmitExtraOutputsForReduce(
+ const Shape& reduction_operand_shape,
+ const llvm_ir::IrArray::Index& index,
+ const ExtraOutputGensMap& extra_output_gens);
+
+ private:
+ ReductionEmitter& reduction_emitter_;
+ const ReductionOutputMap& result_ir_arrays_;
+
+ // One state per reduction operand.
+ using ReductionOpState = absl::InlinedVector<ReductionCalculationState, 2>;
+
+ // HloInstruction -> operand_idx -> cache
+ absl::flat_hash_map<const HloInstruction*, ReductionOpState> state_;
+};
+
+// Creates accumulator alloca's, populates them with initial values, generates
+// __shared__ caches and returns the populated object.
+ReductionGroupEmitter::ReductionGroupEmitter(
+ ReductionEmitter& reduction_emitter,
+ absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
+ const ReductionOutputMap& result_ir_arrays, FusedIrEmitter& fused_emitter)
+ : reduction_emitter_(reduction_emitter),
+ result_ir_arrays_(result_ir_arrays) {
+ const ReductionInfo& reduction_info =
+ reduction_emitter_.reduction_codegen_info_;
+ VLOG(10) << "Emit prologue for reduction: "
+ << reduction_emitter_.fusion_.ToString();
+
+ auto* builder = reduction_emitter_.builder_;
+ for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) {
+ for (int op_result_idx = 0;
+ op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) {
+ Shape result_shape = OutputShape(reduce_hlo->shape(), op_result_idx);
+
+ llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
+ result_shape.element_type(), builder->GetInsertBlock()->getModule());
+ llvm::AllocaInst* reduction_input_address =
+ llvm_ir::EmitAllocaAtFunctionEntry(
+ element_type, "reduction_input_address", builder);
+
+ llvm::AllocaInst* result_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ element_type, "partial_reduction_result", builder);
+
+ const HloInstruction* init_value =
+ reduce_hlo->init_values()[op_result_idx];
+
+ // Initialize the partial result with the initial value of the reduction.
+ llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(
+ *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty()))
+ .value();
+
+ builder->CreateStore(init_ir_value, result_address);
+ const Tiling& tiling = reduction_info.GetTiling();
+ auto shared_cache = [&]() -> std::optional<llvm_ir::SharedMemoryTile> {
+ auto* module = reduction_emitter.ir_emitter_context_.llvm_module();
+ if (reduction_info.IsRowReduction()) {
+ // Multi-row reductions do not use shared memory.
+ if (RowReductionGetRowsPerWarp(
+ reduction_emitter_.ReducedDimensionSize()) > 1) {
+ return std::nullopt;
+ }
+ // Allocate one shared memory element per warp.
+ auto block_size = tiling.GetThreadsPerBlock();
+ CHECK_EQ(block_size[ReductionDimensions::kRowMinorReducedDimension] %
+ WarpSize(),
+ 0);
+ return llvm_ir::AllocateSharedMemoryTile(
+ module, element_type,
+ {block_size[ReductionDimensions::kRowKeptDimension],
+ block_size[ReductionDimensions::kRowMinorReducedDimension] /
+ WarpSize()},
+ "shared_cache");
+ }
+ const auto& num_threads = tiling.GetThreadsPerBlock();
+ int n = num_threads[ReductionDimensions::kColReducedDimension];
+ CHECK_EQ(n, num_threads[ReductionDimensions::kColMinorKeptDimension]);
+ // The "+1" is used to avoid bank conflicts.
+ return llvm_ir::AllocateSharedMemoryTile(module, element_type,
+ {n, n + 1}, "shared_cache");
+ }();
+
+ llvm_ir::ElementGenerator input_gen =
+ *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]);
+ SetCalculationStateFor({shared_cache, init_ir_value, result_address,
+ reduction_input_address, input_gen},
+ reduce_hlo, op_result_idx);
+ }
+ }
+}
+
+void ReductionEmitter::MaybeEmitFenceForAMDGPU() {
+ auto* module = builder_->GetInsertBlock()->getModule();
+ if (IsAMDGPU(module) &&
+ ir_emitter_context_.rocm_compute_capability().fence_before_barrier()) {
+ builder_->CreateFence(
+ llvm::AtomicOrdering::SequentiallyConsistent,
+ builder_->getContext().getOrInsertSyncScopeID("workgroup"));
+ }
+}
+
+void ReductionEmitter::EmitSyncThreads() {
+ MaybeEmitFenceForAMDGPU();
+ EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder_);
+}
+
+// Builds a thunk that calls a new or reused kernel for a fusion operation.
+//
+// The caller must specify the same launch dimensions for fusions which have
+// the same computation.
+//
+// If a given fusion is implemented using multiple kernels, then for each
+// kernel we should provide a discriminator, such as "init" and "impl".
+//
+// The builder_fn is only invoked if the kernel couldn't be reused.
+//
+// This is the typical usage pattern of this method:
+//
+// ```
+// auto builder_fn = [](std::vector<llvm_ir::IrArray> inputs,
+// std::vector<llvm_ir::IrArray> outputs) { ... };
+// TF_ASSIGN_OR_RETURN(
+// auto thunk,
+// BuildKernelThunkForFusion(..., launch_dimensions, builder_fn));
+// AddThunkToThunkSequence(std::move(thunk))
+// ```
+absl::StatusOr<std::unique_ptr<Thunk>>
+ReductionEmitter::BuildKernelThunkForFusion(
+ const LaunchDimensions& launch_dimensions, absl::string_view discriminator,
+ std::function<absl::Status(std::vector<llvm_ir::IrArray>,
+ std::vector<llvm_ir::IrArray>)>
+ kernel_builder_fn) {
+ const HloComputation* fused_computation =
+ fusion_.fused_instructions_computation();
+ std::string suggested_kernel_name = std::string(fusion_.name());
+
+ TF_ASSIGN_OR_RETURN(auto kernel_arguments,
+ KernelArguments::Create(
+ ir_emitter_context_.buffer_assignment(), &fusion_));
+
+ auto [status_or_entry, cached] =
+ ir_emitter_context_.kernel_cache().GetWithStatus(
+ fused_computation, kernel_arguments.args(), discriminator,
+ [&]() -> absl::StatusOr<KernelReuseCache::Entry> {
+ llvm::Function* kernel;
+ std::vector<llvm_ir::IrArray> input_arrays;
+ std::vector<llvm_ir::IrArray> output_arrays;
+ TF_ASSIGN_OR_RETURN(
+ std::tie(kernel, input_arrays, output_arrays),
+ BuildKernelPrototype(ir_emitter_context_, suggested_kernel_name,
+ kernel_arguments.args(),
+ fusion_.operand_count(), launch_dimensions,
+ builder_));
+ TF_RETURN_IF_ERROR(kernel_builder_fn(input_arrays, output_arrays));
+ // Shared memory is allocated statically.
+ return {{kernel->getName().str(), launch_dimensions,
+ /*cluster_dim=*/std::nullopt,
+ /*shmem_bytes=*/0}};
+ });
+ TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);
+ if (cached) {
+ VLOG(3) << "Reuse: " << suggested_kernel_name << " -> "
+ << entry->kernel_name;
+ }
+
+ return std::make_unique<KernelThunk>(
+ &fusion_, entry->kernel_name, kernel_arguments.args(), launch_dimensions,
+ entry->cluster_dim, entry->shmem_bytes);
+}
+
+absl::Status ReductionGroupEmitter::EmitExtraOutputsForReduce(
+ const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index,
+ const ExtraOutputGensMap& extra_output_gens) {
+ if (extra_output_gens.empty()) {
+ return absl::OkStatus();
+ }
+
+ auto* builder = reduction_emitter_.builder_;
+ // Compute all extra output values before writing them. This avoids
+ // overwriting aliased input/output buffers before all reads occurred.
+ std::vector<std::pair<const HloInstruction*, llvm::Value*>>
+ extra_output_ir_values;
+ extra_output_ir_values.reserve(extra_output_gens.size());
+
+ auto get_index = [&](const HloInstruction* instr) {
+ const Shape& s = instr->shape();
+ return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s)
+ ? index
+ : index.SourceIndexOfBitcast(reduction_operand_shape, s,
+ builder);
+ };
+
+ for (const auto& [instr, generator] : extra_output_gens) {
+ TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
+ generator(get_index(instr)));
+ extra_output_ir_values.emplace_back(instr, extra_output_ir_value);
+ }
+
+ for (const auto& [instr, generator] : extra_output_ir_values) {
+ absl::Span<llvm_ir::IrArray const> result_ir = result_ir_arrays_.at(instr);
+ CHECK_EQ(result_ir.size(), 1);
+ result_ir[0].EmitWriteArrayElement(get_index(instr), generator, builder);
+ }
+ return absl::OkStatus();
+}
+
+absl::StatusOr<std::unique_ptr<Thunk>>
+ReductionEmitter::BuildFusedInitializerThunk(const HloInstruction* fusion_root,
+ BufferAllocation::Slice dest_slice,
+ int output_index) {
+ const HloReduceInstruction* reduce =
+ DynCast<HloReduceInstruction>(fusion_root);
+ TF_RET_CHECK(reduce);
+
+ const HloInstruction* init_value = reduce->init_values()[0];
+ TF_ASSIGN_OR_RETURN(
+ std::optional<std::unique_ptr<Thunk>> constant_init_thunk,
+ BuildConstantInitializerThunk(ir_emitter_context_, fusion_root,
+ init_value, dest_slice));
+ if (constant_init_thunk) {
+ return *std::move(constant_init_thunk);
+ }
+
+ const Shape& dest_shape = fusion_root->shape();
+
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ dest_shape, ir_emitter_context_.gpu_device_info());
+ const HloComputation* fused_computation =
+ fusion_.fused_instructions_computation();
+
+ auto builder_fn = [&](std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs) -> absl::Status {
+ FusedIrEmitter fused_emitter(elemental_emitter_);
+ for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ fused_emitter.BindGenerator(
+ *fused_computation->parameter_instruction(i),
+ [builder = builder_,
+ input = inputs[i]](llvm_ir::IrArray::Index index) {
+ return input.EmitReadArrayElement(index, builder);
+ });
+ }
+ HloInstruction* instr = fused_computation->root_instruction();
+ if (instr->opcode() == HloOpcode::kTuple) {
+ instr = instr->mutable_operand(output_index);
+ } else {
+ CHECK_EQ(0, output_index);
+ }
+ TF_RET_CHECK(instr->shape().IsArray());
+ TF_ASSIGN_OR_RETURN(auto generator,
+ fused_emitter.GetGenerator(*instr->operand(1)));
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]},
+ launch_dimensions, builder_)
+ .EmitLoop(fusion_.name()));
+ return absl::OkStatus();
+ };
+
+ return BuildKernelThunkForFusion(launch_dimensions,
+ /*discriminator=*/
+ absl::StrCat("init_", output_index),
+ builder_fn);
+}
+
+// Emits shuffle-down reduction for the `partial_result_address` using the
+// reduction computation `reducer`, writes output into
+// `partial_result_address`.
+//
+// Multiple partial_result_address inputs happen when doing variadic
+// reduction: each one should get the output value.
+void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce(
+ const HloComputation* reducer,
+ absl::Span<TypedPointer const> partial_result_addresses,
+ int threads_per_block, int num_results_per_warp) const {
+ // This only works when the block size is a multiple of 32 threads.
+ // We check this here as a mistake in the number of threads per
+ // block is very hard to detect.
+ CHECK_EQ(threads_per_block % 32, 0);
+ CHECK_EQ(WarpSize() % num_results_per_warp, 0);
+
+ auto* builder = reduction_emitter_.builder_;
+ for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) {
+ absl::InlinedVector<llvm::Value*, 2> reduction_params;
+
+ for (auto acc : partial_result_addresses) {
+ reduction_params.push_back(acc.first);
+ }
+
+ for (auto [partial_result_address, element_type] :
+ partial_result_addresses) {
+ int bit_width = llvm_ir::GetSizeInBits(element_type);
+ llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
+ element_type, "result_from_other_lane", builder);
+
+ reduction_params.push_back(result_from_other_lane);
+
+ // Bitcast cannot be applied to aggregate types (even packed ones), so
+ // we bitcast addresses of load/store to intN* of the same bit-width.
+ llvm::Type* shuffled_value_type = element_type->isStructTy()
+ ? builder->getIntNTy(bit_width)
+ : element_type;
+
+ llvm::Value* partial_result =
+ builder->CreateLoad(shuffled_value_type, partial_result_address,
+ "partial_reduction_result");
+ builder->CreateStore(
+ EmitFullWarpShuffleDown(
+ partial_result, builder->getInt32(distance), builder,
+ reduction_emitter_.ir_emitter_context_.gpu_device_info()),
+ result_from_other_lane);
+ }
+
+ absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
+ CallNestedComputationWithScalarAddrs(
+ builder, reduction_emitter_.ir_emitter_context_, *reducer,
+ reduction_params);
+ TF_CHECK_OK(returned_scalars.status());
+
+ for (int i = 0; i < returned_scalars->size(); i++) {
+ builder->CreateStore(/*Val=*/returned_scalars->at(i),
+ /*Ptr=*/partial_result_addresses[i].first);
+ }
+ }
+}
+
+llvm_ir::IrArray::Index ReductionGroupEmitter::GetOutputIndexForReduction(
+ const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction, const HloInstruction* root,
+ int output_idx) const {
+ auto* builder = reduction_emitter_.builder_;
+ auto* index_ty = reduction_emitter_.index_ty_;
+
+ // 1d or 2d output index (for row/column reduction).
+ auto projected_index = [&]() -> llvm_ir::IrArray::Index {
+ const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+ const auto& offset = tiling_kernel_info.tile_origin;
+ const auto& shape = reduction_info.GetTiling().GetXlaShape();
+ const auto& thread_ids = tiling_kernel_info.thread_id_info.thread_ids;
+ if (reduction_info.IsRowReduction()) {
+ constexpr int kDim = ReductionDimensions::kRowKeptDimension;
+ return {{builder->CreateAdd(offset[kDim], thread_ids[kDim])},
+ {shape.dimensions(kDim)},
+ index_ty};
+ }
+ auto* major_idx = offset[ReductionDimensions::kColMajorKeptDimension];
+ auto* minor_idx = builder->CreateAdd(
+ offset[ReductionDimensions::kColMinorKeptDimension],
+ thread_ids[ReductionDimensions::kColReducedDimension]);
+ return {{major_idx, minor_idx},
+ ShapeUtil::DeleteDimension(
+ ReductionDimensions::kColReducedDimension, shape),
+ index_ty};
+ }();
+
+ auto physical_shape = ShapeUtil::DeleteDimensions(
+ reduction->dimensions(), reduction->operand(output_idx)->shape());
+ auto physical_index =
+ projected_index.SourceIndexOfBitcast(physical_shape, builder);
+ return llvm_ir::IrArray::Index(physical_index.multidim(),
+ OutputShape(reduction->shape(), output_idx),
+ index_ty)
+ .SourceIndexOfBitcast(OutputShape(root->shape(), output_idx), builder);
+}
+
+void ReductionGroupEmitter::WriteReductionOutput(
+ const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction,
+ const std::vector<const HloInstruction*>& roots,
+ const absl::Span<TypedPointer const> values) const {
+ auto* builder = reduction_emitter_.builder_;
+ const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+ const HloComputation* reducer = reduction->to_apply();
+ for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) {
+ auto [output_ptr, type] = typed_ptr;
+ for (auto root : roots) {
+ llvm_ir::IrArray::Index output_index =
+ GetOutputIndexForReduction(tiling_kernel_info, reduction, root, oidx);
+
+ llvm::Value* output_address =
+ result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress(
+ output_index, builder, "output_element_address");
+ if (reduction_info.IsRaceFree()) {
+ FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_);
+ llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output");
+ fused_emitter.BindGenerator(
+ *reduction,
+ [&](const llvm_ir::IrArray::Index& index) { return loaded; });
+ llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root);
+ llvm::Value* generated = *gen(output_index);
+ builder->CreateStore(generated, output_address);
+ } else {
+ CHECK_EQ(values.size(), 1);
+ CHECK_EQ(roots.size(), 1);
+ CHECK_EQ(reduction, root)
+ << "output fusion is not allowed for racing reductions";
+ TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
+ builder, reduction_emitter_.ir_emitter_context_, *reducer,
+ output_address, output_ptr, type));
+ }
+ }
+ }
+}
+
+void ReductionGroupEmitter::EmitReductionOutputForRowReduction(
+ const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction,
+ const std::vector<const HloInstruction*>& roots) const {
+ const HloComputation* reducer = reduction->to_apply();
+ const auto& thread_id_info = tiling_kernel_info.thread_id_info;
+ const auto& thread_ids = thread_id_info.thread_ids;
+ auto* thread_id_x =
+ thread_ids[ReductionDimensions::kRowMinorReducedDimension];
+ auto constant = [&](uint64_t c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
+ };
+
+ auto* builder = reduction_emitter_.builder_;
+ auto is_zero = [&](llvm::Value* value) {
+ return builder->CreateICmpEQ(value, constant(0));
+ };
+
+ int num_outputs = reducer->num_parameters() / 2;
+ absl::InlinedVector<TypedPointer, 2> current_outputs;
+ for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+ const auto& state = GetCalculationStateFor(reduction, output_idx);
+ current_outputs.push_back(
+ {state.partial_result_address,
+ state.partial_result_address->getAllocatedType()});
+ }
+
+ const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+ const Tiling& tiling = reduction_info.GetTiling();
+ int num_rows_per_warp =
+ RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize());
+ EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs),
+ tiling.GetNumThreadsPerBlock(),
+ num_rows_per_warp);
+
+ KernelSupportLibrary ksl(builder);
+ llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize()));
+
+ auto emit_write_output = [&](llvm::Value* write_condition,
+ const absl::Span<TypedPointer const> values) {
+ ksl.If("reduction_write_output", write_condition, [&] {
+ WriteReductionOutput(tiling_kernel_info, reduction, roots, values);
+ });
+ };
+
+ // The major kept dimension and vector dimension are not tiled, so they're
+ // always in bounds.
+ llvm::Value* is_in_bounds_y = builder->CreateICmpULT(
+ thread_ids[ReductionDimensions::kRowKeptDimension],
+ tiling_kernel_info
+ .output_tile_bounds[ReductionDimensions::kRowKeptDimension]);
+
+ ksl.If("thread_in_bounds", is_in_bounds_y, [&] {
+ if (num_rows_per_warp > 1) {
+ llvm::Value* is_writing_thread = is_zero(builder->CreateAnd(
+ thread_id_x,
+ constant(reduction_emitter_.ReducedDimensionSize() - 1)));
+ emit_write_output(is_writing_thread, current_outputs);
+ return;
+ }
+
+ ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
+ for (int oidx = 0; oidx < num_outputs; oidx++) {
+ auto& state = GetCalculationStateFor(reduction, oidx);
+ state.shared_cache->Store(
+ builder->CreateLoad(current_outputs[oidx].second,
+ current_outputs[oidx].first),
+ {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
+ warp_id},
+ builder);
+ }
+ });
+
+ // TODO(cheshire): Don't we want to sync it once for everything in the
+ // output? Not once per each?
+ reduction_emitter_.EmitSyncThreads();
+ ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
+ absl::InlinedVector<TypedPointer, 2> selected_values;
+ for (int oidx = 0; oidx < num_outputs; oidx++) {
+ auto& state = GetCalculationStateFor(reduction, oidx);
+ llvm::Value* block_accum_addr = state.shared_cache->Address(
+ {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
+ thread_id_info.lane_id},
+ builder);
+
+ llvm::Type* element_type =
+ state.partial_result_address->getAllocatedType();
+
+ // Ensure initial value address is in generic, not scratch.
+ llvm::Value* initial_value_addr =
+ CastSharedToGlobal(builder,
+ llvm_ir::EmitAllocaAtFunctionEntry(
+ element_type, "initial_value_addr", builder),
+ element_type, /*name=*/"");
+ builder->CreateStore(state.initial_value, initial_value_addr);
+
+ llvm::Value* warp_exists = builder->CreateICmpULT(
+ thread_id_x,
+ constant(tiling.GetThreadsPerBlock()
+ [ReductionDimensions::kRowMinorReducedDimension] /
+ WarpSize()));
+
+ llvm::Value* selected_value = builder->CreateSelect(
+ warp_exists, block_accum_addr, initial_value_addr);
+
+ selected_values.push_back({selected_value, element_type});
+ }
+
+ // If only one warp produces the output element, we don't need to emit
+ // an inter warp reduce. In our tiling, DimX is the minor reduced
+ // dimension. The major reduced dimension is always emitted as a loop.
+ // TODO(b/241414088) If only warp is present, then inter-warp
+ // communication using shared memory and synchronization using barrier is
+ // also unnecessary and should be removed.
+ if (tiling.GetThreadsPerBlock()
+ [ReductionDimensions::kRowMinorReducedDimension] > WarpSize()) {
+ EmitFullWarpShuffleDownLoopForReduce(
+ reducer, absl::MakeSpan(selected_values),
+ tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1);
+ }
+
+ emit_write_output(is_zero(thread_id_x), selected_values);
+ });
+ });
+}
+
+// Same arguments as EmitReductionOutputForRowReduction.
+void ReductionGroupEmitter::EmitReductionOutputForColumnReduction(
+ const TilingKernelInfo& tiling_kernel_info,
+ const HloReduceInstruction* reduction,
+ const std::vector<const HloInstruction*>& roots) const {
+ auto* builder = reduction_emitter_.builder_;
+ KernelSupportLibrary ksl(builder);
+ const HloComputation* reducer = reduction->to_apply();
+ const auto& thread_id_info = tiling_kernel_info.thread_id_info;
+ const auto& thread_ids = thread_id_info.thread_ids;
+
+ auto constant = [&](uint64_t c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
+ };
+ auto is_zero = [&](llvm::Value* value) {
+ return builder->CreateICmpEQ(value, constant(0));
+ };
+ const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+ const Tiling& tiling = reduction_info.GetTiling();
+ int num_outputs = reducer->num_parameters() / 2;
+
+ auto* kept_index = thread_ids[ReductionDimensions::kColMinorKeptDimension];
+ auto* reduced_index = thread_ids[ReductionDimensions::kColReducedDimension];
+
+ // Store the transpose in shared memory.
+ for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+ const auto& state = GetCalculationStateFor(reduction, output_idx);
+ auto* current_output_value =
+ builder->CreateLoad(state.partial_result_address->getAllocatedType(),
+ state.partial_result_address);
+ state.shared_cache->Store(current_output_value, {kept_index, reduced_index},
+ builder);
+ }
+
+ reduction_emitter_.EmitSyncThreads();
+
+ // Get transposed element from shared memory.
+ absl::InlinedVector<TypedPointer, 2> shmem_transposed_addrs;
+ for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+ const auto& state = GetCalculationStateFor(reduction, output_idx);
+ auto* shmem_transposed_addr =
+ state.shared_cache->Address({reduced_index, kept_index}, builder);
+ shmem_transposed_addrs.push_back(
+ {shmem_transposed_addr, state.shared_cache->GetElementType()});
+ }
+
+ EmitFullWarpShuffleDownLoopForReduce(reducer,
+ absl::MakeSpan(shmem_transposed_addrs),
+ tiling.GetNumThreadsPerBlock(),
+ /*num_results_per_warp=*/1);
+
+ // Some warps in the block are completely outside of the bound of the
+ // tensor, so they should not write any output at all.
+ llvm::Value* has_output = builder->CreateAnd(
+ builder->CreateICmpULT(
+ reduced_index,
+ tiling_kernel_info
+ .output_tile_bounds[ReductionDimensions::kColMinorKeptDimension]),
+ builder->CreateICmpULT(
+ kept_index,
+ tiling_kernel_info
+ .output_tile_bounds[ReductionDimensions::kColReducedDimension]));
+
+ ksl.If("reduction_write_output",
+ builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
+ WriteReductionOutput(tiling_kernel_info, reduction, roots,
+ shmem_transposed_addrs);
+ });
+}
+
+// Generate a single element of the tile (update the accumulator state) for a
+// given reducer.
+void ReductionGroupEmitter::GenerateElementForReducer(
+ const HloReduceInstruction* reduction,
+ const llvm_ir::IrArray::Index& index) const {
+ HloComputation* reducer = reduction->to_apply();
+ auto* builder = reduction_emitter_.builder_;
+ CHECK_EQ(reducer->num_parameters() % 2, 0);
+
+ absl::InlinedVector<llvm::Value*, 2> reduction_accumulators;
+ absl::InlinedVector<llvm::Value*, 2> reduction_input_value;
+ for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) {
+ const auto& state = GetCalculationStateFor(reduction, red_idx);
+
+ llvm::AllocaInst* input_address = state.input_address;
+ auto input_index =
+ index.SourceIndexOfBitcast(reduction->operand(0)->shape(), builder);
+ llvm::Value* const input_ir_value = *state.input_gen(input_index);
+ builder->CreateStore(input_ir_value, input_address);
+ reduction_accumulators.push_back(state.partial_result_address);
+ reduction_input_value.push_back(input_address);
+ }
+
+ absl::InlinedVector<llvm::Value*, 4> reduction_params;
+ for (llvm::Value* acc : reduction_accumulators) {
+ reduction_params.push_back(acc);
+ }
+ for (llvm::Value* value : reduction_input_value) {
+ reduction_params.push_back(value);
+ }
+
+ // Emit a call to the variadic reducer. Since it may be returning a
+ // tuple, we can't return it directly as a value. Instead, before
+ // the call, we create N (N = # arguments in the tuple) allocas, one
+ // for each returned argument, then when we make the call we pass N
+ // pointers as last parameters, the called computation writes into
+ // those pointers, and we have returned values on the stack (as well
+ // as pointers to them).
+ absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
+ CallNestedComputationWithScalarAddrs(
+ builder, reduction_emitter_.ir_emitter_context_, *reducer,
+ reduction_params);
+ TF_CHECK_OK(returned_scalars.status());
+
+ for (int i = 0; i < returned_scalars->size(); i++) {
+ builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]);
+ }
+}
+
+// Emits code for reductions in the output_instructions.
+absl::Status ReductionEmitter::EmitIRForReduction(
+ absl::Span<const HloInstruction* const> instr_index_group,
+ FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
+ const Shape& input_shape) {
+ ExtraOutputGensMap extra_output_gens;
+ absl::flat_hash_map<const HloReduceInstruction*,
+ std::vector<const HloInstruction*>>
+ heroes_to_roots;
+ // Keep a list of deduplicated heroes separate from heroes_to_roots to make
+ // the CodeGen deterministic.
+ std::vector<const HloReduceInstruction*> heroes;
+
+ for (const HloInstruction* hlo : instr_index_group) {
+ auto& hero = FindNonTrivialHero(*hlo);
+ if (IsRealReductionHero(*hlo, hero)) {
+ auto reduction = Cast<HloReduceInstruction>(&hero);
+ if (heroes_to_roots.find(reduction) == heroes_to_roots.end()) {
+ heroes.push_back(reduction);
+ }
+ heroes_to_roots[reduction].push_back(hlo);
+ } else {
+ extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo);
+ }
+ }
+
+ CHECK(!heroes.empty()) << " expect at least one reduce instructions.";
+ const Tiling& tiling = reduction_codegen_info_.GetTiling();
+ CHECK_EQ(tiling.GetNumThreadsPerBlock() % WarpSize(), 0);
+ ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays,
+ fused_emitter);
+
+ TF_ASSIGN_OR_RETURN(
+ TilingKernelInfo tiling_kernel_info,
+ EmitTilingKernel(
+ builder_, tiling, index_ty_,
+ [&](const TilingThreadIdInfo& thread_id_info,
+ const llvm_ir::IrArray::Index& tile_index,
+ absl::Span<llvm::Value* const> tile_dimensions) {
+ auto emit_element =
+ [&](absl::Span<llvm::Value* const> index_in_tile) {
+ auto index = tile_index.AddOffset(index_in_tile, builder_);
+
+ // Emit code to generate the input and perform the reduction
+ // computation for each reduction instruction.
+ for (const HloReduceInstruction* reduce : heroes) {
+ group_emitter.GenerateElementForReducer(reduce, index);
+ }
+
+ // Emit code to generate the output for the non-reduction
+ // instructions in the fusion, if any.
+ TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce(
+ ShapeUtil::MakeShape(
+ F32, reduction_codegen_info_.GetTiling().GetShape()),
+ index, extra_output_gens));
+ };
+ EmitTile(builder_, reduction_codegen_info_.GetTiling(),
+ thread_id_info, tile_dimensions, emit_element);
+ }));
+
+ KernelSupportLibrary ksl(builder_);
+ for (auto reduce : heroes) {
+ if (reduction_codegen_info_.IsRowReduction()) {
+ group_emitter.EmitReductionOutputForRowReduction(
+ tiling_kernel_info, reduce, heroes_to_roots[reduce]);
+ } else {
+ group_emitter.EmitReductionOutputForColumnReduction(
+ tiling_kernel_info, reduce, heroes_to_roots[reduce]);
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+absl::StatusOr<FusionEmissionResult> ReductionEmitter::EmitInitializers() {
+ FusionEmissionResult result;
+ if (reduction_codegen_info_.IsRaceFree()) {
+ return result;
+ }
+ // We need to get the dest slice by traversing the slice assigned to
+ // fusion, because instructions inside fusion don't have buffer assignment.
+ //
+ // The order of fusion roots is determined by its position in the result
+ // tuple. For example, in the following fused computation
+ //
+ // %fused_computation {
+ // %a = ...
+ // &b = ...
+ // ROOT %root = tuple(%a, %b)
+ // }
+ //
+ // The fusion root with index = 0 is %a, and the fusion root %b has index 1.
+ // Therefore we can get the ordered slices by calling ForEachSubshape on the
+ // result shape.
+ std::vector<BufferAllocation::Slice> slices;
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) {
+ if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) {
+ return absl::OkStatus();
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ BufferAllocation::Slice slice,
+ ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_,
+ index));
+ slices.push_back(slice);
+ return absl::OkStatus();
+ }));
+
+ absl::Span<HloInstructionAdaptor const> fusion_roots =
+ analysis_.fusion_roots();
+ for (int i = 0; i < fusion_roots.size(); ++i) {
+ const HloInstruction* fusion_root = &fusion_roots[i].instruction();
+
+ if (IsReductionFromOrToContiguousDimensions(*fusion_root)) {
+ TF_ASSIGN_OR_RETURN(
+ result.thunks.emplace_back(),
+ BuildFusedInitializerThunk(fusion_root, slices[i], i));
+ }
+ }
+ return result;
+}
+
+absl::Status ReductionEmitter::EmitKernel(
+ const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs) {
+ const HloComputation* fused_computation =
+ fusion_.fused_instructions_computation();
+ FusedIrEmitter fused_emitter(elemental_emitter_);
+ for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
+ fused_emitter.BindGenerator(
+ *fused_operand, [builder = builder_, input = inputs[i],
+ fused_operand](const llvm_ir::IrArray::Index& index) {
+ return input.EmitReadArrayElement(index, builder,
+ fused_operand->name());
+ });
+ }
+
+ // Get outputs.
+ ReductionOutputMap result_ir_arrays;
+
+ int ir_arrays_idx = 0;
+ for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) {
+ int get_num_results = GetNumOutputs(root.shape());
+ result_ir_arrays[&root.instruction()] =
+ absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results);
+ ir_arrays_idx += get_num_results;
+ }
+
+ KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll);
+
+ // Use raw block_id_y to select the i-th parallel reduction to run. Using
+ // block_id_y instead of block_id_x simplifies the index calculation
+ // for reduction code generation as the block_id_y is orthogonal to
+ // the indices used within the reductions.
+ const auto& instr_index_groups =
+ reduction_codegen_info_.GetGroups().grouped_roots;
+ Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape();
+
+ llvm::Value* block_id_y = gpu::EmitCallToTargetIntrinsic(
+ gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_);
+ llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
+ llvm::cast<llvm::Instruction>(block_id_y),
+ builder_->GetInsertBlock()->getModule());
+ block_id_y = builder_->CreateZExtOrTrunc(block_id_y, builder_->getInt32Ty());
+ block_id_y->setName("block.id.y");
+ for (int i = 0; i < instr_index_groups.size(); ++i) {
+ TF_RETURN_IF_ERROR(ksl.IfWithStatus(
+ absl::StrCat("reduce-group-", i),
+ builder_->CreateICmpEQ(block_id_y, builder_->getInt32(i)), [&] {
+ return EmitIRForReduction(instr_index_groups[i], fused_emitter,
+ result_ir_arrays, reduce_operand_shape);
+ }));
+ }
+
+ return absl::OkStatus();
+}
+
+} // namespace
+
+absl::StatusOr<FusionEmissionResult> ReductionFusion::EmitInitializers(
+ IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion) const {
+ llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext());
+ return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
+ fusion, &builder)
+ .EmitInitializers();
+}
+
+absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const {
+ return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
+ fusion, builder)
+ .EmitKernel(launch_dims, inputs, outputs);
+}
+
+int ReductionInfo::GetRowsPerWarp() const {
+ if (!is_row_reduction_) return 1;
+ return RowReductionGetRowsPerWarp(
+ tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]);
+}
+
+LaunchDimensions ReductionInfo::launch_dimensions() const {
+ size_t blocks_y = groups_.grouped_roots.size();
+ return {se::BlockDim(/*x=*/tiling_.GetNumBlocks(),
+ /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
+ se::ThreadDim(/*x=*/tiling_.GetNumThreadsPerBlock(),
+ /*y=*/1, /*z=*/1)};
+}
+
+ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) {
+ auto* hero_reduction = analysis.FindHeroReduction();
+ CHECK_NE(hero_reduction, nullptr);
+ Shape input_shape = hero_reduction->operand(0)->shape();
+ ReductionDimensions reduction_dimensions =
+ GetReductionKindAndContiguousComponents(*hero_reduction);
+ auto shape = reduction_dimensions.dimensions;
+ VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
+ << " " << shape[0] << " " << shape[1] << " " << shape[2];
+ Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions);
+
+ int64_t num_threads_y =
+ reduction_dimensions.is_row_reduction ? 1 : WarpSize();
+ int64_t rows_per_warp =
+ reduction_dimensions.is_row_reduction
+ ? RowReductionGetRowsPerWarp(
+ shape[ReductionDimensions::kRowMinorReducedDimension])
+ : 1;
+ int64_t num_threads_x = [&] {
+ if (reduction_dimensions.is_row_reduction) {
+ if (rows_per_warp > 1) {
+ return shape[ReductionDimensions::kRowMinorReducedDimension];
+ }
+ int64_t max_block_size =
+ MinThreadsXRowReduction(hero_reduction->GetModule()->config());
+ return std::min(
+ max_block_size,
+ RoundUpTo(
+ CeilOfRatio(shape[ReductionDimensions::kRowMinorReducedDimension],
+ reduction_tiling
+ [ReductionDimensions::kRowMinorReducedDimension]),
+ WarpSize()));
+ }
+ return WarpSize();
+ }();
+
+ // If we're limited by the size of the x dimension, add additional parallelism
+ // in the y dimension. The code generator doesn't currently support
+ // parallelizing the z dimension (major reduced dimensions). The general
+ // recommendation is to use between 128 and 512 threads, so we just go for
+ // 256. See https://forums.developer.nvidia.com/t/55529
+ constexpr int64_t kThreadsPerBlockTarget = 256;
+ if (reduction_dimensions.is_row_reduction &&
+ num_threads_x * 2 <= kThreadsPerBlockTarget) {
+ int64_t kept_size =
+ reduction_dimensions.dimensions[ReductionDimensions::kRowKeptDimension];
+ // Increase the size of the y dimension as long as there's remaining
+ // parallelism.
+ if (kept_size * num_threads_x <= kThreadsPerBlockTarget) {
+ num_threads_y = kept_size;
+ // num_threads_x is a power of two, but it may be less than 32. If dim_y
+ // is also small, we may have to increase the bound so the total number of
+ // threads is a multiple of 32.
+ while ((num_threads_x * num_threads_y) % 32) ++num_threads_y;
+ } else {
+ num_threads_y = kThreadsPerBlockTarget / num_threads_x;
+ }
+ }
+
+ int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x,
+ reduction_tiling);
+
+ absl::InlinedVector<int64_t, 4> num_threads{1, num_threads_y, num_threads_x};
+ absl::InlinedVector<int64_t, 4> tiled_shape{shape[0], shape[1],
+ shape[2] / vector_size};
+ absl::InlinedVector<int64_t, 4> tile_per_thread{
+ reduction_tiling[0], reduction_tiling[1],
+ std::max<int64_t>(reduction_tiling[2] / vector_size, 1)};
+ if (rows_per_warp > 1) {
+ // If we produce more than one element per thread, that means the reduced
+ // dimension is small and it can't be tiled - we already have more threads
+ // in a warp than the size of the reduced dimension. The code generator
+ // doesn't currently support tiling the kept dimension, because it just
+ // uses the thread ID as the coordinate.
+ tile_per_thread[2] = 1;
+ }
+ if (vector_size != 1) {
+ num_threads.push_back(1); // The vector dimension is a loop.
+ tiled_shape.push_back(vector_size);
+ tile_per_thread.push_back(vector_size);
+ }
+
+ Tiling tiling(tiled_shape, tile_per_thread, num_threads,
+ /*loops_to_unroll=*/{false, false, true, false});
+ bool reduction_is_race_free = ReductionIsRaceFree(
+ hero_reduction->GetModule()->config(), reduction_dimensions);
+ return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction,
+ reduction_is_race_free,
+ GroupDisjointReductions(analysis, /*for_mlir=*/false),
+ hero_reduction);
+}
+
+std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const {
+ if (!groups_.is_reduction_root[root_index]) {
+ auto map = ComposeIndexingMaps(
+ GetIndexingMapForTiling(tiling_, ctx),
+ GetBitcastMap(tiling_.GetXlaShape(),
+ analysis_.fusion_root(root_index).shape(), ctx));
+ AddGroupIdConstraint(map, root_index, groups_);
+ return map;
+ }
+ const auto& hero = analysis_.fusion_hero(root_index).instruction();
+
+ auto block_offsets = GetBlockOffsetsForTiling(tiling_, ctx);
+ auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx),
+ tiling_.GetThreadsPerBlock());
+
+ auto physical_shape =
+ ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape());
+ std::vector<DimVar> dimension_ranges{
+ {{0, tiling_.GetNumThreadsPerBlock() - 1}},
+ {},
+ {},
+ {{0, tiling_.GetNumBlocks() - 1}},
+ {{0, static_cast<int64_t>(groups_.grouped_roots.size() - 1)}},
+ {},
+ };
+
+ constexpr int kRowKept = ReductionDimensions::kRowKeptDimension;
+ constexpr int kRowMinorReduced =
+ ReductionDimensions::kRowMinorReducedDimension;
+
+ constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension;
+ constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension;
+ constexpr int kColReduced = ReductionDimensions::kColReducedDimension;
+
+ auto map = [&]() {
+ if (is_row_reduction_) {
+ IndexingMap linear_index(
+ mlir::AffineMap::get(
+ 6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept],
+ ctx),
+ dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
+ int rows_per_warp = GetRowsPerWarp();
+ if (rows_per_warp > 1) {
+ linear_index.AddConstraint(
+ thread_ids[kRowMinorReduced] % (WarpSize() / rows_per_warp),
+ {0, 0});
+ } else {
+ linear_index.AddConstraint(thread_ids[kRowMinorReduced], {0, 0});
+ }
+ return ComposeIndexingMaps(
+ linear_index, GetBitcastMap(ShapeUtil::MakeShape(
+ PRED, {tiling_.GetShape()[kRowKept]}),
+ physical_shape, ctx));
+ }
+
+ mlir::SmallVector<mlir::AffineExpr> projected_dims{
+ block_offsets.getResult(kColMajorKept),
+ block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]};
+ std::vector<RangeVar> range_vars;
+ if (thread_ids.size() == 4) {
+ int vector_size = tiling_.GetThreadTileSize().back();
+ range_vars.push_back({0, vector_size - 1});
+ projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx));
+ }
+ IndexingMap projected_index(
+ mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx),
+ dimension_ranges, range_vars, /*rt_vars=*/{});
+
+ projected_index.AddConstraint(
+ mlir::getAffineDimExpr(
+ KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) %
+ WarpSize(),
+ {0, 0});
+ if (!is_row_reduction_) {
+ projected_index.AddConstraint(
+ projected_index.GetAffineMap().getResult(1),
+ {0, tiling_.GetShape()[ReductionDimensions::kColMinorKeptDimension] -
+ 1});
+ }
+
+ return ComposeIndexingMaps(
+ projected_index,
+ GetBitcastMap(ShapeUtil::DeleteDimension(
+ ReductionDimensions::kColReducedDimension,
+ tiling_.GetXlaShape()),
+ physical_shape, ctx));
+ }();
+
+ AddGroupIdConstraint(map, root_index, groups_);
+ map.Simplify();
+ return map;
+}
+
+std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const {
+ const auto& hero = analysis_.fusion_hero(root_index).instruction();
+ if (groups_.is_reduction_root[root_index] &&
+ hero_operand_index >= hero.operand_count() / 2) {
+ // We don't have indexing for the init values.
+ return std::nullopt;
+ }
+ if (!groups_.is_reduction_root[root_index]) {
+ return ComposeIndexingMaps(
+ *ComputeThreadIdToOutputIndexing(root_index, ctx),
+ *ComputeOutputToInputIndexing(
+ &analysis_.fusion_root(root_index).instruction(), 0, ctx)
+ .indexing_maps[hero_operand_index]
+ .begin());
+ }
+
+ auto map = ComposeIndexingMaps(
+ GetIndexingMapForTiling(tiling_, ctx),
+ GetBitcastMap(tiling_.GetXlaShape(),
+ hero.operand(hero_operand_index)->shape(), ctx));
+ AddGroupIdConstraint(map, root_index, groups_);
+ map.Simplify();
+ return map;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h
new file mode 100644
index 0000000..131b4ec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h
@@ -0,0 +1,190 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_
+
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/fusions/reduction_base.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+
+namespace xla {
+namespace gpu {
+
+class ReductionInfo {
+ public:
+ static ReductionInfo Create(const HloFusionAnalysis& analysis);
+
+ const Tiling& GetTiling() const { return tiling_; }
+ const ReductionGroups& GetGroups() const { return groups_; }
+ Shape GetReduceOperandShape() const {
+ return first_reduce_->operand(0)->shape();
+ }
+
+ bool IsRowReduction() const { return is_row_reduction_; }
+ bool IsRaceFree() const { return is_race_free_; }
+ int GetRowsPerWarp() const;
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const;
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const;
+
+ LaunchDimensions launch_dimensions() const;
+
+ private:
+ ReductionInfo(const HloFusionAnalysis& analysis, Tiling tiling,
+ bool is_row_reduction, bool is_race_free,
+ ReductionGroups groups, const HloInstruction* first_reduce)
+ : analysis_(analysis),
+ tiling_(tiling),
+ is_row_reduction_(is_row_reduction),
+ is_race_free_(is_race_free),
+ groups_(std::move(groups)),
+ first_reduce_(first_reduce) {}
+
+ const HloFusionAnalysis& analysis_;
+ Tiling tiling_;
+ bool is_row_reduction_;
+ bool is_race_free_;
+ ReductionGroups groups_;
+ const HloInstruction* first_reduce_;
+};
+
+// Generates code for reduction to contiguous dimensions.
+//
+// Row reduction uses the following algorithm described in CUDA-like
+// pseudocode:
+//
+// ```
+// __global__ void reduce(int num_rows, float *in, float out) {
+// __shared__ float[32] cache;
+// int offset = blockDim.x * blockIdx.x + threadIdx.x;
+// if (offset >= num_rows) return;
+// int tile_bound = std::min(offset + kTileSizeX, num_rows);
+// float accum = 0;
+// for (int i=offset; i<num_rows; i+= blockDim.x) {
+// accum += in[i];
+// }
+// accum = warp_reduce(accum);
+// if (threadIdx.x % WarpSize == 0) {
+// cache[threadIdx.x / WarpSize] = accum;
+// }
+// __syncthreads();
+// if (threadIdx.x / WarpSize == 0) {
+// bool warp_exists = threadIdx.x < (blockDim.x / WarpSize);
+// float block_accum = warp_exists ? cache[threadIdx.x % WarpSize] : 0;
+// block_accum = warp_reduce(accum);
+// if (threadIdx.x == 0) {
+// out += block_accum;
+// }
+// }
+// }
+// ```
+//
+// Column reduction uses the following algorithm:
+//
+// ```
+// void reduce(float** in, float* out) {
+// __shared__ float[32][33] cache;
+// int thread_id = GetThreadId();
+// int block_id = GetBlockId();
+// int tile_size = 128;
+//
+// float accum = 0;
+// for (int i=0; i<tile_size; i++) {
+// accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x];
+// }
+// cache[thread_id.x][thread_id.y] = accum;
+//
+// __syncthreads();
+// accum = cache[thread_id.y][thread_id.x];
+// accum = warp_reduce(accum); // Sum all the values of `accum` in the same
+// // warp.
+//
+// if (thread_id.y % 32 == 0) {
+// out[block_id * 32 + thread_id.x] = accum;
+// }
+// }
+// ```
+//
+// Moreover, a heuristic is implemented to divide the reduce instructions
+// into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
+// for details about the heuristic.) Reduce instructions in the same group
+// will run sequentially while different groups will run in parallel.
+//
+// we use raw block_id_y to select the reduce groups for execution without
+// complicating the index calculation in the code generation of the reduce
+// instructions. In other words, a block_id_y is assigned to a group and so
+// different groups can be run in parallel.
+class ReductionFusion : public KernelFusionEmitterBase {
+ public:
+ explicit ReductionFusion(const HloFusionAnalysis& analysis)
+ : analysis_(analysis), reduction_info_(ReductionInfo::Create(analysis)) {}
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const override {
+ return reduction_info_.ComputeThreadIdToOutputIndexing(root_index, ctx);
+ }
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const override {
+ return reduction_info_.ComputeThreadIdToInputIndexing(
+ root_index, hero_operand_index, ctx);
+ }
+
+ LaunchDimensions launch_dimensions() const override {
+ return reduction_info_.launch_dimensions();
+ }
+
+ const ReductionInfo& reduction_info() const { return reduction_info_; }
+
+ protected:
+ absl::StatusOr<FusionEmissionResult> EmitInitializers(
+ IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion) const override;
+
+ absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const override;
+
+ private:
+ const HloFusionAnalysis& analysis_;
+ ReductionInfo reduction_info_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc
new file mode 100644
index 0000000..cc0faf4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc
@@ -0,0 +1,176 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/legacy/reduction.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::SizeIs;
+
+class ReductionTest : public HloTestBase {
+ protected:
+ stream_executor::DeviceDescription device_info_ =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(ReductionTest, ThreadIndexingRowReduction) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+ }
+
+ fusion {
+ %input = f32[100,64,512] parameter(0)
+ %c0 = f32[] constant(0)
+ ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add
+ }
+
+ ENTRY entry {
+ %input = f32[100,64,512] parameter(0)
+ ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+ ReductionFusion fusion(analysis);
+
+ EXPECT_THAT(
+ fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
+ d3 floordiv 8,
+ (d3 mod 8) * 8 + d0 floordiv 32,
+ (d0 mod 32) * 2 + s2 * 64 + s3
+ )
+ domain:
+ d0 in [0, 255]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 799]
+ d4 in [0, 0]
+ d5 in [0, 0]
+ s0 in [0, 0]
+ s1 in [0, 0]
+ s2 in [0, 7]
+ s3 in [0, 1]
+ )"));
+ EXPECT_THAT(
+ fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5) -> (
+ d3 floordiv 8,
+ (d3 mod 8) * 8 + d0 floordiv 32
+ )
+ domain:
+ d0 in [0, 224]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 799]
+ d4 in [0, 0]
+ d5 in [0, 0]
+ d0 mod 32 in [0, 0]
+ )"));
+}
+
+TEST_F(ReductionTest, TwoGroups) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+ }
+ fusion {
+ %p0 = f32[2] parameter(0)
+ %p1 = f32[2] parameter(1)
+ %c0 = f32[] constant(-inf)
+ %r0 = f32[] reduce(%p0, %c0), dimensions={0}, to_apply=add
+ %c1 = f32[] constant(inf)
+ %r1 = f32[] reduce(%p1, %c1), dimensions={0}, to_apply=add
+ ROOT %tuple = (f32[], f32[]) tuple(%r0, %r1)
+ }
+ ENTRY entry {
+ %p0 = f32[2] parameter(0)
+ %p1 = f32[2] parameter(1)
+ ROOT %fusion = (f32[], f32[]) fusion(%p0, %p1), kind=kInput, calls=fusion
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+ ReductionFusion fusion(analysis);
+
+ EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots,
+ ElementsAre(ElementsAre(&analysis.fusion_root(0).instruction()),
+ ElementsAre(&analysis.fusion_root(1).instruction())));
+}
+
+TEST_F(ReductionTest, OneGroup) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ %add {
+ %p0 = c128[] parameter(0)
+ %p1 = c128[] parameter(1)
+ ROOT %add.35 = c128[] add(c128[] %p0, c128[] %p1)
+ }
+ %fusion {
+ %p0 = c128[1,2] parameter(0)
+ %c0 = c128[] constant((0, 0))
+ %reduce = c128[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
+ %real = f64[] real(c128[] %reduce)
+ %imag = f64[] imag(c128[] %reduce)
+ %negate = f64[] negate(f64[] %imag)
+ ROOT %tuple.29 = (f64[], f64[]) tuple(f64[] %real, f64[] %negate)
+ }
+ ENTRY entry {
+ %p0 = c128[1,2] parameter(0)
+ ROOT %fusion = (f64[], f64[]) fusion(%p0), kind=kInput, calls=fusion
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+ ReductionFusion fusion(analysis);
+
+ EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, SizeIs(2));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc
new file mode 100644
index 0000000..0798788
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc
@@ -0,0 +1,294 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/scatter.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/fusions/legacy/loop.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/ir_emitter_nested.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+ScatterFusion::ScatterFusion(const HloFusionAnalysis& analysis)
+ : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {
+ CHECK_EQ(analysis.fusion_root_count(), 1);
+ CHECK_EQ(analysis.fusion_root(0).opcode(), HloOpcode::kScatter);
+}
+
+LaunchDimensions ScatterFusion::launch_dimensions() const {
+ const auto& updates_shape =
+ analysis_.fusion_root(0).instruction().operands().back()->shape();
+ return CalculateLaunchDimensions(updates_shape, analysis_.device_info());
+}
+
+absl::Status ScatterFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const {
+ GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+ // Spin up a new fused emitter for the scatter kernel and emit it.
+ FusedIrEmitter scatter_fused_emitter(elemental_emitter);
+ auto* fused_computation = fusion.fused_instructions_computation();
+ for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ auto fused_operand = fused_computation->parameter_instruction(i);
+ scatter_fused_emitter.BindGenerator(
+ *fused_operand, [builder, &input = inputs[i],
+ fused_operand](llvm_ir::IrArray::Index index) {
+ return input.EmitReadArrayElement(index, builder,
+ fused_operand->name());
+ });
+ }
+
+ auto* root = fused_computation->root_instruction();
+ const xla::ScatterDimensionNumbers& scatter_dims =
+ Cast<HloScatterInstruction>(root)->scatter_dimension_numbers();
+
+ std::string name = llvm_ir::IrName(root);
+ const Shape& operand_shape = root->operand(0)->shape();
+ const Shape& scatter_indices_shape = root->operand(1)->shape();
+ const Shape& updates_shape = root->operand(2)->shape();
+ const HloComputation& update_computation = *root->called_computations()[0];
+
+ TF_ASSIGN_OR_RETURN(auto scatter_indices_gen,
+ scatter_fused_emitter.GetGenerator(*root->operand(1)));
+ TF_ASSIGN_OR_RETURN(auto updates_gen,
+ scatter_fused_emitter.GetGenerator(*root->operand(2)));
+
+ auto loop_body_emitter =
+ [&](const llvm_ir::IrArray::Index& index) -> absl::Status {
+ std::vector<llvm::Value*> raw_window_multidim;
+ std::vector<llvm::Value*> input_scatter_multidim;
+ std::vector<int64_t> raw_window_bounds;
+
+ auto get_i64_array = [](absl::Span<const int64_t> container) {
+ return llvm::ArrayRef<int64_t>{container.data(),
+ static_cast<size_t>(container.size())};
+ };
+
+ llvm::ArrayRef<int64_t> update_window_dims =
+ get_i64_array(scatter_dims.update_window_dims());
+ // Partition the index into window indices and scatter indices.
+ for (int64_t i = 0, e = index.size(); i != e; ++i) {
+ // For window indices also remember the window size, this comes in handy
+ // later.
+ if (llvm::is_contained(update_window_dims, i)) {
+ raw_window_multidim.push_back(index[i]);
+ raw_window_bounds.push_back(updates_shape.dimensions(i));
+ } else {
+ input_scatter_multidim.push_back(index[i]);
+ }
+ }
+ DCHECK_EQ(raw_window_multidim.size(),
+ scatter_dims.update_window_dims_size());
+
+ // Apply inserted_window_dims to the window dimensions.
+ int64_t raw_window_multidim_idx = 0;
+ llvm::SmallVector<llvm::Value*> input_window_multidim;
+ llvm::SmallVector<int64_t> input_window_bounds;
+ const int64_t rank = operand_shape.rank();
+ input_window_bounds.reserve(rank);
+ input_window_multidim.reserve(rank);
+
+ llvm::ArrayRef<int64_t> inserted_window_dims =
+ get_i64_array(scatter_dims.inserted_window_dims());
+ for (int64_t i = 0; i != rank; ++i) {
+ if (llvm::is_contained(inserted_window_dims, i)) {
+ input_window_bounds.push_back(1); // Trivial dimension.
+ input_window_multidim.push_back(index.GetConstantWithIndexType(0));
+ } else {
+ input_window_bounds.push_back(
+ raw_window_bounds[raw_window_multidim_idx]);
+ input_window_multidim.push_back(
+ raw_window_multidim[raw_window_multidim_idx]);
+ ++raw_window_multidim_idx;
+ }
+ }
+ DCHECK_EQ(input_window_multidim.size(), operand_shape.rank());
+
+ // Insert a 1 dimension at the end if index_vector_dim requests one.
+ Shape scatter_indices_shape_fixed = scatter_indices_shape;
+ if (scatter_dims.index_vector_dim() == scatter_indices_shape.rank()) {
+ scatter_indices_shape_fixed.add_dimensions(1);
+ scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
+ scatter_dims.index_vector_dim());
+ }
+
+ // Now load the indices corresponding to the current window from
+ // scatter_indices.
+ std::vector<llvm::Value*> raw_scatter_index_multidim =
+ input_scatter_multidim;
+ raw_scatter_index_multidim.insert(
+ raw_scatter_index_multidim.begin() + scatter_dims.index_vector_dim(),
+ nullptr);
+
+ llvm::ArrayRef<int64_t> scatter_dims_to_operand_dims =
+ get_i64_array(scatter_dims.scatter_dims_to_operand_dims());
+ llvm::Value* is_in_bounds = builder->getTrue();
+ for (int64_t i = 0, e = scatter_dims_to_operand_dims.size(); i != e; ++i) {
+ // Our index is stored along index_vector_dim, insert that into the lookup
+ // index into scatter_indices.
+ raw_scatter_index_multidim[scatter_dims.index_vector_dim()] =
+ index.GetConstantWithIndexType(i);
+ llvm_ir::IrArray::Index raw_scatter_index_index(
+ raw_scatter_index_multidim, scatter_indices_shape_fixed,
+ index.GetType());
+
+ int64_t operand_dim = scatter_dims_to_operand_dims[i];
+ if (operand_dim > rank) {
+ return absl::OutOfRangeError(
+ "The provided scatter_dims_to_operand_dims was out of range.");
+ }
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value* const loaded_scatter_index,
+ scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
+ scatter_indices_shape_fixed, scatter_indices_shape, builder)));
+ // And add the index to our window index. This yields the output index.
+ llvm::Value* casted_scatter_index = builder->CreateIntCast(
+ loaded_scatter_index, index.GetType(),
+ /*isSigned=*/ShapeUtil::ElementIsSigned(scatter_indices_shape));
+ llvm::Value* dim_offset = builder->CreateAdd(
+ input_window_multidim[operand_dim], casted_scatter_index);
+ input_window_multidim[operand_dim] = dim_offset;
+
+ // Also do the bounds check now.
+ int64_t max_index = operand_shape.dimensions(operand_dim) -
+ input_window_bounds[operand_dim] + 1;
+ // is_in_bounds = index >= 0 && index < dim_size-window_size+1
+ // --> index u< dim_size-window_size+1
+ is_in_bounds = builder->CreateAnd(
+ is_in_bounds,
+ builder->CreateICmpULT(casted_scatter_index,
+ index.GetConstantWithIndexType(max_index)));
+ }
+
+ llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
+ is_in_bounds, "scatter.in_bounds", builder, /*emit_else=*/false);
+ llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block,
+ builder);
+ // All done, now just read from the calculated input from the window, and do
+ // an atomic store to the calculated location in the output.
+ llvm_ir::IrArray::Index input_window_index(
+ input_window_multidim, outputs.back().GetShape(), index.GetType());
+ llvm::Value* output_address =
+ outputs.back().EmitArrayElementAddress(input_window_index, builder);
+ llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(updates_shape.element_type(),
+ ir_emitter_context.llvm_module()),
+ "input_address", builder);
+ TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
+ builder->CreateStore(input_ir_value, input_address);
+
+ if (root->unique_indices()) {
+ return CallNestedComputation(
+ builder, ir_emitter_context, update_computation,
+ {output_address, input_address}, output_address);
+ }
+ return EmitAtomicOperationForNestedComputation(
+ builder, ir_emitter_context, update_computation, output_address,
+ input_address, outputs.back().GetElementLlvmType());
+ };
+
+ // Launch a kernel that reads every element in the updates tensor. We could
+ // also do one kernel per window instead if bounds checks turn out to be a
+ // bottleneck.
+ auto index_type =
+ GetIndexTypeForKernel(root, launch_dims.launch_bound(), builder);
+ return ParallelLoopEmitter(loop_body_emitter, updates_shape, launch_dims,
+ builder)
+ .EmitLoop(name, index_type);
+}
+
+std::optional<IndexingMap> ScatterFusion::ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const {
+ const auto* scatter =
+ DynCast<HloScatterInstruction>(&analysis_.fusion_hero(0).instruction());
+ int64_t scatter_operand_count = scatter->scatter_operand_count();
+ // Scatter operands a packed in the following way:
+ // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`.
+ // Operand ID scatter_operand_count for `scatter indices`.
+ // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for
+ // `scatter updates`.
+
+ // For scatter operands we do not know the thread ID indexing.
+ if (hero_operand_index < scatter_operand_count) {
+ return std::nullopt;
+ }
+ // Compute thread id mapping based on the first update operand.
+ Shape scatter_update_shape = scatter->scatter_updates().front()->shape();
+ IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap(
+ launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx);
+
+ // For scatter indices we project indexing for scatter updates and take the
+ // first result of the affine map only, because they coincide.
+ if (hero_operand_index == scatter_operand_count) {
+ Shape scatter_indices_shape = scatter->scatter_indices()->shape();
+ CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString();
+ // Create a map from scatter update to scatter indices.
+ IndexingMap updates_to_indices_map{
+ mlir::AffineMap::get(
+ /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1,
+ {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)},
+ ctx),
+ DimVarsFromTensorSizes(scatter_update_shape.dimensions()),
+ RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}),
+ /*rt_vars=*/{}};
+ auto scatter_indices_map = scatter_update_map * updates_to_indices_map;
+ scatter_indices_map.Simplify();
+ return scatter_indices_map;
+ }
+ return scatter_update_map;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h
new file mode 100644
index 0000000..862d0b3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h
@@ -0,0 +1,71 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_
+
+#include <optional>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// A scatter, implemented as a loop over the updates. All scatters are in-place.
+class ScatterFusion : public KernelFusionEmitterBase {
+ public:
+ explicit ScatterFusion(const HloFusionAnalysis& analysis);
+
+ LaunchDimensions launch_dimensions() const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const override {
+ // The kernel iterates over updates, whose correspondence to output
+ // elements cannot be computed statically.
+ return std::nullopt;
+ }
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const override;
+
+ protected:
+ absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const override;
+
+ private:
+ const HloFusionAnalysis& analysis_;
+ LaunchDimensionsConfig config_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc
new file mode 100644
index 0000000..71eea76
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc
@@ -0,0 +1,224 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/scatter.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class ScatterFusionTest : public HloTestBase {
+ public:
+ void SetUp() override {
+ HloTestBase::SetUp();
+ printer_ =
+ AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+ {"chunk_id", "unroll_id", "index_id"});
+ }
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.set_xla_gpu_mlir_emitter_level(0);
+ return opts;
+ }
+
+ protected:
+ AffineMapPrinter printer_;
+ mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(ScatterFusionTest, ScatterFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ add (lhs: f32[], rhs: f32[]) -> f32[] {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT sum = f32[] add(lhs, rhs)
+ }
+
+ fused_computation {
+ %input = f32[2,9] parameter(0)
+ %indices = s32[3] parameter(1)
+ %updates = f32[3,9] parameter(2)
+ ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
+ to_apply=add,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ }
+
+ ENTRY entry {
+ %input = f32[2,9] parameter(0)
+ %indices = s32[3] parameter(1)
+ %updates = f32[3,9] parameter(2)
+ ROOT %fusion = f32[2,9] fusion(%input, %indices, %updates), kind=kLoop, calls=fused_computation
+ })")
+ .value();
+
+ stream_executor::DeviceDescription device_info =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+ auto emitter =
+ GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+ auto scatter_fusion = dynamic_cast<ScatterFusion*>(emitter.get());
+ ASSERT_NE(scatter_fusion, nullptr);
+ EXPECT_EQ(scatter_fusion->launch_dimensions().launch_bound(),
+ 3 * 9 /* updates size */);
+}
+
+TEST_F(ScatterFusionTest, ThreadIdIndexing) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ computation {
+ %p0 = f32[] parameter(0)
+ %p1 = f32[] parameter(1)
+ %p2 = f32[] parameter(2)
+ %p3 = f32[] parameter(3)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %p2, f32[] %p3)
+ }
+ scatter {
+ %operand0 = f32[300,200] parameter(0)
+ %operand1 = f32[300,200] parameter(1)
+ %indices = s32[42,1] parameter(2)
+ %update.1 = f32[42,10,20] parameter(3)
+ %update.2 = f32[42,10,20]parameter(4)
+
+ ROOT %scatter = (f32[300,200], f32[300,200]) scatter(
+ f32[300,200] %operand0,
+ f32[300,200] %operand1,
+ s32[42,1] %indices,
+ f32[42,10,20] %update.1,
+ f32[42,10,20] %update.2
+ ),
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ to_apply=computation
+ }
+ ENTRY entry {
+ %operand0 = f32[300,200] parameter(0)
+ %operand1 = f32[300,200] parameter(1)
+ %indices = s32[42,1] parameter(2)
+ %update.1 = f32[42,10,20] parameter(3)
+ %update.2 = f32[42,10,20]parameter(4)
+ ROOT %fusion = (f32[300,200], f32[300,200]) fusion(
+ %operand0, %operand1, %indices, %update.1, %update.2),
+ kind=kLoop, calls=scatter
+ }
+ )"));
+ stream_executor::DeviceDescription device_info =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+ auto emitter =
+ GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+ auto fusion = dynamic_cast<ScatterFusion*>(emitter.get());
+ ASSERT_NE(fusion, nullptr);
+
+ constexpr auto kUpdatesIndexing = R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+ (bl_x * 128 + th_x) floordiv 200,
+ ((bl_x * 128 + th_x) floordiv 20) mod 10,
+ (bl_x * 128 + th_x) mod 20
+ )
+ domain:
+ th_x in [0, 127]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 65]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ bl_x * 128 + th_x in [0, 8399]
+ )";
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kUpdatesIndexing));
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kUpdatesIndexing));
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kUpdatesIndexing));
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kUpdatesIndexing));
+
+ constexpr auto kIndicesIndexing = R"(
+ (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] ->
+ ((bl_x * 128 + th_x) floordiv 200, 0)
+ domain:
+ th_x in [0, 127]
+ th_y in [0, 0]
+ th_z in [0, 0]
+ bl_x in [0, 65]
+ bl_y in [0, 0]
+ bl_z in [0, 0]
+ chunk_id in [0, 0]
+ unroll_id in [0, 0]
+ index_id in [0, 0]
+ bl_x * 128 + th_x in [0, 8399]
+ )";
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kIndicesIndexing));
+ EXPECT_THAT(
+ fusion
+ ->ComputeThreadIdToInputIndexing(
+ /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_)
+ ->ToString(printer_),
+ MatchIndexingString(kIndicesIndexing));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc
new file mode 100644
index 0000000..a1a7acb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc
@@ -0,0 +1,351 @@
+/*Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+
+#include <cstdint>
+#include <limits>
+#include <string>
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/target_util.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/kernel_support_library.h"
+#include "xla/service/llvm_ir/llvm_loop.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using mlir::AffineExpr;
+using mlir::AffineMap;
+using mlir::MLIRContext;
+
+void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling,
+ int dim, absl::InlinedVector<llvm::Value*, 4> tile_idx,
+ absl::Span<llvm::Value* const> tile_dimensions,
+ llvm::IRBuilder<>* b, const TileElementGenerator& emit_elem) {
+ llvm::Type* index_ty = thread_id_info.thread_id->getType();
+ auto constant = [&](int64_t val) {
+ return llvm::ConstantInt::get(index_ty, val);
+ };
+
+ auto recurse = [&] {
+ if (dim == tile_idx.size() - 1) {
+ emit_elem(tile_idx);
+ } else {
+ EmitTileRec(thread_id_info, tiling, dim + 1, tile_idx, tile_dimensions, b,
+ emit_elem);
+ }
+ };
+
+ bool unroll = tiling.GetLoopsToUnroll()[dim];
+ KernelSupportLibrary ksl(b, unroll ? llvm_ir::UnrollMode::kFullyUnroll
+ : llvm_ir::UnrollMode::kDefaultUnroll);
+
+ if (tiling.GetBlockTileSize()[dim] == 1) {
+ tile_idx[dim] = constant(0);
+ recurse();
+ } else if (unroll) {
+ // TODO(jreiffers): Check if this unrolling does anything useful.
+ int64_t stride = tiling.GetThreadsPerBlock()[dim];
+ int64_t dim_size = tiling.GetThreadTileSize()[dim];
+
+ auto make_loop = [&](bool emit_bounds_checks) {
+ auto body = [&, emit_bounds_checks](llvm::Value* i) {
+ tile_idx[dim] = b->CreateAdd(i, thread_id_info.thread_ids[dim]);
+ if (emit_bounds_checks) {
+ auto* in_bounds =
+ b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]);
+ ksl.If("x_in_tile", in_bounds, recurse);
+ } else {
+ recurse();
+ }
+ };
+ return [&, body] {
+ ksl.For(absl::StrCat("loop", dim), constant(0),
+ constant(dim_size * stride), constant(stride), body);
+ };
+ };
+ if (stride > 1 && dim_size > 1) {
+ // Most tiles will be full, so we emit a single bounds check for those.
+ auto* is_full_tile = b->CreateICmpEQ(
+ constant(tiling.GetBlockTileSize()[dim]), tile_dimensions[dim]);
+ ksl.If("is_full_tile", is_full_tile, make_loop(false), make_loop(true));
+ } else {
+ make_loop(true)();
+ }
+ } else {
+ // All dimensions are strided (thread 0 processes elements 0, num_threads,
+ // num_threads+2, ...; thread 1 processes elements 1, num_threads + 1 and so
+ // on).
+ ksl.For(absl::StrCat("loop", dim), /*start=*/thread_id_info.thread_ids[dim],
+ /*end=*/tile_dimensions[dim],
+ /*step=*/tiling.GetThreadsPerBlock()[dim], [&](llvm::Value* i) {
+ tile_idx[dim] = i;
+ recurse();
+ });
+ }
+}
+
+} // namespace
+
+void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
+ const TilingThreadIdInfo& thread_id_info,
+ absl::Span<llvm::Value* const> tile_dimensions,
+ const TileElementGenerator& emit_elem_function) {
+ absl::InlinedVector<llvm::Value*, 4> tile_idx(tiling.GetShape().size());
+ EmitTileRec(thread_id_info, tiling, 0, tile_idx, tile_dimensions, builder,
+ emit_elem_function);
+}
+
+namespace {
+
+// Emits current block id.
+llvm::Value* EmitBlockId(llvm::IRBuilder<>* builder, int32_t num_blocks,
+ llvm::Type* index_ty) {
+ llvm::Value* block_id =
+ EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, builder);
+ if (num_blocks != 0) {
+ llvm_ir::AddRangeMetadata(0, num_blocks,
+ llvm::cast<llvm::Instruction>(block_id),
+ builder->GetInsertBlock()->getModule());
+ }
+ auto ret = builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true);
+ ret->setName("block.id.x");
+ return ret;
+}
+
+// Emits current thread id with the given type.
+//
+// Sets the return value range to [0, threads_per_block).
+llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block,
+ llvm::Type* index_ty) {
+ // Calculate (y, x) coordinates respectively in the 2D view of thread block,
+ // defined by (num_thread_y, num_thread_x) from thread_id.
+ llvm::CallInst* thread_id =
+ EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id,
+ builder->GetInsertBlock()->getModule());
+ auto ret = builder->CreateIntCast(thread_id, index_ty, /*isSigned=*/true);
+ ret->setName("thread.id.x");
+ return ret;
+}
+
+// Emits the LLVM values for thread_id, block_id, coordinates of the current
+// tile and strides of the loops to iterate over the current tile.
+absl::StatusOr<TilingThreadIdInfo> EmitThreadIdInfo(llvm::IRBuilder<>* builder,
+ const Tiling& tiling,
+ llvm::Type* index_ty) {
+ auto constant = [&](uint64_t c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+ int64_t num_blocks = tiling.GetNumBlocks();
+ if (num_blocks > (int64_t)std::numeric_limits<uint32_t>::max()) {
+ return FailedPrecondition(
+ "Number of physical blocks (%d) does not fit in an i32 in tiling "
+ "scheme: %s",
+ num_blocks, tiling.ToString());
+ }
+
+ TilingThreadIdInfo info;
+ info.thread_id =
+ EmitThreadId(builder, tiling.GetNumThreadsPerBlock(), index_ty);
+ info.block_id = EmitBlockId(builder, num_blocks, index_ty);
+
+ for (auto [dim, stride] : llvm::enumerate(tiling.GetThreadStrides())) {
+ int64_t size = tiling.GetThreadsPerBlock()[dim];
+ if (size == 1) {
+ info.thread_ids.emplace_back(constant(0));
+ } else {
+ auto& dim_id = info.thread_ids.emplace_back(info.thread_id);
+ if (stride > 1) {
+ dim_id = builder->CreateUDiv(dim_id, constant(stride));
+ }
+ if (dim) {
+ dim_id = builder->CreateURem(dim_id, constant(size));
+ }
+ dim_id->setName(absl::StrCat("thread.id.", dim));
+ }
+ }
+
+ info.lane_id =
+ builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id");
+ return info;
+}
+
+AffineMap GetTilingAffineMap(llvm::ArrayRef<AffineExpr> exprs,
+ int64_t num_symbols) {
+ return AffineMap::get(
+ /*dimCount=*/6, /*symbolCount=*/num_symbols, exprs,
+ exprs[0].getContext());
+}
+
+} // namespace
+
+absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
+ llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
+ const TileGenerator& tile_element_generator) {
+ absl::Span<const int64_t> dims_in_elems = tiling.GetShape();
+ const auto& block_counts = tiling.GetBlockCounts();
+ auto constant = [&](uint64_t c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+
+ TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info,
+ EmitThreadIdInfo(builder, tiling, index_ty));
+
+ KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
+
+ const llvm_ir::IrArray::Index block_coords(
+ thread_id_info.block_id,
+ ShapeUtil::MakeShape(PRED /*arbitrary*/, block_counts), builder);
+
+ absl::InlinedVector<llvm::Value*, 4> tile_dimensions;
+ for (int i = 0; i < block_counts.size(); ++i) {
+ int64_t block_tile_size = tiling.GetBlockTileSize()[i];
+ if (dims_in_elems[i] % block_tile_size == 0) {
+ // The block tile size evenly divides the tiled shape -> no need to emit
+ // the bounds check.
+ tile_dimensions.push_back(constant(block_tile_size));
+ } else {
+ // Only the last tile in each dimension may not have full size.
+ llvm::Value* is_last =
+ builder->CreateICmpEQ(block_coords[i], constant(block_counts[i] - 1));
+ int64_t partial_row =
+ dims_in_elems[i] - (block_counts[i] - 1) * block_tile_size;
+ tile_dimensions.push_back(builder->CreateSelect(
+ is_last, constant(partial_row), constant(block_tile_size),
+ absl::StrCat("tile_bound.", i)));
+ }
+ }
+
+ llvm_ir::IrArray::Index tile_offset = [&] {
+ std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
+ llvm::Type* index_ty = block_coords.GetType();
+ for (int i = 0; i < block_counts.size(); ++i) {
+ elem_multi_index[i] = builder->CreateMul(
+ block_coords[i],
+ llvm::ConstantInt::get(index_ty, tiling.GetBlockTileSize()[i]),
+ absl::StrCat("tile_origin.", i));
+ }
+ return llvm_ir::IrArray::Index(elem_multi_index, tiling.GetShape(),
+ index_ty);
+ }();
+
+ tile_element_generator(thread_id_info, tile_offset, tile_dimensions);
+ return {{tile_dimensions, tile_offset, thread_id_info}};
+}
+
+AffineMap GetBlockOffsetsForTiling(
+ absl::Span<const int64_t> num_blocks,
+ absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
+ MLIRContext* mlir_context) {
+ auto offsets =
+ DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), num_blocks);
+ for (auto&& [offset, tile_size] : llvm::zip(offsets, tile_sizes_per_block)) {
+ offset = offset * tile_size;
+ }
+ return GetTilingAffineMap(offsets, rank);
+}
+
+AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
+ MLIRContext* mlir_context) {
+ return GetBlockOffsetsForTiling(tiling.GetBlockCounts(),
+ tiling.GetBlockTileSize(),
+ tiling.GetShape().size(), mlir_context);
+}
+
+AffineMap GetThreadOffsetsForTiling(
+ absl::Span<const int64_t> num_threads,
+ absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
+ MLIRContext* mlir_context) {
+ auto offsets =
+ DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), num_threads);
+ for (int dim = 0; dim < rank; ++dim) {
+ if (tile_sizes_per_thread[dim] > 1) {
+ offsets[dim] = offsets[dim] +
+ getAffineSymbolExpr(dim, mlir_context) * num_threads[dim];
+ }
+ }
+ return GetTilingAffineMap(offsets, rank);
+}
+
+AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
+ MLIRContext* mlir_context) {
+ return GetThreadOffsetsForTiling(tiling.GetThreadsPerBlock(),
+ tiling.GetThreadTileSize(),
+ tiling.GetShape().size(), mlir_context);
+}
+
+IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
+ MLIRContext* mlir_context) {
+ return GetIndexingMapForTiling(
+ GetBlockOffsetsForTiling(tiling, mlir_context),
+ GetThreadOffsetsForTiling(tiling, mlir_context),
+ tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(),
+ tiling.GetThreadTileSize(), tiling.GetShape());
+}
+
+IndexingMap GetIndexingMapForTiling(AffineMap block_offsets,
+ AffineMap thread_offsets,
+ int64_t threads_per_block,
+ int64_t num_blocks,
+ absl::Span<const int64_t> thread_tile_sizes,
+ absl::Span<const int64_t> tiled_shape) {
+ auto* mlir_context = block_offsets.getContext();
+ llvm::SmallVector<AffineExpr, 4> offsets;
+ offsets.reserve(block_offsets.getNumResults());
+ for (auto [block, thread] :
+ llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) {
+ offsets.push_back(block + thread);
+ }
+ std::vector<DimVar> dimension_ranges{
+ {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {},
+ };
+ auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(),
+ block_offsets.getNumSymbols(), offsets,
+ mlir_context);
+ IndexingMap map{affine_map, dimension_ranges,
+ RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}};
+ for (int i = 0; i < tiled_shape.size(); ++i) {
+ map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1});
+ }
+ return map;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h
new file mode 100644
index 0000000..de367e3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h
@@ -0,0 +1,215 @@
+/*Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_
+
+#include <cstdint>
+#include <functional>
+#include <string>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+// Describes tiling used by the kernel.
+//
+// Used by reduction and transpose emitters.
+class Tiling {
+ public:
+ Tiling(absl::Span<const int64_t> shape, absl::Span<const int64_t> tile_sizes,
+ absl::Span<const int64_t> num_threads,
+ // By default, don't unroll anything.
+ absl::InlinedVector<bool, 4> loops_to_unroll = {})
+ : shape_{shape.begin(), shape.end()},
+ tile_sizes_per_thread_{tile_sizes.begin(), tile_sizes.end()},
+ tile_sizes_per_block_(shape.size()),
+ num_threads_{num_threads.begin(), num_threads.end()},
+ num_blocks_(shape.size()),
+ loops_to_unroll_(loops_to_unroll) {
+ for (int64_t i = 0; i < shape.size(); ++i) {
+ tile_sizes_per_block_[i] = tile_sizes[i] * num_threads[i];
+ CHECK_NE(tile_sizes_per_block_[i], 0);
+ num_blocks_[i] = CeilOfRatio(shape[i], tile_sizes_per_block_[i]);
+ CHECK_NE(num_blocks_[i], 0);
+ }
+ if (loops_to_unroll_.empty()) loops_to_unroll_.resize(shape.size());
+ }
+ Tiling() = default;
+
+ std::string ToString() const {
+ return absl::StrJoin(
+ {absl::StrFormat("shape = {%s}", absl::StrJoin(shape_, ", ")),
+ absl::StrFormat("tile_sizes = {%s}",
+ absl::StrJoin(tile_sizes_per_thread_, ", ")),
+ absl::StrFormat("num_threads = {%s}",
+ absl::StrJoin(num_threads_, ", "))},
+ ", ");
+ }
+
+ // Number of elements in each dimension.
+ const absl::InlinedVector<int64_t, 4>& GetShape() const { return shape_; }
+ xla::Shape GetXlaShape(PrimitiveType element_type = F32) const {
+ return ShapeUtil::MakeShape(element_type, shape_);
+ }
+
+ const absl::InlinedVector<int64_t, 4>& GetBlockCounts() const {
+ return num_blocks_;
+ }
+
+ // Tile size for each thread.
+ //
+ // Equals to the number of iterations in the loop each tile will make.
+ const absl::InlinedVector<int64_t, 4>& GetThreadTileSize() const {
+ return tile_sizes_per_thread_;
+ }
+
+ // Tile size for an entire thread block.
+ const absl::InlinedVector<int64_t, 4>& GetBlockTileSize() const {
+ return tile_sizes_per_block_;
+ }
+
+ const absl::InlinedVector<int64_t, 4>& GetThreadsPerBlock() const {
+ return num_threads_;
+ }
+
+ // Returns the strides of the thread index dimensions wrt. the linear thread
+ // id.
+ absl::InlinedVector<int64_t, 4> GetThreadStrides() const {
+ return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_threads_));
+ }
+
+ int64_t GetNumThreadsPerBlock() const { return Product(num_threads_); }
+
+ int64_t GetNumBlocks() const { return Product(num_blocks_); }
+
+ const absl::InlinedVector<bool, 4>& GetLoopsToUnroll() const {
+ return loops_to_unroll_;
+ }
+
+ private:
+ // The number of elements in each dimension.
+ absl::InlinedVector<int64_t, 4> shape_;
+
+ // The number of elements for each dimension of a tile.
+ absl::InlinedVector<int64_t, 4> tile_sizes_per_thread_;
+ absl::InlinedVector<int64_t, 4> tile_sizes_per_block_;
+
+ absl::InlinedVector<int64_t, 4> num_threads_;
+ absl::InlinedVector<int64_t, 4> num_blocks_;
+
+ absl::InlinedVector<bool, 4> loops_to_unroll_;
+};
+
+struct TilingThreadIdInfo {
+ llvm::Value* thread_id;
+
+ absl::InlinedVector<llvm::Value*, 4> thread_ids;
+
+ // Lane id: `thread_id % WarpSize`
+ llvm::Value* lane_id;
+
+ // Block id.
+ llvm::Value* block_id;
+};
+
+struct TilingKernelInfo {
+ // Tiling bounds.
+ absl::InlinedVector<llvm::Value*, 4> output_tile_bounds;
+
+ // Starting tile, as calculated from block id only.
+ llvm_ir::IrArray::Index tile_origin;
+
+ // Thread meta-info.
+ TilingThreadIdInfo thread_id_info;
+};
+
+// A function to generate the code to emit the entire tile.
+//
+// index: Absolute coordinate of the start of the tile in input.
+// tile_dimensions: Size of the tile
+using TileGenerator =
+ std::function<void(const TilingThreadIdInfo& thread_id_info,
+ const llvm_ir::IrArray::Index& tile_start_index,
+ absl::Span<llvm::Value* const> tile_dimensions)>;
+
+// A function object to generate code to process one element in a tile.
+//
+// index_in_tile: the current coordinates within the tile. To get the global
+// coordinates, use `tile_start_index.AddOffset(index_in_tile, ...)`.
+using TileElementGenerator =
+ std::function<void(absl::Span<llvm::Value* const> index_in_tile)>;
+
+// Emits code to iterate through a tile with given tile dimensions and generate
+// elements using the callback.
+void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
+ const TilingThreadIdInfo& thread_id_info,
+ absl::Span<llvm::Value* const> tile_dimensions,
+ const TileElementGenerator& emit_elem_function);
+
+// Emits a kernel for the hlo instruction using the given kernel mapping
+// scheme.
+absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
+ llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
+ const TileGenerator& tile_element_generator);
+
+// Creates an indexing map from thread and block IDs to elements of the tiled
+// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2
+// are thread indices (currently only 0 is used), dimensions 3 to 5 are block
+// indices (currently only 3 is used).
+mlir::AffineMap GetBlockOffsetsForTiling(
+ absl::Span<const int64_t> num_blocks,
+ absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
+ mlir::MLIRContext* mlir_context);
+mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
+ mlir::MLIRContext* mlir_context);
+mlir::AffineMap GetThreadOffsetsForTiling(
+ absl::Span<const int64_t> num_threads,
+ absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
+ mlir::MLIRContext* mlir_context);
+mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
+ mlir::MLIRContext* mlir_context);
+
+// Convenience functions for the two functions above
+// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up
+// the ranges of dimensions and symbols.
+IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
+ mlir::MLIRContext* mlir_context);
+IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets,
+ mlir::AffineMap thread_offsets,
+ int64_t threads_per_block,
+ int64_t num_blocks,
+ absl::Span<const int64_t> thread_tile_sizes,
+ absl::Span<const int64_t> tiled_shape);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc
new file mode 100644
index 0000000..d6cbdec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc
@@ -0,0 +1,366 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/transpose.h"
+
+#include <array>
+#include <cstdint>
+#include <optional>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "mlir/IR/AffineMap.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/target_util.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/service/llvm_ir/loop_emitter.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info,
+ const TransposeDescription& tiled_transpose) {
+ constexpr int kNumRows = 4;
+ static_assert(WarpSize() % kNumRows == 0);
+
+ // 3D view over the output shape.
+ absl::InlinedVector<int64_t, 3> transposed_dims = tiled_transpose.dimensions;
+ absl::InlinedVector<int64_t, 3> permutation = tiled_transpose.permutation;
+
+ // Note: the supported permutations are their own inverses. Therefore we
+ // always use the permutation, even when we want the inverse.
+ CHECK((permutation == absl::InlinedVector<int64_t, 3>{0, 2, 1}) ||
+ (permutation == absl::InlinedVector<int64_t, 3>{2, 1, 0}));
+
+ absl::InlinedVector<int64_t, 4> input_dims{transposed_dims[permutation[0]],
+ transposed_dims[permutation[1]],
+ transposed_dims[permutation[2]]};
+
+ // We tile along the minor dimensions pre- and post-transpose.
+ absl::InlinedVector<int64_t, 4> tile_sizes{1, 1, 1};
+ tile_sizes[permutation[2]] = WarpSize() / kNumRows;
+ absl::InlinedVector<int64_t, 4> num_threads{1, 1, WarpSize()};
+ num_threads[permutation[2]] = kNumRows;
+
+ auto capability = gpu_device_info.gpu_compute_capability();
+ std::visit(
+ [&](const auto& capability) {
+ if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
+ stream_executor::RocmComputeCapability>) {
+ // kNumRows = 8 works well on MI300 with wavefront size 64.
+ if (capability.gfx9_mi300()) {
+ tile_sizes[permutation[2]] = gpu_device_info.threads_per_warp() / 8;
+ num_threads[permutation[2]] = 8;
+ }
+ }
+ },
+ capability);
+
+ return Tiling(input_dims, tile_sizes, num_threads);
+}
+
+void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
+ IrEmitterContext& ir_emitter_context) {
+ auto* module = builder->GetInsertBlock()->getModule();
+ if (IsAMDGPU(module) &&
+ ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
+ builder->CreateFence(
+ llvm::AtomicOrdering::SequentiallyConsistent,
+ builder->getContext().getOrInsertSyncScopeID("workgroup"));
+ }
+}
+
+void EmitSyncThreads(llvm::IRBuilder<>* builder,
+ IrEmitterContext& ir_emitter_context) {
+ MaybeEmitFenceForAMDGPU(builder, ir_emitter_context);
+ EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder);
+}
+
+llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index,
+ absl::Span<const int64_t> permutation) {
+ return llvm_ir::IrArray::Index{Permute(index.multidim(), permutation),
+ Permute(index.dims(), permutation),
+ index.GetType()};
+}
+
+} // namespace
+
+TransposeFusion::TransposeFusion(const se::DeviceDescription& gpu_device_info,
+ const HloFusionAnalysis& analysis)
+ : analysis_(analysis),
+ tiling_(
+ ComputeTransposeTiling(gpu_device_info, analysis.tiled_transpose())) {
+ for (auto [root, hero] :
+ llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) {
+ if (auto transpose = GetDescriptionForTiledTransposeEmitter(
+ root.instruction(), hero.instruction())) {
+ permutation_ = transpose->permutation;
+ break;
+ }
+ }
+}
+
+absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const {
+ const auto& hlo_roots = analysis_.fusion_roots();
+ GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+ FusedIrEmitter fused_emitter(elemental_emitter);
+ for (auto [i, input] : llvm::enumerate(inputs)) {
+ HloInstruction* fused_operand = fusion.fused_parameter(i);
+ fused_emitter.BindGenerator(
+ *fused_operand, [input = input, builder,
+ fused_operand](const llvm_ir::IrArray::Index& index) {
+ return input.EmitReadArrayElement(index, builder,
+ fused_operand->name());
+ });
+ }
+
+ absl::flat_hash_map<const HloInstruction*,
+ std::vector<std::pair<int64_t, const HloInstruction*>>>
+ transposes_to_roots;
+ // Keep a list of deduplicated transpose heroes separate from
+ // transposes_to_roots to make the CodeGen deterministic.
+ std::vector<TransposeDescription> transposes;
+ transposes.reserve(hlo_roots.size());
+ std::vector<std::pair<int64_t, const HloInstruction*>> extra_outputs;
+
+ for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) {
+ const auto& hero = analysis_.fusion_hero(output_idx).instruction();
+ auto transpose_descr =
+ GetDescriptionForTiledTransposeEmitter(root.instruction(), hero);
+ if (transpose_descr.has_value()) {
+ auto iterator_inserted = transposes_to_roots.insert(std::make_pair(
+ &hero, std::vector<std::pair<int64_t, const HloInstruction*>>{
+ {output_idx, &root.instruction()}}));
+ if (iterator_inserted.second) {
+ transposes.push_back(*transpose_descr);
+ } else {
+ iterator_inserted.first->second.push_back(
+ {output_idx, &root.instruction()});
+ }
+ } else {
+ extra_outputs.push_back({output_idx, &root.instruction()});
+ }
+ }
+
+ absl::flat_hash_map<const HloInstruction*, llvm_ir::SharedMemoryTile> tiles;
+ absl::InlinedVector<int64_t, 3> permutation;
+ for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) {
+ permutation = tr.permutation;
+ auto tile_size = tiling_.GetBlockTileSize();
+ ++tile_size.back(); // Prevent bank conflicts.
+ auto* module = ir_emitter_context.llvm_module();
+ tiles[tr.instr] = llvm_ir::AllocateSharedMemoryTile(
+ module,
+ llvm_ir::PrimitiveTypeToIrType(tr.instr->shape().element_type(),
+ module),
+ tile_size, absl::StrCat("tr_tile_", tile_idx));
+ }
+
+ auto tile_generator = [&](const TilingThreadIdInfo& thread_id_info,
+ const llvm_ir::IrArray::Index& tile_start_index,
+ absl::Span<llvm::Value* const> tile_dimensions) {
+ // Copy input parameter values to shared memory buffers:
+ // tile[thread_id_y, thread_id_x] = input[index]
+ EmitTile(builder, tiling_, thread_id_info, tile_dimensions,
+ [&](absl::Span<llvm::Value* const> index_in_tile) {
+ auto index = tile_start_index.AddOffset(index_in_tile, builder);
+ for (const auto& tr : transposes) {
+ auto input_gen =
+ *fused_emitter.GetGenerator(*tr.instr->operand(0));
+ auto input_index = index.SourceIndexOfBitcast(
+ tr.instr->operand(0)->shape(), builder);
+ llvm::Value* value = *input_gen(input_index);
+ tiles[tr.instr].Store(value, index_in_tile, builder);
+ }
+
+ // Compute all extra output values before writing them. This
+ // avoids overwriting aliased input/output values before all
+ // reads occurred.
+ std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
+ llvm::Value*>>
+ scheduled_writes;
+ for (const auto& [output_idx, root] : extra_outputs) {
+ auto extra_output_index =
+ index.SourceIndexOfBitcast(root->shape(), builder);
+ auto output_gen = *fused_emitter.GetGenerator(*root);
+ llvm::Value* output_value = *output_gen(extra_output_index);
+ scheduled_writes.emplace_back(
+ outputs[output_idx], extra_output_index, output_value);
+ }
+
+ for (const auto& [output, idx, value] : scheduled_writes) {
+ output.EmitWriteArrayElement(idx, value, builder);
+ }
+ });
+
+ EmitSyncThreads(builder, ir_emitter_context);
+
+ auto output_tile_index = PermuteIndex(tile_start_index, permutation);
+ auto transposed_tile_dimensions = Permute(tile_dimensions, permutation);
+
+ EmitTile(
+ builder, tiling_, thread_id_info, transposed_tile_dimensions,
+ /*emit_elem_function=*/
+ [&](absl::Span<llvm::Value* const> index_in_tile) {
+ auto index = output_tile_index.AddOffset(index_in_tile, builder);
+ for (const auto& tr : transposes) {
+ llvm::Value* loaded = tiles[tr.instr].Load(
+ Permute(index_in_tile, permutation), builder);
+
+ FusedIrEmitter fused_emitter(elemental_emitter);
+ fused_emitter.BindGenerator(
+ *tr.instr,
+ [&](const llvm_ir::IrArray::Index&) { return loaded; });
+ for (int64_t i = 0;
+ i < fusion.fused_instructions_computation()->num_parameters();
+ ++i) {
+ llvm_ir::IrArray ir_array = inputs[i];
+ HloInstruction* fused_operand = fusion.fused_parameter(i);
+ fused_emitter.BindGenerator(
+ *fused_operand, [=](const llvm_ir::IrArray::Index& index) {
+ return ir_array.EmitReadArrayElement(index, builder,
+ fused_operand->name());
+ });
+ }
+
+ // Apply code generation for the code after the real hero.
+ // Compute all output values before writing them. This avoids
+ // overwriting aliased input/output values before all reads
+ // occurred.
+ std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
+ llvm::Value*>>
+ scheduled_writes;
+ for (const auto& [output_idx, root] :
+ transposes_to_roots[tr.instr]) {
+ TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen,
+ fused_emitter.GetGenerator(*root));
+
+ // Both for emission and writing it should be
+ // index-as-transformed by the computation.
+ auto untiled_index =
+ index.SourceIndexOfBitcast(root->shape(), builder);
+ TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index));
+ scheduled_writes.emplace_back(outputs[output_idx], untiled_index,
+ generated);
+ }
+ for (const auto& [output, idx, value] : scheduled_writes) {
+ output.EmitWriteArrayElement(idx, value, builder);
+ }
+ }
+ return absl::OkStatus();
+ });
+ };
+
+ llvm::Type* index_type =
+ GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
+ return EmitTilingKernel(builder, tiling_, index_type, tile_generator)
+ .status();
+}
+
+LaunchDimensions TransposeFusion::launch_dimensions() const {
+ return LaunchDimensions(tiling_.GetNumBlocks(),
+ tiling_.GetNumThreadsPerBlock());
+}
+
+std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const {
+ const auto& hero = analysis_.fusion_hero(root_index);
+ if (hero.opcode() != HloOpcode::kTranspose) {
+ // The shape of non-transpose roots are bitcast compatible with the input
+ // shape of transpose heroes.
+ auto map = ComposeIndexingMaps(
+ GetIndexingMapForTiling(tiling_, ctx),
+ GetBitcastMap(tiling_.GetXlaShape(),
+ analysis_.fusion_root(root_index).shape(), ctx));
+ map.Simplify();
+ return map;
+ }
+
+ // The block offsets are permuted, but the thread offsets remain the same.
+ auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx)
+ .getSubMap(std::vector<unsigned>{permutation_.begin(),
+ permutation_.end()});
+ auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx);
+ auto permuted_tiled_shape =
+ ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_));
+
+ auto map = ComposeIndexingMaps(
+ GetIndexingMapForTiling(
+ block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(),
+ tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(),
+ permuted_tiled_shape.dimensions()),
+ GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx));
+ map.Simplify();
+ return map;
+}
+
+std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const {
+ const auto& hero = analysis_.fusion_hero(root_index).instruction();
+ if (hero.opcode() != HloOpcode::kTranspose) {
+ auto map = ComposeIndexingMaps(
+ *ComputeThreadIdToOutputIndexing(root_index, ctx),
+ *ComputeOutputToInputIndexing(
+ &analysis_.fusion_root(root_index).instruction(), 0, ctx)
+ .indexing_maps[hero_operand_index]
+ .begin());
+ map.Simplify();
+ return map;
+ }
+
+ auto map = ComposeIndexingMaps(
+ GetIndexingMapForTiling(tiling_, ctx),
+ GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx));
+ map.Simplify();
+ return map;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h
new file mode 100644
index 0000000..3366130
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h
@@ -0,0 +1,91 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
+// algorithm to improve the memory access patterns for the input parameters
+// with a shape that is a 0-2-1 transpose of the output tensor shape. The
+// caller is responsible for making sure that it is safe to apply the shared
+// memory transpose on the input parameters.
+//
+// For the purpose of tiling, the output tensors have a logical shape of three
+// components 0-2-1 while the relevant input parameters have a logical shape
+// of three components 0-1-2 in the order major to minor. The x- and y-
+// dimensions of the tensors are tiled in square tiles with an edge length
+// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
+// transposes one tile: each thread copies kTileSize/kNumRows elements from
+// the input to a shared memory tile, then the otherwise "regular HLO kernel"
+// reads from the shared memory instead of the original input.
+//
+// This is similar to the following CUDA algorithm in TensorFlow:
+// https://goo.gl/MStRV6.
+//
+// `kTileSize` should usually be same as warp size. We currently choose 32 for
+// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
+//
+// TODO(b/33320379): Here each block transposes 1 tile. It may be more
+// efficient to launch fewer blocks so each transposes many tiles.
+class TransposeFusion : public KernelFusionEmitterBase {
+ public:
+ explicit TransposeFusion(const se::DeviceDescription& gpu_device_info,
+ const HloFusionAnalysis& analysis);
+ LaunchDimensions launch_dimensions() const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+ int64_t root_index, mlir::MLIRContext* ctx) const override;
+
+ std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+ int64_t root_index, int64_t hero_operand_index,
+ mlir::MLIRContext* ctx) const override;
+
+ protected:
+ absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+ const HloFusionInstruction& fusion,
+ const LaunchDimensions& launch_dims,
+ std::vector<llvm_ir::IrArray> inputs,
+ std::vector<llvm_ir::IrArray> outputs,
+ llvm::IRBuilder<>* builder) const override;
+
+ private:
+ const HloFusionAnalysis& analysis_;
+ Tiling tiling_;
+ absl::InlinedVector<int64_t, 3> permutation_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc
new file mode 100644
index 0000000..f33ae04
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc
@@ -0,0 +1,346 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/transpose.h"
+
+#include <memory>
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/statusor.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class TransposeTest : public HloTestBase {
+ protected:
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.set_xla_gpu_mlir_emitter_level(0);
+ return opts;
+ }
+ stream_executor::DeviceDescription device_info_ =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+};
+
+absl::StatusOr<std::unique_ptr<TransposeFusion>> GetTransposeFusion(
+ const HloFusionAnalysis& analysis) {
+ auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
+ auto fusion = dynamic_cast<TransposeFusion*>(emitter.get());
+ TF_RET_CHECK(fusion != nullptr);
+
+ emitter.release();
+ return std::unique_ptr<TransposeFusion>{fusion};
+}
+
+TEST_F(TransposeTest, ThreadIndexing021) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fusion {
+ %input = f32[100,32,64] parameter(0)
+ ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
+ }
+
+ ENTRY entry {
+ %input = f32[100,32,64] parameter(0)
+ ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+ mlir::MLIRContext mlir_context;
+
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ d3 floordiv 2,
+ d0 floordiv 32 + s1 * 4,
+ (d3 mod 2) * 32 + d0 mod 32
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 199]
+ d4 in [0, 0]
+ d5 in [0, 0]
+
+ s0 in [0, 0]
+ s1 in [0, 7]
+ s2 in [0, 0]
+ )"));
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ d3 floordiv 2,
+ (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
+ d0 mod 32
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 199]
+ d4 in [0, 0]
+ d5 in [0, 0]
+
+ s0 in [0, 0]
+ s1 in [0, 7]
+ s2 in [0, 0]
+ )"));
+}
+
+TEST_F(TransposeTest, ThreadIndexing201) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fusion {
+ %input = f32[100,64,32] parameter(0)
+ ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1}
+ }
+
+ ENTRY entry {
+ %input = f32[100,64,32] parameter(0)
+ ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+ mlir::MLIRContext mlir_context;
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ d3 floordiv 2,
+ (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
+ d0 mod 32
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 199]
+ d4 in [0, 0]
+ d5 in [0, 0]
+
+ s0 in [0, 0]
+ s1 in [0, 7]
+ s2 in [0, 0]
+ )"));
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ d0 floordiv 32 + s1 * 4,
+ d3 floordiv 2,
+ (d3 mod 2) * 32 + d0 mod 32
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 199]
+ d4 in [0, 0]
+ d5 in [0, 0]
+
+ s0 in [0, 0]
+ s1 in [0, 7]
+ s2 in [0, 0]
+ )"));
+}
+
+TEST_F(TransposeTest, ThreadIndexingPartialBlock) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ fused_computation {
+ %p0 = f64[24,2,6,4] parameter(0)
+ ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0}
+ }
+
+ ENTRY main {
+ %p0 = f64[24,2,6,4] parameter(0)
+ ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput,
+ calls=%fused_computation
+ }
+ )")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+ mlir::MLIRContext mlir_context;
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ d0 floordiv 32 + s0 * 4,
+ d3,
+ (d0 floordiv 4) mod 8,
+ d0 mod 4
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 1]
+ d4 in [0, 0]
+ d5 in [0, 0]
+ s0 in [0, 5]
+ s1 in [0, 0]
+ s2 in [0, 0]
+ d0 mod 32 in [0, 23]
+ )"));
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ s0,
+ d0 floordiv 32,
+ d3,
+ d0 mod 32
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 1]
+ d4 in [0, 0]
+ d5 in [0, 0]
+ s0 in [0, 5]
+ s1 in [0, 0]
+ s2 in [0, 0]
+ d0 mod 32 in [0, 23]
+ )"));
+}
+
+TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fusion {
+ %input = f32[100,32,64] parameter(0)
+ %transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
+ %bitcast = f32[100,2048] bitcast(%input)
+ ROOT %tuple = (f32[100,64,32], f32[100,2048]) tuple(%transpose, %bitcast)
+ }
+
+ ENTRY entry {
+ %input = f32[100,32,64] parameter(0)
+ ROOT %fusion = (f32[100,64,32], f32[100,2048]) fusion(%input), kind=kInput, calls=fusion
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+ mlir::MLIRContext mlir_context;
+
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+ fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString());
+}
+
+TEST_F(TransposeTest, ThreadIndexingSideOutput) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fusion {
+ %input0 = f32[100,32,64] parameter(0)
+ %input1 = f32[100,32] parameter(1)
+ %transpose = f32[100,64,32] transpose(%input0), dimensions={0,2,1}
+ %broadcast = f32[100,32,64] broadcast(%input1), dimensions={0,1}
+ ROOT %tuple = (f32[100,64,32], f32[100,32,64]) tuple(%transpose, %broadcast)
+ }
+
+ ENTRY entry {
+ %input0 = f32[100,32,64] parameter(0)
+ %input1 = f32[100,32] parameter(1)
+ ROOT %fusion = (f32[100,64,32], f32[100,32,64]) fusion(%input0, %input1), kind=kInput, calls=fusion
+ })")
+ .value();
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+ mlir::MLIRContext mlir_context;
+ // Check if side output `%broadcast` get the correct input indexing, which
+ // should corresponds to `%input1` with shape [100,32].
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ d3 floordiv 2,
+ d0 floordiv 32 + s1 * 4
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 199]
+ d4 in [0, 0]
+ d5 in [0, 0]
+
+ s0 in [0, 0]
+ s1 in [0, 7]
+ s2 in [0, 0]
+ )"));
+ EXPECT_THAT(
+ fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(),
+ MatchIndexingString(R"(
+ (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+ d3 floordiv 2,
+ d0 floordiv 32 + s1 * 4,
+ (d3 mod 2) * 32 + d0 mod 32
+ )
+ domain:
+ d0 in [0, 127]
+ d1 in [0, 0]
+ d2 in [0, 0]
+ d3 in [0, 199]
+ d4 in [0, 0]
+ d5 in [0, 0]
+
+ s0 in [0, 0]
+ s1 in [0, 7]
+ s2 in [0, 0]
+ )"));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/loop.cc b/third_party/xla/xla/service/gpu/fusions/loop.cc
deleted file mode 100644
index 522dc1d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/loop.cc
+++ /dev/null
@@ -1,295 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/loop.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <optional>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/numeric/bits.h"
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/macros.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-const Shape& GetElementShape(const HloFusionAnalysis& analysis) {
- const Shape* shape = &analysis.fusion_root(0).shape();
- while (shape->IsTuple()) {
- shape = &shape->tuple_shapes(0);
- }
- return *shape;
-}
-
-// Computes the maximum valid unroll factor for a given instruction.
-int ComputeMaxUnrollFactor(int64_t num_elements) {
- constexpr int kMaxUnrollFactor = 4;
- for (int i = kMaxUnrollFactor; i > 1; i /= 2) {
- if (num_elements % i == 0) {
- return i;
- }
- }
- return 1;
-}
-
-// Determines if we enable the row optimized codegen. When we have a fusion with
-// only pointwise operations, scalar broadcasting and row broadcasting, we can
-// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in
-// particular on A100. The int is the number of inputs with rank `out_rank`. Its
-// value is only defined if row vectorization is enabled.
-std::pair<bool /*enabled*/, int> RowVectorizationEnabled(
- const HloFusionAdaptor& fusion, int64_t out_rank) {
- auto roots = fusion.GetRoots();
- const auto is_row_major = [](const HloInstruction* instr) {
- // Only tested when the inputs are row-major. So only enable that case.
- // Maybe it would work if only the inner dimensions is contiguous.
- return LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout());
- };
- bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() &&
- is_row_major(&roots[0].instruction());
- if (!row_vectorized) {
- return {false, 0};
- }
-
- // Check that the operations in the fusion are supported. Each
- // supported operation (or category) must be manually vetted as XLA
- // only unrolls and relies on LLVM to vectorize. But this is brittle.
- // Currently tested and supported operations:
- // Elementwise, scalar and row broadcasting.
- //
- // We also detect at the same time if there is a row broadcasting
- // operation.
- int num_big_inputs = 0;
- bool some_row_broadcasting = false;
- HloBfsConsumersFirstTraversal(
- roots, fusion,
- [&](auto node) -> TraversalResult {
- if (!row_vectorized) {
- return TraversalResult::kInterrupt;
- }
-
- if (node.instruction().IsElementwise()) {
- return TraversalResult::kAdvance;
- }
-
- switch (node.opcode()) {
- case HloOpcode::kConstant:
- return TraversalResult::kSkip;
- case HloOpcode::kParameter:
- return TraversalResult::kAdvance;
- case HloOpcode::kBroadcast: {
- auto dims = node.instruction().dimensions();
- if (dims.empty()) {
- return TraversalResult::kAdvance;
- }
-
- if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) {
- some_row_broadcasting = true;
- return TraversalResult::kAdvance;
- }
- TF_FALLTHROUGH_INTENDED;
- }
- default:
- VLOG(2) << "Row vectorization not enabled due to: "
- << node.ToString();
- row_vectorized = false;
- return TraversalResult::kInterrupt;
- }
- });
- if (row_vectorized) {
- for (const HloInstruction* argument : fusion.GetParameters()) {
- if (argument->shape().rank() == out_rank) {
- ++num_big_inputs;
- }
- if (!is_row_major(argument)) {
- row_vectorized = false;
- }
- };
- }
- // Trigger only when there is a row broadcasting.
- return std::make_pair(row_vectorized && some_row_broadcasting,
- num_big_inputs);
-}
-
-} // namespace
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
- const HloFusionAnalysis& analysis) {
- return ComputeLoopFusionConfig(analysis, GetElementShape(analysis));
-}
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
- const HloFusionAnalysis& analysis, const Shape& element_shape) {
- int unroll_factor = 1;
- // Unrolling is good to read large inputs with small elements
- // due to vector loads, but increases the register pressure when one
- // thread has to produce multiple output elements.
- // Therefore for fusions with small outputs prefer to use one thread
- // per output element = no unroll.
- // Call 'small' fusions that use less threads than the GPU has.
- int64_t num_elements = ShapeUtil::ElementsIn(element_shape);
- int64_t n_threads_max = analysis.device_info().threads_per_core_limit() *
- analysis.device_info().core_count();
- if (num_elements >= n_threads_max &&
- !MayPreventVectorization(analysis.fusion())) {
- unroll_factor = ComputeMaxUnrollFactor(num_elements);
- }
- // CHECK that unroll_factor is a power-of-2, as needed by the logic below.
- CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
- // Ensure a single thread writes to a byte containing multiple values by
- // setting unroll_factor to an appropriate number. Setting unroll_factor is
- // safe even if the new unroll_factor doesn't divide the number of elements,
- // as the parallel loop emitter will insert a bounds check in this case to
- // ensure the out-of-bounds element is not computed and written. Setting
- // unroll_factor is safe even if MayPreventVectorization returns false, as
- // the MayPreventVectorization check is an optimization, not a correctness
- // requirement.
- unroll_factor = std::max(
- unroll_factor,
- CeilOfRatio(8, analysis.input_output_info().smallest_output_dtype_bits));
- CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
- VLOG(2) << "Unroll factor: " << unroll_factor;
-
- bool row_vectorized;
- int num_big_inputs;
- std::tie(row_vectorized, num_big_inputs) =
- RowVectorizationEnabled(analysis.fusion(), element_shape.rank());
- bool few_waves = !HloAnyOf(analysis.fusion(), [&](auto instr) {
- if (instr.opcode() == HloOpcode::kParameter ||
- instr.opcode() == HloOpcode::kConstant ||
- HloInstruction::IsOpElementwise(instr.opcode())) {
- return false;
- }
- if (auto broadcast =
- DynCast<HloBroadcastInstruction>(&instr.instruction())) {
- if (broadcast->dimensions().empty() ||
- // More than 3 big inputs cause a speed regression.
- (row_vectorized && num_big_inputs <= 3)) {
- return false;
- }
- }
- VLOG(2) << "few_waves not enabled due to: "
- << instr.instruction().ToString();
- return true;
- });
-
- LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
- row_vectorized};
- // Check that the shapes is supported.
- if (launch_config.row_vectorized &&
- ThreadsPerBlockRowVectorized(element_shape, analysis.device_info(),
- launch_config) <= 0) {
- VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
- launch_config.row_vectorized = false;
- launch_config.few_waves = false;
- }
- return launch_config;
-}
-
-LoopFusion::LoopFusion(const HloFusionAnalysis& analysis)
- : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {}
-
-std::optional<IndexingMap> LoopFusion::ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const {
- auto launch_dims = launch_dimensions();
- return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor,
- GetElementShape(analysis_), ctx);
-}
-
-std::optional<IndexingMap> LoopFusion::ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const {
- std::optional<IndexingMap> thread_id_to_output_indexing =
- ComputeThreadIdToOutputIndexing(root_index, ctx);
- if (!thread_id_to_output_indexing.has_value()) {
- return std::nullopt;
- }
- const HloInstruction* fusion_root =
- &analysis_.fusion_root(root_index).instruction();
- auto output_to_input_indexing =
- ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx);
- IndexingMapSet output_to_input_indexing_set =
- output_to_input_indexing.indexing_maps[hero_operand_index];
- // Since we are computing the indexing for a non-fusion op, there is only one
- // indexing map per operand.
- CHECK_EQ(output_to_input_indexing_set.size(), 1);
- IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps(
- *thread_id_to_output_indexing, *output_to_input_indexing_set.begin());
- thread_id_to_input_indexing_map.Simplify();
- return thread_id_to_input_indexing_map;
-}
-
-absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const {
- GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
- FusedIrEmitter fused_emitter(elemental_emitter);
- for (int i = 0; i < fusion.fused_parameters().size(); i++) {
- fused_emitter.BindGenerator(
- *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
- return inputs[i].EmitReadArrayElement(index, builder);
- });
- }
- TF_ASSIGN_OR_RETURN(
- auto element_generator,
- fused_emitter.GetGenerator(*fusion.fused_expression_root()));
-
- llvm::Type* index_type =
- GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
-
- return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder,
- config_)
- .EmitLoop(fusion.name(), index_type);
-}
-
-LaunchDimensions LoopFusion::launch_dimensions() const {
- return CalculateLaunchDimensions(GetElementShape(analysis_),
- analysis_.device_info(), config_);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/loop.h b/third_party/xla/xla/service/gpu/fusions/loop.h
deleted file mode 100644
index 2d23c302..0000000
--- a/third_party/xla/xla/service/gpu/fusions/loop.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_LOOP_H_
-#define XLA_SERVICE_GPU_FUSIONS_LOOP_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-
-// Generic loop fusion.
-class LoopFusion : public KernelFusionEmitterBase {
- public:
- explicit LoopFusion(const HloFusionAnalysis& analysis);
- LaunchDimensions launch_dimensions() const override;
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const override;
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const override;
-
- protected:
- absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const override;
-
- private:
- const HloFusionAnalysis& analysis_;
- LaunchDimensionsConfig config_;
-};
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
- const HloFusionAnalysis& analysis);
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
- const HloFusionAnalysis& analysis, const Shape& shape);
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_LOOP_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc
index 9db9173..4c6bdac 100644
--- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc
@@ -35,9 +35,9 @@
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h
index 029c67b..ecb1591 100644
--- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h
@@ -22,9 +22,9 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/loop.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h"
+#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/indexing_map.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
index 357ef65..eda89f8 100644
--- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
@@ -46,7 +46,7 @@
thread_id_printer_.SetSymbolName(1, "unroll_id");
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirLoopFusion fusion(analysis);
auto thread_id_to_output_indexing =
fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_);
@@ -88,7 +88,7 @@
thread_id_printer_.SetSymbolName(1, "unroll_id");
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirLoopFusion fusion(analysis);
auto thread_id_to_output_indexing =
@@ -140,7 +140,7 @@
thread_id_printer_.SetSymbolName(1, "unroll_id");
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirLoopFusion fusion(analysis);
auto thread_id_to_output_indexing =
@@ -196,10 +196,10 @@
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1 * 1024 + d0)>
- // CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768)>
- // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16)>
- // CHECK: #[[MAP3:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)>
+ // CHECK-DAG: #[[MAP0:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 1024 + d0)
+ // CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768)
+ // CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16)
+ // CHECK-DAG: #[[MAP3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)
// CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16>
// CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index
// CHECK: %[[THREAD_ID:.*]] = gpu.thread_id
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc
deleted file mode 100644
index 69c41ec..0000000
--- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc
+++ /dev/null
@@ -1,222 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/statusor.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class LoopTest : public HloTestBase {
- public:
- void SetUp() override {
- HloTestBase::SetUp();
-
- printer_ =
- AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
- {"chunk_id", "unroll_id"});
- }
-
- protected:
- stream_executor::DeviceDescription device_info_ =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
- AffineMapPrinter printer_;
- mlir::MLIRContext mlir_context_;
-};
-
-absl::StatusOr<std::unique_ptr<KernelFusionInterface>> GetFusion(
- const HloFusionAnalysis& analysis) {
- auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
- auto fusion = dynamic_cast<KernelFusionInterface*>(emitter.get());
- TF_RET_CHECK(fusion != nullptr);
-
- emitter.release();
- return std::unique_ptr<KernelFusionInterface>{fusion};
-}
-
-TEST_F(LoopTest, ThreadIndexingUnrolled) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- neg {
- %input = f32[100,200,300] parameter(0)
- ROOT neg = f32[100,200,300] negate(%input)
- }
-
- ENTRY entry {
- %input = f32[100,200,300] parameter(0)
- ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
- auto thread_id_to_output_indexing =
- loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
- &mlir_context_);
-
- EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
- MatchIndexingString(R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
- (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000,
- ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
- ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
- )
- domain:
- th_x in [0, 127]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 1007]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 11]
- unroll_id in [0, 3]
- bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999]
-)"));
-}
-
-TEST_F(LoopTest, ThreadIndexingNotUnrolled) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- neg {
- %input = f32[20] parameter(0)
- ROOT neg = f32[20] negate(%input)
- }
-
- ENTRY entry {
- %input = f32[20] parameter(0)
- ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
- auto thread_id_to_output_indexing =
- loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
- &mlir_context_);
- EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
- MatchIndexingString(R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
- domain:
- th_x in [0, 19]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 0]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- )"));
- auto thread_id_to_input_indexing =
- loop_fusion->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
- EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
- MatchIndexingString(R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
- domain:
- th_x in [0, 19]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 0]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- )"));
-}
-
-TEST_F(LoopTest, Broadcast) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- bcast {
- %input = f32[20] parameter(0)
- ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1}
- }
-
- ENTRY entry {
- %input = f32[20] parameter(0)
- ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
- auto thread_id_to_output_indexing =
- loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
- &mlir_context_);
- EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
- MatchIndexingString(R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
- (bl_x * 128 + th_x) floordiv 600,
- ((bl_x * 128 + th_x) floordiv 30) mod 20,
- (bl_x * 128 + th_x) mod 30)
- domain:
- th_x in [0, 127]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 46]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- bl_x * 128 + th_x in [0, 5999]
- )"));
- auto thread_id_to_input_indexing =
- loop_fusion->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
- EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
- MatchIndexingString(R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
- (((bl_x * 128 + th_x) floordiv 30) mod 20)
- domain:
- th_x in [0, 127]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 46]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- bl_x * 128 + th_x in [0, 5999]
- )"));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD
index 5231483..08a159f 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD
@@ -1,4 +1,3 @@
-load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//xla:xla.bzl", "xla_cc_test")
package(
@@ -76,7 +75,7 @@
"//xla/mlir_hlo:type_conversion",
"//xla/service:algorithm_util",
"//xla/service/gpu:hlo_traversal",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor:device_description",
@@ -119,7 +118,7 @@
"//xla/mlir_hlo",
"//xla/service:hlo_parser",
"//xla/service/gpu:launch_dimensions",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor:launch_dim",
@@ -153,7 +152,6 @@
deps = [
":computation_partitioner",
":elemental_hlo_to_mlir",
- ":passes",
":type_util",
"//xla:shape_util",
"//xla:status_macros",
@@ -172,7 +170,8 @@
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:target_util",
"//xla/service/gpu/fusions:fusion_emitter",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
+ "//xla/service/gpu/fusions/transforms:passes",
"//xla/service/gpu/model:indexing_analysis",
"//xla/service/gpu/runtime:kernel_thunk",
"//xla/service/llvm_ir:llvm_util",
@@ -261,95 +260,6 @@
],
)
-gentbl_cc_library(
- name = "passes_inc_gen",
- tbl_outs = [
- (
- [
- "-gen-pass-decls",
- "-name=GpuFusionTransforms",
- ],
- "passes.h.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "passes.td",
- visibility = ["//visibility:private"],
- deps = ["@llvm-project//mlir:PassBaseTdFiles"],
-)
-
-cc_library(
- name = "passes",
- srcs = [
- "convert_xla_gpu_pure_call_ops.cc",
- "erase_dead_functions.cc",
- "expand_float_ops.cc",
- "flatten_tensors.cc",
- "lower_tensors.cc",
- "lower_to_llvm.cc",
- "lower_xla_gpu_to_scf.cc",
- "merge_pointers_to_same_slice.cc",
- "optimize_loops.cc",
- "propagate_slice_indices.cc",
- "simplify_affine.cc",
- "simplify_arith.cc",
- "unswitch_loops.cc",
- "vectorize_loads_stores.cc",
- ],
- hdrs = ["passes.h"],
- deps = [
- ":passes_inc_gen",
- "//xla:shape_util",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/mlir_hlo",
- "//xla/mlir_hlo:map_mhlo_to_scalar_op",
- "//xla/service/gpu:ir_emission_utils",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
- "//xla/service/gpu/model:indexing_analysis",
- "//xla/stream_executor:device_description",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/log:check",
- "@com_google_absl//absl/strings",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:AffineDialect",
- "@llvm-project//mlir:AffineToStandard",
- "@llvm-project//mlir:AffineUtils",
- "@llvm-project//mlir:ArithDialect",
- "@llvm-project//mlir:ArithToLLVM",
- "@llvm-project//mlir:ArithTransforms",
- "@llvm-project//mlir:CallOpInterfaces",
- "@llvm-project//mlir:ComplexDialect",
- "@llvm-project//mlir:ComplexToLLVM",
- "@llvm-project//mlir:ControlFlowToLLVM",
- "@llvm-project//mlir:DataLayoutInterfaces",
- "@llvm-project//mlir:DialectUtils",
- "@llvm-project//mlir:FuncDialect",
- "@llvm-project//mlir:FuncToLLVM",
- "@llvm-project//mlir:GPUDialect",
- "@llvm-project//mlir:GPUToNVVMTransforms",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:LLVMCommonConversion",
- "@llvm-project//mlir:LLVMDialect",
- "@llvm-project//mlir:MathDialect",
- "@llvm-project//mlir:MathToLLVM",
- "@llvm-project//mlir:MathTransforms",
- "@llvm-project//mlir:NVVMDialect",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:SCFDialect",
- "@llvm-project//mlir:SCFToControlFlow",
- "@llvm-project//mlir:SCFUtils",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TensorDialect",
- "@llvm-project//mlir:TransformUtils",
- "@llvm-project//mlir:VectorDialect",
- "@llvm-project//mlir:VectorToLLVM",
- "@llvm-project//mlir:VectorTransforms",
- ],
-)
-
cc_library(
name = "type_util",
srcs = ["type_util.cc"],
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc
deleted file mode 100644
index bb1270e..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <utility>
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DEF_CONVERTPURECALLOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-struct RewriteCall : mlir::OpRewritePattern<PureCallOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- PureCallOp op, mlir::PatternRewriter& rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
- op, op.getResultTypes(), op.getOperands(), op->getAttrs());
- return mlir::success();
- }
-};
-
-class ConvertPureCallOpsPass
- : public impl::ConvertPureCallOpsPassBase<ConvertPureCallOpsPass> {
- public:
- void runOnOperation() override {
- auto* ctx = &getContext();
- mlir::RewritePatternSet patterns(ctx);
- patterns.add<RewriteCall>(ctx);
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<::mlir::Pass> CreateConvertPureCallOpsPass() {
- return std::make_unique<ConvertPureCallOpsPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
index 59471a3..6817336 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
@@ -73,8 +73,8 @@
#include "xla/mlir_hlo/mhlo/utils/type_conversion.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/type_util.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/model/indexing_analysis.h"
@@ -284,43 +284,45 @@
PrimitiveTypeToMlirType(instr->shape().element_type(), b);
int concat_dim =
Cast<HloConcatenateInstruction>(instr)->concatenate_dimension();
- int64_t offset = 0;
- IfOp outermost_if = nullptr;
SmallVector<Value, 3> operand_indices = indices;
- for (auto [index, operand] : llvm::enumerate(instr->operands())) {
- int64_t limit = offset + operand->shape().dimensions(concat_dim);
- auto ins = b.create<CmpIOp>(CmpIPredicate::ult, indices[concat_dim],
- b.create<ConstantIndexOp>(limit));
-
- auto generate_operand = [&, index = index]() {
- operand_indices[concat_dim] = b.create<arith::SubIOp>(
- indices[concat_dim], b.create<ConstantIndexOp>(offset));
- TF_ASSIGN_OR_RETURN(auto operand,
- operand_provider(instr, index, operand_indices));
- b.create<YieldOp>(operand);
- return absl::OkStatus();
- };
-
- if (index < instr->operand_count() - 1) {
- auto if_op =
- b.create<IfOp>(mlir::TypeRange{result_element_type}, ins, true, true);
- if (outermost_if == nullptr) {
- outermost_if = if_op;
- } else {
- b.create<YieldOp>(if_op.getResults());
- }
-
- b.setInsertionPointToStart(if_op.getBody(0));
- TF_RETURN_IF_ERROR(generate_operand());
- b.setInsertionPointToStart(if_op.getBody(1));
- } else {
- TF_RETURN_IF_ERROR(generate_operand());
- }
- offset = limit;
+ SmallVector<int64_t, 3> offsets{0};
+ for (auto* operand : instr->operands()) {
+ offsets.push_back(offsets.back() + operand->shape().dimensions(concat_dim));
}
- b.setInsertionPointAfter(outermost_if);
- return outermost_if.getResults();
+ std::function<absl::StatusOr<SmallVector<Value, 1>>(int64_t, int64_t)>
+ generate_concat;
+ generate_concat = [&](int64_t begin,
+ int64_t end) -> absl::StatusOr<SmallVector<Value, 1>> {
+ // If there's just one operand in the range, emit it.
+ if (begin == end - 1) {
+ operand_indices[concat_dim] = b.create<arith::SubIOp>(
+ indices[concat_dim], b.create<ConstantIndexOp>(offsets[begin]));
+ TF_ASSIGN_OR_RETURN(auto operand,
+ operand_provider(instr, begin, operand_indices));
+ return operand;
+ }
+
+ int64_t mid = (begin + end) / 2; // No risk of overflow.
+ auto if_op = b.create<IfOp>(
+ mlir::TypeRange{result_element_type},
+ b.create<CmpIOp>(CmpIPredicate::ult, indices[concat_dim],
+ b.create<ConstantIndexOp>(offsets[mid])),
+ true, true);
+
+ b.setInsertionPointToStart(if_op.getBody(0));
+ TF_ASSIGN_OR_RETURN(auto left_val, generate_concat(begin, mid));
+ b.create<YieldOp>(left_val);
+
+ b.setInsertionPointToStart(if_op.getBody(1));
+ TF_ASSIGN_OR_RETURN(auto right_val, generate_concat(mid, end));
+ b.create<YieldOp>(right_val);
+ b.setInsertionPointAfter(if_op);
+
+ return if_op.getResults();
+ };
+
+ return generate_concat(0, instr->operand_count());
}
absl::StatusOr<SmallVector<Value, 1>> EmitDynamicSlice(
@@ -665,9 +667,10 @@
return b.createOrFold<mlir::affine::AffineApplyOp>(expr, args);
}
-SmallVector<Value, 3> ApplyIndexing(const IndexingMap& map, ValueRange dims,
+SmallVector<Value, 3> ApplyIndexing(IndexingMap map, ValueRange dims,
ValueRange symbols,
ImplicitLocOpBuilder& b) {
+ map.ClearConstraints();
SmallVector<Value, 3> results;
for (unsigned int i = 0; i < map.GetAffineMap().getNumResults(); ++i) {
SmallVector<Value, 1> result;
@@ -1465,6 +1468,8 @@
.Convert();
}
+} // namespace
+
void GetLoopBoundsFromIndexingMap(ImplicitLocOpBuilder& b,
const IndexingMap& indexing_map,
SmallVectorImpl<Value>* lbs,
@@ -1479,8 +1484,6 @@
}
}
-} // namespace
-
absl::Status SubgraphToMlirFunction(
const PartitionedComputation& computation,
const PartitionedComputation::Subgraph& subgraph, mlir::func::FuncOp& func,
@@ -1513,20 +1516,6 @@
namespace {
-bool IsSymbolConstrained(const IndexingMap& map, int symbol_id) {
- for (const auto& [expr, _] : map.GetConstraints()) {
- bool result = false;
- expr.walk([&](mlir::AffineExpr leaf) {
- auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(leaf);
- if (sym && sym.getPosition() == symbol_id) {
- result = true;
- }
- });
- if (result) return true;
- }
- return false;
-}
-
ValueRange EmitLoopNestImpl(
ImplicitLocOpBuilder& b, ValueRange dim_values, ValueRange iter_args_inits,
const IndexingMap& indexing_map,
@@ -1623,7 +1612,7 @@
sym_index >= 0 && cumulative_loop_size < 64; --sym_index) {
auto& bound = indexing_map.GetSymbolBound(sym_index);
cumulative_loop_size *= bound.GetLoopTripCount();
- if (!IsSymbolConstrained(indexing_map, sym_index)) continue;
+ if (!indexing_map.IsSymbolConstrained(sym_index)) continue;
IndexingMap peeled_map = indexing_map;
if (bound.upper == bound.lower) continue;
@@ -1632,7 +1621,7 @@
peeled_map.Simplify();
// If the symbol is still constrained, peeling does not help.
- if (IsSymbolConstrained(peeled_map, sym_index)) continue;
+ if (peeled_map.IsSymbolConstrained(sym_index)) continue;
auto first_results = EmitLoopNestImpl(b, dim_values, iter_args_inits,
peeled_map, create_body, vectorize);
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
index 1f52109..1a97b57 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
@@ -94,7 +94,7 @@
mlir::ImplicitLocOpBuilder& b);
// Creates an `apply_indexing` op for the given map.
-llvm::SmallVector<mlir::Value, 3> ApplyIndexing(const IndexingMap& map,
+llvm::SmallVector<mlir::Value, 3> ApplyIndexing(IndexingMap map,
mlir::ValueRange dims,
mlir::ValueRange symbols,
mlir::ImplicitLocOpBuilder& b);
@@ -148,6 +148,13 @@
mlir::Block& src_block,
mlir::ValueRange mapped_args);
+// Populates `lbs`, `ubs` and `steps` with the loop bounds from `indexing_map`.
+void GetLoopBoundsFromIndexingMap(mlir::ImplicitLocOpBuilder& b,
+ const IndexingMap& indexing_map,
+ llvm::SmallVectorImpl<mlir::Value>* lbs,
+ llvm::SmallVectorImpl<mlir::Value>* ubs,
+ llvm::SmallVectorImpl<mlir::Value>* steps);
+
} // namespace mlir_converter
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc
index d7bbbb0..eab1568 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc
@@ -36,8 +36,8 @@
#include "mlir/Transforms/Passes.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
@@ -235,10 +235,10 @@
// CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][]
// CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]]
// CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]])
- // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 4)>
- // CHECK-SAME: (%[[Y]] in [0, 2])
- // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0 - 3)>
- // CHECK-SAME: (%[[Z]] in [0, 7])[%[[I]] in [0, 6]]
+ // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 4), domain: d0 in [0, 2]>(%[[Y]])
+ // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 3),
+ // CHECK-SAME: d0 in [0, 7], s0 in [0, 6]>(%[[Z]])[%[[I]]]
// CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]]
// CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]]
// CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]],
@@ -285,8 +285,8 @@
// If symbol rescaling wasn't working we would have a
// `s0 floordiv <base_dilation>` in the map:
// CHECK: %[[K:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)>
- // CHECK-SAME: (%[[X]] in [0, 18])[%[[I]] in [0, 3]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+ // CHECK-SAME: d0 in [0, 18], s0 in [0, 3]>(%[[X]])[%[[I]]]
// CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]]
)"));
@@ -333,6 +333,79 @@
)"));
}
+TEST_F(ElementalHloToMlirTest, ConcatenateMany) {
+ TF_EXPECT_OK(Run(R"(
+ ENTRY main {
+ p0 = f32[10,1,30] parameter(0)
+ p1 = f32[10,2,30] parameter(1)
+ p2 = f32[10,3,30] parameter(2)
+ p3 = f32[10,4,30] parameter(3)
+ p4 = f32[10,5,30] parameter(4)
+ p5 = f32[10,6,30] parameter(5)
+ p6 = f32[10,7,30] parameter(6)
+ ROOT r = f32[10,28,30] concatenate(p0, p1, p2, p3, p4, p5, p6),
+ dimensions={1}
+ })",
+ R"(
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+ // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+ // CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
+ // CHECK-DAG: %[[C21:.*]] = arith.constant 21 : index
+ // CHECK: %[[P0TO2:.*]] = arith.cmpi ult, %[[I:.*]], %[[C6]]
+ // CHECK: %[[CONCAT:.*]] = scf.if %[[P0TO2]] -> (f32)
+ // CHECK-NEXT: %[[P0:.*]] = arith.cmpi ult, %[[I]], %[[C1]]
+ // CHECK-NEXT: scf.if %[[P0]]
+ // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[I]], {{.*}}] : tensor<10x1x30xf32>
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: %[[P1:.*]] = arith.cmpi ult, %[[I]], %[[C3]]
+ // CHECK-NEXT: scf.if %[[P1]]
+ // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C1]]
+ // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x2x30xf32>
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C3]]
+ // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x3x30xf32>
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: %[[P3TO4:.*]] = arith.cmpi ult, %[[I]], %[[C15]]
+ // CHECK-NEXT: scf.if %[[P3TO4]]
+ // CHECK-NEXT: %[[P3:.*]] = arith.cmpi ult, %[[I]], %[[C10]]
+ // CHECK-NEXT: scf.if %[[P3]]
+ // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C6]]
+ // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x4x30xf32>
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C10]]
+ // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x5x30xf32>
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: %[[P5:.*]] = arith.cmpi ult, %[[I]], %[[C21]]
+ // CHECK-NEXT: scf.if %[[P5]]
+ // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C15]]
+ // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x6x30xf32>
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C21]]
+ // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x7x30xf32>
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return %[[CONCAT]]
+ )"));
+}
+
TEST_F(ElementalHloToMlirTest, ConcatenateUnsigned) {
TF_EXPECT_OK(Run(R"(
ENTRY main {
@@ -433,7 +506,7 @@
// CHECK-DAG: %[[C4:.*]] = arith.constant 4
// CHECK-DAG: %[[C7:.*]] = arith.constant 7
// CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7])
+ // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]])
// CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]]
// CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]]
// CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]]
@@ -445,11 +518,9 @@
// CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]]
// CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]]
// CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)>
- // CHECK-SAME: (%[[X]] in [1, 7])
+ // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]])
// CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: <(d0) -> (d0 - 4)>
- // CHECK-SAME: (%[[Y]] in [4, 7])
+ // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]])
// CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]]
// CHECK: scf.yield %[[VAL]]
// CHECK: } else {
@@ -477,7 +548,7 @@
// CHECK-DAG: %[[C4:.*]] = arith.constant 4
// CHECK-DAG: %[[C7:.*]] = arith.constant 7
// CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7])
+ // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]])
// CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]]
// CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]]
// CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]]
@@ -489,11 +560,9 @@
// CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]]
// CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]]
// CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)>
- // CHECK-SAME: (%[[X]] in [1, 7])
+ // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]])
// CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: <(d0) -> (d0 - 4)>
- // CHECK-SAME: (%[[Y]] in [4, 7])
+ // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]])
// CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]]
// CHECK: scf.yield %[[VAL]]
// CHECK: } else {
@@ -810,11 +879,11 @@
// CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
// CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) {
// CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)>
- // CHECK-SAME: (%[[W]] in [0, 5])[%[[X]] in [0, 2]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+ // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)>
- // CHECK-SAME: (%[[H]] in [0, 7])[%[[Y]] in [0, 4]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+ // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
// CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
// CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -856,11 +925,11 @@
// CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
// CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) {
// CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)>
- // CHECK-SAME: (%[[W]] in [0, 2])[%[[X]] in [0, 2]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+ // CHECK-SAME: d0 in [0, 2], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)>
- // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+ // CHECK-SAME: d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
// CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
// CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -903,21 +972,21 @@
// CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) {
// CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) {
// CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
- // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 7])[%[[X]] in [0, 2]]
+ // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index
// CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index
// CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1
- // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 11])[%[[Y]] in [0, 4]]
+ // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index
// CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index
// CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1
// CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) {
// CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 1)>
- // CHECK-SAME: (%[[W]] in [0, 7])[%[[X]] in [0, 2]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 1),
+ // CHECK-SAME: d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 2)>
- // CHECK-SAME: (%[[H]] in [0, 11])[%[[Y]] in [0, 4]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 2),
+ // CHECK-SAME: d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
// CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
// CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -957,17 +1026,17 @@
// CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) {
// CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) {
// CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
- // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 12])[%[[X]] in [0, 2]]
+ // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index
- // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 18])[%[[Y]] in [0, 4]]
+ // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index
// CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) {
// CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)>
- // CHECK-SAME: (%[[W]] in [0, 12])[%[[X]] in [0, 2]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2),
+ // CHECK-SAME: d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)>
- // CHECK-SAME: (%[[H]] in [0, 18])[%[[Y]] in [0, 4]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2),
+ // CHECK-SAME: d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
// CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
// CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1009,11 +1078,11 @@
// CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
// CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) {
// CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)>
- // CHECK-SAME: (%[[W]] in [0, 3])[%[[X]] in [0, 2]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2),
+ // CHECK-SAME: d0 in [0, 3], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)>
- // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2),
+ // CHECK-SAME: d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
// CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
// CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1055,17 +1124,14 @@
// CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
// CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) {
// CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)>
- // CHECK-SAME: (%[[W]] in [0, 5])
- // CHECK-SAME: [%[[X]] in [0, 2]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+ // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)>
- // CHECK-SAME: (%[[H]] in [0, 7])
- // CHECK-SAME: [%[[Y]] in [0, 4]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+ // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0)>
- // CHECK-SAME: (%[[O]] in [0, 15])
- // CHECK-SAME: [%[[I]] in [0, 1]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0),
+ // CHECK-SAME: d0 in [0, 15], s0 in [0, 1]>(%[[O]])[%[[I]]]
// CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32>
// CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32>
// CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1109,13 +1175,11 @@
// CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) {
// CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) {
// CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)>
- // CHECK-SAME: (%[[W]] in [0, 5])
- // CHECK-SAME: [%[[X]] in [0, 2]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+ // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]]
// CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)>
- // CHECK-SAME: (%[[H]] in [0, 7])
- // CHECK-SAME: [%[[Y]] in [0, 4]]
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+ // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]]
// CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
// CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
// CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1581,8 +1645,8 @@
// CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}
// CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]]
// CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing
- // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)>
- // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9])
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1),
+ // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]])
// CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]]
// CHECK: return %[[A]], %[[B]]
)"));
@@ -1605,8 +1669,8 @@
// CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}
// CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0
// CHECK: %[[IDX:.*]] =
- // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)>
- // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9])
+ // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1),
+ // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]])
// CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1
// CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]])
// CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc b/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc
deleted file mode 100644
index 012201a..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <queue>
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Interfaces/CallInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_ERASEDEADFUNCTIONSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-struct CallInfo {
- PureCallOp call;
- int count;
-};
-
-llvm::DenseSet<mlir::func::FuncOp> FindLiveFunctions(mlir::ModuleOp module) {
- std::queue<mlir::func::FuncOp> worklist;
- llvm::DenseSet<mlir::func::FuncOp> live_funcs;
- module.walk([&](mlir::func::FuncOp func) {
- if (!func.isPrivate()) {
- worklist.push(func);
- live_funcs.insert(func);
- }
- });
-
- mlir::SymbolTableCollection symbol_table;
- while (!worklist.empty()) {
- auto func = worklist.front();
- worklist.pop();
- func.walk([&](mlir::CallOpInterface call) {
- auto callee =
- mlir::cast<mlir::func::FuncOp>(call.resolveCallable(&symbol_table));
- if (live_funcs.insert(callee).second) {
- worklist.push(callee);
- }
- });
- }
- return live_funcs;
-}
-
-class EraseDeadFunctionsPass
- : public impl::EraseDeadFunctionsPassBase<EraseDeadFunctionsPass> {
- public:
- void runOnOperation() override {
- // Find live functions and erase dead ones.
- auto live = FindLiveFunctions(getOperation());
- getOperation().walk([&](mlir::func::FuncOp func) {
- if (!live.contains(func)) {
- func.erase();
- }
- });
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-CreateEraseDeadFunctionsPass() {
- return std::make_unique<EraseDeadFunctionsPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc
deleted file mode 100644
index 001df2c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc
+++ /dev/null
@@ -1,675 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <algorithm>
-#include <array>
-#include <cassert>
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "absl/log/check.h"
-#include "llvm/ADT/APFloat.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-namespace ma = ::mlir::arith;
-
-using ma::SelectOp;
-using mlir::Value;
-
-#define GEN_PASS_DEF_EXPANDFLOATOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-// Wraps a Value to provide operator overloading for more readable expressions.
-struct Val {
- Value value;
- mlir::ImplicitLocOpBuilder* b;
-
- operator Value() const { return value; } // NOLINT
-
- Val operator+(int64_t rhs) const { return Binop<ma::AddIOp>(rhs); }
- Val operator+(Value rhs) const { return Binop<ma::AddIOp>(rhs); }
- Val operator-(int64_t rhs) const { return Binop<ma::SubIOp>(rhs); }
- Val operator-(Value rhs) const { return Binop<ma::SubIOp>(rhs); }
- Val operator*(int64_t rhs) const { return Binop<ma::MulIOp>(rhs); }
- Val operator*(Value rhs) const { return Binop<ma::MulIOp>(rhs); }
- Val operator&(Value rhs) const { return Binop<ma::AndIOp>(rhs); }
- Val operator&(int64_t rhs) const { return Binop<ma::AndIOp>(rhs); }
- Val operator|(Value rhs) const { return Binop<ma::OrIOp>(rhs); }
- Val operator|(int64_t rhs) const { return Binop<ma::OrIOp>(rhs); }
- Val operator^(Value rhs) const { return Binop<ma::XOrIOp>(rhs); }
- Val shl(Value rhs) const { return Binop<ma::ShLIOp>(rhs); }
- Val shl(int64_t rhs) const { return Binop<ma::ShLIOp>(rhs); }
- Val shrui(Value rhs) const { return Binop<ma::ShRUIOp>(rhs); }
- Val shrui(int64_t rhs) const { return Binop<ma::ShRUIOp>(rhs); }
-
- Val cmp(ma::CmpIPredicate pred, Value rhs) const {
- return {b->create<ma::CmpIOp>(pred, value, rhs), b};
- }
- Val cmp(ma::CmpIPredicate pred, int64_t rhs) const {
- return cmp(pred, MakeConstant(rhs));
- }
- Val operator==(Value rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
- Val operator==(int64_t rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
- Val operator!=(int64_t rhs) const { return cmp(ma::CmpIPredicate::ne, rhs); }
-
- Val MakeConstant(int64_t c) const {
- return {b->create<ma::ConstantIntOp>(c, value.getType()), b};
- }
-
- private:
- template <typename Op>
- Val Binop(Value rhs) const {
- return {b->create<Op>(value, rhs), b};
- }
-
- template <typename Op>
- Val Binop(int64_t rhs) const {
- return Binop<Op>(MakeConstant(rhs));
- }
-};
-
-template <typename OpTy, ma::CmpFPredicate pred>
-struct RewriteToCmpSelect : public mlir::OpRewritePattern<OpTy> {
- using mlir::OpRewritePattern<OpTy>::OpRewritePattern;
-
- RewriteToCmpSelect(mlir::MLIRContext* context, bool include_f32)
- : mlir::OpRewritePattern<OpTy>(context), include_f32(include_f32) {}
-
- mlir::LogicalResult matchAndRewrite(
- OpTy op, mlir::PatternRewriter& rewriter) const override {
- if (op.getType().isF32() && !include_f32) {
- return rewriter.notifyMatchFailure(op, "not rewriting f32 min/max");
- }
-
- auto lhs_is_nan = rewriter.create<ma::CmpFOp>(
- op.getLoc(), ma::CmpFPredicate::UNE, op.getLhs(), op.getLhs());
- auto rhs_is_not_nan = rewriter.create<ma::CmpFOp>(
- op.getLoc(), ma::CmpFPredicate::OEQ, op.getRhs(), op.getRhs());
-
- auto return_lhs =
- rewriter.create<ma::CmpFOp>(op.getLoc(), pred, op.getLhs(), op.getRhs())
- .getResult();
-
- // logic: isNaN(lhs) || (!isNan(rhs) && return_lhs) ? lhs : rhs
- return_lhs = rewriter.create<ma::OrIOp>(
- op.getLoc(), lhs_is_nan,
- rewriter.create<ma::AndIOp>(op.getLoc(), rhs_is_not_nan, return_lhs));
-
- rewriter.replaceOpWithNewOp<SelectOp>(op, op.getResult().getType(),
- return_lhs, op.getLhs(), op.getRhs());
- return mlir::success();
- }
-
- bool include_f32;
-};
-
-struct RewriteErf32Pattern : public mlir::OpRewritePattern<mlir::math::ErfOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- mlir::math::ErfOp op, mlir::PatternRewriter& rewriter) const override {
- if (!op.getType().isF32()) {
- return rewriter.notifyMatchFailure(op, "not an f32 erf");
- }
-
- static const std::array<float, 5> kAlpha{
- 0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f,
- 0.18520832239976145f, 1.128379143519084f};
-
- static const std::array<float, 7> kBeta{-1.1791602954361697e-7,
- 0.000023547966471313185f,
- 0.0010179625278914885f,
- 0.014070470171167667f,
- 0.11098505178285362f,
- 0.49746925110067538f,
- 1.0f};
-
- // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of
- // which x should be +/-1.
- constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f;
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto c = [&](float v) -> Value {
- return b.create<ma::ConstantFloatOp>(llvm::APFloat(v),
- rewriter.getF32Type());
- };
-
- auto poly = [&](auto x, auto coefficients) -> Value {
- auto r = c(coefficients[0]);
- for (int i = 1; i < coefficients.size(); ++i) {
- r = b.create<mlir::math::FmaOp>(r, x, c(coefficients[i]));
- }
- return r;
- };
-
- Value x = op.getOperand();
- x = b.create<ma::MaximumFOp>(x, c(-kErfInvOneMinusHalfULP));
- x = b.create<ma::MinimumFOp>(x, c(kErfInvOneMinusHalfULP));
- Value x2 = b.create<ma::MulFOp>(x, x);
-
- rewriter.replaceOpWithNewOp<ma::DivFOp>(
- op, b.create<ma::MulFOp>(x, poly(x2, kAlpha)), poly(x2, kBeta));
-
- return mlir::success();
- }
-};
-
-int GetSignificandBits(mlir::FloatType ty) {
- return llvm::APFloat::semanticsPrecision(ty.getFloatSemantics()) - 1;
-}
-
-int GetExponentBias(mlir::FloatType ty) {
- return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
-}
-
-Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
- auto ty = mlir::cast<mlir::FloatType>(value.getType());
- if (mlir::LLVM::isCompatibleOuterType(ty)) {
- value = b.create<mlir::math::AbsFOp>(value);
- Value inf = b.create<ma::ConstantFloatOp>(
- llvm::APFloat::getInf(ty.getFloatSemantics()), ty);
- return b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, value, inf);
- }
-
- assert(ty.getIntOrFloatBitWidth() == 8);
- if (!ty.isFloat8E5M2()) {
- // F8E5M2 is the only 8 bit float with infinities.
- return b.create<ma::ConstantIntOp>(false, b.getI1Type());
- }
- Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
- return (bits & 0x7F) == 0x7C;
-}
-
-Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
- auto ty = value.getType();
- if (mlir::LLVM::isCompatibleOuterType(ty)) {
- return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value);
- }
-
- assert(ty.getIntOrFloatBitWidth() == 8);
- Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
- if (ty.isFloat8E5M2() || ty.isFloat8E4M3FN()) {
- return (bits & 0x7F) == 0x7F;
- }
- return bits == 0x80;
-}
-
-Value EmitReducePrecision(Value value, int exponent_bits, int mantissa_bits,
- mlir::ImplicitLocOpBuilder& b) {
- mlir::mhlo::ReducePrecisionOp::Properties properties;
- properties.exponent_bits = b.getI32IntegerAttr(exponent_bits);
- properties.mantissa_bits = b.getI32IntegerAttr(mantissa_bits);
- return mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType<
- mlir::mhlo::ReducePrecisionOp>(
- b.getLoc(), value.getType(), {value.getType()},
- mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), &b);
-}
-
-Value EmitF16ToF8e5m2(Value in, mlir::ImplicitLocOpBuilder& b) {
- Val in_bits{b.create<ma::BitcastOp>(b.getI16Type(), in), &b};
- // Use this method of checking for NaN because it's the same as what's used
- // in the reduce precision lowering.
- Value is_nan = (in_bits & 32767).cmp(ma::CmpIPredicate::ugt, 31744);
-
- Value value = EmitReducePrecision(in, 5, 2, b);
- value = b.create<ma::BitcastOp>(b.getI16Type(), value);
- value = b.create<ma::ShRUIOp>(value,
- b.create<ma::ConstantIntOp>(8, b.getI16Type()));
- value = b.create<ma::TruncIOp>(b.getI8Type(), value);
- // When the input is NaN, just truncating can turn a NaN into an inf if the
- // mantissa becomes 0.
- value = b.create<ma::SelectOp>(
- is_nan, b.create<ma::ConstantIntOp>(0x7F, value.getType()), value);
- return b.create<ma::BitcastOp>(b.getFloat8E5M2Type(), value);
-}
-
-Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
- mlir::ImplicitLocOpBuilder& b) {
- using ma::CmpIPredicate;
-
- // This is a port of ConvertImpl in
- // https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h
- auto from_ty = mlir::cast<mlir::FloatType>(value.getType());
- if (to_ty == b.getFloat8E5M2Type() && from_ty == b.getF16Type()) {
- return EmitF16ToF8e5m2(value, b);
- }
-
- int from_mantissa = GetSignificandBits(from_ty);
- int from_bias = GetExponentBias(from_ty);
- int from_min_exp =
- llvm::APFloat::semanticsMinExponent(from_ty.getFloatSemantics());
- int from_max_exp =
- llvm::APFloat::semanticsMaxExponent(from_ty.getFloatSemantics());
- auto from_int_ty = b.getIntegerType(from_ty.getIntOrFloatBitWidth());
-
- int to_mantissa = GetSignificandBits(to_ty);
- int to_bias = GetExponentBias(to_ty);
- int to_min_exp =
- llvm::APFloat::semanticsMinExponent(to_ty.getFloatSemantics());
- int to_max_exp =
- llvm::APFloat::semanticsMaxExponent(to_ty.getFloatSemantics());
- auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth());
-
- mlir::IntegerType wide_int_ty;
- if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) {
- wide_int_ty = b.getI16Type();
- } else {
- wide_int_ty = b.getIntegerType(
- std::max(from_int_ty.getWidth(), to_int_ty.getWidth()));
- }
- auto convert_int = [&](mlir::Type ty, Value v) -> Val {
- if (v.getType() == ty) {
- return {v, &b};
- }
- if (ty.getIntOrFloatBitWidth() < v.getType().getIntOrFloatBitWidth()) {
- return {b.create<ma::TruncIOp>(ty, v), &b};
- }
- return {b.create<ma::ExtUIOp>(ty, v), &b};
- };
-
- int64_t exp_offset = to_bias - from_bias;
- int digit_shift = to_mantissa - from_mantissa;
-
- Val from_bits{
- b.create<ma::BitcastOp>(
- b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value),
- &b};
-
- auto cst = [&](mlir::Type ty, int64_t n) -> Val {
- return {b.create<ma::ConstantIntOp>(n, ty), &b};
- };
-
- // Shift bits to destination type, without sign bit.
- Val from_sign_bit =
- from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0;
-
- from_bits =
- from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1);
-
- Value result_is_inf = IsInf(value, b);
- Value input_is_nan = IsNaN(value, b);
-
- auto cst_bits = [&](llvm::APFloat f) {
- return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())),
- f.bitcastToAPInt().getZExtValue());
- };
- Value to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics()));
- Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics()));
- Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics()));
-
- auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
- assert(bits.value.getType() == roundoff.value.getType());
- // Round to nearest even by adding a bias term.
- // Consider a bit pattern
- // FFF...FLRTT...T,
- // where bits RTT...T need to be rounded-off. We add a bias term to the
- // bit pattern s.t. a carry is introduced to round up only if
- // - L is 1, R is 1, OR
- // - L is 0, R is 1, any T is one.
- // We do this by adding L to a bit pattern consisting of all T = 1.
- Val rounded = (bits.shrui(roundoff) & 1) +
- (bits.MakeConstant(1).shl(roundoff - 1) - 1);
- Val bias{b.create<SelectOp>(roundoff == 0, roundoff, rounded), &b};
- return bits + bias;
- };
-
- // Happy path: no subnormals, infinities or NaNs.
- Value result;
- {
- // Round the mantissa if it is shrinking.
- Val rounded_from_bits = convert_int(wide_int_ty, from_bits);
- if (digit_shift < 0) {
- rounded_from_bits = round_bits_to_nearest_even(
- from_bits, from_bits.MakeConstant(-digit_shift)) &
- ~((1ll << (-digit_shift)) - 1);
- }
-
- // Re-bias the exponent.
- rounded_from_bits = rounded_from_bits + (exp_offset << from_mantissa);
-
- // Check for overflows by aligning the significands. We always align the
- // narrower significand to the wider significand.
- int64_t to_highest = llvm::APFloat::getLargest(to_ty.getFloatSemantics())
- .bitcastToAPInt()
- .getZExtValue();
- int64_t aligned_highest = to_highest;
- if (digit_shift < 0) {
- aligned_highest <<= -digit_shift;
- // Shift down, all dropped bits should already be zero.
- result = rounded_from_bits.shrui(-digit_shift);
- } else {
- // Shift up, inserting zeros in the newly created digits.
- rounded_from_bits = rounded_from_bits.shl(digit_shift);
- result = rounded_from_bits;
- }
- result = convert_int(to_int_ty, result);
-
- // `From` supports larger values than `To`, we may overflow.
- if (std::make_pair(to_max_exp, to_mantissa) <
- std::make_pair(from_max_exp, from_mantissa)) {
- result = b.create<SelectOp>(
- rounded_from_bits.cmp(CmpIPredicate::ugt, aligned_highest), to_inf,
- result);
- }
- }
-
- auto i32_ty = b.getI32Type();
- Val biased_from_exp = convert_int(i32_ty, from_bits.shrui(from_mantissa));
-
- if (to_min_exp < from_min_exp) {
- // `To` supports more exponents near zero which means that some subnormal
- // values in `From` may become normal.
-
- // Subnormals.
- Val bits = convert_int(wide_int_ty, from_bits);
-
- // Determine exponent in target type.
- Value normalization_factor =
- convert_int(i32_ty,
- b.create<mlir::math::CountLeadingZerosOp>(from_bits)) -
- (from_int_ty.getWidth() - from_mantissa - 1);
-
- Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor;
- // If the result is subnormal, adjust the subnormal bits to account for
- // the difference in exponent bias.
- Value subnormal_bits = bits;
- if (exp_offset < wide_int_ty.getWidth()) {
- subnormal_bits = bits.shl(exp_offset);
- }
-
- // Result is normal. Shift the mantissa to account for the number of
- // leading zero digits, and clear the hidden bit.
- // Insert the exponent bits.
- Value normal_bits =
- (bits.shl(convert_int(wide_int_ty, normalization_factor)) &
- ~(1 << from_mantissa)) |
- convert_int(wide_int_ty, biased_exponent).shl(from_mantissa);
-
- Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0);
- bits.value =
- b.create<SelectOp>(biased_exp_sle_zero, subnormal_bits, normal_bits);
- if (digit_shift > 0) {
- bits = bits.shl(digit_shift);
- } else {
- bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift));
- bits = bits.shrui(-digit_shift);
- }
- bits = convert_int(to_int_ty, bits);
-
- result = b.create<SelectOp>(biased_from_exp == 0, bits, result);
- } else if (to_min_exp > from_min_exp) {
- // `To` supports fewer exponents near zero which means that some values in
- // `From` may become subnormal.
- Val unbiased_exp = biased_from_exp - from_bias;
- Val biased_to_exp = unbiased_exp + to_bias;
- // Subnormals and zero.
- // Round and shift mantissa down.
- Val from_has_leading_one = biased_from_exp != 0;
- Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one);
- from_has_leading_one = convert_int(from_int_ty, from_has_leading_one);
- Val exponent_shift_i32 =
- (from_has_leading_one_i32 - biased_to_exp) - digit_shift;
- // Insert the implicit leading 1 bit on the mantissa for normalized
- // inputs.
- Val rounded_from_bits = (from_bits & ((1ll << from_mantissa) - 1)) |
- from_has_leading_one.shl(from_mantissa);
-
- // NOTE: we need to round again from the original from_bits,
- // otherwise the lower precision bits may already be lost. There is
- // an edge-case where rounding to a normalized value would normally
- // round down, but for a subnormal, we need to round up.
- Val exponent_shift_from_ty = convert_int(from_int_ty, exponent_shift_i32);
- Val exponent_shift_to_ty = convert_int(to_int_ty, exponent_shift_i32);
- Val positive_bits = convert_int(
- to_int_ty,
- round_bits_to_nearest_even(rounded_from_bits, exponent_shift_from_ty)
- .shrui(exponent_shift_from_ty));
- // To avoid UB, limit rounding and shifting to the full mantissa plus
- // leading 1.
- positive_bits.value = b.create<SelectOp>(
- exponent_shift_i32.cmp(CmpIPredicate::sle, from_mantissa + 1),
- positive_bits, to_zero);
-
- Val negative_bits = convert_int(to_int_ty, rounded_from_bits)
- .shl(to_zero - exponent_shift_to_ty);
- Value bits =
- b.create<SelectOp>(exponent_shift_i32.cmp(CmpIPredicate::sgt, 0),
- positive_bits, negative_bits);
- result = b.create<SelectOp>(biased_to_exp.cmp(CmpIPredicate::sle, 0), bits,
- result);
- }
-
- // Handle types with no unsigned zero.
- auto is_nuz = [](mlir::FloatType ty) {
- return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
- ty.isFloat8E5M2FNUZ();
- };
-
- if (is_nuz(to_ty)) {
- // Clear the sign bit if the result is zero (the output has no negative
- // zero).
- Val result_is_non_zero = Val{result, &b} != 0;
- from_sign_bit = from_sign_bit & result_is_non_zero;
- } else if (is_nuz(from_ty)) {
- // Clear the sign bit if the input is NaN (it's positive but encoded as
- // negative 0).
- from_sign_bit = from_sign_bit ^ input_is_nan;
- }
-
- result = b.create<SelectOp>(result_is_inf, to_inf, result);
- result = b.create<SelectOp>(from_bits == 0, to_zero, result);
- result = b.create<SelectOp>(input_is_nan, to_nan, result);
-
- Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));
-
- // Insert sign bit.
- result = b.create<SelectOp>(from_sign_bit, neg_result, result);
- result = b.create<ma::BitcastOp>(to_ty, result);
- return result;
-}
-
-struct RewriteTruncFPattern : public mlir::OpRewritePattern<ma::TruncFOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- ma::TruncFOp op, mlir::PatternRewriter& rewriter) const override {
- using FloatValue = mlir::TypedValue<mlir::FloatType>;
- auto src = mlir::cast<FloatValue>(op.getOperand());
- auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
- if (dst_ty.getWidth() != 8) {
- return rewriter.notifyMatchFailure(op, "not an 8 bit truncf");
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
- return mlir::success();
- }
-};
-
-struct RewriteExtFPattern : public mlir::OpRewritePattern<ma::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- ma::ExtFOp op, mlir::PatternRewriter& rewriter) const override {
- using FloatValue = mlir::TypedValue<mlir::FloatType>;
- auto src = mlir::cast<FloatValue>(op.getOperand());
- auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
- if (src.getType().getWidth() != 8) {
- return rewriter.notifyMatchFailure(op, "not an 8 bit extf");
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
- return mlir::success();
- }
-};
-
-// Lowering for cmpf : f8 for float to pred conversions.
-struct RewriteF8Cst : public mlir::OpRewritePattern<ma::CmpFOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- ma::CmpFOp op, mlir::PatternRewriter& rewriter) const override {
- using FloatValue = mlir::TypedValue<mlir::FloatType>;
- auto lhs = mlir::cast<FloatValue>(op.getLhs());
- auto rhs = mlir::cast<FloatValue>(op.getRhs());
-
- if (lhs.getType().getWidth() != 8) {
- return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf");
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- // Skip the f32 conversion if we're comparing UNE.cst.
- llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics());
- if (op.getPredicate() == ma::CmpFPredicate::UNE &&
- mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) {
- Val int_value{b.create<ma::BitcastOp>(rewriter.getI8Type(), lhs), &b};
- int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue();
- // If we're comparing to +-0, compare the absolute values.
- if (rhs_cst.isZero() &&
- (lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) {
- int_value = int_value & 0x7f;
- constant &= 0x7f;
- }
- auto cst = b.create<ma::ConstantIntOp>(constant, rewriter.getI8Type());
- rewriter.replaceOpWithNewOp<ma::CmpIOp>(op, ma::CmpIPredicate::ne,
- int_value, cst);
- return mlir::success();
- }
-
- auto lhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), lhs);
- auto rhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), rhs);
- rewriter.replaceOpWithNewOp<ma::CmpFOp>(op, op->getResultTypes(),
- mlir::ValueRange{lhs_ext, rhs_ext},
- op->getAttrs());
- return mlir::success();
- }
-};
-
-struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- mlir::math::AbsFOp op, mlir::PatternRewriter& rewriter) const override {
- using FloatValue = mlir::TypedValue<mlir::FloatType>;
- auto src = mlir::cast<FloatValue>(op.getOperand());
- // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16.
- // Once that's removed, remove the code for BF16 here.
- if (src.getType().getWidth() != 8 && !src.getType().isBF16()) {
- return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf");
- }
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
- Val value{b.create<ma::BitcastOp>(i_ty, src), &b};
- if (src.getType().getWidth() == 8) {
- value = value & 0x7f;
- } else {
- CHECK(src.getType().isBF16());
- value = value & 0x7fff;
- }
- rewriter.replaceOpWithNewOp<ma::BitcastOp>(op, src.getType(), value);
- return mlir::success();
- }
-};
-
-template <typename Op>
-struct RewriteIToFpPattern : public mlir::OpRewritePattern<Op> {
- using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- Op op, mlir::PatternRewriter& rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() != 8) {
- return rewriter.notifyMatchFailure(op, "not an f8 itofp");
- }
- Value to_float =
- rewriter.create<Op>(op.getLoc(), rewriter.getF32Type(), op.getIn());
- rewriter.replaceOpWithNewOp<ma::TruncFOp>(op, op.getType(), to_float);
- return mlir::success();
- }
-};
-
-template <typename Op>
-struct RewriteFpToIPattern : public mlir::OpRewritePattern<Op> {
- using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- Op op, mlir::PatternRewriter& rewriter) const override {
- if (op.getIn().getType().getIntOrFloatBitWidth() != 8) {
- return rewriter.notifyMatchFailure(op, "not an f8 fptoi");
- }
- Value to_f32 = rewriter.create<ma::ExtFOp>(
- op.getLoc(), rewriter.getF32Type(), op.getIn());
- rewriter.replaceOpWithNewOp<Op>(op, op.getType(), to_f32);
- return mlir::success();
- }
-};
-
-class ExpandFloatOpsPass
- : public impl::ExpandFloatOpsPassBase<ExpandFloatOpsPass> {
- public:
- using ExpandFloatOpsPassBase::ExpandFloatOpsPassBase;
- void runOnOperation() override {
- mlir::RewritePatternSet patterns(&getContext());
- patterns.add<RewriteToCmpSelect<ma::MinimumFOp, ma::CmpFPredicate::OLE>>(
- &getContext(), /*include_f32=*/pre_ampere_);
- patterns.add<RewriteToCmpSelect<ma::MaximumFOp, ma::CmpFPredicate::OGE>>(
- &getContext(), /*include_f32=*/pre_ampere_);
- patterns.add<RewriteTruncFPattern, RewriteExtFPattern, RewriteAbsFPattern,
- RewriteF8Cst, RewriteIToFpPattern<ma::SIToFPOp>,
- RewriteIToFpPattern<ma::UIToFPOp>,
- RewriteFpToIPattern<ma::FPToSIOp>,
- RewriteFpToIPattern<ma::FPToUIOp>>(&getContext());
- mlir::populatePolynomialApproximateTanhPattern(patterns);
- patterns.add<RewriteErf32Pattern>(&getContext());
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere) {
- return createExpandFloatOpsPass(ExpandFloatOpsPassOptions{pre_ampere});
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc b/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc
deleted file mode 100644
index 99a7ecb..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc
+++ /dev/null
@@ -1,452 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/IR/Visitors.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/shape_util.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DEF_FLATTENTENSORSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-using mlir::Location;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpRewritePattern;
-using mlir::PatternRewriter;
-using mlir::RankedTensorType;
-using mlir::SmallVector;
-using mlir::Type;
-using mlir::TypedValue;
-using mlir::TypeRange;
-using mlir::UnrealizedConversionCastOp;
-using mlir::Value;
-using mlir::ValueRange;
-using mlir::func::FuncOp;
-using mlir::func::ReturnOp;
-using mlir::scf::ForOp;
-using mlir::scf::IfOp;
-using mlir::tensor::ExtractOp;
-using mlir::tensor::InsertOp;
-
-RankedTensorType GetFlattenedType(RankedTensorType tensor_type) {
- return RankedTensorType::get({tensor_type.getNumElements()},
- tensor_type.getElementType());
-}
-
-bool HasOnlyFlatTensorsOrScalars(TypeRange types) {
- return llvm::all_of(types, [](Type ty) {
- auto tensor_type = mlir::dyn_cast<RankedTensorType>(ty);
- if (!tensor_type) return true;
- return tensor_type.getRank() < 2;
- });
-}
-
-struct RewriteFunctionSignatures : OpRewritePattern<FuncOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(FuncOp op,
- PatternRewriter& rewriter) const override {
- auto input_types = op.getFunctionType().getInputs();
- auto result_types = op.getFunctionType().getResults();
- if (HasOnlyFlatTensorsOrScalars(input_types) &&
- HasOnlyFlatTensorsOrScalars(result_types)) {
- return rewriter.notifyMatchFailure(op, "nothing to flatten");
- }
-
- auto loc = op.getLoc();
- mlir::Block* entry_block = &op.getBody().front();
- SmallVector<Type> new_result_types;
- SmallVector<Value> new_results;
-
- // If some results are tensors, we need to flatten them.
- auto terminator = entry_block->getTerminator();
- rewriter.setInsertionPoint(terminator);
-
- for (Value result : terminator->getOperands()) {
- auto tensor_type = mlir::dyn_cast<RankedTensorType>(result.getType());
- if (!tensor_type) {
- new_result_types.push_back(result.getType());
- new_results.push_back(result);
- continue;
- }
- auto new_result_type = GetFlattenedType(tensor_type);
- new_result_types.push_back(new_result_type);
-
- Value result_1d =
- rewriter
- .create<UnrealizedConversionCastOp>(loc, new_result_type, result)
- .getResult(0);
- new_results.push_back(result_1d);
- }
- rewriter.replaceOpWithNewOp<ReturnOp>(terminator, new_results);
-
- // Cast all function arguments to the original type.
- SmallVector<Type> new_operand_types(input_types);
- rewriter.setInsertionPointToStart(entry_block);
- for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) {
- if (auto tensor_type = mlir::dyn_cast<RankedTensorType>(operand_type)) {
- if (tensor_type.getRank() > 1) {
- mlir::BlockArgument func_argument = op.getArgument(index);
- auto cast_to_orig_type = rewriter.create<UnrealizedConversionCastOp>(
- loc, operand_type, func_argument);
- func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0),
- cast_to_orig_type);
- operand_type = GetFlattenedType(tensor_type);
- }
- }
- }
- // Replace the function arguments with the new types.
- for (auto [arg, arg_type] :
- llvm::zip(entry_block->getArguments(), new_operand_types)) {
- arg.setType(arg_type);
- }
- // Update function signature.
- op.setType(rewriter.getFunctionType(new_operand_types, new_result_types));
- return mlir::success();
- }
-};
-
-// Returns the linearized index, if the rank is greater than 1. Otherwise,
-// returns nullptr.
-Value LinearizeIndex(TypedValue<mlir::RankedTensorType> tensor,
- ValueRange indices, PatternRewriter& rewriter) {
- if (tensor.getType().getRank() < 2) {
- return nullptr;
- }
- auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
- if (auto encoding = tensor.getType().getEncoding()) {
- *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
- mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
- }
- auto linear_shape =
- ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
- auto linearized_map =
- GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
- mlir::SmallVector<Value> result;
- rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
- ValueRange{}, linearized_map);
- return result.front();
-}
-
-struct RewriteTensorExtract : OpRewritePattern<ExtractOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp op,
- PatternRewriter& rewriter) const override {
- auto tensor = op.getTensor();
- auto tensor_type = tensor.getType();
- auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
- if (linear_index == nullptr) {
- return rewriter.notifyMatchFailure(op, "the tensor is already flat");
- }
- auto tensor_1D = rewriter
- .create<UnrealizedConversionCastOp>(
- op.getLoc(), GetFlattenedType(tensor_type), tensor)
- .getResult(0);
- rewriter.replaceOpWithNewOp<ExtractOp>(op, tensor_1D, linear_index);
- return mlir::success();
- }
-};
-
-struct RewriteTensorInsert : OpRewritePattern<InsertOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertOp op,
- PatternRewriter& rewriter) const override {
- auto tensor = op.getDest();
- auto tensor_type = tensor.getType();
- auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
- if (linear_index == nullptr) {
- return rewriter.notifyMatchFailure(op, "the tensor is already flat");
- }
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto tensor_1D = b.create<UnrealizedConversionCastOp>(
- GetFlattenedType(tensor_type), tensor)
- .getResult(0);
- auto new_insert =
- b.create<InsertOp>(op.getScalar(), tensor_1D, linear_index);
- auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
- tensor_type, new_insert.getResult());
- rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
- return mlir::success();
- }
-};
-
-struct RewriteAtomicRMW : OpRewritePattern<AtomicRMWOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(AtomicRMWOp op,
- PatternRewriter& rewriter) const override {
- auto tensor = op.getInput();
- auto tensor_type = tensor.getType();
- auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
- if (linear_index == nullptr) {
- return rewriter.notifyMatchFailure(op, "the tensor is already flat");
- }
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto tensor_1D = b.create<UnrealizedConversionCastOp>(
- GetFlattenedType(tensor_type), tensor)
- .getResult(0);
- auto new_atomic_rmw = b.create<AtomicRMWOp>(tensor_1D, linear_index);
- rewriter.inlineRegionBefore(op.getRegion(),
- &new_atomic_rmw.getRegion().front());
- auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
- tensor_type, new_atomic_rmw.getResult());
- rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
- return mlir::success();
- }
-};
-
-// Checks that the value is produced by an unrealized conversion cast from 1D
-// tensor to ND. Returns the 1D tensor if so.
-std::optional<Value> GetDelinearizedTensor(Value value) {
- auto tensor_type = mlir::dyn_cast<RankedTensorType>(value.getType());
- if (!tensor_type || tensor_type.getRank() < 2) {
- return std::nullopt;
- }
- auto cast = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (!cast || cast->getNumResults() != 1 || cast->getNumOperands() != 1) {
- return std::nullopt;
- }
- auto type_before_linearization =
- mlir::dyn_cast<RankedTensorType>(cast->getOperand(0).getType());
- if (!type_before_linearization || type_before_linearization.getRank() != 1) {
- return std::nullopt;
- }
- return cast->getOperand(0);
-}
-
-struct RewriteForOp : public OpRewritePattern<ForOp> {
- using OpRewritePattern<ForOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ForOp op,
- PatternRewriter& rewriter) const override {
- llvm::SmallBitVector args_to_update(op.getNumResults(), false);
- mlir::SmallVector<Value> new_init_args;
- new_init_args.reserve(op.getNumResults());
- for (auto [index, arg] : llvm::enumerate(op.getInitArgs())) {
- auto type_before_linearization = GetDelinearizedTensor(arg);
- if (!type_before_linearization.has_value()) {
- new_init_args.push_back(arg);
- continue;
- }
- new_init_args.push_back(*type_before_linearization);
- args_to_update.set(index);
- }
- if (args_to_update.none()) {
- return rewriter.notifyMatchFailure(op, "no args to update");
- }
- // Create new ForOp with updated init args.
- Location loc = op.getLoc();
- auto new_for_op =
- rewriter.create<ForOp>(loc, op.getLowerBound(), op.getUpperBound(),
- op.getStep(), new_init_args);
- new_for_op->setAttrs(op->getAttrs());
-
- // Insert casts for the block arguments.
- mlir::Block* new_body = new_for_op.getBody();
- mlir::Block* old_body = op.getBody();
- rewriter.setInsertionPoint(new_body, new_body->begin());
- SmallVector<Value, 4> updated_block_args{new_body->getArguments().begin(),
- new_body->getArguments().end()};
- for (auto [index, arg] :
- llvm::enumerate(new_body->getArguments().drop_front())) {
- if (!args_to_update.test(index)) continue;
- updated_block_args[index + 1] =
- rewriter
- .create<UnrealizedConversionCastOp>(
- loc, old_body->getArgument(index + 1).getType(), arg)
- .getResult(0);
- }
-
- // Move the body of the old ForOp to the new one.
- rewriter.mergeBlocks(old_body, new_body, updated_block_args);
-
- // Update the terminator.
- auto new_terminator =
- mlir::cast<mlir::scf::YieldOp>(new_body->getTerminator());
- rewriter.setInsertionPoint(new_terminator);
- for (auto&& [index, yielded_value] :
- llvm::enumerate(new_terminator.getResultsMutable())) {
- if (!args_to_update.test(index)) continue;
- yielded_value.assign(
- rewriter
- .create<UnrealizedConversionCastOp>(
- loc, new_init_args[index].getType(), yielded_value.get())
- .getResult(0));
- }
-
- // Cast back the results.
- rewriter.setInsertionPointAfter(new_for_op);
- SmallVector<Value> new_results(new_for_op.getResults());
- for (auto&& [index, result] : llvm::enumerate(new_results)) {
- if (!args_to_update.test(index)) continue;
- result = rewriter
- .create<UnrealizedConversionCastOp>(
- loc, op->getResult(index).getType(), result)
- .getResult(0);
- }
- rewriter.replaceOp(op, new_results);
- return mlir::failure();
- }
-};
-
-struct RewriteIfOp : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter& rewriter) const override {
- auto result_types = op.getResultTypes();
- if (HasOnlyFlatTensorsOrScalars(result_types)) {
- return rewriter.notifyMatchFailure(op, "nothing to flatten");
- }
- mlir::scf::YieldOp then_yield = op.thenYield();
- SmallVector<Type> new_result_types;
- new_result_types.reserve(then_yield.getNumOperands());
- bool found_cast = false;
- for (auto& result : then_yield->getOpOperands()) {
- auto delinearized_tensor = GetDelinearizedTensor(result.get());
- if (!delinearized_tensor.has_value()) {
- new_result_types.push_back(result.get().getType());
- continue;
- }
- new_result_types.push_back(delinearized_tensor->getType());
- result.set(*delinearized_tensor);
- found_cast = true;
- }
- if (!found_cast) {
- return rewriter.notifyMatchFailure(op, "no cast found");
- }
- Location loc = op.getLoc();
- // Update the else branch if present.
- bool has_else_region = !op.getElseRegion().empty();
- if (has_else_region) {
- mlir::scf::YieldOp else_yield = op.elseYield();
- mlir::OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(else_yield);
- for (auto&& [result, type] :
- llvm::zip(else_yield->getOpOperands(), new_result_types)) {
- if (result.get().getType() == type) continue;
- result.set(
- rewriter.create<UnrealizedConversionCastOp>(loc, type, result.get())
- .getResult(0));
- }
- }
- // Create new IfOp and move the old op's regions to the new one.
- auto new_if_op = rewriter.create<IfOp>(loc, new_result_types,
- op.getCondition(), has_else_region);
- rewriter.inlineRegionBefore(op.getThenRegion(),
- &new_if_op.getThenRegion().back());
- rewriter.eraseBlock(&new_if_op.getThenRegion().back());
- if (has_else_region) {
- rewriter.inlineRegionBefore(op.getElseRegion(),
- &new_if_op.getElseRegion().back());
- rewriter.eraseBlock(&new_if_op.getElseRegion().back());
- }
-
- // Update the results.
- rewriter.setInsertionPointAfter(new_if_op);
- SmallVector<Value> new_results(new_if_op.getResults());
- for (auto&& [index, result] : llvm::enumerate(new_results)) {
- Type old_type = op->getResult(index).getType();
- if (result.getType() == old_type) continue;
- result =
- rewriter.create<UnrealizedConversionCastOp>(loc, old_type, result)
- .getResult(0);
- }
- rewriter.replaceOp(op, new_results);
- return mlir::success();
- }
-};
-
-class FlattenTensorsPass
- : public impl::FlattenTensorsPassBase<FlattenTensorsPass> {
- public:
- void runOnOperation() override {
- mlir::ModuleOp module = getOperation();
- MLIRContext* mlir_context = &getContext();
- mlir::RewritePatternSet patterns(mlir_context);
- // clang-format off
- patterns.add<
- RewriteAtomicRMW,
- RewriteForOp,
- RewriteFunctionSignatures,
- RewriteIfOp,
- RewriteTensorExtract,
- RewriteTensorInsert
- >(mlir_context);
- // clang-format on
- ApplyIndexingOp::getCanonicalizationPatterns(patterns, mlir_context);
- if (mlir::failed(
- mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
- signalPassFailure();
- return;
- }
- // Check if there are no unrealized_conversion_casts.
- bool module_has_casts = module
- .walk([](UnrealizedConversionCastOp op) {
- return mlir::WalkResult::interrupt();
- })
- .wasInterrupted();
- if (module_has_casts) {
- llvm::outs() << "FlattenTensorsPass failed to converge";
- signalPassFailure();
- return;
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass() {
- return std::make_unique<FlattenTensorsPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD
deleted file mode 100644
index d618413..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD
+++ /dev/null
@@ -1,110 +0,0 @@
-load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
-
-package(
- # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
- default_visibility = [":friends"],
- licenses = ["notice"],
-)
-
-package_group(
- name = "friends",
- includes = [
- "//xla:friends",
- ],
-)
-
-td_library(
- name = "xla_gpu_td_files",
- srcs = glob(["*.td"]),
- includes = ["."],
- deps = [
- "@llvm-project//mlir:CallInterfacesTdFiles",
- "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:SideEffectInterfacesTdFiles",
- ],
-)
-
-gentbl_cc_library(
- name = "xla_gpu_dialect_inc_gen",
- strip_include_prefix = ".",
- tbl_outs = [
- (
- ["-gen-dialect-decls"],
- "xla_gpu_dialect.h.inc",
- ),
- (
- ["-gen-dialect-defs"],
- "xla_gpu_dialect.cc.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "xla_gpu_dialect.td",
- deps = [":xla_gpu_td_files"],
-)
-
-gentbl_cc_library(
- name = "xla_gpu_ops_inc_gen",
- strip_include_prefix = ".",
- tbl_outs = [
- (
- ["-gen-op-decls"],
- "xla_gpu_ops.h.inc",
- ),
- (
- ["-gen-op-defs"],
- "xla_gpu_ops.cc.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "xla_gpu_ops.td",
- deps = [":xla_gpu_td_files"],
-)
-
-gentbl_cc_library(
- name = "xla_gpu_attrs_inc_gen",
- strip_include_prefix = ".",
- tbl_outs = [
- (
- ["-gen-attrdef-decls"],
- "xla_gpu_attrs.h.inc",
- ),
- (
- ["-gen-attrdef-defs"],
- "xla_gpu_attrs.cc.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "xla_gpu_attrs.td",
- deps = [":xla_gpu_td_files"],
-)
-
-cc_library(
- name = "xla_gpu",
- srcs = [
- "xla_gpu_attrs.cc",
- "xla_gpu_dialect.cc",
- "xla_gpu_ops.cc",
- ],
- hdrs = [
- "xla_gpu_attrs.h",
- "xla_gpu_ops.h",
- ],
- deps = [
- ":xla_gpu_attrs_inc_gen",
- ":xla_gpu_dialect_inc_gen",
- ":xla_gpu_ops_inc_gen",
- "//xla/service/gpu/model:indexing_analysis",
- "@com_google_absl//absl/strings:str_format",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:ArithDialect",
- "@llvm-project//mlir:BytecodeOpInterface",
- "@llvm-project//mlir:CallOpInterfaces",
- "@llvm-project//mlir:FuncDialect",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:InferTypeOpInterface",
- "@llvm-project//mlir:InliningUtils",
- "@llvm-project//mlir:SideEffectInterfaces",
- "@llvm-project//mlir:Support",
- ],
-)
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc
deleted file mode 100644
index d38ed34..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc
+++ /dev/null
@@ -1,228 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h"
-
-#include <string>
-#include <utility>
-
-#include "absl/strings/str_format.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-#define GET_ATTRDEF_LIST
-#define GET_ATTRDEF_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc"
-
-namespace xla {
-namespace gpu {
-
-using llvm::ParseResult;
-using llvm::SmallVector;
-using mlir::AffineExpr;
-using mlir::ArrayRef;
-using mlir::AsmParser;
-using mlir::AsmPrinter;
-using mlir::failure;
-using mlir::success;
-
-ParseResult ParseInterval(AsmParser& parser, Interval& interval) {
- // ParseResult converts to `true` if parsing failed.
- return failure(parser.parseLSquare() || parser.parseInteger(interval.lower) ||
- parser.parseComma() || parser.parseInteger(interval.upper) ||
- parser.parseRSquare());
-}
-
-void PrintDimVars(AsmPrinter& p, ArrayRef<DimVar> dim_vars) {
- int index = 0;
- llvm::interleaveComma(dim_vars, p, [&](const DimVar& dim_var) {
- p << "d" << index++ << " in " << dim_var.bounds;
- });
-}
-
-ParseResult ParseDimVars(AsmParser& parser, ArrayRef<std::string> dim_names,
- SmallVector<DimVar>& dim_vars) {
- dim_vars.reserve(dim_names.size());
- for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) {
- if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") ||
- ParseInterval(parser, dim_vars.emplace_back().bounds)) {
- return failure();
- }
- if (index < dim_names.size() - 1 && parser.parseComma()) {
- return failure();
- }
- }
- return success();
-}
-
-void PrintRangeVars(AsmPrinter& p, ArrayRef<RangeVar> range_vars) {
- int index = 0;
- llvm::interleaveComma(range_vars, p, [&](const RangeVar& range_var) {
- p << "s" << index++ << " in " << range_var.range;
- });
-}
-
-ParseResult ParseRangeVars(AsmParser& parser,
- ArrayRef<std::string> range_symbol_names,
- SmallVector<RangeVar>& range_vars) {
- range_vars.reserve(range_symbol_names.size());
- for (const auto& [index, range_symbol_name] :
- llvm::enumerate(range_symbol_names)) {
- if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") ||
- ParseInterval(parser, range_vars.emplace_back().range)) {
- return failure();
- }
- if (index < range_symbol_names.size() - 1 && parser.parseComma()) {
- return failure();
- }
- }
- return success();
-}
-
-void PrintConstraints(AsmPrinter& p,
- ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
- llvm::interleaveComma(constraints, p, [&](const auto& constraint) {
- p << constraint.first << " in " << constraint.second;
- });
-}
-
-ParseResult ParseConstraints(
- AsmParser& parser,
- ArrayRef<std::pair<llvm::StringRef, AffineExpr>> symbolSet,
- SmallVector<std::pair<AffineExpr, Interval>>& constraints) {
- // In order for there to be any constraints, there must be at least 1 symbol
- // or dimension meaning there will be commas for as long as there are
- // constraints left.
- while (succeeded(parser.parseOptionalComma())) {
- auto& constraint = constraints.emplace_back();
- if (parser.parseAffineExpr(symbolSet, constraint.first) ||
- parser.parseKeyword("in") || ParseInterval(parser, constraint.second)) {
- return failure();
- }
- }
- return success();
-}
-
-mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) {
- mlir::AffineMap map;
- if (parser.parseLess() || parser.parseAffineMap(map)) {
- return {};
- }
-
- // Store real strings to back up StringRef throughout ParseConstraints.
- SmallVector<std::string> dim_strings(map.getNumDims());
- SmallVector<std::string> symbol_strings(map.getNumSymbols());
- SmallVector<std::pair<llvm::StringRef, AffineExpr>> symbolSet;
- symbolSet.reserve(map.getNumDims() + map.getNumSymbols());
- for (int i = 0; i < map.getNumDims(); ++i) {
- dim_strings[i] = absl::StrFormat("d%d", i);
- symbolSet.push_back(
- {dim_strings[i], mlir::getAffineDimExpr(i, parser.getContext())});
- }
- for (int i = 0; i < map.getNumSymbols(); ++i) {
- symbol_strings[i] = absl::StrFormat("s%d", i);
- symbolSet.push_back(
- {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())});
- }
- if (map.getNumDims() + map.getNumSymbols() > 0) {
- if (parser.parseComma() || parser.parseKeyword("domain") ||
- parser.parseColon()) {
- return {};
- }
- }
-
- SmallVector<DimVar> dim_vars;
- if (map.getNumDims() > 0) {
- if (ParseDimVars(parser, dim_strings, dim_vars)) {
- return {};
- }
- }
-
- SmallVector<RangeVar> range_vars;
- if (map.getNumSymbols() > 0) {
- if (!dim_vars.empty() && parser.parseComma()) {
- return {};
- }
- if (ParseRangeVars(parser, symbol_strings, range_vars)) {
- return {};
- }
- }
-
- SmallVector<std::pair<AffineExpr, Interval>> constraints;
- if (ParseConstraints(parser, symbolSet, constraints) ||
- parser.parseGreater()) {
- return {};
- }
- return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars,
- constraints);
-}
-
-void IndexingMapAttr::print(mlir::AsmPrinter& printer) const {
- printer << "<";
- printer.printStrippedAttrOrType(getMap());
- if (getDimVars().size() + getRangeVars().size() + getConstraints().size() >
- 0) {
- printer << ", domain: ";
- }
- PrintDimVars(printer, getDimVars());
- if (!getDimVars().empty() &&
- getRangeVars().size() + getConstraints().size() > 0) {
- printer << ", ";
- }
- PrintRangeVars(printer, getRangeVars());
- if (!getRangeVars().empty() && !getConstraints().empty()) {
- printer << ", ";
- }
- PrintConstraints(printer, getConstraints());
- printer << ">";
-}
-
-IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context,
- const IndexingMap& indexing_map) {
- llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints;
- for (auto& constraint : indexing_map.GetConstraints()) {
- constraints.push_back({constraint.first, constraint.second});
- }
- return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(),
- indexing_map.GetRangeVars(), constraints);
-}
-
-mlir::LogicalResult IndexingMapAttr::verify(
- mlir::function_ref<mlir::InFlightDiagnostic()> emitError,
- mlir::AffineMap map, ArrayRef<DimVar> dim_vars,
- ArrayRef<RangeVar> range_vars,
- ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
- if (map.getNumDims() != dim_vars.size()) {
- return emitError()
- << "dim size must match the number of dimensions in the affine map";
- }
- if (map.getNumSymbols() != range_vars.size()) {
- return emitError()
- << "range size must match the number of symbols in the affine map";
- }
- return mlir::success();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h
deleted file mode 100644
index fca9216..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_
-
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep
-
-namespace xla {
-namespace gpu {
-
-// Custom parser to parse IndexingMapAttr.
-mlir::FailureOr<mlir::Attribute> ParseIndexingMapAttr(mlir::AsmParser& parser);
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td
deleted file mode 100644
index 8c8f98c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
-
-include "mlir/IR/AttrTypeBase.td"
-include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td"
-
-class XLAGPU_Attr<string name, list<Trait> traits = []> :
- AttrDef<XlaGpuDialect, name, traits> {
-}
-
-def XLAGPU_AffineMapParameter :
- AttrOrTypeParameter<"::mlir::AffineMap", ""> {
-}
-
-def XLAGPU_DimVarsParameter : ArrayRefParameter<"::xla::gpu::DimVar",
- "DimVarArray"> {
-}
-
-def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar",
- "RangeVarArray"> {
-}
-
-def XLAGPU_ConstraintsParameter :
- ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>",
- "ContraintsArray"> {
-}
-
-def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> {
- let summary = "An Attribute representing an indexing map.";
- let mnemonic = "indexing_map";
- let description = [{This attribute stores an indexing map. See
- https://openxla.org/xla/indexing for more details.
- }];
- let parameters = (ins XLAGPU_AffineMapParameter:$map,
- XLAGPU_DimVarsParameter:$dim_vars,
- XLAGPU_RangeVarsParameter:$range_vars,
- XLAGPU_ConstraintsParameter:$constraints);
- let hasCustomAssemblyFormat = 1;
- let builders = [
- AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>,
- ];
- let genVerifyDecl = 1;
-}
-
-#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_ATTRS
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc
deleted file mode 100644
index 3dc60c9..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc
+++ /dev/null
@@ -1,129 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "llvm/ADT/TypeSwitch.h"
-#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
-#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
-#include "mlir/Transforms/InliningUtils.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#define GET_ATTRDEF_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc"
-#undef GET_ATTRDEF_CLASSES
-
-namespace xla {
-namespace gpu {
-namespace {
-
-struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface {
- using DialectInlinerInterface::DialectInlinerInterface;
- // Returns true if the given operation 'callable', that implements the
- // 'CallableOpInterface', can be inlined into the position given call
- // operation 'call', that is registered to the current dialect and implements
- // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the
- // given 'callable' is set to be cloned during the inlining process, or false
- // if the region is set to be moved in-place (i.e. no duplicates would be
- // created).
- bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable,
- bool wouldBeCloned) const final {
- if (!wouldBeCloned) {
- // If no duplicate would be created, 'call' is likely the only caller of
- // 'callable'.
- return true;
- }
- // Otherwise, inline only if the called function is small. We could
- // theoretically also inline if there is no other caller in the function
- // that contains the callee that has a call path to the callable, but that
- // is more expensive to check.
- auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(callable);
- if (!func_op) {
- return false;
- }
- auto region = func_op.getCallableRegion();
- if (!region) {
- return false;
- }
-
- // If callee and caller call the same third function, inline. We have no
- // guarantee that the indices are the same, but there is a good chance they
- // are (or if the callee gets inlined as well, there will be CSE
- // opportunities).
- // This is duct tape to work around the limitations of our partitioner.
- // Ideally, the partitioner would be aware of the actual indexing and create
- // the partitions based on it (i.e., the case where the indices are the same
- // would never happen).
- llvm::SmallDenseSet<llvm::StringRef> callee_calls;
- for (auto call : region->getOps<PureCallOp>()) {
- callee_calls.insert(call.getCallee());
- }
- for (auto call : call->getParentRegion()->getOps<PureCallOp>()) {
- if (callee_calls.contains(call.getCallee())) {
- return true;
- }
- }
-
- constexpr int kMaxOperationsToInline = 8;
- int num_ops = 0;
- region->front().walk([&](mlir::Operation* op) { ++num_ops; });
-
- // Don't inline functions that are called more than once and contain more
- // than one call themselves.
- return num_ops <= kMaxOperationsToInline;
- }
- // Returns true if the given operation 'op', that is registered to this
- // dialect, can be inlined into the given region, false otherwise.
- // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned
- // during the inlining process, or false if the operation is set to be moved
- // in-place(i.e. no duplicates would be created). 'valueMapping' contains any
- // remapped values from within the 'src' region. This can be used to examine
- // what values may potentially replace the operands to 'op'.
- bool isLegalToInline(mlir::Operation* op, mlir::Region* dest,
- bool wouldBeCloned,
- mlir::IRMapping& valueMapping) const final {
- // We allow any op from the xla_gpu dialect to be inlined.
- return true;
- }
-};
-
-struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface {
- using OpAsmDialectInterface::OpAsmDialectInterface;
- AliasResult getAlias(mlir::Attribute attr,
- mlir::raw_ostream& os) const final {
- if (llvm::isa<IndexingMapAttr>(attr)) {
- os << "indexing_map";
- return AliasResult::FinalAlias;
- }
- return AliasResult::NoAlias;
- }
-};
-
-} // namespace
-
-void XlaGpuDialect::initialize() {
- addOperations<
-#define GET_OP_LIST
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc"
-#undef GET_OP_LIST
- >();
- addAttributes<
-#define GET_ATTRDEF_LIST
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc"
- >();
-#undef GET_ATTRDEF_LIST
- addInterfaces<XlaGpuInlinerInterface, XlaGpuOpAsmDialectInterface>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td
deleted file mode 100644
index 4400747..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td
+++ /dev/null
@@ -1,32 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
-
-include "mlir/IR/DialectBase.td"
-
-def XlaGpuDialect : Dialect {
- let name = "xla_gpu";
-
- let description = [{
- This dialect contains ops required for lowering HLO to LLVM.
- }];
-
- let cppNamespace = "::xla::gpu";
- let useDefaultAttributePrinterParser = 1;
-}
-
-#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_DIALECT
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
deleted file mode 100644
index dfa4d05..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
+++ /dev/null
@@ -1,655 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/STLFunctionalExtras.h"
-#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Builders.h" // IWYU pragma: keep
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
-#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h" // IWYU pragma: keep
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using llvm::ArrayRef;
-using mlir::AffineExpr;
-using mlir::AffineMap;
-using mlir::failure;
-using mlir::getAffineConstantExpr;
-using mlir::getAffineDimExpr;
-using mlir::getAffineSymbolExpr;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpBuilder;
-using mlir::OperationState;
-using mlir::PatternRewriter;
-using mlir::RankedTensorType;
-using mlir::Region;
-using mlir::SmallVector;
-using mlir::success;
-using mlir::Type;
-using mlir::Value;
-using mlir::ValueRange;
-
-namespace arith = mlir::arith;
-
-} // namespace
-
-LogicalResult PureCallOp::verifySymbolUses(
- mlir::SymbolTableCollection& symbolTable) {
- auto callee = getCalleeAttr();
- auto function =
- symbolTable.lookupNearestSymbolFrom<mlir::func::FuncOp>(*this, callee);
- if (!function) {
- return emitError("'f' attribute refers to an undefined function: ")
- << callee;
- }
-
- int func_arg_count = function.getFunctionType().getNumInputs();
- int arg_count = getOperands().size();
-
- if (arg_count != func_arg_count) {
- return emitError() << "argument count mismatch: 'operands' has "
- << arg_count << " arguments, but '" << callee
- << "' expects " << func_arg_count;
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// AllocateSharedOp
-//===----------------------------------------------------------------------===//
-
-void AllocateSharedOp::getAsmResultNames(
- llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
- setNameFn(getResult(), "shmem");
-}
-
-//===----------------------------------------------------------------------===//
-// ApplyIndexingOp
-//===----------------------------------------------------------------------===//
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
- ValueRange dims, ValueRange symbols,
- const IndexingMap& indexing_map) {
- SmallVector<Value, 4> operands;
- operands.reserve(dims.size() + symbols.size());
- operands.append(dims.begin(), dims.end());
- operands.append(symbols.begin(), symbols.end());
- build(builder, result, operands, indexing_map);
-}
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
- ValueRange operands,
- const IndexingMap& indexing_map) {
- build(builder, result, operands, indexing_map.GetAffineMap(),
- indexing_map.GetDimVars(), indexing_map.GetRangeVars());
-}
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
- ValueRange operands, AffineMap affine_map,
- ArrayRef<DimVar> dim_vars,
- ArrayRef<RangeVar> range_vars) {
- SmallVector<int64_t, 4> lower_bounds, upper_bounds;
- for (const DimVar& dim_var : dim_vars) {
- lower_bounds.push_back(dim_var.bounds.lower);
- upper_bounds.push_back(dim_var.bounds.upper);
- }
- for (const RangeVar& range_var : range_vars) {
- lower_bounds.push_back(range_var.range.lower);
- upper_bounds.push_back(range_var.range.upper);
- }
- build(builder, result, operands, affine_map, lower_bounds, upper_bounds);
-}
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
- ValueRange operands, AffineMap affine_map,
- ArrayRef<int64_t> lower_bounds,
- ArrayRef<int64_t> upper_bounds) {
- SmallVector<Type, 2> result_types(affine_map.getNumResults(),
- builder.getIndexType());
- build(builder, result, result_types, operands, affine_map, lower_bounds,
- upper_bounds);
-}
-
-// Parser a comma-separated list of type %operand in [lower_bound, upper_bound].
-// Adds the parsed elements to the provided containers.
-mlir::ParseResult parseOperandsWithBoundsList(
- mlir::OpAsmParser& parser,
- SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4>* operands,
- SmallVector<int64_t, 4>* lower_bounds,
- SmallVector<int64_t, 4>* upper_bounds) {
- int64_t lower_bound, upper_bound;
- mlir::OpAsmParser::UnresolvedOperand operand;
- if (parser.parseCommaSeparatedList([&]() {
- if (parser.parseOperand(operand) || parser.parseKeyword("in") ||
- parser.parseLSquare() || parser.parseInteger(lower_bound) ||
- parser.parseComma() || parser.parseInteger(upper_bound) ||
- parser.parseRSquare()) {
- return failure();
- }
- operands->push_back(operand);
- lower_bounds->push_back(lower_bound);
- upper_bounds->push_back(upper_bound);
- return success();
- })) {
- return failure();
- }
- return success();
-}
-
-mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser& parser,
- OperationState& result) {
- mlir::Builder& builder = parser.getBuilder();
- auto index_type = builder.getIndexType();
-
- mlir::AffineMapAttr affine_map_attr;
- if (parser.parseAttribute(affine_map_attr, "map", result.attributes)) {
- return failure();
- }
-
- SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> operands;
- SmallVector<int64_t, 4> lower_bounds, upper_bounds;
- if (succeeded(parser.parseOptionalLParen())) {
- if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds,
- &upper_bounds) ||
- parser.parseRParen()) {
- return failure();
- }
- }
- if (succeeded(parser.parseOptionalLSquare())) {
- if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds,
- &upper_bounds) ||
- parser.parseRSquare()) {
- return failure();
- }
- }
- if (parser.resolveOperands(operands, index_type, result.operands) ||
- parser.parseOptionalAttrDict(result.attributes)) {
- return failure();
- }
- result.addAttribute("lower_bounds",
- builder.getDenseI64ArrayAttr(lower_bounds));
- result.addAttribute("upper_bounds",
- builder.getDenseI64ArrayAttr(upper_bounds));
-
- auto map = affine_map_attr.getAffineMap();
- result.addTypes(SmallVector<Type, 2>(map.getNumResults(), index_type));
- return success();
-}
-
-void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) {
- mlir::AffineMapAttr affine_map_attr = getMapAttr();
- AffineMap affine_map = affine_map_attr.getAffineMap();
- p << " " << affine_map_attr;
-
- auto lower_bounds = getLowerBounds();
- auto upper_bounds = getUpperBounds();
- auto operands = getOperands();
- unsigned num_dimensions = affine_map.getNumDims();
- if (num_dimensions > 0) {
- p << '(';
- for (int dim_id = 0; dim_id < num_dimensions; ++dim_id) {
- p << operands[dim_id] << " in " << '[' << lower_bounds[dim_id] << ", "
- << upper_bounds[dim_id] << ']';
- if (dim_id != num_dimensions - 1) {
- p << ", ";
- }
- }
- p << ')';
- }
- unsigned num_symbols = affine_map.getNumSymbols();
- if (num_symbols > 0) {
- p << '[';
- for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) {
- unsigned operand_id = num_dimensions + symbol_id;
- p << operands[operand_id] << " in " << '[' << lower_bounds[operand_id]
- << ", " << upper_bounds[operand_id] << ']';
- if (symbol_id != num_symbols - 1) {
- p << ", ";
- }
- }
- p << ']';
- }
- p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
- "map", "lower_bounds", "upper_bounds"});
-}
-
-LogicalResult ApplyIndexingOp::verify() {
- auto affine_map = getMapAttr().getAffineMap();
- unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols();
- if (getOperands().size() != num_variables ||
- getLowerBounds().size() != num_variables ||
- getUpperBounds().size() != num_variables) {
- return emitOpError(
- "operand, lower_bounds, upper_bounds count and affine map dimension "
- "and symbol count must match");
- }
- return success();
-}
-
-IndexingMap ApplyIndexingOp::getIndexingMap() {
- auto lower_bounds = getLowerBounds();
- auto upper_bounds = getUpperBounds();
-
- AffineMap affine_map = getAffineMap();
- unsigned num_dimensions = affine_map.getNumDims();
- std::vector<DimVar> dim_vars;
- dim_vars.reserve(num_dimensions);
- for (unsigned id = 0; id < num_dimensions; ++id) {
- dim_vars.push_back(DimVar{Interval{lower_bounds[id], upper_bounds[id]}});
- }
- unsigned num_symbols = affine_map.getNumSymbols();
- std::vector<RangeVar> range_vars;
- range_vars.reserve(num_symbols);
- for (unsigned id = num_dimensions; id < num_symbols + num_dimensions; ++id) {
- range_vars.push_back(
- RangeVar{Interval{lower_bounds[id], upper_bounds[id]}});
- }
- return IndexingMap(affine_map, std::move(dim_vars), std::move(range_vars),
- /*rt_vars=*/{});
-}
-
-namespace {
-
-// Simplifies the indexing map, removes unused variables.
-struct SimplifyIndexingMap : public mlir::OpRewritePattern<ApplyIndexingOp> {
- using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
- PatternRewriter& rewriter) const override {
- IndexingMap indexing_map = indexing_op.getIndexingMap();
- bool is_simplified = indexing_map.Simplify();
-
- // Remove unused symbols.
- auto unused_symbols_bit_vector = indexing_map.RemoveUnusedVars();
- bool symbols_removed = unused_symbols_bit_vector.count() != 0;
-
- if (!is_simplified && !symbols_removed) {
- return rewriter.notifyMatchFailure(indexing_op,
- "IndexingMap stayed unchanged");
- }
- if (!unused_symbols_bit_vector.empty()) {
- SmallVector<Value, 4> operands;
- operands.reserve(unused_symbols_bit_vector.count());
- for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) {
- if (!unused_symbols_bit_vector[i]) {
- operands.push_back(indexing_op.getOperand(i));
- }
- }
- rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, operands,
- indexing_map);
- } else {
- rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
- indexing_op, indexing_op.getOperands(), indexing_map);
- }
- return success();
- }
-};
-
-struct FoldApplyIndexingSequence
- : public mlir::OpRewritePattern<ApplyIndexingOp> {
- using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
- PatternRewriter& rewriter) const override {
- MLIRContext* ctx = indexing_op.getContext();
- int num_dims = indexing_op.getAffineMap().getNumDims();
- int num_syms = indexing_op.getAffineMap().getNumSymbols();
- mlir::DenseMap<Value, AffineExpr> operand_exprs;
- for (auto& operand : indexing_op->getOpOperands()) {
- int operand_number = operand.getOperandNumber();
- operand_exprs[operand.get()] =
- operand_number < num_dims
- ? getAffineDimExpr(operand_number, ctx)
- : getAffineSymbolExpr(operand_number - num_dims, ctx);
- }
-
- auto this_map = indexing_op.getIndexingMap();
-
- SmallVector<Value> added_dim_args;
- SmallVector<Value> added_sym_args;
- auto new_dim_vars = this_map.GetDimVars();
- auto new_sym_vars = this_map.GetRangeVars();
-
- mlir::DenseMap<AffineExpr, AffineExpr> replacements;
- for (auto& operand : indexing_op->getOpOperands()) {
- if (auto producer = operand.get().getDefiningOp<ApplyIndexingOp>()) {
- auto producer_map = producer.getIndexingMap();
- int producer_result_id =
- mlir::cast<mlir::OpResult>(operand.get()).getResultNumber();
- int num_producer_dims = producer.getAffineMap().getNumDims();
- SmallVector<AffineExpr> producer_dim_replacements;
- SmallVector<AffineExpr> producer_sym_replacements;
- for (auto& producer_operand : producer->getOpOperands()) {
- int producer_operand_number = producer_operand.getOperandNumber();
- bool is_dim = producer_operand_number < num_producer_dims;
- auto& replacement_expr = operand_exprs[producer_operand.get()];
- if (!replacement_expr) {
- if (is_dim) {
- int dim_num = producer_operand_number;
- replacement_expr =
- getAffineDimExpr(num_dims + added_dim_args.size(), ctx);
- added_dim_args.push_back(producer_operand.get());
- new_dim_vars.push_back(producer_map.GetDimVars(dim_num));
- } else {
- int sym_num = producer_operand_number -
- producer.getAffineMap().getNumDims();
- replacement_expr =
- getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx);
- added_sym_args.push_back(producer_operand.get());
- new_sym_vars.push_back(producer_map.GetRangeVar(sym_num));
- }
- }
-
- if (is_dim) {
- producer_dim_replacements.push_back(replacement_expr);
- } else {
- producer_sym_replacements.push_back(replacement_expr);
- }
- }
-
- replacements[operand_exprs[operand.get()]] =
- producer.getAffineMap()
- .getResult(producer_result_id)
- .replaceDimsAndSymbols(producer_dim_replacements,
- producer_sym_replacements);
- }
- }
-
- if (replacements.empty()) {
- return rewriter.notifyMatchFailure(indexing_op,
- "No apply_indexing sequences found");
- }
-
- int new_num_operands = indexing_op->getNumOperands() +
- added_dim_args.size() + added_sym_args.size();
- auto new_affine_map = indexing_op.getAffineMap().replace(
- replacements, num_dims + added_dim_args.size(),
- num_syms + added_sym_args.size());
- IndexingMap new_indexing_map(new_affine_map, new_dim_vars, new_sym_vars,
- /*rt_vars=*/{});
- if (!new_indexing_map.Simplify()) {
- return rewriter.notifyMatchFailure(
- indexing_op, "Folded indexing map was not simplified");
- }
- SmallVector<Value> new_operands;
- new_operands.reserve(new_num_operands);
-
- auto begin = indexing_op.getOperands().begin();
- new_operands.append(begin, begin + num_dims);
- new_operands.append(added_dim_args);
- new_operands.append(begin + num_dims, begin + num_dims + num_syms);
- new_operands.append(added_sym_args);
-
- rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, new_operands,
- new_indexing_map);
- return success();
- }
-};
-
-// Folds constants into the indexing map.
-struct FoldApplyIndexingOperands
- : public mlir::OpRewritePattern<ApplyIndexingOp> {
- using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
- PatternRewriter& rewriter) const override {
- AffineMap affine_map = indexing_op.getAffineMap();
-
- MLIRContext* ctx = affine_map.getContext();
- unsigned num_operands = indexing_op->getNumOperands();
- unsigned num_dims = affine_map.getNumDims();
- unsigned num_symbols = affine_map.getNumSymbols();
-
- SmallVector<std::optional<int64_t>> constant_values(num_operands,
- std::nullopt);
- int num_constants = 0;
- SmallVector<int64_t> dim_id_map(num_dims, -1);
- SmallVector<int64_t> symbol_id_map(num_symbols, -1);
- for (auto& operand : indexing_op->getOpOperands()) {
- if (auto constant =
- operand.get().getDefiningOp<arith::ConstantIndexOp>()) {
- constant_values[operand.getOperandNumber()] = constant.value();
- ++num_constants;
- }
- }
- if (num_constants == 0) {
- return rewriter.notifyMatchFailure(indexing_op,
- "No constant operands found");
- }
- SmallVector<AffineExpr, 2> dim_replacements, symbol_replacements;
- dim_replacements.reserve(num_dims);
- symbol_replacements.reserve(num_symbols);
-
- unsigned new_num_operands = indexing_op->getNumOperands() - num_constants;
- SmallVector<Value, 4> new_operands;
- new_operands.reserve(new_num_operands);
- SmallVector<int64_t, 4> new_lbs, new_ubs;
- new_lbs.reserve(new_num_operands);
- new_ubs.reserve(new_num_operands);
-
- unsigned new_num_dims = 0;
- unsigned new_num_symbols = 0;
- for (auto [operand, constant_value, lb, ub] : llvm::zip(
- indexing_op->getOpOperands(), constant_values,
- indexing_op.getLowerBounds(), indexing_op.getUpperBounds())) {
- unsigned operand_id = operand.getOperandNumber();
- if (constant_value.has_value()) {
- if (operand_id < num_dims) {
- dim_replacements.push_back(
- getAffineConstantExpr(*constant_value, ctx));
- } else {
- symbol_replacements.push_back(
- getAffineConstantExpr(*constant_value, ctx));
- }
- } else {
- if (operand_id < num_dims) {
- dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx));
- } else {
- symbol_replacements.push_back(
- getAffineSymbolExpr(new_num_symbols++, ctx));
- }
- new_operands.push_back(operand.get());
- new_lbs.push_back(lb);
- new_ubs.push_back(ub);
- }
- }
- rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
- indexing_op, new_operands,
- affine_map.replaceDimsAndSymbols(dim_replacements, symbol_replacements,
- new_num_dims, new_num_symbols),
- new_lbs, new_ubs);
- return success();
- }
-};
-
-// Folds constant and dim/symbol expression results.
-struct FoldApplyIndexingResults
- : public mlir::OpRewritePattern<ApplyIndexingOp> {
- using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
- PatternRewriter& rewriter) const override {
- mlir::Location loc = indexing_op.getLoc();
- IndexingMap indexing_map = indexing_op.getIndexingMap();
- if (indexing_map.IsKnownEmpty()) {
- return rewriter.notifyMatchFailure(indexing_op,
- "Domain of the indexing map is empty");
- }
- AffineMap* affine_map = &indexing_map.GetMutableAffineMap();
- unsigned num_results = affine_map->getNumResults();
- SmallVector<AffineExpr, 4> new_exprs;
- new_exprs.reserve(num_results);
- SmallVector<Value, 4> new_values;
- new_values.reserve(num_results);
- for (mlir::OpResult opresult : indexing_op->getOpResults()) {
- if (opresult.use_empty()) {
- new_values.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
- continue;
- }
-
- unsigned id = opresult.getResultNumber();
- AffineExpr result_expr = affine_map->getResult(id);
- if (auto const_expr =
- mlir::dyn_cast<mlir::AffineConstantExpr>(result_expr)) {
- new_values.push_back(rewriter.create<arith::ConstantIndexOp>(
- loc, const_expr.getValue()));
- continue;
- }
- if (auto dim_expr = mlir::dyn_cast<mlir::AffineDimExpr>(result_expr)) {
- new_values.push_back(indexing_op.getOperand(dim_expr.getPosition()));
- continue;
- }
- if (auto symbol_expr =
- mlir::dyn_cast<mlir::AffineSymbolExpr>(result_expr)) {
- new_values.push_back(indexing_op.getOperand(
- indexing_map.GetDimVarsCount() + symbol_expr.getPosition()));
- continue;
- }
- new_exprs.push_back(result_expr);
- new_values.push_back(Value{});
- }
- if (new_exprs.size() == num_results) {
- return rewriter.notifyMatchFailure(
- indexing_op, "No constant or dim/symbol expression found");
- }
- *affine_map =
- AffineMap::get(affine_map->getNumDims(), affine_map->getNumSymbols(),
- new_exprs, affine_map->getContext());
- auto new_indexing_op = rewriter.create<ApplyIndexingOp>(
- loc, indexing_op.getOperands(), indexing_map);
- for (int new_result_id = 0, new_indexing_op_result_id = 0;
- new_result_id < new_values.size(); ++new_result_id) {
- auto& new_value = new_values[new_result_id];
- if (new_value) continue;
- new_value = new_indexing_op.getResult(new_indexing_op_result_id++);
- }
- rewriter.replaceOp(indexing_op, new_values);
- return success();
- }
-};
-
-} // namespace
-
-void ApplyIndexingOp::getCanonicalizationPatterns(
- mlir::RewritePatternSet& results, MLIRContext* context) {
- results.add<FoldApplyIndexingOperands, FoldApplyIndexingResults,
- SimplifyIndexingMap, FoldApplyIndexingSequence>(context);
-}
-
-mlir::LogicalResult ApplyIndexingOp::fold(
- FoldAdaptor adaptor, llvm::SmallVectorImpl<mlir::OpFoldResult>& results) {
- auto map = getAffineMap();
- for (auto expr : map.getResults()) {
- if (auto dim = mlir::dyn_cast<mlir::AffineDimExpr>(expr)) {
- results.push_back(getOperand(dim.getPosition()));
- } else if (auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(expr)) {
- results.push_back(getOperand(map.getNumDims() + sym.getPosition()));
- } else {
- results.clear();
- return failure();
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// AtomicRMWOp
-//===----------------------------------------------------------------------===//
-
-void AtomicRMWOp::getAsmResultNames(
- llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
- setNameFn(getResult(), "atomic_rmw");
-}
-
-void AtomicRMWOp::build(OpBuilder& builder, OperationState& result,
- Value tensor, ValueRange ivs) {
- OpBuilder::InsertionGuard g(builder);
- result.addOperands(tensor);
- result.addOperands(ivs);
- result.addTypes(tensor.getType());
-
- auto tensor_type = llvm::cast<RankedTensorType>(tensor.getType());
- Region* body = result.addRegion();
- builder.createBlock(body);
- body->addArgument(tensor_type.getElementType(), tensor.getLoc());
-}
-
-mlir::OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
- auto* body = getBody();
- if (&body->front() == body->getTerminator() &&
- body->front().getOperand(0) == body->getArgument(0)) {
- return getOperand(0);
- }
- return {};
-}
-
-//===----------------------------------------------------------------------===//
-// PureCallOp
-//===----------------------------------------------------------------------===//
-
-void PureCallOp::getAsmResultNames(
- llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
- for (auto result : getResults()) {
- setNameFn(result, "pure_call");
- }
-}
-
-//===----------------------------------------------------------------------===//
-// SyncThreadsOp
-//===----------------------------------------------------------------------===//
-
-void SyncThreadsOp::getAsmResultNames(
- llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
- for (auto result : getResults()) {
- setNameFn(result, "synced_tensor");
- }
-}
-
-} // namespace gpu
-} // namespace xla
-
-#define GET_OP_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc"
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h
deleted file mode 100644
index f43786f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_
-
-#include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep
-#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep
-#include "mlir/IR/Attributes.h" // IWYU pragma: keep
-#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep
-#include "mlir/IR/Dialect.h" // IWYU pragma: keep
-#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep
-#include "mlir/IR/OpDefinition.h" // IWYU pragma: keep
-#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
-#include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep
-#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep
-#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" // IWYU pragma: keep
-
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc"
-#define GET_OP_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc"
-#undef GET_OP_CLASSES
-#define GET_ATTRDEF_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc"
-#undef GET_ATTRDEF_CLASSES
-
-#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td
deleted file mode 100644
index c05f843..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td
+++ /dev/null
@@ -1,275 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
-
-include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/OpBase.td"
-include "mlir/IR/SymbolInterfaces.td"
-include "mlir/IR/OpAsmInterface.td"
-include "mlir/Interfaces/CallInterfaces.td"
-include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td"
-
-class XLAGPU_Op<string mnemonic, list<Trait> traits = []> :
- Op<XlaGpuDialect, mnemonic, traits> {
-}
-
-def XLAGPU_AllocateSharedOp : XLAGPU_Op<"allocate_shared", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
- ]> {
- let summary = "Allocates a shared memory tile.";
-
- let description = [{
- Allocates a shared memory tensor. The tensor is shared among all threads in
- a block.
-
- ```mlir
- %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
- ```
- }];
-
- let results = (outs AnyStaticShapeTensor:$result);
-
- let assemblyFormat = "attr-dict `:` type($result)";
-}
-
-def XLAGPU_SyncThreadsOp : XLAGPU_Op<"sync_threads", [
- TypesMatchWith<"result type matches type of dest",
- "operands", "results", "$_self">,
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
- ]> {
- let summary = "Synchronizes threads.";
-
- let description = [{
- Synchronizes threads, taking any number of distributed tensors and returning
- the synchronized state.
- }];
-
- let arguments = (ins Variadic<AnyRankedTensor>:$operands);
- let results = (outs Variadic<AnyRankedTensor>:$results);
-
- let assemblyFormat = "operands attr-dict `:` type($operands)";
-}
-
-def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw",
- [Pure,
- TypesMatchWith<"result type matches type of dest",
- "input", "result", "$_self">,
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
- ]> {
- let summary = "Atomically updates an element of a tensor.";
-
- let description = [{
- Reads an element from a tensor, computes the updated value for it, and
- writes back the result.
- }];
-
- let arguments = (ins AnyRankedTensor:$input, Variadic<Index>:$indices);
- let results = (outs AnyRankedTensor:$result);
- // The region takes the current value in the tensor as an argument and yields
- // the updated value.
- let regions = (region SizedRegion<1>:$computation);
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>];
-
- let extraClassDeclaration = [{
- mlir::Block* getBody() { return &getComputation().front(); }
- mlir::OpBuilder getBodyBuilder() {
- return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
- }
- // The value stored in tensor[ivs].
- mlir::Value getCurrentValue() {
- return getRegion().getArgument(0);
- }
- }];
- let hasFolder = 1;
-
- let assemblyFormat = [{
- $input `[` $indices `]` `:` type($input) $computation attr-dict
- }];
-}
-
-def XLAGPU_YieldOp : XLAGPU_Op<"yield",
- [HasParent<"::xla::gpu::AtomicRMWOp">, Terminator]> {
- let summary = "Terminator for atomic_rmw ops.";
- let arguments = (ins AnyType:$result);
-
- let assemblyFormat = "$result attr-dict `:` type($result)";
-}
-
-def XLAGPU_PureCallOp : XLAGPU_Op<"pure_call",
- [Pure, CallOpInterface, DeclareOpInterfaceMethods<SymbolUserOpInterface>,
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
- ]> {
- let summary = "Function call without side effects.";
- let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
- let results = (outs Variadic<AnyType>);
- let builders = [
- OpBuilder<(ins "mlir::func::FuncOp":$callee, CArg<"mlir::ValueRange", "{}">:$operands), [{
- $_state.addOperands(operands);
- $_state.addAttribute("callee", mlir::SymbolRefAttr::get(callee));
- $_state.addTypes(callee.getFunctionType().getResults());
- }]>];
- let assemblyFormat = [{
- $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
- }];
-
- let extraClassDeclaration = [{
- operand_range getArgOperands() {
- return getOperands();
- }
-
- mlir::MutableOperandRange getArgOperandsMutable() {
- return getOperandsMutable();
- }
-
- mlir::CallInterfaceCallable getCallableForCallee() {
- return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
- }
-
- void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
- (*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
- }
- }];
-}
-
-def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
- [Pure,
- TypesMatchWith<"result type matches type of operands",
- "operands", "results", "$_self">]> {
- let summary = "Performs a full warp shuffle and reduces the values";
- let description = [{
- This op performs a full warp shuffle and reduces the results using the given
- function. The function is invoked with the operands from the low lanes,
- followed by the operands from the high lanes. For example:
-
- ```
- shuffle_reduce @argmax(%value, %idx) : (f32, index)
- ```
-
- Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke
- @argmax five times. The first invocations will be
-
- ```
- @argmax(%value[i], %idx[i], %value[16+i], %idx[16+i])
- ```
- }];
- let builders = [
- OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
- $_state.addOperands(operands);
- $_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer));
- $_state.addAttribute("max_distance",
- mlir::IntegerAttr::get(
- mlir::IntegerType::get(reducer.getContext(), 64),
- max_distance));
- $_state.addTypes(reducer.getFunctionType().getResults());
- }]>];
- let arguments = (ins FlatSymbolRefAttr:$reducer,
- Variadic<AnyType>:$operands,
- I64Attr:$max_distance);
- let results = (outs Variadic<AnyType>:$results);
-
- let assemblyFormat = [{
- $reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands)
- }];
-}
-
-def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert",
- [Pure,
- TypesMatchWith<"result type matches type of operands",
- "dest", "result", "$_self">,
- TypesMatchWith<"value type matches element type of dest",
- "dest", "value",
- "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
- let summary = "Inserts a value into a tensor if a condition holds";
- let arguments = (ins I1:$condition, AnyType:$value,
- AnyStaticShapeTensor:$dest, Variadic<Index>:$indices);
- let results = (outs AnyStaticShapeTensor:$result);
-
- let assemblyFormat = [{
- $value `into` $dest `[` $indices `]` `if` $condition attr-dict `:` type($dest)
- }];
-}
-
-def XLAGPU_PredicatedExtractOp : XLAGPU_Op<"predicated_extract",
- [Pure,
- TypesMatchWith<"fallback type matches element type of src",
- "src", "fallback",
- "::llvm::cast<mlir::TensorType>($_self).getElementType()">,
- TypesMatchWith<"result type matches element type of src",
- "src", "result",
- "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
- let summary = "Inserts a value into a tensor if a condition holds";
- let arguments = (ins I1:$condition, AnyType:$fallback,
- AnyStaticShapeTensor:$src, Variadic<Index>:$indices);
- let results = (outs AnyType:$result);
-
- let assemblyFormat = [{
- $src `[` $indices `]` `if` $condition `else` $fallback attr-dict `:` type($src)
- }];
-}
-
-def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> {
- let summary = "Applies indexing map to a list of SSA values";
- let description = [{
- The `apply_indexing` operation applies an affine map to a list
- of SSA values, yielding a single SSA value. The number of dimension and
- symbol arguments must be equal to the respective number of dimensional and
- symbolic inputs in the affine map. The affine mapping can be
- multi-dimensional, and so the `apply_indexing` operation always returns one
- value. The operands and results must all have ‘index’ type.
-
- Example:
-
- ```mlir
- #map = affine_map<(d0, d1)[s0] -> (d0 floordiv 8 + d1 floordiv 128, s0)>
- %results:2 = xla_gpu_ops.apply_indexing #map (%0 in [0, 10], %1 in [0, 11])[%2 in [11, 32]]
- ```
- }];
- let arguments = (ins Variadic<Index>:$operands,
- AffineMapAttr:$map,
- DenseI64ArrayAttr:$lower_bounds,
- DenseI64ArrayAttr:$upper_bounds);
- let results = (outs Variadic<Index>);
-
- let builders = [
- OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols,
- "const IndexingMap&":$indexing_map)>,
- OpBuilder<(ins "mlir::ValueRange":$operands,
- "const IndexingMap&":$indexing_map)>,
- OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map,
- "llvm::ArrayRef<DimVar>":$dim_vars,
- "llvm::ArrayRef<RangeVar>":$range_vars)>,
- OpBuilder<(ins "mlir::ValueRange":$operands,
- "mlir::AffineMap":$affine_map,
- "llvm::ArrayRef<int64_t>":$lower_bounds,
- "llvm::ArrayRef<int64_t>":$upper_bounds)>,
- ];
- let extraClassDeclaration = [{
- // Returns the indexing map constructed from affine_map and the bounds.
- xla::gpu::IndexingMap getIndexingMap();
- // Extracts the affine map from the attribute.
- mlir::AffineMap getAffineMap() { return getMapAttr().getAffineMap(); }
- }];
- let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
- let hasCanonicalizer = 1;
- let hasFolder = 1;
-}
-
-#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc
deleted file mode 100644
index 929ee0e..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc
+++ /dev/null
@@ -1,1095 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cassert>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/log/check.h"
-#include "absl/strings/str_cat.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DECL_LOWERTENSORSPASS
-#define GEN_PASS_DEF_LOWERTENSORSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-using mlir::failure;
-using mlir::Location;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpBuilder;
-using mlir::Operation;
-using mlir::success;
-using mlir::Type;
-using mlir::TypedValue;
-using mlir::TypeRange;
-using mlir::Value;
-using mlir::ValueRange;
-
-namespace arith = ::mlir::arith;
-namespace scf = ::mlir::scf;
-namespace ml = ::mlir::LLVM;
-
-Value GetDestinationBuffer(Value dest) {
- while (dest.getDefiningOp()) {
- int result_number = mlir::cast<mlir::OpResult>(dest).getResultNumber();
- if (auto insert = dest.getDefiningOp<mlir::tensor::InsertOp>()) {
- dest = insert.getDest();
- } else if (auto scf_if = dest.getDefiningOp<scf::IfOp>()) {
- // Pick one of the branches, they're required to yield the same buffers.
- dest = scf_if.getThenRegion().front().getTerminator()->getOperand(
- result_number);
- } else if (auto scf_for = dest.getDefiningOp<scf::ForOp>()) {
- dest = scf_for.getInitArgs()[result_number];
- } else if (dest.getDefiningOp<mlir::UnrealizedConversionCastOp>() ||
- dest.getDefiningOp<AllocateSharedOp>()) {
- break;
- } else if (auto transfer_write =
- dest.getDefiningOp<mlir::vector::TransferWriteOp>()) {
- dest = transfer_write.getSource();
- } else {
- dest.getDefiningOp()->emitOpError("unsupported dest type");
- return nullptr;
- }
- }
- return dest;
-}
-
-template <typename Op>
-bool IsSupportedTransfer(Op op) {
- return !absl::c_linear_search(op.getInBoundsValues(), false) &&
- op.getVectorType().getRank() == 1 && !op.getMask() &&
- op.getPermutationMap().isMinorIdentity();
-}
-
-struct RewriteFunctionSignatures : mlir::OpRewritePattern<mlir::func::FuncOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override {
- auto is_tensor = [](Type ty) {
- return mlir::isa<mlir::RankedTensorType>(ty);
- };
- if (!llvm::any_of(op.getFunctionType().getInputs(), is_tensor)) {
- return rewriter.notifyMatchFailure(op,
- "the function has no input tensors");
- }
-
- bool some_tensor_result =
- llvm::any_of(op.getFunctionType().getResults(), is_tensor);
- bool all_tensor_results =
- llvm::all_of(op.getFunctionType().getResults(), is_tensor);
- if (some_tensor_result && !all_tensor_results) {
- op->emitOpError("function has a mix of tensor and non-tensor results");
- return failure();
- }
-
- TypeRange new_results = op.getFunctionType().getResults();
- if (some_tensor_result) {
- new_results = {};
- auto terminator = op.getFunctionBody().front().getTerminator();
- rewriter.setInsertionPoint(terminator);
- rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(terminator);
- }
-
- llvm::SmallVector<Type> new_operands(op.getFunctionType().getInputs());
- for (auto&& [index, operand] : llvm::enumerate(new_operands)) {
- if (is_tensor(operand)) {
- rewriter.setInsertionPointToStart(&op.getBody().front());
- auto cast = rewriter.create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(), operand, op.getArgument(index));
- op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast);
- operand = mlir::LLVM::LLVMPointerType::get(op.getContext());
- }
- }
-
- op.setFunctionType(rewriter.getFunctionType(new_operands, new_results));
- auto& entry = op->getRegion(0).front();
- for (auto [arg, arg_type] : llvm::zip(entry.getArguments(), new_operands)) {
- arg.setType(arg_type);
- }
-
- return success();
- }
-};
-
-Value GetLinearIndex(TypedValue<mlir::RankedTensorType> tensor,
- ValueRange indices, mlir::PatternRewriter& rewriter) {
- auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
- if (auto encoding = tensor.getType().getEncoding()) {
- *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
- mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
- }
- auto linear_shape =
- ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
- auto linearize_map =
- GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
- mlir::SmallVector<Value> result;
- rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
- ValueRange{}, linearize_map);
- CHECK_EQ(result.size(), 1);
- auto index = result.front();
- auto index_ty = rewriter.getIntegerType(
- mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp())
- .getTypeSizeInBits(index.getType()));
- return rewriter.create<mlir::arith::IndexCastUIOp>(tensor.getLoc(), index_ty,
- index);
-}
-
-std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
- mlir::ImplicitLocOpBuilder& b) {
- Value one = b.create<mlir::arith::ConstantIntOp>(1, linear_index.getType());
- Value is_low_nibble = b.create<mlir::arith::CmpIOp>(
- mlir::arith::CmpIPredicate::eq, one,
- b.create<mlir::arith::AndIOp>(linear_index, one));
- Value i8_index = b.create<mlir::arith::ShRUIOp>(linear_index, one);
- return {i8_index, is_low_nibble};
-}
-
-mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
- Value linear_index, mlir::PatternRewriter& rewriter,
- Type element_type = nullptr) {
- if (!element_type) {
- element_type = tensor.getType().getElementType();
- }
- auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
- auto tensor_ptr = rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- tensor.getLoc(), ptr, tensor)
- .getResult(0);
- mlir::LLVMTypeConverter converter(rewriter.getContext());
- auto llvm_element_type = converter.convertType(element_type);
- auto gep = rewriter.create<mlir::LLVM::GEPOp>(
- tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, linear_index);
- gep.setInbounds(true);
- return gep;
-}
-
-mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
- ValueRange indices,
- mlir::PatternRewriter& rewriter) {
- return CreateGep(tensor, GetLinearIndex(tensor, indices, rewriter), rewriter);
-}
-
-struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- mlir::tensor::ExtractOp op,
- mlir::PatternRewriter& rewriter) const override {
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto linear_index =
- GetLinearIndex(op.getTensor(), op.getIndices(), rewriter);
- Type element_type = op.getTensor().getType().getElementType();
- Value is_low_nibble = nullptr;
- if (element_type == rewriter.getI4Type()) {
- element_type = rewriter.getI8Type();
- std::tie(linear_index, is_low_nibble) =
- GetI4IndexAndNibble(linear_index, b);
- }
-
- auto gep = CreateGep(op.getTensor(), linear_index, rewriter, element_type);
- auto load =
- rewriter
- .create<mlir::LLVM::LoadOp>(gep.getLoc(), gep.getElemType(), gep)
- .getResult();
-
- if (is_low_nibble) {
- auto high_value = b.create<mlir::arith::ShRUIOp>(
- load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
- load = b.create<mlir::arith::TruncIOp>(
- op.getType(),
- b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
- }
-
- rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
- op, op.getType(), load);
- return success();
- }
-};
-
-// Swaps pairs of values in the vector: [0, 1, 2, 3] -> [1, 0, 3, 2].
-Value PermutePairsInVector(Value vector, mlir::ImplicitLocOpBuilder& b) {
- // There is a `vector.extract_strided_slice` op that would be useful here, but
- // it actually requires the strides to be 1.
- auto ty = mlir::cast<mlir::VectorType>(vector.getType());
- int size = ty.getNumElements();
- Value result = vector;
- for (int i = 0; i < size; i += 2) {
- auto v0 = b.create<mlir::vector::ExtractOp>(vector, i);
- auto v1 = b.create<mlir::vector::ExtractOp>(vector, i + 1);
- result = b.create<mlir::vector::InsertOp>(v1, result, i);
- result = b.create<mlir::vector::InsertOp>(v0, result, i + 1);
- }
- return result;
-}
-
-struct RewriteTransferRead
- : mlir::OpRewritePattern<mlir::vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- mlir::vector::TransferReadOp op,
- mlir::PatternRewriter& rewriter) const override {
- assert(IsSupportedTransfer(op));
-
- auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
- op.getSource());
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto linear_index = GetLinearIndex(source, op.getIndices(), rewriter);
-
- mlir::VectorType vector_type = op.getVectorType();
- if (vector_type.getElementType().isInteger(1)) {
- vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
- }
- mlir::Type gep_element_type = vector_type.getElementType();
- if (op.getVectorType().getElementType().isInteger(4)) {
- linear_index = b.create<arith::ShRUIOp>(
- linear_index,
- b.create<arith::ConstantIntOp>(1, linear_index.getType()));
- gep_element_type = b.getI8Type();
- }
- auto gep = CreateGep(source, linear_index, rewriter, gep_element_type);
-
- mlir::LLVMTypeConverter converter(b.getContext());
- auto llvm_vector_type = converter.convertType(vector_type);
- auto loaded =
- b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();
-
- if (source.getType().getElementType().isInteger(1)) {
- Value zero = b.create<mlir::arith::ConstantOp>(
- mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
- loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
- } else if (source.getType().getElementType().isInteger(4)) {
- // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
- // elements.
- loaded = PermutePairsInVector(loaded, b);
- }
-
- rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
- op, op.getType(), loaded);
- return success();
- }
-};
-
-struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- mlir::tensor::InsertOp op,
- mlir::PatternRewriter& rewriter) const override {
- Value dest = GetDestinationBuffer(op.getDest());
- if (!dest) {
- return failure();
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
- auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
- auto element_type = tensor_dest.getType().getElementType();
- Value is_low_nibble = nullptr;
-
- if (element_type == rewriter.getI4Type()) {
- element_type = rewriter.getI8Type();
- std::tie(linear_index, is_low_nibble) =
- GetI4IndexAndNibble(linear_index, b);
- }
-
- auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
- auto scalar_value = op.getScalar();
-
- if (is_low_nibble) {
- Value current_value =
- b.create<mlir::LLVM::LoadOp>(gep.getElemType(), gep);
- auto ty = current_value.getType();
- scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);
- Value low_updated = b.create<mlir::arith::OrIOp>(
- b.create<mlir::arith::AndIOp>(
- current_value, b.create<mlir::arith::ConstantIntOp>(0xf0, ty)),
- scalar_value);
- Value high_updated = b.create<mlir::arith::OrIOp>(
- b.create<mlir::arith::AndIOp>(
- current_value, b.create<mlir::arith::ConstantIntOp>(0x0f, ty)),
- b.create<mlir::arith::ShLIOp>(
- scalar_value, b.create<mlir::arith::ConstantIntOp>(4, ty)));
- scalar_value = b.create<mlir::arith::SelectOp>(is_low_nibble, low_updated,
- high_updated);
- }
-
- mlir::LLVMTypeConverter converter(getContext());
- auto llvm_type = converter.convertType(scalar_value.getType());
- scalar_value = rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- gep.getLoc(), llvm_type, scalar_value)
- .getResult(0);
- rewriter.create<mlir::LLVM::StoreOp>(gep.getLoc(), scalar_value, gep);
-
- op.replaceAllUsesWith(op.getDest());
- op.erase();
- return success();
- }
-};
-
-struct RewriteTransferWrite
- : mlir::OpRewritePattern<mlir::vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- mlir::vector::TransferWriteOp op,
- mlir::PatternRewriter& rewriter) const override {
- assert(IsSupportedTransfer(op));
- Value dest = GetDestinationBuffer(op.getSource());
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
- auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
- auto element_type = tensor_dest.getType().getElementType();
-
- mlir::Value vector_value = op.getVector();
- if (op.getVectorType().getElementType().isInteger(1)) {
- vector_value = b.create<arith::ExtUIOp>(
- op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
- vector_value);
- }
- if (op.getVectorType().getElementType().isInteger(4)) {
- linear_index = b.create<arith::ShRUIOp>(
- linear_index,
- b.create<arith::ConstantIntOp>(1, linear_index.getType()));
- element_type = rewriter.getI8Type();
- // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
- // elements.
- vector_value = PermutePairsInVector(vector_value, b);
- }
- auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
-
- mlir::LLVMTypeConverter converter(getContext());
- auto llvm_type = converter.convertType(vector_value.getType());
- vector_value =
- b.create<mlir::UnrealizedConversionCastOp>(llvm_type, vector_value)
- .getResult(0);
- b.create<mlir::LLVM::StoreOp>(vector_value, gep);
-
- rewriter.replaceOp(op, mlir::ValueRange{op.getSource()});
- return success();
- }
-};
-
-struct RewriteCall : mlir::OpRewritePattern<mlir::func::CallOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- mlir::func::CallOp op, mlir::PatternRewriter& rewriter) const override {
- if (!llvm::any_of(op->getOperandTypes(), [](Type ty) {
- return mlir::isa<mlir::RankedTensorType>(ty);
- })) {
- return rewriter.notifyMatchFailure(op, "the call has no input tensors");
- }
-
- for (const auto&& [index, arg] : llvm::enumerate(op.getOperands())) {
- if (mlir::isa<mlir::RankedTensorType>(arg.getType())) {
- op.setOperand(
- index,
- rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(),
- mlir::LLVM::LLVMPointerType::get(op.getContext()), arg)
- .getResult(0));
- }
- }
- return success();
- }
-};
-
-mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
- const std::string& name_prefix,
- mlir::ShapedType shaped_ty,
- mlir::ModuleOp module, bool is_constant,
- int addr_space,
- mlir::ImplicitLocOpBuilder& b) {
- if (auto elements = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(value)) {
- // The lowering to LLVM only works for 1d tensors or those with trailing
- // unit dimensions.
- value = elements.reshape(mlir::RankedTensorType::get(
- {elements.getNumElements()}, elements.getElementType()));
- }
-
- Type element_type = shaped_ty.getElementType();
- // Needed to support complex element type.
- mlir::LLVMTypeConverter converter(b.getContext());
- auto llvm_element_type = converter.convertType(element_type);
- auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type,
- shaped_ty.getNumElements());
- std::string name;
- int index = 0;
- do {
- name = absl::StrCat(name_prefix, index);
- ++index;
- } while (module.lookupSymbol(name));
- b.setInsertionPointToStart(module.getBody());
- return b.create<mlir::LLVM::GlobalOp>(
- array_ty, is_constant,
- /*linkage=*/mlir::LLVM::Linkage::Private, name, value, /*alignment=*/0,
- addr_space);
-}
-
-struct RewriteAllocateShared : mlir::OpRewritePattern<AllocateSharedOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- AllocateSharedOp op, mlir::PatternRewriter& rewriter) const override {
- auto module = op->getParentOfType<mlir::ModuleOp>();
- auto shaped_ty = mlir::cast<mlir::ShapedType>(op.getResult().getType());
- constexpr int kGPUSharedMemoryAddrSpace = 3;
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-
- auto global =
- CreateGlobalOp(mlir::Attribute{}, "shared_", shaped_ty, module,
- /*is_constant=*/false, kGPUSharedMemoryAddrSpace, b);
-
- rewriter.setInsertionPoint(op);
- auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
- rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
- op, op.getResult().getType(),
- rewriter
- .create<mlir::LLVM::AddrSpaceCastOp>(
- op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
- addr)
- .getResult());
- return success();
- }
-};
-
-struct RewriteNonScalarConstants
- : mlir::OpRewritePattern<mlir::arith::ConstantOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- mlir::arith::ConstantOp op,
- mlir::PatternRewriter& rewriter) const override {
- if (mlir::isa<mlir::VectorType>(op.getType())) {
- return rewriter.notifyMatchFailure(op, "the op is a vector constant");
- }
- auto shaped_ty = mlir::dyn_cast<mlir::ShapedType>(op.getValue().getType());
- // We only need to rewrite non-scalar constants.
- if (!shaped_ty || shaped_ty.getNumElements() < 2) {
- return rewriter.notifyMatchFailure(
- op, "the op is an effective scalar constant");
- }
-
- constexpr int kGPUGlobalMemoryAddrSpace = 0;
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto module = op->getParentOfType<mlir::ModuleOp>();
- auto global =
- CreateGlobalOp(op.getValue(), "global_cst_", shaped_ty, module,
- /*is_constant=*/true, kGPUGlobalMemoryAddrSpace, b);
-
- rewriter.setInsertionPoint(op);
- auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
- rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
- op, op.getResult().getType(),
- rewriter
- .create<mlir::LLVM::AddrSpaceCastOp>(
- op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
- addr)
- .getResult());
- return success();
- }
-};
-
-struct RewriteSyncThreads : mlir::OpRewritePattern<SyncThreadsOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- SyncThreadsOp op, mlir::PatternRewriter& rewriter) const override {
- rewriter.create<mlir::gpu::BarrierOp>(op.getLoc());
- rewriter.replaceOp(op, op.getOperands());
- return success();
- }
-};
-
-// TODO(jreiffers): Generalize this to support index switches with some used
-// results and upstream it as a canonicalization pattern.
-struct RemoveUnusedIndexSwitchResults
- : mlir::OpRewritePattern<scf::IndexSwitchOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(
- scf::IndexSwitchOp op, mlir::PatternRewriter& rewriter) const override {
- if (op->getNumResults() == 0 || !op->use_empty()) {
- return rewriter.notifyMatchFailure(op, "the op has users");
- }
-
- auto new_op = rewriter.create<scf::IndexSwitchOp>(
- op.getLoc(), mlir::TypeRange{}, op.getArg(), op.getCases(),
- op.getNumCases());
- for (int i = 0; i < op->getNumRegions(); ++i) {
- auto& old_region = op->getRegion(i);
- auto& new_region = new_op->getRegion(i);
- rewriter.mergeBlocks(&old_region.getBlocks().front(),
- &new_region.emplaceBlock());
- auto yield_op = new_region.getBlocks().front().getTerminator();
- rewriter.modifyOpInPlace(yield_op, [&]() { yield_op->setOperands({}); });
- }
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-bool IsAtomicIntegral(Type element_type) {
- if (!element_type.isInteger()) {
- return false;
- }
- unsigned element_bitwidth = element_type.getIntOrFloatBitWidth();
- return element_bitwidth == 32 || element_bitwidth == 64;
-}
-
-Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) {
- if (value.getType().isIntOrFloat() && ty.isIntOrFloat()) {
- return b.create<ml::BitcastOp>(ty, value);
- }
-
- mlir::LLVMTypeConverter converter(b.getContext());
- // If either type is a complex, we need to go through an alloca, since no
- // direct bitcast from a struct to an int is possible.
- Type llvm_input_ty = converter.convertType(value.getType());
- Type llvm_result_ty = converter.convertType(ty);
- Type ptr_ty = mlir::LLVM::LLVMPointerType::get(b.getContext());
-
- Value llvm_value =
- b.create<mlir::UnrealizedConversionCastOp>(llvm_input_ty, value)
- .getResult(0);
- Value alloca = b.create<ml::AllocaOp>(
- ptr_ty, llvm_input_ty, b.create<ml::ConstantOp>(b.getI32Type(), 1));
- b.create<ml::StoreOp>(llvm_value, alloca);
- auto result = b.create<ml::LoadOp>(llvm_result_ty, alloca).getResult();
- return b.create<mlir::UnrealizedConversionCastOp>(ty, result).getResult(0);
-};
-
-class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
- public:
- RewriteAtomicRMW(mlir::MLIRContext* context, bool is_amd,
- const std::string& gpu_arch)
- : mlir::OpRewritePattern<AtomicRMWOp>(context),
- is_amd_(is_amd),
- gpu_arch_(gpu_arch) {}
-
- LogicalResult matchAndRewrite(
- AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override {
- if (failed(rewriteAsDirectAtomicRMW(op, rewriter))) {
- rewriteAsAtomicCAS(op, rewriter);
- }
- rewriter.replaceOp(op, op.getInput());
- return success();
- }
-
- private:
- // Returns atomic op modifier and the atomic bin op kind.
- std::optional<std::pair<Value, ml::AtomicBinOp>> GetAtomicModifierParameters(
- AtomicRMWOp op) const {
- Type element_type = op.getInput().getType().getElementType();
- auto& operations = op.getBody()->getOperations();
- auto terminator = op.getBody()->getTerminator();
- if (operations.size() > 2) {
- return std::nullopt;
- }
- // If the body contains only the terminator, then it is an atomic store.
- if (operations.size() == 1) {
- // TODO(b/336367145): Support complex<f32> atomic store.
- if (element_type.isF32() || IsAtomicIntegral(element_type)) {
- return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg);
- }
- return std::nullopt;
- }
- // Match the kind of the atomic op.
- mlir::Operation* modifier_op = &operations.front();
- std::optional<ml::AtomicBinOp> kind =
- llvm::TypeSwitch<Operation*, std::optional<ml::AtomicBinOp>>(
- modifier_op)
- // Floating-point operations.
- .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; })
- .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; })
- .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; })
- // Integer operations.
- .Case([&](arith::AddIOp op) {
- return IsAtomicIntegral(element_type)
- ? std::make_optional(ml::AtomicBinOp::add)
- : std::nullopt;
- })
- .Case([&](arith::MaxUIOp op) {
- return IsAtomicIntegral(element_type)
- ? std::make_optional(ml::AtomicBinOp::umax)
- : std::nullopt;
- })
- .Case([&](arith::MinUIOp op) {
- return IsAtomicIntegral(element_type)
- ? std::make_optional(ml::AtomicBinOp::umin)
- : std::nullopt;
- })
- .Case([&](arith::MaxSIOp op) {
- return IsAtomicIntegral(element_type)
- ? std::make_optional(ml::AtomicBinOp::max)
- : std::nullopt;
- })
- .Case([&](arith::MinSIOp op) {
- return IsAtomicIntegral(element_type)
- ? std::make_optional(ml::AtomicBinOp::min)
- : std::nullopt;
- })
- .Default([](Operation* op) { return std::nullopt; });
- if (!kind.has_value()) {
- return std::nullopt;
- }
- // Find the modifier arg that does not match the argument of `atomic_rmw`
- // body.
- Value block_arg = op.getBody()->getArgument(0);
- Value modifier_arg = modifier_op->getOperand(0) == block_arg
- ? modifier_op->getOperand(1)
- : modifier_op->getOperand(0);
- return std::make_pair(modifier_arg, *kind);
- }
-
- // Certain computations, such as floating-point addition and integer
- // maximization, can be simply implemented using an LLVM atomic instruction.
- // If "computation" is one of this kind, emits code to do that and returns
- // true; otherwise, returns false.
- LogicalResult rewriteAsDirectAtomicRMW(
- AtomicRMWOp op, mlir::PatternRewriter& rewriter) const {
- auto modifier_parameters = GetAtomicModifierParameters(op);
- if (!modifier_parameters.has_value()) {
- return failure();
- }
- Value modifier_arg = modifier_parameters->first;
- Type element_type = modifier_arg.getType();
- ml::AtomicBinOp atomic_bin_op = modifier_parameters->second;
-
- Location loc = op.getLoc();
- llvm::StringRef sync_scope = is_amd_ ? "agent" : "";
- Value addr = CreateGep(op.getInput(), op.getIndices(), rewriter);
-
- switch (atomic_bin_op) {
- case ml::AtomicBinOp::xchg: {
- rewriter.create<ml::StoreOp>(
- loc, modifier_arg, addr,
- /*alignment=*/element_type.getIntOrFloatBitWidth() / 8,
- /*volatile*/ false, /*isNonTemporal=*/false,
- ml::AtomicOrdering::unordered);
- return success();
- }
- case ml::AtomicBinOp::add:
- case ml::AtomicBinOp::max:
- case ml::AtomicBinOp::min:
- case ml::AtomicBinOp::umax:
- case ml::AtomicBinOp::umin: {
- rewriter.create<ml::AtomicRMWOp>(loc, atomic_bin_op, addr, modifier_arg,
- ml::AtomicOrdering::seq_cst,
- sync_scope);
- return success();
- }
- case ml::AtomicBinOp::fadd: {
- // TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr.
- return is_amd_ ? emitAMDAtomicFAdd(loc, modifier_arg, addr, sync_scope,
- gpu_arch_, rewriter)
- : emitNVidiaAtomicFAdd(loc, modifier_arg, addr,
- sync_scope, gpu_arch_, rewriter);
- }
- case ml::AtomicBinOp::fmax: {
- return rewriteAtomicFMaxAsIntAtomics(loc, modifier_arg, addr,
- sync_scope, rewriter);
- }
- default:
- return failure();
- }
- return success();
- }
-
- LogicalResult emitNVidiaAtomicFAdd(Location loc, Value modifier_arg,
- Value addr, llvm::StringRef sync_scope,
- llvm::StringRef cuda_arch,
- OpBuilder& b) const {
- se::CudaComputeCapability cuda_compute_capability(cuda_arch.str());
- Type element_type = modifier_arg.getType();
- // "atom.add.f64 requires sm_60 or higher."
- // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
- bool is_supported_f16_atomic =
- element_type.isF16() &&
- cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA);
- bool is_supported_bf16_atomic =
- element_type.isBF16() &&
- cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER);
- bool is_supported_f64_atomic =
- element_type.isF64() &&
- cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_);
- if (!element_type.isF32() && !is_supported_f16_atomic &&
- !is_supported_bf16_atomic && !is_supported_f64_atomic) {
- return failure();
- }
- b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
- ml::AtomicOrdering::seq_cst, sync_scope);
- return success();
- }
-
- LogicalResult emitAMDAtomicFAdd(Location loc, Value modifier_arg, Value addr,
- llvm::StringRef sync_scope,
- llvm::StringRef gcn_arch,
- OpBuilder& b) const {
- se::RocmComputeCapability rocm_compute_capability(gcn_arch.str());
- Type element_type = modifier_arg.getType();
- bool is_supported_f16_atomic =
- element_type.isF16() &&
- rocm_compute_capability.has_fp16_atomics_support();
- if (!element_type.isF32() && !is_supported_f16_atomic) {
- return failure();
- }
- constexpr int kGlobalMemory = 1;
- constexpr int kSharedMemory = 3;
- auto addr_type = mlir::cast<ml::LLVMPointerType>(addr.getType());
- // adds to shared memory are always atomic.
- if (addr_type.getAddressSpace() != kSharedMemory) {
- // The compiler will only generate a global_atomic_fadd if the pointer is
- // in global addrspace (1)
- addr = b.create<ml::AddrSpaceCastOp>(
- loc, ml::LLVMPointerType::get(b.getContext(), kGlobalMemory), addr);
- }
- b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
- ml::AtomicOrdering::seq_cst, sync_scope);
- return success();
- }
-
- LogicalResult rewriteAtomicFMaxAsIntAtomics(Location loc, Value modifier_arg,
- Value addr,
- llvm::StringRef sync_scope,
- OpBuilder& b) const {
- Type element_type = modifier_arg.getType();
- if (!element_type.isF32()) {
- return failure();
- }
- // Evaluating floating max using integer atomics has the limitation of not
- // propagating -NaNs. To handle this, we check if the update value is -NaN
- // and convert it to a positive one by dropping the sign-bit.
- Value current = b.create<ml::LoadOp>(loc, element_type, addr);
- Value current_is_nan =
- b.create<ml::FCmpOp>(loc, ml::FCmpPredicate::uno, current, current);
- auto is_current_nan =
- b.create<scf::IfOp>(loc, /*resultTypes=*/TypeRange{}, current_is_nan,
- /*addThenBlock=*/true, /*addElseBlock=*/true);
- auto if_current_nan_then_builder =
- OpBuilder::atBlockEnd(is_current_nan.thenBlock(), b.getListener());
- if_current_nan_then_builder.create<scf::YieldOp>(loc);
-
- auto if_current_nan_else_builder =
- OpBuilder::atBlockEnd(is_current_nan.elseBlock(), b.getListener());
- Value is_modifier_nan = if_current_nan_else_builder.create<ml::FCmpOp>(
- loc, ml::FCmpPredicate::uno, modifier_arg, modifier_arg);
- auto f32_nan = mlir::APFloat::getNaN(mlir::APFloat::IEEEsingle());
- Value nan = if_current_nan_else_builder.create<ml::ConstantOp>(
- loc, b.getF32Type(), f32_nan);
- Value no_negative_nan_source =
- if_current_nan_else_builder.create<ml::SelectOp>(loc, is_modifier_nan,
- nan, modifier_arg);
- Value current_less_than_modifier =
- if_current_nan_else_builder.create<ml::FCmpOp>(
- loc, ml::FCmpPredicate::ult, current, no_negative_nan_source);
-
- // This check allows us to skip the atomic update all-together at the
- // expense of reading the value in memory for every update. Evaluated
- // against Waymo's benchmarks, adding the check achieves better overall
- // performance.
- auto if_need_update = if_current_nan_else_builder.create<scf::IfOp>(
- loc, /*resultTypes=*/TypeRange{}, current_less_than_modifier,
- /*withElseRegion=*/true,
- /*addElseBlock=*/false);
- if_current_nan_else_builder.create<scf::YieldOp>(loc);
-
- auto then_builder =
- OpBuilder::atBlockEnd(if_need_update.thenBlock(), b.getListener());
- Value source_float_as_int = then_builder.create<ml::BitcastOp>(
- loc, then_builder.getI32Type(), no_negative_nan_source);
- Value c0 = then_builder.create<ml::ConstantOp>(loc, b.getI32Type(), 0);
- Value is_not_negative = then_builder.create<ml::ICmpOp>(
- loc, ml::ICmpPredicate::sge, source_float_as_int, c0);
- then_builder.create<scf::IfOp>(
- loc, is_not_negative,
- [&](OpBuilder& nested_b, Location nested_loc) {
- // atomicMax((int *)address, __float_as_int(val))
- nested_b.create<ml::AtomicRMWOp>(
- loc, ml::AtomicBinOp::max, addr, source_float_as_int,
- ml::AtomicOrdering::seq_cst, sync_scope);
- nested_b.create<scf::YieldOp>(nested_loc);
- },
- [&](OpBuilder& nested_b, Location nested_loc) {
- // atomicMax((int *)address, __float_as_int(val))
- nested_b.create<ml::AtomicRMWOp>(
- loc, ml::AtomicBinOp::umin, addr, source_float_as_int,
- ml::AtomicOrdering::seq_cst, sync_scope);
- nested_b.create<scf::YieldOp>(nested_loc);
- });
- then_builder.create<scf::YieldOp>(loc);
- return success();
- }
-
- // Implements atomic binary operations using atomic compare-and-swap
- // (atomicCAS) as follows:
- // 1. Reads the value from the memory pointed to by output_address and
- // records it as old_output.
- // 2. Uses old_output as one of the source operand to perform the binary
- // operation and stores the result in new_output.
- // 3. Calls atomicCAS which implements compare-and-swap as an atomic
- // operation. In particular, atomicCAS reads the value from the memory
- // pointed to by output_address, and compares the value with old_output.
- // If the two values equal, new_output is written to the same memory
- // location and true is returned to indicate that the atomic operation
- // succeeds. Otherwise, the new value read from the memory is returned. In
- // this case, the new value is copied to old_output, and steps 2. and 3.
- // are repeated until atomicCAS succeeds.
- //
- // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers.
- // If the element type of the binary operation is 32 bits or 64 bits, the
- // integer type of the same size is used for the atomicCAS operation. On the
- // other hand, if the element type is smaller than 32 bits, int32_t is used
- // for the atomicCAS operation. In this case, atomicCAS reads and writes 32
- // bit values from the memory, which is larger than the memory size required
- // by the original atomic binary operation. We mask off the last two bits of
- // the output_address and use the result as an address to read the 32 bit
- // values from the memory. This can avoid out of bound memory accesses if
- // tensor buffers are 4 byte aligned and have a size of 4N, an assumption that
- // the runtime can guarantee.
- void rewriteAsAtomicCAS(AtomicRMWOp op,
- mlir::PatternRewriter& rewriter) const {
- Location loc = op.getLoc();
- auto input = op.getInput();
-
- // Use 32-bit atomic type for small input types.
- Type result_ty = op.getResult().getType().getElementType();
- int result_size;
- if (auto complex_ty = mlir::dyn_cast<mlir::ComplexType>(result_ty)) {
- result_size = complex_ty.getElementType().getIntOrFloatBitWidth() * 2;
- } else {
- result_size = result_ty.getIntOrFloatBitWidth();
- }
-
- bool small_type = result_size < 32;
- Type atomic_ty =
- mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size);
-
- // Calculate load address for the input.
- Value addr = CreateGep(input, op.getIndices(), rewriter);
- Value shift, mask;
- if (small_type) {
- // Update input pointer by discarding the last two bits - i.e. align to
- // 32-bit boundary for small input types (will not result in OOB, as the
- // input alignment is at least 32 bits).
- Type addr_int_ty = rewriter.getI64Type();
- Value addr_int = rewriter.create<ml::PtrToIntOp>(loc, addr_int_ty, addr);
- Value addr_offset = rewriter.create<ml::AndOp>(
- loc, addr_int, rewriter.create<ml::ConstantOp>(loc, addr_int_ty, 3));
- Value index = rewriter.create<ml::MulOp>(
- loc, addr_offset,
- rewriter.create<ml::ConstantOp>(loc, addr_int_ty, -1));
- addr =
- rewriter.create<ml::GEPOp>(loc, addr.getType(), rewriter.getI8Type(),
- addr, index, /*inbounds=*/true);
-
- // Calculate the bit shift (assume little-endianness).
- Value offset = rewriter.create<ml::TruncOp>(loc, atomic_ty, addr_offset);
- shift = rewriter.create<ml::MulOp>(
- loc, offset,
- rewriter.create<ml::ConstantOp>(loc, offset.getType(), 8));
-
- // Compose the update mask.
- Value bits_long = rewriter.create<ml::ConstantOp>(loc, atomic_ty, -1);
- Value bits_short = rewriter.create<ml::ZExtOp>(
- loc, atomic_ty,
- rewriter.create<ml::ConstantOp>(
- loc, rewriter.getIntegerType(result_size), -1));
- mask = rewriter.create<ml::XOrOp>(
- loc, bits_long, rewriter.create<ml::ShlOp>(loc, bits_short, shift));
- }
-
- // Load initial atomic value and create the loop.
- Value initial = rewriter.create<ml::LoadOp>(loc, atomic_ty, addr);
- rewriter.create<scf::WhileOp>(
- loc, TypeRange{atomic_ty}, ValueRange{initial},
- [&](mlir::OpBuilder& builder, Location loc, ValueRange values) {
- mlir::ImplicitLocOpBuilder b(loc, builder);
- Value old_value = values[0];
-
- // Convert atomic value to input value.
- Value input_value;
- if (small_type) {
- Value short_value =
- b.create<ml::TruncOp>(b.getIntegerType(result_size),
- b.create<ml::LShrOp>(old_value, shift));
- input_value = b.create<ml::BitcastOp>(result_ty, short_value);
- } else {
- input_value = CreateBitcast(b, old_value, result_ty);
- }
-
- // Perform computation on the loaded input value.
- rewriter.mergeBlocks(&op.getComputation().front(), b.getBlock(),
- {input_value});
- auto yield_op = b.getBlock()->getTerminator();
- Value result = yield_op->getOperand(0);
- rewriter.eraseOp(yield_op);
-
- // Convert resulting value to atomic value.
- Value new_value;
- if (small_type) {
- Value cast_value = b.create<ml::ZExtOp>(
- atomic_ty, b.create<ml::BitcastOp>(
- rewriter.getIntegerType(result_size), result));
- new_value =
- b.create<ml::OrOp>(b.create<ml::AndOp>(old_value, mask),
- b.create<ml::ShlOp>(cast_value, shift));
- } else {
- new_value = CreateBitcast(b, result, atomic_ty);
- }
-
- // Try saving the result atomically, retry if failed.
- Value cmpxchg = b.create<ml::AtomicCmpXchgOp>(
- loc, addr, old_value, new_value,
- /*success_ordering=*/ml::AtomicOrdering::seq_cst,
- /*failure_ordering=*/ml::AtomicOrdering::seq_cst);
- Value next = b.create<ml::ExtractValueOp>(cmpxchg, 0);
- Value ok = b.create<ml::ExtractValueOp>(cmpxchg, 1);
- Value low_bit = b.create<ml::ConstantOp>(b.getOneAttr(b.getI1Type()));
- Value not_ok = b.create<ml::XOrOp>(ok, low_bit);
- b.create<scf::ConditionOp>(not_ok, ValueRange{next});
- },
- [&](mlir::OpBuilder& b, Location loc, ValueRange values) {
- b.create<scf::YieldOp>(loc, values);
- });
- }
-
- bool is_amd_;
- std::string gpu_arch_;
-};
-
-class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
- public:
- explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
- : LowerTensorsPassBase(options) {}
-
- void runOnOperation() override {
- MLIRContext* mlir_context = &getContext();
- mlir::RewritePatternSet tensor_patterns(mlir_context);
- tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
- tensor_patterns
- .add<RewriteAllocateShared, RewriteNonScalarConstants,
- RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
- RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
- getOperation(), std::move(tensor_patterns)))) {
- signalPassFailure();
- return;
- }
-
- mlir::RewritePatternSet function_patterns(mlir_context);
- function_patterns.add<RewriteFunctionSignatures, RewriteCall,
- RemoveUnusedIndexSwitchResults>(mlir_context);
- scf::ForOp::getCanonicalizationPatterns(function_patterns, mlir_context);
- scf::IfOp::getCanonicalizationPatterns(function_patterns, mlir_context);
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
- getOperation(), std::move(function_patterns)))) {
- signalPassFailure();
- return;
- }
-
- getOperation()->walk([this](mlir::LLVM::LoadOp load) {
- Value addr = load.getAddr();
- while (auto gep = addr.getDefiningOp<mlir::LLVM::GEPOp>()) {
- addr = gep.getBase();
- }
- if (addr.getDefiningOp<mlir::LLVM::AddrSpaceCastOp>() ||
- addr.getDefiningOp<mlir::LLVM::AddressOfOp>() ||
- addr.getDefiningOp<mlir::LLVM::AllocaOp>()) {
- // Shared memory, global constant or temporary - no need to annotate
- // anything.
- return;
- }
- if (auto base = mlir::dyn_cast<mlir::BlockArgument>(addr)) {
- if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(
- base.getOwner()->getParentOp())) {
- if (func.getArgAttr(base.getArgNumber(), "xla.invariant")) {
- load.setInvariant(true);
- }
- return;
- }
- }
- load.emitOpError("load op address is not (a GEP of) a function argument");
- signalPassFailure();
- });
- }
-};
-
-} // namespace
-
-std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(
- bool is_amd_gpu, const std::string& gpu_arch) {
- LowerTensorsPassOptions options;
- options.is_amd_gpu_ = is_amd_gpu;
- options.gpu_arch_ = gpu_arch;
- return std::make_unique<LowerTensorsPass>(options);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc
deleted file mode 100644
index 6e05ead..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-#include <utility>
-
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
-#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
-#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
-#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // IWYU pragma: keep
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_LOWERTOLLVMPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
- public:
- using LowerToLLVMPassBase::LowerToLLVMPassBase;
-
- void runOnOperation() override {
- // Populate type conversions.
- mlir::LowerToLLVMOptions llvm_opts(&getContext(),
- mlir::DataLayout(getOperation()));
- mlir::LLVMTypeConverter type_converter(getOperation().getContext(),
- llvm_opts);
- mlir::LLVMConversionTarget target(*getOperation().getContext());
-
- // Populate patterns.
- mlir::RewritePatternSet patterns(&getContext());
- mlir::populateAffineToStdConversionPatterns(patterns);
- mlir::populateSCFToControlFlowConversionPatterns(patterns);
- mlir::arith::populateArithExpandOpsPatterns(patterns);
- mlir::arith::populateArithToLLVMConversionPatterns(type_converter,
- patterns);
- mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
- mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
- mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
- mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter,
- patterns);
- mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns);
- mlir::populateMathToLLVMConversionPatterns(type_converter, patterns);
-
- // Setup target.
- mlir::configureGpuToNVVMConversionLegality(target);
- target.addIllegalDialect<mlir::arith::ArithDialect, mlir::func::FuncDialect,
- mlir::complex::ComplexDialect,
- mlir::math::MathDialect>();
- target.addLegalOp<mlir::ModuleOp>();
-
- if (failed(
- applyFullConversion(getOperation(), target, std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass() {
- return std::make_unique<LowerToLLVMPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc
deleted file mode 100644
index 9028480..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc
+++ /dev/null
@@ -1,205 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <utility>
-
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-using mlir::success;
-
-struct RewritePredicatedInsert : mlir::OpRewritePattern<PredicatedInsertOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- PredicatedInsertOp op, mlir::PatternRewriter& rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
- op, op.getCondition(),
- [&](mlir::OpBuilder& b, mlir::Location loc) {
- b.create<mlir::scf::YieldOp>(
- loc, b.create<mlir::tensor::InsertOp>(
- loc, op.getValue(), op.getDest(), op.getIndices())
- .getResult());
- },
- [&](mlir::OpBuilder& b, mlir::Location loc) {
- b.create<mlir::scf::YieldOp>(loc, op.getDest());
- });
- return success();
- }
-};
-
-struct RewritePredicatedExtract : mlir::OpRewritePattern<PredicatedExtractOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- PredicatedExtractOp op, mlir::PatternRewriter& rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
- op, op.getCondition(),
- [&](mlir::OpBuilder& b, mlir::Location loc) {
- b.create<mlir::scf::YieldOp>(
- loc, b.create<mlir::tensor::ExtractOp>(loc, op.getSrc(),
- op.getIndices())
- .getResult());
- },
- [&](mlir::OpBuilder& b, mlir::Location loc) {
- b.create<mlir::scf::YieldOp>(loc, op.getFallback());
- });
- return success();
- }
-};
-
-struct RewriteShuffleReduce : mlir::OpRewritePattern<ShuffleReduceOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override {
- int max_distance =
- mlir::cast<mlir::IntegerAttr>(op->getAttr("max_distance")).getInt();
- // TODO(jreiffers): Do this in a verifier.
- if (max_distance & (max_distance - 1) || max_distance >= WarpSize()) {
- return op->emitOpError("max_distance must be a power of 2 < WarpSize()");
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- mlir::ValueRange values = op.getOperands();
- for (int distance = max_distance; distance > 0; distance /= 2) {
- namespace ml = mlir::LLVM;
- auto shuffle_32 = [&](mlir::Value v) {
- return b
- .create<mlir::gpu::ShuffleOp>(v, distance, WarpSize(),
- mlir::gpu::ShuffleMode::DOWN)
- .getShuffleResult();
- };
-
- auto shuffle_int_or_float = [&](mlir::Value value) {
- auto ty = value.getType();
- int bit_width = ty.getIntOrFloatBitWidth();
- if (bit_width == 32) {
- return shuffle_32(value);
- }
- int n_shuffles = CeilOfRatio(bit_width, 32);
- auto int_ty = b.getIntegerType(bit_width);
- auto padded_int_ty = b.getIntegerType(n_shuffles * 32);
- value = b.create<mlir::arith::BitcastOp>(int_ty, value);
- value = b.create<mlir::arith::ExtUIOp>(padded_int_ty, value);
- if (n_shuffles > 1) {
- // Don't generate vectors if the size is 1.
- auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles);
- value = b.create<ml::BitcastOp>(vector_type, value);
- mlir::Value result_vec = b.create<ml::UndefOp>(vector_type);
- for (int i = 0; i < n_shuffles; ++i) {
- auto idx = b.create<mlir::arith::ConstantIntOp>(i, 32);
- result_vec = b.create<ml::InsertElementOp>(
- result_vec,
- shuffle_32(b.create<ml::ExtractElementOp>(value, idx)), idx);
- }
- value = b.create<ml::BitcastOp>(padded_int_ty, result_vec);
- } else {
- value = shuffle_32(value);
- }
- value = b.create<mlir::arith::TruncIOp>(int_ty, value);
- value = b.create<ml::BitcastOp>(ty, value);
- return value;
- };
-
- auto shuffle = [&](mlir::Value value) -> mlir::Value {
- if (mlir::isa<mlir::ComplexType>(value.getType())) {
- return b.create<mlir::complex::CreateOp>(
- value.getType(),
- shuffle_int_or_float(b.create<mlir::complex::ReOp>(value)),
- shuffle_int_or_float(b.create<mlir::complex::ImOp>(value)));
- }
- if (value.getType().isUnsignedInteger()) {
- auto ty = value.getType();
- auto signless_ty = b.getIntegerType(ty.getIntOrFloatBitWidth());
- value = b.create<mlir::UnrealizedConversionCastOp>(
- mlir::TypeRange{signless_ty}, value)
- .getResult(0);
- value = shuffle_int_or_float(value);
- value = b.create<mlir::UnrealizedConversionCastOp>(
- mlir::TypeRange{ty}, value)
- .getResult(0);
- return value;
- }
- return shuffle_int_or_float(value);
- };
-
- llvm::SmallVector<mlir::Value> args = values;
- for (auto value : values) {
- args.push_back(shuffle(value));
- }
- values = b.create<PureCallOp>(op.getResultTypes(),
- op.getReducerAttr().getAttr(), args)
- .getResults();
- }
- rewriter.replaceOp(op, values);
- return success();
- }
-};
-
-class LowerXlaGpuToScfPass
- : public impl::LowerXlaGpuToScfPassBase<LowerXlaGpuToScfPass> {
- public:
- void runOnOperation() override {
- mlir::RewritePatternSet patterns(&getContext());
- patterns.add<RewritePredicatedInsert, RewritePredicatedExtract,
- RewriteShuffleReduce>(&getContext());
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() {
- return std::make_unique<LowerXlaGpuToScfPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc b/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc
deleted file mode 100644
index c1899d2..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc
+++ /dev/null
@@ -1,117 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <optional>
-#include <string>
-
-#include "absl/container/flat_hash_map.h"
-#include "llvm/ADT/BitVector.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class MergePointersToSameSlicePass
- : public impl::MergePointersToSameSlicePassBase<
- MergePointersToSameSlicePass> {
- public:
- void runOnOperation() override;
-};
-
-struct PackedArgs {
- llvm::BitVector args_to_erase;
- // replacement_args[i] == i iff !args_to_erase[i].
- llvm::SmallVector<int> replacement_args;
-
- PackedArgs() = default;
- explicit PackedArgs(mlir::func::FuncOp func) {
- absl::flat_hash_map<int, std::optional<int>> slice_to_operand;
- args_to_erase.resize(func.getNumArguments());
- replacement_args.reserve(func.getNumArguments());
- for (int i = 0; i < func.getNumArguments(); ++i) {
- replacement_args.push_back(i);
- }
-
- for (auto [idx, operand] : llvm::enumerate(func.getArguments())) {
- auto slice_index = func.getArgAttr(idx, "xla.slice_index");
- if (!slice_index) {
- continue;
- }
-
- auto& target_index = slice_to_operand[static_cast<int>(
- mlir::cast<mlir::IntegerAttr>(slice_index).getInt())];
- if (target_index) {
- replacement_args[idx] = *target_index;
- args_to_erase[idx] = true;
- } else {
- target_index = idx;
- }
- }
- }
-
- void Pack(mlir::func::FuncOp op) {
- for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
- if (replacement_args[idx] != idx) {
- arg.replaceAllUsesWith(op.getArgument(replacement_args[idx]));
- }
- }
- op.eraseArguments(args_to_erase);
- for (int i = 0; i < op.getNumArguments(); ++i) {
- if (op.getArgAttr(i, "xla.slice_index")) {
- op.removeArgAttr(i, "xla.slice_index");
- op.setArgAttr(i, mlir::LLVM::LLVMDialect::getNoAliasAttrName(),
- mlir::UnitAttr::get(op->getContext()));
- }
- }
- }
-
- void Pack(mlir::func::CallOp op) { op->eraseOperands(args_to_erase); }
-};
-
-void MergePointersToSameSlicePass::runOnOperation() {
- mlir::func::FuncOp entry;
-
- absl::flat_hash_map<std::string, PackedArgs> args_to_pack;
- getOperation()->walk([&](mlir::func::FuncOp func) {
- args_to_pack[func.getName()] = PackedArgs(func);
- });
- getOperation()->walk([&](mlir::func::CallOp call) {
- args_to_pack[call.getCallee()].Pack(call);
- });
- getOperation()->walk([&](mlir::func::FuncOp func) {
- args_to_pack[func.getName()].Pack(func);
- });
-}
-
-} // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-CreateMergePointersToSameSlicePass() {
- return std::make_unique<MergePointersToSameSlicePass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
index 251d3ff..524d365 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
@@ -83,11 +83,11 @@
#include "xla/service/buffer_assignment.h"
#include "xla/service/dump.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
#include "xla/service/gpu/fusions/mlir/type_util.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/kernel_arguments.h"
#include "xla/service/gpu/kernel_reuse_cache.h"
@@ -222,7 +222,7 @@
absl::StatusOr<FusionEmissionResult> MlirFusionEmitterBase::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
- VLOG(5) << "Fusion: " << fusion.fused_instructions_computation()->ToString();
+ VLOG(4) << "Fusion: " << fusion.fused_instructions_computation()->ToString();
TF_ASSIGN_OR_RETURN(
auto args,
KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion));
@@ -305,13 +305,14 @@
mlir::PassManager pm(&mlir_context);
pm.addPass(CreateEraseDeadFunctionsPass());
pm.addPass(mlir::createCSEPass());
- pm.addPass(CreateLowerXlaGpuToScfPass());
+ pm.addNestedPass<mlir::func::FuncOp>(CreateLowerXlaGpuToScfPass());
pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
// CSE after inlining because inlining can introduce duplicates.
pm.addPass(mlir::createCSEPass());
}));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
+ pm.addNestedPass<mlir::func::FuncOp>(CreateLowerXlaGpuLoopsToScfPass());
pm.addPass(mlir::mhlo::createConvertToSignlessPass());
pm.addPass(CreatePropagateSliceIndicesPass());
// We need LICM before unswitching loops, because our loop unswitcher only
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc
deleted file mode 100644
index 6d5456f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc
+++ /dev/null
@@ -1,315 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <algorithm>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/log/check.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/IR/Visitors.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_OPTIMIZELOOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-mlir::Value GetSource(mlir::vector::TransferReadOp op) {
- return op.getSource();
-}
-
-bool DoIndicesDependOnInductionVar(mlir::ValueRange indices,
- mlir::scf::ForOp loop) {
- // We assume LICM ran, so we can just check if any index is defined in the
- // loop.
- return absl::c_any_of(indices, [&](mlir::Value v) {
- return v.getParentRegion() == &loop.getBodyRegion();
- });
-}
-
-bool CanReplaceInductionVar(mlir::ValueRange indices) {
- return absl::c_all_of(indices, [&](mlir::Value v) {
- if (mlir::isa<mlir::BlockArgument>(v)) {
- return true;
- }
- auto* op = v.getDefiningOp();
- return op &&
- mlir::isa<mlir::arith::ConstantOp, ApplyIndexingOp,
- mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
- mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(
- op) &&
- CanReplaceInductionVar(op->getOperands());
- });
-}
-
-llvm::SmallVector<mlir::Value> ReplaceInductionVar(
- mlir::Value induction_var, mlir::Value replacement,
- llvm::SmallVector<mlir::Value> indices,
- mlir::ImplicitLocOpBuilder& builder) {
- for (mlir::Value& index : indices) {
- if (mlir::isa<mlir::BlockArgument>(index)) {
- if (index == induction_var) {
- index = replacement;
- }
- continue;
- }
-
- auto* op = index.getDefiningOp();
- CHECK(op) << "Did CanReplaceInductionVar() fail?";
- if (mlir::isa<mlir::arith::ConstantOp>(op)) {
- continue;
- }
-
- CHECK(
- (mlir::isa<ApplyIndexingOp, mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
- mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(op)))
- << "Did CanReplaceInductionVar() fail?";
- auto replaced_args = ReplaceInductionVar(induction_var, replacement,
- op->getOperands(), builder);
- index = builder
- .create(builder.getLoc(), op->getName().getIdentifier(),
- replaced_args, op->getResultTypes(), op->getAttrs())
- ->getResult(0);
- }
- return indices;
-}
-
-mlir::Value GetSource(mlir::tensor::ExtractOp op) { return op.getTensor(); }
-
-// TODO(jreiffers): Use a shared memory queue for pipelining instead of
-// registers.
-template <typename Op>
-struct PipelineLoad : mlir::OpRewritePattern<Op> {
- using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- Op op, mlir::PatternRewriter& rewriter) const override {
- auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
- if (!loop) {
- return rewriter.notifyMatchFailure(op, "no loop found");
- }
-
- if (auto step = loop.getConstantStep();
- !step || step->getSExtValue() != 1) {
- return rewriter.notifyMatchFailure(op, "loop step is not 1");
- }
-
- llvm::APInt lb, ub;
- if (!mlir::matchPattern(loop.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
- !mlir::matchPattern(loop.getUpperBound(), mlir::m_ConstantInt(&ub))) {
- return rewriter.notifyMatchFailure(op, "bounds are not constants");
- }
- if (lb.getSExtValue() != 0) {
- return rewriter.notifyMatchFailure(op, "lower bound is not 0");
- }
-
- auto source = GetSource(op);
- if (!source.getParentRegion()->isProperAncestor(&loop.getBodyRegion())) {
- return rewriter.notifyMatchFailure(
- op, "source is not defined outside the loop");
- }
-
- if (!DoIndicesDependOnInductionVar(op.getIndices(), loop)) {
- // We don't run LICM between iterations, so this could happen.
- // Just hoist the load out of the loop.
- rewriter.moveOpBefore(op, loop);
- return mlir::success();
- }
-
- if (!CanReplaceInductionVar(op.getIndices())) {
- return rewriter.notifyMatchFailure(op, "unable to replace indices");
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- mlir::Value zero = b.create<mlir::arith::ConstantIndexOp>(0);
-
- b.setInsertionPoint(loop);
- auto first_args =
- ReplaceInductionVar(loop.getInductionVar(), zero, op.getOperands(), b);
- auto loaded_first =
- b.create<Op>(op->getResultTypes(), first_args, op->getAttrs());
- auto ub_minus_one =
- b.create<mlir::arith::ConstantIndexOp>(ub.getSExtValue() - 1);
-
- b.setInsertionPointToStart(loop.getBody());
-
- auto needs_load = b.create<mlir::arith::CmpIOp>(
- mlir::arith::CmpIPredicate::ult, loop.getInductionVar(), ub_minus_one);
- auto next_value =
- b.create<mlir::scf::IfOp>(op->getResultTypes(), needs_load, true, true);
- auto new_for =
- mlir::cast<mlir::scf::ForOp>(*loop.replaceWithAdditionalYields(
- rewriter, loaded_first->getResult(0),
- /*replaceInitOperandUsesInLoop=*/false,
- [&](mlir::OpBuilder&, mlir::Location,
- llvm::ArrayRef<mlir::BlockArgument>) {
- return llvm::SmallVector<mlir::Value>{next_value->getResult(0)};
- }));
- rewriter.replaceAllUsesWith(op, new_for.getRegionIterArgs().back());
-
- b.setInsertionPointToStart(next_value.thenBlock());
- auto yield = b.create<mlir::scf::YieldOp>(op->getResult(0));
-
- // We use this convoluted way to add 1 so folding works properly.
- auto plus_one_map = mlir::AffineMap::get(
- 1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1);
- b.setInsertionPoint(next_value);
- auto induction_plus_one =
- b.create<ApplyIndexingOp>(new_for.getInductionVar(), plus_one_map, 0,
- ub.getSExtValue() - 1)
- ->getResult(0);
-
- // Create the new apply_indexing ops outside the if, to improve CSE.
- rewriter.modifyOpInPlace(op, [&]() {
- op->setOperands(ReplaceInductionVar(
- new_for.getInductionVar(), induction_plus_one, op->getOperands(), b));
- });
- rewriter.moveOpBefore(op, yield);
-
- b.setInsertionPointToStart(next_value.elseBlock());
- b.create<mlir::scf::YieldOp>(new_for.getRegionIterArgs().back());
- return mlir::success();
- }
-};
-
-int GetUnrollingFactor(mlir::scf::ForOp op) {
- // We only unroll loops with a step of 1 and a lower bound of 0. That's the
- // only type we generate.
- if (auto step = op.getConstantStep(); !step || step->getSExtValue() != 1) {
- return 1;
- }
- llvm::APInt lb, ub;
- if (!mlir::matchPattern(op.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
- !mlir::matchPattern(op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
- return 1;
- }
- if (lb.getSExtValue() != 0) {
- return 1;
- }
-
- int64_t trip_count = ub.getSExtValue();
- constexpr int kMaxSize = 400; // Chosen empirically.
-
- // Get a rough estimate of the size of the loop body.
- int64_t size = 0;
- op.getBodyRegion().walk([&](mlir::Operation* op) {
- if (mlir::isa<mlir::func::CallOp, mlir::scf::ForOp>(op)) {
- size += kMaxSize;
- return;
- }
-
- int64_t this_size = 1;
- if (mlir::isa<mlir::math::MathDialect>(op->getDialect())) {
- // Integer instructions in math are ok, but many float ops lower to lots
- // of instructions.
- if (!op->getResultTypes().front().isIntOrIndex()) {
- namespace mm = mlir::math;
- // We err on the side of not unrolling, so we maintain a list of ops
- // known to be cheap.
- if (!mlir::isa<mm::AbsFOp, mm::CeilOp, mm::CopySignOp, mm::FloorOp,
- mm::FmaOp, mm::RoundEvenOp, mm::RoundOp, mm::RsqrtOp,
- mm::SqrtOp, mm::TruncOp>(op)) {
- this_size = 20; // Rough estimate.
- }
- }
- }
-
- if (!op->getResultTypes().empty()) {
- if (auto vector_ty =
- mlir::dyn_cast<mlir::VectorType>(op->getResultTypes().front())) {
- this_size *= vector_ty.getNumElements();
- }
- }
-
- size += this_size;
- });
-
- int factor = std::min(trip_count, kMaxSize / size);
- while (factor > 1 && trip_count % factor) {
- --factor;
- }
- return factor;
-}
-
-struct UnrollLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
- using mlir::OpRewritePattern<mlir::scf::ForOp>::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
- if (int factor = GetUnrollingFactor(op); factor > 1) {
- return mlir::loopUnrollByFactor(op, factor);
- }
- return rewriter.notifyMatchFailure(op, "loop can't be unrolled");
- }
-};
-
-class OptimizeLoopsPass
- : public impl::OptimizeLoopsPassBase<OptimizeLoopsPass> {
- public:
- void runOnOperation() override {
- // First unroll loops. If unrolling is possible, we prefer it.
- mlir::RewritePatternSet unroll_patterns(&getContext());
- unroll_patterns.add<UnrollLoops>(&getContext());
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
- getOperation(), std::move(unroll_patterns)))) {
- signalPassFailure();
- return;
- }
-
- // Then pipeline the remaining loops.
- mlir::RewritePatternSet patterns(&getContext());
- patterns.add<PipelineLoad<mlir::vector::TransferReadOp>,
- PipelineLoad<mlir::tensor::ExtractOp>>(&getContext());
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
-CreateOptimizeLoopsPass() {
- return std::make_unique<OptimizeLoopsPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.h b/third_party/xla/xla/service/gpu/fusions/mlir/passes.h
deleted file mode 100644
index bb0f1d4..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/passes.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_
-
-#include <memory>
-#include <optional>
-#include <string>
-
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DECL
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-// Returns the range of a given value, if it can be statically determined.
-std::optional<Interval> GetRange(mlir::Value value);
-
-// Returns the range for the induction variable, if it can be statically
-// determined.
-std::optional<Interval> GetIVRange(mlir::Value iv);
-
-std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
-std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere);
-std::unique_ptr<mlir::Pass> CreateConvertPureCallOpsPass();
-std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass();
-std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
- bool is_amd_gpu = false, const std::string& gpu_arch = "6.0");
-std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass();
-std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass();
-std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass();
-std::unique_ptr<mlir::Pass> CreateOptimizeLoopsPass();
-std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass();
-std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass();
-std::unique_ptr<mlir::Pass> CreateSimplifyArithPass();
-std::unique_ptr<mlir::Pass> CreateUnswitchLoopsPass();
-std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass();
-
-#define GEN_PASS_REGISTRATION
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.td b/third_party/xla/xla/service/gpu/fusions/mlir/passes.td
deleted file mode 100644
index 6785670..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/passes.td
+++ /dev/null
@@ -1,290 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_
-
-include "mlir/Pass/PassBase.td"
-
-def PropagateSliceIndicesPass :
- Pass<"xla-gpu-propagate-slice-indices", "mlir::ModuleOp"> {
- let summary = "Propagates slice indices from the entry function to all callees.";
-
- let description = [{
- Propagates xla.slice_index attributes from the function with the xla.entry
- attribute to all other functions.
- }];
-
- let dependentDialects = [
- "mlir::func::FuncDialect"
- ];
-
- let constructor = "CreatePropagateSliceIndicesPass()";
-}
-
-def ConvertPureCallOpsPass
- : Pass<"xla-gpu-convert-pure-call-ops", "mlir::func::FuncOp"> {
- let summary = "Converts xla_gpu.pure_call to func.call";
- let description = [{
- We use xla_gpu.pure_call ops for calls to enable CSE and other
- transformations (e.g. LICM). This pass rewrites our custom ops to standard
- ops.
- }];
- let dependentDialects = [
- "mlir::func::FuncDialect",
- "xla::gpu::XlaGpuDialect"
- ];
- let constructor = "CreateConvertPureCallOpsPass()";
-}
-
-def FlattenTensorsPass : Pass<"xla-gpu-flatten-tensors", "mlir::ModuleOp"> {
- let summary = "Flatten tensors.";
-
- let description = [{
- Linearizes all tensors loads and stores.
- }];
-
- let dependentDialects = [
- "mlir::func::FuncDialect",
- "mlir::tensor::TensorDialect",
- "xla::gpu::XlaGpuDialect",
- ];
- let constructor = "CreateFlattenTensorsPass()";
-}
-
-def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {
- let summary = "Lowers tensors to llvm pointers and loads/stores.";
-
- let description = [{
- Lowers tensors to LLVM. We cannot use the memref lowerings because they
- are not compatible with XLA's ABI.
- }];
-
- let dependentDialects = [
- "mlir::LLVM::LLVMDialect",
- "mlir::func::FuncDialect",
- "mlir::gpu::GPUDialect",
- "mlir::scf::SCFDialect",
- "mlir::tensor::TensorDialect",
- "xla::gpu::XlaGpuDialect",
- ];
- let options = [
- Option<"is_amd_gpu_", "is_amd_gpu", "bool", /*default=*/"false",
- "True if AMD GPU.">,
- Option<"gpu_arch_", "gpu_arch", "std::string", /*default=*/"",
- "CUDA or ROCm compute capability.">,
- ];
- let constructor = "CreateLowerTensorsPass()";
-}
-
-def MergePointersToSameSlicePass :
- Pass<"xla-gpu-merge-pointers", "mlir::ModuleOp"> {
- let summary = "Merges pointers that share slices.";
-
- let description = [{
- When a function has multiple pointer arguments with the same slice index,
- merges them.
- }];
-
- let dependentDialects = [
- "mlir::func::FuncDialect"
- ];
-
- let constructor = "CreateMergePointersToSameSlicePass()";
-}
-
-def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::func::FuncOp"> {
- let summary = "Simplifies arith using XLA's range-aware simplifier.";
-
- let description = [{
- We often emit bounds checks that are statically known to be satisfied.
- This pass removes them.
- }];
-
- let dependentDialects = [
- "mlir::arith::ArithDialect",
- "mlir::func::FuncDialect",
- ];
-
- let constructor = "CreateSimplifyArithPass()";
-}
-
-def SimplifyAffinePass : Pass<"xla-gpu-simplify-affine", "mlir::ModuleOp"> {
- let summary = "Simplifies affine.apply using XLA's range-aware simplifier.";
-
- let description = [{
- The standard affine canonicalizer cannot simplify all expressions, since
- it is unaware of range information. This pass uses `xla.range` attributes
- on arguments and ops for simplification. It also lowers floordiv and mod
- to simpler expressions than lower-affine. This pass only works for
- expressions for which we can prove the LHS of mod and div is nonnegative.
- }];
-
- let dependentDialects = [
- "mlir::affine::AffineDialect", "mlir::func::FuncDialect",
- "mlir::scf::SCFDialect",
- ];
-
- let constructor = "CreateSimplifyAffinePass()";
-}
-
-def ExpandFloatOpsPass : Pass<"xla-gpu-expand-float-ops", "mlir::ModuleOp"> {
- let summary = "Expands float ops that are not natively supported.";
-
- let description = [{
- Not all float ops are natively supported, either because they don't exist
- in hardware or they are too inaccurate.
-
- This pass replaces these ops with alternative implementations.
- }];
-
- let dependentDialects = [
- "mlir::arith::ArithDialect", "mlir::math::MathDialect",
- "mlir::mhlo::MhloDialect"
- ];
-
- let options = [
- Option<"pre_ampere_", "pre-ampere", "bool", /*default=*/"false",
- "Rewrite ops that are not supported on architectures before Ampere">,
- ];
-}
-
-def LowerXlaGpuToScfPass :
- Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::ModuleOp"> {
- let summary = "Lowers xla_gpu to SCF.";
-
- let dependentDialects = [
- "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect",
- "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect",
- ];
-
- let constructor = "CreateLowerXlaGpuToScfPass()";
-}
-
-def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> {
- let summary = "Deletes unused functions";
-
- let description = [{
- Deletes functions that are not called.
- }];
-
- let dependentDialects = [
- "mlir::func::FuncDialect",
- "xla::gpu::XlaGpuDialect"
- ];
-
- let constructor = "CreateEraseDeadFunctionsPass()";
-}
-
-def LowerToLLVMPass :
- Pass<"xla-gpu-lower-to-llvm", "mlir::ModuleOp"> {
- let summary = "Lowers to LLVM.";
-
- let description = [{
- Lowers the rest to LLVM
- }];
-
- let dependentDialects = [
- "mlir::func::FuncDialect",
- "mlir::LLVM::LLVMDialect",
- "mlir::NVVM::NVVMDialect",
- ];
-
- let constructor = "CreateLowerToLLVMPass()";
-}
-
-def VectorizeLoadsAndStoresPass :
- Pass<"xla-gpu-vectorize-loads-stores", "mlir::func::FuncOp"> {
- let summary = "Vectorizes loads and stores.";
-
- let description = [{
- Rewrites tensor.extract and tensor.insert ops inside loops to their vector
- equivalents (vector.transfer_read and vector.transfer_write + vector.extract
- and vector.insert).
- }];
-
- let dependentDialects = [
- "mlir::vector::VectorDialect",
- ];
-
- let constructor = "CreateVectorizeLoadsAndStoresPass()";
-}
-
-def OptimizeLoopsPass :
- Pass<"xla-gpu-optimize-loops", "mlir::func::FuncOp"> {
- let summary = "Unrolls and pipelines loops.";
-
- let description = [{
- Unrolls loops with a small trip count. Pipelines loops with a large trip
- count.
- }];
-
- let dependentDialects = [
- "mlir::vector::VectorDialect",
- "xla::gpu::XlaGpuDialect",
- ];
-
- let constructor = "CreateOptimizeLoopsPass()";
-}
-
-def UnswitchLoopsPass :
- Pass<"xla-gpu-unswitch-loops", "mlir::func::FuncOp"> {
- let summary = "Swaps scf.if and scf.for.";
-
- let description = [{
- Extracts `scf.if` ops with conditions that are independent of the loop
- variable from `scf.for` by doing the following rewrite:
-
- Before:
-
- %cond = some_cond() : i1
- %results = scf.for {
- %some_val = scf.if %cond {
- } else {
- }
- scf.yield %some_val
- }
-
- After:
-
- %cond = some_cond() : i1
- %results = scf.if %cond {
- %results = scf.for {
- %some_val = scf.if %true {
- } else {
- }
- }
- yield %results
- } else {
- %results = scf.for {
- %some_val = scf.if %false {
- } else {
- }
- }
- yield %results
- }
-
- This only triggers if there is a single `scf.if` op in the loop body (and
- nothing else).
- }];
-
- let dependentDialects = [
- "mlir::func::FuncDialect", "mlir::scf::SCFDialect"
- ];
-
- let constructor = "CreateUnswitchLoopsPass()";
-}
-
-#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc b/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc
deleted file mode 100644
index 218b432f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class PropagateSliceIndicesPass
- : public impl::PropagateSliceIndicesPassBase<PropagateSliceIndicesPass> {
- public:
- void runOnOperation() override;
-};
-
-void PropagateSliceIndicesPass::runOnOperation() {
- mlir::func::FuncOp entry;
- for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
- if (func->getAttr("xla.entry")) {
- entry = func;
- break;
- }
- }
-
- if (!entry) {
- getOperation()->emitOpError("No entry function found.");
- signalPassFailure();
- return;
- }
-
- for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
- if (func.getNumArguments() == 0 || func == entry) {
- continue;
- }
-
- for (int i = 0; i < func.getNumArguments(); ++i) {
- if (mlir::isa<mlir::RankedTensorType>(func.getArgument(i).getType())) {
- if (auto index = entry.getArgAttr(i, "xla.slice_index")) {
- func.setArgAttr(i, "xla.slice_index", index);
- }
- if (auto invariant = entry.getArgAttr(i, "xla.invariant")) {
- func.setArgAttr(i, "xla.invariant", invariant);
- }
- } else {
- break;
- }
- }
- }
-}
-
-} // namespace
-
-std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass() {
- return std::make_unique<PropagateSliceIndicesPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc
deleted file mode 100644
index 7b23499..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc
+++ /dev/null
@@ -1,368 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cstdint>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/base/optimization.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/LoopUtils.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using mlir::AffineBinaryOpExpr;
-using mlir::AffineConstantExpr;
-using mlir::AffineDimExpr;
-using mlir::AffineExpr;
-using mlir::AffineExprKind;
-using mlir::AffineMap;
-using mlir::AffineSymbolExpr;
-using mlir::ImplicitLocOpBuilder;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpRewritePattern;
-using mlir::PatternRewriter;
-using mlir::SmallVector;
-using mlir::Value;
-using mlir::ValueRange;
-using mlir::affine::AffineApplyOp;
-
-namespace arith = mlir::arith;
-
-#define GEN_PASS_DEF_SIMPLIFYAFFINEPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-int Distance(ImplicitLocOpBuilder& builder, Value a) {
- auto* block = builder.getInsertionBlock();
- auto* parent = a.getParentBlock();
- int distance = 0;
- while (block && block != parent) {
- ++distance;
- block = block->getParentOp()->getBlock();
- }
- return distance;
-}
-
-void CollectArgs(AffineExpr expr, AffineExprKind kind,
- llvm::SmallVector<AffineExpr>& ret) {
- if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
- if (bin_op.getKind() == kind) {
- CollectArgs(bin_op.getLHS(), kind, ret);
- CollectArgs(bin_op.getRHS(), kind, ret);
- return;
- }
- }
- ret.push_back(expr);
-}
-
-struct ExpressionEvaluator {
- ExpressionEvaluator(ImplicitLocOpBuilder& builder, unsigned dim_count,
- ValueRange operands)
- : builder(builder), operands(operands) {
- for (int i = 0; i < dim_count; ++i) {
- dim_distances.push_back(Distance(builder, operands[i]));
- }
- for (int i = dim_count; i < operands.size(); ++i) {
- sym_distances.push_back(Distance(builder, operands[i]));
- }
- }
-
- // Returns the distance (in basic blocks) from the insertion point to the
- // values used in the given expression.
- int ExprDistance(AffineExpr e, int depth = 0) {
- if (auto dim = mlir::dyn_cast<AffineDimExpr>(e)) {
- return dim_distances[dim.getPosition()];
- }
- if (auto sym = mlir::dyn_cast<AffineSymbolExpr>(e)) {
- return sym_distances[sym.getPosition()];
- }
- if (auto binop = mlir::dyn_cast<AffineBinaryOpExpr>(e)) {
- return std::min(ExprDistance(binop.getLHS(), depth + 1),
- ExprDistance(binop.getRHS(), depth + 1));
- }
- if (depth == 0) {
- // Top-level constant. Always add these last.
- return std::numeric_limits<int>::min();
- }
- // Nested constant. Ignore these for distances.
- return std::numeric_limits<int>::max();
- }
-
- Value EvaluateExpression(AffineExpr expr);
-
- template <typename Op>
- Value EvaluateAddMul(AffineExpr expr);
-
- ImplicitLocOpBuilder& builder;
- ValueRange operands;
- SmallVector<int> dim_distances;
- SmallVector<int> sym_distances;
-};
-
-template <typename Op>
-Value ExpressionEvaluator::EvaluateAddMul(AffineExpr expr) {
- llvm::SmallVector<AffineExpr> args;
- CollectArgs(expr, expr.getKind(), args);
- // Sort the args so that the ones that are closest to the insertion point
- // are evaluated last - this improves LICM.
- llvm::stable_sort(args, [&](AffineExpr a, AffineExpr b) {
- int dist_a = ExprDistance(a);
- int dist_b = ExprDistance(b);
- return dist_a > dist_b;
- });
-
- Value result = nullptr;
- for (auto arg : args) {
- Value arg_evaluated = EvaluateExpression(arg);
- if (result) {
- result = builder.create<Op>(result, arg_evaluated);
- } else {
- result = arg_evaluated;
- }
- }
-
- return result;
-}
-
-Value ExpressionEvaluator::EvaluateExpression(AffineExpr expr) {
- if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
- switch (expr.getKind()) {
- case AffineExprKind::Add:
- return EvaluateAddMul<arith::AddIOp>(expr);
- case AffineExprKind::Mul:
- return EvaluateAddMul<arith::MulIOp>(expr);
- case AffineExprKind::Mod:
- return builder.create<arith::RemUIOp>(
- EvaluateExpression(bin_op.getLHS()),
- EvaluateExpression(bin_op.getRHS()));
- case AffineExprKind::FloorDiv:
- return builder.create<arith::DivUIOp>(
- EvaluateExpression(bin_op.getLHS()),
- EvaluateExpression(bin_op.getRHS()));
- default:
- ABSL_UNREACHABLE();
- }
- }
- switch (expr.getKind()) {
- case AffineExprKind::Constant:
- return builder.create<arith::ConstantIndexOp>(
- mlir::cast<AffineConstantExpr>(expr).getValue());
- case AffineExprKind::DimId:
- return operands[mlir::cast<AffineDimExpr>(expr).getPosition()];
- case AffineExprKind::SymbolId:
- return operands[dim_distances.size() +
- mlir::cast<AffineSymbolExpr>(expr).getPosition()];
- default:
- ABSL_UNREACHABLE();
- }
-}
-
-bool IsLoweringSupported(AffineExpr expr, RangeEvaluator& range_evaluator) {
- auto bin_op = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
- if (!bin_op) {
- return true;
- }
- // Mod and div can be lowered if their LHS is >= 0 and their RHS is a
- // constant.
- if (expr.getKind() == AffineExprKind::Mod ||
- expr.getKind() == AffineExprKind::FloorDiv) {
- if (!range_evaluator.IsAlwaysPositiveOrZero(bin_op.getLHS()) ||
- !range_evaluator.ComputeExpressionRange(bin_op.getRHS()).IsPoint()) {
- return false;
- }
- }
- if (expr.getKind() == AffineExprKind::CeilDiv) {
- return false;
- }
- return IsLoweringSupported(bin_op.getLHS(), range_evaluator) &&
- IsLoweringSupported(bin_op.getRHS(), range_evaluator);
-}
-
-struct RewriteAffineApply : OpRewritePattern<mlir::affine::AffineApplyOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mlir::affine::AffineApplyOp op,
- PatternRewriter& rewriter) const override {
- AffineMap affine_map = op.getAffineMap();
- std::vector<DimVar> dim_ranges(affine_map.getNumDims());
- std::vector<RangeVar> symbol_ranges(affine_map.getNumSymbols());
-
- for (int i = 0; i < affine_map.getNumInputs(); ++i) {
- if (auto range = GetRange(op->getOperand(i))) {
- if (i >= dim_ranges.size()) {
- symbol_ranges[i - dim_ranges.size()] = RangeVar{*range};
- } else {
- dim_ranges[i] = DimVar{*range};
- }
- } else {
- return rewriter.notifyMatchFailure(op, "failed to deduce range");
- }
- }
-
- IndexingMap indexing_map(affine_map, std::move(dim_ranges),
- std::move(symbol_ranges),
- /*rt_vars=*/{});
- indexing_map.Simplify();
- auto result_expr = indexing_map.GetAffineMap().getResult(0);
-
- ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
- if (!IsLoweringSupported(result_expr, range_evaluator)) {
- return rewriter.notifyMatchFailure(op,
- "unable to lower the affine apply");
- }
- b.setInsertionPoint(op);
- auto result = ExpressionEvaluator(b, indexing_map.GetDimensionCount(),
- op->getOperands())
- .EvaluateExpression(result_expr);
- rewriter.replaceOp(op, result);
- return mlir::success();
- }
-};
-
-struct RewriteApplyIndexingOp : OpRewritePattern<ApplyIndexingOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ApplyIndexingOp op,
- PatternRewriter& rewriter) const override {
- auto indexing_map = op.getIndexingMap();
- indexing_map.Simplify();
- auto affine_map = indexing_map.GetAffineMap();
- int64_t dim_count = indexing_map.GetDimensionCount();
- auto operands = op->getOperands();
-
- ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
-
- b.setInsertionPoint(op);
- SmallVector<Value, 4> results;
- results.reserve(affine_map.getNumResults());
- for (unsigned i = 0; i < affine_map.getNumResults(); ++i) {
- AffineExpr result_expr = affine_map.getResult(i);
- // If the expression cannot be lowered, we convert it to affine.apply,
- // since it supports more expression types.
- if (IsLoweringSupported(result_expr, range_evaluator)) {
- results.push_back(ExpressionEvaluator(b, dim_count, operands)
- .EvaluateExpression(result_expr));
- } else {
- results.push_back(
- b.create<AffineApplyOp>(affine_map.getSubMap({i}), operands));
- }
- }
- rewriter.replaceOp(op, results);
- return mlir::success();
- }
-};
-
-struct SimplifyAffinePass
- : public impl::SimplifyAffinePassBase<SimplifyAffinePass> {
- public:
- void runOnOperation() override {
- MLIRContext* ctx = &getContext();
- mlir::RewritePatternSet patterns(ctx);
- patterns.add<RewriteAffineApply, RewriteApplyIndexingOp>(ctx);
- mlir::GreedyRewriteConfig config;
- // There's no point simplifying more than once.
- config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
- getOperation(), std::move(patterns), config))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::optional<Interval> GetRange(mlir::Value value) {
- auto attr_to_range = [](mlir::Attribute attr) -> std::optional<Interval> {
- if (!attr) {
- return std::nullopt;
- }
- auto values = llvm::to_vector(
- mlir::cast<mlir::ArrayAttr>(attr).getAsValueRange<mlir::IntegerAttr>());
- return {{values[0].getSExtValue(), values[1].getSExtValue()}};
- };
-
- if (auto apply = value.getDefiningOp<ApplyIndexingOp>()) {
- return apply.getIndexingMap().GetRangeEvaluator().ComputeExpressionRange(
- apply.getIndexingMap().GetAffineMap().getResult(
- mlir::cast<mlir::OpResult>(value).getResultNumber()));
- } else if (auto cst = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
- return {{cst.value(), cst.value()}};
- } else if (value.getDefiningOp()) {
- return attr_to_range(value.getDefiningOp()->getAttr("xla.range"));
- }
-
- auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(value);
- if (!bbarg) {
- return std::nullopt;
- }
-
- auto parent = bbarg.getParentBlock()->getParentOp();
- if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(parent)) {
- return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range"));
- }
- return GetIVRange(value);
-}
-
-std::optional<Interval> GetIVRange(mlir::Value iv) {
- auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(iv);
- if (!bbarg) {
- return std::nullopt;
- }
- auto parent = bbarg.getParentBlock()->getParentOp();
- if (auto for_op = mlir::dyn_cast<mlir::scf::ForOp>(parent)) {
- llvm::APInt lb, ub;
- if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) &&
- mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
- return {{lb.getSExtValue(), ub.getSExtValue() - 1}};
- }
- }
- return std::nullopt;
-}
-
-std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass() {
- return std::make_unique<SimplifyAffinePass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc
deleted file mode 100644
index 77b1d7c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc
+++ /dev/null
@@ -1,344 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <cstdint>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DEF_SIMPLIFYARITHPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-using mlir::LogicalResult;
-using mlir::OpRewritePattern;
-using mlir::PatternRewriter;
-using mlir::arith::CmpIOp;
-using mlir::arith::CmpIPredicate;
-
-Interval::ComparisonResult EvaluateCmpI(CmpIPredicate pred, Interval lhs,
- Interval rhs) {
- switch (pred) {
- case CmpIPredicate::eq:
- return lhs.Eq(rhs);
- case CmpIPredicate::ne:
- return lhs.Ne(rhs);
- case CmpIPredicate::slt:
- case CmpIPredicate::ult:
- return lhs.Lt(rhs);
- case CmpIPredicate::sle:
- case CmpIPredicate::ule:
- return lhs.Le(rhs);
- case CmpIPredicate::sgt:
- case CmpIPredicate::ugt:
- return lhs.Gt(rhs);
- case CmpIPredicate::sge:
- case CmpIPredicate::uge:
- return lhs.Ge(rhs);
- }
-}
-
-struct RewriteCmpI : OpRewritePattern<CmpIOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CmpIOp op,
- PatternRewriter& rewriter) const override {
- auto rhs = GetRange(op.getRhs());
- auto lhs = GetRange(op.getLhs());
- if (!lhs || !rhs) {
- return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
- }
- Interval::ComparisonResult result =
- EvaluateCmpI(op.getPredicate(), *lhs, *rhs);
- if (result != std::nullopt) {
- rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(
- op, *result, rewriter.getI1Type());
- return mlir::success();
- }
- return rewriter.notifyMatchFailure(op, "not a constant result");
- }
-};
-
-struct RewriteMaxSi : OpRewritePattern<mlir::arith::MaxSIOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mlir::arith::MaxSIOp op,
- PatternRewriter& rewriter) const override {
- auto lhs = GetRange(op.getLhs());
- auto rhs = GetRange(op.getRhs());
- if (!lhs || !rhs) {
- return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
- }
- if (auto lhs_ge_rhs = lhs->Ge(*rhs); lhs_ge_rhs == true) {
- rewriter.replaceOp(op, op.getLhs());
- } else if (auto rhs_ge_lhs = rhs->Ge(*lhs); rhs_ge_lhs == true) {
- rewriter.replaceOp(op, op.getRhs());
- } else {
- return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
- }
- return mlir::success();
- }
-};
-
-struct RewriteMinSi : OpRewritePattern<mlir::arith::MinSIOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mlir::arith::MinSIOp op,
- PatternRewriter& rewriter) const override {
- auto lhs = GetRange(op.getLhs());
- auto rhs = GetRange(op.getRhs());
- if (!lhs || !rhs) {
- return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
- }
- if (auto lhs_le_rhs = lhs->Le(*rhs); lhs_le_rhs == true) {
- rewriter.replaceOp(op, op.getLhs());
- } else if (auto rhs_le_lhs = rhs->Le(*lhs); rhs_le_lhs == true) {
- rewriter.replaceOp(op, op.getRhs());
- } else {
- return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
- }
- return mlir::success();
- }
-};
-
-// Finds the narrowest value in a use-def chain of truncis/extuis.
-mlir::Value FindNarrowestValueInChain(mlir::Value value) {
- if (auto ext = value.getDefiningOp<mlir::arith::ExtUIOp>()) {
- return FindNarrowestValueInChain(ext.getOperand());
- }
- auto defining_op = value.getDefiningOp<mlir::arith::TruncIOp>();
- if (defining_op) {
- auto first_trunc = FindNarrowestValueInChain(defining_op.getOperand());
- if (first_trunc && first_trunc.getType().getIntOrFloatBitWidth() <=
- defining_op.getType().getIntOrFloatBitWidth()) {
- return first_trunc;
- }
- return defining_op;
- }
- return value;
-}
-
-// Rewrites trunc-bitwise to bitwise-trunc.
-//
-// For pred reductions, we generate code like this:
-//
-// %1 = arith.trunci %0 : i32 to i1
-// %2 = arith.ori %1, %x
-// %3 = arith.extui %2 : i1 to i32
-// %4 = gpu.shuffle %3
-//
-// By swapping the trunc with the or, we get a trunc-ext-shuffle sequence, which
-// can be rewritten to shuffle-trunc-ext. If there is another copy of the
-// pattern afterwards, we can push the truncs/exts further down.
-template <typename Op>
-struct RewriteTruncBitExt : OpRewritePattern<Op> {
- using OpRewritePattern<Op>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(Op op,
- PatternRewriter& rewriter) const override {
- mlir::Value lhs = FindNarrowestValueInChain(op.getLhs());
- mlir::Value rhs = FindNarrowestValueInChain(op.getRhs());
-
- if (lhs.getType() != rhs.getType()) {
- return rewriter.notifyMatchFailure(op, "mismatched narrowest types");
- }
-
- auto trunci_lhs = lhs.getDefiningOp<mlir::arith::TruncIOp>();
- auto trunci_rhs = rhs.getDefiningOp<mlir::arith::TruncIOp>();
- if (!trunci_lhs && !trunci_rhs) {
- return rewriter.notifyMatchFailure(
- op, "neither narrowest value is the result of a truncation");
- }
-
- auto wide_type =
- (trunci_lhs ? trunci_lhs : trunci_rhs).getOperand().getType();
- if (trunci_rhs && trunci_rhs.getOperand().getType() != wide_type) {
- return rewriter.notifyMatchFailure(op, "mismatched truncation types");
- }
-
- mlir::Value new_lhs = trunci_lhs ? trunci_lhs.getOperand()
- : rewriter.create<mlir::arith::ExtUIOp>(
- op.getLoc(), wide_type, lhs);
- mlir::Value new_rhs = trunci_rhs ? trunci_rhs.getOperand()
- : rewriter.create<mlir::arith::ExtUIOp>(
- op.getLoc(), wide_type, rhs);
- mlir::Value new_op = rewriter.create<Op>(op.getLoc(), new_lhs, new_rhs);
- rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, op.getType(),
- new_op);
-
- return mlir::success();
- }
-};
-
-// Rewrites trunc-ext-shuffle to shuffle-trunc-ext. This pattern is designed to
-// work together with RewriteTruncBitExt to optimize pred reductions.
-struct RewriteTruncExtShuffle : public OpRewritePattern<mlir::gpu::ShuffleOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mlir::gpu::ShuffleOp op,
- PatternRewriter& rewriter) const override {
- auto ext = op.getOperand(0).getDefiningOp<mlir::arith::ExtUIOp>();
- if (!ext) {
- return rewriter.notifyMatchFailure(op, "no ext");
- }
- auto trunc = ext.getOperand().getDefiningOp<mlir::arith::TruncIOp>();
- if (!trunc || trunc.getOperand().getType() != ext.getType()) {
- return rewriter.notifyMatchFailure(op, "no trunc or type mismatch");
- }
- rewriter.setInsertionPointAfter(op);
- auto new_trunc = rewriter.create<mlir::arith::TruncIOp>(
- op.getLoc(), trunc.getType(), op.getResult(0));
- auto new_ext = rewriter.create<mlir::arith::ExtUIOp>(
- op.getLoc(), ext.getType(), new_trunc.getResult());
- rewriter.modifyOpInPlace(op,
- [&]() { op->setOperand(0, trunc.getOperand()); });
- rewriter.replaceAllUsesExcept(op.getResult(0), new_ext, new_trunc);
- return mlir::success();
- }
-};
-
-void AnnotateRanges(mlir::func::FuncOp func) {
- func->walk([](mlir::Operation* op) {
- if (op->getNumResults() != 1) {
- return;
- }
-
- auto result = op->getResult(0);
- if (GetRange(result).has_value()) {
- return;
- }
-
- auto get_range = [](mlir::Value value) -> Interval {
- auto range = GetRange(value);
- if (range) {
- return *range;
- }
- return {std::numeric_limits<int64_t>::min(),
- std::numeric_limits<int64_t>::max()};
- };
-
- std::optional<Interval> out_range = std::nullopt;
- if (mlir::isa<mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
- mlir::arith::AddIOp, mlir::arith::MulIOp>(op)) {
- auto lhs_range = get_range(op->getOperand(0));
- auto rhs_range = get_range(op->getOperand(1));
- if (mlir::isa<mlir::arith::MaxSIOp>(op)) {
- out_range = lhs_range.max(rhs_range);
- } else if (mlir::isa<mlir::arith::MinSIOp>(op)) {
- out_range = lhs_range.min(rhs_range);
- } else if (mlir::isa<mlir::arith::AddIOp>(op)) {
- out_range = lhs_range + rhs_range;
- } else {
- out_range = lhs_range * rhs_range;
- }
- }
-
- if (out_range) {
- mlir::OpBuilder b(op);
- op->setAttr("xla.range",
- b.getIndexArrayAttr({out_range->lower, out_range->upper}));
- }
- });
-}
-
-// Pattern to refine the bounds of an indexing map if some of its operands are
-// bound, e.g. loop induction variables.
-struct RefineConstraints : public OpRewritePattern<ApplyIndexingOp> {
- using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
- PatternRewriter& rewriter) const override {
- // Right now, we only handle loop induction variables, but other rules might
- // be added.
- IndexingMap indexing_map = indexing_op.getIndexingMap();
- int64_t dim_count = indexing_map.GetDimensionCount();
- bool updated_bounds = false;
- for (mlir::OpOperand& operand : indexing_op->getOpOperands()) {
- auto range = GetIVRange(operand.get());
- if (!range) {
- continue;
- }
- auto operand_id = operand.getOperandNumber();
- Interval& current_interval =
- operand_id < dim_count
- ? indexing_map.GetMutableDimensionBound(operand_id)
- : indexing_map.GetMutableSymbolBound(operand_id - dim_count);
- if (!range->Contains(current_interval)) {
- current_interval = current_interval.Intersect(*range);
- updated_bounds = true;
- }
- }
- if (!updated_bounds) {
- return rewriter.notifyMatchFailure(indexing_op, "No bounds to refine");
- }
- indexing_map.Simplify();
- rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
- indexing_op, indexing_op.getOperands(), indexing_map);
- return mlir::success();
- }
-};
-
-class SimplifyArithPass
- : public impl::SimplifyArithPassBase<SimplifyArithPass> {
- public:
- void runOnOperation() override {
- auto ctx = &getContext();
- auto func = getOperation();
- mlir::RewritePatternSet patterns(ctx);
- AnnotateRanges(func);
- // clang-format off
- patterns.add<
- RefineConstraints,
- RewriteCmpI,
- RewriteMaxSi,
- RewriteMinSi,
- RewriteTruncBitExt<mlir::arith::AndIOp>,
- RewriteTruncBitExt<mlir::arith::OrIOp>,
- RewriteTruncExtShuffle
- >(ctx);
- // clang-format on
- if (mlir::failed(
- mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::Pass> CreateSimplifyArithPass() {
- return std::make_unique<SimplifyArithPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD
deleted file mode 100644
index 69b4bd0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD
+++ /dev/null
@@ -1,43 +0,0 @@
-load("//xla:lit.bzl", "lit_test_suite")
-load("//xla:xla.bzl", "xla_cc_binary")
-
-package(
- # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
- licenses = ["notice"],
-)
-
-xla_cc_binary(
- name = "mlir_fusions_opt",
- srcs = ["mlir_fusions_opt.cc"],
- deps = [
- "//xla/mlir_hlo",
- "//xla/service/gpu/fusions/mlir:passes",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
- "@llvm-project//mlir:AffineDialect",
- "@llvm-project//mlir:ArithDialect",
- "@llvm-project//mlir:ComplexDialect",
- "@llvm-project//mlir:DLTIDialect",
- "@llvm-project//mlir:FuncDialect",
- "@llvm-project//mlir:FuncExtensions",
- "@llvm-project//mlir:GPUDialect",
- "@llvm-project//mlir:LLVMDialect",
- "@llvm-project//mlir:MathDialect",
- "@llvm-project//mlir:MlirOptLib",
- "@llvm-project//mlir:NVVMDialect",
- "@llvm-project//mlir:SCFDialect",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TensorDialect",
- "@llvm-project//mlir:Transforms",
- "@llvm-project//mlir:VectorDialect",
- ],
-)
-
-lit_test_suite(
- name = "tests",
- srcs = glob(["*.mlir"]),
- cfg = "//xla:lit.cfg.py",
- tools = [
- ":mlir_fusions_opt",
- "@llvm-project//llvm:FileCheck",
- ],
-)
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir
deleted file mode 100644
index 17b0f8d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir
+++ /dev/null
@@ -1,179 +0,0 @@
-// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s
-
-#map0 = affine_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)>
-func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) {
- %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10], %s1 in [0, 2]]
- func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1, s0 mod 2)>
-
-// CHECK-LABEL: func.func @simplify_apply_indexing
-// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2)>
-func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index,
- %d2: index, %s0: index, %s1: index) -> (index, index, index) {
- %0:3 = xla_gpu.apply_indexing #map0
- (%d0 in [0, 1], %d1 in [0, 2], %d2 in [0, 3])
- [%s0 in [-11, 11], %s1 in [0, 3]]
- func.return %0#0, %0#1, %0#2 : index, index, index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1)>
-
-// CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims
-// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME: (%[[ARG_0]] in [0, 1], %[[ARG_2]] in [0, 3])
-// CHECK-SAME: [%[[ARG_3]] in [-11, 11]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0)>
-func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index)
- -> (index, index, index, index, index) {
- %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]]
- func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-
-// CHECK-LABEL: func.func @fold_indexing_map_results
-// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
-
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-
-// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]]
-// CHECK: return %[[NEW_RESULT]], %[[C4]], %[[ARG_1]], %[[C1]], %[[ARG_2]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)>
-func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) {
- %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]]
- func.return %0#2 : index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
-
-// CHECK-LABEL: func.func @remove_unused_results
-// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
-
-// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 2])
-// CHECK: return %[[NEW_RESULT]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)>
-func.func @fold_operands(%d0: index) -> index {
- %d1 = arith.constant 1 : index
- %s0 = arith.constant 2 : index
- %s1 = arith.constant 3 : index
- %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 10], %d1 in [0, 5])
- [%s0 in [-10, 10], %s1 in [0, 4]]
- func.return %0 : index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 3)>
-
-// CHECK-LABEL: func.func @fold_operands
-// CHECK-SAME: %[[ARG_0:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10])
-
-// -----
-
-func.func @fold_operands_and_results(%arg0: index, %arg1: index)
- -> (index, index) {
- %0:2 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (0, d1)>
- (%arg0 in [0, 4], %arg1 in [0, 5])
- return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @fold_operands_and_results
-// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
-// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
-// CHECK-NEXT: return %[[C0]], %[[ARG_1]] : index, index
-
-// -----
-
-func.func @fold_sequence(%arg0: index, %arg1: index) -> index {
- %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
- (%arg0 in [0, 5], %arg1 in [0, 4])
- %1 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 100 + 42)>
- (%0 in [0, 10000])
- func.return %1 : index
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)>
-// CHECK-LABEL: func.func @fold_sequence
-// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4])
-
-// -----
-
-func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index {
- %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
- (%arg0 in [0, 5], %arg1 in [0, 4])
- %1 = xla_gpu.apply_indexing affine_map<()[s0] -> (s0 mod 100 + 42)>
- [%0 in [0, 10000]]
- func.return %1 : index
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)>
-// CHECK-LABEL: func.func @fold_sequence_sym
-// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4])
-
-// -----
-
-func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index {
- %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
- (%arg0 in [0, 5], %arg1 in [0, 4])
- %1 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
- (%arg1 in [0, 4], %0 in [0, 10000])
- func.return %1 : index
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
-// CHECK-LABEL: func.func @fold_sequence_shared_operands
-// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME: (%[[ARG1]] in [0, 4], %[[ARG0]] in [0, 5])
-
-// -----
-
-func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index)
- -> (tensor<2x3xf32>) {
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
- ^bb0(%current : f32):
- xla_gpu.yield %current : f32
- }
- return %ret : tensor<2x3xf32>
-}
-// CHECK-LABEL: func.func @atomic_rmw_empty
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>
-// CHECK: return %[[ARG0]]
-
-
-// -----
-
-func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index)
- -> (tensor<2x3xf32>) {
- %cst = arith.constant 0.0 : f32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
- ^bb0(%current : f32):
- xla_gpu.yield %cst : f32
- }
- return %ret : tensor<2x3xf32>
-}
-// CHECK-LABEL: func.func @atomic_rmw_cst
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>
-// CHECK-NEXT: %[[CST:.*]] = arith.constant
-// CHECK-NEXT: atomic_rmw
-// CHECK: xla_gpu.yield %[[CST]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir
deleted file mode 100644
index ee2c2ae..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir
+++ /dev/null
@@ -1,140 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-flatten-tensors \
-// RUN: --verify-diagnostics | FileCheck %s
-
-func.func @tensor_extract(
- %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
- %arg1: index, %arg2: index) -> f32 {
- %v = tensor.extract %arg0[%arg1, %arg2]
- : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
- func.return %v : f32
-}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1 * 2 + d0)>
-
-// CHECK-LABEL: func.func @tensor_extract(
-// CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>,
-// CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index) -> f32 {
-// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
-// CHECK-SAME: in [0, 1], %[[J]] in [0, 2])
-// CHECK: tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32>
-
-// -----
-
-func.func @tensor_insert(
- %arg0: tensor<10x24xcomplex<f32>>) -> tensor<10x24xcomplex<f32>> {
- %c1 = arith.constant 1 : index
- %real = arith.constant 3.0 : f32
- %imag = arith.constant 2.0 : f32
- %complex = complex.create %real, %imag : complex<f32>
- %out = tensor.insert %complex into %arg0[%c1, %c1] : tensor<10x24xcomplex<f32>>
- func.return %out : tensor<10x24xcomplex<f32>>
-}
-// CHECK-LABEL: func.func @tensor_insert(
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<240xcomplex<f32>>) -> tensor<240xcomplex<f32>> {
-// CHECK: %[[INDEX:.*]] = arith.constant 25
-// CHECK: %[[COMPLEX:.*]] = complex.create
-// CHECK: tensor.insert %[[COMPLEX]] into %[[TENSOR]][%[[INDEX]]]
-// CHECK-SAME: : tensor<240xcomplex<f32>>
-
-// -----
-
-func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index)
- -> (tensor<2x4xf32>) {
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
- ^bb0(%current : f32):
- %c42 = arith.constant 1.0 : f32
- %add = arith.minimumf %current, %c42 : f32
- xla_gpu.yield %add : f32
- }
- return %ret : tensor<2x4xf32>
-}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
-
-// CHECK-LABEL: func.func @atomic_rmw(
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index,
-// CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> {
-// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME: (%[[I]] in [0, 1], %[[J]] in [0, 3])
-// CHECK: xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32>
-
-// -----
-
-func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>)
- -> (tensor<32x1024xf32>, tensor<64x8x4xf32>, f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c32 = arith.constant 32 : index
- %c64 = arith.constant 64 : index
- %c0_f32 = arith.constant 0.0 : f32
- %for:2 = scf.for %i = %c0 to %c64 step %c32 iter_args(%t0_ = %t0, %t1_ = %t1)
- -> (tensor<32x1024xf32>, tensor<64x8x4xf32>) {
- %update0 = tensor.insert %c0_f32 into %t0_[%c1, %i] : tensor<32x1024xf32>
- %update1 = tensor.insert %c0_f32 into %t1_[%i, %c1, %c1] : tensor<64x8x4xf32>
- scf.yield %update0, %update1 : tensor<32x1024xf32>, tensor<64x8x4xf32>
- } {some_attr}
- return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 + 1024)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 32 + 5)>
-// CHECK-LABEL: func.func @for_loop(
-// CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) {
-
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
-// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
-// CHECK-DAG: %[[F32:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]]
-// CHECK-SAME: step %[[C32]]
-// CHECK-SAME: iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]])
-// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]] in [0, 1023])
-// CHECK: %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]]
-// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]] in [0, 63])
-// CHECK: %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]]
-// CHECK: scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32>
-
-// -----
-
-#map = affine_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36)>
-#map1 = affine_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4)>
-#map2 = affine_map<(d0, d1) -> ((d1 * 128 + d0) mod 9)>
-func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>,
- %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>)
- -> tensor<4000x4x9xf32> {
- %c0 = arith.constant 0 : index
- %c3999 = arith.constant 3999 : index
- %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
- %bl_x = gpu.block_id x {xla.range = [0 : index, 393749 : index]}
- %0 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 393749])
- %extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32>
- %1 = arith.index_cast %extracted : i32 to index
- %2 = arith.cmpi ule, %1, %c3999 : index
- %3 = scf.if %2 -> (tensor<4000x4x9xf32>) {
- %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 393749])
- %5 = xla_gpu.apply_indexing #map2(%th_x in [0, 127], %bl_x in [0, 393749])
- %elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32>
- %atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> {
- ^bb0(%arg4: f32):
- %6 = arith.addf %arg4, %elem : f32
- xla_gpu.yield %6 : f32
- }
- scf.yield %atomic_rmw : tensor<4000x4x9xf32>
- } else {
- scf.yield %arg3 : tensor<4000x4x9xf32>
- }
- return %3 : tensor<4000x4x9xf32>
-}
-// CHECK-LABEL: func.func @if_op
-// CHECK-NOT: builtin.unrealized_conversion_cast
-// CHECK: scf.if %{{.*}} -> (tensor<144000xf32>) {
-// CHECK-COUNT-2: scf.yield %{{.*}} : tensor<144000xf32>
-// CHECK: return %{{.*}} : tensor<144000xf32>
-
-// -----
-
-func.func @dangling_cast(%arg0: tensor<6xf32>, %arg1: index) -> i32 {
- %v = tensor.extract %arg0[%arg1] : tensor<6xf32>
- %cast = builtin.unrealized_conversion_cast %v : f32 to i32
- func.return %cast : i32
-}
-// CHECK: FlattenTensorsPass failed to converge
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir
deleted file mode 100644
index c5cdeeb..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir
+++ /dev/null
@@ -1,136 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0, d1, d2)[s0] -> (d0),
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [1, 2],
-// CHECK-SAME: d1 in [5, 8],
-// CHECK-SAME: d2 in [10, 12],
-// CHECK-SAME: s0 in [0, 32],
-// CHECK-SAME: d0 mod 2 in [0, 1],
-// CHECK-SAME: d0 + s0 in [1, 10]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
- domain:
- d0 in [1, 2],
- d1 in [5, 8],
- d2 in [10, 12],
- s0 in [0, 32],
- d0 mod 2 in [0, 1],
- d0 + s0 in [1, 10]
- >
-
-func.func private @indexing_map_attr(tensor<32xf64, #map>)
-// CHECK-LABEL: @indexing_map_attr
-// CHECK: tensor<32xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [1, 2]
-// CHECK-SAME: d1 in [5, 8]
-// CHECK-SAME: s0 in [0, 10]
-// CHECK-SAME: s1 in [0, 5]
-// CHECK-SAME: s2 in [0, 32]
-// CHECK-SAME: d0 mod 2 in [0, 1]
-// CHECK-SAME: d0 + s0 in [1, 10]
-// CHECK-SAME: d1 + s1 + s2 in [1, 32]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),
- domain:
- d0 in [1, 2],
- d1 in [5, 8],
- s0 in [0, 10],
- s1 in [0, 5],
- s2 in [0, 32],
- d0 mod 2 in [0, 1],
- d0 + s0 in [1, 10],
- d1 + s1 + s2 in [1, 32]
- >
-func.func private @more_range_vars(tensor<32xf64, #map>)
-// CHECK-LABEL: @more_range_vars
-// CHECK: tensor<32xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0)[s0] -> (d0)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [0, 100]
-// CHECK-SAME: s0 in [-3, -1]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0),
- domain:
- d0 in [0, 100],
- s0 in [-3, -1]
- >
-func.func private @indexing_map_small(tensor<100xf64, #map>)
-// CHECK-LABEL: @indexing_map_small
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0, d1, d2)[s0] -> (d0)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [1, 2]
-// CHECK-SAME: d1 in [5, 8]
-// CHECK-SAME: d2 in [10, 12]
-// CHECK-SAME: s0 in [0, 32]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
- domain:
- d0 in [1, 2],
- d1 in [5, 8],
- d2 in [10, 12],
- s0 in [0, 32]
- >
-func.func private @no_constraints(tensor<32xf64, #map>)
-// CHECK-LABEL: @no_constraints
-// CHECK: tensor<32xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: ()[s0] -> (s0)
-// CHECK-SAME: domain:
-// CHECK-SAME: s0 in [3, 5]
-// CHECK-SAME: s0 mod 2 in [0, 1]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<()[s0] -> (s0),
- domain:
- s0 in [3, 5],
- s0 mod 2 in [0, 1]
- >
-func.func private @no_dimensions(tensor<100xf64, #map>)
-// CHECK-LABEL: @no_dimensions
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0) -> (d0)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [3, 5]
-// CHECK-SAME: d0 mod 2 in [0, 1]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0) -> (d0),
- domain:
- d0 in [3, 5],
- d0 mod 2 in [0, 1]
- >
-func.func private @no_symbols(tensor<100xf64, #map>)
-// CHECK-LABEL: @no_symbols
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: () -> ()
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<() -> ()>
-func.func private @empty(tensor<100xf64, #map>)
-// CHECK-LABEL: @empty
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir
deleted file mode 100644
index fbef7c0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)>
-func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
- // expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}}
- %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2])
- func.return %0#0, %0#1 : index, index
-}
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir
deleted file mode 100644
index 2125e6f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir
+++ /dev/null
@@ -1,854 +0,0 @@
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=6.0" \
-// RUN: | FileCheck %s
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=7.0" \
-// RUN: | FileCheck %s --check-prefix=CHECK-VOLTA
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=8.0" \
-// RUN: | FileCheck %s --check-prefix=CHECK-AMPERE
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=9.0" \
-// RUN: | FileCheck %s --check-prefix=CHECK-HOPPER
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx908:sramecc+:xnack" \
-// RUN: | FileCheck %s --check-prefix=CHECK-GFX908-MI100
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx90a:sramecc+:xnack" \
-// RUN: | FileCheck %s --check-prefix=CHECK-GFX90A-MI200
-
-module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>} {
- func.func private @add(%arg0: f32, %arg1: f32) -> f32 {
- %sum = arith.addf %arg0, %arg1 : f32
- func.return %sum : f32
- }
-
- func.func private @tensorarg(%arg0: tensor<43xf32> {xla.invariant, xla.slice_index = 0}, %arg1: index) -> f32 {
- %v1 = arith.constant 2.0 : f32
- %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32>
- %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32
- func.return %sum : f32
- }
-
- func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, %arg1: index) -> f32 {
- %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32
- func.return %call : f32
- }
-
- func.func @stores(%arg0: tensor<17xf32> {xla.slice_index = 0}, %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
- %c17 = arith.constant 17 : index
- %c23 = arith.constant 23 : index
- %cst = arith.constant 3.0 : f32
- %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32>
- %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32>
- func.return %out2 : tensor<43xf32>
- }
-}
-
-// CHECK: func.func private @add(%{{.*}}: f32, %{{.*}}: f32) -> f32 {
-// CHECK-NEXT: arith.addf
-// CHECK-NEXT: return
-
-// CHECK: func.func private @tensorarg(%[[ARG0:.*]]: !llvm.ptr
-// CHECK-SAME: {xla.invariant, xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) -> f32 {
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00
-// CHECK-DAG: %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i32
-// CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX]]]
-// CHECK-DAG: %[[V2:.*]] = llvm.load %[[PTR]] invariant
-// CHECK: %[[RET:.*]] = call @add(%[[C2]], %[[V2]])
-// CHECK: return %[[RET]]
-
-// CHECK: func.func @tensorcall(%[[ARG0:.*]]: !llvm.ptr
-// CHECK-SAME: {xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index)
-// CHECK: %[[RET:.*]] = call @tensorarg(%[[ARG0]], %[[ARG1]])
-// CHECK: return %[[RET]]
-
-// CHECK: func.func @stores(
-// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {xla.slice_index = 0 : i64},
-// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {xla.slice_index = 1 : i64})
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 3.000000e+00 : f32
-// CHECK-NEXT: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[ARG1]][17]
-// CHECK-NEXT: llvm.store %[[CST]], %[[PTR1]]
-// CHECK-NEXT: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[ARG1]][23]
-// CHECK-NEXT: llvm.store %[[CST]], %[[PTR2]]
-// CHECK-NEXT: return
-
-// -----
-
-module {
- func.func @layout(
- %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
- %arg1: index, %arg2: index) -> f32 {
- %v = tensor.extract %arg0[%arg1, %arg2]
- : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
- func.return %v : f32
- }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1 * 2 + d0)>
-// CHECK-LABEL: @layout(
-// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
-// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index
-// CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME: (%[[X]] in [0, 1], %[[Y]] in [0, 2])
-// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64
-// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]]
-// CHECK: llvm.load %[[PTR]]
-
-// -----
-
-module {
- func.func @store_control_flow(
- %arg0: tensor<2xf32>,
- %arg1: index
- ) -> tensor<2xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %cst = arith.constant 0.0 : f32
- %cst2 = arith.constant 1.0 : f32
-
- %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> {
- %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32>
- scf.yield %new_out : tensor<2xf32>
- }
-
- %inbounds = arith.cmpi sle, %arg1, %c1 : index
- %result = scf.if %inbounds -> tensor<2xf32> {
- %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32>
- scf.yield %if : tensor<2xf32>
- } else {
- scf.yield %for : tensor<2xf32>
- }
- func.return %result : tensor<2xf32>
- }
-}
-
-// CHECK: @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
-// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i64
-// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]]
-// CHECK: llvm.store {{.*}}, %[[PTR]]
-// CHECK: %[[INBOUNDS:.*]] = arith.cmpi
-// CHECK: scf.if %[[INBOUNDS]] {
-// CHECK: llvm.store
-// CHECK-NEXT: }
-// CHECK-NEXT: return
-
-// -----
-
-module {
- func.func @large_tensor(
- %arg0: tensor<1024x1024x1024x6xf32>,
- %arg1: index) -> f32 {
- %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32>
- func.return %v : f32
- }
-}
-
-// CHECK: @large_tensor
-// CHECK: arith.index_castui {{.*}} : index to i64
-
-// -----
-
-module {
- func.func @extract_from_constant(%arg0: tensor<2x1xf32>,
- %arg1: index, %arg2: index) -> f32 {
- %cst = arith.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>
- %extracted = tensor.extract %arg0[%arg1, %arg2] : tensor<2x1xf32>
- %extracted_0 = tensor.extract %cst[%arg1, %arg2] : tensor<2x1xf32>
- %0 = arith.addf %extracted, %extracted_0 : f32
- return %0 : f32
- }
-}
-// CHECK: llvm.mlir.global private constant @global_cst_0(dense<
-// CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32>
-// CHECK: @extract_from_constant
-// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr
-// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f32
-// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[LOAD]] : f32
-// CHECK: return %[[ADD]] : f32
-
-// -----
-
-module {
- func.func @vector_constant() -> vector<2xindex> {
- %c1 = arith.constant dense<[1, 2]> : vector<2xindex>
- func.return %c1 : vector<2xindex>
- }
-}
-
-// vector constants should not be rewritten.
-// CHECK: @vector_constant
-// CHECK-NEXT: arith.constant
-
-// -----
-
-module {
- func.func @complex_tensor_insert(
- %arg0: tensor<10xcomplex<f32>>) -> tensor<10xcomplex<f32>> {
- %c1 = arith.constant 1 : index
- %real = arith.constant 3.0 : f32
- %imag = arith.constant 2.0 : f32
- %complex = complex.create %real, %imag : complex<f32>
- %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex<f32>>
- func.return %out : tensor<10xcomplex<f32>>
- }
-}
-
-// CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr
-// CHECK: %[[C:.*]] = complex.create
-// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[C]] : complex<f32> to !llvm.struct<(f32, f32)>
-// CHECK: llvm.store %[[CAST]], %[[GEP]] : !llvm.struct<(f32, f32)>, !llvm.ptr
-
-// -----
-
-module {
- func.func @complex_tensor_extract(
- %arg0: tensor<10xcomplex<f32>>) -> complex<f32> {
- %c1 = arith.constant 1 : index
- %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex<f32>>
- func.return %v2 : complex<f32>
- }
-}
-
-// CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr
-// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
-// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)>
-// CHECK: builtin.unrealized_conversion_cast %[[LOAD]] : !llvm.struct<(f32, f32)> to complex<f32>
-
-// -----
-
-module {
- // This example is a bit silly, in real life there wouldn't be a loop (the
- // loop body would be executed by different threads). We're just doing it this
- // way so control flow with shared memory is tested as well.
- func.func @transpose_shared(%in: tensor<32x32xf32>,
- %out: tensor<32x32xf32>) -> tensor<32x32xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c32 = arith.constant 32 : index
-
- %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
- %loaded_tile = scf.for %i = %c0 to %c32 step %c1
- iter_args(%tile = %shared) -> tensor<32x32xf32> {
- %inner_loaded_tile = scf.for %j = %c0 to %c32 step %c1
- iter_args(%inner_tile = %tile) -> tensor<32x32xf32> {
- %v = tensor.extract %in[%i, %j] : tensor<32x32xf32>
- %inserted = tensor.insert %v into %inner_tile[%i, %j]
- : tensor<32x32xf32>
- scf.yield %inserted : tensor<32x32xf32>
- }
- scf.yield %inner_loaded_tile : tensor<32x32xf32>
- }
-
- %synced = xla_gpu.sync_threads %shared : tensor<32x32xf32>
- %written_tile = scf.for %i = %c0 to %c32 step %c1
- iter_args(%written = %out) -> tensor<32x32xf32> {
- %inner_written_tile = scf.for %j = %c0 to %c32 step %c1
- iter_args(%inner_written = %written) -> tensor<32x32xf32> {
- %v = tensor.extract %shared[%j, %i] : tensor<32x32xf32>
- %inserted = tensor.insert %v into %inner_written[%i, %j]
- : tensor<32x32xf32>
- scf.yield %inserted : tensor<32x32xf32>
- }
- scf.yield %inner_written_tile : tensor<32x32xf32>
- }
-
- return %written_tile : tensor<32x32xf32>
- }
-}
-
-// CHECK: llvm.mlir.global private @[[SHARED:shared_.*]]()
-// CHECK-SAME: {addr_space = 3 : i32} : !llvm.array<1024 x f32>
-// CHECK: @transpose_shared
-// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @[[SHARED]] : !llvm.ptr<3>
-// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[ADDR]]
-// CHECK-SAME: : !llvm.ptr<3> to !llvm.ptr
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
-// CHECK: llvm.store {{.*}} %[[ELEM_ADDR]]
-// CHECK: gpu.barrier
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
-// CHECK: llvm.load %[[ELEM_ADDR]]
-
-// -----
-
-module {
- func.func @atomic_rmw_f32(%in: tensor<2x4xf32>, %i: index, %j: index)
- -> (tensor<2x4xf32>) {
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
- ^bb0(%current : f32):
- %c42 = arith.constant 1.0 : f32
- %add = arith.minimumf %current, %c42 : f32
- xla_gpu.yield %add : f32
- }
- return %ret : tensor<2x4xf32>
- }
-}
-
-// CHECK: @atomic_rmw_f32
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]]
-// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
-// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f32 to i32
-// CHECK-NEXT: llvm.cmpxchg %[[ADDR]], %[[VAR]], %[[RES]]
-
-// -----
-
-module {
- func.func @atomic_rmw_f16(%in: tensor<2x4xf16>, %i: index, %j: index)
- -> (tensor<2x4xf16>) {
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
- ^bb0(%current : f16):
- %c1 = arith.constant 1.0 : f16
- %add = arith.addf %current, %c1 : f16
- xla_gpu.yield %add : f16
- }
- return %ret : tensor<2x4xf16>
- }
-}
-
-// CHECK: @atomic_rmw_f16
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
-// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
-// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
-// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
-// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
-// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
-// CHECK-NEXT: %[[VAR_SHIFT:.*]] = llvm.lshr %[[VAR]], %{{.*}}
-// CHECK-NEXT: %[[VAR_TRUNC:.*]] = llvm.trunc %[[VAR_SHIFT]]
-// CHECK-NEXT: llvm.bitcast %[[VAR_TRUNC]] : i16 to f16
-// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
-// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
-// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
-// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
-// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
-// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
-
-// -----
-
-module {
- func.func @atomic_rmw_overwrite(%in: tensor<2x4xf16>, %i: index, %j: index)
- -> (tensor<2x4xf16>) {
- %c1 = arith.constant 1.0 : f16
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
- ^bb0(%current : f16):
- xla_gpu.yield %c1 : f16
- }
- return %ret : tensor<2x4xf16>
- }
-}
-// CHECK: @atomic_rmw_overwrite
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
-// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
-// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
-// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
-// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
-// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
-// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
-// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
-// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
-// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
-// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
-// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
-
-// -----
-
-module {
- func.func @shared_complex() -> tensor<10xcomplex<f32>> {
- %shared = xla_gpu.allocate_shared : tensor<10xcomplex<f32>>
- return %shared : tensor<10xcomplex<f32>>
- }
-}
-
-// CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>>
-// CHECK: @shared_complex
-
-// -----
-
-module {
- func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> {
- %v = tensor.extract %arg[%i] : tensor<10xi4>
- %r = tensor.insert %v into %arg[%j] : tensor<10xi4>
- return %r : tensor<10xi4>
- }
-}
-
-// CHECK: @i4_load_store
-// CHECK: llvm.getelementptr
-// CHECK-SAME: -> !llvm.ptr, i8
-// CHECK: llvm.load
-// CHECK: llvm.getelementptr
-// CHECK-SAME: -> !llvm.ptr, i8
-// CHECK: llvm.load
-// CHECK: llvm.store
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_overwrite(%in: tensor<2x4xi32>,
- %i: index, %j: index) -> (tensor<2x4xi32>) {
- %c2 = arith.constant 2 : i32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
- ^bb0(%current : i32):
- xla_gpu.yield %c2 : i32
- }
- return %ret : tensor<2x4xi32>
- }
-}
-// CHECK: @direct_atomic_rmw_overwrite
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.store %[[C2]], %[[ADDR]] atomic unordered {alignment = 4 : i64}
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_addi(%in: tensor<2x4xi32>,
- %i: index, %j: index) -> (tensor<2x4xi32>) {
- %c2 = arith.constant 2 : i32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
- ^bb0(%current : i32):
- %min = arith.addi %current, %c2 : i32
- xla_gpu.yield %c2 : i32
- }
- return %ret : tensor<2x4xi32>
- }
-}
-// CHECK: @direct_atomic_rmw_addi
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw add %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_maxsi(%in: tensor<2x4xi32>,
- %i: index, %j: index) -> (tensor<2x4xi32>) {
- %c2 = arith.constant 2 : i32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
- ^bb0(%current : i32):
- %min = arith.maxsi %current, %c2 : i32
- xla_gpu.yield %c2 : i32
- }
- return %ret : tensor<2x4xi32>
- }
-}
-// CHECK: @direct_atomic_rmw_maxsi
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw max %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_maxui(%in: tensor<2x4xi32>,
- %i: index, %j: index) -> (tensor<2x4xi32>) {
- %c2 = arith.constant 2 : i32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
- ^bb0(%current : i32):
- %min = arith.maxui %current, %c2 : i32
- xla_gpu.yield %c2 : i32
- }
- return %ret : tensor<2x4xi32>
- }
-}
-// CHECK: @direct_atomic_rmw_maxui
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw umax %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_minsi(%in: tensor<2x4xi32>,
- %i: index, %j: index) -> (tensor<2x4xi32>) {
- %c2 = arith.constant 2 : i32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
- ^bb0(%current : i32):
- %min = arith.minsi %current, %c2 : i32
- xla_gpu.yield %c2 : i32
- }
- return %ret : tensor<2x4xi32>
- }
-}
-// CHECK: @direct_atomic_rmw_minsi
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw min %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_minui(%in: tensor<2x4xi32>,
- %i: index, %j: index) -> (tensor<2x4xi32>) {
- %c2 = arith.constant 2 : i32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
- ^bb0(%current : i32):
- %min = arith.minui %current, %c2 : i32
- xla_gpu.yield %c2 : i32
- }
- return %ret : tensor<2x4xi32>
- }
-}
-// CHECK: @direct_atomic_rmw_minui
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw umin %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_fadd_f32(%in: tensor<2x4xf32>,
- %i: index, %j: index) -> (tensor<2x4xf32>) {
- %c2 = arith.constant 2.0 : f32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
- ^bb0(%current : f32):
- %min = arith.addf %current, %c2 : f32
- xla_gpu.yield %c2 : f32
- }
- return %ret : tensor<2x4xf32>
- }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
-// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
-// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-GFX908-MI100: %[[C2:.*]] = arith.constant 2
-// CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-GFX908-MI100: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
-// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
-
-// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
-// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
-// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_fadd_f16(%in: tensor<2x4xf16>,
- %i: index, %j: index) -> (tensor<2x4xf16>) {
- %c2 = arith.constant 2.0 : f16
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
- ^bb0(%current : f16):
- %min = arith.addf %current, %c2 : f16
- xla_gpu.yield %c2 : f16
- }
- return %ret : tensor<2x4xf16>
- }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-NOT: llvm.atomicrmw fadd
-
-// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
-// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
-// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
-
-// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
-// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
-// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<2x4xbf16>,
- %i: index, %j: index) -> (tensor<2x4xbf16>) {
- %c2 = arith.constant 2.0 : bf16
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xbf16> {
- ^bb0(%current : bf16):
- %min = arith.addf %current, %c2 : bf16
- xla_gpu.yield %c2 : bf16
- }
- return %ret : tensor<2x4xbf16>
- }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_bf16
-// CHECK-NOT: llvm.atomicrmw fadd
-
-// CHECK-HOPPER-LABEL: @direct_atomic_rmw_fadd_bf16
-// CHECK-HOPPER: %[[C2:.*]] = arith.constant 2
-// CHECK-HOPPER: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-HOPPER: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_fadd_f64(%in: tensor<2x4xf64>,
- %i: index, %j: index) -> (tensor<2x4xf64>) {
- %c2 = arith.constant 2.0 : f64
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf64> {
- ^bb0(%current : f64):
- %min = arith.addf %current, %c2 : f64
- xla_gpu.yield %c2 : f64
- }
- return %ret : tensor<2x4xf64>
- }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
-// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
-// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
-
-// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-GFX90A-MI200-NOT: llvm.atomicrmw fadd
-
-// -----
-
-module {
- func.func @direct_atomic_rmw_maximumf(%in: tensor<2x4xf32>,
- %i: index, %j: index) -> (tensor<2x4xf32>) {
- %c2 = arith.constant 2.0 : f32
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
- ^bb0(%current : f32):
- %min = arith.maximumf %current, %c2 : f32
- xla_gpu.yield %c2 : f32
- }
- return %ret : tensor<2x4xf32>
- }
-}
-// CHECK-LABEL: @direct_atomic_rmw_maximumf
-
-// CHECK: %[[MODIFIER:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: %[[CURRENT:.*]] = llvm.load %[[ADDR]] : !llvm.ptr -> f32
-// CHECK: %[[CURRENT_IS_NAN:.*]] = llvm.fcmp "uno" %[[CURRENT]], %[[CURRENT]] : f32
-// CHECK: scf.if %[[CURRENT_IS_NAN]] {
-// CHECK: } else {
-// CHECK: %[[MODIFIER_IS_NAN:.*]] = llvm.fcmp "uno" %[[MODIFIER]], %[[MODIFIER]] : f32
-// CHECK: %[[MODIFIER_OR_NAN:.*]] = llvm.select %[[MODIFIER_IS_NAN]], %[[NAN]], %[[MODIFIER]] : i1, f32
-// CHECK: %[[VAL_13:.*]] = llvm.fcmp "ult" %[[CURRENT]], %[[MODIFIER_OR_NAN]] : f32
-// CHECK: scf.if %[[VAL_13]] {
-// CHECK: %[[INT_MODIFIER_OR_NAN:.*]] = llvm.bitcast %[[MODIFIER_OR_NAN]] : f32 to i32
-// CHECK: %[[IS_POSITIVE:.*]] = llvm.icmp "sge" %[[INT_MODIFIER_OR_NAN]], %[[C0]] : i32
-// CHECK: scf.if %[[IS_POSITIVE]] {
-// CHECK: llvm.atomicrmw max %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
-// CHECK: } else {
-// CHECK: llvm.atomicrmw umin %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: return
-
-// -----
-
-module {
- func.func @atomic_rmw_c32(%in: tensor<2x4xcomplex<f32>>, %i: index, %j: index)
- -> (tensor<2x4xcomplex<f32>>) {
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xcomplex<f32>> {
- ^bb0(%current : complex<f32>):
- %a = complex.add %current, %current : complex<f32>
- xla_gpu.yield %a : complex<f32>
- }
- return %ret : tensor<2x4xcomplex<f32>>
- }
-}
-
-// CHECK-LABEL: @atomic_rmw_c32
-
-// CHECK: scf.while (%[[ITER_ARG:.*]] = %{{.*}}) : (i64) -> i64
-// CHECK: %[[TMP:.*]] = llvm.alloca
-// CHECK: llvm.store %[[ITER_ARG]], %[[TMP]]
-// CHECK: %[[LD:.*]] = llvm.load %[[TMP]] : {{.*}} -> !llvm.struct<(f32, f32)>
-// CHECK: builtin.unrealized_conversion_cast %[[LD]] : {{.*}} to complex<f32>
-
-// -----
-
-module {
- func.func @unused_index_switch_results(%i: index) -> index {
- %ret, %ret2 = scf.index_switch %i -> tensor<2x4xi32>, tensor<3xf32>
- case 0 {
- %x, %y = "dummy.op1"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
- scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
- }
- default {
- %x, %y = "dummy.op2"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
- scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
- }
- return %i : index
- }
-}
-
-// CHECK-LABEL: func.func @unused_index_switch_results
-// CHECK-SAME: (%[[I:.*]]: index)
-// CHECK-NEXT: scf.index_switch %[[I]]
-// CHECK-NEXT: case 0 {
-// CHECK-NEXT: "dummy.op1"
-// CHECK-NEXT: scf.yield
-// CHECK-NEXT: }
-// CHECK-NEXT: default {
-// CHECK-NEXT: "dummy.op2"
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[I]] : index
-
-// -----
-
-module {
- func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
- %c16 = arith.constant 16 : index
- %c22 = arith.constant 22 : index
- %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32>
- %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32>
- %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32>
- func.return %out2 : tensor<43xf32>
- }
-}
-
-// CHECK-LABEL: @transfer_write
-// CHECK: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
-// CHECK-NEXT: llvm.store %[[CST:.*]], %[[PTR1]]
-// CHECK-NEXT: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
-// CHECK-NEXT: llvm.store %[[CST]], %[[PTR2]]
-
-// -----
-
-module {
- func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> {
- %c16 = arith.constant 16 : index
- %c0 = arith.constant 0.0 : f32
- %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32>
- func.return %out : vector<2xf32>
- }
-}
-
-// CHECK-LABEL: @transfer_read
-// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
-// CHECK-NEXT: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xf32>
-
-// -----
-
-module {
- func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1},
- %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> {
- %c16 = arith.constant 16 : index
- %c22 = arith.constant 22 : index
- %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1>
- %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1>
- func.return %out2 : tensor<43xi1>
- }
-}
-
-// CHECK-LABEL: @transfer_write_i1
-// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr
-// CHECK-SAME: %[[V1:.*]]: vector<2xi1>, %[[V2:.*]]: vector<2xi1>)
-// CHECK-DAG: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
-// CHECK-DAG: %[[V1_EXT:.*]] = arith.extui %[[V1]]
-// CHECK: llvm.store %[[V1_EXT]], %[[PTR1]]
-// CHECK-DAG: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
-// CHECK-DAG: %[[V2_EXT:.*]] = arith.extui %[[V2]]
-// CHECK: llvm.store %[[V2_EXT]], %[[PTR2]]
-
-// -----
-
-module {
- func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> {
- %c16 = arith.constant 16 : index
- %false = arith.constant false
- %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1>
- func.return %out : vector<2xi1>
- }
-}
-
-// CHECK-LABEL: @transfer_read_i1
-// CHECK-DAG: %[[C0:.*]] = arith.constant dense<0> : vector<2xi8>
-// CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
-// CHECK: %[[LOADED:.*]] = llvm.load %[[PTR]] : !llvm.ptr
-// CHECK: %[[CAST:.*]] = arith.cmpi ne, %[[LOADED]], %[[C0]]
-// CHECK: return %[[CAST]] : vector<2xi1>
-
-// -----
-
-module {
- func.func @transfer_write_i4(%arg0: tensor<43xi4> {xla.slice_index = 1},
- %v1: vector<4xi4>) -> tensor<43xi4> {
- %c16 = arith.constant 16 : index
- %out = vector.transfer_write %v1, %arg0[%c16] : vector<4xi4>, tensor<43xi4>
- func.return %out : tensor<43xi4>
- }
-}
-
-// CHECK-LABEL: @transfer_write_i4
-// CHECK-SAME: , %[[V1:.*]]: vector<4xi4>
-// CHECK-DAG: %[[A0:.*]] = vector.extract %[[V1]][0]
-// CHECK-DAG: %[[A1:.*]] = vector.extract %[[V1]][1]
-// CHECK-DAG: %[[A2:.*]] = vector.extract %[[V1]][2]
-// CHECK-DAG: %[[A3:.*]] = vector.extract %[[V1]][3]
-// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1]
-// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0]
-// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3]
-// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2]
-
-module {
- func.func @transfer_read_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}) -> vector<4xi4> {
- %c16 = arith.constant 16 : index
- %c0 = arith.constant 0 : i4
- %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xi4>, vector<4xi4>
- func.return %out : vector<4xi4>
- }
-}
-
-// CHECK-LABEL: @transfer_read_i4
-// CHECK: %[[LOADED:.*]] = llvm.load
-// CHECK-DAG: %[[A0:.*]] = vector.extract %[[LOADED]][0]
-// CHECK-DAG: %[[A1:.*]] = vector.extract %[[LOADED]][1]
-// CHECK-DAG: %[[A2:.*]] = vector.extract %[[LOADED]][2]
-// CHECK-DAG: %[[A3:.*]] = vector.extract %[[LOADED]][3]
-// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1]
-// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0]
-// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3]
-// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir
deleted file mode 100644
index 645430a..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir
+++ /dev/null
@@ -1,141 +0,0 @@
-// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf | FileCheck %s
-
-module {
- func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
- return %a, %b : f32, i32
- }
-
- func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
- %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32
- return %ret#0, %ret#1 : f32, i32
- }
-}
-
-// CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32)
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4
-// CHECK-DAG: %[[C32:.*]] = arith.constant 32
-// CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]]
-// CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]]
-// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
-// CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]]
-// CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]]
-// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @reducer(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
-// CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]]
-// CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]]
-// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @reducer(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
-// CHECK: return %[[AB1_0]], %[[AB1_1]]
-
-// -----
-
-module {
- func.func @reducer(%a: f64, %b: f64) -> f64 {
- return %a : f64
- }
-
- func.func @shuffler(%a: f64) -> f64 {
- %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64
- return %ret : f64
- }
-}
-
-// CHECK: @shuffler(%[[A:.*]]: f64
-// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
-// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
- func.func @reducer(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
- return %a : complex<f64>
- }
-
- func.func @shuffler(%a: complex<f64>) -> complex<f64> {
- %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex<f64>
- return %ret : complex<f64>
- }
-}
-
-// CHECK: @shuffler
-// CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
- func.func @reducer(%a: ui64, %b: ui64) -> ui64 {
- return %a : ui64
- }
-
- func.func @shuffler(%a: ui64) -> ui64 {
- %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64
- return %ret : ui64
- }
-}
-
-// CHECK: @shuffler
-// CHECK: unrealized_conversion_cast
-// CHECK-COUNT-2: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
- func.func @reducer(%a: i8, %b: i8) -> i8 {
- return %a : i8
- }
-
- func.func @shuffler_i8(%a: i8) -> i8 {
- %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8
- return %ret : i8
- }
-}
-
-// CHECK: @shuffler_i8(
-// CHECK-NOT: vector
-// CHECK-COUNT-1: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
- func.func @predicated_insert(
- %v: i32, %tensor: tensor<2xi32>, %index: index,
- %cond: i1) -> tensor<2xi32> {
- %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond
- : tensor<2xi32>
- return %ret : tensor<2xi32>
- }
-}
-
-// CHECK: @predicated_insert
-// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
-// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
-// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
-// CHECK-NEXT: %[[UPD:.*]] = tensor.insert %[[V]] into %[[TENSOR]][%[[INDEX]]]
-// CHECK-NEXT: yield %[[UPD]]
-// CHECK-NEXT: else
-// CHECK-NEXT: yield %[[TENSOR]]
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[RET]]
-
-// -----
-
-module {
- func.func @predicated_extract(
- %v: i32, %tensor: tensor<2xi32>, %index: index,
- %cond: i1) -> i32 {
- %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v
- : tensor<2xi32>
- return %ret : i32
- }
-}
-
-// CHECK: @predicated_extract
-// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
-// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
-// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
-// CHECK-NEXT: %[[VAL:.*]] = tensor.extract %[[TENSOR]][%[[INDEX]]]
-// CHECK-NEXT: yield %[[VAL]]
-// CHECK-NEXT: else
-// CHECK-NEXT: yield %[[V]]
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[RET]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc
deleted file mode 100644
index 8e9fb47..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
-#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Tools/mlir-opt/MlirOptMain.h"
-#include "mlir/Transforms/Passes.h"
-#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-
-int main(int argc, char **argv) {
- mlir::DialectRegistry registry;
- registry.insert<mlir::DLTIDialect, mlir::tensor::TensorDialect,
- mlir::func::FuncDialect, mlir::affine::AffineDialect,
- mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
- mlir::math::MathDialect, mlir::scf::SCFDialect,
- mlir::mhlo::MhloDialect, mlir::LLVM::LLVMDialect,
- mlir::gpu::GPUDialect, mlir::mhlo::MhloDialect,
- mlir::vector::VectorDialect, xla::gpu::XlaGpuDialect,
- mlir::NVVM::NVVMDialect>();
- mlir::func::registerAllExtensions(registry);
- mlir::registerCanonicalizerPass();
- mlir::registerCSEPass();
- mlir::registerInliner();
- xla::gpu::registerGpuFusionTransformsPasses();
-
- return mlir::failed(
- MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry));
-}
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir
deleted file mode 100644
index c7f1507..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir
+++ /dev/null
@@ -1,96 +0,0 @@
-// R-UN: mlir_fusions_opt %s --split-input-file | FileCheck %s
-// Verify the printed output can be parsed.
-// RU-N: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s
-// Verify the generic form can be parsed.
-// RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s
-
-func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) {
- %shared1 = xla_gpu.allocate_shared : tensor<2xf32>
- %shared2 = xla_gpu.allocate_shared : tensor<2xf32>
- %sync:2 = xla_gpu.sync_threads %shared1, %shared2
- : tensor<2xf32>, tensor<2xf32>
- return %sync#0, %sync#1 : tensor<2xf32>, tensor<2xf32>
-}
-// CHECK-LABEL: @shared_and_sync
-// CHECK-NEXT: allocate_shared
-// CHECK-NEXT: allocate_shared
-// CHECK-NEXT: sync_threads
-// CHECK-NEXT: return
-
-// -----
-
-func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index)
- -> (tensor<2x3xf32>) {
- %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
- ^bb0(%current : f32):
- %c42 = arith.constant 42.0 : f32
- %add = arith.addf %current, %c42 : f32
- xla_gpu.yield %add : f32
- }
- return %ret : tensor<2x3xf32>
-}
-// CHECK-LABEL: @atomic_rmw
-// CHECK: xla_gpu.atomic_rmw
-
-// -----
-
-func.func private @add(%a: f32, %b: f32) -> f32 {
- %ret = arith.addf %a, %b : f32
- return %ret : f32
-}
-
-func.func @caller(%a: f32, %b: f32) -> f32 {
- %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
- %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
- %ret = arith.addf %c, %d : f32
- return %ret : f32
-}
-// CHECK-LABEL: @caller
-// CHECK: %[[C:.*]] = xla_gpu.pure_call @add
-// CHECK: %[[D:.*]] = xla_gpu.pure_call @add
-// CHECK: arith.addf %[[C]], %[[D]]
-
-// CHECK-CSE: @caller
-// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add
-// CHECK-CSE: arith.addf %[[C]], %[[C]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)>
-func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
- %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])[%s0 in [2, 4]]
- func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)>
-
-// CHECK-LABEL: @apply_indexing
-// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
-// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3])[%[[s0]] in [2, 4]]
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) {
- %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])
- func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @apply_indexing_no_symbols
-// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
-// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3])
-
-// -----
-
-#map0 = affine_map<()[s0] -> (s0, s0)>
-func.func @apply_indexing_no_dims(%s0: index) -> (index, index) {
- %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 4]]
- func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0, s0)>
-
-// CHECK-LABEL: @apply_indexing_no_dims
-// CHECK: (%[[s0:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 4]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir
deleted file mode 100644
index 6f903f3..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir
+++ /dev/null
@@ -1,182 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s
-
-#map = affine_map<(d0) -> (d0 floordiv 8)>
-#map1 = affine_map<(d0) -> (d0 mod 8)>
-#map2 = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)>
-module {
- func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>,
- %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>,
- %arg4: tensor<4x8x4096xbf16>, %arg5: tensor<4x8xf32>,
- %arg6: tensor<4x8x4096xf32>) -> (tensor<4x8x4096xf32>, f32) {
- %cst = arith.constant 1.000000e+00 : f32
- %cst_1 = arith.constant 1.000000e+00 : bf16
- %c2 = arith.constant 2 : index
- %c8 = arith.constant 8 : index
- %c32 = arith.constant 32 : index
- %c1 = arith.constant 1 : index
- %c0 = arith.constant 0 : index
- %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 255 : index]}
- %block_id_x = gpu.block_id x {xla.range = [0 : index, 31 : index]}
- %0 = gpu.lane_id
- %1 = arith.cmpi eq, %0, %c0 : index
- %2 = arith.divui %thread_id_x, %c32 : index
- %3 = arith.cmpi ult, %thread_id_x, %c8 : index
- %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 31])
- %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 31])
- %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32>
- %6 = arith.mulf %extracted, %cst : f32
- %7 = arith.addf %6, %cst : f32
- %8 = math.rsqrt %7 : f32
- %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) {
- %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
- %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
- %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
- %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
- %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
- %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16>
- %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
- %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32>
- %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) {
- %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
- %28 = vector.extract %25[%arg10] : f32 from vector<2xf32>
- %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16>
- %30 = arith.extf %29 : bf16 to f32
- %31 = vector.extract %21[%arg10] : bf16 from vector<2xbf16>
- %32 = arith.extf %31 : bf16 to f32
- %33 = arith.mulf %30, %32 : f32
- %34 = arith.mulf %33, %8 : f32
- %35 = vector.extract %19[%arg10] : bf16 from vector<2xbf16>
- %36 = arith.extf %35 : bf16 to f32
- %37 = arith.addf %36, %cst : f32
- %38 = arith.mulf %34, %37 : f32
- %39 = arith.addf %28, %38 : f32
- %40 = arith.mulf %39, %39 : f32
- %41 = arith.addf %arg12, %40 : f32
- %inserted = tensor.insert %39 into %arg11[%4, %5, %27] : tensor<4x8x4096xf32>
- scf.yield %inserted, %41 : tensor<4x8x4096xf32>, f32
- }
- scf.yield %26#0, %26#1 : tensor<4x8x4096xf32>, f32
- }
- return %9#0, %9#1 : tensor<4x8x4096xf32>, f32
- }
-}
-
-// CHECK-LABEL: @fully_unroll
-// CHECK-NOT: scf.for
-
-// -----
-
-module {
- func.func @unroll_by_factor(%arg0: f32) -> f32 {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c256 = arith.constant 256 : index
- %ret = scf.for %i = %c0 to %c256 step %c1 iter_args (%v = %arg0) -> (f32) {
- %exp = math.exp %v : f32
- %add = arith.addf %v, %exp : f32
- %log = math.log %add : f32
- scf.yield %log : f32
- }
- return %ret : f32
- }
-}
-
-// CHECK-LABEL: @unroll_by_factor
-// CHECK: %[[C8:.*]] = arith.constant 8 : index
-// CHECK: scf.for {{.*}} step %[[C8]]
-
-// -----
-
-module {
- func.func @do_not_unroll(%arg0: f32) -> f32 {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c31 = arith.constant 31 : index
- %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%v = %arg0) -> (f32) {
- %exp = math.exp %v : f32
- %add = arith.addf %v, %exp : f32
- %log = math.log %add : f32
- scf.yield %log : f32
- }
- return %ret : f32
- }
-}
-
-// CHECK-LABEL: @do_not_unroll
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: scf.for {{.*}} step %[[C1]]
-
-// -----
-
-module {
- func.func @pipeline_extract(%arg: tensor<31xf32>) -> f32 {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c31 = arith.constant 31 : index
- %cst = arith.constant 0.0 : f32
- %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%iter = %cst) -> (f32) {
- %val = tensor.extract %arg[%i] : tensor<31xf32>
- %log = math.log %val : f32
- %add = arith.addf %log, %iter : f32
- scf.yield %add : f32
- }
- return %ret : f32
- }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 1)>
-// CHECK-LABEL: @pipeline_extract
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index
-// CHECK: %[[VAL0:.*]] = tensor.extract %[[ARG0:.*]][%[[C0]]]
-// CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
-// CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C30]]
-// CHECK-DAG: %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
-// CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
-// CHECK-NEXT: tensor.extract %[[ARG0]][%[[NEXT_I]]]
-// CHECK-NEXT: yield
-// CHECK-NEXT: else
-// CHECK-NEXT: yield %[[VAL]]
-// CHECK: math.log %[[VAL]]
-// CHECK: %[[ADD:.*]] = arith.addf
-// CHECK: yield %[[ADD]], %[[NEXT_VAL]]
-
-// -----
-
-module {
- func.func @pipeline_transfer(%arg: tensor<34xf32>) -> vector<2xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c17 = arith.constant 17 : index
- %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32>
- %cst0 = arith.constant 0.0 : f32
- %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) {
- %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 15])
- %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32>
- %log = math.log %val : vector<2xf32>
- %add = arith.addf %log, %iter : vector<2xf32>
- scf.yield %add : vector<2xf32>
- }
- return %ret : vector<2xf32>
- }
-}
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 2)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 + 1)>
-// CHECK-LABEL: @pipeline_transfer
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[BASE0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[C0]]
-// CHECK: %[[VAL0:.*]] = vector.transfer_read %[[ARG0:.*]][%[[BASE0]]]
-// CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
-// CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C16]]
-// CHECK-DAG: %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]
-// CHECK-DAG: %[[NEXT_BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[NEXT_I]]
-// CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
-// CHECK-NEXT: vector.transfer_read %[[ARG0]][%[[NEXT_BASE]]]
-// CHECK-NEXT: yield
-// CHECK-NEXT: else
-// CHECK-NEXT: yield %[[VAL]]
-// CHECK: math.log %[[VAL]]
-// CHECK: %[[ADD:.*]] = arith.addf
-// CHECK: yield %[[ADD]], %[[NEXT_VAL]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir
deleted file mode 100644
index ec1a726..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir
+++ /dev/null
@@ -1,145 +0,0 @@
-// RUN: mlir_fusions_opt --allow-unregistered-dialect %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s
-
-func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
- %0 = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
- %1 = gpu.block_id x {xla.range = [0 : index, 3071 : index]}
- scf.for %arg3 = %c0 to %c4 step %c1 {
- %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))>()[%1, %0, %arg3]
- %3 = arith.index_castui %2 : index to i64
- %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- %5 = llvm.load %4 invariant : !llvm.ptr -> f32
- %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- %9 = llvm.load %8 invariant : !llvm.ptr -> f32
- %10 = arith.cmpf oge, %5, %9 : f32
- %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
- llvm.store %10, %11 : i1, !llvm.ptr
- }
- return
-}
-
-// CHECK-LABEL: @op_and_for_ranges
-// CHECK-DAG: %[[C512:.*]] = arith.constant 512
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4
-// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
-// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
-// CHECK: scf.for %[[I:.*]] =
-// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
-// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
-// CHECK: %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
-// CHECK: arith.addi %[[OFFSET]], %[[I]]
-
-// -----
-
-func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
- %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
- return %0 : index
-}
-
-// CHECK-LABEL: @arg_ranges
-// CHECK-NEXT: %[[C100:.*]] = arith.constant 100
-// CHECK-NEXT: %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
-// CHECK-NEXT: return %[[RET]]
-
-// -----
-
-func.func @cant_lower(%arg0: index {xla.range = [-10 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
- %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
- return %0 : index
-}
-
-// CHECK-LABEL: @cant_lower
-// CHECK: affine.apply
-
-// -----
-
-func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
- %0 = gpu.thread_id x
- %1 = gpu.block_id x
- scf.for %i = %c0 to %c4 step %c1 {
- %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))>
- [%1 in [0, 3071], %0 in [0, 127], %i in [0, 3]]
- %3 = arith.index_castui %2 : index to i64
- %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- %5 = llvm.load %4 invariant : !llvm.ptr -> f32
- %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- %9 = llvm.load %8 invariant : !llvm.ptr -> f32
- %10 = arith.cmpf oge, %5, %9 : f32
- %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
- llvm.store %10, %11 : i1, !llvm.ptr
- }
- return
-}
-
-// CHECK-LABEL: @op_and_for_ranges
-// CHECK-DAG: %[[C512:.*]] = arith.constant 512
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4
-// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
-// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
-// CHECK: scf.for %[[I:.*]] =
-// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
-// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
-// CHECK: %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
-// CHECK: arith.addi %[[OFFSET]], %[[I]]
-
-// -----
-
-func.func @arg_ranges(%arg0: index, %arg1: index) -> index {
- %0 = xla_gpu.apply_indexing
- affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>
- [%arg0 in [0, 42], %arg1 in [0, 1000]]
- return %0 : index
-}
-
-// CHECK-LABEL: @arg_ranges
-// CHECK-NEXT: %[[C100:.*]] = arith.constant 100
-// CHECK-NEXT: %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
-// CHECK-NEXT: return %[[RET]]
-
-// -----
-
-func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) {
- %0:2 = xla_gpu.apply_indexing
- affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)>
- [%arg0 in [-10, 42], %arg1 in [0, 1000]]
- return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: @cant_lower
-// CHECK: affine.apply
-// CHECK-NEXT: arith.addi
-
-// -----
-
-func.func @order_summands(%arg1: index) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
- scf.for %arg2 = %c0 to %c4 step %c1 {
- scf.for %arg3 = %c0 to %c4 step %c1 {
- %0 = xla_gpu.apply_indexing
- affine_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)>
- [%arg2 in [0, 3], %arg1 in [0, 3], %arg3 in [0, 3]]
- "dummy.op"(%0) : (index) -> ()
- }
- }
- return
-}
-
-// CHECK-LABEL: @order_summands
-// CHECK-SAME: (%[[ARG1:.*]]: index)
-// CHECK: scf.for %[[ARG2:.*]] =
-// CHECK: scf.for %[[ARG3:.*]] =
-// CHECK: arith.muli %[[ARG1]]
-// CHECK: arith.muli %[[ARG2]]
-// CHECK: arith.addi
-// CHECK: arith.addi %[[ARG1]], %[[ARG2]]
-// CHECK: arith.divui
-// CHECK: arith.addi
-// CHECK: arith.muli %[[ARG3]]
-// CHECK: arith.addi %5, %6 : index
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir
deleted file mode 100644
index ee2e0dd..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir
+++ /dev/null
@@ -1,292 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -cse -canonicalize | FileCheck %s
-
-module {
- func.func @unknown(%arg0: index {xla.range = [0 : index, 42 : index]}) -> i1 {
- %c12 = arith.constant 12 : index
- %eq = arith.cmpi eq, %arg0, %c12 : index
- return %eq : i1
- }
-}
-
-// CHECK: @unknown
-// CHECK: cmpi
-
-// -----
-
-module {
- func.func @true(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
- %c5 = arith.constant 5 : index
- %eq = arith.cmpi sge, %arg0, %c5 : index
- return %eq : i1
- }
-}
-
-// CHECK: @true
-// CHECK-NEXT: constant true
-// CHECK-NEXT: return
-
-// -----
-
-module {
- func.func @false(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
- %c5 = arith.constant 5 : index
- %eq = arith.cmpi slt, %arg0, %c5 : index
- return %eq : i1
- }
-}
-
-// CHECK: @false
-// CHECK-NEXT: constant false
-// CHECK-NEXT: return
-
-// -----
-
-module {
- func.func @rhs_range(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
- %c42 = arith.constant 64 : index
- %eq = arith.cmpi slt, %c42, %arg0 : index
- return %eq : i1
- }
-}
-
-// CHECK: @rhs_range
-// CHECK-NEXT: constant false
-// CHECK-NEXT: return
-
-// -----
-
-module {
- func.func @both_range(%arg0: index {xla.range = [12 : index, 42 : index]},
- %arg1: index {xla.range = [63 : index, 100 : index]}) -> i1 {
- %eq = arith.cmpi slt, %arg0, %arg1 : index
- return %eq : i1
- }
-}
-
-// CHECK-LABEL: @both_range
-// CHECK-NEXT: constant true
-// CHECK-NEXT: return
-
-// -----
-
-module {
- func.func @minsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
- %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
- %min = arith.minsi %arg0, %arg1 : index
- return %min : index
- }
-}
-
-// CHECK-LABEL: @minsi_lhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG0]]
-
-// -----
-
-module {
- func.func @minsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
- %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
- %min = arith.minsi %arg1, %arg0 : index
- return %min : index
- }
-}
-
-// CHECK-LABEL: @minsi_rhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG0]]
-
-// -----
-
-module {
- func.func @maxsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
- %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
- %min = arith.maxsi %arg1, %arg0 : index
- return %min : index
- }
-}
-
-// CHECK-LABEL: @maxsi_lhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG1]]
-
-// -----
-
-module {
- func.func @maxsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
- %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
- %min = arith.maxsi %arg0, %arg1 : index
- return %min : index
- }
-}
-
-// CHECK-LABEL: @maxsi_rhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG1]]
-
-// -----
-
-module {
- func.func @maxsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
- %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
- %add = arith.addi %arg0, %arg1 : index
- %min = arith.maxsi %add, %arg1 : index
- return %min : index
- }
-}
-
-// CHECK-LABEL: @maxsi_add
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG1]]
-// CHECK-NEXT: return %[[ADD]]
-
-// -----
-
-module {
- func.func @minsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
- %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
- %add = arith.addi %arg0, %arg1 : index
- %min = arith.minsi %add, %arg1 : index
- return %min : index
- }
-}
-
-// CHECK-LABEL: @minsi_add
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG1]]
-
-// -----
-
-module {
- func.func @pred_reduce(%in: i1) -> i1 {
- %c1_i32 = arith.constant 1 : i32
- %c2_i32 = arith.constant 2 : i32
- %c4_i32 = arith.constant 4 : i32
- %c8_i32 = arith.constant 8 : i32
- %c16_i32 = arith.constant 16 : i32
- %c32_i32 = arith.constant 32 : i32
- %0 = arith.extui %in : i1 to i32
- %shuffleResult, %valid = gpu.shuffle down %0, %c16_i32, %c32_i32 : i32
- %1 = arith.trunci %shuffleResult : i32 to i1
- %2 = arith.ori %in, %1 : i1
- %3 = arith.extui %2 : i1 to i32
- %shuffleResult_0, %valid_1 = gpu.shuffle down %3, %c8_i32, %c32_i32 : i32
- %4 = arith.trunci %shuffleResult_0 : i32 to i1
- %5 = arith.ori %2, %4 : i1
- %6 = arith.extui %5 : i1 to i32
- %shuffleResult_2, %valid_3 = gpu.shuffle down %6, %c4_i32, %c32_i32 : i32
- %7 = arith.trunci %shuffleResult_2 : i32 to i1
- %8 = arith.ori %5, %7 : i1
- %9 = arith.extui %8 : i1 to i32
- %shuffleResult_4, %valid_5 = gpu.shuffle down %9, %c2_i32, %c32_i32 : i32
- %10 = arith.trunci %shuffleResult_4 : i32 to i1
- %11 = arith.ori %8, %10 : i1
- %12 = arith.extui %11 : i1 to i32
- %shuffleResult_6, %valid_7 = gpu.shuffle down %12, %c1_i32, %c32_i32 : i32
- %13 = arith.trunci %shuffleResult_6 : i32 to i1
- %14 = arith.ori %11, %13 : i1
- return %14 : i1
- }
-}
-
-// CHECK-LABEL: @pred_reduce
-// CHECK-SAME: (%[[IN:.*]]: i1)
-// CHECK: %[[IN_EXT:.*]] = arith.extui %[[IN]]
-// CHECK-NEXT: %[[SHUFFLE0:.*]], {{.*}} = gpu.shuffle down %[[IN_EXT]]
-// CHECK-NEXT: %[[OR0:.*]] = arith.ori %[[IN_EXT]], %[[SHUFFLE0]]
-// CHECK-NEXT: %[[SHUFFLE1:.*]], {{.*}} = gpu.shuffle down %[[OR0]]
-// CHECK-NEXT: %[[OR1:.*]] = arith.ori %[[OR0]], %[[SHUFFLE1]]
-// CHECK-NEXT: %[[SHUFFLE2:.*]], {{.*}} = gpu.shuffle down %[[OR1]]
-// CHECK-NEXT: %[[OR2:.*]] = arith.ori %[[OR1]], %[[SHUFFLE2]]
-// CHECK-NEXT: %[[SHUFFLE3:.*]], {{.*}} = gpu.shuffle down %[[OR2]]
-// CHECK-NEXT: %[[OR3:.*]] = arith.ori %[[OR2]], %[[SHUFFLE3]]
-// CHECK-NEXT: %[[SHUFFLE4:.*]], {{.*}} = gpu.shuffle down %[[OR3]]
-// CHECK-NEXT: %[[OR4:.*]] = arith.ori %[[OR3]], %[[SHUFFLE4]]
-// CHECK-NEXT: %[[RET:.*]] = arith.trunci %[[OR4]]
-// CHECK-NEXT: return %[[RET]]
-
-// -----
-
-module {
- func.func @andi_no_trunc_arg(%a: i4, %b: i8) -> i4 {
- %lhs = arith.extui %a : i4 to i8
- %add = arith.andi %lhs, %b : i8
- %ret = arith.trunci %add : i8 to i4
- return %ret : i4
- }
-}
-
-// CHECK-LABEL: @andi_no_trunc_arg
-// CHECK-NEXT: extui
-// CHECK-NEXT: andi
-// CHECK-NEXT: trunci
-// CHECK-NEXT: return
-
-// -----
-
-module {
- func.func @ori_mismatched_narrowest(%a: i8, %b: i8) -> i8 {
- %0 = arith.trunci %a : i8 to i4
- %1 = arith.extui %0 : i4 to i8
- %ret = arith.ori %b, %1 : i8
- return %ret : i8
- }
-}
-
-// CHECK-LABEL: @ori_mismatched_narrowest
-// CHECK-NEXT: trunci
-// CHECK-NEXT: extui
-// CHECK-NEXT: ori
-// CHECK-NEXT: return
-
-// -----
-
-func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c3 = arith.constant 3 : index
- %c42_f32 = arith.constant 42.0 : f32
- %loop = scf.for %i = %c0 to %c3 step %c1
- iter_args(%in_ = %tensor) -> (tensor<100xf32>) {
- %0 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 4)> (%i in [0, 9])
- %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32>
- scf.yield %updated :tensor<100xf32>
- }
- func.return %loop : tensor<100xf32>
-}
-// CHECK-LABEL: func.func @refine_constraints
-// CHECK: %[[CST:.*]] = arith.constant 4.2
-// CHECK: scf.for
-// CHECK: tensor.insert %[[CST]]
-
-
-// -----
-
-#map = affine_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)>
-#map1 = affine_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)>
-func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>,
- %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
- %c73 = arith.constant 73 : index
- %c42_f32 = arith.constant 42.0 : f32
- %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
- %bl_x = gpu.block_id x {xla.range = [0 : index, 575 : index]}
- %0 = scf.for %i = %c0 to %c73 step %c1 iter_args(%arg3 = %arg1)
- -> (tensor<2400000x9xf32>) {
- %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3)
- -> (tensor<2400000x9xf32>) {
- %3 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 575])
- [%i in [0, 73], %j in [0, 3]]
- %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 575])
- [%j in [0, 3]]
- %inserted = tensor.insert %c42_f32 into %arg5[%3, %4]
- : tensor<2400000x9xf32>
- scf.yield %inserted : tensor<2400000x9xf32>
- }
- scf.yield %2 : tensor<2400000x9xf32>
- }
- return %0 : tensor<2400000x9xf32>
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768)>
-// CHECK-LABEL: func.func @refine_constraints_for_symbol
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir
deleted file mode 100644
index 1141d15..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir
+++ /dev/null
@@ -1,405 +0,0 @@
-// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file -xla-gpu-vectorize-loads-stores -canonicalize -cse | FileCheck %s
-
-#map = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
-module {
- func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c64 = arith.constant 64 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
- %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 * 2)>
-// CHECK-LABEL: @simple_read
-// CHECK-SAME: (%[[ARG0:.*]]: tensor
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
-// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] =
-// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 63])
-// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]]
-// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]]
-// CHECK-NEXT: vector.extract %[[V]][%[[J]]]
-// CHECK-NEXT: addf
-
-// -----
-
-module {
- func.func @simple_read_2d(%arg0: tensor<64x2xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c64 = arith.constant 64 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK-LABEL: @simple_read_2d
-// CHECK-SAME: (%[[ARG0:.*]]: tensor
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: scf.for %[[I:.*]] = %[[C0]]
-// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[C0]]]
-// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]]
-// CHECK-NEXT: vector.extract %[[V]][%[[J]]]
-
-// -----
-
-#map = affine_map<(d0)[s0] -> (d0 * 2 + s0 + 1)>
-module {
- func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c63 = arith.constant 63 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
- %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK-LABEL: @misaligned_indexing_map
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-#map = affine_map<(d0)[s0] -> (d0 * 3 + s0)>
-module {
- func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c63 = arith.constant 63 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
- %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK-LABEL: @misaligned_indexing_map_2
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-module {
- func.func @misaligned_shape(%arg0: tensor<64x3xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c64 = arith.constant 64 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %extracted = tensor.extract %arg0[%i, %j] : tensor<64x3xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK-LABEL: @misaligned_shape
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-#map = affine_map<(d0)[s0] -> (d0 + s0 * 2)>
-module {
- func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c63 = arith.constant 63 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
- %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK-LABEL: @wrong_stride
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-// We could vectorize this as a float vector load of double the size, but we
-// don't currently.
-module {
- func.func @simple_read_complex(%arg0: tensor<64x2xcomplex<f32>>, %i: index) -> (complex<f32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex<f32>
- %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex<f32> {
- %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xcomplex<f32>>
- %added = complex.add %iter, %extracted : complex<f32>
- scf.yield %added : complex<f32>
- }
- return %loop : complex<f32>
- }
-}
-
-// CHECK-LABEL: @simple_read_complex
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-// This is vectorizable, but not currently supported.
-module {
- func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c64 = arith.constant 64 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %extracted = tensor.extract %arg0[%j, %i]
- : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK-LABEL: @layout
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-module {
- func.func @simple_write(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 2 : index
- %cst = arith.constant 0.0 : f32
- %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
- %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
- scf.yield %inserted : tensor<16x4xf32>
- }
- return %loop : tensor<16x4xf32>
- }
-}
-
-// CHECK-LABEL: @simple_write
-// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[I:.*]]: index
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[V:.*]] = scf.for
-// CHECK-NEXT: vector.insert
-// CHECK-NEXT: scf.yield
-// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[I]], %[[C0]]]
-// CHECK-NEXT: return %[[WRITTEN]]
-
-// -----
-
-module {
- func.func @write_with_use(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 2 : index
- %cst = arith.constant 0.0 : f32
- %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
- %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
- "dummy.op1"(%inserted) : (tensor<16x4xf32>) -> ()
- scf.yield %inserted : tensor<16x4xf32>
- }
- return %loop : tensor<16x4xf32>
- }
-}
-
-// CHECK-LABEL: @write_with_use
-// CHECK-NOT: transfer_write
-
-// -----
-
-module {
- func.func @write_not_to_iter_arg(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 2 : index
- %cst = arith.constant 0.0 : f32
- %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
- %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
- scf.yield %inserted : tensor<16x4xf32>
- }
- return %loop : tensor<16x4xf32>
- }
-}
-
-// CHECK-LABEL: @write_not_to_iter_arg
-// CHECK-NOT: transfer_write
-
-// -----
-
-module {
- func.func @write_not_yielded(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 2 : index
- %cst = arith.constant 0.0 : f32
- %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
- %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
- scf.yield %arg0 : tensor<16x4xf32>
- }
- return %loop : tensor<16x4xf32>
- }
-}
-
-// CHECK-LABEL: @write_not_yielded
-// CHECK-NOT: transfer_write
-
-// -----
-
-#map = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)>
-module {
- func.func @multiple(%arg0: tensor<32x4096xf32>, %arg1: tensor<4096xbf16>,
- %arg2: tensor<32xf32>, %arg3: tensor<32x4096xf32>,
- %arg4: index) -> (tensor<32x4096xf32>, f32) {
- %cst = arith.constant 1.000000e+00 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c8 = arith.constant 8 : index
- %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32>
- %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) {
- %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) {
- %2 = xla_gpu.apply_indexing #map(%j in [0, 1], %arg4 in [0, 255])[%i in [0, 7]]
- %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32>
- %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16>
- %3 = arith.extf %extracted3 : bf16 to f32
- %4 = arith.addf %extracted2, %3 : f32
- %5 = arith.addf %extracted1, %4 : f32
- %6 = arith.addf %iter3, %5 : f32
- %inserted = tensor.insert %5 into %iter2[%i, %2] : tensor<32x4096xf32>
- scf.yield %inserted, %6 : tensor<32x4096xf32>, f32
- }
- scf.yield %1#0, %1#1 : tensor<32x4096xf32>, f32
- }
- return %0#0, %0#1 : tensor<32x4096xf32>, f32
- }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 * 2 + s0 * 512)>
-// CHECK-LABEL: @multiple
-// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: scf.for %[[I:.*]] = %[[C0]]
-// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 255])[%[[I]] in [0, 7]]
-// CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]]
-// CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]]
-// CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>)
-// CHECK-DAG: vector.extract %[[READ1]][%[[J]]]
-// CHECK-DAG: vector.extract %[[READ2]][%[[J]]]
-// CHECK: extf
-// CHECK-NEXT: addf
-// CHECK-NEXT: %[[TO_INSERT:.*]] = arith.addf
-// CHECK-NEXT: %[[TO_YIELD:.*]] = arith.addf
-// CHECK-NEXT: %[[V_NEXT:.*]] = vector.insert %[[TO_INSERT]], %[[V]] [%[[J]]]
-// CHECK-NEXT: scf.yield %[[TO_YIELD]], %[[V_NEXT]]
-// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[I]], %[[BASE]]]
-// CHECK: scf.yield %[[WRITTEN]], %[[INNER]]#0
-
-// -----
-
-#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0)>
-module {
- func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c63 = arith.constant 63 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
- %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> ((d0 mod 16) * 4)>
-// CHECK-LABEL: @remainder_with_modulo
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: scf.for %[[I:.*]] = %[[C0]]
-// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
-// CHECK: vector.transfer_read {{.*}}[%[[BASE]]]
-
-// -----
-
-#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0)>
-module {
- func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c63 = arith.constant 63 : index
- %cst = arith.constant 0.0 : f32
- %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
- %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
- %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
- %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
- %added = arith.addf %iter1, %extracted : f32
- scf.yield %added : f32
- }
- scf.yield %inner : f32
- }
- return %outer : f32
- }
-}
-
-// CHECK-LABEL: @remainder_with_modulo_misaligned
-// CHECK-NOT: vector.transfer_read
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc
deleted file mode 100644
index 7d963f3..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <utility>
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_UNSWITCHLOOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class UnswitchLoopsPass
- : public impl::UnswitchLoopsPassBase<UnswitchLoopsPass> {
- public:
- void runOnOperation() override;
-};
-
-struct UnswitchLoop : mlir::OpRewritePattern<mlir::scf::ForOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
- if (op.getBody()->getOperations().size() != 2) {
- return rewriter.notifyMatchFailure(
- op, "loop body is not a single instruction");
- }
- auto if_op = mlir::dyn_cast<mlir::scf::IfOp>(op.getBody()->front());
- if (!if_op) {
- return rewriter.notifyMatchFailure(op, "no if found inside the loop");
- }
- if (mlir::matchPattern(if_op.getCondition(), mlir::m_Constant())) {
- return rewriter.notifyMatchFailure(op, "condition is a constant");
- }
-
- auto true_cst = rewriter.create<mlir::arith::ConstantOp>(
- op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
- auto false_cst = rewriter.create<mlir::arith::ConstantOp>(
- op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
- rewriter.setInsertionPoint(op);
- mlir::IRMapping mapping;
- mapping.map(if_op.getCondition(), false_cst);
- auto false_branch_loop = op->clone(mapping);
- auto new_if = rewriter.create<mlir::scf::IfOp>(
- op.getLoc(), op.getResultTypes(), if_op.getCondition(), true, true);
- rewriter.replaceAllUsesWith(op.getResults(), new_if.getResults());
-
- auto then_builder = new_if.getThenBodyBuilder(rewriter.getListener());
- auto then_yield =
- then_builder.create<mlir::scf::YieldOp>(op.getLoc(), op.getResults());
- rewriter.moveOpBefore(op, then_yield);
- rewriter.modifyOpInPlace(if_op, [&]() { if_op->setOperand(0, true_cst); });
-
- auto else_builder = new_if.getElseBodyBuilder(rewriter.getListener());
- else_builder.insert(false_branch_loop);
- else_builder.create<mlir::scf::YieldOp>(op.getLoc(),
- false_branch_loop->getResults());
-
- return mlir::success();
- }
-};
-
-void UnswitchLoopsPass::runOnOperation() {
- mlir::RewritePatternSet patterns(&getContext());
- patterns.add<UnswitchLoop>(&getContext());
- mlir::scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
- mlir::scf::IfOp::getCanonicalizationPatterns(patterns, &getContext());
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- signalPassFailure();
- }
-}
-
-} // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
-CreateUnswitchLoopsPass() {
- return std::make_unique<UnswitchLoopsPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc b/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc
deleted file mode 100644
index 0007984..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc
+++ /dev/null
@@ -1,359 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cstdint>
-#include <memory>
-#include <numeric>
-#include <optional>
-#include <utility>
-
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_VECTORIZELOADSANDSTORESPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-// Tries to find the stride of a symbol or dimension in an affine expression.
-// Returns std::nullopt if the stride could not be determined.
-//
-// Note: this function only attempts to handle the cases where the stride is
-// known to be 0 or 1.
-//
-// Example: the stride of `d0` in `(d0 + d1)` is 1.
-// Example: the stride of `d0` in `d0 * 2` is unknown (nullopt).
-std::optional<int> GetStride(mlir::AffineExpr expr,
- mlir::AffineExpr dim_or_sym) {
- if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
- auto lhs_stride = GetStride(binop.getLHS(), dim_or_sym);
- auto rhs_stride = GetStride(binop.getRHS(), dim_or_sym);
-
- if (binop.getKind() == mlir::AffineExprKind::Add) {
- if (lhs_stride && rhs_stride) {
- return *lhs_stride + *rhs_stride;
- }
- return std::nullopt;
- }
- // Just return 0 if the expression doesn't occur on either side.
- if (lhs_stride == 0 && rhs_stride == 0) {
- return 0;
- }
- // Otherwise, we don't know the stride.
- return std::nullopt;
- }
- return expr == dim_or_sym ? 1 : 0;
-}
-
-int64_t GetAlignmentOfRemainder(mlir::AffineExpr expr,
- mlir::AffineExpr dim_or_sym) {
- if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
- auto lhs_align = GetAlignmentOfRemainder(binop.getLHS(), dim_or_sym);
- auto rhs_align = GetAlignmentOfRemainder(binop.getRHS(), dim_or_sym);
-
- std::optional<int64_t> rhs_cst = std::nullopt;
- if (binop.getRHS().getKind() == mlir::AffineExprKind::Constant) {
- rhs_cst = binop.getRHS().cast<mlir::AffineConstantExpr>().getValue();
- }
-
- switch (binop.getKind()) {
- case mlir::AffineExprKind::Add:
- if (binop.getLHS() == dim_or_sym) return rhs_align;
- if (binop.getRHS() == dim_or_sym) return lhs_align;
- return std::gcd(lhs_align, rhs_align);
- case mlir::AffineExprKind::Mul:
- return lhs_align * rhs_align;
- case mlir::AffineExprKind::FloorDiv:
- case mlir::AffineExprKind::CeilDiv:
- return 1;
- case mlir::AffineExprKind::Mod:
- // (a * c) % (b * c) = (a % b) * c.
- return std::gcd(lhs_align, rhs_align);
- default:
- llvm_unreachable("expr is none of the binary expressions");
- }
- }
- if (auto cst = mlir::dyn_cast<mlir::AffineConstantExpr>(expr)) {
- return cst.getValue();
- }
- return 1;
-}
-
-// Attempts to extract the vector type for the given loop. This means:
-// - checks that the lower bound is 0
-// - checks that the step is 1
-// - checks that the upper bound is 2 or 4.
-// Returns a vector type with the given upper bound and the tensor's element
-// type.
-mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type,
- mlir::scf::ForOp loop) {
- // TODO(jreiffers): Support layouts.
- if (tensor_type.getEncoding()) {
- return nullptr;
- }
- if (!mlir::VectorType::isValidElementType(tensor_type.getElementType())) {
- return nullptr;
- }
- if (mlir::getConstantIntValue(loop.getStep()) != 1 ||
- mlir::getConstantIntValue(loop.getLowerBound()) != 0) {
- return nullptr;
- }
- std::optional<int> vector_size =
- mlir::getConstantIntValue(loop.getUpperBound());
- if (vector_size != 2 && vector_size != 4) {
- return nullptr; // Unsupported vector size.
- }
- if (tensor_type.getRank() > 1 &&
- tensor_type.getShape().back() % *vector_size) {
- return nullptr; // Misaligned start indices.
- }
- return mlir::VectorType::get({*vector_size}, tensor_type.getElementType());
-}
-
-std::optional<llvm::SmallVector<mlir::Value>> GetVectorBaseIndices(
- mlir::ValueRange indices, mlir::scf::ForOp loop,
- mlir::VectorType vector_type, mlir::ImplicitLocOpBuilder& b) {
- if (indices.empty()) {
- return std::nullopt;
- }
-
- // The major dimensions' indices must all be defined outside the loop.
- for (int i = 0; i < indices.size() - 1; ++i) {
- if (!indices[i].getParentRegion()->isProperAncestor(
- &loop.getBodyRegion())) {
- return std::nullopt;
- }
- }
-
- mlir::Value induction_var = loop.getInductionVar();
- if (indices.back() == induction_var) {
- llvm::SmallVector<mlir::Value> ret = indices;
- ret.back() = b.create<mlir::arith::ConstantIndexOp>(0);
- return ret;
- }
-
- auto apply_indexing =
- mlir::dyn_cast_or_null<ApplyIndexingOp>(indices.back().getDefiningOp());
- if (!apply_indexing) {
- return std::nullopt;
- }
-
- // We don't generate these, but they are allowed in theory.
- if (apply_indexing->getNumResults() != 1) {
- return std::nullopt;
- }
- mlir::AffineMap map = apply_indexing.getAffineMap();
-
- int induction_var_operand_index;
- mlir::AffineExpr induction_var_expr = nullptr;
- for (auto [index, operand] : llvm::enumerate(apply_indexing.getOperands())) {
- if (operand == induction_var) {
- if (induction_var_expr) {
- // The induction variable should be used only once.
- return std::nullopt;
- }
- induction_var_operand_index = index;
- induction_var_expr = index < map.getNumDims()
- ? mlir::getAffineDimExpr(index, b.getContext())
- : mlir::getAffineSymbolExpr(
- index - map.getNumDims(), b.getContext());
- }
- }
- if (!induction_var_expr) {
- return std::nullopt;
- }
-
- if (GetStride(map.getResult(0), induction_var_expr) != 1) {
- // The indexing map is not contiguous in the vectorized dimension.
- return std::nullopt;
- }
-
- if (GetAlignmentOfRemainder(map.getResult(0), induction_var_expr) %
- vector_type.getNumElements()) {
- return std::nullopt;
- }
-
- auto operands = llvm::to_vector(apply_indexing.getOperands());
- operands[induction_var_operand_index] =
- b.create<mlir::arith::ConstantIndexOp>(0);
-
- llvm::SmallVector<mlir::Value> ret = indices;
- ret.back() =
- b.create<ApplyIndexingOp>(operands, map, apply_indexing.getLowerBounds(),
- apply_indexing.getUpperBounds())
- ->getResult(0);
- return ret;
-}
-
-bool IsConflictFree(mlir::tensor::ExtractOp op) {
- return op.getTensor().getParentRegion()->isProperAncestor(
- op->getParentRegion());
-}
-
-struct VectorizeLoad : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- mlir::tensor::ExtractOp op,
- mlir::PatternRewriter& rewriter) const override {
- auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
- if (!loop) {
- return rewriter.notifyMatchFailure(op, "no loop found");
- }
- if (!IsConflictFree(op)) {
- return rewriter.notifyMatchFailure(op,
- "source may be written in the loop");
- }
-
- auto vector_type = GetVectorType(op.getTensor().getType(), loop);
- if (!vector_type) {
- return rewriter.notifyMatchFailure(op, "not a vectorizable loop");
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- b.setInsertionPoint(loop);
- auto vector_indices =
- GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
- if (!vector_indices) {
- return rewriter.notifyMatchFailure(
- op, "the instruction does not access contiguous elements");
- }
-
- auto loaded_vector = b.create<mlir::vector::TransferReadOp>(
- vector_type, op.getTensor(), *vector_indices,
- llvm::ArrayRef<bool>{true});
- rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(
- op, loaded_vector, loop.getInductionVar());
- return mlir::success();
- }
-};
-
-// Verifies that the insertions happening in the loop can all safely be batched
-// in the end.
-bool IsConflictFree(mlir::tensor::InsertOp op) {
- // The insertion's only use must be the yield.
- if (!op->hasOneUse() || !mlir::isa<mlir::scf::YieldOp>(*op->user_begin())) {
- return false;
- }
- // The destination must be one of the loop's block arguments, and the
- // destination must be the argument's only use.
- auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(op.getDest());
- return bbarg && bbarg.hasOneUse() &&
- bbarg.getOwner()->getParentOp() == op->getParentOp();
-}
-
-struct VectorizeStore : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult matchAndRewrite(
- mlir::tensor::InsertOp op,
- mlir::PatternRewriter& rewriter) const override {
- auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
- if (!loop) {
- return rewriter.notifyMatchFailure(op, "no loop found");
- }
- if (!IsConflictFree(op)) {
- return rewriter.notifyMatchFailure(op, "write may be read back by loop");
- }
- auto vector_type = GetVectorType(op.getDest().getType(), loop);
- if (!vector_type) {
- return rewriter.notifyMatchFailure(op, "loop is not vectorizable");
- }
-
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- b.setInsertionPoint(loop);
- auto vector_indices =
- GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
- if (!vector_indices) {
- return rewriter.notifyMatchFailure(
- op, "the instruction does not access contiguous elements");
- }
-
- auto init = b.create<mlir::arith::ConstantOp>(b.getZeroAttr(vector_type))
- .getResult();
-
- auto yield_fn = [&](mlir::OpBuilder& yield_b, mlir::Location yield_loc,
- llvm::ArrayRef<mlir::BlockArgument> bbarg) {
- auto induction_var =
- mlir::cast<mlir::scf::ForOp>(bbarg.front().getOwner()->getParentOp())
- .getInductionVar();
- auto insert_op = yield_b.create<mlir::vector::InsertOp>(
- yield_loc, op.getScalar(), bbarg.front(), induction_var);
- return llvm::SmallVector<mlir::Value>{insert_op.getResult()};
- };
- int result_index = op->use_begin()->getOperandNumber();
- auto new_for = *loop.replaceWithAdditionalYields(
- rewriter, init,
- /*replaceInitOperandUsesInLoop=*/false, yield_fn);
-
- b.setInsertionPointAfter(new_for);
- rewriter.replaceOp(op, op.getDest());
-
- auto filled_vector = new_for->getResults().back();
- auto written = b.create<mlir::vector::TransferWriteOp>(
- filled_vector, new_for.getInits()[result_index], *vector_indices,
- llvm::ArrayRef<bool>{true});
- new_for->getResult(result_index).replaceAllUsesWith(written.getResult());
-
- return mlir::success();
- }
-};
-
-class VectorizeLoadsAndStoresPass
- : public impl::VectorizeLoadsAndStoresPassBase<
- VectorizeLoadsAndStoresPass> {
- public:
- void runOnOperation() override {
- mlir::RewritePatternSet patterns(&getContext());
- patterns.add<VectorizeLoad, VectorizeStore>(&getContext());
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
-CreateVectorizeLoadsAndStoresPass() {
- return std::make_unique<VectorizeLoadsAndStoresPass>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc
index 774fa8e..db4a93a 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc
@@ -38,7 +38,7 @@
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/model/affine_map_printer.h"
#include "xla/tests/filecheck.h"
@@ -82,7 +82,7 @@
TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string));
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
auto fusion_emitter = GetEmitter(analysis);
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h
index 0006dc5..3b0c78c 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h
+++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h
@@ -72,7 +72,7 @@
auto& module =
modules_.emplace_back(ParseAndReturnVerifiedModule(hlo_string).value());
auto* root = module->entry_computation()->root_instruction();
- analyses_.push_back(AnalyzeFusion(*root, device_info_));
+ analyses_.push_back(HloFusionAnalysis::Create(*root, device_info_));
return GetEmitter(analyses_.back());
}
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc
deleted file mode 100644
index 77c1e91..0000000
--- a/third_party/xla/xla/service/gpu/fusions/reduction.cc
+++ /dev/null
@@ -1,1330 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/reduction.h"
-
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/container/node_hash_map.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/GlobalVariable.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Instructions.h"
-#include "llvm/IR/Type.h"
-#include "llvm/IR/Value.h"
-#include "llvm/Support/AtomicOrdering.h"
-#include "llvm/Support/Casting.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/reduction_base.h"
-#include "xla/service/gpu/fusions/thunk_util.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/ir_emitter_nested.h"
-#include "xla/service/gpu/kernel_arguments.h"
-#include "xla/service/gpu/kernel_reuse_cache.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/gpu/runtime/kernel_thunk.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/service/gpu/target_util.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/kernel_support_library.h"
-#include "xla/service/llvm_ir/llvm_loop.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/service/llvm_ir/loop_emitter.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using TypedPointer = std::pair<llvm::Value* const, llvm::Type* const>;
-
-// Fusion root -> array of indexes, one per reduction output.
-using ReductionOutputMap =
- ConstHloInstructionMap<absl::Span<llvm_ir::IrArray const>>;
-
-using ExtraOutputGensMap = ConstHloInstructionMap<llvm_ir::ElementGenerator>;
-
-int GetNumOutputs(const Shape& shape) {
- if (shape.IsTuple()) {
- return shape.tuple_shapes_size();
- }
- return 1;
-}
-
-const Shape& OutputShape(const Shape& output_shape, int output_index) {
- CHECK(output_index == 0 || output_shape.IsTuple());
- return output_shape.IsTuple() ? output_shape.tuple_shapes(output_index)
- : output_shape;
-}
-
-llvm::Type* GetIndexType(const HloFusionInstruction& fusion,
- const Tiling& tiling, llvm::IRBuilder<>* builder) {
- return GetIndexTypeForKernel(
- &fusion, tiling.GetNumThreadsPerBlock() * tiling.GetNumBlocks(), builder);
-}
-
-llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input,
- llvm::Type* element_type, llvm::Twine name) {
- return builder->CreateAddrSpaceCast(
- input,
- llvm::PointerType::get(element_type,
- /*AddressSpace=*/0),
- name);
-}
-
-class ReductionEmitter {
- public:
- ReductionEmitter(const HloFusionAnalysis& analysis,
- const ReductionInfo& reduction_codegen_info,
- IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- llvm::IRBuilder<>* builder)
- : builder_(builder),
- elemental_emitter_(ir_emitter_context, builder_),
- analysis_(analysis),
- reduction_codegen_info_(reduction_codegen_info),
- ir_emitter_context_(ir_emitter_context),
- fusion_(fusion),
- index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(),
- elemental_emitter_.builder())) {
- for (auto hero : analysis.fusion_heroes()) {
- if (hero.opcode() == HloOpcode::kReduce) {
- for (int i = 0; i < hero.instruction().operand_count() / 2; ++i) {
- CHECK(LayoutUtil::IsMonotonicWithDim0Major(
- hero.instruction().operand(i)->shape().layout()))
- << "reduction-layout-normalizer must run before code generation";
- }
- }
- }
- }
-
- absl::StatusOr<FusionEmissionResult> EmitInitializers();
- absl::Status EmitKernel(const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs);
-
- private:
- friend class ReductionGroupEmitter;
-
- absl::StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkForFusion(
- const LaunchDimensions& launch_dimensions,
- absl::string_view discriminator,
- std::function<absl::Status(std::vector<llvm_ir::IrArray>,
- std::vector<llvm_ir::IrArray>)>
- kernel_builder_fn);
-
- absl::StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
- const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice,
- int output_index);
-
- absl::Status EmitIRForReduction(
- absl::Span<const HloInstruction* const> instr_index_group,
- FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
- const Shape& input_shape);
-
- void MaybeEmitFenceForAMDGPU();
- void EmitSyncThreads();
-
- int ReducedDimensionSize() const {
- return reduction_codegen_info_.GetTiling().GetShape()[2];
- }
-
- llvm::IRBuilder<>* builder_;
- GpuElementalIrEmitter elemental_emitter_;
- const HloFusionAnalysis& analysis_;
- const ReductionInfo& reduction_codegen_info_;
- IrEmitterContext& ir_emitter_context_;
- const HloFusionInstruction& fusion_;
- llvm::Type* index_ty_;
-};
-
-class ReductionEmitter;
-
-class ReductionGroupEmitter {
- public:
- struct ReductionCalculationState {
- std::optional<llvm_ir::SharedMemoryTile> shared_cache;
- llvm::Value* initial_value;
- llvm::AllocaInst* partial_result_address;
- llvm::AllocaInst* input_address;
- llvm_ir::ElementGenerator input_gen;
- };
-
- ReductionGroupEmitter(
- ReductionEmitter& reduction_emitter,
- absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
- const ReductionOutputMap& result_ir_arrays,
- FusedIrEmitter& fused_emitter);
-
- const ReductionCalculationState& GetCalculationStateFor(
- const HloInstruction* instruction, int operand_idx) const {
- const ReductionOpState& op_state = state_.at(instruction);
- CHECK_LT(operand_idx, op_state.size());
- return op_state[operand_idx];
- }
-
- void SetCalculationStateFor(
- const ReductionCalculationState& calculation_state,
- const HloInstruction* instruction, int operand_idx) {
- ReductionOpState& op_state = state_[instruction];
- CHECK_EQ(operand_idx, op_state.size());
- op_state.push_back(calculation_state);
- }
-
- void EmitReductionOutputForRowReduction(
- const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction,
- const std::vector<const HloInstruction*>& roots) const;
-
- void EmitReductionOutputForColumnReduction(
- const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction,
- const std::vector<const HloInstruction*>& roots) const;
-
- void EmitFullWarpShuffleDownLoopForReduce(
- const HloComputation* reducer,
- absl::Span<TypedPointer const> partial_result_addresses,
- int threads_per_block, int num_results_per_warp) const;
-
- void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction,
- const std::vector<const HloInstruction*>& roots,
- absl::Span<TypedPointer const> values) const;
-
- llvm_ir::IrArray::Index GetOutputIndexForReduction(
- const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction, const HloInstruction* root,
- int output_idx) const;
-
- void GenerateElementForReducer(const HloReduceInstruction* reduction,
- const llvm_ir::IrArray::Index& index) const;
-
- absl::Status EmitExtraOutputsForReduce(
- const Shape& reduction_operand_shape,
- const llvm_ir::IrArray::Index& index,
- const ExtraOutputGensMap& extra_output_gens);
-
- private:
- ReductionEmitter& reduction_emitter_;
- const ReductionOutputMap& result_ir_arrays_;
-
- // One state per reduction operand.
- using ReductionOpState = absl::InlinedVector<ReductionCalculationState, 2>;
-
- // HloInstruction -> operand_idx -> cache
- absl::flat_hash_map<const HloInstruction*, ReductionOpState> state_;
-};
-
-// Creates accumulator alloca's, populates them with initial values, generates
-// __shared__ caches and returns the populated object.
-ReductionGroupEmitter::ReductionGroupEmitter(
- ReductionEmitter& reduction_emitter,
- absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
- const ReductionOutputMap& result_ir_arrays, FusedIrEmitter& fused_emitter)
- : reduction_emitter_(reduction_emitter),
- result_ir_arrays_(result_ir_arrays) {
- const ReductionInfo& reduction_info =
- reduction_emitter_.reduction_codegen_info_;
- VLOG(10) << "Emit prologue for reduction: "
- << reduction_emitter_.fusion_.ToString();
-
- auto* builder = reduction_emitter_.builder_;
- for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) {
- for (int op_result_idx = 0;
- op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) {
- Shape result_shape = OutputShape(reduce_hlo->shape(), op_result_idx);
-
- llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
- result_shape.element_type(), builder->GetInsertBlock()->getModule());
- llvm::AllocaInst* reduction_input_address =
- llvm_ir::EmitAllocaAtFunctionEntry(
- element_type, "reduction_input_address", builder);
-
- llvm::AllocaInst* result_address = llvm_ir::EmitAllocaAtFunctionEntry(
- element_type, "partial_reduction_result", builder);
-
- const HloInstruction* init_value =
- reduce_hlo->init_values()[op_result_idx];
-
- // Initialize the partial result with the initial value of the reduction.
- llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(
- *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty()))
- .value();
-
- builder->CreateStore(init_ir_value, result_address);
- const Tiling& tiling = reduction_info.GetTiling();
- auto shared_cache = [&]() -> std::optional<llvm_ir::SharedMemoryTile> {
- auto* module = reduction_emitter.ir_emitter_context_.llvm_module();
- if (reduction_info.IsRowReduction()) {
- // Multi-row reductions do not use shared memory.
- if (RowReductionGetRowsPerWarp(
- reduction_emitter_.ReducedDimensionSize()) > 1) {
- return std::nullopt;
- }
- // Allocate one shared memory element per warp.
- auto block_size = tiling.GetThreadsPerBlock();
- CHECK_EQ(block_size[ReductionDimensions::kRowMinorReducedDimension] %
- WarpSize(),
- 0);
- return llvm_ir::AllocateSharedMemoryTile(
- module, element_type,
- {block_size[ReductionDimensions::kRowKeptDimension],
- block_size[ReductionDimensions::kRowMinorReducedDimension] /
- WarpSize()},
- "shared_cache");
- }
- const auto& num_threads = tiling.GetThreadsPerBlock();
- int n = num_threads[ReductionDimensions::kColReducedDimension];
- CHECK_EQ(n, num_threads[ReductionDimensions::kColMinorKeptDimension]);
- // The "+1" is used to avoid bank conflicts.
- return llvm_ir::AllocateSharedMemoryTile(module, element_type,
- {n, n + 1}, "shared_cache");
- }();
-
- llvm_ir::ElementGenerator input_gen =
- *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]);
- SetCalculationStateFor({shared_cache, init_ir_value, result_address,
- reduction_input_address, input_gen},
- reduce_hlo, op_result_idx);
- }
- }
-}
-
-void ReductionEmitter::MaybeEmitFenceForAMDGPU() {
- auto* module = builder_->GetInsertBlock()->getModule();
- if (IsAMDGPU(module) &&
- ir_emitter_context_.rocm_compute_capability().fence_before_barrier()) {
- builder_->CreateFence(
- llvm::AtomicOrdering::SequentiallyConsistent,
- builder_->getContext().getOrInsertSyncScopeID("workgroup"));
- }
-}
-
-void ReductionEmitter::EmitSyncThreads() {
- MaybeEmitFenceForAMDGPU();
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder_);
-}
-
-// Builds a thunk that calls a new or reused kernel for a fusion operation.
-//
-// The caller must specify the same launch dimensions for fusions which have
-// the same computation.
-//
-// If a given fusion is implemented using multiple kernels, then for each
-// kernel we should provide a discriminator, such as "init" and "impl".
-//
-// The builder_fn is only invoked if the kernel couldn't be reused.
-//
-// This is the typical usage pattern of this method:
-//
-// ```
-// auto builder_fn = [](std::vector<llvm_ir::IrArray> inputs,
-// std::vector<llvm_ir::IrArray> outputs) { ... };
-// TF_ASSIGN_OR_RETURN(
-// auto thunk,
-// BuildKernelThunkForFusion(..., launch_dimensions, builder_fn));
-// AddThunkToThunkSequence(std::move(thunk))
-// ```
-absl::StatusOr<std::unique_ptr<Thunk>>
-ReductionEmitter::BuildKernelThunkForFusion(
- const LaunchDimensions& launch_dimensions, absl::string_view discriminator,
- std::function<absl::Status(std::vector<llvm_ir::IrArray>,
- std::vector<llvm_ir::IrArray>)>
- kernel_builder_fn) {
- const HloComputation* fused_computation =
- fusion_.fused_instructions_computation();
- std::string suggested_kernel_name = std::string(fusion_.name());
-
- TF_ASSIGN_OR_RETURN(auto kernel_arguments,
- KernelArguments::Create(
- ir_emitter_context_.buffer_assignment(), &fusion_));
-
- auto [status_or_entry, cached] =
- ir_emitter_context_.kernel_cache().GetWithStatus(
- fused_computation, kernel_arguments.args(), discriminator,
- [&]() -> absl::StatusOr<KernelReuseCache::Entry> {
- llvm::Function* kernel;
- std::vector<llvm_ir::IrArray> input_arrays;
- std::vector<llvm_ir::IrArray> output_arrays;
- TF_ASSIGN_OR_RETURN(
- std::tie(kernel, input_arrays, output_arrays),
- BuildKernelPrototype(ir_emitter_context_, suggested_kernel_name,
- kernel_arguments.args(),
- fusion_.operand_count(), launch_dimensions,
- builder_));
- TF_RETURN_IF_ERROR(kernel_builder_fn(input_arrays, output_arrays));
- // Shared memory is allocated statically.
- return {{kernel->getName().str(), launch_dimensions,
- /*cluster_dim=*/std::nullopt,
- /*shmem_bytes=*/0}};
- });
- TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);
- if (cached) {
- VLOG(3) << "Reuse: " << suggested_kernel_name << " -> "
- << entry->kernel_name;
- }
-
- return std::make_unique<KernelThunk>(
- &fusion_, entry->kernel_name, kernel_arguments.args(), launch_dimensions,
- entry->cluster_dim, entry->shmem_bytes);
-}
-
-absl::Status ReductionGroupEmitter::EmitExtraOutputsForReduce(
- const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index,
- const ExtraOutputGensMap& extra_output_gens) {
- if (extra_output_gens.empty()) {
- return absl::OkStatus();
- }
-
- auto* builder = reduction_emitter_.builder_;
- // Compute all extra output values before writing them. This avoids
- // overwriting aliased input/output buffers before all reads occurred.
- std::vector<std::pair<const HloInstruction*, llvm::Value*>>
- extra_output_ir_values;
- extra_output_ir_values.reserve(extra_output_gens.size());
-
- auto get_index = [&](const HloInstruction* instr) {
- const Shape& s = instr->shape();
- return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s)
- ? index
- : index.SourceIndexOfBitcast(reduction_operand_shape, s,
- builder);
- };
-
- for (const auto& [instr, generator] : extra_output_gens) {
- TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
- generator(get_index(instr)));
- extra_output_ir_values.emplace_back(instr, extra_output_ir_value);
- }
-
- for (const auto& [instr, generator] : extra_output_ir_values) {
- absl::Span<llvm_ir::IrArray const> result_ir = result_ir_arrays_.at(instr);
- CHECK_EQ(result_ir.size(), 1);
- result_ir[0].EmitWriteArrayElement(get_index(instr), generator, builder);
- }
- return absl::OkStatus();
-}
-
-absl::StatusOr<std::unique_ptr<Thunk>>
-ReductionEmitter::BuildFusedInitializerThunk(const HloInstruction* fusion_root,
- BufferAllocation::Slice dest_slice,
- int output_index) {
- const HloReduceInstruction* reduce =
- DynCast<HloReduceInstruction>(fusion_root);
- TF_RET_CHECK(reduce);
-
- const HloInstruction* init_value = reduce->init_values()[0];
- TF_ASSIGN_OR_RETURN(
- std::optional<std::unique_ptr<Thunk>> constant_init_thunk,
- BuildConstantInitializerThunk(ir_emitter_context_, fusion_root,
- init_value, dest_slice));
- if (constant_init_thunk) {
- return *std::move(constant_init_thunk);
- }
-
- const Shape& dest_shape = fusion_root->shape();
-
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- dest_shape, ir_emitter_context_.gpu_device_info());
- const HloComputation* fused_computation =
- fusion_.fused_instructions_computation();
-
- auto builder_fn = [&](std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs) -> absl::Status {
- FusedIrEmitter fused_emitter(elemental_emitter_);
- for (int i = 0; i < fused_computation->num_parameters(); i++) {
- fused_emitter.BindGenerator(
- *fused_computation->parameter_instruction(i),
- [builder = builder_,
- input = inputs[i]](llvm_ir::IrArray::Index index) {
- return input.EmitReadArrayElement(index, builder);
- });
- }
- HloInstruction* instr = fused_computation->root_instruction();
- if (instr->opcode() == HloOpcode::kTuple) {
- instr = instr->mutable_operand(output_index);
- } else {
- CHECK_EQ(0, output_index);
- }
- TF_RET_CHECK(instr->shape().IsArray());
- TF_ASSIGN_OR_RETURN(auto generator,
- fused_emitter.GetGenerator(*instr->operand(1)));
- TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]},
- launch_dimensions, builder_)
- .EmitLoop(fusion_.name()));
- return absl::OkStatus();
- };
-
- return BuildKernelThunkForFusion(launch_dimensions,
- /*discriminator=*/
- absl::StrCat("init_", output_index),
- builder_fn);
-}
-
-// Emits shuffle-down reduction for the `partial_result_address` using the
-// reduction computation `reducer`, writes output into
-// `partial_result_address`.
-//
-// Multiple partial_result_address inputs happen when doing variadic
-// reduction: each one should get the output value.
-void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce(
- const HloComputation* reducer,
- absl::Span<TypedPointer const> partial_result_addresses,
- int threads_per_block, int num_results_per_warp) const {
- // This only works when the block size is a multiple of 32 threads.
- // We check this here as a mistake in the number of threads per
- // block is very hard to detect.
- CHECK_EQ(threads_per_block % 32, 0);
- CHECK_EQ(WarpSize() % num_results_per_warp, 0);
-
- auto* builder = reduction_emitter_.builder_;
- for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) {
- absl::InlinedVector<llvm::Value*, 2> reduction_params;
-
- for (auto acc : partial_result_addresses) {
- reduction_params.push_back(acc.first);
- }
-
- for (auto [partial_result_address, element_type] :
- partial_result_addresses) {
- int bit_width = llvm_ir::GetSizeInBits(element_type);
- llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
- element_type, "result_from_other_lane", builder);
-
- reduction_params.push_back(result_from_other_lane);
-
- // Bitcast cannot be applied to aggregate types (even packed ones), so
- // we bitcast addresses of load/store to intN* of the same bit-width.
- llvm::Type* shuffled_value_type = element_type->isStructTy()
- ? builder->getIntNTy(bit_width)
- : element_type;
-
- llvm::Value* partial_result =
- builder->CreateLoad(shuffled_value_type, partial_result_address,
- "partial_reduction_result");
- builder->CreateStore(
- EmitFullWarpShuffleDown(
- partial_result, builder->getInt32(distance), builder,
- reduction_emitter_.ir_emitter_context_.gpu_device_info()),
- result_from_other_lane);
- }
-
- absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
- CallNestedComputationWithScalarAddrs(
- builder, reduction_emitter_.ir_emitter_context_, *reducer,
- reduction_params);
- TF_CHECK_OK(returned_scalars.status());
-
- for (int i = 0; i < returned_scalars->size(); i++) {
- builder->CreateStore(/*Val=*/returned_scalars->at(i),
- /*Ptr=*/partial_result_addresses[i].first);
- }
- }
-}
-
-llvm_ir::IrArray::Index ReductionGroupEmitter::GetOutputIndexForReduction(
- const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction, const HloInstruction* root,
- int output_idx) const {
- auto* builder = reduction_emitter_.builder_;
- auto* index_ty = reduction_emitter_.index_ty_;
-
- // 1d or 2d output index (for row/column reduction).
- auto projected_index = [&]() -> llvm_ir::IrArray::Index {
- const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
- const auto& offset = tiling_kernel_info.tile_origin;
- const auto& shape = reduction_info.GetTiling().GetXlaShape();
- const auto& thread_ids = tiling_kernel_info.thread_id_info.thread_ids;
- if (reduction_info.IsRowReduction()) {
- constexpr int kDim = ReductionDimensions::kRowKeptDimension;
- return {{builder->CreateAdd(offset[kDim], thread_ids[kDim])},
- {shape.dimensions(kDim)},
- index_ty};
- }
- auto* major_idx = offset[ReductionDimensions::kColMajorKeptDimension];
- auto* minor_idx = builder->CreateAdd(
- offset[ReductionDimensions::kColMinorKeptDimension],
- thread_ids[ReductionDimensions::kColReducedDimension]);
- return {{major_idx, minor_idx},
- ShapeUtil::DeleteDimension(
- ReductionDimensions::kColReducedDimension, shape),
- index_ty};
- }();
-
- auto physical_shape = ShapeUtil::DeleteDimensions(
- reduction->dimensions(), reduction->operand(output_idx)->shape());
- auto physical_index =
- projected_index.SourceIndexOfBitcast(physical_shape, builder);
- return llvm_ir::IrArray::Index(physical_index.multidim(),
- OutputShape(reduction->shape(), output_idx),
- index_ty)
- .SourceIndexOfBitcast(OutputShape(root->shape(), output_idx), builder);
-}
-
-void ReductionGroupEmitter::WriteReductionOutput(
- const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction,
- const std::vector<const HloInstruction*>& roots,
- const absl::Span<TypedPointer const> values) const {
- auto* builder = reduction_emitter_.builder_;
- const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
- const HloComputation* reducer = reduction->to_apply();
- for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) {
- auto [output_ptr, type] = typed_ptr;
- for (auto root : roots) {
- llvm_ir::IrArray::Index output_index =
- GetOutputIndexForReduction(tiling_kernel_info, reduction, root, oidx);
-
- llvm::Value* output_address =
- result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress(
- output_index, builder, "output_element_address");
- if (reduction_info.IsRaceFree()) {
- FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_);
- llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output");
- fused_emitter.BindGenerator(
- *reduction,
- [&](const llvm_ir::IrArray::Index& index) { return loaded; });
- llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root);
- llvm::Value* generated = *gen(output_index);
- builder->CreateStore(generated, output_address);
- } else {
- CHECK_EQ(values.size(), 1);
- CHECK_EQ(roots.size(), 1);
- CHECK_EQ(reduction, root)
- << "output fusion is not allowed for racing reductions";
- TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
- builder, reduction_emitter_.ir_emitter_context_, *reducer,
- output_address, output_ptr, type));
- }
- }
- }
-}
-
-void ReductionGroupEmitter::EmitReductionOutputForRowReduction(
- const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction,
- const std::vector<const HloInstruction*>& roots) const {
- const HloComputation* reducer = reduction->to_apply();
- const auto& thread_id_info = tiling_kernel_info.thread_id_info;
- const auto& thread_ids = thread_id_info.thread_ids;
- auto* thread_id_x =
- thread_ids[ReductionDimensions::kRowMinorReducedDimension];
- auto constant = [&](uint64_t c) -> llvm::Constant* {
- return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
- };
-
- auto* builder = reduction_emitter_.builder_;
- auto is_zero = [&](llvm::Value* value) {
- return builder->CreateICmpEQ(value, constant(0));
- };
-
- int num_outputs = reducer->num_parameters() / 2;
- absl::InlinedVector<TypedPointer, 2> current_outputs;
- for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
- const auto& state = GetCalculationStateFor(reduction, output_idx);
- current_outputs.push_back(
- {state.partial_result_address,
- state.partial_result_address->getAllocatedType()});
- }
-
- const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
- const Tiling& tiling = reduction_info.GetTiling();
- int num_rows_per_warp =
- RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize());
- EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs),
- tiling.GetNumThreadsPerBlock(),
- num_rows_per_warp);
-
- KernelSupportLibrary ksl(builder);
- llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize()));
-
- auto emit_write_output = [&](llvm::Value* write_condition,
- const absl::Span<TypedPointer const> values) {
- ksl.If("reduction_write_output", write_condition, [&] {
- WriteReductionOutput(tiling_kernel_info, reduction, roots, values);
- });
- };
-
- // The major kept dimension and vector dimension are not tiled, so they're
- // always in bounds.
- llvm::Value* is_in_bounds_y = builder->CreateICmpULT(
- thread_ids[ReductionDimensions::kRowKeptDimension],
- tiling_kernel_info
- .output_tile_bounds[ReductionDimensions::kRowKeptDimension]);
-
- ksl.If("thread_in_bounds", is_in_bounds_y, [&] {
- if (num_rows_per_warp > 1) {
- llvm::Value* is_writing_thread = is_zero(builder->CreateAnd(
- thread_id_x,
- constant(reduction_emitter_.ReducedDimensionSize() - 1)));
- emit_write_output(is_writing_thread, current_outputs);
- return;
- }
-
- ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
- for (int oidx = 0; oidx < num_outputs; oidx++) {
- auto& state = GetCalculationStateFor(reduction, oidx);
- state.shared_cache->Store(
- builder->CreateLoad(current_outputs[oidx].second,
- current_outputs[oidx].first),
- {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
- warp_id},
- builder);
- }
- });
-
- // TODO(cheshire): Don't we want to sync it once for everything in the
- // output? Not once per each?
- reduction_emitter_.EmitSyncThreads();
- ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
- absl::InlinedVector<TypedPointer, 2> selected_values;
- for (int oidx = 0; oidx < num_outputs; oidx++) {
- auto& state = GetCalculationStateFor(reduction, oidx);
- llvm::Value* block_accum_addr = state.shared_cache->Address(
- {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
- thread_id_info.lane_id},
- builder);
-
- llvm::Type* element_type =
- state.partial_result_address->getAllocatedType();
-
- // Ensure initial value address is in generic, not scratch.
- llvm::Value* initial_value_addr =
- CastSharedToGlobal(builder,
- llvm_ir::EmitAllocaAtFunctionEntry(
- element_type, "initial_value_addr", builder),
- element_type, /*name=*/"");
- builder->CreateStore(state.initial_value, initial_value_addr);
-
- llvm::Value* warp_exists = builder->CreateICmpULT(
- thread_id_x,
- constant(tiling.GetThreadsPerBlock()
- [ReductionDimensions::kRowMinorReducedDimension] /
- WarpSize()));
-
- llvm::Value* selected_value = builder->CreateSelect(
- warp_exists, block_accum_addr, initial_value_addr);
-
- selected_values.push_back({selected_value, element_type});
- }
-
- // If only one warp produces the output element, we don't need to emit
- // an inter warp reduce. In our tiling, DimX is the minor reduced
- // dimension. The major reduced dimension is always emitted as a loop.
- // TODO(b/241414088) If only warp is present, then inter-warp
- // communication using shared memory and synchronization using barrier is
- // also unnecessary and should be removed.
- if (tiling.GetThreadsPerBlock()
- [ReductionDimensions::kRowMinorReducedDimension] > WarpSize()) {
- EmitFullWarpShuffleDownLoopForReduce(
- reducer, absl::MakeSpan(selected_values),
- tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1);
- }
-
- emit_write_output(is_zero(thread_id_x), selected_values);
- });
- });
-}
-
-// Same arguments as EmitReductionOutputForRowReduction.
-void ReductionGroupEmitter::EmitReductionOutputForColumnReduction(
- const TilingKernelInfo& tiling_kernel_info,
- const HloReduceInstruction* reduction,
- const std::vector<const HloInstruction*>& roots) const {
- auto* builder = reduction_emitter_.builder_;
- KernelSupportLibrary ksl(builder);
- const HloComputation* reducer = reduction->to_apply();
- const auto& thread_id_info = tiling_kernel_info.thread_id_info;
- const auto& thread_ids = thread_id_info.thread_ids;
-
- auto constant = [&](uint64_t c) -> llvm::Constant* {
- return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
- };
- auto is_zero = [&](llvm::Value* value) {
- return builder->CreateICmpEQ(value, constant(0));
- };
- const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
- const Tiling& tiling = reduction_info.GetTiling();
- int num_outputs = reducer->num_parameters() / 2;
-
- auto* kept_index = thread_ids[ReductionDimensions::kColMinorKeptDimension];
- auto* reduced_index = thread_ids[ReductionDimensions::kColReducedDimension];
-
- // Store the transpose in shared memory.
- for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
- const auto& state = GetCalculationStateFor(reduction, output_idx);
- auto* current_output_value =
- builder->CreateLoad(state.partial_result_address->getAllocatedType(),
- state.partial_result_address);
- state.shared_cache->Store(current_output_value, {kept_index, reduced_index},
- builder);
- }
-
- reduction_emitter_.EmitSyncThreads();
-
- // Get transposed element from shared memory.
- absl::InlinedVector<TypedPointer, 2> shmem_transposed_addrs;
- for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
- const auto& state = GetCalculationStateFor(reduction, output_idx);
- auto* shmem_transposed_addr =
- state.shared_cache->Address({reduced_index, kept_index}, builder);
- shmem_transposed_addrs.push_back(
- {shmem_transposed_addr, state.shared_cache->GetElementType()});
- }
-
- EmitFullWarpShuffleDownLoopForReduce(reducer,
- absl::MakeSpan(shmem_transposed_addrs),
- tiling.GetNumThreadsPerBlock(),
- /*num_results_per_warp=*/1);
-
- // Some warps in the block are completely outside of the bound of the
- // tensor, so they should not write any output at all.
- llvm::Value* has_output = builder->CreateAnd(
- builder->CreateICmpULT(
- reduced_index,
- tiling_kernel_info
- .output_tile_bounds[ReductionDimensions::kColMinorKeptDimension]),
- builder->CreateICmpULT(
- kept_index,
- tiling_kernel_info
- .output_tile_bounds[ReductionDimensions::kColReducedDimension]));
-
- ksl.If("reduction_write_output",
- builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
- WriteReductionOutput(tiling_kernel_info, reduction, roots,
- shmem_transposed_addrs);
- });
-}
-
-// Generate a single element of the tile (update the accumulator state) for a
-// given reducer.
-void ReductionGroupEmitter::GenerateElementForReducer(
- const HloReduceInstruction* reduction,
- const llvm_ir::IrArray::Index& index) const {
- HloComputation* reducer = reduction->to_apply();
- auto* builder = reduction_emitter_.builder_;
- CHECK_EQ(reducer->num_parameters() % 2, 0);
-
- absl::InlinedVector<llvm::Value*, 2> reduction_accumulators;
- absl::InlinedVector<llvm::Value*, 2> reduction_input_value;
- for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) {
- const auto& state = GetCalculationStateFor(reduction, red_idx);
-
- llvm::AllocaInst* input_address = state.input_address;
- auto input_index =
- index.SourceIndexOfBitcast(reduction->operand(0)->shape(), builder);
- llvm::Value* const input_ir_value = *state.input_gen(input_index);
- builder->CreateStore(input_ir_value, input_address);
- reduction_accumulators.push_back(state.partial_result_address);
- reduction_input_value.push_back(input_address);
- }
-
- absl::InlinedVector<llvm::Value*, 4> reduction_params;
- for (llvm::Value* acc : reduction_accumulators) {
- reduction_params.push_back(acc);
- }
- for (llvm::Value* value : reduction_input_value) {
- reduction_params.push_back(value);
- }
-
- // Emit a call to the variadic reducer. Since it may be returning a
- // tuple, we can't return it directly as a value. Instead, before
- // the call, we create N (N = # arguments in the tuple) allocas, one
- // for each returned argument, then when we make the call we pass N
- // pointers as last parameters, the called computation writes into
- // those pointers, and we have returned values on the stack (as well
- // as pointers to them).
- absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
- CallNestedComputationWithScalarAddrs(
- builder, reduction_emitter_.ir_emitter_context_, *reducer,
- reduction_params);
- TF_CHECK_OK(returned_scalars.status());
-
- for (int i = 0; i < returned_scalars->size(); i++) {
- builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]);
- }
-}
-
-// Emits code for reductions in the output_instructions.
-absl::Status ReductionEmitter::EmitIRForReduction(
- absl::Span<const HloInstruction* const> instr_index_group,
- FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
- const Shape& input_shape) {
- ExtraOutputGensMap extra_output_gens;
- absl::flat_hash_map<const HloReduceInstruction*,
- std::vector<const HloInstruction*>>
- heroes_to_roots;
- // Keep a list of deduplicated heroes separate from heroes_to_roots to make
- // the CodeGen deterministic.
- std::vector<const HloReduceInstruction*> heroes;
-
- for (const HloInstruction* hlo : instr_index_group) {
- auto& hero = FindNonTrivialHero(*hlo);
- if (IsRealReductionHero(*hlo, hero)) {
- auto reduction = Cast<HloReduceInstruction>(&hero);
- if (heroes_to_roots.find(reduction) == heroes_to_roots.end()) {
- heroes.push_back(reduction);
- }
- heroes_to_roots[reduction].push_back(hlo);
- } else {
- extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo);
- }
- }
-
- CHECK(!heroes.empty()) << " expect at least one reduce instructions.";
- const Tiling& tiling = reduction_codegen_info_.GetTiling();
- CHECK_EQ(tiling.GetNumThreadsPerBlock() % WarpSize(), 0);
- ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays,
- fused_emitter);
-
- TF_ASSIGN_OR_RETURN(
- TilingKernelInfo tiling_kernel_info,
- EmitTilingKernel(
- builder_, tiling, index_ty_,
- [&](const TilingThreadIdInfo& thread_id_info,
- const llvm_ir::IrArray::Index& tile_index,
- absl::Span<llvm::Value* const> tile_dimensions) {
- auto emit_element =
- [&](absl::Span<llvm::Value* const> index_in_tile) {
- auto index = tile_index.AddOffset(index_in_tile, builder_);
-
- // Emit code to generate the input and perform the reduction
- // computation for each reduction instruction.
- for (const HloReduceInstruction* reduce : heroes) {
- group_emitter.GenerateElementForReducer(reduce, index);
- }
-
- // Emit code to generate the output for the non-reduction
- // instructions in the fusion, if any.
- TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce(
- ShapeUtil::MakeShape(
- F32, reduction_codegen_info_.GetTiling().GetShape()),
- index, extra_output_gens));
- };
- EmitTile(builder_, reduction_codegen_info_.GetTiling(),
- thread_id_info, tile_dimensions, emit_element);
- }));
-
- KernelSupportLibrary ksl(builder_);
- for (auto reduce : heroes) {
- if (reduction_codegen_info_.IsRowReduction()) {
- group_emitter.EmitReductionOutputForRowReduction(
- tiling_kernel_info, reduce, heroes_to_roots[reduce]);
- } else {
- group_emitter.EmitReductionOutputForColumnReduction(
- tiling_kernel_info, reduce, heroes_to_roots[reduce]);
- }
- }
-
- return absl::OkStatus();
-}
-
-absl::StatusOr<FusionEmissionResult> ReductionEmitter::EmitInitializers() {
- FusionEmissionResult result;
- if (reduction_codegen_info_.IsRaceFree()) {
- return result;
- }
- // We need to get the dest slice by traversing the slice assigned to
- // fusion, because instructions inside fusion don't have buffer assignment.
- //
- // The order of fusion roots is determined by its position in the result
- // tuple. For example, in the following fused computation
- //
- // %fused_computation {
- // %a = ...
- // &b = ...
- // ROOT %root = tuple(%a, %b)
- // }
- //
- // The fusion root with index = 0 is %a, and the fusion root %b has index 1.
- // Therefore we can get the ordered slices by calling ForEachSubshape on the
- // result shape.
- std::vector<BufferAllocation::Slice> slices;
- TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
- fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) {
- if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) {
- return absl::OkStatus();
- }
-
- TF_ASSIGN_OR_RETURN(
- BufferAllocation::Slice slice,
- ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_,
- index));
- slices.push_back(slice);
- return absl::OkStatus();
- }));
-
- absl::Span<HloInstructionAdaptor const> fusion_roots =
- analysis_.fusion_roots();
- for (int i = 0; i < fusion_roots.size(); ++i) {
- const HloInstruction* fusion_root = &fusion_roots[i].instruction();
-
- if (IsReductionFromOrToContiguousDimensions(*fusion_root)) {
- TF_ASSIGN_OR_RETURN(
- result.thunks.emplace_back(),
- BuildFusedInitializerThunk(fusion_root, slices[i], i));
- }
- }
- return result;
-}
-
-absl::Status ReductionEmitter::EmitKernel(
- const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs) {
- const HloComputation* fused_computation =
- fusion_.fused_instructions_computation();
- FusedIrEmitter fused_emitter(elemental_emitter_);
- for (int i = 0; i < fused_computation->num_parameters(); i++) {
- HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
- fused_emitter.BindGenerator(
- *fused_operand, [builder = builder_, input = inputs[i],
- fused_operand](const llvm_ir::IrArray::Index& index) {
- return input.EmitReadArrayElement(index, builder,
- fused_operand->name());
- });
- }
-
- // Get outputs.
- ReductionOutputMap result_ir_arrays;
-
- int ir_arrays_idx = 0;
- for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) {
- int get_num_results = GetNumOutputs(root.shape());
- result_ir_arrays[&root.instruction()] =
- absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results);
- ir_arrays_idx += get_num_results;
- }
-
- KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll);
-
- // Use raw block_id_y to select the i-th parallel reduction to run. Using
- // block_id_y instead of block_id_x simplifies the index calculation
- // for reduction code generation as the block_id_y is orthogonal to
- // the indices used within the reductions.
- const auto& instr_index_groups =
- reduction_codegen_info_.GetGroups().grouped_roots;
- Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape();
-
- llvm::Value* block_id_y = gpu::EmitCallToTargetIntrinsic(
- gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_);
- llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
- llvm::cast<llvm::Instruction>(block_id_y),
- builder_->GetInsertBlock()->getModule());
- block_id_y = builder_->CreateZExtOrTrunc(block_id_y, builder_->getInt32Ty());
- block_id_y->setName("block.id.y");
- for (int i = 0; i < instr_index_groups.size(); ++i) {
- TF_RETURN_IF_ERROR(ksl.IfWithStatus(
- absl::StrCat("reduce-group-", i),
- builder_->CreateICmpEQ(block_id_y, builder_->getInt32(i)), [&] {
- return EmitIRForReduction(instr_index_groups[i], fused_emitter,
- result_ir_arrays, reduce_operand_shape);
- }));
- }
-
- return absl::OkStatus();
-}
-
-} // namespace
-
-absl::StatusOr<FusionEmissionResult> ReductionFusion::EmitInitializers(
- IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion) const {
- llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext());
- return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
- fusion, &builder)
- .EmitInitializers();
-}
-
-absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const {
- return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
- fusion, builder)
- .EmitKernel(launch_dims, inputs, outputs);
-}
-
-int ReductionInfo::GetRowsPerWarp() const {
- if (!is_row_reduction_) return 1;
- return RowReductionGetRowsPerWarp(
- tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]);
-}
-
-LaunchDimensions ReductionInfo::launch_dimensions() const {
- size_t blocks_y = groups_.grouped_roots.size();
- return {se::BlockDim(/*x=*/tiling_.GetNumBlocks(),
- /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
- se::ThreadDim(/*x=*/tiling_.GetNumThreadsPerBlock(),
- /*y=*/1, /*z=*/1)};
-}
-
-ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) {
- auto* hero_reduction = analysis.FindHeroReduction();
- CHECK_NE(hero_reduction, nullptr);
- Shape input_shape = hero_reduction->operand(0)->shape();
- ReductionDimensions reduction_dimensions =
- GetReductionKindAndContiguousComponents(*hero_reduction);
- auto shape = reduction_dimensions.dimensions;
- VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
- << " " << shape[0] << " " << shape[1] << " " << shape[2];
- Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions);
-
- int64_t num_threads_y =
- reduction_dimensions.is_row_reduction ? 1 : WarpSize();
- int64_t rows_per_warp =
- reduction_dimensions.is_row_reduction
- ? RowReductionGetRowsPerWarp(
- shape[ReductionDimensions::kRowMinorReducedDimension])
- : 1;
- int64_t num_threads_x = [&] {
- if (reduction_dimensions.is_row_reduction) {
- if (rows_per_warp > 1) {
- return shape[ReductionDimensions::kRowMinorReducedDimension];
- }
- int64_t max_block_size =
- MinThreadsXRowReduction(hero_reduction->GetModule()->config());
- return std::min(
- max_block_size,
- RoundUpTo(
- CeilOfRatio(shape[ReductionDimensions::kRowMinorReducedDimension],
- reduction_tiling
- [ReductionDimensions::kRowMinorReducedDimension]),
- WarpSize()));
- }
- return WarpSize();
- }();
-
- // If we're limited by the size of the x dimension, add additional parallelism
- // in the y dimension. The code generator doesn't currently support
- // parallelizing the z dimension (major reduced dimensions). The general
- // recommendation is to use between 128 and 512 threads, so we just go for
- // 256. See https://forums.developer.nvidia.com/t/55529
- constexpr int64_t kThreadsPerBlockTarget = 256;
- if (reduction_dimensions.is_row_reduction &&
- num_threads_x * 2 <= kThreadsPerBlockTarget) {
- int64_t kept_size =
- reduction_dimensions.dimensions[ReductionDimensions::kRowKeptDimension];
- // Increase the size of the y dimension as long as there's remaining
- // parallelism.
- if (kept_size * num_threads_x <= kThreadsPerBlockTarget) {
- num_threads_y = kept_size;
- // num_threads_x is a power of two, but it may be less than 32. If dim_y
- // is also small, we may have to increase the bound so the total number of
- // threads is a multiple of 32.
- while ((num_threads_x * num_threads_y) % 32) ++num_threads_y;
- } else {
- num_threads_y = kThreadsPerBlockTarget / num_threads_x;
- }
- }
-
- int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x,
- reduction_tiling);
-
- absl::InlinedVector<int64_t, 4> num_threads{1, num_threads_y, num_threads_x};
- absl::InlinedVector<int64_t, 4> tiled_shape{shape[0], shape[1],
- shape[2] / vector_size};
- absl::InlinedVector<int64_t, 4> tile_per_thread{
- reduction_tiling[0], reduction_tiling[1],
- std::max<int64_t>(reduction_tiling[2] / vector_size, 1)};
- if (rows_per_warp > 1) {
- // If we produce more than one element per thread, that means the reduced
- // dimension is small and it can't be tiled - we already have more threads
- // in a warp than the size of the reduced dimension. The code generator
- // doesn't currently support tiling the kept dimension, because it just
- // uses the thread ID as the coordinate.
- tile_per_thread[2] = 1;
- }
- if (vector_size != 1) {
- num_threads.push_back(1); // The vector dimension is a loop.
- tiled_shape.push_back(vector_size);
- tile_per_thread.push_back(vector_size);
- }
-
- Tiling tiling(tiled_shape, tile_per_thread, num_threads,
- /*loops_to_unroll=*/{false, false, true, false});
- bool reduction_is_race_free = ReductionIsRaceFree(
- hero_reduction->GetModule()->config(), reduction_dimensions);
- return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction,
- reduction_is_race_free,
- GroupDisjointReductions(analysis, /*for_mlir=*/false),
- hero_reduction);
-}
-
-std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const {
- if (!groups_.is_reduction_root[root_index]) {
- auto map = ComposeIndexingMaps(
- GetIndexingMapForTiling(tiling_, ctx),
- GetBitcastMap(tiling_.GetXlaShape(),
- analysis_.fusion_root(root_index).shape(), ctx));
- AddGroupIdConstraint(map, root_index, groups_);
- return map;
- }
- const auto& hero = analysis_.fusion_hero(root_index).instruction();
-
- auto block_offsets = GetBlockOffsetsForTiling(tiling_, ctx);
- auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx),
- tiling_.GetThreadsPerBlock());
-
- auto physical_shape =
- ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape());
- std::vector<DimVar> dimension_ranges{
- {{0, tiling_.GetNumThreadsPerBlock() - 1}},
- {},
- {},
- {{0, tiling_.GetNumBlocks() - 1}},
- {{0, static_cast<int64_t>(groups_.grouped_roots.size() - 1)}},
- {},
- };
-
- constexpr int kRowKept = ReductionDimensions::kRowKeptDimension;
- constexpr int kRowMinorReduced =
- ReductionDimensions::kRowMinorReducedDimension;
-
- constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension;
- constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension;
- constexpr int kColReduced = ReductionDimensions::kColReducedDimension;
-
- auto map = [&]() {
- if (is_row_reduction_) {
- IndexingMap linear_index(
- mlir::AffineMap::get(
- 6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept],
- ctx),
- dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
- int rows_per_warp = GetRowsPerWarp();
- if (rows_per_warp > 1) {
- linear_index.AddConstraint(
- thread_ids[kRowMinorReduced] % (WarpSize() / rows_per_warp),
- {0, 0});
- } else {
- linear_index.AddConstraint(thread_ids[kRowMinorReduced], {0, 0});
- }
- return ComposeIndexingMaps(
- linear_index, GetBitcastMap(ShapeUtil::MakeShape(
- PRED, {tiling_.GetShape()[kRowKept]}),
- physical_shape, ctx));
- }
-
- mlir::SmallVector<mlir::AffineExpr> projected_dims{
- block_offsets.getResult(kColMajorKept),
- block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]};
- std::vector<RangeVar> range_vars;
- if (thread_ids.size() == 4) {
- int vector_size = tiling_.GetThreadTileSize().back();
- range_vars.push_back({0, vector_size - 1});
- projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx));
- }
- IndexingMap projected_index(
- mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx),
- dimension_ranges, range_vars, /*rt_vars=*/{});
-
- projected_index.AddConstraint(
- mlir::getAffineDimExpr(
- KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) %
- WarpSize(),
- {0, 0});
- if (!is_row_reduction_) {
- projected_index.AddConstraint(
- projected_index.GetAffineMap().getResult(1),
- {0, tiling_.GetShape()[ReductionDimensions::kColMinorKeptDimension] -
- 1});
- }
-
- return ComposeIndexingMaps(
- projected_index,
- GetBitcastMap(ShapeUtil::DeleteDimension(
- ReductionDimensions::kColReducedDimension,
- tiling_.GetXlaShape()),
- physical_shape, ctx));
- }();
-
- AddGroupIdConstraint(map, root_index, groups_);
- map.Simplify();
- return map;
-}
-
-std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const {
- const auto& hero = analysis_.fusion_hero(root_index).instruction();
- if (groups_.is_reduction_root[root_index] &&
- hero_operand_index >= hero.operand_count() / 2) {
- // We don't have indexing for the init values.
- return std::nullopt;
- }
- if (!groups_.is_reduction_root[root_index]) {
- return ComposeIndexingMaps(
- *ComputeThreadIdToOutputIndexing(root_index, ctx),
- *ComputeOutputToInputIndexing(
- &analysis_.fusion_root(root_index).instruction(), 0, ctx)
- .indexing_maps[hero_operand_index]
- .begin());
- }
-
- auto map = ComposeIndexingMaps(
- GetIndexingMapForTiling(tiling_, ctx),
- GetBitcastMap(tiling_.GetXlaShape(),
- hero.operand(hero_operand_index)->shape(), ctx));
- AddGroupIdConstraint(map, root_index, groups_);
- map.Simplify();
- return map;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.h b/third_party/xla/xla/service/gpu/fusions/reduction.h
deleted file mode 100644
index a15462f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/reduction.h
+++ /dev/null
@@ -1,190 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_
-#define XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_
-
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/reduction_base.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-
-namespace xla {
-namespace gpu {
-
-class ReductionInfo {
- public:
- static ReductionInfo Create(const HloFusionAnalysis& analysis);
-
- const Tiling& GetTiling() const { return tiling_; }
- const ReductionGroups& GetGroups() const { return groups_; }
- Shape GetReduceOperandShape() const {
- return first_reduce_->operand(0)->shape();
- }
-
- bool IsRowReduction() const { return is_row_reduction_; }
- bool IsRaceFree() const { return is_race_free_; }
- int GetRowsPerWarp() const;
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const;
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const;
-
- LaunchDimensions launch_dimensions() const;
-
- private:
- ReductionInfo(const HloFusionAnalysis& analysis, Tiling tiling,
- bool is_row_reduction, bool is_race_free,
- ReductionGroups groups, const HloInstruction* first_reduce)
- : analysis_(analysis),
- tiling_(tiling),
- is_row_reduction_(is_row_reduction),
- is_race_free_(is_race_free),
- groups_(std::move(groups)),
- first_reduce_(first_reduce) {}
-
- const HloFusionAnalysis& analysis_;
- Tiling tiling_;
- bool is_row_reduction_;
- bool is_race_free_;
- ReductionGroups groups_;
- const HloInstruction* first_reduce_;
-};
-
-// Generates code for reduction to contiguous dimensions.
-//
-// Row reduction uses the following algorithm described in CUDA-like
-// pseudocode:
-//
-// ```
-// __global__ void reduce(int num_rows, float *in, float out) {
-// __shared__ float[32] cache;
-// int offset = blockDim.x * blockIdx.x + threadIdx.x;
-// if (offset >= num_rows) return;
-// int tile_bound = std::min(offset + kTileSizeX, num_rows);
-// float accum = 0;
-// for (int i=offset; i<num_rows; i+= blockDim.x) {
-// accum += in[i];
-// }
-// accum = warp_reduce(accum);
-// if (threadIdx.x % WarpSize == 0) {
-// cache[threadIdx.x / WarpSize] = accum;
-// }
-// __syncthreads();
-// if (threadIdx.x / WarpSize == 0) {
-// bool warp_exists = threadIdx.x < (blockDim.x / WarpSize);
-// float block_accum = warp_exists ? cache[threadIdx.x % WarpSize] : 0;
-// block_accum = warp_reduce(accum);
-// if (threadIdx.x == 0) {
-// out += block_accum;
-// }
-// }
-// }
-// ```
-//
-// Column reduction uses the following algorithm:
-//
-// ```
-// void reduce(float** in, float* out) {
-// __shared__ float[32][33] cache;
-// int thread_id = GetThreadId();
-// int block_id = GetBlockId();
-// int tile_size = 128;
-//
-// float accum = 0;
-// for (int i=0; i<tile_size; i++) {
-// accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x];
-// }
-// cache[thread_id.x][thread_id.y] = accum;
-//
-// __syncthreads();
-// accum = cache[thread_id.y][thread_id.x];
-// accum = warp_reduce(accum); // Sum all the values of `accum` in the same
-// // warp.
-//
-// if (thread_id.y % 32 == 0) {
-// out[block_id * 32 + thread_id.x] = accum;
-// }
-// }
-// ```
-//
-// Moreover, a heuristic is implemented to divide the reduce instructions
-// into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
-// for details about the heuristic.) Reduce instructions in the same group
-// will run sequentially while different groups will run in parallel.
-//
-// we use raw block_id_y to select the reduce groups for execution without
-// complicating the index calculation in the code generation of the reduce
-// instructions. In other words, a block_id_y is assigned to a group and so
-// different groups can be run in parallel.
-class ReductionFusion : public KernelFusionEmitterBase {
- public:
- explicit ReductionFusion(const HloFusionAnalysis& analysis)
- : analysis_(analysis), reduction_info_(ReductionInfo::Create(analysis)) {}
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const override {
- return reduction_info_.ComputeThreadIdToOutputIndexing(root_index, ctx);
- }
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const override {
- return reduction_info_.ComputeThreadIdToInputIndexing(
- root_index, hero_operand_index, ctx);
- }
-
- LaunchDimensions launch_dimensions() const override {
- return reduction_info_.launch_dimensions();
- }
-
- const ReductionInfo& reduction_info() const { return reduction_info_; }
-
- protected:
- absl::StatusOr<FusionEmissionResult> EmitInitializers(
- IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion) const override;
-
- absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const override;
-
- private:
- const HloFusionAnalysis& analysis_;
- ReductionInfo reduction_info_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc
index 7895108..e0af3b4 100644
--- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc
+++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc
@@ -36,7 +36,6 @@
#include "xla/hlo/utils/hlo_query.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
index d58f809..075678d 100644
--- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
@@ -47,9 +47,9 @@
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/type_util.h"
#include "xla/service/gpu/fusions/reduction_base.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
@@ -79,79 +79,6 @@
constexpr int kRowKept = ReductionDimensions::kRowKeptDimension;
constexpr int kRowMinorReduced = ReductionDimensions::kRowMinorReducedDimension;
-LaunchDimensions MlirReductionFusion::launch_dimensions() const {
- size_t blocks_y = groups_.grouped_roots.size();
- return {se::BlockDim(/*x=*/Product(num_blocks_),
- /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
- se::ThreadDim(/*x=*/Product(num_threads_),
- /*y=*/1, /*z=*/1)};
-}
-
-MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis)
- : analysis_(analysis) {
- auto* hero_reduction = analysis.FindHeroReduction();
- CHECK_NE(hero_reduction, nullptr);
- Shape input_shape = hero_reduction->operand(0)->shape();
- reduction_dimensions_ =
- GetReductionKindAndContiguousComponents(*hero_reduction);
- VLOG(10) << reduction_dimensions_;
-
- CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(),
- reduction_dimensions_))
- << "Non-race-free reductions should have been decomposed. Did "
- "tree_reduction_rewriter run?";
-
- groups_ = GroupDisjointReductions(analysis, /*for_mlir=*/true);
- first_reduce_ = hero_reduction;
-
- const auto& groups = GetGroups();
- int num_groups = groups.grouped_roots.size();
- side_output_roots_.resize(num_groups);
- reduction_heroes_.resize(num_groups);
- reduction_roots_.resize(num_groups);
-
- absl::flat_hash_set<const HloInstruction*> seen_heroes;
- for (auto [root_adaptor, hero_adaptor, is_reduction, group_id] :
- llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(),
- groups.is_reduction_root, groups.group_id_per_root)) {
- const HloInstruction* root = &root_adaptor.instruction();
- const HloInstruction* hero = &hero_adaptor.instruction();
- if (is_reduction) {
- if (seen_heroes.insert(hero).second) {
- reduction_heroes_[group_id].push_back(hero);
- }
- reduction_roots_[group_id].push_back(root);
- } else {
- side_output_roots_[group_id].push_back(root);
- }
- }
-}
-
-IndexingMap MlirReductionFusion::GetIndexingMap(
- llvm::ArrayRef<mlir::AffineExpr> results,
- absl::Span<int64_t const> symbol_sizes) const {
- auto* ctx = results.front().getContext();
- auto num_groups = static_cast<int64_t>(reduction_heroes_.size());
- return IndexingMap{
- AffineMap::get(6, symbol_sizes.size(), results, ctx),
- DimVarsFromTensorSizes(
- {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}),
- RangeVarsFromTensorSizes(symbol_sizes),
- /*rt_vars=*/{}};
-}
-
-IndexingMap MlirReductionFusion::GetThreadIndexingMap(
- llvm::ArrayRef<mlir::AffineExpr> results,
- absl::Span<std::pair<mlir::AffineExpr, Interval> const> constraints,
- absl::Span<int64_t const> symbol_sizes) const {
- auto affine_map = AffineMap::get(1, symbol_sizes.size(), results,
- results.front().getContext());
- return IndexingMap{affine_map,
- DimVarsFromTensorSizes({Product(num_threads_)}),
- RangeVarsFromTensorSizes(symbol_sizes),
- /*rt_vars=*/{}, constraints};
-}
-
struct PerThreadOutputs {
// The partially reduced scalars for each thread.
HloValueMap reduction_scalars;
@@ -232,140 +159,6 @@
SmallVector<Value> thread_and_block_ids;
};
-std::vector<mlir_converter::EpilogueSpecification>
-MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion,
- MLIRContext* mlir_context) const {
- std::vector<mlir_converter::EpilogueSpecification> epilogues;
- epilogues.reserve(reduction_heroes_.size());
- for (const auto& [heroes, roots] :
- llvm::zip(reduction_heroes_, reduction_roots_)) {
- epilogues.push_back(
- mlir_converter::EpilogueSpecification::FromOutputIndexing(
- analysis_, heroes, roots, *this, mlir_context));
- }
- // Add empty epilogues for the side outputs. This ensures their roots don't
- // get "fused" into the tuple function.
- for (const auto& roots : side_output_roots_) {
- for (const auto* root : roots) {
- epilogues.push_back(
- mlir_converter::EpilogueSpecification::FromIdentityIndexing(
- root, root, mlir_context));
- }
- }
- return epilogues;
-}
-
-absl::Status MlirReductionFusion::EmitEntryFunction(
- const PartitionedComputations& computations,
- const mlir_converter::CallTargetProvider& call_targets,
- mlir::func::FuncOp entry_function,
- const HloFusionInstruction& fusion) const {
- EmitterState state{*this, entry_function, fusion, computations, call_targets};
- auto& b = state.builder;
- b.setInsertionPointToStart(entry_function.addEntryBlock());
- state.thread_and_block_ids = EmitThreadAndBlockIds(b);
- if (reduction_heroes_.size() == 1) {
- b.create<mlir::func::ReturnOp>(EmitReduction(0, state));
- return absl::OkStatus();
- }
- SmallVector<int64_t> cases(reduction_heroes_.size() - 1);
- absl::c_iota(cases, 1); // `default` is region 0.
- auto switch_op = b.create<mlir::scf::IndexSwitchOp>(
- entry_function.getResultTypes(), EmitBlockId(b, 1), cases, cases.size());
- b.create<mlir::func::ReturnOp>(switch_op.getResults());
- for (auto [id, region] : llvm::enumerate(switch_op->getRegions())) {
- b.setInsertionPointToStart(®ion.emplaceBlock());
- b.create<mlir::scf::YieldOp>(EmitReduction(id, state));
- }
- return absl::OkStatus();
-}
-
-IndexingMap MlirRowReductionFusion::ComputeReductionInputIndexing(
- mlir::MLIRContext* ctx) const {
- auto thread_id =
- DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
- auto block_id =
- DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
- auto major_reduced = getAffineSymbolExpr(0, ctx);
- auto minor_reduced = getAffineSymbolExpr(1, ctx);
- auto vector_index = getAffineSymbolExpr(2, ctx);
-
- SmallVector<AffineExpr> indices{
- major_reduced,
- block_id[0] * tile_sizes_per_block_[0] + thread_id[0],
- block_id[1] * tile_sizes_per_block_[1] +
- (minor_reduced * num_threads_[1]) + thread_id[1],
- vector_index,
- };
-
- auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
- for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
- map.AddConstraint(result, {0, input_dim - 1});
- }
- return map;
-}
-
-IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing(
- MLIRContext* ctx) const {
- auto thread_id =
- DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
- auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
- : mlir::getAffineDimExpr(3, ctx);
- IndexingMap projected_index =
- GetIndexingMap(block_id * num_threads_[0] + thread_id[0]);
- projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()),
- {0, 0});
- // We don't need a constraint on the loop dimensions, because they are removed
- // by GetIndexingMap (since they don't show up in the output index
- // computation).
- return projected_index;
-}
-
-IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing(
- mlir::MLIRContext* ctx) const {
- auto thread_id =
- DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
- auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
- : mlir::getAffineDimExpr(3, ctx);
- auto major_reduced = getAffineSymbolExpr(0, ctx);
- auto vector_index = getAffineSymbolExpr(1, ctx);
-
- SmallVector<AffineExpr> indices{
- major_reduced, block_id * num_threads_[0] + thread_id[0],
- thread_id[1] * tile_sizes_per_thread_[1] + vector_index};
-
- auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
- for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
- map.AddConstraint(result, {0, input_dim - 1});
- }
- return map;
-}
-
-IndexingMap MlirRowReductionFusion::ComputeReductionOutputIndexing(
- MLIRContext* ctx) const {
- auto thread_id =
- DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
- auto block_id =
- DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
- IndexingMap projected_index =
- GetIndexingMap(block_id[0] * tile_sizes_per_block_[0] + thread_id[0]);
- projected_index.AddConstraint(thread_id[1], {0, 0});
- return projected_index;
-}
-
-HloValueMap MlirReductionFusion::GetInits(int group_id,
- EmitterState& state) const {
- HloValueMap result;
- const auto& reductions = reduction_heroes_[group_id];
- for (auto* hero : reductions) {
- int arity = hero->operand_count() / 2;
- result[hero] = ProvideParameterRange(state.computation, hero, arity, arity,
- {}, state.call_target,
- state.entry_function, state.builder);
- }
- return result;
-}
-
PerThreadOutputs MlirReductionFusion::EmitterState::EmitPerThreadElements(
int group_id, const HloValueMap& inits, const SmallVector<Value>& outputs) {
auto tile_indexing =
@@ -558,6 +351,140 @@
});
}
+MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis)
+ : analysis_(analysis) {
+ auto* hero_reduction = analysis.FindHeroReduction();
+ CHECK_NE(hero_reduction, nullptr);
+ Shape input_shape = hero_reduction->operand(0)->shape();
+ reduction_dimensions_ =
+ GetReductionKindAndContiguousComponents(*hero_reduction);
+ VLOG(10) << reduction_dimensions_;
+
+ CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(),
+ reduction_dimensions_))
+ << "Non-race-free reductions should have been decomposed. Did "
+ "tree_reduction_rewriter run?";
+
+ groups_ = GroupDisjointReductions(analysis, /*for_mlir=*/true);
+ first_reduce_ = hero_reduction;
+
+ const auto& groups = GetGroups();
+ int num_groups = groups.grouped_roots.size();
+ side_output_roots_.resize(num_groups);
+ reduction_heroes_.resize(num_groups);
+ reduction_roots_.resize(num_groups);
+
+ absl::flat_hash_set<const HloInstruction*> seen_heroes;
+ for (auto [root_adaptor, hero_adaptor, is_reduction, group_id] :
+ llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(),
+ groups.is_reduction_root, groups.group_id_per_root)) {
+ const HloInstruction* root = &root_adaptor.instruction();
+ const HloInstruction* hero = &hero_adaptor.instruction();
+ if (is_reduction) {
+ if (seen_heroes.insert(hero).second) {
+ reduction_heroes_[group_id].push_back(hero);
+ }
+ reduction_roots_[group_id].push_back(root);
+ } else {
+ side_output_roots_[group_id].push_back(root);
+ }
+ }
+}
+
+IndexingMap MlirReductionFusion::GetIndexingMap(
+ llvm::ArrayRef<mlir::AffineExpr> results,
+ absl::Span<int64_t const> symbol_sizes) const {
+ auto* ctx = results.front().getContext();
+ auto num_groups = static_cast<int64_t>(reduction_heroes_.size());
+ return IndexingMap{
+ AffineMap::get(6, symbol_sizes.size(), results, ctx),
+ DimVarsFromTensorSizes(
+ {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}),
+ RangeVarsFromTensorSizes(symbol_sizes),
+ /*rt_vars=*/{}};
+}
+
+IndexingMap MlirReductionFusion::GetThreadIndexingMap(
+ llvm::ArrayRef<mlir::AffineExpr> results,
+ absl::Span<std::pair<mlir::AffineExpr, Interval> const> constraints,
+ absl::Span<int64_t const> symbol_sizes) const {
+ auto affine_map = AffineMap::get(1, symbol_sizes.size(), results,
+ results.front().getContext());
+ return IndexingMap{affine_map,
+ DimVarsFromTensorSizes({Product(num_threads_)}),
+ RangeVarsFromTensorSizes(symbol_sizes),
+ /*rt_vars=*/{}, constraints};
+}
+
+LaunchDimensions MlirReductionFusion::launch_dimensions() const {
+ size_t blocks_y = groups_.grouped_roots.size();
+ return {se::BlockDim(/*x=*/Product(num_blocks_),
+ /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
+ se::ThreadDim(/*x=*/Product(num_threads_),
+ /*y=*/1, /*z=*/1)};
+}
+
+std::vector<mlir_converter::EpilogueSpecification>
+MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion,
+ MLIRContext* mlir_context) const {
+ std::vector<mlir_converter::EpilogueSpecification> epilogues;
+ epilogues.reserve(reduction_heroes_.size());
+ for (const auto& [heroes, roots] :
+ llvm::zip(reduction_heroes_, reduction_roots_)) {
+ epilogues.push_back(
+ mlir_converter::EpilogueSpecification::FromOutputIndexing(
+ analysis_, heroes, roots, *this, mlir_context));
+ }
+ // Add empty epilogues for the side outputs. This ensures their roots don't
+ // get "fused" into the tuple function.
+ for (const auto& roots : side_output_roots_) {
+ for (const auto* root : roots) {
+ epilogues.push_back(
+ mlir_converter::EpilogueSpecification::FromIdentityIndexing(
+ root, root, mlir_context));
+ }
+ }
+ return epilogues;
+}
+
+absl::Status MlirReductionFusion::EmitEntryFunction(
+ const PartitionedComputations& computations,
+ const mlir_converter::CallTargetProvider& call_targets,
+ mlir::func::FuncOp entry_function,
+ const HloFusionInstruction& fusion) const {
+ EmitterState state{*this, entry_function, fusion, computations, call_targets};
+ auto& b = state.builder;
+ b.setInsertionPointToStart(entry_function.addEntryBlock());
+ state.thread_and_block_ids = EmitThreadAndBlockIds(b);
+ if (reduction_heroes_.size() == 1) {
+ b.create<mlir::func::ReturnOp>(EmitReduction(0, state));
+ return absl::OkStatus();
+ }
+ SmallVector<int64_t> cases(reduction_heroes_.size() - 1);
+ absl::c_iota(cases, 1); // `default` is region 0.
+ auto switch_op = b.create<mlir::scf::IndexSwitchOp>(
+ entry_function.getResultTypes(), EmitBlockId(b, 1), cases, cases.size());
+ b.create<mlir::func::ReturnOp>(switch_op.getResults());
+ for (auto [id, region] : llvm::enumerate(switch_op->getRegions())) {
+ b.setInsertionPointToStart(®ion.emplaceBlock());
+ b.create<mlir::scf::YieldOp>(EmitReduction(id, state));
+ }
+ return absl::OkStatus();
+}
+
+HloValueMap MlirReductionFusion::GetInits(int group_id,
+ EmitterState& state) const {
+ HloValueMap result;
+ const auto& reductions = reduction_heroes_[group_id];
+ for (auto* hero : reductions) {
+ int arity = hero->operand_count() / 2;
+ result[hero] = ProvideParameterRange(state.computation, hero, arity, arity,
+ {}, state.call_target,
+ state.entry_function, state.builder);
+ }
+ return result;
+}
+
std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToInputIndexing(
int64_t root_index, int64_t hero_operand_index, MLIRContext* ctx) const {
const auto& hero = analysis_.fusion_hero(root_index).instruction();
@@ -645,182 +572,6 @@
return outputs;
}
-MlirRowReductionFusion::MlirRowReductionFusion(
- const HloFusionAnalysis& analysis)
- : MlirReductionFusion(analysis) {
- CHECK(reduction_dimensions_.is_row_reduction);
- Vector3 shape = reduction_dimensions_.dimensions;
- CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1);
- constexpr int64_t kMinorReducedElementsPerThread = 16;
-
- int64_t num_threads_kept = 1;
- int64_t num_threads_reduced = [&] {
- int64_t max_block_size =
- MinThreadsXRowReduction(first_reduce_->GetModule()->config());
- return std::min(max_block_size,
- RoundUpTo(CeilOfRatio(shape[kRowMinorReduced],
- kMinorReducedElementsPerThread),
- WarpSize()));
- }();
-
- // If we're limited by the size of the x dimension, add additional parallelism
- // in the y dimension. The code generator doesn't currently support
- // parallelizing the z dimension (major reduced dimensions). The general
- // recommendation is to use between 128 and 512 threads, so we just go for
- // 256. See https://forums.developer.nvidia.com/t/55529
- constexpr int64_t kThreadsPerBlockTarget = 256;
- if (num_threads_reduced * 2 <= kThreadsPerBlockTarget) {
- int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
- // Increase the size of the y dimension as long as there's remaining
- // parallelism.
- if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
- num_threads_kept = kept_size;
- } else {
- num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
- }
- }
-
- int vector_size = GetVectorSizeForMlir(analysis, reduction_dimensions_,
- num_threads_reduced);
- num_threads_ = {num_threads_kept, num_threads_reduced};
- // TODO(jreiffers): Get rid of `vector_size` in here.
- input_shape_ = {shape[0], shape[1], shape[2] / vector_size, vector_size};
- // TODO(jreiffers): Tighten ranges based on constraints when simplifying
- // instead of using min here. For example, based on
- //
- // s1 in [0, 127]
- // d0 floordiv 32 + s1 * 32 in [0, 63]
- //
- // Tighten the bound of s1 to [0, 1].
- int minor_reduced_tile_size =
- std::min(kMinorReducedElementsPerThread / vector_size,
- CeilOfRatio(input_shape_[2], num_threads_[1]));
-
- tile_sizes_per_thread_ = {shape[0], minor_reduced_tile_size, vector_size};
- tile_sizes_per_block_ = {num_threads_kept,
- minor_reduced_tile_size * num_threads_reduced};
- num_blocks_ = {CeilOfRatio(input_shape_[1], tile_sizes_per_block_[0]),
- CeilOfRatio(input_shape_[2], tile_sizes_per_block_[1])};
-}
-
-MlirMultiRowReductionFusion::MlirMultiRowReductionFusion(
- const HloFusionAnalysis& analysis)
- : MlirReductionFusion(analysis) {
- CHECK(reduction_dimensions_.is_row_reduction);
- Vector3 shape = reduction_dimensions_.dimensions;
- int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]);
- input_shape_ = {shape[0], shape[1], shape[2]};
- CHECK_GT(rows_per_warp, 1);
-
- auto compute_block_size = [&](int vector_size) {
- int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size;
-
- constexpr int64_t kThreadsPerBlockTarget = 256;
- int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
- int64_t num_threads_kept = 1;
- if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
- num_threads_kept = kept_size;
- } else {
- num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
- }
- num_threads_ = {num_threads_kept, num_threads_reduced};
- tile_sizes_per_thread_ = {shape[0], vector_size};
- num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)};
- };
-
- // Compute the launch grid without vectorization. We use the results to
- // compute the vectorized launch grid.
- compute_block_size(1);
-
- // Normally, we only consider input types for vectorization. However, in
- // multi-row reductions, the input:output ratio is much higher, so we consider
- // both inputs and outputs.
- int smallest_input_or_output_bits =
- std::min(analysis.input_output_info().smallest_input_dtype_bits,
- analysis.input_output_info().smallest_output_dtype_bits);
-
- // This vector size is always valid: we know that the reduced dimension is a
- // power of 2, since otherwise RowReductionGetRowsPerWarp would have
- // returned 1.
- // Our codegen can't currently deal with vectorization across rows, so we
- // limit the vector size to the size of the row. Note that this emitter
- // essentially reverts to the loop emitter in this case, except for side
- // outputs.
- int vector_size = std::min(static_cast<int>(input_shape_[kRowMinorReduced]),
- 32 / smallest_input_or_output_bits);
-
- // We target 8 warps per block, which means there could be up to 8 blocks per
- // SM, but we have no good way of knowing. In practice, enabling vectorization
- // for decently sized reductions at least does not hurt.
- if (num_blocks_.front() > analysis.device_info().core_count() &&
- vector_size > 1) {
- compute_block_size(vector_size);
- }
-}
-
-int MlirMultiRowReductionFusion::GetRowsPerWarp() const {
- return RowReductionGetRowsPerWarp(
- input_shape_[ReductionDimensions::kRowMinorReducedDimension]) *
- tile_sizes_per_thread_[1];
-}
-
-int MlirRowReductionFusion::GetWarpsPerRow() const {
- return CeilOfRatio(num_threads_[1], WarpSize());
-}
-
-IndexingMap MlirRowReductionFusion::GetSharedMemoryReductionReadMap(
- mlir::MLIRContext* ctx) const {
- auto thread_id =
- DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
- auto lane_id = thread_id[1] % WarpSize();
- return GetThreadIndexingMap({thread_id[0], lane_id},
- {{thread_id[1], {0, GetWarpsPerRow() - 1}}});
-}
-
-IndexingMap MlirRowReductionFusion::GetSharedMemoryWriteMap(
- mlir::MLIRContext* ctx) const {
- auto thread_id =
- DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
- // The reduced dimension is tiled; each warp writes one element to shared
- // memory (from lane 0).
- auto lane_id = thread_id[1] % WarpSize();
- auto warp_id = thread_id[1].floorDiv(WarpSize());
- return GetThreadIndexingMap({thread_id[0], warp_id}, {{lane_id, {0, 0}}});
-}
-
-llvm::SmallVector<mlir::Value> MlirRowReductionFusion::EmitReduction(
- int group_id, EmitterState& state) const {
- const auto& reductions = reduction_heroes_[group_id];
-
- HloValueMap inits = GetInits(group_id, state);
- auto per_thread =
- state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
- per_thread.reduction_scalars =
- state.ShuffleReduce(reductions, per_thread.reduction_scalars);
-
- if (GetWarpsPerRow() == 1) {
- // If only a single warp works on an element, we don't need to go through
- // shared memory.
- return EvaluateEpilogue(per_thread.reduction_scalars,
- std::move(per_thread.outputs), state, group_id,
- /*symbol_values=*/{});
- }
-
- return state.ReduceViaSharedMemory(group_id, per_thread, inits);
-}
-
-llvm::SmallVector<mlir::Value> MlirMultiRowReductionFusion::EmitReduction(
- int group_id, EmitterState& state) const {
- HloValueMap inits = GetInits(group_id, state);
- const auto& reductions = reduction_heroes_[group_id];
- auto per_thread =
- state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
- auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars,
- WarpSize() / 2 / GetRowsPerWarp());
- return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state,
- group_id, /*symbol_values=*/{});
-}
-
MlirColumnReductionFusion::MlirColumnReductionFusion(
const HloFusionAnalysis& analysis)
: MlirReductionFusion(analysis) {
@@ -930,5 +681,254 @@
return std::make_unique<MlirColumnReductionFusion>(analysis);
}
+MlirRowReductionFusion::MlirRowReductionFusion(
+ const HloFusionAnalysis& analysis)
+ : MlirReductionFusion(analysis) {
+ CHECK(reduction_dimensions_.is_row_reduction);
+ Vector3 shape = reduction_dimensions_.dimensions;
+ CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1);
+ constexpr int64_t kMinorReducedElementsPerThread = 16;
+
+ int64_t num_threads_kept = 1;
+ int64_t num_threads_reduced = [&] {
+ int64_t max_block_size =
+ MinThreadsXRowReduction(first_reduce_->GetModule()->config());
+ return std::min(max_block_size,
+ RoundUpTo(CeilOfRatio(shape[kRowMinorReduced],
+ kMinorReducedElementsPerThread),
+ WarpSize()));
+ }();
+
+ // If we're limited by the size of the x dimension, add additional parallelism
+ // in the y dimension. The code generator doesn't currently support
+ // parallelizing the z dimension (major reduced dimensions). The general
+ // recommendation is to use between 128 and 512 threads, so we just go for
+ // 256. See https://forums.developer.nvidia.com/t/55529
+ constexpr int64_t kThreadsPerBlockTarget = 256;
+ if (num_threads_reduced * 2 <= kThreadsPerBlockTarget) {
+ int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
+ // Increase the size of the y dimension as long as there's remaining
+ // parallelism.
+ if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
+ num_threads_kept = kept_size;
+ } else {
+ num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
+ }
+ }
+
+ int vector_size = GetVectorSizeForMlir(analysis, reduction_dimensions_,
+ num_threads_reduced);
+ num_threads_ = {num_threads_kept, num_threads_reduced};
+ // TODO(jreiffers): Get rid of `vector_size` in here.
+ input_shape_ = {shape[0], shape[1], shape[2] / vector_size, vector_size};
+ // TODO(jreiffers): Tighten ranges based on constraints when simplifying
+ // instead of using min here. For example, based on
+ //
+ // s1 in [0, 127]
+ // d0 floordiv 32 + s1 * 32 in [0, 63]
+ //
+ // Tighten the bound of s1 to [0, 1].
+ int minor_reduced_tile_size =
+ std::min(kMinorReducedElementsPerThread / vector_size,
+ CeilOfRatio(input_shape_[2], num_threads_[1]));
+
+ tile_sizes_per_thread_ = {shape[0], minor_reduced_tile_size, vector_size};
+ tile_sizes_per_block_ = {num_threads_kept,
+ minor_reduced_tile_size * num_threads_reduced};
+ num_blocks_ = {CeilOfRatio(input_shape_[1], tile_sizes_per_block_[0]),
+ CeilOfRatio(input_shape_[2], tile_sizes_per_block_[1])};
+}
+
+IndexingMap MlirRowReductionFusion::ComputeReductionInputIndexing(
+ mlir::MLIRContext* ctx) const {
+ auto thread_id =
+ DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+ auto block_id =
+ DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
+ auto major_reduced = getAffineSymbolExpr(0, ctx);
+ auto minor_reduced = getAffineSymbolExpr(1, ctx);
+ auto vector_index = getAffineSymbolExpr(2, ctx);
+
+ SmallVector<AffineExpr> indices{
+ major_reduced,
+ block_id[0] * tile_sizes_per_block_[0] + thread_id[0],
+ block_id[1] * tile_sizes_per_block_[1] +
+ (minor_reduced * num_threads_[1]) + thread_id[1],
+ vector_index,
+ };
+
+ auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
+ for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
+ map.AddConstraint(result, {0, input_dim - 1});
+ }
+ return map;
+}
+
+IndexingMap MlirRowReductionFusion::ComputeReductionOutputIndexing(
+ MLIRContext* ctx) const {
+ auto thread_id =
+ DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+ auto block_id =
+ DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
+ IndexingMap projected_index =
+ GetIndexingMap(block_id[0] * tile_sizes_per_block_[0] + thread_id[0]);
+ projected_index.AddConstraint(thread_id[1], {0, 0});
+ return projected_index;
+}
+
+int MlirRowReductionFusion::GetWarpsPerRow() const {
+ return CeilOfRatio(num_threads_[1], WarpSize());
+}
+
+IndexingMap MlirRowReductionFusion::GetSharedMemoryReductionReadMap(
+ mlir::MLIRContext* ctx) const {
+ auto thread_id =
+ DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
+ auto lane_id = thread_id[1] % WarpSize();
+ return GetThreadIndexingMap({thread_id[0], lane_id},
+ {{thread_id[1], {0, GetWarpsPerRow() - 1}}});
+}
+
+IndexingMap MlirRowReductionFusion::GetSharedMemoryWriteMap(
+ mlir::MLIRContext* ctx) const {
+ auto thread_id =
+ DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
+ // The reduced dimension is tiled; each warp writes one element to shared
+ // memory (from lane 0).
+ auto lane_id = thread_id[1] % WarpSize();
+ auto warp_id = thread_id[1].floorDiv(WarpSize());
+ return GetThreadIndexingMap({thread_id[0], warp_id}, {{lane_id, {0, 0}}});
+}
+
+llvm::SmallVector<mlir::Value> MlirRowReductionFusion::EmitReduction(
+ int group_id, EmitterState& state) const {
+ const auto& reductions = reduction_heroes_[group_id];
+
+ HloValueMap inits = GetInits(group_id, state);
+ auto per_thread =
+ state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
+ per_thread.reduction_scalars =
+ state.ShuffleReduce(reductions, per_thread.reduction_scalars);
+
+ if (GetWarpsPerRow() == 1) {
+ // If only a single warp works on an element, we don't need to go through
+ // shared memory.
+ return EvaluateEpilogue(per_thread.reduction_scalars,
+ std::move(per_thread.outputs), state, group_id,
+ /*symbol_values=*/{});
+ }
+
+ return state.ReduceViaSharedMemory(group_id, per_thread, inits);
+}
+
+MlirMultiRowReductionFusion::MlirMultiRowReductionFusion(
+ const HloFusionAnalysis& analysis)
+ : MlirReductionFusion(analysis) {
+ CHECK(reduction_dimensions_.is_row_reduction);
+ Vector3 shape = reduction_dimensions_.dimensions;
+ int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]);
+ input_shape_ = {shape[0], shape[1], shape[2]};
+ CHECK_GT(rows_per_warp, 1);
+
+ auto compute_block_size = [&](int vector_size) {
+ int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size;
+
+ constexpr int64_t kThreadsPerBlockTarget = 256;
+ int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
+ int64_t num_threads_kept = 1;
+ if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
+ num_threads_kept = kept_size;
+ } else {
+ num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
+ }
+ num_threads_ = {num_threads_kept, num_threads_reduced};
+ tile_sizes_per_thread_ = {shape[0], vector_size};
+ num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)};
+ };
+
+ // Compute the launch grid without vectorization. We use the results to
+ // compute the vectorized launch grid.
+ compute_block_size(1);
+
+ // Normally, we only consider input types for vectorization. However, in
+ // multi-row reductions, the input:output ratio is much higher, so we consider
+ // both inputs and outputs.
+ int smallest_input_or_output_bits =
+ std::min(analysis.input_output_info().smallest_input_dtype_bits,
+ analysis.input_output_info().smallest_output_dtype_bits);
+
+ // This vector size is always valid: we know that the reduced dimension is a
+ // power of 2, since otherwise RowReductionGetRowsPerWarp would have
+ // returned 1.
+ // Our codegen can't currently deal with vectorization across rows, so we
+ // limit the vector size to the size of the row. Note that this emitter
+ // essentially reverts to the loop emitter in this case, except for side
+ // outputs.
+ int vector_size = std::min(static_cast<int>(input_shape_[kRowMinorReduced]),
+ 32 / smallest_input_or_output_bits);
+
+ // We target 8 warps per block, which means there could be up to 8 blocks per
+ // SM, but we have no good way of knowing. In practice, enabling vectorization
+ // for decently sized reductions at least does not hurt.
+ if (num_blocks_.front() > analysis.device_info().core_count() &&
+ vector_size > 1) {
+ compute_block_size(vector_size);
+ }
+}
+
+IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing(
+ mlir::MLIRContext* ctx) const {
+ auto thread_id =
+ DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+ auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
+ : mlir::getAffineDimExpr(3, ctx);
+ auto major_reduced = getAffineSymbolExpr(0, ctx);
+ auto vector_index = getAffineSymbolExpr(1, ctx);
+
+ SmallVector<AffineExpr> indices{
+ major_reduced, block_id * num_threads_[0] + thread_id[0],
+ thread_id[1] * tile_sizes_per_thread_[1] + vector_index};
+
+ auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
+ for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
+ map.AddConstraint(result, {0, input_dim - 1});
+ }
+ return map;
+}
+
+IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing(
+ MLIRContext* ctx) const {
+ auto thread_id =
+ DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+ auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
+ : mlir::getAffineDimExpr(3, ctx);
+ IndexingMap projected_index =
+ GetIndexingMap(block_id * num_threads_[0] + thread_id[0]);
+ projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()),
+ {0, 0});
+ // We don't need a constraint on the loop dimensions, because they are removed
+ // by GetIndexingMap (since they don't show up in the output index
+ // computation).
+ return projected_index;
+}
+
+int MlirMultiRowReductionFusion::GetRowsPerWarp() const {
+ return RowReductionGetRowsPerWarp(
+ input_shape_[ReductionDimensions::kRowMinorReducedDimension]) *
+ tile_sizes_per_thread_[1];
+}
+
+llvm::SmallVector<mlir::Value> MlirMultiRowReductionFusion::EmitReduction(
+ int group_id, EmitterState& state) const {
+ HloValueMap inits = GetInits(group_id, state);
+ const auto& reductions = reduction_heroes_[group_id];
+ auto per_thread =
+ state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
+ auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars,
+ WarpSize() / 2 / GetRowsPerWarp());
+ return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state,
+ group_id, /*symbol_values=*/{});
+}
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
index 761ecb4..4798528 100644
--- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
@@ -23,10 +23,10 @@
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
-#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "xla/error_spec.h"
#include "xla/service/gpu/fusions/mlir_emitter_test_base.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/indexing_test_utils.h"
#include "xla/tsl/lib/core/status_test_util.h"
@@ -54,69 +54,8 @@
}
};
-using MlirRowReductionTest = ReductionTest<MlirRowReductionFusion>;
-using MlirColumnReductionTest = ReductionTest<MlirColumnReductionFusion>;
using MlirMultiRowReductionTest = ReductionTest<MlirMultiRowReductionFusion>;
-constexpr std::string_view kVariadicRowReduction = R"(
- Add {
- scalar_lhs.0 = f32[] parameter(0)
- scalar_rhs.0 = f32[] parameter(1)
- scalar_lhs.1 = f32[] parameter(2)
- scalar_rhs.1 = f32[] parameter(3)
- add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1)
- add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1)
- ROOT t = (f32[], f32[]) tuple(add.0, add.1)
- }
- fused_computation {
- param_0 = f32[2, 3, 2048] parameter(0)
- param_1 = f32[2, 3, 2048] parameter(1)
- param_2 = f32[] parameter(2)
- ROOT d.1 = (f32[2, 3], f32[2, 3])
- reduce(param_0, param_1, param_2, param_2), dimensions={2}, to_apply=Add
- }
- ENTRY main {
- a = f32[2, 3, 2048] parameter(0)
- b = f32[2, 3, 2048] parameter(1)
- c = f32[] constant(0)
- ROOT fusion = (f32[2, 3], f32[2, 3]) fusion(a, b, c),
- kind=kInput, calls=fused_computation
- })";
-
-constexpr std::string_view kF64RowReduction = R"(
- Add {
- lhs = f64[] parameter(0)
- rhs = f64[] parameter(1)
- ROOT add = f64[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f64[100,128] parameter(0)
- param_1 = f64[] parameter(1)
- ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- }
- ENTRY main {
- a = f64[100,128] parameter(0)
- c = f64[] constant(0)
- ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation
- })";
-
-constexpr auto kRowReductionMinorAndMajor = R"(
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[7,100,128] parameter(0)
- param_1 = f32[] parameter(1)
- ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=Add
- }
- ENTRY main {
- a = f32[7,100,128] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation
- })";
-
constexpr auto kMultiRowReductionX8 = R"(
Add {
lhs = f32[] parameter(0)
@@ -179,181 +118,6 @@
ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion
})";
-constexpr std::string_view kRowReductionSideOutput = R"(
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[8,2048] parameter(0)
- param_1 = f32[] parameter(1)
- exp = f32[8,2048] exponential(param_0)
- reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp)
- }
- ENTRY main {
- a = f32[8,2048] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = (f32[8], f32[8,2048]) fusion(a, c), kind=kInput,
- calls=fused_computation
- })";
-
-TEST_F(MlirRowReductionTest, VariadicRowReductionIndexing) {
- auto fusion = GetEmitter(kVariadicRowReduction);
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
- {2, 3, 2048}));
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {2, 3}));
-}
-
-TEST_F(MlirRowReductionTest, VariadicRowReductionCorrectness) {
- EXPECT_TRUE(RunAndCompareNoHloPasses(kVariadicRowReduction, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReduceEpilogue) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
-
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[8,2048] parameter(0)
- param_1 = f32[] parameter(1)
- reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- ROOT log = f32[8] log(reduce)
- }
- ENTRY main {
- a = f32[8,2048] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = f32[8] fusion(a, c), kind=kInput, calls=fused_computation
- })";
- TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: pure_call @Add_add
- // CHECK: shuffle_reduce
- // CHECK: allocate_shared
- // CHECK: sync_threads
- // CHECK: shuffle_reduce
- )"));
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReduceMOFEpilogue) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
-
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- Mul {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT mul = f32[] multiply(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[8,1024] parameter(0)
- param_1 = f32[] parameter(1)
- reduce1 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- reduce2 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Mul
- log = f32[8] log(reduce1)
- abs = f32[8] abs(reduce1)
- neg = f32[8] negate(reduce2)
- ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs)
- }
- ENTRY main {
- a = f32[8,1024] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = (f32[8], f32[8], f32[8]) fusion(a, c), kind=kInput,
- calls=fused_computation
- })";
- TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK-DAG: pure_call @Add_add
- // CHECK-DAG: shuffle_reduce @Add_add
- // CHECK-DAG: pure_call @Mul_mul
- // CHECK-DAG: shuffle_reduce @Mul_mul
- // CHECK: allocate_shared
- // CHECK: allocate_shared
- // CHECK: sync_threads
- // CHECK-DAG: shuffle_reduce @Add_add
- // CHECK-DAG: shuffle_reduce @Mul_mul
- )"));
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReduceMOFGroups) {
- constexpr auto kHloString = R"(
- %add_f32 {
- %x = f32[] parameter(0)
- %y = f32[] parameter(1)
- ROOT %add = f32[] add(%x, %y)
- }
-
- %fused_computation {
- %param0 = f32[1024] parameter(0)
- %param1 = f32[1024] parameter(1)
- %constant0 = f32[] constant(0)
- %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
- %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
- ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2)
- }
-
- ENTRY %cluster {
- %param0 = f32[1024] parameter(0)
- %param1 = f32[1024] parameter(1)
- ROOT %fusion = (f32[], f32[])
- fusion(%param0, %param1), kind=kInput, calls=%fused_computation
- })";
- TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: scf.index_switch %block_id_y
- // CHECK: case 1 {
- // CHECK: default {
- )"));
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, F64RowReductionIndexing) {
- auto fusion = GetEmitter(kF64RowReduction);
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
- /*shape=*/{100, 128}));
- TF_EXPECT_OK(
- TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_),
- /*shape=*/{100}));
-}
-
-TEST_F(MlirRowReductionTest, F64RowReductionIr) {
- // This reduction is small enough not to require shared memory.
- TF_ASSERT_OK(EmitAndCheckIR(kF64RowReduction, R"(
- // CHECK-NOT: allocate_shared
- )"));
-}
-
-TEST_F(MlirRowReductionTest, F64RowReductionCorrectness) {
- EXPECT_TRUE(RunAndCompareNoHloPasses(kF64RowReduction, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorIndexing) {
- auto fusion = GetEmitter(kRowReductionMinorAndMajor);
-
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
- /*shape=*/{7, 100, 128}));
- TF_EXPECT_OK(
- TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_),
- /*shape=*/{100}));
-}
-
-TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorCorrectness) {
- EXPECT_TRUE(
- RunAndCompareNoHloPasses(kRowReductionMinorAndMajor, ErrorSpec{1e-3}));
-}
-
TEST_F(MlirMultiRowReductionTest, MultiRowReductionIndexing) {
auto fusion = GetEmitter(kMultiRowReductionX8);
@@ -379,207 +143,6 @@
EXPECT_TRUE(RunAndCompareNoHloPasses(kMultiRowReductionX8, ErrorSpec{1e-3}));
}
-TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
-
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[100,568] parameter(0)
- param_1 = f32[] parameter(1)
- ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- }
- ENTRY main {
- a = f32[100,568] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation
- })";
- TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0)>
- // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512)>
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
- // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]]
- // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
- // CHECK-NOT: scf.if
- // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 3]]
- // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]])
- // CHECK: scf.if
- // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 1], %thread_id_x in [0, 255])
- )"));
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirMultiRowReductionTest, NonTrivialEpilogueCorrectness) {
- constexpr auto kHloString = R"(
- HloModule module
- add {
- p0 = f64[] parameter(0)
- p1 = f64[] parameter(1)
- ROOT add = f64[] add(p0, p1)
- }
- fusion {
- %p0 = f64[4] parameter(0)
- %p1 = f64[4] parameter(1)
- %c0 = f64[] constant(-inf)
- %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add
- %bc0 = f64[4] broadcast(reduce0), dimensions={}
- %compare0 = pred[4] compare(p1, bc0), direction=EQ
- %c1 = f64[] constant(0)
- %bc1 = f64[4] broadcast(c1), dimensions={}
- %select.3.1 = f64[4] select(compare0, p0, bc1)
- %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add
- %convert0 = f64[4] convert(compare0)
- %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add
- ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2)
- }
- ENTRY main {
- %p0 = f64[4] parameter(0)
- %p1 = f64[4] parameter(1)
- ROOT %fusion = (f64[], f64[], f64[]) fusion(%p0, %p1), kind=kInput,
- calls=fusion
- })";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, SideOutputIndexing) {
- auto fusion = GetEmitter(kRowReductionSideOutput);
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
- {8, 2048}));
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {8}));
- TF_EXPECT_OK(
- TestBijection(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context_),
- {8, 2048})); // Side output.
-}
-
-TEST_F(MlirRowReductionTest, SideOutputIr) {
- TF_ASSERT_OK(EmitAndCheckIR(kRowReductionSideOutput, R"(
- // CHECK: @fused_computation
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp
- // CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]]
- )"));
-}
-
-TEST_F(MlirRowReductionTest, SideOutputCorrectness) {
- EXPECT_TRUE(
- RunAndCompareNoHloPasses(kRowReductionSideOutput, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, UnsignedSideOutputCorrectness) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
-
- Add {
- lhs = u32[] parameter(0)
- rhs = u32[] parameter(1)
- ROOT add = u32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = u32[8,2048] parameter(0)
- param_1 = u32[] parameter(1)
- add = u32[8,2048] add(param_0, param_0)
- reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add)
- }
- ENTRY main {
- a = u32[8,2048] parameter(0)
- c = u32[] constant(0)
- ROOT fusion = (u32[8], u32[8,2048]) fusion(a, c), kind=kInput,
- calls=fused_computation
- })";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, BroadcastSideOutputCorrectness) {
- constexpr auto kHloString = R"(
- %add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
- %fusion {
- %p0 = f32[6,6] parameter(0)
- %c0 = f32[] constant(0)
- %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
- %broadcast = f32[6,6] broadcast(%reduce), dimensions={}
- ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce)
- }
- ENTRY main {
- %p0 = f32[6,6] parameter(0)
- ROOT %fusion = (f32[6,6], f32[]) fusion(%p0), kind=kInput, calls=%fusion
- })";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, VariadicMOFCorrectness) {
- constexpr auto kHloString = R"(
- %reducer1 {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
- %reducer2 {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- p2 = f32[] parameter(2)
- p3 = f32[] parameter(3)
- add0 = f32[] add(p0, p2)
- add1 = f32[] add(p1, p3)
- ROOT tuple = (f32[], f32[]) tuple(add0, add1)
- }
- %fusion {
- %p0 = f32[6,6] parameter(0)
- %c0 = f32[] constant(0)
- %neg = f32[6,6] negate(%p0)
- %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1
- %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2
- ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg)
- }
- ENTRY main {
- %p0 = f32[6,6] parameter(0)
- ROOT %fusion = (f32[], (f32[], f32[]), f32[6,6]) fusion(%p0), kind=kInput, calls=%fusion
- })";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, OutputLayoutCorrectness) {
- constexpr std::string_view kHloString = R"(
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
-
- fusion {
- %input = f32[17,19,127] parameter(0)
- %c0 = f32[] constant(0)
- ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add
- }
-
- ENTRY entry {
- %input = f32[17,19,127] parameter(0)
- ROOT %fusion = f32[17,19]{0,1} fusion(%input), kind=kInput, calls=fusion
- })";
-
- auto fusion = GetEmitter(kHloString);
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
- {17, 19, 127}));
- TF_EXPECT_OK(TestBijection(
- *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {17, 19}));
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
TEST_F(MlirMultiRowReductionTest, TwoGroups) {
auto module = ParseAndReturnVerifiedModule(R"(
add {
@@ -604,7 +167,7 @@
.value();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirMultiRowReductionFusion fusion(analysis);
EXPECT_THAT(fusion.GetGroups().grouped_roots,
@@ -635,231 +198,12 @@
.value();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirMultiRowReductionFusion mlir_fusion(analysis);
EXPECT_THAT(mlir_fusion.GetGroups().grouped_roots, SizeIs(1));
}
-constexpr absl::string_view kColumnVectorizationTemplate = R"(
- add {
- b = $0[] parameter(1)
- a = $0[] parameter(0)
- ROOT out = $0[] add(a, b)
- }
- fusion {
- %p0 = $0[192,64,1536] parameter(0)
- %p1 = $0[] parameter(1)
- ROOT reduce = $0[192,1536] reduce(p0, p1), dimensions={1}, to_apply=add
- }
- ENTRY entry {
- %p0 = $0[192,64,1536] parameter(0)
- %p1 = $0[] parameter(1)
- ROOT %fusion = $0[192,1536] fusion(p0, p1), kind=kInput, calls=fusion
- })";
-
-TEST_F(MlirColumnReductionTest, ColumnReduction) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
-
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[13,1051,321] parameter(0)
- param_1 = f32[] parameter(1)
- ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- }
- ENTRY main {
- a = f32[13,1051,321] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = f32[13,321] fusion(a, c), kind=kInput, calls=fused_computation
- })";
-
- auto module = ParseAndReturnVerifiedModule(kHloString).value();
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
- MlirColumnReductionFusion fusion(analysis);
- EXPECT_THAT(
- fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1] -> (
- d3 floordiv 11,
- d0 floordiv 32 + s0 * 32,
- (d3 mod 11) * 32 + d0 mod 32
- )
- domain:
- d0 in [0, 1023]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 142]
- d4 in [0, 0]
- d5 in [0, 0]
- s0 in [0, 32]
- s1 in [0, 0]
- (d3 mod 11) * 32 + d0 mod 32 in [0, 320]
- d0 floordiv 32 + s0 * 32 in [0, 1050]
- )"));
- EXPECT_THAT(
- fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0] -> (
- d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32
- )
- domain:
- d0 in [0, 992]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 142]
- d4 in [0, 0]
- d5 in [0, 0]
- s0 in [0, 0]
- (d3 mod 11) * 32 + d0 floordiv 32 in [0, 320]
- d0 mod 32 in [0, 0]
- )"));
- TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: xla_gpu.pure_call @Add_add
- // CHECK: allocate_shared
- // CHECK: tensor.insert
- // CHECK: sync_threads
- // CHECK: predicated_extract
- // CHECK: shuffle_reduce
- // CHECK: predicated_insert
- )"));
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, SmallColumnReduction) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
-
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[3,128,4] parameter(0)
- param_1 = f32[] parameter(1)
- ROOT reduce = f32[3,4] reduce(param_0, param_1), dimensions={1}, to_apply=Add
- }
- ENTRY main {
- a = f32[3,128,4] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = f32[3,4] fusion(a, c), kind=kInput, calls=fused_computation
- })";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, MixedIndexing) {
- constexpr auto kHloString = R"(
- HloModule module
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
- fusion {
- %param_0 = f32[64,128] parameter(0)
- %constant_0 = f32[] constant(0)
- %reduce.1 = f32[128] reduce(f32[64,128] %param_0, f32[] %constant_0), dimensions={0}, to_apply=%add
- %neg = f32[64,128] negate(f32[64,128] %param_0)
- %bitcast = f32[8,8,128]{2,1,0} bitcast(f32[64,128] %neg)
- %reduce.2 = f32[128] reduce(f32[8,8,128]{2,1,0} %bitcast, f32[] %constant_0), dimensions={0,1}, to_apply=%add
- ROOT %tuple.12 = (f32[128], f32[128]) tuple(f32[128] %reduce.1, f32[128] %reduce.2)
- }
- ENTRY entry {
- %param_0 = f32[64,128] parameter(0)
- ROOT %fusion = (f32[128], f32[128]) fusion(%param_0), kind=kInput, calls=fusion
- })";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, ColumnReductionVectorizationCorrectness) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = f32[2048,16384] parameter(0)
- param_1 = f32[] parameter(1)
- ROOT reduce = f32[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add
- }
- ENTRY main {
- a = f32[2048,16384] parameter(0)
- c = f32[] constant(0)
- ROOT fusion = f32[16384] fusion(a, c), kind=kInput, calls=fused_computation
- })";
- TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: vector<2xf32>
- )"));
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, ColumnReductionVectorization_v4) {
- constexpr auto kHloString = R"(
- HloModule Test, is_scheduled=true
- Add {
- lhs = s16[] parameter(0)
- rhs = s16[] parameter(1)
- ROOT add = s16[] add(lhs, rhs)
- }
- fused_computation {
- param_0 = s16[2048,16384] parameter(0)
- param_1 = s16[] parameter(1)
- ROOT reduce = s16[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add
- }
- ENTRY main {
- a = s16[2048,16384] parameter(0)
- c = s16[] constant(0)
- ROOT fusion = s16[16384] fusion(a, c), kind=kInput, calls=fused_computation
- })";
- TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: vector<4xi16>
- )"));
- // We don't use RunAndCompareNoHloPasses because the interpreter is too slow
- // for this input.
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) {
- const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f32");
- auto fusion = GetEmitter(hlo_string);
- EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
- 0, 0, &mlir_context_)),
- ElementsAre(2 /* major reduced */, 2 /* vector size */));
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) {
- const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f16");
- auto fusion = GetEmitter(hlo_string);
- EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
- 0, 0, &mlir_context_)),
- ElementsAre(2 /* major reduced */, 4 /* vector size */));
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_f64) {
- // Verifies that we do not use the vectorized indexing for f64.
- const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f64");
- auto fusion = GetEmitter(hlo_string);
- EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
- 0, 0, &mlir_context_)),
- ElementsAre(2 /* major reduced */, 1 /* vector size */));
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) {
- // Verifies that we do not use the vectorized indexing for complex types.
- const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "c64");
- auto fusion = GetEmitter(hlo_string);
- EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
- 0, 0, &mlir_context_)),
- ElementsAre(2 /* major reduced */, 1 /* vector size */));
-}
-
TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) {
auto fusion = GetEmitter(kMultiRowReductionX2VectorX4);
@@ -883,61 +227,6 @@
RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3}));
}
-TEST_F(MlirRowReductionTest, LargeToUnit) {
- // Regression test for a bug where not all threads in the warp produced a
- // valid value for the final warp shuffle.
- constexpr auto kHloString = R"(
- and {
- p0 = pred[] parameter(0)
- p1 = pred[] parameter(1)
- ROOT and = pred[] and(p0, p1)
- }
-
- %fused_reduce {
- c1 = pred[] constant(true)
- p0 = pred[10000] broadcast(c1), dimensions={}
- ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and
- }
- )";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, MOFTwoVariadic) {
- // Regression test for a compilation crash with a MOF with two variadic
- // reductions.
- constexpr auto kHloString = R"(
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- p2 = f32[] parameter(2)
- p3 = f32[] parameter(3)
- a = f32[] add(p0, p2)
- b = f32[] add(p1, p3)
- ROOT out = (f32[], f32[]) tuple(a, b)
- }
-
- fused_reduce {
- p0 = f32[3,2] parameter(0)
- p1 = f32[3,2] parameter(1)
- c0 = f32[] constant(0)
- iota0 = f32[3,2] iota(), iota_dimension=1
- iota1 = f32[3,2] iota(), iota_dimension=1
- reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1},
- to_apply=add
- reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1},
- to_apply=add
- ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1)
- }
-
- ENTRY main {
- p0 = f32[3,2] parameter(0)
- p1 = f32[3,2] parameter(1)
- ROOT fusion = ((f32[3], f32[3]), (f32[3], f32[3])) fusion(p0, p1),
- kind=kInput, calls=fused_reduce
- })";
- EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
} // namespace
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc
deleted file mode 100644
index 81649a7..0000000
--- a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc
+++ /dev/null
@@ -1,176 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/reduction.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-using ::testing::SizeIs;
-
-class ReductionTest : public HloTestBase {
- protected:
- stream_executor::DeviceDescription device_info_ =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
- mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(ReductionTest, ThreadIndexingRowReduction) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
-
- fusion {
- %input = f32[100,64,512] parameter(0)
- %c0 = f32[] constant(0)
- ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add
- }
-
- ENTRY entry {
- %input = f32[100,64,512] parameter(0)
- ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
- ReductionFusion fusion(analysis);
-
- EXPECT_THAT(
- fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
- d3 floordiv 8,
- (d3 mod 8) * 8 + d0 floordiv 32,
- (d0 mod 32) * 2 + s2 * 64 + s3
- )
- domain:
- d0 in [0, 255]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 799]
- d4 in [0, 0]
- d5 in [0, 0]
- s0 in [0, 0]
- s1 in [0, 0]
- s2 in [0, 7]
- s3 in [0, 1]
- )"));
- EXPECT_THAT(
- fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5) -> (
- d3 floordiv 8,
- (d3 mod 8) * 8 + d0 floordiv 32
- )
- domain:
- d0 in [0, 224]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 799]
- d4 in [0, 0]
- d5 in [0, 0]
- d0 mod 32 in [0, 0]
- )"));
-}
-
-TEST_F(ReductionTest, TwoGroups) {
- auto module = ParseAndReturnVerifiedModule(R"(
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
- fusion {
- %p0 = f32[2] parameter(0)
- %p1 = f32[2] parameter(1)
- %c0 = f32[] constant(-inf)
- %r0 = f32[] reduce(%p0, %c0), dimensions={0}, to_apply=add
- %c1 = f32[] constant(inf)
- %r1 = f32[] reduce(%p1, %c1), dimensions={0}, to_apply=add
- ROOT %tuple = (f32[], f32[]) tuple(%r0, %r1)
- }
- ENTRY entry {
- %p0 = f32[2] parameter(0)
- %p1 = f32[2] parameter(1)
- ROOT %fusion = (f32[], f32[]) fusion(%p0, %p1), kind=kInput, calls=fusion
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
- ReductionFusion fusion(analysis);
-
- EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots,
- ElementsAre(ElementsAre(&analysis.fusion_root(0).instruction()),
- ElementsAre(&analysis.fusion_root(1).instruction())));
-}
-
-TEST_F(ReductionTest, OneGroup) {
- auto module = ParseAndReturnVerifiedModule(R"(
- %add {
- %p0 = c128[] parameter(0)
- %p1 = c128[] parameter(1)
- ROOT %add.35 = c128[] add(c128[] %p0, c128[] %p1)
- }
- %fusion {
- %p0 = c128[1,2] parameter(0)
- %c0 = c128[] constant((0, 0))
- %reduce = c128[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
- %real = f64[] real(c128[] %reduce)
- %imag = f64[] imag(c128[] %reduce)
- %negate = f64[] negate(f64[] %imag)
- ROOT %tuple.29 = (f64[], f64[]) tuple(f64[] %real, f64[] %negate)
- }
- ENTRY entry {
- %p0 = c128[1,2] parameter(0)
- ROOT %fusion = (f64[], f64[]) fusion(%p0), kind=kInput, calls=fusion
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
- ReductionFusion fusion(analysis);
-
- EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, SizeIs(2));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.cc b/third_party/xla/xla/service/gpu/fusions/scatter.cc
deleted file mode 100644
index 8f7f773..0000000
--- a/third_party/xla/xla/service/gpu/fusions/scatter.cc
+++ /dev/null
@@ -1,293 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/scatter.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <optional>
-#include <string>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Value.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/fusions/loop.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/ir_emitter_nested.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-ScatterFusion::ScatterFusion(const HloFusionAnalysis& analysis)
- : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {
- CHECK_EQ(analysis.fusion_root_count(), 1);
- CHECK_EQ(analysis.fusion_root(0).opcode(), HloOpcode::kScatter);
-}
-
-LaunchDimensions ScatterFusion::launch_dimensions() const {
- const auto& updates_shape =
- analysis_.fusion_root(0).instruction().operands().back()->shape();
- return CalculateLaunchDimensions(updates_shape, analysis_.device_info());
-}
-
-absl::Status ScatterFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const {
- GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
- // Spin up a new fused emitter for the scatter kernel and emit it.
- FusedIrEmitter scatter_fused_emitter(elemental_emitter);
- auto* fused_computation = fusion.fused_instructions_computation();
- for (int i = 0; i < fused_computation->num_parameters(); i++) {
- auto fused_operand = fused_computation->parameter_instruction(i);
- scatter_fused_emitter.BindGenerator(
- *fused_operand, [builder, &input = inputs[i],
- fused_operand](llvm_ir::IrArray::Index index) {
- return input.EmitReadArrayElement(index, builder,
- fused_operand->name());
- });
- }
-
- auto* root = fused_computation->root_instruction();
- const xla::ScatterDimensionNumbers& scatter_dims =
- Cast<HloScatterInstruction>(root)->scatter_dimension_numbers();
-
- std::string name = llvm_ir::IrName(root);
- const Shape& operand_shape = root->operand(0)->shape();
- const Shape& scatter_indices_shape = root->operand(1)->shape();
- const Shape& updates_shape = root->operand(2)->shape();
- const HloComputation& update_computation = *root->called_computations()[0];
-
- TF_ASSIGN_OR_RETURN(auto scatter_indices_gen,
- scatter_fused_emitter.GetGenerator(*root->operand(1)));
- TF_ASSIGN_OR_RETURN(auto updates_gen,
- scatter_fused_emitter.GetGenerator(*root->operand(2)));
-
- auto loop_body_emitter =
- [&](const llvm_ir::IrArray::Index& index) -> absl::Status {
- std::vector<llvm::Value*> raw_window_multidim;
- std::vector<llvm::Value*> input_scatter_multidim;
- std::vector<int64_t> raw_window_bounds;
-
- auto get_i64_array = [](absl::Span<const int64_t> container) {
- return llvm::ArrayRef<int64_t>{container.data(),
- static_cast<size_t>(container.size())};
- };
-
- llvm::ArrayRef<int64_t> update_window_dims =
- get_i64_array(scatter_dims.update_window_dims());
- // Partition the index into window indices and scatter indices.
- for (int64_t i = 0, e = index.size(); i != e; ++i) {
- // For window indices also remember the window size, this comes in handy
- // later.
- if (llvm::is_contained(update_window_dims, i)) {
- raw_window_multidim.push_back(index[i]);
- raw_window_bounds.push_back(updates_shape.dimensions(i));
- } else {
- input_scatter_multidim.push_back(index[i]);
- }
- }
- DCHECK_EQ(raw_window_multidim.size(),
- scatter_dims.update_window_dims_size());
-
- // Apply inserted_window_dims to the window dimensions.
- int64_t raw_window_multidim_idx = 0;
- llvm::SmallVector<llvm::Value*> input_window_multidim;
- llvm::SmallVector<int64_t> input_window_bounds;
- const int64_t rank = operand_shape.rank();
- input_window_bounds.reserve(rank);
- input_window_multidim.reserve(rank);
-
- llvm::ArrayRef<int64_t> inserted_window_dims =
- get_i64_array(scatter_dims.inserted_window_dims());
- for (int64_t i = 0; i != rank; ++i) {
- if (llvm::is_contained(inserted_window_dims, i)) {
- input_window_bounds.push_back(1); // Trivial dimension.
- input_window_multidim.push_back(index.GetConstantWithIndexType(0));
- } else {
- input_window_bounds.push_back(
- raw_window_bounds[raw_window_multidim_idx]);
- input_window_multidim.push_back(
- raw_window_multidim[raw_window_multidim_idx]);
- ++raw_window_multidim_idx;
- }
- }
- DCHECK_EQ(input_window_multidim.size(), operand_shape.rank());
-
- // Insert a 1 dimension at the end if index_vector_dim requests one.
- Shape scatter_indices_shape_fixed = scatter_indices_shape;
- if (scatter_dims.index_vector_dim() == scatter_indices_shape.rank()) {
- scatter_indices_shape_fixed.add_dimensions(1);
- scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
- scatter_dims.index_vector_dim());
- }
-
- // Now load the indices corresponding to the current window from
- // scatter_indices.
- std::vector<llvm::Value*> raw_scatter_index_multidim =
- input_scatter_multidim;
- raw_scatter_index_multidim.insert(
- raw_scatter_index_multidim.begin() + scatter_dims.index_vector_dim(),
- nullptr);
-
- llvm::ArrayRef<int64_t> scatter_dims_to_operand_dims =
- get_i64_array(scatter_dims.scatter_dims_to_operand_dims());
- llvm::Value* is_in_bounds = builder->getTrue();
- for (int64_t i = 0, e = scatter_dims_to_operand_dims.size(); i != e; ++i) {
- // Our index is stored along index_vector_dim, insert that into the lookup
- // index into scatter_indices.
- raw_scatter_index_multidim[scatter_dims.index_vector_dim()] =
- index.GetConstantWithIndexType(i);
- llvm_ir::IrArray::Index raw_scatter_index_index(
- raw_scatter_index_multidim, scatter_indices_shape_fixed,
- index.GetType());
-
- int64_t operand_dim = scatter_dims_to_operand_dims[i];
- if (operand_dim > rank) {
- return absl::OutOfRangeError(
- "The provided scatter_dims_to_operand_dims was out of range.");
- }
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const loaded_scatter_index,
- scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
- scatter_indices_shape_fixed, scatter_indices_shape, builder)));
- // And add the index to our window index. This yields the output index.
- llvm::Value* casted_scatter_index = builder->CreateIntCast(
- loaded_scatter_index, index.GetType(),
- /*isSigned=*/ShapeUtil::ElementIsSigned(scatter_indices_shape));
- llvm::Value* dim_offset = builder->CreateAdd(
- input_window_multidim[operand_dim], casted_scatter_index);
- input_window_multidim[operand_dim] = dim_offset;
-
- // Also do the bounds check now.
- int64_t max_index = operand_shape.dimensions(operand_dim) -
- input_window_bounds[operand_dim] + 1;
- // is_in_bounds = index >= 0 && index < dim_size-window_size+1
- // --> index u< dim_size-window_size+1
- is_in_bounds = builder->CreateAnd(
- is_in_bounds,
- builder->CreateICmpULT(casted_scatter_index,
- index.GetConstantWithIndexType(max_index)));
- }
-
- llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
- is_in_bounds, "scatter.in_bounds", builder, /*emit_else=*/false);
- llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block,
- builder);
- // All done, now just read from the calculated input from the window, and do
- // an atomic store to the calculated location in the output.
- llvm_ir::IrArray::Index input_window_index(
- input_window_multidim, outputs.back().GetShape(), index.GetType());
- llvm::Value* output_address =
- outputs.back().EmitArrayElementAddress(input_window_index, builder);
- llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(updates_shape.element_type(),
- ir_emitter_context.llvm_module()),
- "input_address", builder);
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
- builder->CreateStore(input_ir_value, input_address);
-
- if (root->unique_indices()) {
- return CallNestedComputation(
- builder, ir_emitter_context, update_computation,
- {output_address, input_address}, output_address);
- }
- return EmitAtomicOperationForNestedComputation(
- builder, ir_emitter_context, update_computation, output_address,
- input_address, outputs.back().GetElementLlvmType());
- };
-
- // Launch a kernel that reads every element in the updates tensor. We could
- // also do one kernel per window instead if bounds checks turn out to be a
- // bottleneck.
- auto index_type =
- GetIndexTypeForKernel(root, launch_dims.launch_bound(), builder);
- return ParallelLoopEmitter(loop_body_emitter, updates_shape, launch_dims,
- builder)
- .EmitLoop(name, index_type);
-}
-
-std::optional<IndexingMap> ScatterFusion::ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const {
- const auto* scatter =
- DynCast<HloScatterInstruction>(&analysis_.fusion_hero(0).instruction());
- int64_t scatter_operand_count = scatter->scatter_operand_count();
- // Scatter operands a packed in the following way:
- // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`.
- // Operand ID scatter_operand_count for `scatter indices`.
- // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for
- // `scatter updates`.
-
- // For scatter operands we do not know the thread ID indexing.
- if (hero_operand_index < scatter_operand_count) {
- return std::nullopt;
- }
- // Compute thread id mapping based on the first update operand.
- Shape scatter_update_shape = scatter->scatter_updates().front()->shape();
- IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap(
- launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx);
-
- // For scatter indices we project indexing for scatter updates and take the
- // first result of the affine map only, because they coincide.
- if (hero_operand_index == scatter_operand_count) {
- Shape scatter_indices_shape = scatter->scatter_indices()->shape();
- CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString();
- // Create a map from scatter update to scatter indices.
- IndexingMap updates_to_indices_map{
- mlir::AffineMap::get(
- /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1,
- {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)},
- ctx),
- DimVarsFromTensorSizes(scatter_update_shape.dimensions()),
- RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}),
- /*rt_vars=*/{}};
- auto scatter_indices_map = scatter_update_map * updates_to_indices_map;
- scatter_indices_map.Simplify();
- return scatter_indices_map;
- }
- return scatter_update_map;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.h b/third_party/xla/xla/service/gpu/fusions/scatter.h
deleted file mode 100644
index dda11c0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/scatter.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_SCATTER_H_
-#define XLA_SERVICE_GPU_FUSIONS_SCATTER_H_
-
-#include <optional>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-
-// A scatter, implemented as a loop over the updates. All scatters are in-place.
-class ScatterFusion : public KernelFusionEmitterBase {
- public:
- explicit ScatterFusion(const HloFusionAnalysis& analysis);
-
- LaunchDimensions launch_dimensions() const override;
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const override {
- // The kernel iterates over updates, whose correspondence to output
- // elements cannot be computed statically.
- return std::nullopt;
- }
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const override;
-
- protected:
- absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const override;
-
- private:
- const HloFusionAnalysis& analysis_;
- LaunchDimensionsConfig config_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_SCATTER_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
index a281c0e..85e1e50 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
@@ -39,9 +39,10 @@
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/primitive_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/indexing_map.h"
@@ -62,7 +63,6 @@
using mlir::Value;
using mlir::ValueRange;
using mlir::func::ReturnOp;
-using mlir::tensor::InsertOp;
using mlir_converter::CallTargetProvider;
using mlir_converter::PartitionedComputations;
using mlir_converter::ProvideParameter;
@@ -174,7 +174,8 @@
auto reduced_val = mlir_converter::InlineBlock(
b, reducer.getBody().front(), {operand_elem, update_elem})[0];
- return b.create<InsertOp>(reduced_val, output_tensor, indices);
+ return b.create<mlir::tensor::InsertOp>(reduced_val, output_tensor,
+ indices);
}
auto atomic_rmw = b.create<AtomicRMWOp>(output_tensor, indices);
mlir::OpBuilder body_builder = atomic_rmw.getBodyBuilder();
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h
index de9743a..3efaa0a 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h
@@ -25,7 +25,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/loop.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
index 6b8d013..2e9a11a 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
@@ -77,7 +77,7 @@
thread_id_printer_.SetSymbolName(2, "index_id");
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirScatterFusion fusion(analysis);
constexpr auto kUpdatesIndexing = R"(
@@ -187,8 +187,8 @@
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 2)>
- // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 2)>
+ // CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 2)
+ // CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2)
// CHECK-LABEL: func.func @fused_computation(
// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32>
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_test.cc
deleted file mode 100644
index 284d308..0000000
--- a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc
+++ /dev/null
@@ -1,224 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/scatter.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class ScatterFusionTest : public HloTestBase {
- public:
- void SetUp() override {
- HloTestBase::SetUp();
- printer_ =
- AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
- {"chunk_id", "unroll_id", "index_id"});
- }
- DebugOptions GetDebugOptionsForTest() override {
- auto opts = HloTestBase::GetDebugOptionsForTest();
- opts.set_xla_gpu_mlir_emitter_level(0);
- return opts;
- }
-
- protected:
- AffineMapPrinter printer_;
- mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(ScatterFusionTest, ScatterFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- add (lhs: f32[], rhs: f32[]) -> f32[] {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT sum = f32[] add(lhs, rhs)
- }
-
- fused_computation {
- %input = f32[2,9] parameter(0)
- %indices = s32[3] parameter(1)
- %updates = f32[3,9] parameter(2)
- ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
- to_apply=add,
- update_window_dims={1},
- inserted_window_dims={0},
- scatter_dims_to_operand_dims={0},
- index_vector_dim=1
- }
-
- ENTRY entry {
- %input = f32[2,9] parameter(0)
- %indices = s32[3] parameter(1)
- %updates = f32[3,9] parameter(2)
- ROOT %fusion = f32[2,9] fusion(%input, %indices, %updates), kind=kLoop, calls=fused_computation
- })")
- .value();
-
- stream_executor::DeviceDescription device_info =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis_fused = AnalyzeFusion(*root, device_info);
-
- auto emitter =
- GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
- auto scatter_fusion = dynamic_cast<ScatterFusion*>(emitter.get());
- ASSERT_NE(scatter_fusion, nullptr);
- EXPECT_EQ(scatter_fusion->launch_dimensions().launch_bound(),
- 3 * 9 /* updates size */);
-}
-
-TEST_F(ScatterFusionTest, ThreadIdIndexing) {
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- computation {
- %p0 = f32[] parameter(0)
- %p1 = f32[] parameter(1)
- %p2 = f32[] parameter(2)
- %p3 = f32[] parameter(3)
- ROOT %tuple = (f32[], f32[]) tuple(f32[] %p2, f32[] %p3)
- }
- scatter {
- %operand0 = f32[300,200] parameter(0)
- %operand1 = f32[300,200] parameter(1)
- %indices = s32[42,1] parameter(2)
- %update.1 = f32[42,10,20] parameter(3)
- %update.2 = f32[42,10,20]parameter(4)
-
- ROOT %scatter = (f32[300,200], f32[300,200]) scatter(
- f32[300,200] %operand0,
- f32[300,200] %operand1,
- s32[42,1] %indices,
- f32[42,10,20] %update.1,
- f32[42,10,20] %update.2
- ),
- update_window_dims={1,2},
- inserted_window_dims={},
- scatter_dims_to_operand_dims={0},
- index_vector_dim=1,
- to_apply=computation
- }
- ENTRY entry {
- %operand0 = f32[300,200] parameter(0)
- %operand1 = f32[300,200] parameter(1)
- %indices = s32[42,1] parameter(2)
- %update.1 = f32[42,10,20] parameter(3)
- %update.2 = f32[42,10,20]parameter(4)
- ROOT %fusion = (f32[300,200], f32[300,200]) fusion(
- %operand0, %operand1, %indices, %update.1, %update.2),
- kind=kLoop, calls=scatter
- }
- )"));
- stream_executor::DeviceDescription device_info =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis_fused = AnalyzeFusion(*root, device_info);
-
- auto emitter =
- GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
- auto fusion = dynamic_cast<ScatterFusion*>(emitter.get());
- ASSERT_NE(fusion, nullptr);
-
- constexpr auto kUpdatesIndexing = R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
- (bl_x * 128 + th_x) floordiv 200,
- ((bl_x * 128 + th_x) floordiv 20) mod 10,
- (bl_x * 128 + th_x) mod 20
- )
- domain:
- th_x in [0, 127]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 65]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- bl_x * 128 + th_x in [0, 8399]
- )";
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kUpdatesIndexing));
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kUpdatesIndexing));
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kUpdatesIndexing));
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kUpdatesIndexing));
-
- constexpr auto kIndicesIndexing = R"(
- (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] ->
- ((bl_x * 128 + th_x) floordiv 200, 0)
- domain:
- th_x in [0, 127]
- th_y in [0, 0]
- th_z in [0, 0]
- bl_x in [0, 65]
- bl_y in [0, 0]
- bl_z in [0, 0]
- chunk_id in [0, 0]
- unroll_id in [0, 0]
- index_id in [0, 0]
- bl_x * 128 + th_x in [0, 8399]
- )";
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kIndicesIndexing));
- EXPECT_THAT(
- fusion
- ->ComputeThreadIdToInputIndexing(
- /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_)
- ->ToString(printer_),
- MatchIndexingString(kIndicesIndexing));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/tests/BUILD
new file mode 100644
index 0000000..d3e3b66
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/BUILD
@@ -0,0 +1,19 @@
+load("//xla:lit.bzl", "lit_test_suite")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ licenses = ["notice"],
+)
+
+lit_test_suite(
+ name = "tests",
+ srcs = glob(["**/*.hlo"]),
+ cfg = "//xla:lit.cfg.py",
+ default_tags = ["requires-gpu-sm80-only"],
+ tools = [
+ "//xla/service/gpu/fusions/tools:fusion_to_mlir",
+ "//xla/service/gpu/fusions/tools:mlir_fusions_opt",
+ "//xla/service/gpu/fusions/tools:test_correctness",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo
new file mode 100644
index 0000000..1646ade
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo
@@ -0,0 +1,19 @@
+// RUN: test_correctness %s --bijection_inputs=reduce.1:0 --bijection_inputs=reduce.2:0 --bijection_outputs=reduce.1 --bijection_outputs=reduce.2
+
+add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+}
+
+fusion {
+ %param_0 = f32[64,128] parameter(0)
+ %constant_0 = f32[] constant(0)
+ %reduce.1 = f32[128] reduce(param_0, constant_0), dimensions={0},
+ to_apply=%add
+ %neg = f32[64,128] negate(param_0)
+ %bitcast = f32[8,8,128] bitcast(neg)
+ %reduce.2 = f32[128] reduce(bitcast, constant_0), dimensions={0,1},
+ to_apply=%add
+ ROOT %tuple = (f32[128], f32[128]) tuple(reduce.1, reduce.2)
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo
new file mode 100644
index 0000000..e7ae070
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo
@@ -0,0 +1,22 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f32[13,1051,321] parameter(0)
+ param_1 = f32[] parameter(1)
+ ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=add
+}
+
+// CHECK: xla_gpu.pure_call @add_add
+// CHECK: allocate_shared
+// CHECK: tensor.insert
+// CHECK: sync_threads
+// CHECK: predicated_extract
+// CHECK: shuffle_reduce
+// CHECK: predicated_insert
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo
new file mode 100644
index 0000000..958b391
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo
@@ -0,0 +1,13 @@
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f32[3,128,4] parameter(0)
+ c0 = f32[] constant(0)
+ ROOT reduce = f32[3,4] reduce(param_0, c0), dimensions={1}, to_apply=add
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo
new file mode 100644
index 0000000..a2a22363
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = c64[] parameter(0)
+ rhs = c64[] parameter(1)
+ ROOT add = c64[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = c64[128,64] parameter(0)
+ c0 = c64[] constant((0, 0))
+ ROOT reduce = c64[64] reduce(param_0, c0), dimensions={0},
+ to_apply=add
+}
+
+// CHECK-NOT: vector<
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo
new file mode 100644
index 0000000..660664b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f64[] parameter(0)
+ rhs = f64[] parameter(1)
+ ROOT add = f64[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f64[128,64] parameter(0)
+ c0 = f64[] constant(0)
+ ROOT reduce = f64[64] reduce(param_0, c0), dimensions={0},
+ to_apply=add
+}
+
+// CHECK-NOT: vector<
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo
new file mode 100644
index 0000000..a142ad4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f32[2048,64] parameter(0)
+ c0 = f32[] constant(0)
+ ROOT reduce = f32[64] reduce(param_0, c0), dimensions={0},
+ to_apply=add
+}
+
+// CHECK: vector<2xf32>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo
new file mode 100644
index 0000000..81da088
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = s16[] parameter(0)
+ rhs = s16[] parameter(1)
+ ROOT add = s16[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = s16[256,128] parameter(0)
+ c0 = s16[] constant(0)
+ ROOT reduce = s16[128] reduce(param_0, c0), dimensions={0},
+ to_apply=add
+}
+
+// CHECK: vector<4xi16>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo
new file mode 100644
index 0000000..f8a9e86
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo
@@ -0,0 +1,22 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f32[8,2048] parameter(0)
+ param_1 = f32[] parameter(1)
+ reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add
+ ROOT log = f32[8] log(reduce)
+}
+
+// CHECK: shuffle_reduce
+// CHECK: allocate_shared
+// CHECK: sync_threads
+// CHECK: shuffle_reduce
+// CHECK-NEXT: %[[OUT:.*]] = math.log
+// CHECK: predicated_insert %[[OUT]]
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo
new file mode 100644
index 0000000..bc84174
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo
@@ -0,0 +1,43 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2
+
+add {
+ p0 = f64[] parameter(0)
+ p1 = f64[] parameter(1)
+ ROOT add = f64[] add(p0, p1)
+}
+
+// This fusion is valid, but we can't efficiently codegen it.
+fusion {
+ %p0 = f64[4] parameter(0)
+ %p1 = f64[4] parameter(1)
+ %c0 = f64[] constant(-inf)
+ %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add
+ %bc0 = f64[4] broadcast(reduce0), dimensions={}
+ %compare0 = pred[4] compare(p1, bc0), direction=EQ
+ %c1 = f64[] constant(0)
+ %bc1 = f64[4] broadcast(c1), dimensions={}
+ %select.3.1 = f64[4] select(compare0, p0, bc1)
+ %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add
+ %convert0 = f64[4] convert(compare0)
+ %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add
+ ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2)
+}
+
+// We read all of %p1 once from each thread, and then read one element again.
+// CHECK: func.func @main
+// CHECK-SAME: , %[[P1:.*]]: tensor<4xf64> {xla.slice_index = 1 : index}
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST0:.*]] = arith.constant 0xFFF0000000000000
+// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
+
+// reduce0 in the context of reduce2 and reduce1's prologue:
+// CHECK: scf.for %[[I:.*]] = %[[C0]]
+// CHECK-NEXT: tensor.extract %[[P1]][%[[I]]]
+// CHECK-NEXT: addf
+// CHECK-NEXT: yield
+
+// reduce0 again, in the context of its status as a fusion hero:
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[P1]][%[[TID_X]]]
+// CHECK: %[[ADDED:.*]] = arith.addf %[[CST0]], %[[EXTRACTED]]
+// CHECK: shuffle_reduce @add_add(%[[ADDED]]) to 2
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo
new file mode 100644
index 0000000..ee155c8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo
@@ -0,0 +1,15 @@
+// Regression test for a bug where not all threads in the warp produced a valid
+// value for the final warp shuffle.
+// RUN: test_correctness %s
+
+and {
+ p0 = pred[] parameter(0)
+ p1 = pred[] parameter(1)
+ ROOT and = pred[] and(p0, p1)
+}
+
+fused_reduce {
+ c1 = pred[] constant(true)
+ p0 = pred[10000] broadcast(c1), dimensions={}
+ ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo
new file mode 100644
index 0000000..102e32b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+}
+
+fusion {
+ %input = f32[17,19,127] parameter(0)
+ %c0 = f32[] constant(0)
+ // The output is physically transposed.
+ ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add
+}
+
+// CHECK: xla_gpu.predicated_insert {{.*}} : tensor<17x19xf32, dense<[0, 1]> : tensor<2xi64>>
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo
new file mode 100644
index 0000000..c9481f3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo
@@ -0,0 +1,20 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f32[7,100,128] parameter(0)
+ param_1 = f32[] parameter(1)
+ ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=add
+}
+
+// Our codegen doesn't support parallelizing the major reduction dimension. In
+// principle, this could be done via shared memory.
+// CHECK-NOT: allocate_shared
+// CHECK: shuffle_reduce
+// CHECK-NOT: allocate_shared
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo
new file mode 100644
index 0000000..315d604
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo
@@ -0,0 +1,40 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+mul {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT mul = f32[] multiply(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f32[8,1024] parameter(0)
+ c0 = f32[] constant(0)
+ c1 = f32[] constant(1)
+ reduce1 = f32[8] reduce(param_0, c0), dimensions={1}, to_apply=add
+ reduce2 = f32[8] reduce(param_0, c1), dimensions={1}, to_apply=mul
+ log = f32[8] log(reduce1)
+ abs = f32[8] abs(reduce1)
+ neg = f32[8] negate(reduce2)
+ ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs)
+}
+
+// CHECK-DAG: shuffle_reduce @add_add
+// CHECK-DAG: shuffle_reduce @mul_mul
+// CHECK: allocate_shared
+// CHECK: allocate_shared
+// CHECK: sync_threads
+// CHECK-DAG: %[[ADDED:.*]] = xla_gpu.shuffle_reduce @add_add
+// CHECK-DAG: %[[MULTIPLIED:.*]] = xla_gpu.shuffle_reduce @mul_mul
+// CHECK-DAG: %[[LOG:.*]] = math.log %[[ADDED]]
+// CHECK-DAG: %[[ABS:.*]] = math.absf %[[ADDED]]
+// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[MULTIPLIED]]
+// CHECK-DAG: xla_gpu.predicated_insert %[[LOG]]
+// CHECK-DAG: xla_gpu.predicated_insert %[[ABS]]
+// CHECK-DAG: xla_gpu.predicated_insert %[[NEG]]
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo
new file mode 100644
index 0000000..48a2033
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo
@@ -0,0 +1,26 @@
+// RUN: test_correctness %s
+
+%reducer1 {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+}
+
+%reducer2 {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ p2 = f32[] parameter(2)
+ p3 = f32[] parameter(3)
+ add0 = f32[] add(p0, p2)
+ add1 = f32[] add(p1, p3)
+ ROOT tuple = (f32[], f32[]) tuple(add0, add1)
+}
+
+%fusion {
+ %p0 = f32[6,6] parameter(0)
+ %c0 = f32[] constant(0)
+ %neg = f32[6,6] negate(%p0)
+ %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1
+ %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2
+ ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg)
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo
new file mode 100644
index 0000000..6d47fc6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo
@@ -0,0 +1,26 @@
+// Regression test for a compilation crash with a MOF with two variadic
+// reductions.
+// RUN: test_correctness %s
+
+add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ p2 = f32[] parameter(2)
+ p3 = f32[] parameter(3)
+ a = f32[] add(p0, p2)
+ b = f32[] add(p1, p3)
+ ROOT out = (f32[], f32[]) tuple(a, b)
+}
+
+fused_reduce {
+ p0 = f32[3,2] parameter(0)
+ p1 = f32[3,2] parameter(1)
+ c0 = f32[] constant(0)
+ iota0 = f32[3,2] iota(), iota_dimension=1
+ iota1 = f32[3,2] iota(), iota_dimension=1
+ reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1},
+ to_apply=add
+ reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1},
+ to_apply=add
+ ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1)
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo
new file mode 100644
index 0000000..30202d0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo
@@ -0,0 +1,31 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+fused_computation {
+ param_0 = f32[100,568] parameter(0)
+ param_1 = f32[] parameter(1)
+ ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=add
+}
+
+// CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 3]>
+// CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512), domain: d0 in [0, 1], d1 in [0, 255]>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+
+// The full loop without bounds checks:
+// CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]]
+// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-NOT: scf.if
+// CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]], %thread_id_x)[%[[I]]]
+
+// The tail loop:
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]])
+// CHECK: scf.if
+// CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]], %thread_id_x)
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo
new file mode 100644
index 0000000..a7e6415
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo
@@ -0,0 +1,22 @@
+// RUN: fusion_to_mlir %s | FileCheck %s
+// RUN: test_correctness %s
+
+%add_f32 {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+
+%fused_computation {
+ %param0 = f32[1024] parameter(0)
+ %param1 = f32[1024] parameter(1)
+ %constant0 = f32[] constant(0)
+ %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
+ %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
+ ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2)
+}
+
+// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
+// CHECK: scf.index_switch %[[BLOCK_ID_Y]]
+// CHECK: case 1 {
+// CHECK: default {
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo
new file mode 100644
index 0000000..e950e3c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo
@@ -0,0 +1,24 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce --bijection_outputs=exp
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f32[8,2048] parameter(0)
+ param_1 = f32[] parameter(1)
+ exp = f32[8,2048] exponential(param_0)
+ reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add
+ ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp)
+}
+
+// CHECK: @fused_computation
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp
+// CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]]
+// CHECK: scf.yield
+// CHECK: scf.yield
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo
new file mode 100644
index 0000000..0db1901
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo
@@ -0,0 +1,15 @@
+// RUN: test_correctness %s
+
+%add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+}
+
+%fusion {
+ %p0 = f32[6,6] parameter(0)
+ %c0 = f32[] constant(0)
+ %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
+ %broadcast = f32[6,6] broadcast(%reduce), dimensions={}
+ ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce)
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo
new file mode 100644
index 0000000..5371b80
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo
@@ -0,0 +1,15 @@
+// RUN: test_correctness %s
+
+add {
+ lhs = u32[] parameter(0)
+ rhs = u32[] parameter(1)
+ ROOT add = u32[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = u32[8,2048] parameter(0)
+ param_1 = u32[] parameter(1)
+ add = u32[8,2048] add(param_0, param_0)
+ reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add
+ ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add)
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo
new file mode 100644
index 0000000..56e3266
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+ lhs = f64[] parameter(0)
+ rhs = f64[] parameter(1)
+ ROOT add = f64[] add(lhs, rhs)
+}
+
+fused_computation {
+ param_0 = f64[100,128] parameter(0)
+ param_1 = f64[] parameter(1)
+ ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=add
+}
+
+// This reduction is small enough to not require any shared memory.
+// CHECK-NOT: allocate_shared
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo
new file mode 100644
index 0000000..b28bff4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo
@@ -0,0 +1,23 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0,1 --bijection_outputs=reduce
+
+add {
+ scalar_lhs.0 = f32[] parameter(0)
+ scalar_rhs.0 = f32[] parameter(1)
+ scalar_lhs.1 = f32[] parameter(2)
+ scalar_rhs.1 = f32[] parameter(3)
+ add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1)
+ add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1)
+ ROOT t = (f32[], f32[]) tuple(add.0, add.1)
+}
+
+fused_computation {
+ param_0 = f32[2, 3, 2048] parameter(0)
+ param_1 = f32[2, 3, 2048] parameter(1)
+ c0 = f32[] constant(0)
+ ROOT reduce = (f32[2, 3], f32[2, 3])
+ reduce(param_0, param_1, c0, c0), dimensions={2}, to_apply=add
+}
+
+// CHECK: allocate_shared
+// CHECK: allocate_shared
diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/tiling_util.cc
deleted file mode 100644
index 9ad085f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc
+++ /dev/null
@@ -1,259 +0,0 @@
-/*Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/tiling_util.h"
-
-#include <cstdint>
-#include <limits>
-#include <string>
-#include <vector>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Instructions.h"
-#include "llvm/IR/Value.h"
-#include "llvm/Support/Casting.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/target_util.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/kernel_support_library.h"
-#include "xla/service/llvm_ir/llvm_loop.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling,
- int dim, absl::InlinedVector<llvm::Value*, 4> tile_idx,
- absl::Span<llvm::Value* const> tile_dimensions,
- llvm::IRBuilder<>* b, const TileElementGenerator& emit_elem) {
- llvm::Type* index_ty = thread_id_info.thread_id->getType();
- auto constant = [&](int64_t val) {
- return llvm::ConstantInt::get(index_ty, val);
- };
-
- auto recurse = [&] {
- if (dim == tile_idx.size() - 1) {
- emit_elem(tile_idx);
- } else {
- EmitTileRec(thread_id_info, tiling, dim + 1, tile_idx, tile_dimensions, b,
- emit_elem);
- }
- };
-
- bool unroll = tiling.GetLoopsToUnroll()[dim];
- KernelSupportLibrary ksl(b, unroll ? llvm_ir::UnrollMode::kFullyUnroll
- : llvm_ir::UnrollMode::kDefaultUnroll);
-
- if (tiling.GetBlockTileSize()[dim] == 1) {
- tile_idx[dim] = constant(0);
- recurse();
- } else if (unroll) {
- // TODO(jreiffers): Check if this unrolling does anything useful.
- int64_t stride = tiling.GetThreadsPerBlock()[dim];
- int64_t dim_size = tiling.GetThreadTileSize()[dim];
-
- auto make_loop = [&](bool emit_bounds_checks) {
- auto body = [&, emit_bounds_checks](llvm::Value* i) {
- tile_idx[dim] = b->CreateAdd(i, thread_id_info.thread_ids[dim]);
- if (emit_bounds_checks) {
- auto* in_bounds =
- b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]);
- ksl.If("x_in_tile", in_bounds, recurse);
- } else {
- recurse();
- }
- };
- return [&, body] {
- ksl.For(absl::StrCat("loop", dim), constant(0),
- constant(dim_size * stride), constant(stride), body);
- };
- };
- if (stride > 1 && dim_size > 1) {
- // Most tiles will be full, so we emit a single bounds check for those.
- auto* is_full_tile = b->CreateICmpEQ(
- constant(tiling.GetBlockTileSize()[dim]), tile_dimensions[dim]);
- ksl.If("is_full_tile", is_full_tile, make_loop(false), make_loop(true));
- } else {
- make_loop(true)();
- }
- } else {
- // All dimensions are strided (thread 0 processes elements 0, num_threads,
- // num_threads+2, ...; thread 1 processes elements 1, num_threads + 1 and so
- // on).
- ksl.For(absl::StrCat("loop", dim), /*start=*/thread_id_info.thread_ids[dim],
- /*end=*/tile_dimensions[dim],
- /*step=*/tiling.GetThreadsPerBlock()[dim], [&](llvm::Value* i) {
- tile_idx[dim] = i;
- recurse();
- });
- }
-}
-
-} // namespace
-
-void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
- const TilingThreadIdInfo& thread_id_info,
- absl::Span<llvm::Value* const> tile_dimensions,
- const TileElementGenerator& emit_elem_function) {
- absl::InlinedVector<llvm::Value*, 4> tile_idx(tiling.GetShape().size());
- EmitTileRec(thread_id_info, tiling, 0, tile_idx, tile_dimensions, builder,
- emit_elem_function);
-}
-
-namespace {
-
-// Emits current block id.
-llvm::Value* EmitBlockId(llvm::IRBuilder<>* builder, int32_t num_blocks,
- llvm::Type* index_ty) {
- llvm::Value* block_id =
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, builder);
- if (num_blocks != 0) {
- llvm_ir::AddRangeMetadata(0, num_blocks,
- llvm::cast<llvm::Instruction>(block_id),
- builder->GetInsertBlock()->getModule());
- }
- auto ret = builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true);
- ret->setName("block.id.x");
- return ret;
-}
-
-// Emits current thread id with the given type.
-//
-// Sets the return value range to [0, threads_per_block).
-llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block,
- llvm::Type* index_ty) {
- // Calculate (y, x) coordinates respectively in the 2D view of thread block,
- // defined by (num_thread_y, num_thread_x) from thread_id.
- llvm::CallInst* thread_id =
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, builder);
- llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id,
- builder->GetInsertBlock()->getModule());
- auto ret = builder->CreateIntCast(thread_id, index_ty, /*isSigned=*/true);
- ret->setName("thread.id.x");
- return ret;
-}
-
-// Emits the LLVM values for thread_id, block_id, coordinates of the current
-// tile and strides of the loops to iterate over the current tile.
-absl::StatusOr<TilingThreadIdInfo> EmitThreadIdInfo(llvm::IRBuilder<>* builder,
- const Tiling& tiling,
- llvm::Type* index_ty) {
- auto constant = [&](uint64_t c) -> llvm::Constant* {
- return llvm::ConstantInt::get(index_ty, c);
- };
- int64_t num_blocks = tiling.GetNumBlocks();
- if (num_blocks > (int64_t)std::numeric_limits<uint32_t>::max()) {
- return FailedPrecondition(
- "Number of physical blocks (%d) does not fit in an i32 in tiling "
- "scheme: %s",
- num_blocks, tiling.ToString());
- }
-
- TilingThreadIdInfo info;
- info.thread_id =
- EmitThreadId(builder, tiling.GetNumThreadsPerBlock(), index_ty);
- info.block_id = EmitBlockId(builder, num_blocks, index_ty);
-
- for (auto [dim, stride] : llvm::enumerate(tiling.GetThreadStrides())) {
- int64_t size = tiling.GetThreadsPerBlock()[dim];
- if (size == 1) {
- info.thread_ids.emplace_back(constant(0));
- } else {
- auto& dim_id = info.thread_ids.emplace_back(info.thread_id);
- if (stride > 1) {
- dim_id = builder->CreateUDiv(dim_id, constant(stride));
- }
- if (dim) {
- dim_id = builder->CreateURem(dim_id, constant(size));
- }
- dim_id->setName(absl::StrCat("thread.id.", dim));
- }
- }
-
- info.lane_id =
- builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id");
- return info;
-}
-
-} // namespace
-
-absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
- llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
- const TileGenerator& tile_element_generator) {
- absl::Span<const int64_t> dims_in_elems = tiling.GetShape();
- const auto& block_counts = tiling.GetBlockCounts();
- auto constant = [&](uint64_t c) -> llvm::Constant* {
- return llvm::ConstantInt::get(index_ty, c);
- };
-
- TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info,
- EmitThreadIdInfo(builder, tiling, index_ty));
-
- KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
-
- const llvm_ir::IrArray::Index block_coords(
- thread_id_info.block_id,
- ShapeUtil::MakeShape(PRED /*arbitrary*/, block_counts), builder);
-
- absl::InlinedVector<llvm::Value*, 4> tile_dimensions;
- for (int i = 0; i < block_counts.size(); ++i) {
- int64_t block_tile_size = tiling.GetBlockTileSize()[i];
- if (dims_in_elems[i] % block_tile_size == 0) {
- // The block tile size evenly divides the tiled shape -> no need to emit
- // the bounds check.
- tile_dimensions.push_back(constant(block_tile_size));
- } else {
- // Only the last tile in each dimension may not have full size.
- llvm::Value* is_last =
- builder->CreateICmpEQ(block_coords[i], constant(block_counts[i] - 1));
- int64_t partial_row =
- dims_in_elems[i] - (block_counts[i] - 1) * block_tile_size;
- tile_dimensions.push_back(builder->CreateSelect(
- is_last, constant(partial_row), constant(block_tile_size),
- absl::StrCat("tile_bound.", i)));
- }
- }
-
- llvm_ir::IrArray::Index tile_offset = [&] {
- std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
- llvm::Type* index_ty = block_coords.GetType();
- for (int i = 0; i < block_counts.size(); ++i) {
- elem_multi_index[i] = builder->CreateMul(
- block_coords[i],
- llvm::ConstantInt::get(index_ty, tiling.GetBlockTileSize()[i]),
- absl::StrCat("tile_origin.", i));
- }
- return llvm_ir::IrArray::Index(elem_multi_index, tiling.GetShape(),
- index_ty);
- }();
-
- tile_element_generator(thread_id_info, tile_offset, tile_dimensions);
- return {{tile_dimensions, tile_offset, thread_id_info}};
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/tiling_util.h
deleted file mode 100644
index 66014ae..0000000
--- a/third_party/xla/xla/service/gpu/fusions/tiling_util.h
+++ /dev/null
@@ -1,183 +0,0 @@
-/*Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_
-#define XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_
-
-#include <cstdint>
-#include <functional>
-#include <string>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-#include "absl/types/span.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-#include "llvm/IR/Value.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// Describes tiling used by the kernel.
-//
-// Used by reduction and transpose emitters.
-class Tiling {
- public:
- Tiling(absl::Span<const int64_t> shape, absl::Span<const int64_t> tile_sizes,
- absl::Span<const int64_t> num_threads,
- // By default, don't unroll anything.
- absl::InlinedVector<bool, 4> loops_to_unroll = {})
- : shape_{shape.begin(), shape.end()},
- tile_sizes_per_thread_{tile_sizes.begin(), tile_sizes.end()},
- tile_sizes_per_block_(shape.size()),
- num_threads_{num_threads.begin(), num_threads.end()},
- num_blocks_(shape.size()),
- loops_to_unroll_(loops_to_unroll) {
- for (int64_t i = 0; i < shape.size(); ++i) {
- tile_sizes_per_block_[i] = tile_sizes[i] * num_threads[i];
- CHECK_NE(tile_sizes_per_block_[i], 0);
- num_blocks_[i] = CeilOfRatio(shape[i], tile_sizes_per_block_[i]);
- CHECK_NE(num_blocks_[i], 0);
- }
- if (loops_to_unroll_.empty()) loops_to_unroll_.resize(shape.size());
- }
- Tiling() = default;
-
- std::string ToString() const {
- return absl::StrJoin(
- {absl::StrFormat("shape = {%s}", absl::StrJoin(shape_, ", ")),
- absl::StrFormat("tile_sizes = {%s}",
- absl::StrJoin(tile_sizes_per_thread_, ", ")),
- absl::StrFormat("num_threads = {%s}",
- absl::StrJoin(num_threads_, ", "))},
- ", ");
- }
-
- // Number of elements in each dimension.
- const absl::InlinedVector<int64_t, 4>& GetShape() const { return shape_; }
- xla::Shape GetXlaShape(PrimitiveType element_type = F32) const {
- return ShapeUtil::MakeShape(element_type, shape_);
- }
-
- const absl::InlinedVector<int64_t, 4>& GetBlockCounts() const {
- return num_blocks_;
- }
-
- // Tile size for each thread.
- //
- // Equals to the number of iterations in the loop each tile will make.
- const absl::InlinedVector<int64_t, 4>& GetThreadTileSize() const {
- return tile_sizes_per_thread_;
- }
-
- // Tile size for an entire thread block.
- const absl::InlinedVector<int64_t, 4>& GetBlockTileSize() const {
- return tile_sizes_per_block_;
- }
-
- const absl::InlinedVector<int64_t, 4>& GetThreadsPerBlock() const {
- return num_threads_;
- }
-
- // Returns the strides of the thread index dimensions wrt. the linear thread
- // id.
- absl::InlinedVector<int64_t, 4> GetThreadStrides() const {
- return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_threads_));
- }
-
- int64_t GetNumThreadsPerBlock() const { return Product(num_threads_); }
-
- int64_t GetNumBlocks() const { return Product(num_blocks_); }
-
- const absl::InlinedVector<bool, 4>& GetLoopsToUnroll() const {
- return loops_to_unroll_;
- }
-
- private:
- // The number of elements in each dimension.
- absl::InlinedVector<int64_t, 4> shape_;
-
- // The number of elements for each dimension of a tile.
- absl::InlinedVector<int64_t, 4> tile_sizes_per_thread_;
- absl::InlinedVector<int64_t, 4> tile_sizes_per_block_;
-
- absl::InlinedVector<int64_t, 4> num_threads_;
- absl::InlinedVector<int64_t, 4> num_blocks_;
-
- absl::InlinedVector<bool, 4> loops_to_unroll_;
-};
-
-struct TilingThreadIdInfo {
- llvm::Value* thread_id;
-
- absl::InlinedVector<llvm::Value*, 4> thread_ids;
-
- // Lane id: `thread_id % WarpSize`
- llvm::Value* lane_id;
-
- // Block id.
- llvm::Value* block_id;
-};
-
-struct TilingKernelInfo {
- // Tiling bounds.
- absl::InlinedVector<llvm::Value*, 4> output_tile_bounds;
-
- // Starting tile, as calculated from block id only.
- llvm_ir::IrArray::Index tile_origin;
-
- // Thread meta-info.
- TilingThreadIdInfo thread_id_info;
-};
-
-// A function to generate the code to emit the entire tile.
-//
-// index: Absolute coordinate of the start of the tile in input.
-// tile_dimensions: Size of the tile
-using TileGenerator =
- std::function<void(const TilingThreadIdInfo& thread_id_info,
- const llvm_ir::IrArray::Index& tile_start_index,
- absl::Span<llvm::Value* const> tile_dimensions)>;
-
-// A function object to generate code to process one element in a tile.
-//
-// index_in_tile: the current coordinates within the tile. To get the global
-// coordinates, use `tile_start_index.AddOffset(index_in_tile, ...)`.
-using TileElementGenerator =
- std::function<void(absl::Span<llvm::Value* const> index_in_tile)>;
-
-// Emits code to iterate through a tile with given tile dimensions and generate
-// elements using the callback.
-void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
- const TilingThreadIdInfo& thread_id_info,
- absl::Span<llvm::Value* const> tile_dimensions,
- const TileElementGenerator& emit_elem_function);
-
-// Emits a kernel for the hlo instruction using the given kernel mapping
-// scheme.
-absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
- llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
- const TileGenerator& tile_element_generator);
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/BUILD b/third_party/xla/xla/service/gpu/fusions/tools/BUILD
new file mode 100644
index 0000000..2886ad1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/BUILD
@@ -0,0 +1,113 @@
+load("//xla:xla.bzl", "xla_cc_binary")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ licenses = ["notice"],
+)
+
+xla_cc_binary(
+ name = "mlir_fusions_opt",
+ srcs = ["mlir_fusions_opt.cc"],
+ visibility = ["//xla/service/gpu/fusions:__subpackages__"],
+ deps = [
+ "//xla/mlir_hlo",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
+ "//xla/service/gpu/fusions/transforms:passes",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:ComplexDialect",
+ "@llvm-project//mlir:DLTIDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FuncExtensions",
+ "@llvm-project//mlir:GPUDialect",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MlirOptLib",
+ "@llvm-project//mlir:NVVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorDialect",
+ ],
+)
+
+cc_library(
+ name = "test_lib",
+ testonly = 1,
+ srcs = ["test_lib.cc"],
+ hdrs = ["test_lib.h"],
+ deps = [
+ "//xla:status_macros",
+ "//xla/hlo/ir:hlo",
+ "//xla/mlir_hlo",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/fusions",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
+ "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
+ "//xla/stream_executor:device_description",
+ "//xla/tools:hlo_module_loader",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:ComplexDialect",
+ "@llvm-project//mlir:DLTIDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FuncExtensions",
+ "@llvm-project//mlir:GPUDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MlirOptLib",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:VectorDialect",
+ ],
+)
+
+xla_cc_binary(
+ name = "fusion_to_mlir",
+ testonly = 1,
+ srcs = ["fusion_to_mlir.cc"],
+ visibility = ["//xla/service/gpu/fusions:__subpackages__"],
+ deps = [
+ ":test_lib",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@llvm-project//llvm:Support",
+ "@local_tsl//tsl/platform:platform_port",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_binary(
+ name = "test_correctness",
+ testonly = 1,
+ srcs = ["test_correctness.cc"],
+ visibility = ["//xla/service/gpu/fusions:__subpackages__"],
+ deps = [
+ ":test_lib",
+ "//xla:debug_options_flags",
+ "//xla:error_spec",
+ "//xla:shape_util",
+ "//xla/service:gpu_plugin",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/service/gpu/model:indexing_test_utils",
+ "//xla/tests:hlo_test_base",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@com_google_googletest//:gtest",
+ "@llvm-project//llvm:Support",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc
new file mode 100644
index 0000000..9fe41b6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc
@@ -0,0 +1,48 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "llvm/Support/raw_ostream.h"
+#include "xla/service/gpu/fusions/tools/test_lib.h"
+#include "tsl/platform/init_main.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::Status Run(const std::string& filename) {
+ TF_ASSIGN_OR_RETURN(auto module, LoadTestModule(filename));
+ TF_ASSIGN_OR_RETURN(auto emitter_data, GetMlirFusionEmitter(*module));
+
+ auto context = GetMlirContextForTest();
+ TF_ASSIGN_OR_RETURN(auto mlir_module,
+ emitter_data->emitter->CreateMLIRModule(
+ context, *emitter_data->fusion, "main",
+ /*buffer_assignment=*/nullptr));
+ llvm::outs() << *mlir_module;
+ return absl::OkStatus();
+}
+
+} // namespace gpu
+} // namespace xla
+
+int main(int argc, char** argv) {
+ tsl::port::InitMain(argv[0], &argc, &argv);
+ CHECK_EQ(argc, 2) << "Must specify an input file";
+ CHECK_OK(xla::gpu::Run(argv[1]));
+ return 0;
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc
new file mode 100644
index 0000000..7206cc1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc
@@ -0,0 +1,77 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "mlir/Transforms/Passes.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+
+int main(int argc, char** argv) {
+ mlir::DialectRegistry registry;
+ registry.insert<mlir::DLTIDialect, mlir::tensor::TensorDialect,
+ mlir::func::FuncDialect, mlir::affine::AffineDialect,
+ mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
+ mlir::math::MathDialect, mlir::scf::SCFDialect,
+ mlir::mhlo::MhloDialect, mlir::LLVM::LLVMDialect,
+ mlir::gpu::GPUDialect, mlir::mhlo::MhloDialect,
+ mlir::vector::VectorDialect, xla::gpu::XlaGpuDialect,
+ mlir::NVVM::NVVMDialect>();
+ mlir::func::registerAllExtensions(registry);
+ mlir::registerCanonicalizerPass();
+ mlir::registerCSEPass();
+ mlir::registerInliner();
+ xla::gpu::registerGpuFusionTransformsPasses();
+ mlir::registerPassPipeline(
+ "xla-gpu-test-to-inline",
+ "Test pipeline of passes up to inlining. No vectorization, also does not "
+ "lower xla_gpu. Intended to simplify IR in tests.",
+ [=](mlir::OpPassManager& pm, llvm::StringRef options,
+ llvm::function_ref<mlir::LogicalResult(const llvm::Twine&)>
+ errorHandler) {
+ if (!options.empty()) return mlir::failure();
+
+ pm.addNestedPass<mlir::func::FuncOp>(
+ xla::gpu::CreateSimplifyArithPass());
+ pm.addPass(xla::gpu::CreateEraseDeadFunctionsPass());
+ pm.addPass(mlir::createCSEPass());
+ pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
+ pm.addPass(mlir::createCSEPass());
+ }));
+ return mlir::success();
+ },
+ [](llvm::function_ref<void(const mlir::detail::PassOptions&)>) {});
+
+ return mlir::failed(
+ MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry));
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc
new file mode 100644
index 0000000..72529cd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc
@@ -0,0 +1,192 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "xla/debug_options_flags.h"
+#include "xla/error_spec.h"
+#include "xla/service/gpu/fusions/tools/test_lib.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/shape.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+
+struct Flags {
+ std::string input_file = "";
+ float abs_error_bound = 1e-4;
+ float rel_error_bound = 1e-4;
+ std::vector<std::pair<std::string, std::vector<int64_t>>> bijection_inputs;
+ std::vector<std::string> bijection_outputs;
+};
+
+Flags& flags = *new Flags;
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using CorrectnessTest = HloTestBase;
+
+const Shape& GetFirstArrayShape(const Shape& shape) {
+ if (shape.IsArray()) {
+ return shape;
+ }
+ CHECK(shape.IsTuple());
+ return GetFirstArrayShape(shape.tuple_shapes(0));
+}
+
+absl::Status TestBijection(const IndexingMap& map,
+ absl::Span<int64_t const> shape) {
+ std::vector<Interval> intervals;
+ for (int64_t size : shape) {
+ intervals.push_back({0, size - 1});
+ }
+ auto status = VerifyBijection(map, intervals);
+ if (status.ok()) return status;
+ return absl::FailedPreconditionError(
+ absl::StrCat(status.message(), " in map ", map.ToString()));
+}
+
+TEST_F(CorrectnessTest, RunAndCompare) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file));
+ EXPECT_TRUE(RunAndCompareNoHloPasses(
+ std::move(module),
+ ErrorSpec{flags.abs_error_bound, flags.rel_error_bound}));
+}
+
+absl::StatusOr<int64_t> GetHeroIndex(absl::string_view name,
+ const HloFusionAnalysis& analysis) {
+ for (auto [index, hero] : llvm::enumerate(analysis.fusion_heroes())) {
+ if (hero.name() == name) {
+ return index;
+ }
+ }
+ return absl::NotFoundError(absl::StrCat("Hero ", name, " not found"));
+}
+
+std::pair<std::string, std::vector<int64_t>> ParseHeroAndIds(
+ absl::string_view hero_and_ids) {
+ std::pair<absl::string_view, absl::string_view> hero_and_ids_pair =
+ absl::StrSplit(hero_and_ids, ':');
+ std::vector<int64_t> ids;
+ for (absl::string_view id : absl::StrSplit(hero_and_ids_pair.second, ',')) {
+ ids.push_back(std::stoi(std::string(absl::StripAsciiWhitespace(id))));
+ }
+ return {std::string(absl::StripAsciiWhitespace(hero_and_ids_pair.first)),
+ ids};
+}
+
+TEST_F(CorrectnessTest, InputIndexingIsBijection) {
+ auto context = GetMlirContextForTest();
+ TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file));
+ TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module));
+ for (const auto& [hero_name, ids] : flags.bijection_inputs) {
+ TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index,
+ GetHeroIndex(hero_name, *emitter_data->analysis));
+ for (int64_t id : ids) {
+ auto indexing = emitter_data->emitter->ComputeThreadIdToInputIndexing(
+ hero_index, id, &context);
+ ASSERT_TRUE(indexing.has_value());
+ TF_ASSERT_OK(TestBijection(*indexing,
+ emitter_data->analysis->fusion_hero(hero_index)
+ .GetOperand(id)
+ .shape()
+ .dimensions()))
+ << "Expected operand " << id << " of " << hero_name << " (root index "
+ << hero_index << ") to be read exactly once.";
+ }
+ }
+}
+
+TEST_F(CorrectnessTest, OutputIndexingIsBijection) {
+ auto context = GetMlirContextForTest();
+ TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file));
+ TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module));
+ for (const auto& hero_name : flags.bijection_outputs) {
+ TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index,
+ GetHeroIndex(hero_name, *emitter_data->analysis));
+ auto indexing = emitter_data->emitter->ComputeThreadIdToOutputIndexing(
+ hero_index, &context);
+ ASSERT_TRUE(indexing.has_value());
+ TF_ASSERT_OK(TestBijection(
+ *indexing, GetFirstArrayShape(
+ emitter_data->analysis->fusion_root(hero_index).shape())
+ .dimensions()))
+ << "Expected output of " << hero_name << " (root index " << hero_index
+ << ") to be written exactly once.";
+ }
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
+
+int main(int argc, char* argv[]) {
+ std::vector<tsl::Flag> flag_list = {
+ tsl::Flag("abs_error_bound", &flags.abs_error_bound,
+ "Absolute error bound."),
+ tsl::Flag("rel_error_bound", &flags.rel_error_bound,
+ "Relative error bound."),
+ tsl::Flag(
+ "bijection_inputs",
+ [](std::string name_and_ids) {
+ if (name_and_ids.empty()) return false;
+ flags.bijection_inputs.push_back(
+ xla::gpu::ParseHeroAndIds(name_and_ids));
+ return true;
+ },
+ "",
+ "The name of a hero followed by operand ids that should be read "
+ "exactly once, i.e. there's a bijection between a subset of threads "
+ "and the input shape. Example: 'reduction0: 0, 1'."),
+ tsl::Flag(
+ "bijection_outputs",
+ [](std::string name) {
+ if (name.empty()) return false;
+ flags.bijection_outputs.push_back(name);
+ return true;
+ },
+ "",
+ "The name of a hero whose outputs should be written exactly once, "
+ "i.e. there's a bijection between a subset of threads and the output "
+ "shape.")};
+
+ xla::AppendDebugOptionsFlags(&flag_list);
+ std::string usage = tsl::Flags::Usage(argv[0], flag_list);
+ bool parseResult = tsl::Flags::Parse(&argc, argv, flag_list);
+ if (!parseResult || argc != 2) {
+ LOG(ERROR) << "\n" << usage;
+ return 1;
+ }
+
+ flags.input_file = argv[1];
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc
new file mode 100644
index 0000000..11b82dd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc
@@ -0,0 +1,118 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/tools/test_lib.h"
+
+#include <memory>
+
+#include "absl/algorithm/container.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/status_macros.h"
+#include "xla/tools/hlo_module_loader.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<std::unique_ptr<HloModule>> LoadTestModule(
+ absl::string_view filename) {
+ auto module = *xla::LoadModuleFromFile(std::string(filename));
+ module->mutable_config()
+ .mutable_debug_options()
+ .set_xla_gpu_mlir_emitter_level(4);
+
+ int num_fusions = absl::c_count_if(
+ module->entry_computation()->instructions(),
+ [](const HloInstruction* instruction) {
+ return instruction->opcode() == xla::HloOpcode::kFusion;
+ });
+ TF_RET_CHECK(num_fusions <= 1) << "HLO must contain at most one fusion";
+
+ if (num_fusions == 0) {
+ // Generate a fusion from the entry computation.
+ HloComputation::Builder builder("generated_main");
+ std::vector<HloInstruction*> params;
+ for (const auto* param :
+ module->entry_computation()->parameter_instructions()) {
+ params.push_back(*builder.AddParameter(param->Clone(/*suffix=*/"")));
+ }
+ builder.AddInstruction(HloInstruction::CreateFusion(
+ module->entry_computation()->root_instruction()->shape(),
+ HloInstruction::FusionKind::kLoop /* irrelevant */, params,
+ module->entry_computation()));
+
+ auto* new_entry = module->AddComputationAndUnifyNamesAndIds(
+ builder.Build(), /*is_entry=*/false);
+ module->ReplaceEntryComputation(new_entry);
+ }
+
+ return module;
+}
+
+absl::StatusOr<std::unique_ptr<EmitterData>> GetMlirFusionEmitter(
+ const HloModule& module) {
+ auto data = std::make_unique<EmitterData>();
+ data->fusion = DynCast<HloFusionInstruction>(
+ module.entry_computation()->root_instruction());
+ TF_RET_CHECK(data->fusion != nullptr) << "Root instruction must be a fusion";
+ data->device.emplace(TestGpuDeviceInfo::RTXA6000DeviceInfo());
+ data->analysis.emplace(
+ HloFusionAnalysis::Create(*data->fusion, data->device.value()));
+ PreBufferAssignmentFusionInfo info(data->analysis.value());
+ auto emitter = GetFusionEmitter(info);
+
+ auto mlir_emitter = dynamic_cast<MlirFusionEmitterBase*>(emitter.get());
+ TF_RET_CHECK(mlir_emitter != nullptr)
+ << "Expected emitter to be an MlirFusionEmitter";
+
+ emitter.release();
+ data->emitter.reset(mlir_emitter);
+ return data;
+}
+
+mlir::MLIRContext GetMlirContextForTest() {
+ mlir::DialectRegistry registry;
+ registry.insert<mlir::DLTIDialect, mlir::tensor::TensorDialect,
+ mlir::func::FuncDialect, mlir::affine::AffineDialect,
+ mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
+ mlir::math::MathDialect, mlir::scf::SCFDialect,
+ mlir::mhlo::MhloDialect, mlir::gpu::GPUDialect,
+ mlir::vector::VectorDialect, XlaGpuDialect>();
+ return mlir::MLIRContext(registry);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h
new file mode 100644
index 0000000..5dfa300
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h
@@ -0,0 +1,58 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_
+#define XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_
+
+#include <memory>
+#include <optional>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+
+namespace gpu {
+
+// Loads a test module from the given filename, ensuring it has a single fusion.
+// If the file contains more than one fusion, the function fails. If the file
+// contains no fusions, the function generates a fusion from the entry
+// computation.
+absl::StatusOr<std::unique_ptr<HloModule>> LoadTestModule(
+ absl::string_view filename);
+
+// Returns the MLIR fusion emitter for the given module, which should have been
+// loaded using LoadTestModule.
+struct EmitterData {
+ HloFusionInstruction* fusion;
+ std::optional<se::DeviceDescription> device;
+ std::optional<HloFusionAnalysis> analysis;
+ std::unique_ptr<MlirFusionEmitterBase> emitter;
+};
+absl::StatusOr<std::unique_ptr<EmitterData>> GetMlirFusionEmitter(
+ const HloModule& module);
+
+// Returns an MLIR context with all the dialects needed for testing.
+mlir::MLIRContext GetMlirContextForTest();
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD
new file mode 100644
index 0000000..3cf1d26
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD
@@ -0,0 +1,105 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = [":friends"],
+ licenses = ["notice"],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//xla:friends",
+ ],
+)
+
+gentbl_cc_library(
+ name = "passes_inc_gen",
+ tbl_outs = [
+ (
+ [
+ "-gen-pass-decls",
+ "-name=GpuFusionTransforms",
+ ],
+ "passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "passes.td",
+ visibility = ["//visibility:private"],
+ deps = ["@llvm-project//mlir:PassBaseTdFiles"],
+)
+
+cc_library(
+ name = "passes",
+ srcs = [
+ "convert_xla_gpu_pure_call_ops.cc",
+ "erase_dead_functions.cc",
+ "expand_float_ops.cc",
+ "flatten_tensors.cc",
+ "lower_tensors.cc",
+ "lower_to_llvm.cc",
+ "lower_xla_gpu_to_scf.cc",
+ "merge_pointers_to_same_slice.cc",
+ "optimize_loops.cc",
+ "peel_loops.cc",
+ "propagate_slice_indices.cc",
+ "simplify_affine.cc",
+ "simplify_arith.cc",
+ "unswitch_loops.cc",
+ "vectorize_loads_stores.cc",
+ ],
+ hdrs = ["passes.h"],
+ deps = [
+ ":passes_inc_gen",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/mlir_hlo",
+ "//xla/mlir_hlo:map_mhlo_to_scalar_op",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
+ "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
+ "//xla/service/gpu/model:indexing_analysis",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:AffineToStandard",
+ "@llvm-project//mlir:AffineUtils",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:ArithToLLVM",
+ "@llvm-project//mlir:ArithTransforms",
+ "@llvm-project//mlir:CallOpInterfaces",
+ "@llvm-project//mlir:ComplexDialect",
+ "@llvm-project//mlir:ComplexToLLVM",
+ "@llvm-project//mlir:ControlFlowToLLVM",
+ "@llvm-project//mlir:DataLayoutInterfaces",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FuncToLLVM",
+ "@llvm-project//mlir:GPUDialect",
+ "@llvm-project//mlir:GPUToNVVMTransforms",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LLVMCommonConversion",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MathToLLVM",
+ "@llvm-project//mlir:MathTransforms",
+ "@llvm-project//mlir:NVVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SCFToControlFlow",
+ "@llvm-project//mlir:SCFUtils",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:VectorDialect",
+ "@llvm-project//mlir:VectorToLLVM",
+ "@llvm-project//mlir:VectorTransforms",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc
new file mode 100644
index 0000000..0c9053a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc
@@ -0,0 +1,61 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_CONVERTPURECALLOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+struct RewriteCall : mlir::OpRewritePattern<PureCallOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ PureCallOp op, mlir::PatternRewriter& rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
+ op, op.getResultTypes(), op.getOperands(), op->getAttrs());
+ return mlir::success();
+ }
+};
+
+class ConvertPureCallOpsPass
+ : public impl::ConvertPureCallOpsPassBase<ConvertPureCallOpsPass> {
+ public:
+ void runOnOperation() override {
+ auto* ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<RewriteCall>(ctx);
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<::mlir::Pass> CreateConvertPureCallOpsPass() {
+ return std::make_unique<ConvertPureCallOpsPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc
new file mode 100644
index 0000000..3918a19
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc
@@ -0,0 +1,86 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <queue>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_ERASEDEADFUNCTIONSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+struct CallInfo {
+ PureCallOp call;
+ int count;
+};
+
+llvm::DenseSet<mlir::func::FuncOp> FindLiveFunctions(mlir::ModuleOp module) {
+ std::queue<mlir::func::FuncOp> worklist;
+ llvm::DenseSet<mlir::func::FuncOp> live_funcs;
+ module.walk([&](mlir::func::FuncOp func) {
+ if (!func.isPrivate()) {
+ worklist.push(func);
+ live_funcs.insert(func);
+ }
+ });
+
+ mlir::SymbolTableCollection symbol_table;
+ while (!worklist.empty()) {
+ auto func = worklist.front();
+ worklist.pop();
+ func.walk([&](mlir::CallOpInterface call) {
+ auto callee =
+ mlir::cast<mlir::func::FuncOp>(call.resolveCallable(&symbol_table));
+ if (live_funcs.insert(callee).second) {
+ worklist.push(callee);
+ }
+ });
+ }
+ return live_funcs;
+}
+
+class EraseDeadFunctionsPass
+ : public impl::EraseDeadFunctionsPassBase<EraseDeadFunctionsPass> {
+ public:
+ void runOnOperation() override {
+ // Find live functions and erase dead ones.
+ auto live = FindLiveFunctions(getOperation());
+ getOperation().walk([&](mlir::func::FuncOp func) {
+ if (!live.contains(func)) {
+ func.erase();
+ }
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateEraseDeadFunctionsPass() {
+ return std::make_unique<EraseDeadFunctionsPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc
new file mode 100644
index 0000000..2274576
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc
@@ -0,0 +1,675 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "absl/log/check.h"
+#include "llvm/ADT/APFloat.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+namespace ma = ::mlir::arith;
+
+using ma::SelectOp;
+using mlir::Value;
+
+#define GEN_PASS_DEF_EXPANDFLOATOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+// Wraps a Value to provide operator overloading for more readable expressions.
+struct Val {
+ Value value;
+ mlir::ImplicitLocOpBuilder* b;
+
+ operator Value() const { return value; } // NOLINT
+
+ Val operator+(int64_t rhs) const { return Binop<ma::AddIOp>(rhs); }
+ Val operator+(Value rhs) const { return Binop<ma::AddIOp>(rhs); }
+ Val operator-(int64_t rhs) const { return Binop<ma::SubIOp>(rhs); }
+ Val operator-(Value rhs) const { return Binop<ma::SubIOp>(rhs); }
+ Val operator*(int64_t rhs) const { return Binop<ma::MulIOp>(rhs); }
+ Val operator*(Value rhs) const { return Binop<ma::MulIOp>(rhs); }
+ Val operator&(Value rhs) const { return Binop<ma::AndIOp>(rhs); }
+ Val operator&(int64_t rhs) const { return Binop<ma::AndIOp>(rhs); }
+ Val operator|(Value rhs) const { return Binop<ma::OrIOp>(rhs); }
+ Val operator|(int64_t rhs) const { return Binop<ma::OrIOp>(rhs); }
+ Val operator^(Value rhs) const { return Binop<ma::XOrIOp>(rhs); }
+ Val shl(Value rhs) const { return Binop<ma::ShLIOp>(rhs); }
+ Val shl(int64_t rhs) const { return Binop<ma::ShLIOp>(rhs); }
+ Val shrui(Value rhs) const { return Binop<ma::ShRUIOp>(rhs); }
+ Val shrui(int64_t rhs) const { return Binop<ma::ShRUIOp>(rhs); }
+
+ Val cmp(ma::CmpIPredicate pred, Value rhs) const {
+ return {b->create<ma::CmpIOp>(pred, value, rhs), b};
+ }
+ Val cmp(ma::CmpIPredicate pred, int64_t rhs) const {
+ return cmp(pred, MakeConstant(rhs));
+ }
+ Val operator==(Value rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
+ Val operator==(int64_t rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
+ Val operator!=(int64_t rhs) const { return cmp(ma::CmpIPredicate::ne, rhs); }
+
+ Val MakeConstant(int64_t c) const {
+ return {b->create<ma::ConstantIntOp>(c, value.getType()), b};
+ }
+
+ private:
+ template <typename Op>
+ Val Binop(Value rhs) const {
+ return {b->create<Op>(value, rhs), b};
+ }
+
+ template <typename Op>
+ Val Binop(int64_t rhs) const {
+ return Binop<Op>(MakeConstant(rhs));
+ }
+};
+
+template <typename OpTy, ma::CmpFPredicate pred>
+struct RewriteToCmpSelect : public mlir::OpRewritePattern<OpTy> {
+ using mlir::OpRewritePattern<OpTy>::OpRewritePattern;
+
+ RewriteToCmpSelect(mlir::MLIRContext* context, bool include_f32)
+ : mlir::OpRewritePattern<OpTy>(context), include_f32(include_f32) {}
+
+ mlir::LogicalResult matchAndRewrite(
+ OpTy op, mlir::PatternRewriter& rewriter) const override {
+ if (op.getType().isF32() && !include_f32) {
+ return rewriter.notifyMatchFailure(op, "not rewriting f32 min/max");
+ }
+
+ auto lhs_is_nan = rewriter.create<ma::CmpFOp>(
+ op.getLoc(), ma::CmpFPredicate::UNE, op.getLhs(), op.getLhs());
+ auto rhs_is_not_nan = rewriter.create<ma::CmpFOp>(
+ op.getLoc(), ma::CmpFPredicate::OEQ, op.getRhs(), op.getRhs());
+
+ auto return_lhs =
+ rewriter.create<ma::CmpFOp>(op.getLoc(), pred, op.getLhs(), op.getRhs())
+ .getResult();
+
+ // logic: isNaN(lhs) || (!isNan(rhs) && return_lhs) ? lhs : rhs
+ return_lhs = rewriter.create<ma::OrIOp>(
+ op.getLoc(), lhs_is_nan,
+ rewriter.create<ma::AndIOp>(op.getLoc(), rhs_is_not_nan, return_lhs));
+
+ rewriter.replaceOpWithNewOp<SelectOp>(op, op.getResult().getType(),
+ return_lhs, op.getLhs(), op.getRhs());
+ return mlir::success();
+ }
+
+ bool include_f32;
+};
+
+struct RewriteErf32Pattern : public mlir::OpRewritePattern<mlir::math::ErfOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::math::ErfOp op, mlir::PatternRewriter& rewriter) const override {
+ if (!op.getType().isF32()) {
+ return rewriter.notifyMatchFailure(op, "not an f32 erf");
+ }
+
+ static const std::array<float, 5> kAlpha{
+ 0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f,
+ 0.18520832239976145f, 1.128379143519084f};
+
+ static const std::array<float, 7> kBeta{-1.1791602954361697e-7,
+ 0.000023547966471313185f,
+ 0.0010179625278914885f,
+ 0.014070470171167667f,
+ 0.11098505178285362f,
+ 0.49746925110067538f,
+ 1.0f};
+
+ // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of
+ // which x should be +/-1.
+ constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f;
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto c = [&](float v) -> Value {
+ return b.create<ma::ConstantFloatOp>(llvm::APFloat(v),
+ rewriter.getF32Type());
+ };
+
+ auto poly = [&](auto x, auto coefficients) -> Value {
+ auto r = c(coefficients[0]);
+ for (int i = 1; i < coefficients.size(); ++i) {
+ r = b.create<mlir::math::FmaOp>(r, x, c(coefficients[i]));
+ }
+ return r;
+ };
+
+ Value x = op.getOperand();
+ x = b.create<ma::MaximumFOp>(x, c(-kErfInvOneMinusHalfULP));
+ x = b.create<ma::MinimumFOp>(x, c(kErfInvOneMinusHalfULP));
+ Value x2 = b.create<ma::MulFOp>(x, x);
+
+ rewriter.replaceOpWithNewOp<ma::DivFOp>(
+ op, b.create<ma::MulFOp>(x, poly(x2, kAlpha)), poly(x2, kBeta));
+
+ return mlir::success();
+ }
+};
+
+int GetSignificandBits(mlir::FloatType ty) {
+ return llvm::APFloat::semanticsPrecision(ty.getFloatSemantics()) - 1;
+}
+
+int GetExponentBias(mlir::FloatType ty) {
+ return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
+}
+
+Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
+ auto ty = mlir::cast<mlir::FloatType>(value.getType());
+ if (mlir::LLVM::isCompatibleOuterType(ty)) {
+ value = b.create<mlir::math::AbsFOp>(value);
+ Value inf = b.create<ma::ConstantFloatOp>(
+ llvm::APFloat::getInf(ty.getFloatSemantics()), ty);
+ return b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, value, inf);
+ }
+
+ assert(ty.getIntOrFloatBitWidth() == 8);
+ if (!ty.isFloat8E5M2()) {
+ // F8E5M2 is the only 8 bit float with infinities.
+ return b.create<ma::ConstantIntOp>(false, b.getI1Type());
+ }
+ Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
+ return (bits & 0x7F) == 0x7C;
+}
+
+Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
+ auto ty = value.getType();
+ if (mlir::LLVM::isCompatibleOuterType(ty)) {
+ return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value);
+ }
+
+ assert(ty.getIntOrFloatBitWidth() == 8);
+ Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
+ if (ty.isFloat8E5M2() || ty.isFloat8E4M3FN()) {
+ return (bits & 0x7F) == 0x7F;
+ }
+ return bits == 0x80;
+}
+
+Value EmitReducePrecision(Value value, int exponent_bits, int mantissa_bits,
+ mlir::ImplicitLocOpBuilder& b) {
+ mlir::mhlo::ReducePrecisionOp::Properties properties;
+ properties.exponent_bits = b.getI32IntegerAttr(exponent_bits);
+ properties.mantissa_bits = b.getI32IntegerAttr(mantissa_bits);
+ return mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType<
+ mlir::mhlo::ReducePrecisionOp>(
+ b.getLoc(), value.getType(), {value.getType()},
+ mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), &b);
+}
+
+Value EmitF16ToF8e5m2(Value in, mlir::ImplicitLocOpBuilder& b) {
+ Val in_bits{b.create<ma::BitcastOp>(b.getI16Type(), in), &b};
+ // Use this method of checking for NaN because it's the same as what's used
+ // in the reduce precision lowering.
+ Value is_nan = (in_bits & 32767).cmp(ma::CmpIPredicate::ugt, 31744);
+
+ Value value = EmitReducePrecision(in, 5, 2, b);
+ value = b.create<ma::BitcastOp>(b.getI16Type(), value);
+ value = b.create<ma::ShRUIOp>(value,
+ b.create<ma::ConstantIntOp>(8, b.getI16Type()));
+ value = b.create<ma::TruncIOp>(b.getI8Type(), value);
+ // When the input is NaN, just truncating can turn a NaN into an inf if the
+ // mantissa becomes 0.
+ value = b.create<ma::SelectOp>(
+ is_nan, b.create<ma::ConstantIntOp>(0x7F, value.getType()), value);
+ return b.create<ma::BitcastOp>(b.getFloat8E5M2Type(), value);
+}
+
+Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
+ mlir::ImplicitLocOpBuilder& b) {
+ using ma::CmpIPredicate;
+
+ // This is a port of ConvertImpl in
+ // https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h
+ auto from_ty = mlir::cast<mlir::FloatType>(value.getType());
+ if (to_ty == b.getFloat8E5M2Type() && from_ty == b.getF16Type()) {
+ return EmitF16ToF8e5m2(value, b);
+ }
+
+ int from_mantissa = GetSignificandBits(from_ty);
+ int from_bias = GetExponentBias(from_ty);
+ int from_min_exp =
+ llvm::APFloat::semanticsMinExponent(from_ty.getFloatSemantics());
+ int from_max_exp =
+ llvm::APFloat::semanticsMaxExponent(from_ty.getFloatSemantics());
+ auto from_int_ty = b.getIntegerType(from_ty.getIntOrFloatBitWidth());
+
+ int to_mantissa = GetSignificandBits(to_ty);
+ int to_bias = GetExponentBias(to_ty);
+ int to_min_exp =
+ llvm::APFloat::semanticsMinExponent(to_ty.getFloatSemantics());
+ int to_max_exp =
+ llvm::APFloat::semanticsMaxExponent(to_ty.getFloatSemantics());
+ auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth());
+
+ mlir::IntegerType wide_int_ty;
+ if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) {
+ wide_int_ty = b.getI16Type();
+ } else {
+ wide_int_ty = b.getIntegerType(
+ std::max(from_int_ty.getWidth(), to_int_ty.getWidth()));
+ }
+ auto convert_int = [&](mlir::Type ty, Value v) -> Val {
+ if (v.getType() == ty) {
+ return {v, &b};
+ }
+ if (ty.getIntOrFloatBitWidth() < v.getType().getIntOrFloatBitWidth()) {
+ return {b.create<ma::TruncIOp>(ty, v), &b};
+ }
+ return {b.create<ma::ExtUIOp>(ty, v), &b};
+ };
+
+ int64_t exp_offset = to_bias - from_bias;
+ int digit_shift = to_mantissa - from_mantissa;
+
+ Val from_bits{
+ b.create<ma::BitcastOp>(
+ b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value),
+ &b};
+
+ auto cst = [&](mlir::Type ty, int64_t n) -> Val {
+ return {b.create<ma::ConstantIntOp>(n, ty), &b};
+ };
+
+ // Shift bits to destination type, without sign bit.
+ Val from_sign_bit =
+ from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0;
+
+ from_bits =
+ from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1);
+
+ Value result_is_inf = IsInf(value, b);
+ Value input_is_nan = IsNaN(value, b);
+
+ auto cst_bits = [&](llvm::APFloat f) {
+ return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())),
+ f.bitcastToAPInt().getZExtValue());
+ };
+ Value to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics()));
+ Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics()));
+ Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics()));
+
+ auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
+ assert(bits.value.getType() == roundoff.value.getType());
+ // Round to nearest even by adding a bias term.
+ // Consider a bit pattern
+ // FFF...FLRTT...T,
+ // where bits RTT...T need to be rounded-off. We add a bias term to the
+ // bit pattern s.t. a carry is introduced to round up only if
+ // - L is 1, R is 1, OR
+ // - L is 0, R is 1, any T is one.
+ // We do this by adding L to a bit pattern consisting of all T = 1.
+ Val rounded = (bits.shrui(roundoff) & 1) +
+ (bits.MakeConstant(1).shl(roundoff - 1) - 1);
+ Val bias{b.create<SelectOp>(roundoff == 0, roundoff, rounded), &b};
+ return bits + bias;
+ };
+
+ // Happy path: no subnormals, infinities or NaNs.
+ Value result;
+ {
+ // Round the mantissa if it is shrinking.
+ Val rounded_from_bits = convert_int(wide_int_ty, from_bits);
+ if (digit_shift < 0) {
+ rounded_from_bits = round_bits_to_nearest_even(
+ from_bits, from_bits.MakeConstant(-digit_shift)) &
+ ~((1ll << (-digit_shift)) - 1);
+ }
+
+ // Re-bias the exponent.
+ rounded_from_bits = rounded_from_bits + (exp_offset << from_mantissa);
+
+ // Check for overflows by aligning the significands. We always align the
+ // narrower significand to the wider significand.
+ int64_t to_highest = llvm::APFloat::getLargest(to_ty.getFloatSemantics())
+ .bitcastToAPInt()
+ .getZExtValue();
+ int64_t aligned_highest = to_highest;
+ if (digit_shift < 0) {
+ aligned_highest <<= -digit_shift;
+ // Shift down, all dropped bits should already be zero.
+ result = rounded_from_bits.shrui(-digit_shift);
+ } else {
+ // Shift up, inserting zeros in the newly created digits.
+ rounded_from_bits = rounded_from_bits.shl(digit_shift);
+ result = rounded_from_bits;
+ }
+ result = convert_int(to_int_ty, result);
+
+ // `From` supports larger values than `To`, we may overflow.
+ if (std::make_pair(to_max_exp, to_mantissa) <
+ std::make_pair(from_max_exp, from_mantissa)) {
+ result = b.create<SelectOp>(
+ rounded_from_bits.cmp(CmpIPredicate::ugt, aligned_highest), to_inf,
+ result);
+ }
+ }
+
+ auto i32_ty = b.getI32Type();
+ Val biased_from_exp = convert_int(i32_ty, from_bits.shrui(from_mantissa));
+
+ if (to_min_exp < from_min_exp) {
+ // `To` supports more exponents near zero which means that some subnormal
+ // values in `From` may become normal.
+
+ // Subnormals.
+ Val bits = convert_int(wide_int_ty, from_bits);
+
+ // Determine exponent in target type.
+ Value normalization_factor =
+ convert_int(i32_ty,
+ b.create<mlir::math::CountLeadingZerosOp>(from_bits)) -
+ (from_int_ty.getWidth() - from_mantissa - 1);
+
+ Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor;
+ // If the result is subnormal, adjust the subnormal bits to account for
+ // the difference in exponent bias.
+ Value subnormal_bits = bits;
+ if (exp_offset < wide_int_ty.getWidth()) {
+ subnormal_bits = bits.shl(exp_offset);
+ }
+
+ // Result is normal. Shift the mantissa to account for the number of
+ // leading zero digits, and clear the hidden bit.
+ // Insert the exponent bits.
+ Value normal_bits =
+ (bits.shl(convert_int(wide_int_ty, normalization_factor)) &
+ ~(1 << from_mantissa)) |
+ convert_int(wide_int_ty, biased_exponent).shl(from_mantissa);
+
+ Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0);
+ bits.value =
+ b.create<SelectOp>(biased_exp_sle_zero, subnormal_bits, normal_bits);
+ if (digit_shift > 0) {
+ bits = bits.shl(digit_shift);
+ } else {
+ bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift));
+ bits = bits.shrui(-digit_shift);
+ }
+ bits = convert_int(to_int_ty, bits);
+
+ result = b.create<SelectOp>(biased_from_exp == 0, bits, result);
+ } else if (to_min_exp > from_min_exp) {
+ // `To` supports fewer exponents near zero which means that some values in
+ // `From` may become subnormal.
+ Val unbiased_exp = biased_from_exp - from_bias;
+ Val biased_to_exp = unbiased_exp + to_bias;
+ // Subnormals and zero.
+ // Round and shift mantissa down.
+ Val from_has_leading_one = biased_from_exp != 0;
+ Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one);
+ from_has_leading_one = convert_int(from_int_ty, from_has_leading_one);
+ Val exponent_shift_i32 =
+ (from_has_leading_one_i32 - biased_to_exp) - digit_shift;
+ // Insert the implicit leading 1 bit on the mantissa for normalized
+ // inputs.
+ Val rounded_from_bits = (from_bits & ((1ll << from_mantissa) - 1)) |
+ from_has_leading_one.shl(from_mantissa);
+
+ // NOTE: we need to round again from the original from_bits,
+ // otherwise the lower precision bits may already be lost. There is
+ // an edge-case where rounding to a normalized value would normally
+ // round down, but for a subnormal, we need to round up.
+ Val exponent_shift_from_ty = convert_int(from_int_ty, exponent_shift_i32);
+ Val exponent_shift_to_ty = convert_int(to_int_ty, exponent_shift_i32);
+ Val positive_bits = convert_int(
+ to_int_ty,
+ round_bits_to_nearest_even(rounded_from_bits, exponent_shift_from_ty)
+ .shrui(exponent_shift_from_ty));
+ // To avoid UB, limit rounding and shifting to the full mantissa plus
+ // leading 1.
+ positive_bits.value = b.create<SelectOp>(
+ exponent_shift_i32.cmp(CmpIPredicate::sle, from_mantissa + 1),
+ positive_bits, to_zero);
+
+ Val negative_bits = convert_int(to_int_ty, rounded_from_bits)
+ .shl(to_zero - exponent_shift_to_ty);
+ Value bits =
+ b.create<SelectOp>(exponent_shift_i32.cmp(CmpIPredicate::sgt, 0),
+ positive_bits, negative_bits);
+ result = b.create<SelectOp>(biased_to_exp.cmp(CmpIPredicate::sle, 0), bits,
+ result);
+ }
+
+ // Handle types with no unsigned zero.
+ auto is_nuz = [](mlir::FloatType ty) {
+ return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
+ ty.isFloat8E5M2FNUZ();
+ };
+
+ if (is_nuz(to_ty)) {
+ // Clear the sign bit if the result is zero (the output has no negative
+ // zero).
+ Val result_is_non_zero = Val{result, &b} != 0;
+ from_sign_bit = from_sign_bit & result_is_non_zero;
+ } else if (is_nuz(from_ty)) {
+ // Clear the sign bit if the input is NaN (it's positive but encoded as
+ // negative 0).
+ from_sign_bit = from_sign_bit ^ input_is_nan;
+ }
+
+ result = b.create<SelectOp>(result_is_inf, to_inf, result);
+ result = b.create<SelectOp>(from_bits == 0, to_zero, result);
+ result = b.create<SelectOp>(input_is_nan, to_nan, result);
+
+ Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));
+
+ // Insert sign bit.
+ result = b.create<SelectOp>(from_sign_bit, neg_result, result);
+ result = b.create<ma::BitcastOp>(to_ty, result);
+ return result;
+}
+
+struct RewriteTruncFPattern : public mlir::OpRewritePattern<ma::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ ma::TruncFOp op, mlir::PatternRewriter& rewriter) const override {
+ using FloatValue = mlir::TypedValue<mlir::FloatType>;
+ auto src = mlir::cast<FloatValue>(op.getOperand());
+ auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
+ if (dst_ty.getWidth() != 8) {
+ return rewriter.notifyMatchFailure(op, "not an 8 bit truncf");
+ }
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
+ return mlir::success();
+ }
+};
+
+struct RewriteExtFPattern : public mlir::OpRewritePattern<ma::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ ma::ExtFOp op, mlir::PatternRewriter& rewriter) const override {
+ using FloatValue = mlir::TypedValue<mlir::FloatType>;
+ auto src = mlir::cast<FloatValue>(op.getOperand());
+ auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
+ if (src.getType().getWidth() != 8) {
+ return rewriter.notifyMatchFailure(op, "not an 8 bit extf");
+ }
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
+ return mlir::success();
+ }
+};
+
+// Lowering for cmpf : f8 for float to pred conversions.
+struct RewriteF8Cst : public mlir::OpRewritePattern<ma::CmpFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ ma::CmpFOp op, mlir::PatternRewriter& rewriter) const override {
+ using FloatValue = mlir::TypedValue<mlir::FloatType>;
+ auto lhs = mlir::cast<FloatValue>(op.getLhs());
+ auto rhs = mlir::cast<FloatValue>(op.getRhs());
+
+ if (lhs.getType().getWidth() != 8) {
+ return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf");
+ }
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ // Skip the f32 conversion if we're comparing UNE.cst.
+ llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics());
+ if (op.getPredicate() == ma::CmpFPredicate::UNE &&
+ mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) {
+ Val int_value{b.create<ma::BitcastOp>(rewriter.getI8Type(), lhs), &b};
+ int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue();
+ // If we're comparing to +-0, compare the absolute values.
+ if (rhs_cst.isZero() &&
+ (lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) {
+ int_value = int_value & 0x7f;
+ constant &= 0x7f;
+ }
+ auto cst = b.create<ma::ConstantIntOp>(constant, rewriter.getI8Type());
+ rewriter.replaceOpWithNewOp<ma::CmpIOp>(op, ma::CmpIPredicate::ne,
+ int_value, cst);
+ return mlir::success();
+ }
+
+ auto lhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), lhs);
+ auto rhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), rhs);
+ rewriter.replaceOpWithNewOp<ma::CmpFOp>(op, op->getResultTypes(),
+ mlir::ValueRange{lhs_ext, rhs_ext},
+ op->getAttrs());
+ return mlir::success();
+ }
+};
+
+struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::math::AbsFOp op, mlir::PatternRewriter& rewriter) const override {
+ using FloatValue = mlir::TypedValue<mlir::FloatType>;
+ auto src = mlir::cast<FloatValue>(op.getOperand());
+ // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16.
+ // Once that's removed, remove the code for BF16 here.
+ if (src.getType().getWidth() != 8 && !src.getType().isBF16()) {
+ return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf");
+ }
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
+ Val value{b.create<ma::BitcastOp>(i_ty, src), &b};
+ if (src.getType().getWidth() == 8) {
+ value = value & 0x7f;
+ } else {
+ CHECK(src.getType().isBF16());
+ value = value & 0x7fff;
+ }
+ rewriter.replaceOpWithNewOp<ma::BitcastOp>(op, src.getType(), value);
+ return mlir::success();
+ }
+};
+
+template <typename Op>
+struct RewriteIToFpPattern : public mlir::OpRewritePattern<Op> {
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ Op op, mlir::PatternRewriter& rewriter) const override {
+ if (op.getType().getIntOrFloatBitWidth() != 8) {
+ return rewriter.notifyMatchFailure(op, "not an f8 itofp");
+ }
+ Value to_float =
+ rewriter.create<Op>(op.getLoc(), rewriter.getF32Type(), op.getIn());
+ rewriter.replaceOpWithNewOp<ma::TruncFOp>(op, op.getType(), to_float);
+ return mlir::success();
+ }
+};
+
+template <typename Op>
+struct RewriteFpToIPattern : public mlir::OpRewritePattern<Op> {
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ Op op, mlir::PatternRewriter& rewriter) const override {
+ if (op.getIn().getType().getIntOrFloatBitWidth() != 8) {
+ return rewriter.notifyMatchFailure(op, "not an f8 fptoi");
+ }
+ Value to_f32 = rewriter.create<ma::ExtFOp>(
+ op.getLoc(), rewriter.getF32Type(), op.getIn());
+ rewriter.replaceOpWithNewOp<Op>(op, op.getType(), to_f32);
+ return mlir::success();
+ }
+};
+
+class ExpandFloatOpsPass
+ : public impl::ExpandFloatOpsPassBase<ExpandFloatOpsPass> {
+ public:
+ using ExpandFloatOpsPassBase::ExpandFloatOpsPassBase;
+ void runOnOperation() override {
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<RewriteToCmpSelect<ma::MinimumFOp, ma::CmpFPredicate::OLE>>(
+ &getContext(), /*include_f32=*/pre_ampere_);
+ patterns.add<RewriteToCmpSelect<ma::MaximumFOp, ma::CmpFPredicate::OGE>>(
+ &getContext(), /*include_f32=*/pre_ampere_);
+ patterns.add<RewriteTruncFPattern, RewriteExtFPattern, RewriteAbsFPattern,
+ RewriteF8Cst, RewriteIToFpPattern<ma::SIToFPOp>,
+ RewriteIToFpPattern<ma::UIToFPOp>,
+ RewriteFpToIPattern<ma::FPToSIOp>,
+ RewriteFpToIPattern<ma::FPToUIOp>>(&getContext());
+ mlir::populatePolynomialApproximateTanhPattern(patterns);
+ patterns.add<RewriteErf32Pattern>(&getContext());
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere) {
+ return createExpandFloatOpsPass(ExpandFloatOpsPassOptions{pre_ampere});
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc
new file mode 100644
index 0000000..a60a003
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc
@@ -0,0 +1,452 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/shape_util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_FLATTENTENSORSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::Location;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::RankedTensorType;
+using mlir::SmallVector;
+using mlir::Type;
+using mlir::TypedValue;
+using mlir::TypeRange;
+using mlir::UnrealizedConversionCastOp;
+using mlir::Value;
+using mlir::ValueRange;
+using mlir::func::FuncOp;
+using mlir::func::ReturnOp;
+using mlir::scf::ForOp;
+using mlir::scf::IfOp;
+using mlir::tensor::ExtractOp;
+using mlir::tensor::InsertOp;
+
+RankedTensorType GetFlattenedType(RankedTensorType tensor_type) {
+ return RankedTensorType::get({tensor_type.getNumElements()},
+ tensor_type.getElementType());
+}
+
+bool HasOnlyFlatTensorsOrScalars(TypeRange types) {
+ return llvm::all_of(types, [](Type ty) {
+ auto tensor_type = mlir::dyn_cast<RankedTensorType>(ty);
+ if (!tensor_type) return true;
+ return tensor_type.getRank() < 2;
+ });
+}
+
+struct RewriteFunctionSignatures : OpRewritePattern<FuncOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FuncOp op,
+ PatternRewriter& rewriter) const override {
+ auto input_types = op.getFunctionType().getInputs();
+ auto result_types = op.getFunctionType().getResults();
+ if (HasOnlyFlatTensorsOrScalars(input_types) &&
+ HasOnlyFlatTensorsOrScalars(result_types)) {
+ return rewriter.notifyMatchFailure(op, "nothing to flatten");
+ }
+
+ auto loc = op.getLoc();
+ mlir::Block* entry_block = &op.getBody().front();
+ SmallVector<Type> new_result_types;
+ SmallVector<Value> new_results;
+
+ // If some results are tensors, we need to flatten them.
+ auto terminator = entry_block->getTerminator();
+ rewriter.setInsertionPoint(terminator);
+
+ for (Value result : terminator->getOperands()) {
+ auto tensor_type = mlir::dyn_cast<RankedTensorType>(result.getType());
+ if (!tensor_type) {
+ new_result_types.push_back(result.getType());
+ new_results.push_back(result);
+ continue;
+ }
+ auto new_result_type = GetFlattenedType(tensor_type);
+ new_result_types.push_back(new_result_type);
+
+ Value result_1d =
+ rewriter
+ .create<UnrealizedConversionCastOp>(loc, new_result_type, result)
+ .getResult(0);
+ new_results.push_back(result_1d);
+ }
+ rewriter.replaceOpWithNewOp<ReturnOp>(terminator, new_results);
+
+ // Cast all function arguments to the original type.
+ SmallVector<Type> new_operand_types(input_types);
+ rewriter.setInsertionPointToStart(entry_block);
+ for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) {
+ if (auto tensor_type = mlir::dyn_cast<RankedTensorType>(operand_type)) {
+ if (tensor_type.getRank() > 1) {
+ mlir::BlockArgument func_argument = op.getArgument(index);
+ auto cast_to_orig_type = rewriter.create<UnrealizedConversionCastOp>(
+ loc, operand_type, func_argument);
+ func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0),
+ cast_to_orig_type);
+ operand_type = GetFlattenedType(tensor_type);
+ }
+ }
+ }
+ // Replace the function arguments with the new types.
+ for (auto [arg, arg_type] :
+ llvm::zip(entry_block->getArguments(), new_operand_types)) {
+ arg.setType(arg_type);
+ }
+ // Update function signature.
+ op.setType(rewriter.getFunctionType(new_operand_types, new_result_types));
+ return mlir::success();
+ }
+};
+
+// Returns the linearized index, if the rank is greater than 1. Otherwise,
+// returns nullptr.
+Value LinearizeIndex(TypedValue<mlir::RankedTensorType> tensor,
+ ValueRange indices, PatternRewriter& rewriter) {
+ if (tensor.getType().getRank() < 2) {
+ return nullptr;
+ }
+ auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
+ if (auto encoding = tensor.getType().getEncoding()) {
+ *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
+ mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
+ }
+ auto linear_shape =
+ ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
+ auto linearized_map =
+ GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
+ mlir::SmallVector<Value> result;
+ rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
+ ValueRange{}, linearized_map);
+ return result.front();
+}
+
+struct RewriteTensorExtract : OpRewritePattern<ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp op,
+ PatternRewriter& rewriter) const override {
+ auto tensor = op.getTensor();
+ auto tensor_type = tensor.getType();
+ auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
+ if (linear_index == nullptr) {
+ return rewriter.notifyMatchFailure(op, "the tensor is already flat");
+ }
+ auto tensor_1D = rewriter
+ .create<UnrealizedConversionCastOp>(
+ op.getLoc(), GetFlattenedType(tensor_type), tensor)
+ .getResult(0);
+ rewriter.replaceOpWithNewOp<ExtractOp>(op, tensor_1D, linear_index);
+ return mlir::success();
+ }
+};
+
+struct RewriteTensorInsert : OpRewritePattern<InsertOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter& rewriter) const override {
+ auto tensor = op.getDest();
+ auto tensor_type = tensor.getType();
+ auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
+ if (linear_index == nullptr) {
+ return rewriter.notifyMatchFailure(op, "the tensor is already flat");
+ }
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto tensor_1D = b.create<UnrealizedConversionCastOp>(
+ GetFlattenedType(tensor_type), tensor)
+ .getResult(0);
+ auto new_insert =
+ b.create<InsertOp>(op.getScalar(), tensor_1D, linear_index);
+ auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
+ tensor_type, new_insert.getResult());
+ rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
+ return mlir::success();
+ }
+};
+
+struct RewriteAtomicRMW : OpRewritePattern<AtomicRMWOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AtomicRMWOp op,
+ PatternRewriter& rewriter) const override {
+ auto tensor = op.getInput();
+ auto tensor_type = tensor.getType();
+ auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
+ if (linear_index == nullptr) {
+ return rewriter.notifyMatchFailure(op, "the tensor is already flat");
+ }
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto tensor_1D = b.create<UnrealizedConversionCastOp>(
+ GetFlattenedType(tensor_type), tensor)
+ .getResult(0);
+ auto new_atomic_rmw = b.create<AtomicRMWOp>(tensor_1D, linear_index);
+ rewriter.inlineRegionBefore(op.getRegion(),
+ &new_atomic_rmw.getRegion().front());
+ auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
+ tensor_type, new_atomic_rmw.getResult());
+ rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
+ return mlir::success();
+ }
+};
+
+// Checks that the value is produced by an unrealized conversion cast from 1D
+// tensor to ND. Returns the 1D tensor if so.
+std::optional<Value> GetDelinearizedTensor(Value value) {
+ auto tensor_type = mlir::dyn_cast<RankedTensorType>(value.getType());
+ if (!tensor_type || tensor_type.getRank() < 2) {
+ return std::nullopt;
+ }
+ auto cast = value.getDefiningOp<UnrealizedConversionCastOp>();
+ if (!cast || cast->getNumResults() != 1 || cast->getNumOperands() != 1) {
+ return std::nullopt;
+ }
+ auto type_before_linearization =
+ mlir::dyn_cast<RankedTensorType>(cast->getOperand(0).getType());
+ if (!type_before_linearization || type_before_linearization.getRank() != 1) {
+ return std::nullopt;
+ }
+ return cast->getOperand(0);
+}
+
+struct RewriteForOp : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForOp op,
+ PatternRewriter& rewriter) const override {
+ llvm::SmallBitVector args_to_update(op.getNumResults(), false);
+ mlir::SmallVector<Value> new_init_args;
+ new_init_args.reserve(op.getNumResults());
+ for (auto [index, arg] : llvm::enumerate(op.getInitArgs())) {
+ auto type_before_linearization = GetDelinearizedTensor(arg);
+ if (!type_before_linearization.has_value()) {
+ new_init_args.push_back(arg);
+ continue;
+ }
+ new_init_args.push_back(*type_before_linearization);
+ args_to_update.set(index);
+ }
+ if (args_to_update.none()) {
+ return rewriter.notifyMatchFailure(op, "no args to update");
+ }
+ // Create new ForOp with updated init args.
+ Location loc = op.getLoc();
+ auto new_for_op =
+ rewriter.create<ForOp>(loc, op.getLowerBound(), op.getUpperBound(),
+ op.getStep(), new_init_args);
+ new_for_op->setAttrs(op->getAttrs());
+
+ // Insert casts for the block arguments.
+ mlir::Block* new_body = new_for_op.getBody();
+ mlir::Block* old_body = op.getBody();
+ rewriter.setInsertionPoint(new_body, new_body->begin());
+ SmallVector<Value, 4> updated_block_args{new_body->getArguments().begin(),
+ new_body->getArguments().end()};
+ for (auto [index, arg] :
+ llvm::enumerate(new_body->getArguments().drop_front())) {
+ if (!args_to_update.test(index)) continue;
+ updated_block_args[index + 1] =
+ rewriter
+ .create<UnrealizedConversionCastOp>(
+ loc, old_body->getArgument(index + 1).getType(), arg)
+ .getResult(0);
+ }
+
+ // Move the body of the old ForOp to the new one.
+ rewriter.mergeBlocks(old_body, new_body, updated_block_args);
+
+ // Update the terminator.
+ auto new_terminator =
+ mlir::cast<mlir::scf::YieldOp>(new_body->getTerminator());
+ rewriter.setInsertionPoint(new_terminator);
+ for (auto&& [index, yielded_value] :
+ llvm::enumerate(new_terminator.getResultsMutable())) {
+ if (!args_to_update.test(index)) continue;
+ yielded_value.assign(
+ rewriter
+ .create<UnrealizedConversionCastOp>(
+ loc, new_init_args[index].getType(), yielded_value.get())
+ .getResult(0));
+ }
+
+ // Cast back the results.
+ rewriter.setInsertionPointAfter(new_for_op);
+ SmallVector<Value> new_results(new_for_op.getResults());
+ for (auto&& [index, result] : llvm::enumerate(new_results)) {
+ if (!args_to_update.test(index)) continue;
+ result = rewriter
+ .create<UnrealizedConversionCastOp>(
+ loc, op->getResult(index).getType(), result)
+ .getResult(0);
+ }
+ rewriter.replaceOp(op, new_results);
+ return mlir::failure();
+ }
+};
+
+struct RewriteIfOp : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter& rewriter) const override {
+ auto result_types = op.getResultTypes();
+ if (HasOnlyFlatTensorsOrScalars(result_types)) {
+ return rewriter.notifyMatchFailure(op, "nothing to flatten");
+ }
+ mlir::scf::YieldOp then_yield = op.thenYield();
+ SmallVector<Type> new_result_types;
+ new_result_types.reserve(then_yield.getNumOperands());
+ bool found_cast = false;
+ for (auto& result : then_yield->getOpOperands()) {
+ auto delinearized_tensor = GetDelinearizedTensor(result.get());
+ if (!delinearized_tensor.has_value()) {
+ new_result_types.push_back(result.get().getType());
+ continue;
+ }
+ new_result_types.push_back(delinearized_tensor->getType());
+ result.set(*delinearized_tensor);
+ found_cast = true;
+ }
+ if (!found_cast) {
+ return rewriter.notifyMatchFailure(op, "no cast found");
+ }
+ Location loc = op.getLoc();
+ // Update the else branch if present.
+ bool has_else_region = !op.getElseRegion().empty();
+ if (has_else_region) {
+ mlir::scf::YieldOp else_yield = op.elseYield();
+ mlir::OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(else_yield);
+ for (auto&& [result, type] :
+ llvm::zip(else_yield->getOpOperands(), new_result_types)) {
+ if (result.get().getType() == type) continue;
+ result.set(
+ rewriter.create<UnrealizedConversionCastOp>(loc, type, result.get())
+ .getResult(0));
+ }
+ }
+ // Create new IfOp and move the old op's regions to the new one.
+ auto new_if_op = rewriter.create<IfOp>(loc, new_result_types,
+ op.getCondition(), has_else_region);
+ rewriter.inlineRegionBefore(op.getThenRegion(),
+ &new_if_op.getThenRegion().back());
+ rewriter.eraseBlock(&new_if_op.getThenRegion().back());
+ if (has_else_region) {
+ rewriter.inlineRegionBefore(op.getElseRegion(),
+ &new_if_op.getElseRegion().back());
+ rewriter.eraseBlock(&new_if_op.getElseRegion().back());
+ }
+
+ // Update the results.
+ rewriter.setInsertionPointAfter(new_if_op);
+ SmallVector<Value> new_results(new_if_op.getResults());
+ for (auto&& [index, result] : llvm::enumerate(new_results)) {
+ Type old_type = op->getResult(index).getType();
+ if (result.getType() == old_type) continue;
+ result =
+ rewriter.create<UnrealizedConversionCastOp>(loc, old_type, result)
+ .getResult(0);
+ }
+ rewriter.replaceOp(op, new_results);
+ return mlir::success();
+ }
+};
+
+class FlattenTensorsPass
+ : public impl::FlattenTensorsPassBase<FlattenTensorsPass> {
+ public:
+ void runOnOperation() override {
+ mlir::ModuleOp module = getOperation();
+ MLIRContext* mlir_context = &getContext();
+ mlir::RewritePatternSet patterns(mlir_context);
+ // clang-format off
+ patterns.add<
+ RewriteAtomicRMW,
+ RewriteForOp,
+ RewriteFunctionSignatures,
+ RewriteIfOp,
+ RewriteTensorExtract,
+ RewriteTensorInsert
+ >(mlir_context);
+ // clang-format on
+ ApplyIndexingOp::getCanonicalizationPatterns(patterns, mlir_context);
+ if (mlir::failed(
+ mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
+ signalPassFailure();
+ return;
+ }
+ // Check if there are no unrealized_conversion_casts.
+ bool module_has_casts = module
+ .walk([](UnrealizedConversionCastOp op) {
+ return mlir::WalkResult::interrupt();
+ })
+ .wasInterrupted();
+ if (module_has_casts) {
+ llvm::outs() << "FlattenTensorsPass failed to converge";
+ signalPassFailure();
+ return;
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass() {
+ return std::make_unique<FlattenTensorsPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc
new file mode 100644
index 0000000..efccaaa
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc
@@ -0,0 +1,1095 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cassert>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/log/check.h"
+#include "absl/strings/str_cat.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DECL_LOWERTENSORSPASS
+#define GEN_PASS_DEF_LOWERTENSORSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::failure;
+using mlir::Location;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpBuilder;
+using mlir::Operation;
+using mlir::success;
+using mlir::Type;
+using mlir::TypedValue;
+using mlir::TypeRange;
+using mlir::Value;
+using mlir::ValueRange;
+
+namespace arith = ::mlir::arith;
+namespace scf = ::mlir::scf;
+namespace ml = ::mlir::LLVM;
+
+Value GetDestinationBuffer(Value dest) {
+ while (dest.getDefiningOp()) {
+ int result_number = mlir::cast<mlir::OpResult>(dest).getResultNumber();
+ if (auto insert = dest.getDefiningOp<mlir::tensor::InsertOp>()) {
+ dest = insert.getDest();
+ } else if (auto scf_if = dest.getDefiningOp<scf::IfOp>()) {
+ // Pick one of the branches, they're required to yield the same buffers.
+ dest = scf_if.getThenRegion().front().getTerminator()->getOperand(
+ result_number);
+ } else if (auto scf_for = dest.getDefiningOp<scf::ForOp>()) {
+ dest = scf_for.getInitArgs()[result_number];
+ } else if (dest.getDefiningOp<mlir::UnrealizedConversionCastOp>() ||
+ dest.getDefiningOp<AllocateSharedOp>()) {
+ break;
+ } else if (auto transfer_write =
+ dest.getDefiningOp<mlir::vector::TransferWriteOp>()) {
+ dest = transfer_write.getSource();
+ } else {
+ dest.getDefiningOp()->emitOpError("unsupported dest type");
+ return nullptr;
+ }
+ }
+ return dest;
+}
+
+template <typename Op>
+bool IsSupportedTransfer(Op op) {
+ return !absl::c_linear_search(op.getInBoundsValues(), false) &&
+ op.getVectorType().getRank() == 1 && !op.getMask() &&
+ op.getPermutationMap().isMinorIdentity();
+}
+
+struct RewriteFunctionSignatures : mlir::OpRewritePattern<mlir::func::FuncOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override {
+ auto is_tensor = [](Type ty) {
+ return mlir::isa<mlir::RankedTensorType>(ty);
+ };
+ if (!llvm::any_of(op.getFunctionType().getInputs(), is_tensor)) {
+ return rewriter.notifyMatchFailure(op,
+ "the function has no input tensors");
+ }
+
+ bool some_tensor_result =
+ llvm::any_of(op.getFunctionType().getResults(), is_tensor);
+ bool all_tensor_results =
+ llvm::all_of(op.getFunctionType().getResults(), is_tensor);
+ if (some_tensor_result && !all_tensor_results) {
+ op->emitOpError("function has a mix of tensor and non-tensor results");
+ return failure();
+ }
+
+ TypeRange new_results = op.getFunctionType().getResults();
+ if (some_tensor_result) {
+ new_results = {};
+ auto terminator = op.getFunctionBody().front().getTerminator();
+ rewriter.setInsertionPoint(terminator);
+ rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(terminator);
+ }
+
+ llvm::SmallVector<Type> new_operands(op.getFunctionType().getInputs());
+ for (auto&& [index, operand] : llvm::enumerate(new_operands)) {
+ if (is_tensor(operand)) {
+ rewriter.setInsertionPointToStart(&op.getBody().front());
+ auto cast = rewriter.create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), operand, op.getArgument(index));
+ op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast);
+ operand = mlir::LLVM::LLVMPointerType::get(op.getContext());
+ }
+ }
+
+ op.setFunctionType(rewriter.getFunctionType(new_operands, new_results));
+ auto& entry = op->getRegion(0).front();
+ for (auto [arg, arg_type] : llvm::zip(entry.getArguments(), new_operands)) {
+ arg.setType(arg_type);
+ }
+
+ return success();
+ }
+};
+
+Value GetLinearIndex(TypedValue<mlir::RankedTensorType> tensor,
+ ValueRange indices, mlir::PatternRewriter& rewriter) {
+ auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
+ if (auto encoding = tensor.getType().getEncoding()) {
+ *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
+ mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
+ }
+ auto linear_shape =
+ ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
+ auto linearize_map =
+ GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
+ mlir::SmallVector<Value> result;
+ rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
+ ValueRange{}, linearize_map);
+ CHECK_EQ(result.size(), 1);
+ auto index = result.front();
+ auto index_ty = rewriter.getIntegerType(
+ mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp())
+ .getTypeSizeInBits(index.getType()));
+ return rewriter.create<mlir::arith::IndexCastUIOp>(tensor.getLoc(), index_ty,
+ index);
+}
+
+std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
+ mlir::ImplicitLocOpBuilder& b) {
+ Value one = b.create<mlir::arith::ConstantIntOp>(1, linear_index.getType());
+ Value is_low_nibble = b.create<mlir::arith::CmpIOp>(
+ mlir::arith::CmpIPredicate::eq, one,
+ b.create<mlir::arith::AndIOp>(linear_index, one));
+ Value i8_index = b.create<mlir::arith::ShRUIOp>(linear_index, one);
+ return {i8_index, is_low_nibble};
+}
+
+mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
+ Value linear_index, mlir::PatternRewriter& rewriter,
+ Type element_type = nullptr) {
+ if (!element_type) {
+ element_type = tensor.getType().getElementType();
+ }
+ auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
+ auto tensor_ptr = rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ tensor.getLoc(), ptr, tensor)
+ .getResult(0);
+ mlir::LLVMTypeConverter converter(rewriter.getContext());
+ auto llvm_element_type = converter.convertType(element_type);
+ auto gep = rewriter.create<mlir::LLVM::GEPOp>(
+ tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, linear_index);
+ gep.setInbounds(true);
+ return gep;
+}
+
+mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
+ ValueRange indices,
+ mlir::PatternRewriter& rewriter) {
+ return CreateGep(tensor, GetLinearIndex(tensor, indices, rewriter), rewriter);
+}
+
+struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::tensor::ExtractOp op,
+ mlir::PatternRewriter& rewriter) const override {
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto linear_index =
+ GetLinearIndex(op.getTensor(), op.getIndices(), rewriter);
+ Type element_type = op.getTensor().getType().getElementType();
+ Value is_low_nibble = nullptr;
+ if (element_type == rewriter.getI4Type()) {
+ element_type = rewriter.getI8Type();
+ std::tie(linear_index, is_low_nibble) =
+ GetI4IndexAndNibble(linear_index, b);
+ }
+
+ auto gep = CreateGep(op.getTensor(), linear_index, rewriter, element_type);
+ auto load =
+ rewriter
+ .create<mlir::LLVM::LoadOp>(gep.getLoc(), gep.getElemType(), gep)
+ .getResult();
+
+ if (is_low_nibble) {
+ auto high_value = b.create<mlir::arith::ShRUIOp>(
+ load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
+ load = b.create<mlir::arith::TruncIOp>(
+ op.getType(),
+ b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
+ }
+
+ rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+ op, op.getType(), load);
+ return success();
+ }
+};
+
+// Swaps pairs of values in the vector: [0, 1, 2, 3] -> [1, 0, 3, 2].
+Value PermutePairsInVector(Value vector, mlir::ImplicitLocOpBuilder& b) {
+ // There is a `vector.extract_strided_slice` op that would be useful here, but
+ // it actually requires the strides to be 1.
+ auto ty = mlir::cast<mlir::VectorType>(vector.getType());
+ int size = ty.getNumElements();
+ Value result = vector;
+ for (int i = 0; i < size; i += 2) {
+ auto v0 = b.create<mlir::vector::ExtractOp>(vector, i);
+ auto v1 = b.create<mlir::vector::ExtractOp>(vector, i + 1);
+ result = b.create<mlir::vector::InsertOp>(v1, result, i);
+ result = b.create<mlir::vector::InsertOp>(v0, result, i + 1);
+ }
+ return result;
+}
+
+struct RewriteTransferRead
+ : mlir::OpRewritePattern<mlir::vector::TransferReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::vector::TransferReadOp op,
+ mlir::PatternRewriter& rewriter) const override {
+ assert(IsSupportedTransfer(op));
+
+ auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
+ op.getSource());
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto linear_index = GetLinearIndex(source, op.getIndices(), rewriter);
+
+ mlir::VectorType vector_type = op.getVectorType();
+ if (vector_type.getElementType().isInteger(1)) {
+ vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
+ }
+ mlir::Type gep_element_type = vector_type.getElementType();
+ if (op.getVectorType().getElementType().isInteger(4)) {
+ linear_index = b.create<arith::ShRUIOp>(
+ linear_index,
+ b.create<arith::ConstantIntOp>(1, linear_index.getType()));
+ gep_element_type = b.getI8Type();
+ }
+ auto gep = CreateGep(source, linear_index, rewriter, gep_element_type);
+
+ mlir::LLVMTypeConverter converter(b.getContext());
+ auto llvm_vector_type = converter.convertType(vector_type);
+ auto loaded =
+ b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();
+
+ if (source.getType().getElementType().isInteger(1)) {
+ Value zero = b.create<mlir::arith::ConstantOp>(
+ mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
+ loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
+ } else if (source.getType().getElementType().isInteger(4)) {
+ // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
+ // elements.
+ loaded = PermutePairsInVector(loaded, b);
+ }
+
+ rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+ op, op.getType(), loaded);
+ return success();
+ }
+};
+
+struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::tensor::InsertOp op,
+ mlir::PatternRewriter& rewriter) const override {
+ Value dest = GetDestinationBuffer(op.getDest());
+ if (!dest) {
+ return failure();
+ }
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
+ auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
+ auto element_type = tensor_dest.getType().getElementType();
+ Value is_low_nibble = nullptr;
+
+ if (element_type == rewriter.getI4Type()) {
+ element_type = rewriter.getI8Type();
+ std::tie(linear_index, is_low_nibble) =
+ GetI4IndexAndNibble(linear_index, b);
+ }
+
+ auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
+ auto scalar_value = op.getScalar();
+
+ if (is_low_nibble) {
+ Value current_value =
+ b.create<mlir::LLVM::LoadOp>(gep.getElemType(), gep);
+ auto ty = current_value.getType();
+ scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);
+ Value low_updated = b.create<mlir::arith::OrIOp>(
+ b.create<mlir::arith::AndIOp>(
+ current_value, b.create<mlir::arith::ConstantIntOp>(0xf0, ty)),
+ scalar_value);
+ Value high_updated = b.create<mlir::arith::OrIOp>(
+ b.create<mlir::arith::AndIOp>(
+ current_value, b.create<mlir::arith::ConstantIntOp>(0x0f, ty)),
+ b.create<mlir::arith::ShLIOp>(
+ scalar_value, b.create<mlir::arith::ConstantIntOp>(4, ty)));
+ scalar_value = b.create<mlir::arith::SelectOp>(is_low_nibble, low_updated,
+ high_updated);
+ }
+
+ mlir::LLVMTypeConverter converter(getContext());
+ auto llvm_type = converter.convertType(scalar_value.getType());
+ scalar_value = rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ gep.getLoc(), llvm_type, scalar_value)
+ .getResult(0);
+ rewriter.create<mlir::LLVM::StoreOp>(gep.getLoc(), scalar_value, gep);
+
+ op.replaceAllUsesWith(op.getDest());
+ op.erase();
+ return success();
+ }
+};
+
+struct RewriteTransferWrite
+ : mlir::OpRewritePattern<mlir::vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::vector::TransferWriteOp op,
+ mlir::PatternRewriter& rewriter) const override {
+ assert(IsSupportedTransfer(op));
+ Value dest = GetDestinationBuffer(op.getSource());
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
+ auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
+ auto element_type = tensor_dest.getType().getElementType();
+
+ mlir::Value vector_value = op.getVector();
+ if (op.getVectorType().getElementType().isInteger(1)) {
+ vector_value = b.create<arith::ExtUIOp>(
+ op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
+ vector_value);
+ }
+ if (op.getVectorType().getElementType().isInteger(4)) {
+ linear_index = b.create<arith::ShRUIOp>(
+ linear_index,
+ b.create<arith::ConstantIntOp>(1, linear_index.getType()));
+ element_type = rewriter.getI8Type();
+ // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
+ // elements.
+ vector_value = PermutePairsInVector(vector_value, b);
+ }
+ auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
+
+ mlir::LLVMTypeConverter converter(getContext());
+ auto llvm_type = converter.convertType(vector_value.getType());
+ vector_value =
+ b.create<mlir::UnrealizedConversionCastOp>(llvm_type, vector_value)
+ .getResult(0);
+ b.create<mlir::LLVM::StoreOp>(vector_value, gep);
+
+ rewriter.replaceOp(op, mlir::ValueRange{op.getSource()});
+ return success();
+ }
+};
+
+struct RewriteCall : mlir::OpRewritePattern<mlir::func::CallOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::func::CallOp op, mlir::PatternRewriter& rewriter) const override {
+ if (!llvm::any_of(op->getOperandTypes(), [](Type ty) {
+ return mlir::isa<mlir::RankedTensorType>(ty);
+ })) {
+ return rewriter.notifyMatchFailure(op, "the call has no input tensors");
+ }
+
+ for (const auto&& [index, arg] : llvm::enumerate(op.getOperands())) {
+ if (mlir::isa<mlir::RankedTensorType>(arg.getType())) {
+ op.setOperand(
+ index,
+ rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(),
+ mlir::LLVM::LLVMPointerType::get(op.getContext()), arg)
+ .getResult(0));
+ }
+ }
+ return success();
+ }
+};
+
+mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
+ const std::string& name_prefix,
+ mlir::ShapedType shaped_ty,
+ mlir::ModuleOp module, bool is_constant,
+ int addr_space,
+ mlir::ImplicitLocOpBuilder& b) {
+ if (auto elements = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(value)) {
+ // The lowering to LLVM only works for 1d tensors or those with trailing
+ // unit dimensions.
+ value = elements.reshape(mlir::RankedTensorType::get(
+ {elements.getNumElements()}, elements.getElementType()));
+ }
+
+ Type element_type = shaped_ty.getElementType();
+ // Needed to support complex element type.
+ mlir::LLVMTypeConverter converter(b.getContext());
+ auto llvm_element_type = converter.convertType(element_type);
+ auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type,
+ shaped_ty.getNumElements());
+ std::string name;
+ int index = 0;
+ do {
+ name = absl::StrCat(name_prefix, index);
+ ++index;
+ } while (module.lookupSymbol(name));
+ b.setInsertionPointToStart(module.getBody());
+ return b.create<mlir::LLVM::GlobalOp>(
+ array_ty, is_constant,
+ /*linkage=*/mlir::LLVM::Linkage::Private, name, value, /*alignment=*/0,
+ addr_space);
+}
+
+struct RewriteAllocateShared : mlir::OpRewritePattern<AllocateSharedOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ AllocateSharedOp op, mlir::PatternRewriter& rewriter) const override {
+ auto module = op->getParentOfType<mlir::ModuleOp>();
+ auto shaped_ty = mlir::cast<mlir::ShapedType>(op.getResult().getType());
+ constexpr int kGPUSharedMemoryAddrSpace = 3;
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ auto global =
+ CreateGlobalOp(mlir::Attribute{}, "shared_", shaped_ty, module,
+ /*is_constant=*/false, kGPUSharedMemoryAddrSpace, b);
+
+ rewriter.setInsertionPoint(op);
+ auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
+ rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+ op, op.getResult().getType(),
+ rewriter
+ .create<mlir::LLVM::AddrSpaceCastOp>(
+ op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
+ addr)
+ .getResult());
+ return success();
+ }
+};
+
+struct RewriteNonScalarConstants
+ : mlir::OpRewritePattern<mlir::arith::ConstantOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::arith::ConstantOp op,
+ mlir::PatternRewriter& rewriter) const override {
+ if (mlir::isa<mlir::VectorType>(op.getType())) {
+ return rewriter.notifyMatchFailure(op, "the op is a vector constant");
+ }
+ auto shaped_ty = mlir::dyn_cast<mlir::ShapedType>(op.getValue().getType());
+ // We only need to rewrite non-scalar constants.
+ if (!shaped_ty || shaped_ty.getNumElements() < 2) {
+ return rewriter.notifyMatchFailure(
+ op, "the op is an effective scalar constant");
+ }
+
+ constexpr int kGPUGlobalMemoryAddrSpace = 0;
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto module = op->getParentOfType<mlir::ModuleOp>();
+ auto global =
+ CreateGlobalOp(op.getValue(), "global_cst_", shaped_ty, module,
+ /*is_constant=*/true, kGPUGlobalMemoryAddrSpace, b);
+
+ rewriter.setInsertionPoint(op);
+ auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
+ rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+ op, op.getResult().getType(),
+ rewriter
+ .create<mlir::LLVM::AddrSpaceCastOp>(
+ op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
+ addr)
+ .getResult());
+ return success();
+ }
+};
+
+struct RewriteSyncThreads : mlir::OpRewritePattern<SyncThreadsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ SyncThreadsOp op, mlir::PatternRewriter& rewriter) const override {
+ rewriter.create<mlir::gpu::BarrierOp>(op.getLoc());
+ rewriter.replaceOp(op, op.getOperands());
+ return success();
+ }
+};
+
+// TODO(jreiffers): Generalize this to support index switches with some used
+// results and upstream it as a canonicalization pattern.
+struct RemoveUnusedIndexSwitchResults
+ : mlir::OpRewritePattern<scf::IndexSwitchOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ scf::IndexSwitchOp op, mlir::PatternRewriter& rewriter) const override {
+ if (op->getNumResults() == 0 || !op->use_empty()) {
+ return rewriter.notifyMatchFailure(op, "the op has users");
+ }
+
+ auto new_op = rewriter.create<scf::IndexSwitchOp>(
+ op.getLoc(), mlir::TypeRange{}, op.getArg(), op.getCases(),
+ op.getNumCases());
+ for (int i = 0; i < op->getNumRegions(); ++i) {
+ auto& old_region = op->getRegion(i);
+ auto& new_region = new_op->getRegion(i);
+ rewriter.mergeBlocks(&old_region.getBlocks().front(),
+ &new_region.emplaceBlock());
+ auto yield_op = new_region.getBlocks().front().getTerminator();
+ rewriter.modifyOpInPlace(yield_op, [&]() { yield_op->setOperands({}); });
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+bool IsAtomicIntegral(Type element_type) {
+ if (!element_type.isInteger()) {
+ return false;
+ }
+ unsigned element_bitwidth = element_type.getIntOrFloatBitWidth();
+ return element_bitwidth == 32 || element_bitwidth == 64;
+}
+
+Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) {
+ if (value.getType().isIntOrFloat() && ty.isIntOrFloat()) {
+ return b.create<ml::BitcastOp>(ty, value);
+ }
+
+ mlir::LLVMTypeConverter converter(b.getContext());
+ // If either type is a complex, we need to go through an alloca, since no
+ // direct bitcast from a struct to an int is possible.
+ Type llvm_input_ty = converter.convertType(value.getType());
+ Type llvm_result_ty = converter.convertType(ty);
+ Type ptr_ty = mlir::LLVM::LLVMPointerType::get(b.getContext());
+
+ Value llvm_value =
+ b.create<mlir::UnrealizedConversionCastOp>(llvm_input_ty, value)
+ .getResult(0);
+ Value alloca = b.create<ml::AllocaOp>(
+ ptr_ty, llvm_input_ty, b.create<ml::ConstantOp>(b.getI32Type(), 1));
+ b.create<ml::StoreOp>(llvm_value, alloca);
+ auto result = b.create<ml::LoadOp>(llvm_result_ty, alloca).getResult();
+ return b.create<mlir::UnrealizedConversionCastOp>(ty, result).getResult(0);
+};
+
+class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
+ public:
+ RewriteAtomicRMW(mlir::MLIRContext* context, bool is_amd,
+ const std::string& gpu_arch)
+ : mlir::OpRewritePattern<AtomicRMWOp>(context),
+ is_amd_(is_amd),
+ gpu_arch_(gpu_arch) {}
+
+ LogicalResult matchAndRewrite(
+ AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override {
+ if (failed(rewriteAsDirectAtomicRMW(op, rewriter))) {
+ rewriteAsAtomicCAS(op, rewriter);
+ }
+ rewriter.replaceOp(op, op.getInput());
+ return success();
+ }
+
+ private:
+ // Returns atomic op modifier and the atomic bin op kind.
+ std::optional<std::pair<Value, ml::AtomicBinOp>> GetAtomicModifierParameters(
+ AtomicRMWOp op) const {
+ Type element_type = op.getInput().getType().getElementType();
+ auto& operations = op.getBody()->getOperations();
+ auto terminator = op.getBody()->getTerminator();
+ if (operations.size() > 2) {
+ return std::nullopt;
+ }
+ // If the body contains only the terminator, then it is an atomic store.
+ if (operations.size() == 1) {
+ // TODO(b/336367145): Support complex<f32> atomic store.
+ if (element_type.isF32() || IsAtomicIntegral(element_type)) {
+ return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg);
+ }
+ return std::nullopt;
+ }
+ // Match the kind of the atomic op.
+ mlir::Operation* modifier_op = &operations.front();
+ std::optional<ml::AtomicBinOp> kind =
+ llvm::TypeSwitch<Operation*, std::optional<ml::AtomicBinOp>>(
+ modifier_op)
+ // Floating-point operations.
+ .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; })
+ .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; })
+ .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; })
+ // Integer operations.
+ .Case([&](arith::AddIOp op) {
+ return IsAtomicIntegral(element_type)
+ ? std::make_optional(ml::AtomicBinOp::add)
+ : std::nullopt;
+ })
+ .Case([&](arith::MaxUIOp op) {
+ return IsAtomicIntegral(element_type)
+ ? std::make_optional(ml::AtomicBinOp::umax)
+ : std::nullopt;
+ })
+ .Case([&](arith::MinUIOp op) {
+ return IsAtomicIntegral(element_type)
+ ? std::make_optional(ml::AtomicBinOp::umin)
+ : std::nullopt;
+ })
+ .Case([&](arith::MaxSIOp op) {
+ return IsAtomicIntegral(element_type)
+ ? std::make_optional(ml::AtomicBinOp::max)
+ : std::nullopt;
+ })
+ .Case([&](arith::MinSIOp op) {
+ return IsAtomicIntegral(element_type)
+ ? std::make_optional(ml::AtomicBinOp::min)
+ : std::nullopt;
+ })
+ .Default([](Operation* op) { return std::nullopt; });
+ if (!kind.has_value()) {
+ return std::nullopt;
+ }
+ // Find the modifier arg that does not match the argument of `atomic_rmw`
+ // body.
+ Value block_arg = op.getBody()->getArgument(0);
+ Value modifier_arg = modifier_op->getOperand(0) == block_arg
+ ? modifier_op->getOperand(1)
+ : modifier_op->getOperand(0);
+ return std::make_pair(modifier_arg, *kind);
+ }
+
+ // Certain computations, such as floating-point addition and integer
+ // maximization, can be simply implemented using an LLVM atomic instruction.
+ // If "computation" is one of this kind, emits code to do that and returns
+ // true; otherwise, returns false.
+ LogicalResult rewriteAsDirectAtomicRMW(
+ AtomicRMWOp op, mlir::PatternRewriter& rewriter) const {
+ auto modifier_parameters = GetAtomicModifierParameters(op);
+ if (!modifier_parameters.has_value()) {
+ return failure();
+ }
+ Value modifier_arg = modifier_parameters->first;
+ Type element_type = modifier_arg.getType();
+ ml::AtomicBinOp atomic_bin_op = modifier_parameters->second;
+
+ Location loc = op.getLoc();
+ llvm::StringRef sync_scope = is_amd_ ? "agent" : "";
+ Value addr = CreateGep(op.getInput(), op.getIndices(), rewriter);
+
+ switch (atomic_bin_op) {
+ case ml::AtomicBinOp::xchg: {
+ rewriter.create<ml::StoreOp>(
+ loc, modifier_arg, addr,
+ /*alignment=*/element_type.getIntOrFloatBitWidth() / 8,
+ /*volatile*/ false, /*isNonTemporal=*/false,
+ ml::AtomicOrdering::unordered);
+ return success();
+ }
+ case ml::AtomicBinOp::add:
+ case ml::AtomicBinOp::max:
+ case ml::AtomicBinOp::min:
+ case ml::AtomicBinOp::umax:
+ case ml::AtomicBinOp::umin: {
+ rewriter.create<ml::AtomicRMWOp>(loc, atomic_bin_op, addr, modifier_arg,
+ ml::AtomicOrdering::seq_cst,
+ sync_scope);
+ return success();
+ }
+ case ml::AtomicBinOp::fadd: {
+ // TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr.
+ return is_amd_ ? emitAMDAtomicFAdd(loc, modifier_arg, addr, sync_scope,
+ gpu_arch_, rewriter)
+ : emitNVidiaAtomicFAdd(loc, modifier_arg, addr,
+ sync_scope, gpu_arch_, rewriter);
+ }
+ case ml::AtomicBinOp::fmax: {
+ return rewriteAtomicFMaxAsIntAtomics(loc, modifier_arg, addr,
+ sync_scope, rewriter);
+ }
+ default:
+ return failure();
+ }
+ return success();
+ }
+
+ LogicalResult emitNVidiaAtomicFAdd(Location loc, Value modifier_arg,
+ Value addr, llvm::StringRef sync_scope,
+ llvm::StringRef cuda_arch,
+ OpBuilder& b) const {
+ se::CudaComputeCapability cuda_compute_capability(cuda_arch.str());
+ Type element_type = modifier_arg.getType();
+ // "atom.add.f64 requires sm_60 or higher."
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
+ bool is_supported_f16_atomic =
+ element_type.isF16() &&
+ cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA);
+ bool is_supported_bf16_atomic =
+ element_type.isBF16() &&
+ cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER);
+ bool is_supported_f64_atomic =
+ element_type.isF64() &&
+ cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_);
+ if (!element_type.isF32() && !is_supported_f16_atomic &&
+ !is_supported_bf16_atomic && !is_supported_f64_atomic) {
+ return failure();
+ }
+ b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
+ ml::AtomicOrdering::seq_cst, sync_scope);
+ return success();
+ }
+
+ LogicalResult emitAMDAtomicFAdd(Location loc, Value modifier_arg, Value addr,
+ llvm::StringRef sync_scope,
+ llvm::StringRef gcn_arch,
+ OpBuilder& b) const {
+ se::RocmComputeCapability rocm_compute_capability(gcn_arch.str());
+ Type element_type = modifier_arg.getType();
+ bool is_supported_f16_atomic =
+ element_type.isF16() &&
+ rocm_compute_capability.has_fp16_atomics_support();
+ if (!element_type.isF32() && !is_supported_f16_atomic) {
+ return failure();
+ }
+ constexpr int kGlobalMemory = 1;
+ constexpr int kSharedMemory = 3;
+ auto addr_type = mlir::cast<ml::LLVMPointerType>(addr.getType());
+ // adds to shared memory are always atomic.
+ if (addr_type.getAddressSpace() != kSharedMemory) {
+ // The compiler will only generate a global_atomic_fadd if the pointer is
+ // in global addrspace (1)
+ addr = b.create<ml::AddrSpaceCastOp>(
+ loc, ml::LLVMPointerType::get(b.getContext(), kGlobalMemory), addr);
+ }
+ b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
+ ml::AtomicOrdering::seq_cst, sync_scope);
+ return success();
+ }
+
+ LogicalResult rewriteAtomicFMaxAsIntAtomics(Location loc, Value modifier_arg,
+ Value addr,
+ llvm::StringRef sync_scope,
+ OpBuilder& b) const {
+ Type element_type = modifier_arg.getType();
+ if (!element_type.isF32()) {
+ return failure();
+ }
+ // Evaluating floating max using integer atomics has the limitation of not
+ // propagating -NaNs. To handle this, we check if the update value is -NaN
+ // and convert it to a positive one by dropping the sign-bit.
+ Value current = b.create<ml::LoadOp>(loc, element_type, addr);
+ Value current_is_nan =
+ b.create<ml::FCmpOp>(loc, ml::FCmpPredicate::uno, current, current);
+ auto is_current_nan =
+ b.create<scf::IfOp>(loc, /*resultTypes=*/TypeRange{}, current_is_nan,
+ /*addThenBlock=*/true, /*addElseBlock=*/true);
+ auto if_current_nan_then_builder =
+ OpBuilder::atBlockEnd(is_current_nan.thenBlock(), b.getListener());
+ if_current_nan_then_builder.create<scf::YieldOp>(loc);
+
+ auto if_current_nan_else_builder =
+ OpBuilder::atBlockEnd(is_current_nan.elseBlock(), b.getListener());
+ Value is_modifier_nan = if_current_nan_else_builder.create<ml::FCmpOp>(
+ loc, ml::FCmpPredicate::uno, modifier_arg, modifier_arg);
+ auto f32_nan = mlir::APFloat::getNaN(mlir::APFloat::IEEEsingle());
+ Value nan = if_current_nan_else_builder.create<ml::ConstantOp>(
+ loc, b.getF32Type(), f32_nan);
+ Value no_negative_nan_source =
+ if_current_nan_else_builder.create<ml::SelectOp>(loc, is_modifier_nan,
+ nan, modifier_arg);
+ Value current_less_than_modifier =
+ if_current_nan_else_builder.create<ml::FCmpOp>(
+ loc, ml::FCmpPredicate::ult, current, no_negative_nan_source);
+
+ // This check allows us to skip the atomic update all-together at the
+ // expense of reading the value in memory for every update. Evaluated
+ // against Waymo's benchmarks, adding the check achieves better overall
+ // performance.
+ auto if_need_update = if_current_nan_else_builder.create<scf::IfOp>(
+ loc, /*resultTypes=*/TypeRange{}, current_less_than_modifier,
+ /*withElseRegion=*/true,
+ /*addElseBlock=*/false);
+ if_current_nan_else_builder.create<scf::YieldOp>(loc);
+
+ auto then_builder =
+ OpBuilder::atBlockEnd(if_need_update.thenBlock(), b.getListener());
+ Value source_float_as_int = then_builder.create<ml::BitcastOp>(
+ loc, then_builder.getI32Type(), no_negative_nan_source);
+ Value c0 = then_builder.create<ml::ConstantOp>(loc, b.getI32Type(), 0);
+ Value is_not_negative = then_builder.create<ml::ICmpOp>(
+ loc, ml::ICmpPredicate::sge, source_float_as_int, c0);
+ then_builder.create<scf::IfOp>(
+ loc, is_not_negative,
+ [&](OpBuilder& nested_b, Location nested_loc) {
+ // atomicMax((int *)address, __float_as_int(val))
+ nested_b.create<ml::AtomicRMWOp>(
+ loc, ml::AtomicBinOp::max, addr, source_float_as_int,
+ ml::AtomicOrdering::seq_cst, sync_scope);
+ nested_b.create<scf::YieldOp>(nested_loc);
+ },
+ [&](OpBuilder& nested_b, Location nested_loc) {
+ // atomicMax((int *)address, __float_as_int(val))
+ nested_b.create<ml::AtomicRMWOp>(
+ loc, ml::AtomicBinOp::umin, addr, source_float_as_int,
+ ml::AtomicOrdering::seq_cst, sync_scope);
+ nested_b.create<scf::YieldOp>(nested_loc);
+ });
+ then_builder.create<scf::YieldOp>(loc);
+ return success();
+ }
+
+ // Implements atomic binary operations using atomic compare-and-swap
+ // (atomicCAS) as follows:
+ // 1. Reads the value from the memory pointed to by output_address and
+ // records it as old_output.
+ // 2. Uses old_output as one of the source operand to perform the binary
+ // operation and stores the result in new_output.
+ // 3. Calls atomicCAS which implements compare-and-swap as an atomic
+ // operation. In particular, atomicCAS reads the value from the memory
+ // pointed to by output_address, and compares the value with old_output.
+ // If the two values equal, new_output is written to the same memory
+ // location and true is returned to indicate that the atomic operation
+ // succeeds. Otherwise, the new value read from the memory is returned. In
+ // this case, the new value is copied to old_output, and steps 2. and 3.
+ // are repeated until atomicCAS succeeds.
+ //
+ // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers.
+ // If the element type of the binary operation is 32 bits or 64 bits, the
+ // integer type of the same size is used for the atomicCAS operation. On the
+ // other hand, if the element type is smaller than 32 bits, int32_t is used
+ // for the atomicCAS operation. In this case, atomicCAS reads and writes 32
+ // bit values from the memory, which is larger than the memory size required
+ // by the original atomic binary operation. We mask off the last two bits of
+ // the output_address and use the result as an address to read the 32 bit
+ // values from the memory. This can avoid out of bound memory accesses if
+ // tensor buffers are 4 byte aligned and have a size of 4N, an assumption that
+ // the runtime can guarantee.
+ void rewriteAsAtomicCAS(AtomicRMWOp op,
+ mlir::PatternRewriter& rewriter) const {
+ Location loc = op.getLoc();
+ auto input = op.getInput();
+
+ // Use 32-bit atomic type for small input types.
+ Type result_ty = op.getResult().getType().getElementType();
+ int result_size;
+ if (auto complex_ty = mlir::dyn_cast<mlir::ComplexType>(result_ty)) {
+ result_size = complex_ty.getElementType().getIntOrFloatBitWidth() * 2;
+ } else {
+ result_size = result_ty.getIntOrFloatBitWidth();
+ }
+
+ bool small_type = result_size < 32;
+ Type atomic_ty =
+ mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size);
+
+ // Calculate load address for the input.
+ Value addr = CreateGep(input, op.getIndices(), rewriter);
+ Value shift, mask;
+ if (small_type) {
+ // Update input pointer by discarding the last two bits - i.e. align to
+ // 32-bit boundary for small input types (will not result in OOB, as the
+ // input alignment is at least 32 bits).
+ Type addr_int_ty = rewriter.getI64Type();
+ Value addr_int = rewriter.create<ml::PtrToIntOp>(loc, addr_int_ty, addr);
+ Value addr_offset = rewriter.create<ml::AndOp>(
+ loc, addr_int, rewriter.create<ml::ConstantOp>(loc, addr_int_ty, 3));
+ Value index = rewriter.create<ml::MulOp>(
+ loc, addr_offset,
+ rewriter.create<ml::ConstantOp>(loc, addr_int_ty, -1));
+ addr =
+ rewriter.create<ml::GEPOp>(loc, addr.getType(), rewriter.getI8Type(),
+ addr, index, /*inbounds=*/true);
+
+ // Calculate the bit shift (assume little-endianness).
+ Value offset = rewriter.create<ml::TruncOp>(loc, atomic_ty, addr_offset);
+ shift = rewriter.create<ml::MulOp>(
+ loc, offset,
+ rewriter.create<ml::ConstantOp>(loc, offset.getType(), 8));
+
+ // Compose the update mask.
+ Value bits_long = rewriter.create<ml::ConstantOp>(loc, atomic_ty, -1);
+ Value bits_short = rewriter.create<ml::ZExtOp>(
+ loc, atomic_ty,
+ rewriter.create<ml::ConstantOp>(
+ loc, rewriter.getIntegerType(result_size), -1));
+ mask = rewriter.create<ml::XOrOp>(
+ loc, bits_long, rewriter.create<ml::ShlOp>(loc, bits_short, shift));
+ }
+
+ // Load initial atomic value and create the loop.
+ Value initial = rewriter.create<ml::LoadOp>(loc, atomic_ty, addr);
+ rewriter.create<scf::WhileOp>(
+ loc, TypeRange{atomic_ty}, ValueRange{initial},
+ [&](mlir::OpBuilder& builder, Location loc, ValueRange values) {
+ mlir::ImplicitLocOpBuilder b(loc, builder);
+ Value old_value = values[0];
+
+ // Convert atomic value to input value.
+ Value input_value;
+ if (small_type) {
+ Value short_value =
+ b.create<ml::TruncOp>(b.getIntegerType(result_size),
+ b.create<ml::LShrOp>(old_value, shift));
+ input_value = b.create<ml::BitcastOp>(result_ty, short_value);
+ } else {
+ input_value = CreateBitcast(b, old_value, result_ty);
+ }
+
+ // Perform computation on the loaded input value.
+ rewriter.mergeBlocks(&op.getComputation().front(), b.getBlock(),
+ {input_value});
+ auto yield_op = b.getBlock()->getTerminator();
+ Value result = yield_op->getOperand(0);
+ rewriter.eraseOp(yield_op);
+
+ // Convert resulting value to atomic value.
+ Value new_value;
+ if (small_type) {
+ Value cast_value = b.create<ml::ZExtOp>(
+ atomic_ty, b.create<ml::BitcastOp>(
+ rewriter.getIntegerType(result_size), result));
+ new_value =
+ b.create<ml::OrOp>(b.create<ml::AndOp>(old_value, mask),
+ b.create<ml::ShlOp>(cast_value, shift));
+ } else {
+ new_value = CreateBitcast(b, result, atomic_ty);
+ }
+
+ // Try saving the result atomically, retry if failed.
+ Value cmpxchg = b.create<ml::AtomicCmpXchgOp>(
+ loc, addr, old_value, new_value,
+ /*success_ordering=*/ml::AtomicOrdering::seq_cst,
+ /*failure_ordering=*/ml::AtomicOrdering::seq_cst);
+ Value next = b.create<ml::ExtractValueOp>(cmpxchg, 0);
+ Value ok = b.create<ml::ExtractValueOp>(cmpxchg, 1);
+ Value low_bit = b.create<ml::ConstantOp>(b.getOneAttr(b.getI1Type()));
+ Value not_ok = b.create<ml::XOrOp>(ok, low_bit);
+ b.create<scf::ConditionOp>(not_ok, ValueRange{next});
+ },
+ [&](mlir::OpBuilder& b, Location loc, ValueRange values) {
+ b.create<scf::YieldOp>(loc, values);
+ });
+ }
+
+ bool is_amd_;
+ std::string gpu_arch_;
+};
+
+class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
+ public:
+ explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
+ : LowerTensorsPassBase(options) {}
+
+ void runOnOperation() override {
+ MLIRContext* mlir_context = &getContext();
+ mlir::RewritePatternSet tensor_patterns(mlir_context);
+ tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
+ tensor_patterns
+ .add<RewriteAllocateShared, RewriteNonScalarConstants,
+ RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
+ RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+ getOperation(), std::move(tensor_patterns)))) {
+ signalPassFailure();
+ return;
+ }
+
+ mlir::RewritePatternSet function_patterns(mlir_context);
+ function_patterns.add<RewriteFunctionSignatures, RewriteCall,
+ RemoveUnusedIndexSwitchResults>(mlir_context);
+ scf::ForOp::getCanonicalizationPatterns(function_patterns, mlir_context);
+ scf::IfOp::getCanonicalizationPatterns(function_patterns, mlir_context);
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+ getOperation(), std::move(function_patterns)))) {
+ signalPassFailure();
+ return;
+ }
+
+ getOperation()->walk([this](mlir::LLVM::LoadOp load) {
+ Value addr = load.getAddr();
+ while (auto gep = addr.getDefiningOp<mlir::LLVM::GEPOp>()) {
+ addr = gep.getBase();
+ }
+ if (addr.getDefiningOp<mlir::LLVM::AddrSpaceCastOp>() ||
+ addr.getDefiningOp<mlir::LLVM::AddressOfOp>() ||
+ addr.getDefiningOp<mlir::LLVM::AllocaOp>()) {
+ // Shared memory, global constant or temporary - no need to annotate
+ // anything.
+ return;
+ }
+ if (auto base = mlir::dyn_cast<mlir::BlockArgument>(addr)) {
+ if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(
+ base.getOwner()->getParentOp())) {
+ if (func.getArgAttr(base.getArgNumber(), "xla.invariant")) {
+ load.setInvariant(true);
+ }
+ return;
+ }
+ }
+ load.emitOpError("load op address is not (a GEP of) a function argument");
+ signalPassFailure();
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(
+ bool is_amd_gpu, const std::string& gpu_arch) {
+ LowerTensorsPassOptions options;
+ options.is_amd_gpu_ = is_amd_gpu;
+ options.gpu_arch_ = gpu_arch;
+ return std::make_unique<LowerTensorsPass>(options);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc
new file mode 100644
index 0000000..28762d0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc
@@ -0,0 +1,99 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // IWYU pragma: keep
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_LOWERTOLLVMPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
+ public:
+ using LowerToLLVMPassBase::LowerToLLVMPassBase;
+
+ void runOnOperation() override {
+ // Populate type conversions.
+ mlir::LowerToLLVMOptions llvm_opts(&getContext(),
+ mlir::DataLayout(getOperation()));
+ mlir::LLVMTypeConverter type_converter(getOperation().getContext(),
+ llvm_opts);
+ mlir::LLVMConversionTarget target(*getOperation().getContext());
+
+ // Populate patterns.
+ mlir::RewritePatternSet patterns(&getContext());
+ mlir::populateAffineToStdConversionPatterns(patterns);
+ mlir::populateSCFToControlFlowConversionPatterns(patterns);
+ mlir::arith::populateArithExpandOpsPatterns(patterns);
+ mlir::arith::populateArithToLLVMConversionPatterns(type_converter,
+ patterns);
+ mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
+ mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
+ mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
+ mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter,
+ patterns);
+ mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns);
+ mlir::populateMathToLLVMConversionPatterns(type_converter, patterns);
+
+ // Setup target.
+ mlir::configureGpuToNVVMConversionLegality(target);
+ target.addIllegalDialect<mlir::arith::ArithDialect, mlir::func::FuncDialect,
+ mlir::complex::ComplexDialect,
+ mlir::math::MathDialect>();
+ target.addLegalOp<mlir::ModuleOp>();
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass() {
+ return std::make_unique<LowerToLLVMPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
new file mode 100644
index 0000000..cbd64b8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
@@ -0,0 +1,278 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <utility>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS
+#define GEN_PASS_DEF_LOWERXLAGPULOOPSTOSCFPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::ImplicitLocOpBuilder;
+using mlir::Location;
+using mlir::OpBuilder;
+using mlir::SmallVector;
+using mlir::success;
+using mlir::Value;
+using mlir::ValueRange;
+using mlir::scf::IfOp;
+
+struct RewritePredicatedInsert : mlir::OpRewritePattern<PredicatedInsertOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ PredicatedInsertOp op, mlir::PatternRewriter& rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
+ op, op.getCondition(),
+ [&](mlir::OpBuilder& b, mlir::Location loc) {
+ b.create<mlir::scf::YieldOp>(
+ loc, b.create<mlir::tensor::InsertOp>(
+ loc, op.getValue(), op.getDest(), op.getIndices())
+ .getResult());
+ },
+ [&](mlir::OpBuilder& b, mlir::Location loc) {
+ b.create<mlir::scf::YieldOp>(loc, op.getDest());
+ });
+ return success();
+ }
+};
+
+struct RewritePredicatedExtract : mlir::OpRewritePattern<PredicatedExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ PredicatedExtractOp op, mlir::PatternRewriter& rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
+ op, op.getCondition(),
+ [&](mlir::OpBuilder& b, mlir::Location loc) {
+ b.create<mlir::scf::YieldOp>(
+ loc, b.create<mlir::tensor::ExtractOp>(loc, op.getSrc(),
+ op.getIndices())
+ .getResult());
+ },
+ [&](mlir::OpBuilder& b, mlir::Location loc) {
+ b.create<mlir::scf::YieldOp>(loc, op.getFallback());
+ });
+ return success();
+ }
+};
+
+struct RewriteShuffleReduce : mlir::OpRewritePattern<ShuffleReduceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override {
+ int max_distance =
+ mlir::cast<mlir::IntegerAttr>(op->getAttr("max_distance")).getInt();
+ // TODO(jreiffers): Do this in a verifier.
+ if (max_distance & (max_distance - 1) || max_distance >= WarpSize()) {
+ return op->emitOpError("max_distance must be a power of 2 < WarpSize()");
+ }
+
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ ValueRange values = op.getOperands();
+ for (int distance = max_distance; distance > 0; distance /= 2) {
+ namespace ml = mlir::LLVM;
+ auto shuffle_32 = [&](Value v) {
+ return b
+ .create<mlir::gpu::ShuffleOp>(v, distance, WarpSize(),
+ mlir::gpu::ShuffleMode::DOWN)
+ .getShuffleResult();
+ };
+
+ auto shuffle_int_or_float = [&](Value value) {
+ auto ty = value.getType();
+ int bit_width = ty.getIntOrFloatBitWidth();
+ if (bit_width == 32) {
+ return shuffle_32(value);
+ }
+ int n_shuffles = CeilOfRatio(bit_width, 32);
+ auto int_ty = b.getIntegerType(bit_width);
+ auto padded_int_ty = b.getIntegerType(n_shuffles * 32);
+ value = b.create<mlir::arith::BitcastOp>(int_ty, value);
+ value = b.create<mlir::arith::ExtUIOp>(padded_int_ty, value);
+ if (n_shuffles > 1) {
+ // Don't generate vectors if the size is 1.
+ auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles);
+ value = b.create<ml::BitcastOp>(vector_type, value);
+ Value result_vec = b.create<ml::UndefOp>(vector_type);
+ for (int i = 0; i < n_shuffles; ++i) {
+ auto idx = b.create<mlir::arith::ConstantIntOp>(i, 32);
+ result_vec = b.create<ml::InsertElementOp>(
+ result_vec,
+ shuffle_32(b.create<ml::ExtractElementOp>(value, idx)), idx);
+ }
+ value = b.create<ml::BitcastOp>(padded_int_ty, result_vec);
+ } else {
+ value = shuffle_32(value);
+ }
+ value = b.create<mlir::arith::TruncIOp>(int_ty, value);
+ value = b.create<ml::BitcastOp>(ty, value);
+ return value;
+ };
+
+ auto shuffle = [&](Value value) -> Value {
+ if (mlir::isa<mlir::ComplexType>(value.getType())) {
+ return b.create<mlir::complex::CreateOp>(
+ value.getType(),
+ shuffle_int_or_float(b.create<mlir::complex::ReOp>(value)),
+ shuffle_int_or_float(b.create<mlir::complex::ImOp>(value)));
+ }
+ if (value.getType().isUnsignedInteger()) {
+ auto ty = value.getType();
+ auto signless_ty = b.getIntegerType(ty.getIntOrFloatBitWidth());
+ value = b.create<mlir::UnrealizedConversionCastOp>(
+ mlir::TypeRange{signless_ty}, value)
+ .getResult(0);
+ value = shuffle_int_or_float(value);
+ value = b.create<mlir::UnrealizedConversionCastOp>(
+ mlir::TypeRange{ty}, value)
+ .getResult(0);
+ return value;
+ }
+ return shuffle_int_or_float(value);
+ };
+
+ SmallVector<Value> args = values;
+ for (auto value : values) {
+ args.push_back(shuffle(value));
+ }
+ values = b.create<PureCallOp>(op.getResultTypes(),
+ op.getReducerAttr().getAttr(), args)
+ .getResults();
+ }
+ rewriter.replaceOp(op, values);
+ return success();
+ }
+};
+
+struct RewriteXlaGpuLoop : mlir::OpRewritePattern<LoopOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ LoopOp op, mlir::PatternRewriter& rewriter) const override {
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
+
+ IndexingMap indexing_map = op.getIndexingMap();
+ SmallVector<Value, 4> lbs, ubs, steps;
+ mlir_converter::GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs,
+ &steps);
+ mlir::scf::LoopNest loop_nest = mlir::scf::buildLoopNest(
+ b, loc, lbs, ubs, steps, op.getInits(),
+ [&](OpBuilder& nested_builder, Location loc, ValueRange symbol_values,
+ ValueRange iter_args) -> mlir::scf::ValueVector {
+ mlir::ImplicitLocOpBuilder nested_b(loc, nested_builder);
+ auto is_in_bounds = mlir_converter::CheckConstraints(
+ indexing_map, op.getDims(), symbol_values, nested_b);
+ auto if_op = nested_b.create<mlir::scf::IfOp>(
+ is_in_bounds,
+ [&](OpBuilder& then_builder, Location then_loc) -> void {
+ SmallVector<Value, 4> bb_args(symbol_values);
+ bb_args.append(iter_args.begin(), iter_args.end());
+
+ mlir::Block* then_block = then_builder.getInsertionBlock();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(then_block);
+ rewriter.mergeBlocks(op.getBody(), then_block, bb_args);
+
+ auto old_terminator = then_block->getTerminator();
+ then_builder.create<mlir::scf::YieldOp>(
+ then_loc, old_terminator->getOperands());
+ old_terminator->erase();
+ },
+ [&](OpBuilder& else_b, Location else_loc) {
+ else_b.create<mlir::scf::YieldOp>(loc, iter_args);
+ });
+ return if_op.getResults();
+ });
+ rewriter.replaceOp(op, loop_nest.results);
+ return mlir::success();
+ }
+};
+
+class LowerXlaGpuToScfPass
+ : public impl::LowerXlaGpuToScfPassBase<LowerXlaGpuToScfPass> {
+ public:
+ void runOnOperation() override {
+ auto* ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<RewritePredicatedInsert, RewritePredicatedExtract,
+ RewriteShuffleReduce>(ctx);
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+class LowerXlaGpuLoopsToScfPass
+ : public impl::LowerXlaGpuLoopsToScfPassBase<LowerXlaGpuLoopsToScfPass> {
+ public:
+ void runOnOperation() override {
+ auto* ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<RewriteXlaGpuLoop>(ctx);
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() {
+ return std::make_unique<LowerXlaGpuToScfPass>();
+}
+
+std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuLoopsToScfPass() {
+ return std::make_unique<LowerXlaGpuLoopsToScfPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc b/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc
new file mode 100644
index 0000000..50193e3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc
@@ -0,0 +1,117 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <optional>
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class MergePointersToSameSlicePass
+ : public impl::MergePointersToSameSlicePassBase<
+ MergePointersToSameSlicePass> {
+ public:
+ void runOnOperation() override;
+};
+
+struct PackedArgs {
+ llvm::BitVector args_to_erase;
+ // replacement_args[i] == i iff !args_to_erase[i].
+ llvm::SmallVector<int> replacement_args;
+
+ PackedArgs() = default;
+ explicit PackedArgs(mlir::func::FuncOp func) {
+ absl::flat_hash_map<int, std::optional<int>> slice_to_operand;
+ args_to_erase.resize(func.getNumArguments());
+ replacement_args.reserve(func.getNumArguments());
+ for (int i = 0; i < func.getNumArguments(); ++i) {
+ replacement_args.push_back(i);
+ }
+
+ for (auto [idx, operand] : llvm::enumerate(func.getArguments())) {
+ auto slice_index = func.getArgAttr(idx, "xla.slice_index");
+ if (!slice_index) {
+ continue;
+ }
+
+ auto& target_index = slice_to_operand[static_cast<int>(
+ mlir::cast<mlir::IntegerAttr>(slice_index).getInt())];
+ if (target_index) {
+ replacement_args[idx] = *target_index;
+ args_to_erase[idx] = true;
+ } else {
+ target_index = idx;
+ }
+ }
+ }
+
+ void Pack(mlir::func::FuncOp op) {
+ for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+ if (replacement_args[idx] != idx) {
+ arg.replaceAllUsesWith(op.getArgument(replacement_args[idx]));
+ }
+ }
+ op.eraseArguments(args_to_erase);
+ for (int i = 0; i < op.getNumArguments(); ++i) {
+ if (op.getArgAttr(i, "xla.slice_index")) {
+ op.removeArgAttr(i, "xla.slice_index");
+ op.setArgAttr(i, mlir::LLVM::LLVMDialect::getNoAliasAttrName(),
+ mlir::UnitAttr::get(op->getContext()));
+ }
+ }
+ }
+
+ void Pack(mlir::func::CallOp op) { op->eraseOperands(args_to_erase); }
+};
+
+void MergePointersToSameSlicePass::runOnOperation() {
+ mlir::func::FuncOp entry;
+
+ absl::flat_hash_map<std::string, PackedArgs> args_to_pack;
+ getOperation()->walk([&](mlir::func::FuncOp func) {
+ args_to_pack[func.getName()] = PackedArgs(func);
+ });
+ getOperation()->walk([&](mlir::func::CallOp call) {
+ args_to_pack[call.getCallee()].Pack(call);
+ });
+ getOperation()->walk([&](mlir::func::FuncOp func) {
+ args_to_pack[func.getName()].Pack(func);
+ });
+}
+
+} // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateMergePointersToSameSlicePass() {
+ return std::make_unique<MergePointersToSameSlicePass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc
new file mode 100644
index 0000000..e483bfe
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc
@@ -0,0 +1,326 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/log/check.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_OPTIMIZELOOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+mlir::Value GetSource(mlir::vector::TransferReadOp op) {
+ return op.getSource();
+}
+
+bool DoIndicesDependOnInductionVar(mlir::ValueRange indices,
+ mlir::scf::ForOp loop) {
+ // We assume LICM ran, so we can just check if any index is defined in the
+ // loop.
+ return absl::c_any_of(indices, [&](mlir::Value v) {
+ return v.getParentRegion() == &loop.getBodyRegion();
+ });
+}
+
+bool CanReplaceInductionVar(mlir::ValueRange indices) {
+ return absl::c_all_of(indices, [&](mlir::Value v) {
+ if (auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(v)) {
+ auto for_op = mlir::dyn_cast_or_null<mlir::scf::ForOp>(
+ v.getParentRegion()->getParentOp());
+ // This is a bbarg that is defined outside of the loop, so it doesn't
+ // affect pipelining.
+ if (!for_op) {
+ return true;
+ }
+ // We can only replace the induction variable, not other loop-carried
+ // values.
+ return v == for_op.getInductionVar();
+ }
+ auto* op = v.getDefiningOp();
+ return op &&
+ mlir::isa<mlir::arith::ConstantOp, ApplyIndexingOp,
+ mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
+ mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(
+ op) &&
+ CanReplaceInductionVar(op->getOperands());
+ });
+}
+
+llvm::SmallVector<mlir::Value> ReplaceInductionVar(
+ mlir::Value induction_var, mlir::Value replacement,
+ llvm::SmallVector<mlir::Value> indices,
+ mlir::ImplicitLocOpBuilder& builder) {
+ for (mlir::Value& index : indices) {
+ if (mlir::isa<mlir::BlockArgument>(index)) {
+ if (index == induction_var) {
+ index = replacement;
+ }
+ continue;
+ }
+
+ auto* op = index.getDefiningOp();
+ CHECK(op) << "Did CanReplaceInductionVar() fail?";
+ if (mlir::isa<mlir::arith::ConstantOp>(op)) {
+ continue;
+ }
+
+ CHECK(
+ (mlir::isa<ApplyIndexingOp, mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
+ mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(op)))
+ << "Did CanReplaceInductionVar() fail?";
+ auto replaced_args = ReplaceInductionVar(induction_var, replacement,
+ op->getOperands(), builder);
+ index = builder
+ .create(builder.getLoc(), op->getName().getIdentifier(),
+ replaced_args, op->getResultTypes(), op->getAttrs())
+ ->getResult(0);
+ }
+ return indices;
+}
+
+mlir::Value GetSource(mlir::tensor::ExtractOp op) { return op.getTensor(); }
+
+// TODO(jreiffers): Use a shared memory queue for pipelining instead of
+// registers.
+template <typename Op>
+struct PipelineLoad : mlir::OpRewritePattern<Op> {
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ Op op, mlir::PatternRewriter& rewriter) const override {
+ auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
+ if (!loop) {
+ return rewriter.notifyMatchFailure(op, "no loop found");
+ }
+
+ if (auto step = loop.getConstantStep();
+ !step || step->getSExtValue() != 1) {
+ return rewriter.notifyMatchFailure(op, "loop step is not 1");
+ }
+
+ llvm::APInt lb, ub;
+ if (!mlir::matchPattern(loop.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
+ !mlir::matchPattern(loop.getUpperBound(), mlir::m_ConstantInt(&ub))) {
+ return rewriter.notifyMatchFailure(op, "bounds are not constants");
+ }
+ if (lb.getSExtValue() != 0) {
+ return rewriter.notifyMatchFailure(op, "lower bound is not 0");
+ }
+
+ auto source = GetSource(op);
+ if (!source.getParentRegion()->isProperAncestor(&loop.getBodyRegion())) {
+ return rewriter.notifyMatchFailure(
+ op, "source is not defined outside the loop");
+ }
+
+ if (!DoIndicesDependOnInductionVar(op.getIndices(), loop)) {
+ // We don't run LICM between iterations, so this could happen.
+ // Just hoist the load out of the loop.
+ rewriter.moveOpBefore(op, loop);
+ return mlir::success();
+ }
+
+ if (!CanReplaceInductionVar(op.getIndices())) {
+ return rewriter.notifyMatchFailure(op, "unable to replace indices");
+ }
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ mlir::Value zero = b.create<mlir::arith::ConstantIndexOp>(0);
+
+ b.setInsertionPoint(loop);
+ auto first_args =
+ ReplaceInductionVar(loop.getInductionVar(), zero, op.getOperands(), b);
+ auto loaded_first =
+ b.create<Op>(op->getResultTypes(), first_args, op->getAttrs());
+ auto ub_minus_one =
+ b.create<mlir::arith::ConstantIndexOp>(ub.getSExtValue() - 1);
+
+ b.setInsertionPointToStart(loop.getBody());
+
+ auto needs_load = b.create<mlir::arith::CmpIOp>(
+ mlir::arith::CmpIPredicate::ult, loop.getInductionVar(), ub_minus_one);
+ auto next_value =
+ b.create<mlir::scf::IfOp>(op->getResultTypes(), needs_load, true, true);
+ auto new_for =
+ mlir::cast<mlir::scf::ForOp>(*loop.replaceWithAdditionalYields(
+ rewriter, loaded_first->getResult(0),
+ /*replaceInitOperandUsesInLoop=*/false,
+ [&](mlir::OpBuilder&, mlir::Location,
+ llvm::ArrayRef<mlir::BlockArgument>) {
+ return llvm::SmallVector<mlir::Value>{next_value->getResult(0)};
+ }));
+ rewriter.replaceAllUsesWith(op, new_for.getRegionIterArgs().back());
+
+ b.setInsertionPointToStart(next_value.thenBlock());
+ auto yield = b.create<mlir::scf::YieldOp>(op->getResult(0));
+
+ // We use this convoluted way to add 1 so folding works properly.
+ auto plus_one_map = mlir::AffineMap::get(
+ 1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1);
+ b.setInsertionPoint(next_value);
+ IndexingMap indexing_map(plus_one_map, {DimVar{0, ub.getSExtValue() - 1}},
+ /*range_vars=*/{}, /*rt_vars=*/{});
+ auto induction_plus_one =
+ b.create<ApplyIndexingOp>(new_for.getInductionVar(), indexing_map)
+ ->getResult(0);
+
+ // Create the new apply_indexing ops outside the if, to improve CSE.
+ rewriter.modifyOpInPlace(op, [&]() {
+ op->setOperands(ReplaceInductionVar(
+ new_for.getInductionVar(), induction_plus_one, op->getOperands(), b));
+ });
+ rewriter.moveOpBefore(op, yield);
+
+ b.setInsertionPointToStart(next_value.elseBlock());
+ b.create<mlir::scf::YieldOp>(new_for.getRegionIterArgs().back());
+ return mlir::success();
+ }
+};
+
+int GetUnrollingFactor(mlir::scf::ForOp op) {
+ // We only unroll loops with a step of 1 and a lower bound of 0. That's the
+ // only type we generate.
+ if (auto step = op.getConstantStep(); !step || step->getSExtValue() != 1) {
+ return 1;
+ }
+ llvm::APInt lb, ub;
+ if (!mlir::matchPattern(op.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
+ !mlir::matchPattern(op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
+ return 1;
+ }
+ if (lb.getSExtValue() != 0) {
+ return 1;
+ }
+
+ int64_t trip_count = ub.getSExtValue();
+ constexpr int kMaxSize = 400; // Chosen empirically.
+
+ // Get a rough estimate of the size of the loop body.
+ int64_t size = 0;
+ op.getBodyRegion().walk([&](mlir::Operation* op) {
+ if (mlir::isa<mlir::func::CallOp, mlir::scf::ForOp>(op)) {
+ size += kMaxSize;
+ return;
+ }
+
+ int64_t this_size = 1;
+ if (mlir::isa<mlir::math::MathDialect>(op->getDialect())) {
+ // Integer instructions in math are ok, but many float ops lower to lots
+ // of instructions.
+ if (!op->getResultTypes().front().isIntOrIndex()) {
+ namespace mm = mlir::math;
+ // We err on the side of not unrolling, so we maintain a list of ops
+ // known to be cheap.
+ if (!mlir::isa<mm::AbsFOp, mm::CeilOp, mm::CopySignOp, mm::FloorOp,
+ mm::FmaOp, mm::RoundEvenOp, mm::RoundOp, mm::RsqrtOp,
+ mm::SqrtOp, mm::TruncOp>(op)) {
+ this_size = 20; // Rough estimate.
+ }
+ }
+ }
+
+ if (!op->getResultTypes().empty()) {
+ if (auto vector_ty =
+ mlir::dyn_cast<mlir::VectorType>(op->getResultTypes().front())) {
+ this_size *= vector_ty.getNumElements();
+ }
+ }
+
+ size += this_size;
+ });
+
+ int factor = std::min(trip_count, kMaxSize / size);
+ while (factor > 1 && trip_count % factor) {
+ --factor;
+ }
+ return factor;
+}
+
+struct UnrollLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
+ using mlir::OpRewritePattern<mlir::scf::ForOp>::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
+ if (int factor = GetUnrollingFactor(op); factor > 1) {
+ return mlir::loopUnrollByFactor(op, factor);
+ }
+ return rewriter.notifyMatchFailure(op, "loop can't be unrolled");
+ }
+};
+
+class OptimizeLoopsPass
+ : public impl::OptimizeLoopsPassBase<OptimizeLoopsPass> {
+ public:
+ void runOnOperation() override {
+ // First unroll loops. If unrolling is possible, we prefer it.
+ mlir::RewritePatternSet unroll_patterns(&getContext());
+ unroll_patterns.add<UnrollLoops>(&getContext());
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+ getOperation(), std::move(unroll_patterns)))) {
+ signalPassFailure();
+ return;
+ }
+
+ // Then pipeline the remaining loops.
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<PipelineLoad<mlir::vector::TransferReadOp>,
+ PipelineLoad<mlir::tensor::ExtractOp>>(&getContext());
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateOptimizeLoopsPass() {
+ return std::make_unique<OptimizeLoopsPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h
new file mode 100644
index 0000000..e70af75
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h
@@ -0,0 +1,63 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_
+#define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_
+
+#include <memory>
+#include <optional>
+#include <string>
+
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DECL
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+// Returns the range of a given value, if it can be statically determined.
+std::optional<Interval> GetRange(mlir::Value value);
+
+// Returns the range for the induction variable, if it can be statically
+// determined.
+std::optional<Interval> GetIVRange(mlir::Value iv);
+
+std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
+std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere);
+std::unique_ptr<mlir::Pass> CreateConvertPureCallOpsPass();
+std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass();
+std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
+ bool is_amd_gpu = false, const std::string& gpu_arch = "6.0");
+std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass();
+std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass();
+std::unique_ptr<mlir::Pass> CreateLowerXlaGpuLoopsToScfPass();
+std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass();
+std::unique_ptr<mlir::Pass> CreateOptimizeLoopsPass();
+std::unique_ptr<mlir::Pass> CreatePeelLoopsPass();
+std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass();
+std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass();
+std::unique_ptr<mlir::Pass> CreateSimplifyArithPass();
+std::unique_ptr<mlir::Pass> CreateUnswitchLoopsPass();
+std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass();
+
+#define GEN_PASS_REGISTRATION
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td
new file mode 100644
index 0000000..af27b36
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td
@@ -0,0 +1,319 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_
+#define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_
+
+include "mlir/Pass/PassBase.td"
+
+def PropagateSliceIndicesPass :
+ Pass<"xla-gpu-propagate-slice-indices", "mlir::ModuleOp"> {
+ let summary = "Propagates slice indices from the entry function to all callees.";
+
+ let description = [{
+ Propagates xla.slice_index attributes from the function with the xla.entry
+ attribute to all other functions.
+ }];
+
+ let dependentDialects = [
+ "mlir::func::FuncDialect"
+ ];
+
+ let constructor = "CreatePropagateSliceIndicesPass()";
+}
+
+def ConvertPureCallOpsPass
+ : Pass<"xla-gpu-convert-pure-call-ops", "mlir::func::FuncOp"> {
+ let summary = "Converts xla_gpu.pure_call to func.call";
+ let description = [{
+ We use xla_gpu.pure_call ops for calls to enable CSE and other
+ transformations (e.g. LICM). This pass rewrites our custom ops to standard
+ ops.
+ }];
+ let dependentDialects = [
+ "mlir::func::FuncDialect",
+ "xla::gpu::XlaGpuDialect"
+ ];
+ let constructor = "CreateConvertPureCallOpsPass()";
+}
+
+def FlattenTensorsPass : Pass<"xla-gpu-flatten-tensors", "mlir::ModuleOp"> {
+ let summary = "Flatten tensors.";
+
+ let description = [{
+ Linearizes all tensors loads and stores.
+ }];
+
+ let dependentDialects = [
+ "mlir::func::FuncDialect",
+ "mlir::tensor::TensorDialect",
+ "xla::gpu::XlaGpuDialect",
+ ];
+ let constructor = "CreateFlattenTensorsPass()";
+}
+
+def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {
+ let summary = "Lowers tensors to llvm pointers and loads/stores.";
+
+ let description = [{
+ Lowers tensors to LLVM. We cannot use the memref lowerings because they
+ are not compatible with XLA's ABI.
+ }];
+
+ let dependentDialects = [
+ "mlir::LLVM::LLVMDialect",
+ "mlir::func::FuncDialect",
+ "mlir::gpu::GPUDialect",
+ "mlir::scf::SCFDialect",
+ "mlir::tensor::TensorDialect",
+ "xla::gpu::XlaGpuDialect",
+ ];
+ let options = [
+ Option<"is_amd_gpu_", "is_amd_gpu", "bool", /*default=*/"false",
+ "True if AMD GPU.">,
+ Option<"gpu_arch_", "gpu_arch", "std::string", /*default=*/"",
+ "CUDA or ROCm compute capability.">,
+ ];
+ let constructor = "CreateLowerTensorsPass()";
+}
+
+def MergePointersToSameSlicePass :
+ Pass<"xla-gpu-merge-pointers", "mlir::ModuleOp"> {
+ let summary = "Merges pointers that share slices.";
+
+ let description = [{
+ When a function has multiple pointer arguments with the same slice index,
+ merges them.
+ }];
+
+ let dependentDialects = [
+ "mlir::func::FuncDialect"
+ ];
+
+ let constructor = "CreateMergePointersToSameSlicePass()";
+}
+
+def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::func::FuncOp"> {
+ let summary = "Simplifies arith using XLA's range-aware simplifier.";
+
+ let description = [{
+ We often emit bounds checks that are statically known to be satisfied.
+ This pass removes them.
+ }];
+
+ let dependentDialects = [
+ "mlir::arith::ArithDialect",
+ "mlir::func::FuncDialect",
+ ];
+
+ let constructor = "CreateSimplifyArithPass()";
+}
+
+def SimplifyAffinePass : Pass<"xla-gpu-simplify-affine", "mlir::ModuleOp"> {
+ let summary = "Simplifies affine.apply using XLA's range-aware simplifier.";
+
+ let description = [{
+ The standard affine canonicalizer cannot simplify all expressions, since
+ it is unaware of range information. This pass uses `xla.range` attributes
+ on arguments and ops for simplification. It also lowers floordiv and mod
+ to simpler expressions than lower-affine. This pass only works for
+ expressions for which we can prove the LHS of mod and div is nonnegative.
+ }];
+
+ let dependentDialects = [
+ "mlir::affine::AffineDialect", "mlir::func::FuncDialect",
+ "mlir::scf::SCFDialect",
+ ];
+
+ let constructor = "CreateSimplifyAffinePass()";
+}
+
+def ExpandFloatOpsPass : Pass<"xla-gpu-expand-float-ops", "mlir::ModuleOp"> {
+ let summary = "Expands float ops that are not natively supported.";
+
+ let description = [{
+ Not all float ops are natively supported, either because they don't exist
+ in hardware or they are too inaccurate.
+
+ This pass replaces these ops with alternative implementations.
+ }];
+
+ let dependentDialects = [
+ "mlir::arith::ArithDialect", "mlir::math::MathDialect",
+ "mlir::mhlo::MhloDialect"
+ ];
+
+ let options = [
+ Option<"pre_ampere_", "pre-ampere", "bool", /*default=*/"false",
+ "Rewrite ops that are not supported on architectures before Ampere">,
+ ];
+}
+
+def LowerXlaGpuToScfPass :
+ Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::func::FuncOp"> {
+ let summary = "Lowers xla_gpu to SCF.";
+
+ let dependentDialects = [
+ "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect",
+ "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect",
+ ];
+
+ let constructor = "CreateLowerXlaGpuToScfPass()";
+}
+
+def LowerXlaGpuLoopsToScfPass : Pass<
+ "xla-gpu-lower-xla-gpu-loops-to-scf", "mlir::func::FuncOp"> {
+ let summary = "Lowers xla_gpu.loop to SCF.";
+
+ let description = [{
+ This pass is separate from lower-xla-gpu-to-scf because
+ lower-xla-gpu-to-scf, inliner, peeling and lower-xla-gpu-loops-to-scf
+ have to run in that order.
+ }];
+
+ let dependentDialects = [
+ "mlir::scf::SCFDialect",
+ "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect",
+ ];
+
+ let constructor = "CreateLowerXlaGpuLoopsToScfPass()";
+}
+
+def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> {
+ let summary = "Deletes unused functions";
+
+ let description = [{
+ Deletes functions that are not called.
+ }];
+
+ let dependentDialects = [
+ "mlir::func::FuncDialect",
+ "xla::gpu::XlaGpuDialect"
+ ];
+
+ let constructor = "CreateEraseDeadFunctionsPass()";
+}
+
+def LowerToLLVMPass :
+ Pass<"xla-gpu-lower-to-llvm", "mlir::ModuleOp"> {
+ let summary = "Lowers to LLVM.";
+
+ let description = [{
+ Lowers the rest to LLVM
+ }];
+
+ let dependentDialects = [
+ "mlir::func::FuncDialect",
+ "mlir::LLVM::LLVMDialect",
+ "mlir::NVVM::NVVMDialect",
+ ];
+
+ let constructor = "CreateLowerToLLVMPass()";
+}
+
+def VectorizeLoadsAndStoresPass :
+ Pass<"xla-gpu-vectorize-loads-stores", "mlir::func::FuncOp"> {
+ let summary = "Vectorizes loads and stores.";
+
+ let description = [{
+ Rewrites tensor.extract and tensor.insert ops inside loops to their vector
+ equivalents (vector.transfer_read and vector.transfer_write + vector.extract
+ and vector.insert).
+ }];
+
+ let dependentDialects = [
+ "mlir::vector::VectorDialect",
+ ];
+
+ let constructor = "CreateVectorizeLoadsAndStoresPass()";
+}
+
+def PeelLoopsPass : Pass<"xla-gpu-peel-loops", "mlir::func::FuncOp"> {
+ let summary = "Peels xla_gpu.loop.";
+ let description = [{
+ Attempts to split each loop dimension [0, NUM_ITERATIONS)
+ as [0, NUM_ITERATIONS - 1) and [NUM_ITERATIONS - 1, NUM_ITERATIONS)
+ if it removes a constraint.
+ }];
+ let dependentDialects = ["xla::gpu::XlaGpuDialect"];
+ let constructor = "CreatePeelLoopsPass()";
+}
+
+def OptimizeLoopsPass :
+ Pass<"xla-gpu-optimize-loops", "mlir::func::FuncOp"> {
+ let summary = "Unrolls and pipelines loops.";
+
+ let description = [{
+ Unrolls loops with a small trip count. Pipelines loops with a large trip
+ count.
+ }];
+
+ let dependentDialects = [
+ "mlir::vector::VectorDialect",
+ "xla::gpu::XlaGpuDialect",
+ ];
+
+ let constructor = "CreateOptimizeLoopsPass()";
+}
+
+def UnswitchLoopsPass :
+ Pass<"xla-gpu-unswitch-loops", "mlir::func::FuncOp"> {
+ let summary = "Swaps scf.if and scf.for.";
+
+ let description = [{
+ Extracts `scf.if` ops with conditions that are independent of the loop
+ variable from `scf.for` by doing the following rewrite:
+
+ Before:
+
+ %cond = some_cond() : i1
+ %results = scf.for {
+ %some_val = scf.if %cond {
+ } else {
+ }
+ scf.yield %some_val
+ }
+
+ After:
+
+ %cond = some_cond() : i1
+ %results = scf.if %cond {
+ %results = scf.for {
+ %some_val = scf.if %true {
+ } else {
+ }
+ }
+ yield %results
+ } else {
+ %results = scf.for {
+ %some_val = scf.if %false {
+ } else {
+ }
+ }
+ yield %results
+ }
+
+ This only triggers if there is a single `scf.if` op in the loop body (and
+ nothing else).
+ }];
+
+ let dependentDialects = [
+ "mlir::func::FuncDialect", "mlir::scf::SCFDialect"
+ ];
+
+ let constructor = "CreateUnswitchLoopsPass()";
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc
new file mode 100644
index 0000000..a8157a4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc
@@ -0,0 +1,149 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_PEELLOOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::Location;
+using mlir::OpBuilder;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::SmallVector;
+using mlir::Value;
+using mlir::ValueRange;
+
+struct PeelLoop : public OpRewritePattern<LoopOp> {
+ using OpRewritePattern<LoopOp>::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ LoopOp loop_op, PatternRewriter& rewriter) const override {
+ int64_t cumulative_loop_size = 1;
+
+ // Compute the list of indexing maps. The last element is the "peeled" or
+ // "main" loop. Everything else is a "tail" loop.
+ auto indexing_map = loop_op.getIndexingMap();
+ // TODO(b/358274367): Remove the simplify call once we have `is_simplified`
+ // field and a canonicalization pattern to simplify indexing map in
+ // xla_gpu.loop.
+ indexing_map.Simplify();
+ SmallVector<IndexingMap> indexing_maps{indexing_map};
+ for (int sym_index = indexing_map.GetSymbolCount() - 1;
+ sym_index >= 0 && cumulative_loop_size < 64; --sym_index) {
+ IndexingMap indexing_map = indexing_maps.back();
+ auto& bound = indexing_map.GetSymbolBound(sym_index);
+ cumulative_loop_size *= bound.GetLoopTripCount();
+ if (!indexing_map.IsSymbolConstrained(sym_index) ||
+ bound.upper == bound.lower) {
+ continue;
+ }
+ // Create peeled indexing map.
+ IndexingMap peeled_map = indexing_map;
+ --peeled_map.GetMutableSymbolBound(sym_index).upper;
+ peeled_map.Simplify();
+
+ // If the symbol is still constrained, peeling does not help.
+ if (peeled_map.IsSymbolConstrained(sym_index)) continue;
+
+ // Create remainder indexing map.
+ IndexingMap tail_map = indexing_map;
+ tail_map.GetMutableSymbolBound(sym_index).lower = bound.upper;
+ tail_map.Simplify();
+
+ VLOG(5) << "Peeled indexing map\n"
+ << indexing_map.ToString() << "into\n"
+ << peeled_map.ToString() << "and\n"
+ << tail_map.ToString() << "\n";
+ indexing_maps.pop_back();
+ indexing_maps.push_back(tail_map);
+ indexing_maps.push_back(peeled_map);
+ }
+
+ if (indexing_maps.size() == 1) {
+ return rewriter.notifyMatchFailure(loop_op,
+ "No range variables to peel.");
+ }
+
+ // Create chained loops from the list of indexing maps.
+ Location loc = loop_op.getLoc();
+ SmallVector<Value, 4> inits = loop_op.getInits();
+ for (const auto& indexing_map : llvm::reverse(indexing_maps)) {
+ if (indexing_map.IsKnownEmpty()) continue;
+ auto tail_loop = rewriter.create<LoopOp>(
+ loc, indexing_map, loop_op.getDims(), inits,
+ [&](OpBuilder& nested_b, Location nested_loc, ValueRange ivs,
+ ValueRange iter_args) {
+ OpBuilder::InsertionGuard guard(nested_b);
+ mlir::IRMapping mapping;
+ mapping.map(loop_op.getInductionVars(), ivs);
+ mapping.map(loop_op.getRegionIterArgs(), iter_args);
+ for (auto& op : loop_op.getBody()->getOperations()) {
+ nested_b.clone(op, mapping);
+ }
+ });
+ inits = tail_loop.getResults();
+ }
+ rewriter.replaceOp(loop_op, inits);
+ return mlir::success();
+ }
+};
+
+struct PeelLoopsPass : public impl::PeelLoopsPassBase<PeelLoopsPass> {
+ void runOnOperation() override {
+ auto func = getOperation();
+ mlir::MLIRContext* mlir_context = &getContext();
+ mlir::RewritePatternSet patterns(mlir_context);
+ patterns.add<PeelLoop>(mlir_context);
+ if (mlir::failed(
+ mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
+ signalPassFailure();
+ return;
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> CreatePeelLoopsPass() {
+ return std::make_unique<PeelLoopsPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc b/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc
new file mode 100644
index 0000000..31a6379
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc
@@ -0,0 +1,80 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class PropagateSliceIndicesPass
+ : public impl::PropagateSliceIndicesPassBase<PropagateSliceIndicesPass> {
+ public:
+ void runOnOperation() override;
+};
+
+void PropagateSliceIndicesPass::runOnOperation() {
+ mlir::func::FuncOp entry;
+ for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
+ if (func->getAttr("xla.entry")) {
+ entry = func;
+ break;
+ }
+ }
+
+ if (!entry) {
+ getOperation()->emitOpError("No entry function found.");
+ signalPassFailure();
+ return;
+ }
+
+ for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
+ if (func.getNumArguments() == 0 || func == entry) {
+ continue;
+ }
+
+ for (int i = 0; i < func.getNumArguments(); ++i) {
+ if (mlir::isa<mlir::RankedTensorType>(func.getArgument(i).getType())) {
+ if (auto index = entry.getArgAttr(i, "xla.slice_index")) {
+ func.setArgAttr(i, "xla.slice_index", index);
+ }
+ if (auto invariant = entry.getArgAttr(i, "xla.invariant")) {
+ func.setArgAttr(i, "xla.invariant", invariant);
+ }
+ } else {
+ break;
+ }
+ }
+ }
+}
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass() {
+ return std::make_unique<PropagateSliceIndicesPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc
new file mode 100644
index 0000000..acbd9d3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc
@@ -0,0 +1,368 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/base/optimization.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using mlir::AffineBinaryOpExpr;
+using mlir::AffineConstantExpr;
+using mlir::AffineDimExpr;
+using mlir::AffineExpr;
+using mlir::AffineExprKind;
+using mlir::AffineMap;
+using mlir::AffineSymbolExpr;
+using mlir::ImplicitLocOpBuilder;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::SmallVector;
+using mlir::Value;
+using mlir::ValueRange;
+using mlir::affine::AffineApplyOp;
+
+namespace arith = mlir::arith;
+
+#define GEN_PASS_DEF_SIMPLIFYAFFINEPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+int Distance(ImplicitLocOpBuilder& builder, Value a) {
+ auto* block = builder.getInsertionBlock();
+ auto* parent = a.getParentBlock();
+ int distance = 0;
+ while (block && block != parent) {
+ ++distance;
+ block = block->getParentOp()->getBlock();
+ }
+ return distance;
+}
+
+void CollectArgs(AffineExpr expr, AffineExprKind kind,
+ llvm::SmallVector<AffineExpr>& ret) {
+ if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (bin_op.getKind() == kind) {
+ CollectArgs(bin_op.getLHS(), kind, ret);
+ CollectArgs(bin_op.getRHS(), kind, ret);
+ return;
+ }
+ }
+ ret.push_back(expr);
+}
+
+struct ExpressionEvaluator {
+ ExpressionEvaluator(ImplicitLocOpBuilder& builder, unsigned dim_count,
+ ValueRange operands)
+ : builder(builder), operands(operands) {
+ for (int i = 0; i < dim_count; ++i) {
+ dim_distances.push_back(Distance(builder, operands[i]));
+ }
+ for (int i = dim_count; i < operands.size(); ++i) {
+ sym_distances.push_back(Distance(builder, operands[i]));
+ }
+ }
+
+ // Returns the distance (in basic blocks) from the insertion point to the
+ // values used in the given expression.
+ int ExprDistance(AffineExpr e, int depth = 0) {
+ if (auto dim = mlir::dyn_cast<AffineDimExpr>(e)) {
+ return dim_distances[dim.getPosition()];
+ }
+ if (auto sym = mlir::dyn_cast<AffineSymbolExpr>(e)) {
+ return sym_distances[sym.getPosition()];
+ }
+ if (auto binop = mlir::dyn_cast<AffineBinaryOpExpr>(e)) {
+ return std::min(ExprDistance(binop.getLHS(), depth + 1),
+ ExprDistance(binop.getRHS(), depth + 1));
+ }
+ if (depth == 0) {
+ // Top-level constant. Always add these last.
+ return std::numeric_limits<int>::min();
+ }
+ // Nested constant. Ignore these for distances.
+ return std::numeric_limits<int>::max();
+ }
+
+ Value EvaluateExpression(AffineExpr expr);
+
+ template <typename Op>
+ Value EvaluateAddMul(AffineExpr expr);
+
+ ImplicitLocOpBuilder& builder;
+ ValueRange operands;
+ SmallVector<int> dim_distances;
+ SmallVector<int> sym_distances;
+};
+
+template <typename Op>
+Value ExpressionEvaluator::EvaluateAddMul(AffineExpr expr) {
+ llvm::SmallVector<AffineExpr> args;
+ CollectArgs(expr, expr.getKind(), args);
+ // Sort the args so that the ones that are closest to the insertion point
+ // are evaluated last - this improves LICM.
+ llvm::stable_sort(args, [&](AffineExpr a, AffineExpr b) {
+ int dist_a = ExprDistance(a);
+ int dist_b = ExprDistance(b);
+ return dist_a > dist_b;
+ });
+
+ Value result = nullptr;
+ for (auto arg : args) {
+ Value arg_evaluated = EvaluateExpression(arg);
+ if (result) {
+ result = builder.create<Op>(result, arg_evaluated);
+ } else {
+ result = arg_evaluated;
+ }
+ }
+
+ return result;
+}
+
+Value ExpressionEvaluator::EvaluateExpression(AffineExpr expr) {
+ if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
+ switch (expr.getKind()) {
+ case AffineExprKind::Add:
+ return EvaluateAddMul<arith::AddIOp>(expr);
+ case AffineExprKind::Mul:
+ return EvaluateAddMul<arith::MulIOp>(expr);
+ case AffineExprKind::Mod:
+ return builder.create<arith::RemUIOp>(
+ EvaluateExpression(bin_op.getLHS()),
+ EvaluateExpression(bin_op.getRHS()));
+ case AffineExprKind::FloorDiv:
+ return builder.create<arith::DivUIOp>(
+ EvaluateExpression(bin_op.getLHS()),
+ EvaluateExpression(bin_op.getRHS()));
+ default:
+ ABSL_UNREACHABLE();
+ }
+ }
+ switch (expr.getKind()) {
+ case AffineExprKind::Constant:
+ return builder.create<arith::ConstantIndexOp>(
+ mlir::cast<AffineConstantExpr>(expr).getValue());
+ case AffineExprKind::DimId:
+ return operands[mlir::cast<AffineDimExpr>(expr).getPosition()];
+ case AffineExprKind::SymbolId:
+ return operands[dim_distances.size() +
+ mlir::cast<AffineSymbolExpr>(expr).getPosition()];
+ default:
+ ABSL_UNREACHABLE();
+ }
+}
+
+bool IsLoweringSupported(AffineExpr expr, RangeEvaluator& range_evaluator) {
+ auto bin_op = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!bin_op) {
+ return true;
+ }
+ // Mod and div can be lowered if their LHS is >= 0 and their RHS is a
+ // constant.
+ if (expr.getKind() == AffineExprKind::Mod ||
+ expr.getKind() == AffineExprKind::FloorDiv) {
+ if (!range_evaluator.IsAlwaysPositiveOrZero(bin_op.getLHS()) ||
+ !range_evaluator.ComputeExpressionRange(bin_op.getRHS()).IsPoint()) {
+ return false;
+ }
+ }
+ if (expr.getKind() == AffineExprKind::CeilDiv) {
+ return false;
+ }
+ return IsLoweringSupported(bin_op.getLHS(), range_evaluator) &&
+ IsLoweringSupported(bin_op.getRHS(), range_evaluator);
+}
+
+struct RewriteAffineApply : OpRewritePattern<mlir::affine::AffineApplyOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::affine::AffineApplyOp op,
+ PatternRewriter& rewriter) const override {
+ AffineMap affine_map = op.getAffineMap();
+ std::vector<DimVar> dim_ranges(affine_map.getNumDims());
+ std::vector<RangeVar> symbol_ranges(affine_map.getNumSymbols());
+
+ for (int i = 0; i < affine_map.getNumInputs(); ++i) {
+ if (auto range = GetRange(op->getOperand(i))) {
+ if (i >= dim_ranges.size()) {
+ symbol_ranges[i - dim_ranges.size()] = RangeVar{*range};
+ } else {
+ dim_ranges[i] = DimVar{*range};
+ }
+ } else {
+ return rewriter.notifyMatchFailure(op, "failed to deduce range");
+ }
+ }
+
+ IndexingMap indexing_map(affine_map, std::move(dim_ranges),
+ std::move(symbol_ranges),
+ /*rt_vars=*/{});
+ indexing_map.Simplify();
+ auto result_expr = indexing_map.GetAffineMap().getResult(0);
+
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
+ if (!IsLoweringSupported(result_expr, range_evaluator)) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to lower the affine apply");
+ }
+ b.setInsertionPoint(op);
+ auto result = ExpressionEvaluator(b, indexing_map.GetDimensionCount(),
+ op->getOperands())
+ .EvaluateExpression(result_expr);
+ rewriter.replaceOp(op, result);
+ return mlir::success();
+ }
+};
+
+struct RewriteApplyIndexingOp : OpRewritePattern<ApplyIndexingOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ApplyIndexingOp op,
+ PatternRewriter& rewriter) const override {
+ auto indexing_map = op.getIndexingMap();
+ indexing_map.Simplify();
+ auto affine_map = indexing_map.GetAffineMap();
+ int64_t dim_count = indexing_map.GetDimensionCount();
+ auto operands = op->getOperands();
+
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
+
+ b.setInsertionPoint(op);
+ SmallVector<Value, 4> results;
+ results.reserve(affine_map.getNumResults());
+ for (unsigned i = 0; i < affine_map.getNumResults(); ++i) {
+ AffineExpr result_expr = affine_map.getResult(i);
+ // If the expression cannot be lowered, we convert it to affine.apply,
+ // since it supports more expression types.
+ if (IsLoweringSupported(result_expr, range_evaluator)) {
+ results.push_back(ExpressionEvaluator(b, dim_count, operands)
+ .EvaluateExpression(result_expr));
+ } else {
+ results.push_back(
+ b.create<AffineApplyOp>(affine_map.getSubMap({i}), operands));
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return mlir::success();
+ }
+};
+
+struct SimplifyAffinePass
+ : public impl::SimplifyAffinePassBase<SimplifyAffinePass> {
+ public:
+ void runOnOperation() override {
+ MLIRContext* ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<RewriteAffineApply, RewriteApplyIndexingOp>(ctx);
+ mlir::GreedyRewriteConfig config;
+ // There's no point simplifying more than once.
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+ getOperation(), std::move(patterns), config))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::optional<Interval> GetRange(mlir::Value value) {
+ auto attr_to_range = [](mlir::Attribute attr) -> std::optional<Interval> {
+ if (!attr) {
+ return std::nullopt;
+ }
+ auto values = llvm::to_vector(
+ mlir::cast<mlir::ArrayAttr>(attr).getAsValueRange<mlir::IntegerAttr>());
+ return {{values[0].getSExtValue(), values[1].getSExtValue()}};
+ };
+
+ if (auto apply = value.getDefiningOp<ApplyIndexingOp>()) {
+ return apply.getIndexingMap().GetRangeEvaluator().ComputeExpressionRange(
+ apply.getIndexingMap().GetAffineMap().getResult(
+ mlir::cast<mlir::OpResult>(value).getResultNumber()));
+ } else if (auto cst = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
+ return {{cst.value(), cst.value()}};
+ } else if (value.getDefiningOp()) {
+ return attr_to_range(value.getDefiningOp()->getAttr("xla.range"));
+ }
+
+ auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(value);
+ if (!bbarg) {
+ return std::nullopt;
+ }
+
+ auto parent = bbarg.getParentBlock()->getParentOp();
+ if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(parent)) {
+ return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range"));
+ }
+ return GetIVRange(value);
+}
+
+std::optional<Interval> GetIVRange(mlir::Value iv) {
+ auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(iv);
+ if (!bbarg) {
+ return std::nullopt;
+ }
+ auto parent = bbarg.getParentBlock()->getParentOp();
+ if (auto for_op = mlir::dyn_cast<mlir::scf::ForOp>(parent)) {
+ llvm::APInt lb, ub;
+ if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) &&
+ mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
+ return {{lb.getSExtValue(), ub.getSExtValue() - 1}};
+ }
+ }
+ return std::nullopt;
+}
+
+std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass() {
+ return std::make_unique<SimplifyAffinePass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc
new file mode 100644
index 0000000..bbf4831
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc
@@ -0,0 +1,344 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_SIMPLIFYARITHPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::LogicalResult;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::arith::CmpIOp;
+using mlir::arith::CmpIPredicate;
+
+Interval::ComparisonResult EvaluateCmpI(CmpIPredicate pred, Interval lhs,
+ Interval rhs) {
+ switch (pred) {
+ case CmpIPredicate::eq:
+ return lhs.Eq(rhs);
+ case CmpIPredicate::ne:
+ return lhs.Ne(rhs);
+ case CmpIPredicate::slt:
+ case CmpIPredicate::ult:
+ return lhs.Lt(rhs);
+ case CmpIPredicate::sle:
+ case CmpIPredicate::ule:
+ return lhs.Le(rhs);
+ case CmpIPredicate::sgt:
+ case CmpIPredicate::ugt:
+ return lhs.Gt(rhs);
+ case CmpIPredicate::sge:
+ case CmpIPredicate::uge:
+ return lhs.Ge(rhs);
+ }
+}
+
+struct RewriteCmpI : OpRewritePattern<CmpIOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CmpIOp op,
+ PatternRewriter& rewriter) const override {
+ auto rhs = GetRange(op.getRhs());
+ auto lhs = GetRange(op.getLhs());
+ if (!lhs || !rhs) {
+ return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
+ }
+ Interval::ComparisonResult result =
+ EvaluateCmpI(op.getPredicate(), *lhs, *rhs);
+ if (result != std::nullopt) {
+ rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(
+ op, *result, rewriter.getI1Type());
+ return mlir::success();
+ }
+ return rewriter.notifyMatchFailure(op, "not a constant result");
+ }
+};
+
+struct RewriteMaxSi : OpRewritePattern<mlir::arith::MaxSIOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::arith::MaxSIOp op,
+ PatternRewriter& rewriter) const override {
+ auto lhs = GetRange(op.getLhs());
+ auto rhs = GetRange(op.getRhs());
+ if (!lhs || !rhs) {
+ return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
+ }
+ if (auto lhs_ge_rhs = lhs->Ge(*rhs); lhs_ge_rhs == true) {
+ rewriter.replaceOp(op, op.getLhs());
+ } else if (auto rhs_ge_lhs = rhs->Ge(*lhs); rhs_ge_lhs == true) {
+ rewriter.replaceOp(op, op.getRhs());
+ } else {
+ return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
+ }
+ return mlir::success();
+ }
+};
+
+struct RewriteMinSi : OpRewritePattern<mlir::arith::MinSIOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::arith::MinSIOp op,
+ PatternRewriter& rewriter) const override {
+ auto lhs = GetRange(op.getLhs());
+ auto rhs = GetRange(op.getRhs());
+ if (!lhs || !rhs) {
+ return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
+ }
+ if (auto lhs_le_rhs = lhs->Le(*rhs); lhs_le_rhs == true) {
+ rewriter.replaceOp(op, op.getLhs());
+ } else if (auto rhs_le_lhs = rhs->Le(*lhs); rhs_le_lhs == true) {
+ rewriter.replaceOp(op, op.getRhs());
+ } else {
+ return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
+ }
+ return mlir::success();
+ }
+};
+
+// Finds the narrowest value in a use-def chain of truncis/extuis.
+mlir::Value FindNarrowestValueInChain(mlir::Value value) {
+ if (auto ext = value.getDefiningOp<mlir::arith::ExtUIOp>()) {
+ return FindNarrowestValueInChain(ext.getOperand());
+ }
+ auto defining_op = value.getDefiningOp<mlir::arith::TruncIOp>();
+ if (defining_op) {
+ auto first_trunc = FindNarrowestValueInChain(defining_op.getOperand());
+ if (first_trunc && first_trunc.getType().getIntOrFloatBitWidth() <=
+ defining_op.getType().getIntOrFloatBitWidth()) {
+ return first_trunc;
+ }
+ return defining_op;
+ }
+ return value;
+}
+
+// Rewrites trunc-bitwise to bitwise-trunc.
+//
+// For pred reductions, we generate code like this:
+//
+// %1 = arith.trunci %0 : i32 to i1
+// %2 = arith.ori %1, %x
+// %3 = arith.extui %2 : i1 to i32
+// %4 = gpu.shuffle %3
+//
+// By swapping the trunc with the or, we get a trunc-ext-shuffle sequence, which
+// can be rewritten to shuffle-trunc-ext. If there is another copy of the
+// pattern afterwards, we can push the truncs/exts further down.
+template <typename Op>
+struct RewriteTruncBitExt : OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter& rewriter) const override {
+ mlir::Value lhs = FindNarrowestValueInChain(op.getLhs());
+ mlir::Value rhs = FindNarrowestValueInChain(op.getRhs());
+
+ if (lhs.getType() != rhs.getType()) {
+ return rewriter.notifyMatchFailure(op, "mismatched narrowest types");
+ }
+
+ auto trunci_lhs = lhs.getDefiningOp<mlir::arith::TruncIOp>();
+ auto trunci_rhs = rhs.getDefiningOp<mlir::arith::TruncIOp>();
+ if (!trunci_lhs && !trunci_rhs) {
+ return rewriter.notifyMatchFailure(
+ op, "neither narrowest value is the result of a truncation");
+ }
+
+ auto wide_type =
+ (trunci_lhs ? trunci_lhs : trunci_rhs).getOperand().getType();
+ if (trunci_rhs && trunci_rhs.getOperand().getType() != wide_type) {
+ return rewriter.notifyMatchFailure(op, "mismatched truncation types");
+ }
+
+ mlir::Value new_lhs = trunci_lhs ? trunci_lhs.getOperand()
+ : rewriter.create<mlir::arith::ExtUIOp>(
+ op.getLoc(), wide_type, lhs);
+ mlir::Value new_rhs = trunci_rhs ? trunci_rhs.getOperand()
+ : rewriter.create<mlir::arith::ExtUIOp>(
+ op.getLoc(), wide_type, rhs);
+ mlir::Value new_op = rewriter.create<Op>(op.getLoc(), new_lhs, new_rhs);
+ rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, op.getType(),
+ new_op);
+
+ return mlir::success();
+ }
+};
+
+// Rewrites trunc-ext-shuffle to shuffle-trunc-ext. This pattern is designed to
+// work together with RewriteTruncBitExt to optimize pred reductions.
+struct RewriteTruncExtShuffle : public OpRewritePattern<mlir::gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::gpu::ShuffleOp op,
+ PatternRewriter& rewriter) const override {
+ auto ext = op.getOperand(0).getDefiningOp<mlir::arith::ExtUIOp>();
+ if (!ext) {
+ return rewriter.notifyMatchFailure(op, "no ext");
+ }
+ auto trunc = ext.getOperand().getDefiningOp<mlir::arith::TruncIOp>();
+ if (!trunc || trunc.getOperand().getType() != ext.getType()) {
+ return rewriter.notifyMatchFailure(op, "no trunc or type mismatch");
+ }
+ rewriter.setInsertionPointAfter(op);
+ auto new_trunc = rewriter.create<mlir::arith::TruncIOp>(
+ op.getLoc(), trunc.getType(), op.getResult(0));
+ auto new_ext = rewriter.create<mlir::arith::ExtUIOp>(
+ op.getLoc(), ext.getType(), new_trunc.getResult());
+ rewriter.modifyOpInPlace(op,
+ [&]() { op->setOperand(0, trunc.getOperand()); });
+ rewriter.replaceAllUsesExcept(op.getResult(0), new_ext, new_trunc);
+ return mlir::success();
+ }
+};
+
+void AnnotateRanges(mlir::func::FuncOp func) {
+ func->walk([](mlir::Operation* op) {
+ if (op->getNumResults() != 1) {
+ return;
+ }
+
+ auto result = op->getResult(0);
+ if (GetRange(result).has_value()) {
+ return;
+ }
+
+ auto get_range = [](mlir::Value value) -> Interval {
+ auto range = GetRange(value);
+ if (range) {
+ return *range;
+ }
+ return {std::numeric_limits<int64_t>::min(),
+ std::numeric_limits<int64_t>::max()};
+ };
+
+ std::optional<Interval> out_range = std::nullopt;
+ if (mlir::isa<mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
+ mlir::arith::AddIOp, mlir::arith::MulIOp>(op)) {
+ auto lhs_range = get_range(op->getOperand(0));
+ auto rhs_range = get_range(op->getOperand(1));
+ if (mlir::isa<mlir::arith::MaxSIOp>(op)) {
+ out_range = lhs_range.max(rhs_range);
+ } else if (mlir::isa<mlir::arith::MinSIOp>(op)) {
+ out_range = lhs_range.min(rhs_range);
+ } else if (mlir::isa<mlir::arith::AddIOp>(op)) {
+ out_range = lhs_range + rhs_range;
+ } else {
+ out_range = lhs_range * rhs_range;
+ }
+ }
+
+ if (out_range) {
+ mlir::OpBuilder b(op);
+ op->setAttr("xla.range",
+ b.getIndexArrayAttr({out_range->lower, out_range->upper}));
+ }
+ });
+}
+
+// Pattern to refine the bounds of an indexing map if some of its operands are
+// bound, e.g. loop induction variables.
+struct RefineConstraints : public OpRewritePattern<ApplyIndexingOp> {
+ using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+ PatternRewriter& rewriter) const override {
+ // Right now, we only handle loop induction variables, but other rules might
+ // be added.
+ IndexingMap indexing_map = indexing_op.getIndexingMap();
+ int64_t dim_count = indexing_map.GetDimensionCount();
+ bool updated_bounds = false;
+ for (mlir::OpOperand& operand : indexing_op->getOpOperands()) {
+ auto range = GetIVRange(operand.get());
+ if (!range) {
+ continue;
+ }
+ auto operand_id = operand.getOperandNumber();
+ Interval& current_interval =
+ operand_id < dim_count
+ ? indexing_map.GetMutableDimensionBound(operand_id)
+ : indexing_map.GetMutableSymbolBound(operand_id - dim_count);
+ if (!range->Contains(current_interval)) {
+ current_interval = current_interval.Intersect(*range);
+ updated_bounds = true;
+ }
+ }
+ if (!updated_bounds) {
+ return rewriter.notifyMatchFailure(indexing_op, "No bounds to refine");
+ }
+ indexing_map.Simplify();
+ rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
+ indexing_op, indexing_op.getOperands(), indexing_map);
+ return mlir::success();
+ }
+};
+
+class SimplifyArithPass
+ : public impl::SimplifyArithPassBase<SimplifyArithPass> {
+ public:
+ void runOnOperation() override {
+ auto ctx = &getContext();
+ auto func = getOperation();
+ mlir::RewritePatternSet patterns(ctx);
+ AnnotateRanges(func);
+ // clang-format off
+ patterns.add<
+ RefineConstraints,
+ RewriteCmpI,
+ RewriteMaxSi,
+ RewriteMinSi,
+ RewriteTruncBitExt<mlir::arith::AndIOp>,
+ RewriteTruncBitExt<mlir::arith::OrIOp>,
+ RewriteTruncExtShuffle
+ >(ctx);
+ // clang-format on
+ if (mlir::failed(
+ mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> CreateSimplifyArithPass() {
+ return std::make_unique<SimplifyArithPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD
new file mode 100644
index 0000000..381d5a3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD
@@ -0,0 +1,16 @@
+load("//xla:lit.bzl", "lit_test_suite")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ licenses = ["notice"],
+)
+
+lit_test_suite(
+ name = "tests",
+ srcs = glob(["*.mlir"]),
+ cfg = "//xla:lit.cfg.py",
+ tools = [
+ "//xla/service/gpu/fusions/tools:mlir_fusions_opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/convert_xla_gpu_pure_calls.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/convert_xla_gpu_pure_calls.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir
new file mode 100644
index 0000000..21a8dc2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-flatten-tensors \
+// RUN: --verify-diagnostics | FileCheck %s
+
+func.func @tensor_extract(
+ %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
+ %arg1: index, %arg2: index) -> f32 {
+ %v = tensor.extract %arg0[%arg1, %arg2]
+ : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
+ func.return %v : f32
+}
+// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0),
+// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 2]>
+
+// CHECK-LABEL: func.func @tensor_extract(
+// CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>,
+// CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index) -> f32 {
+// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]], %[[J]])
+// CHECK: tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32>
+
+// -----
+
+func.func @tensor_insert(
+ %arg0: tensor<10x24xcomplex<f32>>) -> tensor<10x24xcomplex<f32>> {
+ %c1 = arith.constant 1 : index
+ %real = arith.constant 3.0 : f32
+ %imag = arith.constant 2.0 : f32
+ %complex = complex.create %real, %imag : complex<f32>
+ %out = tensor.insert %complex into %arg0[%c1, %c1] : tensor<10x24xcomplex<f32>>
+ func.return %out : tensor<10x24xcomplex<f32>>
+}
+// CHECK-LABEL: func.func @tensor_insert(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<240xcomplex<f32>>) -> tensor<240xcomplex<f32>> {
+// CHECK: %[[INDEX:.*]] = arith.constant 25
+// CHECK: %[[COMPLEX:.*]] = complex.create
+// CHECK: tensor.insert %[[COMPLEX]] into %[[TENSOR]][%[[INDEX]]]
+// CHECK-SAME: : tensor<240xcomplex<f32>>
+
+// -----
+
+func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index)
+ -> (tensor<2x4xf32>) {
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+ ^bb0(%current : f32):
+ %c42 = arith.constant 1.0 : f32
+ %add = arith.minimumf %current, %c42 : f32
+ xla_gpu.yield %add : f32
+ }
+ return %ret : tensor<2x4xf32>
+}
+// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1),
+// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3]>
+
+// CHECK-LABEL: func.func @atomic_rmw(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index,
+// CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> {
+// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME: (%[[I]], %[[J]])
+// CHECK: xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32>
+
+// -----
+
+func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>)
+ -> (tensor<32x1024xf32>, tensor<64x8x4xf32>, f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %for:2 = scf.for %i = %c0 to %c64 step %c32 iter_args(%t0_ = %t0, %t1_ = %t1)
+ -> (tensor<32x1024xf32>, tensor<64x8x4xf32>) {
+ %update0 = tensor.insert %c0_f32 into %t0_[%c1, %i] : tensor<32x1024xf32>
+ %update1 = tensor.insert %c0_f32 into %t1_[%i, %c1, %c1] : tensor<64x8x4xf32>
+ scf.yield %update0, %update1 : tensor<32x1024xf32>, tensor<64x8x4xf32>
+ } {some_attr}
+ return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32
+}
+
+// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024),
+// CHECK-SAME: domain: d0 in [0, 1023]>
+// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5),
+// CHECK-SAME: domain: d0 in [0, 63]>
+// CHECK-LABEL: func.func @for_loop(
+// CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>,
+// CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) {
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
+// CHECK-DAG: %[[F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]]
+// CHECK-SAME: step %[[C32]]
+// CHECK-SAME: iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]])
+// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]])
+// CHECK: %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]]
+// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]])
+// CHECK: %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]]
+// CHECK: scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32>
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36),
+ domain: d0 in [0, 127], d1 in [0, 393749]>
+#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4),
+ domain: d0 in [0, 127], d1 in [0, 393749]>
+#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9),
+ domain: d0 in [0, 127], d1 in [0, 393749]>
+func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>,
+ %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>)
+ -> tensor<4000x4x9xf32> {
+ %c0 = arith.constant 0 : index
+ %c3999 = arith.constant 3999 : index
+ %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
+ %bl_x = gpu.block_id x {xla.range = [0 : index, 393749 : index]}
+ %0 = xla_gpu.apply_indexing #map(%th_x, %bl_x)
+ %extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32>
+ %1 = arith.index_cast %extracted : i32 to index
+ %2 = arith.cmpi ule, %1, %c3999 : index
+ %3 = scf.if %2 -> (tensor<4000x4x9xf32>) {
+ %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x)
+ %5 = xla_gpu.apply_indexing #map2(%th_x, %bl_x)
+ %elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32>
+ %atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> {
+ ^bb0(%arg4: f32):
+ %6 = arith.addf %arg4, %elem : f32
+ xla_gpu.yield %6 : f32
+ }
+ scf.yield %atomic_rmw : tensor<4000x4x9xf32>
+ } else {
+ scf.yield %arg3 : tensor<4000x4x9xf32>
+ }
+ return %3 : tensor<4000x4x9xf32>
+}
+// CHECK-LABEL: func.func @if_op
+// CHECK-NOT: builtin.unrealized_conversion_cast
+// CHECK: scf.if %{{.*}} -> (tensor<144000xf32>) {
+// CHECK-COUNT-2: scf.yield %{{.*}} : tensor<144000xf32>
+// CHECK: return %{{.*}} : tensor<144000xf32>
+
+// -----
+
+func.func @dangling_cast(%arg0: tensor<6xf32>, %arg1: index) -> i32 {
+ %v = tensor.extract %arg0[%arg1] : tensor<6xf32>
+ %cast = builtin.unrealized_conversion_cast %v : f32 to i32
+ func.return %cast : i32
+}
+// CHECK: FlattenTensorsPass failed to converge
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir
similarity index 100%
copy from third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir
copy to third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir
new file mode 100644
index 0000000..be8eb1e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir
@@ -0,0 +1,854 @@
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=6.0" \
+// RUN: | FileCheck %s
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=7.0" \
+// RUN: | FileCheck %s --check-prefix=CHECK-VOLTA
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=8.0" \
+// RUN: | FileCheck %s --check-prefix=CHECK-AMPERE
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=9.0" \
+// RUN: | FileCheck %s --check-prefix=CHECK-HOPPER
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx908:sramecc+:xnack" \
+// RUN: | FileCheck %s --check-prefix=CHECK-GFX908-MI100
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx90a:sramecc+:xnack" \
+// RUN: | FileCheck %s --check-prefix=CHECK-GFX90A-MI200
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>} {
+ func.func private @add(%arg0: f32, %arg1: f32) -> f32 {
+ %sum = arith.addf %arg0, %arg1 : f32
+ func.return %sum : f32
+ }
+
+ func.func private @tensorarg(%arg0: tensor<43xf32> {xla.invariant, xla.slice_index = 0}, %arg1: index) -> f32 {
+ %v1 = arith.constant 2.0 : f32
+ %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32>
+ %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32
+ func.return %sum : f32
+ }
+
+ func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, %arg1: index) -> f32 {
+ %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32
+ func.return %call : f32
+ }
+
+ func.func @stores(%arg0: tensor<17xf32> {xla.slice_index = 0}, %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
+ %c17 = arith.constant 17 : index
+ %c23 = arith.constant 23 : index
+ %cst = arith.constant 3.0 : f32
+ %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32>
+ %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32>
+ func.return %out2 : tensor<43xf32>
+ }
+}
+
+// CHECK: func.func private @add(%{{.*}}: f32, %{{.*}}: f32) -> f32 {
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: return
+
+// CHECK: func.func private @tensorarg(%[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME: {xla.invariant, xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) -> f32 {
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00
+// CHECK-DAG: %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i32
+// CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX]]]
+// CHECK-DAG: %[[V2:.*]] = llvm.load %[[PTR]] invariant
+// CHECK: %[[RET:.*]] = call @add(%[[C2]], %[[V2]])
+// CHECK: return %[[RET]]
+
+// CHECK: func.func @tensorcall(%[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME: {xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index)
+// CHECK: %[[RET:.*]] = call @tensorarg(%[[ARG0]], %[[ARG1]])
+// CHECK: return %[[RET]]
+
+// CHECK: func.func @stores(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {xla.slice_index = 0 : i64},
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {xla.slice_index = 1 : i64})
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 3.000000e+00 : f32
+// CHECK-NEXT: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[ARG1]][17]
+// CHECK-NEXT: llvm.store %[[CST]], %[[PTR1]]
+// CHECK-NEXT: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[ARG1]][23]
+// CHECK-NEXT: llvm.store %[[CST]], %[[PTR2]]
+// CHECK-NEXT: return
+
+// -----
+
+module {
+ func.func @layout(
+ %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
+ %arg1: index, %arg2: index) -> f32 {
+ %v = tensor.extract %arg0[%arg1, %arg2]
+ : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
+ func.return %v : f32
+ }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0),
+// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 2]>
+// CHECK-LABEL: @layout(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index
+// CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[X]], %[[Y]])
+// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64
+// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]]
+// CHECK: llvm.load %[[PTR]]
+
+// -----
+
+module {
+ func.func @store_control_flow(
+ %arg0: tensor<2xf32>,
+ %arg1: index
+ ) -> tensor<2xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %cst = arith.constant 0.0 : f32
+ %cst2 = arith.constant 1.0 : f32
+
+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> {
+ %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32>
+ scf.yield %new_out : tensor<2xf32>
+ }
+
+ %inbounds = arith.cmpi sle, %arg1, %c1 : index
+ %result = scf.if %inbounds -> tensor<2xf32> {
+ %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32>
+ scf.yield %if : tensor<2xf32>
+ } else {
+ scf.yield %for : tensor<2xf32>
+ }
+ func.return %result : tensor<2xf32>
+ }
+}
+
+// CHECK: @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i64
+// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]]
+// CHECK: llvm.store {{.*}}, %[[PTR]]
+// CHECK: %[[INBOUNDS:.*]] = arith.cmpi
+// CHECK: scf.if %[[INBOUNDS]] {
+// CHECK: llvm.store
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// -----
+
+module {
+ func.func @large_tensor(
+ %arg0: tensor<1024x1024x1024x6xf32>,
+ %arg1: index) -> f32 {
+ %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32>
+ func.return %v : f32
+ }
+}
+
+// CHECK: @large_tensor
+// CHECK: arith.index_castui {{.*}} : index to i64
+
+// -----
+
+module {
+ func.func @extract_from_constant(%arg0: tensor<2x1xf32>,
+ %arg1: index, %arg2: index) -> f32 {
+ %cst = arith.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>
+ %extracted = tensor.extract %arg0[%arg1, %arg2] : tensor<2x1xf32>
+ %extracted_0 = tensor.extract %cst[%arg1, %arg2] : tensor<2x1xf32>
+ %0 = arith.addf %extracted, %extracted_0 : f32
+ return %0 : f32
+ }
+}
+// CHECK: llvm.mlir.global private constant @global_cst_0(dense<
+// CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32>
+// CHECK: @extract_from_constant
+// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr
+// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f32
+// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[LOAD]] : f32
+// CHECK: return %[[ADD]] : f32
+
+// -----
+
+module {
+ func.func @vector_constant() -> vector<2xindex> {
+ %c1 = arith.constant dense<[1, 2]> : vector<2xindex>
+ func.return %c1 : vector<2xindex>
+ }
+}
+
+// vector constants should not be rewritten.
+// CHECK: @vector_constant
+// CHECK-NEXT: arith.constant
+
+// -----
+
+module {
+ func.func @complex_tensor_insert(
+ %arg0: tensor<10xcomplex<f32>>) -> tensor<10xcomplex<f32>> {
+ %c1 = arith.constant 1 : index
+ %real = arith.constant 3.0 : f32
+ %imag = arith.constant 2.0 : f32
+ %complex = complex.create %real, %imag : complex<f32>
+ %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex<f32>>
+ func.return %out : tensor<10xcomplex<f32>>
+ }
+}
+
+// CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr
+// CHECK: %[[C:.*]] = complex.create
+// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[C]] : complex<f32> to !llvm.struct<(f32, f32)>
+// CHECK: llvm.store %[[CAST]], %[[GEP]] : !llvm.struct<(f32, f32)>, !llvm.ptr
+
+// -----
+
+module {
+ func.func @complex_tensor_extract(
+ %arg0: tensor<10xcomplex<f32>>) -> complex<f32> {
+ %c1 = arith.constant 1 : index
+ %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex<f32>>
+ func.return %v2 : complex<f32>
+ }
+}
+
+// CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr
+// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
+// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)>
+// CHECK: builtin.unrealized_conversion_cast %[[LOAD]] : !llvm.struct<(f32, f32)> to complex<f32>
+
+// -----
+
+module {
+ // This example is a bit silly, in real life there wouldn't be a loop (the
+ // loop body would be executed by different threads). We're just doing it this
+ // way so control flow with shared memory is tested as well.
+ func.func @transpose_shared(%in: tensor<32x32xf32>,
+ %out: tensor<32x32xf32>) -> tensor<32x32xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+
+ %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
+ %loaded_tile = scf.for %i = %c0 to %c32 step %c1
+ iter_args(%tile = %shared) -> tensor<32x32xf32> {
+ %inner_loaded_tile = scf.for %j = %c0 to %c32 step %c1
+ iter_args(%inner_tile = %tile) -> tensor<32x32xf32> {
+ %v = tensor.extract %in[%i, %j] : tensor<32x32xf32>
+ %inserted = tensor.insert %v into %inner_tile[%i, %j]
+ : tensor<32x32xf32>
+ scf.yield %inserted : tensor<32x32xf32>
+ }
+ scf.yield %inner_loaded_tile : tensor<32x32xf32>
+ }
+
+ %synced = xla_gpu.sync_threads %shared : tensor<32x32xf32>
+ %written_tile = scf.for %i = %c0 to %c32 step %c1
+ iter_args(%written = %out) -> tensor<32x32xf32> {
+ %inner_written_tile = scf.for %j = %c0 to %c32 step %c1
+ iter_args(%inner_written = %written) -> tensor<32x32xf32> {
+ %v = tensor.extract %shared[%j, %i] : tensor<32x32xf32>
+ %inserted = tensor.insert %v into %inner_written[%i, %j]
+ : tensor<32x32xf32>
+ scf.yield %inserted : tensor<32x32xf32>
+ }
+ scf.yield %inner_written_tile : tensor<32x32xf32>
+ }
+
+ return %written_tile : tensor<32x32xf32>
+ }
+}
+
+// CHECK: llvm.mlir.global private @[[SHARED:shared_.*]]()
+// CHECK-SAME: {addr_space = 3 : i32} : !llvm.array<1024 x f32>
+// CHECK: @transpose_shared
+// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @[[SHARED]] : !llvm.ptr<3>
+// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[ADDR]]
+// CHECK-SAME: : !llvm.ptr<3> to !llvm.ptr
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
+// CHECK: llvm.store {{.*}} %[[ELEM_ADDR]]
+// CHECK: gpu.barrier
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
+// CHECK: llvm.load %[[ELEM_ADDR]]
+
+// -----
+
+module {
+ func.func @atomic_rmw_f32(%in: tensor<2x4xf32>, %i: index, %j: index)
+ -> (tensor<2x4xf32>) {
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+ ^bb0(%current : f32):
+ %c42 = arith.constant 1.0 : f32
+ %add = arith.minimumf %current, %c42 : f32
+ xla_gpu.yield %add : f32
+ }
+ return %ret : tensor<2x4xf32>
+ }
+}
+
+// CHECK: @atomic_rmw_f32
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]]
+// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
+// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f32 to i32
+// CHECK-NEXT: llvm.cmpxchg %[[ADDR]], %[[VAR]], %[[RES]]
+
+// -----
+
+module {
+ func.func @atomic_rmw_f16(%in: tensor<2x4xf16>, %i: index, %j: index)
+ -> (tensor<2x4xf16>) {
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
+ ^bb0(%current : f16):
+ %c1 = arith.constant 1.0 : f16
+ %add = arith.addf %current, %c1 : f16
+ xla_gpu.yield %add : f16
+ }
+ return %ret : tensor<2x4xf16>
+ }
+}
+
+// CHECK: @atomic_rmw_f16
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
+// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
+// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
+// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
+// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
+// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
+// CHECK-NEXT: %[[VAR_SHIFT:.*]] = llvm.lshr %[[VAR]], %{{.*}}
+// CHECK-NEXT: %[[VAR_TRUNC:.*]] = llvm.trunc %[[VAR_SHIFT]]
+// CHECK-NEXT: llvm.bitcast %[[VAR_TRUNC]] : i16 to f16
+// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
+// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
+// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
+// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
+// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
+// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
+
+// -----
+
+module {
+ func.func @atomic_rmw_overwrite(%in: tensor<2x4xf16>, %i: index, %j: index)
+ -> (tensor<2x4xf16>) {
+ %c1 = arith.constant 1.0 : f16
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
+ ^bb0(%current : f16):
+ xla_gpu.yield %c1 : f16
+ }
+ return %ret : tensor<2x4xf16>
+ }
+}
+// CHECK: @atomic_rmw_overwrite
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
+// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
+// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
+// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
+// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
+// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
+// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
+// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
+// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
+// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
+// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
+// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
+
+// -----
+
+module {
+ func.func @shared_complex() -> tensor<10xcomplex<f32>> {
+ %shared = xla_gpu.allocate_shared : tensor<10xcomplex<f32>>
+ return %shared : tensor<10xcomplex<f32>>
+ }
+}
+
+// CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>>
+// CHECK: @shared_complex
+
+// -----
+
+module {
+ func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> {
+ %v = tensor.extract %arg[%i] : tensor<10xi4>
+ %r = tensor.insert %v into %arg[%j] : tensor<10xi4>
+ return %r : tensor<10xi4>
+ }
+}
+
+// CHECK: @i4_load_store
+// CHECK: llvm.getelementptr
+// CHECK-SAME: -> !llvm.ptr, i8
+// CHECK: llvm.load
+// CHECK: llvm.getelementptr
+// CHECK-SAME: -> !llvm.ptr, i8
+// CHECK: llvm.load
+// CHECK: llvm.store
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_overwrite(%in: tensor<2x4xi32>,
+ %i: index, %j: index) -> (tensor<2x4xi32>) {
+ %c2 = arith.constant 2 : i32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+ ^bb0(%current : i32):
+ xla_gpu.yield %c2 : i32
+ }
+ return %ret : tensor<2x4xi32>
+ }
+}
+// CHECK: @direct_atomic_rmw_overwrite
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.store %[[C2]], %[[ADDR]] atomic unordered {alignment = 4 : i64}
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_addi(%in: tensor<2x4xi32>,
+ %i: index, %j: index) -> (tensor<2x4xi32>) {
+ %c2 = arith.constant 2 : i32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+ ^bb0(%current : i32):
+ %min = arith.addi %current, %c2 : i32
+ xla_gpu.yield %c2 : i32
+ }
+ return %ret : tensor<2x4xi32>
+ }
+}
+// CHECK: @direct_atomic_rmw_addi
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw add %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_maxsi(%in: tensor<2x4xi32>,
+ %i: index, %j: index) -> (tensor<2x4xi32>) {
+ %c2 = arith.constant 2 : i32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+ ^bb0(%current : i32):
+ %min = arith.maxsi %current, %c2 : i32
+ xla_gpu.yield %c2 : i32
+ }
+ return %ret : tensor<2x4xi32>
+ }
+}
+// CHECK: @direct_atomic_rmw_maxsi
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw max %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_maxui(%in: tensor<2x4xi32>,
+ %i: index, %j: index) -> (tensor<2x4xi32>) {
+ %c2 = arith.constant 2 : i32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+ ^bb0(%current : i32):
+ %min = arith.maxui %current, %c2 : i32
+ xla_gpu.yield %c2 : i32
+ }
+ return %ret : tensor<2x4xi32>
+ }
+}
+// CHECK: @direct_atomic_rmw_maxui
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw umax %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_minsi(%in: tensor<2x4xi32>,
+ %i: index, %j: index) -> (tensor<2x4xi32>) {
+ %c2 = arith.constant 2 : i32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+ ^bb0(%current : i32):
+ %min = arith.minsi %current, %c2 : i32
+ xla_gpu.yield %c2 : i32
+ }
+ return %ret : tensor<2x4xi32>
+ }
+}
+// CHECK: @direct_atomic_rmw_minsi
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw min %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_minui(%in: tensor<2x4xi32>,
+ %i: index, %j: index) -> (tensor<2x4xi32>) {
+ %c2 = arith.constant 2 : i32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+ ^bb0(%current : i32):
+ %min = arith.minui %current, %c2 : i32
+ xla_gpu.yield %c2 : i32
+ }
+ return %ret : tensor<2x4xi32>
+ }
+}
+// CHECK: @direct_atomic_rmw_minui
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw umin %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_fadd_f32(%in: tensor<2x4xf32>,
+ %i: index, %j: index) -> (tensor<2x4xf32>) {
+ %c2 = arith.constant 2.0 : f32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+ ^bb0(%current : f32):
+ %min = arith.addf %current, %c2 : f32
+ xla_gpu.yield %c2 : f32
+ }
+ return %ret : tensor<2x4xf32>
+ }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
+// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
+// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-GFX908-MI100: %[[C2:.*]] = arith.constant 2
+// CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-GFX908-MI100: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
+
+// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
+// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_fadd_f16(%in: tensor<2x4xf16>,
+ %i: index, %j: index) -> (tensor<2x4xf16>) {
+ %c2 = arith.constant 2.0 : f16
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
+ ^bb0(%current : f16):
+ %min = arith.addf %current, %c2 : f16
+ xla_gpu.yield %c2 : f16
+ }
+ return %ret : tensor<2x4xf16>
+ }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-NOT: llvm.atomicrmw fadd
+
+// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
+// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
+// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
+
+// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
+// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<2x4xbf16>,
+ %i: index, %j: index) -> (tensor<2x4xbf16>) {
+ %c2 = arith.constant 2.0 : bf16
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xbf16> {
+ ^bb0(%current : bf16):
+ %min = arith.addf %current, %c2 : bf16
+ xla_gpu.yield %c2 : bf16
+ }
+ return %ret : tensor<2x4xbf16>
+ }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_bf16
+// CHECK-NOT: llvm.atomicrmw fadd
+
+// CHECK-HOPPER-LABEL: @direct_atomic_rmw_fadd_bf16
+// CHECK-HOPPER: %[[C2:.*]] = arith.constant 2
+// CHECK-HOPPER: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-HOPPER: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_fadd_f64(%in: tensor<2x4xf64>,
+ %i: index, %j: index) -> (tensor<2x4xf64>) {
+ %c2 = arith.constant 2.0 : f64
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf64> {
+ ^bb0(%current : f64):
+ %min = arith.addf %current, %c2 : f64
+ xla_gpu.yield %c2 : f64
+ }
+ return %ret : tensor<2x4xf64>
+ }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
+// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
+// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
+
+// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-GFX90A-MI200-NOT: llvm.atomicrmw fadd
+
+// -----
+
+module {
+ func.func @direct_atomic_rmw_maximumf(%in: tensor<2x4xf32>,
+ %i: index, %j: index) -> (tensor<2x4xf32>) {
+ %c2 = arith.constant 2.0 : f32
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+ ^bb0(%current : f32):
+ %min = arith.maximumf %current, %c2 : f32
+ xla_gpu.yield %c2 : f32
+ }
+ return %ret : tensor<2x4xf32>
+ }
+}
+// CHECK-LABEL: @direct_atomic_rmw_maximumf
+
+// CHECK: %[[MODIFIER:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: %[[CURRENT:.*]] = llvm.load %[[ADDR]] : !llvm.ptr -> f32
+// CHECK: %[[CURRENT_IS_NAN:.*]] = llvm.fcmp "uno" %[[CURRENT]], %[[CURRENT]] : f32
+// CHECK: scf.if %[[CURRENT_IS_NAN]] {
+// CHECK: } else {
+// CHECK: %[[MODIFIER_IS_NAN:.*]] = llvm.fcmp "uno" %[[MODIFIER]], %[[MODIFIER]] : f32
+// CHECK: %[[MODIFIER_OR_NAN:.*]] = llvm.select %[[MODIFIER_IS_NAN]], %[[NAN]], %[[MODIFIER]] : i1, f32
+// CHECK: %[[VAL_13:.*]] = llvm.fcmp "ult" %[[CURRENT]], %[[MODIFIER_OR_NAN]] : f32
+// CHECK: scf.if %[[VAL_13]] {
+// CHECK: %[[INT_MODIFIER_OR_NAN:.*]] = llvm.bitcast %[[MODIFIER_OR_NAN]] : f32 to i32
+// CHECK: %[[IS_POSITIVE:.*]] = llvm.icmp "sge" %[[INT_MODIFIER_OR_NAN]], %[[C0]] : i32
+// CHECK: scf.if %[[IS_POSITIVE]] {
+// CHECK: llvm.atomicrmw max %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
+// CHECK: } else {
+// CHECK: llvm.atomicrmw umin %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return
+
+// -----
+
+module {
+ func.func @atomic_rmw_c32(%in: tensor<2x4xcomplex<f32>>, %i: index, %j: index)
+ -> (tensor<2x4xcomplex<f32>>) {
+ %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xcomplex<f32>> {
+ ^bb0(%current : complex<f32>):
+ %a = complex.add %current, %current : complex<f32>
+ xla_gpu.yield %a : complex<f32>
+ }
+ return %ret : tensor<2x4xcomplex<f32>>
+ }
+}
+
+// CHECK-LABEL: @atomic_rmw_c32
+
+// CHECK: scf.while (%[[ITER_ARG:.*]] = %{{.*}}) : (i64) -> i64
+// CHECK: %[[TMP:.*]] = llvm.alloca
+// CHECK: llvm.store %[[ITER_ARG]], %[[TMP]]
+// CHECK: %[[LD:.*]] = llvm.load %[[TMP]] : {{.*}} -> !llvm.struct<(f32, f32)>
+// CHECK: builtin.unrealized_conversion_cast %[[LD]] : {{.*}} to complex<f32>
+
+// -----
+
+module {
+ func.func @unused_index_switch_results(%i: index) -> index {
+ %ret, %ret2 = scf.index_switch %i -> tensor<2x4xi32>, tensor<3xf32>
+ case 0 {
+ %x, %y = "dummy.op1"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
+ scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
+ }
+ default {
+ %x, %y = "dummy.op2"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
+ scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
+ }
+ return %i : index
+ }
+}
+
+// CHECK-LABEL: func.func @unused_index_switch_results
+// CHECK-SAME: (%[[I:.*]]: index)
+// CHECK-NEXT: scf.index_switch %[[I]]
+// CHECK-NEXT: case 0 {
+// CHECK-NEXT: "dummy.op1"
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: default {
+// CHECK-NEXT: "dummy.op2"
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[I]] : index
+
+// -----
+
+module {
+ func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
+ %c16 = arith.constant 16 : index
+ %c22 = arith.constant 22 : index
+ %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32>
+ %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32>
+ %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32>
+ func.return %out2 : tensor<43xf32>
+ }
+}
+
+// CHECK-LABEL: @transfer_write
+// CHECK: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
+// CHECK-NEXT: llvm.store %[[CST:.*]], %[[PTR1]]
+// CHECK-NEXT: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
+// CHECK-NEXT: llvm.store %[[CST]], %[[PTR2]]
+
+// -----
+
+module {
+ func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> {
+ %c16 = arith.constant 16 : index
+ %c0 = arith.constant 0.0 : f32
+ %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32>
+ func.return %out : vector<2xf32>
+ }
+}
+
+// CHECK-LABEL: @transfer_read
+// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
+// CHECK-NEXT: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xf32>
+
+// -----
+
+module {
+ func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1},
+ %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> {
+ %c16 = arith.constant 16 : index
+ %c22 = arith.constant 22 : index
+ %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1>
+ %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1>
+ func.return %out2 : tensor<43xi1>
+ }
+}
+
+// CHECK-LABEL: @transfer_write_i1
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME: %[[V1:.*]]: vector<2xi1>, %[[V2:.*]]: vector<2xi1>)
+// CHECK-DAG: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
+// CHECK-DAG: %[[V1_EXT:.*]] = arith.extui %[[V1]]
+// CHECK: llvm.store %[[V1_EXT]], %[[PTR1]]
+// CHECK-DAG: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
+// CHECK-DAG: %[[V2_EXT:.*]] = arith.extui %[[V2]]
+// CHECK: llvm.store %[[V2_EXT]], %[[PTR2]]
+
+// -----
+
+module {
+ func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> {
+ %c16 = arith.constant 16 : index
+ %false = arith.constant false
+ %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1>
+ func.return %out : vector<2xi1>
+ }
+}
+
+// CHECK-LABEL: @transfer_read_i1
+// CHECK-DAG: %[[C0:.*]] = arith.constant dense<0> : vector<2xi8>
+// CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
+// CHECK: %[[LOADED:.*]] = llvm.load %[[PTR]] : !llvm.ptr
+// CHECK: %[[CAST:.*]] = arith.cmpi ne, %[[LOADED]], %[[C0]]
+// CHECK: return %[[CAST]] : vector<2xi1>
+
+// -----
+
+module {
+ func.func @transfer_write_i4(%arg0: tensor<43xi4> {xla.slice_index = 1},
+ %v1: vector<4xi4>) -> tensor<43xi4> {
+ %c16 = arith.constant 16 : index
+ %out = vector.transfer_write %v1, %arg0[%c16] : vector<4xi4>, tensor<43xi4>
+ func.return %out : tensor<43xi4>
+ }
+}
+
+// CHECK-LABEL: @transfer_write_i4
+// CHECK-SAME: , %[[V1:.*]]: vector<4xi4>
+// CHECK-DAG: %[[A0:.*]] = vector.extract %[[V1]][0]
+// CHECK-DAG: %[[A1:.*]] = vector.extract %[[V1]][1]
+// CHECK-DAG: %[[A2:.*]] = vector.extract %[[V1]][2]
+// CHECK-DAG: %[[A3:.*]] = vector.extract %[[V1]][3]
+// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1]
+// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0]
+// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3]
+// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2]
+
+module {
+ func.func @transfer_read_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}) -> vector<4xi4> {
+ %c16 = arith.constant 16 : index
+ %c0 = arith.constant 0 : i4
+ %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xi4>, vector<4xi4>
+ func.return %out : vector<4xi4>
+ }
+}
+
+// CHECK-LABEL: @transfer_read_i4
+// CHECK: %[[LOADED:.*]] = llvm.load
+// CHECK-DAG: %[[A0:.*]] = vector.extract %[[LOADED]][0]
+// CHECK-DAG: %[[A1:.*]] = vector.extract %[[LOADED]][1]
+// CHECK-DAG: %[[A2:.*]] = vector.extract %[[LOADED]][2]
+// CHECK-DAG: %[[A3:.*]] = vector.extract %[[LOADED]][3]
+// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1]
+// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0]
+// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3]
+// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2]
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir
new file mode 100644
index 0000000..f0de25a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf | FileCheck %s
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1),
+ domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]>
+func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) {
+ %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+ %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+ %add = arith.addf %sum_, %t : f32
+ xla_gpu.yield %add : f32
+ } {xla.range = [0 : index, 42 : index]}
+ func.return %sum : f32
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + s1),
+// CHECK-SAME: domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]>
+
+// CHECK-LABEL: func.func @loop_op(
+// CHECK-SAME: %[[IN:.*]]: tensor<1024x32xf32>,
+// CHECK-SAME: %[[INIT:.*]]: f32, %[[DIM:.*]]: index) -> f32 {
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C33:.*]] = arith.constant 33 : index
+// CHECK-DAG: %[[C90:.*]] = arith.constant 90 : index
+// CHECK-DAG: %[[C1025:.*]] = arith.constant 1025 : index
+
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C1025]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[INIT_:.*]] = %[[INIT]]) -> (f32) {
+
+// CHECK: %[[INNER_FOR:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C33]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[INIT__:.*]] = %[[INIT_]]) -> (f32) {
+
+// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing
+// CHECK-SAME: #[[$MAP]](%[[DIM]])[%[[I]], %[[J]]]
+// CHECK: %[[VAL1:.*]] = arith.cmpi sge, %[[INDEX]], %[[C0]] : index
+// CHECK: %[[VAL2:.*]] = arith.cmpi sle, %[[INDEX]], %[[C90]] : index
+// CHECK: %[[VAL3:.*]] = arith.andi %[[VAL1]], %[[VAL2]] : i1
+// CHECK: %[[VAL4:.*]] = arith.cmpi sge, %[[DIM]], %[[C0]] : index
+// CHECK: %[[VAL5:.*]] = arith.cmpi sle, %[[DIM]], %[[C3]] : index
+// CHECK: %[[VAL6:.*]] = arith.andi %[[VAL4]], %[[VAL5]] : i1
+// CHECK: %[[INBOUNDS:.*]] = arith.andi %[[VAL3]], %[[VAL6]] : i1
+// CHECK: %[[IF_RESULT:.*]] = scf.if %[[INBOUNDS]] -> (f32) {
+// CHECK: %[[ELEM:.*]] = tensor.extract %[[IN]][%[[I]], %[[J]]]
+// CHECK: %[[SUM:.*]] = arith.addf %[[INIT__]], %[[ELEM]] : f32
+// CHECK: scf.yield %[[SUM]] : f32
+// CHECK: } else {
+// CHECK: scf.yield %[[INIT__]] : f32
+// CHECK: }
+// CHECK: scf.yield %[[IF_RESULT]] : f32
+// CHECK: }
+// CHECK: scf.yield %[[INNER_FOR]] : f32
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
new file mode 100644
index 0000000..2f9494a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \
+// RUN: | FileCheck %s
+
+func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
+ return %a, %b : f32, i32
+}
+
+func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
+ %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32
+ return %ret#0, %ret#1 : f32, i32
+}
+// CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32)
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32
+// CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]]
+// CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]]
+// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
+// CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]]
+// CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]]
+// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @reducer(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
+// CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]]
+// CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]]
+// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @reducer(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
+// CHECK: return %[[AB1_0]], %[[AB1_1]]
+
+// -----
+
+func.func @reducer(%a: f64, %b: f64) -> f64 {
+ return %a : f64
+}
+
+func.func @shuffler(%a: f64) -> f64 {
+ %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64
+ return %ret : f64
+}
+// CHECK: @shuffler(%[[A:.*]]: f64
+// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
+// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @reducer(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
+ return %a : complex<f64>
+}
+
+func.func @shuffler(%a: complex<f64>) -> complex<f64> {
+ %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex<f64>
+ return %ret : complex<f64>
+}
+// CHECK: @shuffler
+// CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @reducer(%a: ui64, %b: ui64) -> ui64 {
+ return %a : ui64
+}
+
+func.func @shuffler(%a: ui64) -> ui64 {
+ %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64
+ return %ret : ui64
+}
+// CHECK: @shuffler
+// CHECK: unrealized_conversion_cast
+// CHECK-COUNT-2: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @reducer(%a: i8, %b: i8) -> i8 {
+ return %a : i8
+}
+
+func.func @shuffler_i8(%a: i8) -> i8 {
+ %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8
+ return %ret : i8
+}
+// CHECK: @shuffler_i8(
+// CHECK-NOT: vector
+// CHECK-COUNT-1: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @predicated_insert(
+ %v: i32, %tensor: tensor<2xi32>, %index: index,
+ %cond: i1) -> tensor<2xi32> {
+ %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond
+ : tensor<2xi32>
+ return %ret : tensor<2xi32>
+}
+// CHECK: @predicated_insert
+// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
+// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
+// CHECK-NEXT: %[[UPD:.*]] = tensor.insert %[[V]] into %[[TENSOR]][%[[INDEX]]]
+// CHECK-NEXT: yield %[[UPD]]
+// CHECK-NEXT: else
+// CHECK-NEXT: yield %[[TENSOR]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RET]]
+
+// -----
+
+func.func @predicated_extract(
+ %v: i32, %tensor: tensor<2xi32>, %index: index,
+ %cond: i1) -> i32 {
+ %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v
+ : tensor<2xi32>
+ return %ret : i32
+}
+// CHECK: @predicated_extract
+// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
+// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
+// CHECK-NEXT: %[[VAL:.*]] = tensor.extract %[[TENSOR]][%[[INDEX]]]
+// CHECK-NEXT: yield %[[VAL]]
+// CHECK-NEXT: else
+// CHECK-NEXT: yield %[[V]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RET]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir
new file mode 100644
index 0000000..cb6c048
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir
@@ -0,0 +1,211 @@
+// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s
+
+#map = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 8),
+ domain: d0 in [0, 31]>
+#map1 = #xla_gpu.indexing_map<(d0) -> (d0 mod 8),
+ domain: d0 in [0, 31]>
+#map2 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512),
+ domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]>
+module {
+ func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>,
+ %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>,
+ %arg4: tensor<4x8x4096xbf16>, %arg5: tensor<4x8xf32>,
+ %arg6: tensor<4x8x4096xf32>) -> (tensor<4x8x4096xf32>, f32) {
+ %cst = arith.constant 1.000000e+00 : f32
+ %cst_1 = arith.constant 1.000000e+00 : bf16
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 255 : index]}
+ %block_id_x = gpu.block_id x {xla.range = [0 : index, 31 : index]}
+ %0 = gpu.lane_id
+ %1 = arith.cmpi eq, %0, %c0 : index
+ %2 = arith.divui %thread_id_x, %c32 : index
+ %3 = arith.cmpi ult, %thread_id_x, %c8 : index
+ %4 = xla_gpu.apply_indexing #map(%block_id_x)
+ %5 = xla_gpu.apply_indexing #map1(%block_id_x)
+ %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32>
+ %6 = arith.mulf %extracted, %cst : f32
+ %7 = arith.addf %6, %cst : f32
+ %8 = math.rsqrt %7 : f32
+ %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) {
+ %18 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+ %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
+ %20 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+ %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
+ %22 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+ %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16>
+ %24 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+ %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32>
+ %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) {
+ %27 = xla_gpu.apply_indexing #map2(%arg10, %thread_id_x)[%arg7]
+ %28 = vector.extract %25[%arg10] : f32 from vector<2xf32>
+ %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16>
+ %30 = arith.extf %29 : bf16 to f32
+ %31 = vector.extract %21[%arg10] : bf16 from vector<2xbf16>
+ %32 = arith.extf %31 : bf16 to f32
+ %33 = arith.mulf %30, %32 : f32
+ %34 = arith.mulf %33, %8 : f32
+ %35 = vector.extract %19[%arg10] : bf16 from vector<2xbf16>
+ %36 = arith.extf %35 : bf16 to f32
+ %37 = arith.addf %36, %cst : f32
+ %38 = arith.mulf %34, %37 : f32
+ %39 = arith.addf %28, %38 : f32
+ %40 = arith.mulf %39, %39 : f32
+ %41 = arith.addf %arg12, %40 : f32
+ %inserted = tensor.insert %39 into %arg11[%4, %5, %27] : tensor<4x8x4096xf32>
+ scf.yield %inserted, %41 : tensor<4x8x4096xf32>, f32
+ }
+ scf.yield %26#0, %26#1 : tensor<4x8x4096xf32>, f32
+ }
+ return %9#0, %9#1 : tensor<4x8x4096xf32>, f32
+ }
+}
+
+// CHECK-LABEL: @fully_unroll
+// CHECK-NOT: scf.for
+
+// -----
+
+module {
+ func.func @unroll_by_factor(%arg0: f32) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c256 = arith.constant 256 : index
+ %ret = scf.for %i = %c0 to %c256 step %c1 iter_args (%v = %arg0) -> (f32) {
+ %exp = math.exp %v : f32
+ %add = arith.addf %v, %exp : f32
+ %log = math.log %add : f32
+ scf.yield %log : f32
+ }
+ return %ret : f32
+ }
+}
+
+// CHECK-LABEL: @unroll_by_factor
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: scf.for {{.*}} step %[[C8]]
+
+// -----
+
+module {
+ func.func @do_not_unroll(%arg0: f32) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c31 = arith.constant 31 : index
+ %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%v = %arg0) -> (f32) {
+ %exp = math.exp %v : f32
+ %add = arith.addf %v, %exp : f32
+ %log = math.log %add : f32
+ scf.yield %log : f32
+ }
+ return %ret : f32
+ }
+}
+
+// CHECK-LABEL: @do_not_unroll
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for {{.*}} step %[[C1]]
+
+// -----
+
+module {
+ func.func @pipeline_extract(%arg: tensor<31xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c31 = arith.constant 31 : index
+ %cst = arith.constant 0.0 : f32
+ %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%iter = %cst) -> (f32) {
+ %val = tensor.extract %arg[%i] : tensor<31xf32>
+ %log = math.log %val : f32
+ %add = arith.addf %log, %iter : f32
+ scf.yield %add : f32
+ }
+ return %ret : f32
+ }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1),
+// CHECK-LABEL: @pipeline_extract
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index
+// CHECK: %[[VAL0:.*]] = tensor.extract %[[ARG0:.*]][%[[C0]]]
+// CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
+// CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C30]]
+// CHECK-DAG: %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
+// CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
+// CHECK-NEXT: tensor.extract %[[ARG0]][%[[NEXT_I]]]
+// CHECK-NEXT: yield
+// CHECK-NEXT: else
+// CHECK-NEXT: yield %[[VAL]]
+// CHECK: math.log %[[VAL]]
+// CHECK: %[[ADD:.*]] = arith.addf
+// CHECK: yield %[[ADD]], %[[NEXT_VAL]]
+
+// -----
+
+module {
+ func.func @pipeline_transfer(%arg: tensor<34xf32>) -> vector<2xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c17 = arith.constant 17 : index
+ %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32>
+ %cst0 = arith.constant 0.0 : f32
+ %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) {
+ %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 15]>(%i)
+ %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32>
+ %log = math.log %val : vector<2xf32>
+ %add = arith.addf %log, %iter : vector<2xf32>
+ scf.yield %add : vector<2xf32>
+ }
+ return %ret : vector<2xf32>
+ }
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2),
+// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1),
+// CHECK-LABEL: @pipeline_transfer
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[BASE0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[C0]]
+// CHECK: %[[VAL0:.*]] = vector.transfer_read %[[ARG0:.*]][%[[BASE0]]]
+// CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
+// CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C16]]
+// CHECK-DAG: %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]
+// CHECK-DAG: %[[NEXT_BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[NEXT_I]]
+// CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
+// CHECK-NEXT: vector.transfer_read %[[ARG0]][%[[NEXT_BASE]]]
+// CHECK-NEXT: yield
+// CHECK-NEXT: else
+// CHECK-NEXT: yield %[[VAL]]
+// CHECK: math.log %[[VAL]]
+// CHECK: %[[ADD:.*]] = arith.addf
+// CHECK: yield %[[ADD]], %[[NEXT_VAL]]
+
+// -----
+
+module {
+ func.func @sequential_extract(%arg0: tensor<6xindex>, %arg1: tensor<22xindex>) -> (index) {
+ %c1 = arith.constant 1 : index
+ %c733 = arith.constant 733 : index
+ %c0 = arith.constant 0 : index
+ %2 = scf.for %i = %c0 to %c733 step %c1 iter_args(%x = %c1) -> (index) {
+ %extracted = tensor.extract %arg0[%i] : tensor<6xindex>
+ %extracted_1 = tensor.extract %arg1[%extracted] : tensor<22xindex>
+ scf.yield %extracted_1 : index
+ }
+ return %2 : index
+ }
+}
+
+// Once `extracted` is pipelined, it becomes an iter arg, so `extracted_1` is
+// extract %arg1[%arg]. While it is possible to pipeline this in principle, we
+// do not currently do this.
+
+// CHECK-LABEL: @sequential_extract
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<6xindex>, %[[ARG1:.*]]: tensor<22xindex>)
+// CHECK: tensor.extract %[[ARG0]]
+// CHECK-NOT: tensor.extract
+// CHECK: scf.for
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir
new file mode 100644
index 0000000..2044254
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-peel-loops \
+// RUN: | FileCheck %s
+
+#map = #xla_gpu.indexing_map<
+ (d0)[s0, s1] -> (s0, s1),
+ domain:
+ d0 in [0, 3],
+ s0 in [0, 7],
+ s1 in [0, 10],
+ d0 + s0 in [0, 9],
+ d0 + s1 in [0, 12]
+>
+func.func @peel_both_loops(%input: tensor<16x32xf32>,
+ %init: f32, %dim: index) -> (f32) {
+ %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+ %t = tensor.extract %input[%i, %j] : tensor<16x32xf32>
+ %add = arith.addf %sum_, %t : f32
+ xla_gpu.yield %add : f32
+ }
+ func.return %sum : f32
+}
+// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]>
+// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]>
+// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]>
+
+// CHECK-LABEL: func.func @peel_both_loops(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32xf32>,
+// CHECK-SAME: %[[INIT:.*]]: f32, %[[DIM:.*]]: index)
+
+// CHECK: %[[PEELED:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]]
+// CHECK-SAME: in #[[$PEELED_MAP]] iter_args(%[[INIT_:.*]] = %[[INIT]])
+// CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]] : tensor<16x32xf32>
+// CHECK: arith.addf %[[INIT_]]
+
+// CHECK: %[[TAIL0:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]]
+// CHECK-SAME: in #[[$TAIL_MAP0]] iter_args(%[[INIT_:.*]] = %[[PEELED]])
+// CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]]
+// CHECK: arith.addf %[[INIT_]]
+
+// CHECK: %[[TAIL1:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]]
+// CHECK-SAME: in #[[$TAIL_MAP1]] iter_args(%[[INIT_:.*]] = %[[TAIL0]])
+// CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]]
+// CHECK: arith.addf %[[INIT_]]
+
+// CHECK: return %[[TAIL1]] : f32
+
+// -----
+
+#map = #xla_gpu.indexing_map<
+ (d0)[s0] -> (s0),
+ domain:
+ d0 in [0, 3],
+ s0 in [0, 7]
+>
+func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32,
+ %dim: index) -> (f32) {
+ %sum = xla_gpu.loop (%dim)[%i] in #map iter_args(%sum_ = %init) -> (f32) {
+ %t = tensor.extract %input[%i] : tensor<16xf32>
+ %add = arith.addf %sum_, %t : f32
+ xla_gpu.yield %add : f32
+ }
+ func.return %sum : f32
+}
+// CHECK-LABEL: func.func @not_constrained_symbol
+// CHECK: xla_gpu.loop
+// CHECK-NOT: xla_gpu.loop
+
+// -----
+
+#map = #xla_gpu.indexing_map<
+ (d0)[s0] -> (s0),
+ domain:
+ d0 in [0, 3],
+ s0 in [0, 7],
+ s0 mod 5 in [0, 1]
+>
+func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32,
+ %dim: index) -> (f32) {
+ %sum = xla_gpu.loop (%dim)[%i] in #map iter_args(%sum_ = %init) -> (f32) {
+ %t = tensor.extract %input[%i] : tensor<16xf32>
+ %add = arith.addf %sum_, %t : f32
+ xla_gpu.yield %add : f32
+ }
+ func.return %sum : f32
+}
+// CHECK-LABEL: func.func @constraint_exists_after_peeling
+// CHECK: xla_gpu.loop
+// CHECK-NOT: xla_gpu.loop
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir
new file mode 100644
index 0000000..d51566a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir_fusions_opt --allow-unregistered-dialect %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s
+
+func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %0 = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
+ %1 = gpu.block_id x {xla.range = [0 : index, 3071 : index]}
+ scf.for %arg3 = %c0 to %c4 step %c1 {
+ %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))>()[%1, %0, %arg3]
+ %3 = arith.index_castui %2 : index to i64
+ %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %5 = llvm.load %4 invariant : !llvm.ptr -> f32
+ %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %9 = llvm.load %8 invariant : !llvm.ptr -> f32
+ %10 = arith.cmpf oge, %5, %9 : f32
+ %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
+ llvm.store %10, %11 : i1, !llvm.ptr
+ }
+ return
+}
+
+// CHECK-LABEL: @op_and_for_ranges
+// CHECK-DAG: %[[C512:.*]] = arith.constant 512
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4
+// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
+// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
+// CHECK: scf.for %[[I:.*]] =
+// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
+// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
+// CHECK: %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
+// CHECK: arith.addi %[[OFFSET]], %[[I]]
+
+// -----
+
+func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
+ %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
+ return %0 : index
+}
+
+// CHECK-LABEL: @arg_ranges
+// CHECK-NEXT: %[[C100:.*]] = arith.constant 100
+// CHECK-NEXT: %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
+// CHECK-NEXT: return %[[RET]]
+
+// -----
+
+func.func @cant_lower(%arg0: index {xla.range = [-10 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
+ %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
+ return %0 : index
+}
+
+// CHECK-LABEL: @cant_lower
+// CHECK: affine.apply
+
+// -----
+
+func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %0 = gpu.thread_id x
+ %1 = gpu.block_id x
+ scf.for %i = %c0 to %c4 step %c1 {
+ %2 = xla_gpu.apply_indexing
+ #xla_gpu.indexing_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4)),
+ domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]>[%1, %0, %i]
+ %3 = arith.index_castui %2 : index to i64
+ %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %5 = llvm.load %4 invariant : !llvm.ptr -> f32
+ %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %9 = llvm.load %8 invariant : !llvm.ptr -> f32
+ %10 = arith.cmpf oge, %5, %9 : f32
+ %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
+ llvm.store %10, %11 : i1, !llvm.ptr
+ }
+ return
+}
+
+// CHECK-LABEL: @op_and_for_ranges
+// CHECK-DAG: %[[C512:.*]] = arith.constant 512
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4
+// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
+// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
+// CHECK: scf.for %[[I:.*]] =
+// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
+// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
+// CHECK: %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
+// CHECK: arith.addi %[[OFFSET]], %[[I]]
+
+// -----
+
+func.func @arg_ranges(%arg0: index, %arg1: index) -> index {
+ %0 = xla_gpu.apply_indexing
+ #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100),
+ domain: s0 in [0, 42], s1 in [0, 1000]>[%arg0, %arg1]
+ return %0 : index
+}
+
+// CHECK-LABEL: @arg_ranges
+// CHECK-NEXT: %[[C100:.*]] = arith.constant 100
+// CHECK-NEXT: %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
+// CHECK-NEXT: return %[[RET]]
+
+// -----
+
+func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) {
+ %0:2 = xla_gpu.apply_indexing
+ #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1),
+ domain: s0 in [-10, 42], s1 in [0, 1000]>[%arg0, %arg1]
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: @cant_lower
+// CHECK: affine.apply
+// CHECK-NEXT: arith.addi
+
+// -----
+
+func.func @order_summands(%arg1: index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg2 = %c0 to %c4 step %c1 {
+ scf.for %arg3 = %c0 to %c4 step %c1 {
+ %0 = xla_gpu.apply_indexing
+ #xla_gpu.indexing_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10),
+ domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]>[%arg2, %arg1, %arg3]
+ "dummy.op"(%0) : (index) -> ()
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: @order_summands
+// CHECK-SAME: (%[[ARG1:.*]]: index)
+// CHECK: scf.for %[[ARG2:.*]] =
+// CHECK: scf.for %[[ARG3:.*]] =
+// CHECK: arith.muli %[[ARG1]]
+// CHECK: arith.muli %[[ARG2]]
+// CHECK: arith.addi
+// CHECK: arith.addi %[[ARG1]], %[[ARG2]]
+// CHECK: arith.divui
+// CHECK: arith.addi
+// CHECK: arith.muli %[[ARG3]]
+// CHECK: arith.addi %5, %6 : index
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir
new file mode 100644
index 0000000..09c8901
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir
@@ -0,0 +1,292 @@
+// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -cse -canonicalize | FileCheck %s
+
+module {
+ func.func @unknown(%arg0: index {xla.range = [0 : index, 42 : index]}) -> i1 {
+ %c12 = arith.constant 12 : index
+ %eq = arith.cmpi eq, %arg0, %c12 : index
+ return %eq : i1
+ }
+}
+
+// CHECK: @unknown
+// CHECK: cmpi
+
+// -----
+
+module {
+ func.func @true(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
+ %c5 = arith.constant 5 : index
+ %eq = arith.cmpi sge, %arg0, %c5 : index
+ return %eq : i1
+ }
+}
+
+// CHECK: @true
+// CHECK-NEXT: constant true
+// CHECK-NEXT: return
+
+// -----
+
+module {
+ func.func @false(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
+ %c5 = arith.constant 5 : index
+ %eq = arith.cmpi slt, %arg0, %c5 : index
+ return %eq : i1
+ }
+}
+
+// CHECK: @false
+// CHECK-NEXT: constant false
+// CHECK-NEXT: return
+
+// -----
+
+module {
+ func.func @rhs_range(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
+ %c42 = arith.constant 64 : index
+ %eq = arith.cmpi slt, %c42, %arg0 : index
+ return %eq : i1
+ }
+}
+
+// CHECK: @rhs_range
+// CHECK-NEXT: constant false
+// CHECK-NEXT: return
+
+// -----
+
+module {
+ func.func @both_range(%arg0: index {xla.range = [12 : index, 42 : index]},
+ %arg1: index {xla.range = [63 : index, 100 : index]}) -> i1 {
+ %eq = arith.cmpi slt, %arg0, %arg1 : index
+ return %eq : i1
+ }
+}
+
+// CHECK-LABEL: @both_range
+// CHECK-NEXT: constant true
+// CHECK-NEXT: return
+
+// -----
+
+module {
+ func.func @minsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+ %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+ %min = arith.minsi %arg0, %arg1 : index
+ return %min : index
+ }
+}
+
+// CHECK-LABEL: @minsi_lhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG0]]
+
+// -----
+
+module {
+ func.func @minsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+ %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+ %min = arith.minsi %arg1, %arg0 : index
+ return %min : index
+ }
+}
+
+// CHECK-LABEL: @minsi_rhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG0]]
+
+// -----
+
+module {
+ func.func @maxsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+ %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+ %min = arith.maxsi %arg1, %arg0 : index
+ return %min : index
+ }
+}
+
+// CHECK-LABEL: @maxsi_lhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG1]]
+
+// -----
+
+module {
+ func.func @maxsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+ %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+ %min = arith.maxsi %arg0, %arg1 : index
+ return %min : index
+ }
+}
+
+// CHECK-LABEL: @maxsi_rhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG1]]
+
+// -----
+
+module {
+ func.func @maxsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
+ %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+ %add = arith.addi %arg0, %arg1 : index
+ %min = arith.maxsi %add, %arg1 : index
+ return %min : index
+ }
+}
+
+// CHECK-LABEL: @maxsi_add
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG1]]
+// CHECK-NEXT: return %[[ADD]]
+
+// -----
+
+module {
+ func.func @minsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
+ %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+ %add = arith.addi %arg0, %arg1 : index
+ %min = arith.minsi %add, %arg1 : index
+ return %min : index
+ }
+}
+
+// CHECK-LABEL: @minsi_add
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG1]]
+
+// -----
+
+module {
+ func.func @pred_reduce(%in: i1) -> i1 {
+ %c1_i32 = arith.constant 1 : i32
+ %c2_i32 = arith.constant 2 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c8_i32 = arith.constant 8 : i32
+ %c16_i32 = arith.constant 16 : i32
+ %c32_i32 = arith.constant 32 : i32
+ %0 = arith.extui %in : i1 to i32
+ %shuffleResult, %valid = gpu.shuffle down %0, %c16_i32, %c32_i32 : i32
+ %1 = arith.trunci %shuffleResult : i32 to i1
+ %2 = arith.ori %in, %1 : i1
+ %3 = arith.extui %2 : i1 to i32
+ %shuffleResult_0, %valid_1 = gpu.shuffle down %3, %c8_i32, %c32_i32 : i32
+ %4 = arith.trunci %shuffleResult_0 : i32 to i1
+ %5 = arith.ori %2, %4 : i1
+ %6 = arith.extui %5 : i1 to i32
+ %shuffleResult_2, %valid_3 = gpu.shuffle down %6, %c4_i32, %c32_i32 : i32
+ %7 = arith.trunci %shuffleResult_2 : i32 to i1
+ %8 = arith.ori %5, %7 : i1
+ %9 = arith.extui %8 : i1 to i32
+ %shuffleResult_4, %valid_5 = gpu.shuffle down %9, %c2_i32, %c32_i32 : i32
+ %10 = arith.trunci %shuffleResult_4 : i32 to i1
+ %11 = arith.ori %8, %10 : i1
+ %12 = arith.extui %11 : i1 to i32
+ %shuffleResult_6, %valid_7 = gpu.shuffle down %12, %c1_i32, %c32_i32 : i32
+ %13 = arith.trunci %shuffleResult_6 : i32 to i1
+ %14 = arith.ori %11, %13 : i1
+ return %14 : i1
+ }
+}
+
+// CHECK-LABEL: @pred_reduce
+// CHECK-SAME: (%[[IN:.*]]: i1)
+// CHECK: %[[IN_EXT:.*]] = arith.extui %[[IN]]
+// CHECK-NEXT: %[[SHUFFLE0:.*]], {{.*}} = gpu.shuffle down %[[IN_EXT]]
+// CHECK-NEXT: %[[OR0:.*]] = arith.ori %[[IN_EXT]], %[[SHUFFLE0]]
+// CHECK-NEXT: %[[SHUFFLE1:.*]], {{.*}} = gpu.shuffle down %[[OR0]]
+// CHECK-NEXT: %[[OR1:.*]] = arith.ori %[[OR0]], %[[SHUFFLE1]]
+// CHECK-NEXT: %[[SHUFFLE2:.*]], {{.*}} = gpu.shuffle down %[[OR1]]
+// CHECK-NEXT: %[[OR2:.*]] = arith.ori %[[OR1]], %[[SHUFFLE2]]
+// CHECK-NEXT: %[[SHUFFLE3:.*]], {{.*}} = gpu.shuffle down %[[OR2]]
+// CHECK-NEXT: %[[OR3:.*]] = arith.ori %[[OR2]], %[[SHUFFLE3]]
+// CHECK-NEXT: %[[SHUFFLE4:.*]], {{.*}} = gpu.shuffle down %[[OR3]]
+// CHECK-NEXT: %[[OR4:.*]] = arith.ori %[[OR3]], %[[SHUFFLE4]]
+// CHECK-NEXT: %[[RET:.*]] = arith.trunci %[[OR4]]
+// CHECK-NEXT: return %[[RET]]
+
+// -----
+
+module {
+ func.func @andi_no_trunc_arg(%a: i4, %b: i8) -> i4 {
+ %lhs = arith.extui %a : i4 to i8
+ %add = arith.andi %lhs, %b : i8
+ %ret = arith.trunci %add : i8 to i4
+ return %ret : i4
+ }
+}
+
+// CHECK-LABEL: @andi_no_trunc_arg
+// CHECK-NEXT: extui
+// CHECK-NEXT: andi
+// CHECK-NEXT: trunci
+// CHECK-NEXT: return
+
+// -----
+
+module {
+ func.func @ori_mismatched_narrowest(%a: i8, %b: i8) -> i8 {
+ %0 = arith.trunci %a : i8 to i4
+ %1 = arith.extui %0 : i4 to i8
+ %ret = arith.ori %b, %1 : i8
+ return %ret : i8
+ }
+}
+
+// CHECK-LABEL: @ori_mismatched_narrowest
+// CHECK-NEXT: trunci
+// CHECK-NEXT: extui
+// CHECK-NEXT: ori
+// CHECK-NEXT: return
+
+// -----
+
+func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c42_f32 = arith.constant 42.0 : f32
+ %loop = scf.for %i = %c0 to %c3 step %c1
+ iter_args(%in_ = %tensor) -> (tensor<100xf32>) {
+ %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 4), domain: d0 in [0, 9]>(%i)
+ %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32>
+ scf.yield %updated :tensor<100xf32>
+ }
+ func.return %loop : tensor<100xf32>
+}
+// CHECK-LABEL: func.func @refine_constraints
+// CHECK: %[[CST:.*]] = arith.constant 4.2
+// CHECK: scf.for
+// CHECK: tensor.insert %[[CST]]
+
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000),
+ domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9),
+ domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]>
+func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>,
+ %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c73 = arith.constant 73 : index
+ %c42_f32 = arith.constant 42.0 : f32
+ %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
+ %bl_x = gpu.block_id x {xla.range = [0 : index, 575 : index]}
+ %0 = scf.for %i = %c0 to %c73 step %c1 iter_args(%arg3 = %arg1)
+ -> (tensor<2400000x9xf32>) {
+ %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3)
+ -> (tensor<2400000x9xf32>) {
+ %3 = xla_gpu.apply_indexing #map(%th_x, %bl_x)[%i, %j]
+ %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x)[%j]
+ %inserted = tensor.insert %c42_f32 into %arg5[%3, %4]
+ : tensor<2400000x9xf32>
+ scf.yield %inserted : tensor<2400000x9xf32>
+ }
+ scf.yield %2 : tensor<2400000x9xf32>
+ }
+ return %0 : tensor<2400000x9xf32>
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768),
+// CHECK-LABEL: func.func @refine_constraints_for_symbol
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/unswitch_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/unswitch_loops.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir
new file mode 100644
index 0000000..16e4498
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir
@@ -0,0 +1,413 @@
+// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file -xla-gpu-vectorize-loads-stores -canonicalize -cse | FileCheck %s
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+ domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c64 = arith.constant 64 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %idx = xla_gpu.apply_indexing #map(%i)[%j]
+ %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 63]>
+// CHECK-LABEL: @simple_read
+// CHECK-SAME: (%[[ARG0:.*]]: tensor
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] =
+// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]])
+// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]]
+// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]]
+// CHECK-NEXT: vector.extract %[[V]][%[[J]]]
+// CHECK-NEXT: addf
+
+// -----
+
+module {
+ func.func @simple_read_2d(%arg0: tensor<64x2xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c64 = arith.constant 64 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK-LABEL: @simple_read_2d
+// CHECK-SAME: (%[[ARG0:.*]]: tensor
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]]
+// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[C0]]]
+// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]]
+// CHECK-NEXT: vector.extract %[[V]][%[[J]]]
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 + 1),
+ domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+ func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c63 = arith.constant 63 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %idx = xla_gpu.apply_indexing #map(%i)[%j]
+ %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK-LABEL: @misaligned_indexing_map
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 3 + s0),
+ domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+ func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c63 = arith.constant 63 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %idx = xla_gpu.apply_indexing #map(%i)[%j]
+ %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK-LABEL: @misaligned_indexing_map_2
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+module {
+ func.func @misaligned_shape(%arg0: tensor<64x3xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c64 = arith.constant 64 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %extracted = tensor.extract %arg0[%i, %j] : tensor<64x3xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK-LABEL: @misaligned_shape
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2),
+ domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+ func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c63 = arith.constant 63 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %idx = xla_gpu.apply_indexing #map(%i)[%j]
+ %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK-LABEL: @wrong_stride
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+// We could vectorize this as a float vector load of double the size, but we
+// don't currently.
+module {
+ func.func @simple_read_complex(%arg0: tensor<64x2xcomplex<f32>>, %i: index) -> (complex<f32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex<f32>
+ %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex<f32> {
+ %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xcomplex<f32>>
+ %added = complex.add %iter, %extracted : complex<f32>
+ scf.yield %added : complex<f32>
+ }
+ return %loop : complex<f32>
+ }
+}
+
+// CHECK-LABEL: @simple_read_complex
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+// This is vectorizable, but not currently supported.
+module {
+ func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c64 = arith.constant 64 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %extracted = tensor.extract %arg0[%j, %i]
+ : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK-LABEL: @layout
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+module {
+ func.func @simple_write(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 2 : index
+ %cst = arith.constant 0.0 : f32
+ %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+ %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
+ scf.yield %inserted : tensor<16x4xf32>
+ }
+ return %loop : tensor<16x4xf32>
+ }
+}
+
+// CHECK-LABEL: @simple_write
+// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[I:.*]]: index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[V:.*]] = scf.for
+// CHECK-NEXT: vector.insert
+// CHECK-NEXT: scf.yield
+// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[I]], %[[C0]]]
+// CHECK-NEXT: return %[[WRITTEN]]
+
+// -----
+
+module {
+ func.func @write_with_use(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 2 : index
+ %cst = arith.constant 0.0 : f32
+ %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+ %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
+ "dummy.op1"(%inserted) : (tensor<16x4xf32>) -> ()
+ scf.yield %inserted : tensor<16x4xf32>
+ }
+ return %loop : tensor<16x4xf32>
+ }
+}
+
+// CHECK-LABEL: @write_with_use
+// CHECK-NOT: transfer_write
+
+// -----
+
+module {
+ func.func @write_not_to_iter_arg(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 2 : index
+ %cst = arith.constant 0.0 : f32
+ %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+ %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
+ scf.yield %inserted : tensor<16x4xf32>
+ }
+ return %loop : tensor<16x4xf32>
+ }
+}
+
+// CHECK-LABEL: @write_not_to_iter_arg
+// CHECK-NOT: transfer_write
+
+// -----
+
+module {
+ func.func @write_not_yielded(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 2 : index
+ %cst = arith.constant 0.0 : f32
+ %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+ %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
+ scf.yield %arg0 : tensor<16x4xf32>
+ }
+ return %loop : tensor<16x4xf32>
+ }
+}
+
+// CHECK-LABEL: @write_not_yielded
+// CHECK-NOT: transfer_write
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512),
+ domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]>
+module {
+ func.func @multiple(%arg0: tensor<32x4096xf32>, %arg1: tensor<4096xbf16>,
+ %arg2: tensor<32xf32>, %arg3: tensor<32x4096xf32>,
+ %arg4: index) -> (tensor<32x4096xf32>, f32) {
+ %cst = arith.constant 1.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32>
+ %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) {
+ %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) {
+ %2 = xla_gpu.apply_indexing #map(%j, %arg4)[%i]
+ %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32>
+ %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16>
+ %3 = arith.extf %extracted3 : bf16 to f32
+ %4 = arith.addf %extracted2, %3 : f32
+ %5 = arith.addf %extracted1, %4 : f32
+ %6 = arith.addf %iter3, %5 : f32
+ %inserted = tensor.insert %5 into %iter2[%i, %2] : tensor<32x4096xf32>
+ scf.yield %inserted, %6 : tensor<32x4096xf32>, f32
+ }
+ scf.yield %1#0, %1#1 : tensor<32x4096xf32>, f32
+ }
+ return %0#0, %0#1 : tensor<32x4096xf32>, f32
+ }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 * 512),
+// CHECK-SAME: domain: d0 in [0, 255], s0 in [0, 7]>
+// CHECK-LABEL: @multiple
+// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]]
+// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]])[%[[I]]]
+// CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]]
+// CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]]
+// CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>)
+// CHECK-DAG: vector.extract %[[READ1]][%[[J]]]
+// CHECK-DAG: vector.extract %[[READ2]][%[[J]]]
+// CHECK: extf
+// CHECK-NEXT: addf
+// CHECK-NEXT: %[[TO_INSERT:.*]] = arith.addf
+// CHECK-NEXT: %[[TO_YIELD:.*]] = arith.addf
+// CHECK-NEXT: %[[V_NEXT:.*]] = vector.insert %[[TO_INSERT]], %[[V]] [%[[J]]]
+// CHECK-NEXT: scf.yield %[[TO_YIELD]], %[[V_NEXT]]
+// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[I]], %[[BASE]]]
+// CHECK: scf.yield %[[WRITTEN]], %[[INNER]]#0
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0),
+ domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c63 = arith.constant 63 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %idx = xla_gpu.apply_indexing #map(%i)[%j]
+ %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4),
+// CHECK-LABEL: @remainder_with_modulo
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]]
+// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
+// CHECK: vector.transfer_read {{.*}}[%[[BASE]]]
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0),
+ domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+ func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c63 = arith.constant 63 : index
+ %cst = arith.constant 0.0 : f32
+ %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+ %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+ %idx = xla_gpu.apply_indexing #map(%i)[%j]
+ %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+ %added = arith.addf %iter1, %extracted : f32
+ scf.yield %added : f32
+ }
+ scf.yield %inner : f32
+ }
+ return %outer : f32
+ }
+}
+
+// CHECK-LABEL: @remainder_with_modulo_misaligned
+// CHECK-NOT: vector.transfer_read
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc
new file mode 100644
index 0000000..d514a67
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc
@@ -0,0 +1,106 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_UNSWITCHLOOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class UnswitchLoopsPass
+ : public impl::UnswitchLoopsPassBase<UnswitchLoopsPass> {
+ public:
+ void runOnOperation() override;
+};
+
+struct UnswitchLoop : mlir::OpRewritePattern<mlir::scf::ForOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
+ if (op.getBody()->getOperations().size() != 2) {
+ return rewriter.notifyMatchFailure(
+ op, "loop body is not a single instruction");
+ }
+ auto if_op = mlir::dyn_cast<mlir::scf::IfOp>(op.getBody()->front());
+ if (!if_op) {
+ return rewriter.notifyMatchFailure(op, "no if found inside the loop");
+ }
+ if (mlir::matchPattern(if_op.getCondition(), mlir::m_Constant())) {
+ return rewriter.notifyMatchFailure(op, "condition is a constant");
+ }
+
+ auto true_cst = rewriter.create<mlir::arith::ConstantOp>(
+ op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+ auto false_cst = rewriter.create<mlir::arith::ConstantOp>(
+ op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
+ rewriter.setInsertionPoint(op);
+ mlir::IRMapping mapping;
+ mapping.map(if_op.getCondition(), false_cst);
+ auto false_branch_loop = op->clone(mapping);
+ auto new_if = rewriter.create<mlir::scf::IfOp>(
+ op.getLoc(), op.getResultTypes(), if_op.getCondition(), true, true);
+ rewriter.replaceAllUsesWith(op.getResults(), new_if.getResults());
+
+ auto then_builder = new_if.getThenBodyBuilder(rewriter.getListener());
+ auto then_yield =
+ then_builder.create<mlir::scf::YieldOp>(op.getLoc(), op.getResults());
+ rewriter.moveOpBefore(op, then_yield);
+ rewriter.modifyOpInPlace(if_op, [&]() { if_op->setOperand(0, true_cst); });
+
+ auto else_builder = new_if.getElseBodyBuilder(rewriter.getListener());
+ else_builder.insert(false_branch_loop);
+ else_builder.create<mlir::scf::YieldOp>(op.getLoc(),
+ false_branch_loop->getResults());
+
+ return mlir::success();
+ }
+};
+
+void UnswitchLoopsPass::runOnOperation() {
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<UnswitchLoop>(&getContext());
+ mlir::scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
+ mlir::scf::IfOp::getCanonicalizationPatterns(patterns, &getContext());
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateUnswitchLoopsPass() {
+ return std::make_unique<UnswitchLoopsPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc
new file mode 100644
index 0000000..dd5d443
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc
@@ -0,0 +1,358 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <memory>
+#include <numeric>
+#include <optional>
+#include <utility>
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_VECTORIZELOADSANDSTORESPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+// Tries to find the stride of a symbol or dimension in an affine expression.
+// Returns std::nullopt if the stride could not be determined.
+//
+// Note: this function only attempts to handle the cases where the stride is
+// known to be 0 or 1.
+//
+// Example: the stride of `d0` in `(d0 + d1)` is 1.
+// Example: the stride of `d0` in `d0 * 2` is unknown (nullopt).
+std::optional<int> GetStride(mlir::AffineExpr expr,
+ mlir::AffineExpr dim_or_sym) {
+ if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
+ auto lhs_stride = GetStride(binop.getLHS(), dim_or_sym);
+ auto rhs_stride = GetStride(binop.getRHS(), dim_or_sym);
+
+ if (binop.getKind() == mlir::AffineExprKind::Add) {
+ if (lhs_stride && rhs_stride) {
+ return *lhs_stride + *rhs_stride;
+ }
+ return std::nullopt;
+ }
+ // Just return 0 if the expression doesn't occur on either side.
+ if (lhs_stride == 0 && rhs_stride == 0) {
+ return 0;
+ }
+ // Otherwise, we don't know the stride.
+ return std::nullopt;
+ }
+ return expr == dim_or_sym ? 1 : 0;
+}
+
+int64_t GetAlignmentOfRemainder(mlir::AffineExpr expr,
+ mlir::AffineExpr dim_or_sym) {
+ if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
+ auto lhs_align = GetAlignmentOfRemainder(binop.getLHS(), dim_or_sym);
+ auto rhs_align = GetAlignmentOfRemainder(binop.getRHS(), dim_or_sym);
+
+ std::optional<int64_t> rhs_cst = std::nullopt;
+ if (binop.getRHS().getKind() == mlir::AffineExprKind::Constant) {
+ rhs_cst = binop.getRHS().cast<mlir::AffineConstantExpr>().getValue();
+ }
+
+ switch (binop.getKind()) {
+ case mlir::AffineExprKind::Add:
+ if (binop.getLHS() == dim_or_sym) return rhs_align;
+ if (binop.getRHS() == dim_or_sym) return lhs_align;
+ return std::gcd(lhs_align, rhs_align);
+ case mlir::AffineExprKind::Mul:
+ return lhs_align * rhs_align;
+ case mlir::AffineExprKind::FloorDiv:
+ case mlir::AffineExprKind::CeilDiv:
+ return 1;
+ case mlir::AffineExprKind::Mod:
+ // (a * c) % (b * c) = (a % b) * c.
+ return std::gcd(lhs_align, rhs_align);
+ default:
+ llvm_unreachable("expr is none of the binary expressions");
+ }
+ }
+ if (auto cst = mlir::dyn_cast<mlir::AffineConstantExpr>(expr)) {
+ return cst.getValue();
+ }
+ return 1;
+}
+
+// Attempts to extract the vector type for the given loop. This means:
+// - checks that the lower bound is 0
+// - checks that the step is 1
+// - checks that the upper bound is 2 or 4.
+// Returns a vector type with the given upper bound and the tensor's element
+// type.
+mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type,
+ mlir::scf::ForOp loop) {
+ // TODO(jreiffers): Support layouts.
+ if (tensor_type.getEncoding()) {
+ return nullptr;
+ }
+ if (!mlir::VectorType::isValidElementType(tensor_type.getElementType())) {
+ return nullptr;
+ }
+ if (mlir::getConstantIntValue(loop.getStep()) != 1 ||
+ mlir::getConstantIntValue(loop.getLowerBound()) != 0) {
+ return nullptr;
+ }
+ std::optional<int> vector_size =
+ mlir::getConstantIntValue(loop.getUpperBound());
+ if (vector_size != 2 && vector_size != 4) {
+ return nullptr; // Unsupported vector size.
+ }
+ if (tensor_type.getRank() > 1 &&
+ tensor_type.getShape().back() % *vector_size) {
+ return nullptr; // Misaligned start indices.
+ }
+ return mlir::VectorType::get({*vector_size}, tensor_type.getElementType());
+}
+
+std::optional<llvm::SmallVector<mlir::Value>> GetVectorBaseIndices(
+ mlir::ValueRange indices, mlir::scf::ForOp loop,
+ mlir::VectorType vector_type, mlir::ImplicitLocOpBuilder& b) {
+ if (indices.empty()) {
+ return std::nullopt;
+ }
+
+ // The major dimensions' indices must all be defined outside the loop.
+ for (int i = 0; i < indices.size() - 1; ++i) {
+ if (!indices[i].getParentRegion()->isProperAncestor(
+ &loop.getBodyRegion())) {
+ return std::nullopt;
+ }
+ }
+
+ mlir::Value induction_var = loop.getInductionVar();
+ if (indices.back() == induction_var) {
+ llvm::SmallVector<mlir::Value> ret = indices;
+ ret.back() = b.create<mlir::arith::ConstantIndexOp>(0);
+ return ret;
+ }
+
+ auto apply_indexing =
+ mlir::dyn_cast_or_null<ApplyIndexingOp>(indices.back().getDefiningOp());
+ if (!apply_indexing) {
+ return std::nullopt;
+ }
+
+ // We don't generate these, but they are allowed in theory.
+ if (apply_indexing->getNumResults() != 1) {
+ return std::nullopt;
+ }
+ mlir::AffineMap map = apply_indexing.getAffineMap();
+
+ int induction_var_operand_index;
+ mlir::AffineExpr induction_var_expr = nullptr;
+ for (auto [index, operand] : llvm::enumerate(apply_indexing.getOperands())) {
+ if (operand == induction_var) {
+ if (induction_var_expr) {
+ // The induction variable should be used only once.
+ return std::nullopt;
+ }
+ induction_var_operand_index = index;
+ induction_var_expr = index < map.getNumDims()
+ ? mlir::getAffineDimExpr(index, b.getContext())
+ : mlir::getAffineSymbolExpr(
+ index - map.getNumDims(), b.getContext());
+ }
+ }
+ if (!induction_var_expr) {
+ return std::nullopt;
+ }
+
+ if (GetStride(map.getResult(0), induction_var_expr) != 1) {
+ // The indexing map is not contiguous in the vectorized dimension.
+ return std::nullopt;
+ }
+
+ if (GetAlignmentOfRemainder(map.getResult(0), induction_var_expr) %
+ vector_type.getNumElements()) {
+ return std::nullopt;
+ }
+
+ auto operands = llvm::to_vector(apply_indexing.getOperands());
+ operands[induction_var_operand_index] =
+ b.create<mlir::arith::ConstantIndexOp>(0);
+
+ llvm::SmallVector<mlir::Value> ret = indices;
+ ret.back() =
+ b.create<ApplyIndexingOp>(operands, apply_indexing.getIndexingMap())
+ ->getResult(0);
+ return ret;
+}
+
+bool IsConflictFree(mlir::tensor::ExtractOp op) {
+ return op.getTensor().getParentRegion()->isProperAncestor(
+ op->getParentRegion());
+}
+
+struct VectorizeLoad : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::tensor::ExtractOp op,
+ mlir::PatternRewriter& rewriter) const override {
+ auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
+ if (!loop) {
+ return rewriter.notifyMatchFailure(op, "no loop found");
+ }
+ if (!IsConflictFree(op)) {
+ return rewriter.notifyMatchFailure(op,
+ "source may be written in the loop");
+ }
+
+ auto vector_type = GetVectorType(op.getTensor().getType(), loop);
+ if (!vector_type) {
+ return rewriter.notifyMatchFailure(op, "not a vectorizable loop");
+ }
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ b.setInsertionPoint(loop);
+ auto vector_indices =
+ GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
+ if (!vector_indices) {
+ return rewriter.notifyMatchFailure(
+ op, "the instruction does not access contiguous elements");
+ }
+
+ auto loaded_vector = b.create<mlir::vector::TransferReadOp>(
+ vector_type, op.getTensor(), *vector_indices,
+ llvm::ArrayRef<bool>{true});
+ rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(
+ op, loaded_vector, loop.getInductionVar());
+ return mlir::success();
+ }
+};
+
+// Verifies that the insertions happening in the loop can all safely be batched
+// in the end.
+bool IsConflictFree(mlir::tensor::InsertOp op) {
+ // The insertion's only use must be the yield.
+ if (!op->hasOneUse() || !mlir::isa<mlir::scf::YieldOp>(*op->user_begin())) {
+ return false;
+ }
+ // The destination must be one of the loop's block arguments, and the
+ // destination must be the argument's only use.
+ auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(op.getDest());
+ return bbarg && bbarg.hasOneUse() &&
+ bbarg.getOwner()->getParentOp() == op->getParentOp();
+}
+
+struct VectorizeStore : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::tensor::InsertOp op,
+ mlir::PatternRewriter& rewriter) const override {
+ auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
+ if (!loop) {
+ return rewriter.notifyMatchFailure(op, "no loop found");
+ }
+ if (!IsConflictFree(op)) {
+ return rewriter.notifyMatchFailure(op, "write may be read back by loop");
+ }
+ auto vector_type = GetVectorType(op.getDest().getType(), loop);
+ if (!vector_type) {
+ return rewriter.notifyMatchFailure(op, "loop is not vectorizable");
+ }
+
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ b.setInsertionPoint(loop);
+ auto vector_indices =
+ GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
+ if (!vector_indices) {
+ return rewriter.notifyMatchFailure(
+ op, "the instruction does not access contiguous elements");
+ }
+
+ auto init = b.create<mlir::arith::ConstantOp>(b.getZeroAttr(vector_type))
+ .getResult();
+
+ auto yield_fn = [&](mlir::OpBuilder& yield_b, mlir::Location yield_loc,
+ llvm::ArrayRef<mlir::BlockArgument> bbarg) {
+ auto induction_var =
+ mlir::cast<mlir::scf::ForOp>(bbarg.front().getOwner()->getParentOp())
+ .getInductionVar();
+ auto insert_op = yield_b.create<mlir::vector::InsertOp>(
+ yield_loc, op.getScalar(), bbarg.front(), induction_var);
+ return llvm::SmallVector<mlir::Value>{insert_op.getResult()};
+ };
+ int result_index = op->use_begin()->getOperandNumber();
+ auto new_for = *loop.replaceWithAdditionalYields(
+ rewriter, init,
+ /*replaceInitOperandUsesInLoop=*/false, yield_fn);
+
+ b.setInsertionPointAfter(new_for);
+ rewriter.replaceOp(op, op.getDest());
+
+ auto filled_vector = new_for->getResults().back();
+ auto written = b.create<mlir::vector::TransferWriteOp>(
+ filled_vector, new_for.getInits()[result_index], *vector_indices,
+ llvm::ArrayRef<bool>{true});
+ new_for->getResult(result_index).replaceAllUsesWith(written.getResult());
+
+ return mlir::success();
+ }
+};
+
+class VectorizeLoadsAndStoresPass
+ : public impl::VectorizeLoadsAndStoresPassBase<
+ VectorizeLoadsAndStoresPass> {
+ public:
+ void runOnOperation() override {
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<VectorizeLoad, VectorizeStore>(&getContext());
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateVectorizeLoadsAndStoresPass() {
+ return std::make_unique<VectorizeLoadsAndStoresPass>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.cc b/third_party/xla/xla/service/gpu/fusions/transpose.cc
deleted file mode 100644
index 611099d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/transpose.cc
+++ /dev/null
@@ -1,366 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/transpose.h"
-
-#include <array>
-#include <cstdint>
-#include <optional>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-#include "llvm/IR/Value.h"
-#include "llvm/Support/AtomicOrdering.h"
-#include "mlir/IR/AffineMap.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/target_util.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/service/llvm_ir/loop_emitter.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info,
- const TransposeDescription& tiled_transpose) {
- constexpr int kNumRows = 4;
- static_assert(WarpSize() % kNumRows == 0);
-
- // 3D view over the output shape.
- Vector3 transposed_dims = tiled_transpose.dimensions;
- Vector3 permutation = tiled_transpose.permutation;
-
- // Note: the supported permutations are their own inverses. Therefore we
- // always use the permutation, even when we want the inverse.
- CHECK((permutation == Vector3{0, 2, 1}) || (permutation == Vector3{2, 1, 0}));
-
- absl::InlinedVector<int64_t, 4> input_dims{transposed_dims[permutation[0]],
- transposed_dims[permutation[1]],
- transposed_dims[permutation[2]]};
-
- // We tile along the minor dimensions pre- and post-transpose.
- absl::InlinedVector<int64_t, 4> tile_sizes{1, 1, 1};
- tile_sizes[permutation[2]] = WarpSize() / kNumRows;
- absl::InlinedVector<int64_t, 4> num_threads{1, 1, WarpSize()};
- num_threads[permutation[2]] = kNumRows;
-
- auto capability = gpu_device_info.gpu_compute_capability();
- std::visit(
- [&](const auto& capability) {
- if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
- stream_executor::RocmComputeCapability>) {
- // kNumRows = 8 works well on MI300 with wavefront size 64.
- if (capability.gfx9_mi300()) {
- tile_sizes[permutation[2]] = gpu_device_info.threads_per_warp() / 8;
- num_threads[permutation[2]] = 8;
- }
- }
- },
- capability);
-
- return Tiling(input_dims, tile_sizes, num_threads);
-}
-
-void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
- IrEmitterContext& ir_emitter_context) {
- auto* module = builder->GetInsertBlock()->getModule();
- if (IsAMDGPU(module) &&
- ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
- builder->CreateFence(
- llvm::AtomicOrdering::SequentiallyConsistent,
- builder->getContext().getOrInsertSyncScopeID("workgroup"));
- }
-}
-
-void EmitSyncThreads(llvm::IRBuilder<>* builder,
- IrEmitterContext& ir_emitter_context) {
- MaybeEmitFenceForAMDGPU(builder, ir_emitter_context);
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder);
-}
-
-llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index,
- absl::Span<const int64_t> permutation) {
- return llvm_ir::IrArray::Index{Permute(index.multidim(), permutation),
- Permute(index.dims(), permutation),
- index.GetType()};
-}
-
-} // namespace
-
-TransposeFusion::TransposeFusion(const se::DeviceDescription& gpu_device_info,
- const HloFusionAnalysis& analysis)
- : analysis_(analysis),
- tiling_(
- ComputeTransposeTiling(gpu_device_info, analysis.tiled_transpose())) {
- for (auto [root, hero] :
- llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) {
- if (auto transpose = GetDescriptionForTiledTransposeEmitter(
- root.instruction(), hero.instruction())) {
- permutation_ = transpose->permutation;
- break;
- }
- }
-}
-
-absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const {
- const auto& hlo_roots = analysis_.fusion_roots();
- GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
- FusedIrEmitter fused_emitter(elemental_emitter);
- for (auto [i, input] : llvm::enumerate(inputs)) {
- HloInstruction* fused_operand = fusion.fused_parameter(i);
- fused_emitter.BindGenerator(
- *fused_operand, [input = input, builder,
- fused_operand](const llvm_ir::IrArray::Index& index) {
- return input.EmitReadArrayElement(index, builder,
- fused_operand->name());
- });
- }
-
- absl::flat_hash_map<const HloInstruction*,
- std::vector<std::pair<int64_t, const HloInstruction*>>>
- transposes_to_roots;
- // Keep a list of deduplicated transpose heroes separate from
- // transposes_to_roots to make the CodeGen deterministic.
- std::vector<TransposeDescription> transposes;
- transposes.reserve(hlo_roots.size());
- std::vector<std::pair<int64_t, const HloInstruction*>> extra_outputs;
-
- for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) {
- const auto& hero = analysis_.fusion_hero(output_idx).instruction();
- auto transpose_descr =
- GetDescriptionForTiledTransposeEmitter(root.instruction(), hero);
- if (transpose_descr.has_value()) {
- auto iterator_inserted = transposes_to_roots.insert(std::make_pair(
- &hero, std::vector<std::pair<int64_t, const HloInstruction*>>{
- {output_idx, &root.instruction()}}));
- if (iterator_inserted.second) {
- transposes.push_back(*transpose_descr);
- } else {
- iterator_inserted.first->second.push_back(
- {output_idx, &root.instruction()});
- }
- } else {
- extra_outputs.push_back({output_idx, &root.instruction()});
- }
- }
-
- absl::flat_hash_map<const HloInstruction*, llvm_ir::SharedMemoryTile> tiles;
- Vector3 permutation;
- for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) {
- permutation = tr.permutation;
- auto tile_size = tiling_.GetBlockTileSize();
- ++tile_size.back(); // Prevent bank conflicts.
- auto* module = ir_emitter_context.llvm_module();
- tiles[tr.instr] = llvm_ir::AllocateSharedMemoryTile(
- module,
- llvm_ir::PrimitiveTypeToIrType(tr.instr->shape().element_type(),
- module),
- tile_size, absl::StrCat("tr_tile_", tile_idx));
- }
-
- auto tile_generator = [&](const TilingThreadIdInfo& thread_id_info,
- const llvm_ir::IrArray::Index& tile_start_index,
- absl::Span<llvm::Value* const> tile_dimensions) {
- // Copy input parameter values to shared memory buffers:
- // tile[thread_id_y, thread_id_x] = input[index]
- EmitTile(builder, tiling_, thread_id_info, tile_dimensions,
- [&](absl::Span<llvm::Value* const> index_in_tile) {
- auto index = tile_start_index.AddOffset(index_in_tile, builder);
- for (const auto& tr : transposes) {
- auto input_gen =
- *fused_emitter.GetGenerator(*tr.instr->operand(0));
- auto input_index = index.SourceIndexOfBitcast(
- tr.instr->operand(0)->shape(), builder);
- llvm::Value* value = *input_gen(input_index);
- tiles[tr.instr].Store(value, index_in_tile, builder);
- }
-
- // Compute all extra output values before writing them. This
- // avoids overwriting aliased input/output values before all
- // reads occurred.
- std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
- llvm::Value*>>
- scheduled_writes;
- for (const auto& [output_idx, root] : extra_outputs) {
- auto extra_output_index =
- index.SourceIndexOfBitcast(root->shape(), builder);
- auto output_gen = *fused_emitter.GetGenerator(*root);
- llvm::Value* output_value = *output_gen(extra_output_index);
- scheduled_writes.emplace_back(
- outputs[output_idx], extra_output_index, output_value);
- }
-
- for (const auto& [output, idx, value] : scheduled_writes) {
- output.EmitWriteArrayElement(idx, value, builder);
- }
- });
-
- EmitSyncThreads(builder, ir_emitter_context);
-
- auto output_tile_index = PermuteIndex(tile_start_index, permutation);
- auto transposed_tile_dimensions = Permute(tile_dimensions, permutation);
-
- EmitTile(
- builder, tiling_, thread_id_info, transposed_tile_dimensions,
- /*emit_elem_function=*/
- [&](absl::Span<llvm::Value* const> index_in_tile) {
- auto index = output_tile_index.AddOffset(index_in_tile, builder);
- for (const auto& tr : transposes) {
- llvm::Value* loaded = tiles[tr.instr].Load(
- Permute(index_in_tile, permutation), builder);
-
- FusedIrEmitter fused_emitter(elemental_emitter);
- fused_emitter.BindGenerator(
- *tr.instr,
- [&](const llvm_ir::IrArray::Index&) { return loaded; });
- for (int64_t i = 0;
- i < fusion.fused_instructions_computation()->num_parameters();
- ++i) {
- llvm_ir::IrArray ir_array = inputs[i];
- HloInstruction* fused_operand = fusion.fused_parameter(i);
- fused_emitter.BindGenerator(
- *fused_operand, [=](const llvm_ir::IrArray::Index& index) {
- return ir_array.EmitReadArrayElement(index, builder,
- fused_operand->name());
- });
- }
-
- // Apply code generation for the code after the real hero.
- // Compute all output values before writing them. This avoids
- // overwriting aliased input/output values before all reads
- // occurred.
- std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
- llvm::Value*>>
- scheduled_writes;
- for (const auto& [output_idx, root] :
- transposes_to_roots[tr.instr]) {
- TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen,
- fused_emitter.GetGenerator(*root));
-
- // Both for emission and writing it should be
- // index-as-transformed by the computation.
- auto untiled_index =
- index.SourceIndexOfBitcast(root->shape(), builder);
- TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index));
- scheduled_writes.emplace_back(outputs[output_idx], untiled_index,
- generated);
- }
- for (const auto& [output, idx, value] : scheduled_writes) {
- output.EmitWriteArrayElement(idx, value, builder);
- }
- }
- return absl::OkStatus();
- });
- };
-
- llvm::Type* index_type =
- GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
- return EmitTilingKernel(builder, tiling_, index_type, tile_generator)
- .status();
-}
-
-LaunchDimensions TransposeFusion::launch_dimensions() const {
- return LaunchDimensions(tiling_.GetNumBlocks(),
- tiling_.GetNumThreadsPerBlock());
-}
-
-std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const {
- const auto& hero = analysis_.fusion_hero(root_index);
- if (hero.opcode() != HloOpcode::kTranspose) {
- // The shape of non-transpose roots are bitcast compatible with the input
- // shape of transpose heroes.
- auto map = ComposeIndexingMaps(
- GetIndexingMapForTiling(tiling_, ctx),
- GetBitcastMap(tiling_.GetXlaShape(),
- analysis_.fusion_root(root_index).shape(), ctx));
- map.Simplify();
- return map;
- }
-
- // The block offsets are permuted, but the thread offsets remain the same.
- auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx)
- .getSubMap(std::vector<unsigned>{permutation_.begin(),
- permutation_.end()});
- auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx);
- auto permuted_tiled_shape =
- ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_));
-
- auto map = ComposeIndexingMaps(
- GetIndexingMapForTiling(
- block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(),
- tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(),
- permuted_tiled_shape.dimensions()),
- GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx));
- map.Simplify();
- return map;
-}
-
-std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const {
- const auto& hero = analysis_.fusion_hero(root_index).instruction();
- if (hero.opcode() != HloOpcode::kTranspose) {
- auto map = ComposeIndexingMaps(
- *ComputeThreadIdToOutputIndexing(root_index, ctx),
- *ComputeOutputToInputIndexing(
- &analysis_.fusion_root(root_index).instruction(), 0, ctx)
- .indexing_maps[hero_operand_index]
- .begin());
- map.Simplify();
- return map;
- }
-
- auto map = ComposeIndexingMaps(
- GetIndexingMapForTiling(tiling_, ctx),
- GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx));
- map.Simplify();
- return map;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.h b/third_party/xla/xla/service/gpu/fusions/transpose.h
deleted file mode 100644
index 3f369a4..0000000
--- a/third_party/xla/xla/service/gpu/fusions/transpose.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_
-#define XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
-// algorithm to improve the memory access patterns for the input parameters
-// with a shape that is a 0-2-1 transpose of the output tensor shape. The
-// caller is responsible for making sure that it is safe to apply the shared
-// memory transpose on the input parameters.
-//
-// For the purpose of tiling, the output tensors have a logical shape of three
-// components 0-2-1 while the relevant input parameters have a logical shape
-// of three components 0-1-2 in the order major to minor. The x- and y-
-// dimensions of the tensors are tiled in square tiles with an edge length
-// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
-// transposes one tile: each thread copies kTileSize/kNumRows elements from
-// the input to a shared memory tile, then the otherwise "regular HLO kernel"
-// reads from the shared memory instead of the original input.
-//
-// This is similar to the following CUDA algorithm in TensorFlow:
-// https://goo.gl/MStRV6.
-//
-// `kTileSize` should usually be same as warp size. We currently choose 32 for
-// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
-//
-// TODO(b/33320379): Here each block transposes 1 tile. It may be more
-// efficient to launch fewer blocks so each transposes many tiles.
-class TransposeFusion : public KernelFusionEmitterBase {
- public:
- explicit TransposeFusion(const se::DeviceDescription& gpu_device_info,
- const HloFusionAnalysis& analysis);
- LaunchDimensions launch_dimensions() const override;
-
- std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
- int64_t root_index, mlir::MLIRContext* ctx) const override;
-
- std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
- int64_t root_index, int64_t hero_operand_index,
- mlir::MLIRContext* ctx) const override;
-
- protected:
- absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
- const HloFusionInstruction& fusion,
- const LaunchDimensions& launch_dims,
- std::vector<llvm_ir::IrArray> inputs,
- std::vector<llvm_ir::IrArray> outputs,
- llvm::IRBuilder<>* builder) const override;
-
- private:
- const HloFusionAnalysis& analysis_;
- Tiling tiling_;
- Vector3 permutation_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
index 8053a89..c6ce433 100644
--- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
@@ -24,6 +24,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
+#include "absl/status/status.h"
#include "absl/types/span.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -41,14 +42,12 @@
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/mlir/utils/type_util.h"
#include "xla/permutation_util.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/type_util.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
@@ -72,8 +71,6 @@
using mlir::ValueRange;
using mlir::func::FuncOp;
using mlir::func::ReturnOp;
-using mlir::tensor::ExtractOp;
-using mlir::tensor::InsertOp;
using mlir_converter::ApplyIndexing;
constexpr int kNumRows = 4;
@@ -99,6 +96,11 @@
transposes_to_tile.insert(&hero.instruction());
shmem_transpose_roots_.push_back(&root.instruction());
int size = primitive_util::ByteWidth(hero.shape().element_type());
+ // If the last dimension stays the same, we need to make it part of the
+ // shared memory tile.
+ if (MostMinorDimensionUnchanged()) {
+ size *= input_shape_.back();
+ }
max_element_bytes = std::max(max_element_bytes, size);
shmem_usage += kBaseBlockSize * (kBaseBlockSize + 1) * size;
shmem_transpose_root_indices_.push_back(index);
@@ -113,8 +115,12 @@
auto compute_block_sizes = [this](int vector_size) {
vector_size_ = vector_size;
block_size_ = kBaseBlockSize * vector_size_;
- block_sizes_ = {1, 1, block_size_};
- block_sizes_[permutation_[2]] = block_size_;
+ if (MostMinorDimensionUnchanged()) {
+ block_sizes_ = {block_size_, block_size_, input_shape_.back()};
+ } else {
+ block_sizes_ = {1, 1, block_size_};
+ block_sizes_[permutation_[2]] = block_size_;
+ }
block_counts_ = {CeilOfRatio(input_shape_[0], block_sizes_[0]),
CeilOfRatio(input_shape_[1], block_sizes_[1]),
CeilOfRatio(input_shape_[2], block_sizes_[2])};
@@ -137,6 +143,10 @@
shmem_usage * elems_per_thread <= device.shared_memory_per_block();
bool aligned_dims = (input_shape_[2] % vec_size == 0) &&
(input_shape_[permutation_[2]] % vec_size == 0);
+ if (MostMinorDimensionUnchanged()) {
+ aligned_dims =
+ input_shape_[0] % vec_size == 0 && input_shape_[1] % vec_size == 0;
+ }
if (enough_work && enough_shmem && aligned_dims) {
compute_block_sizes(vec_size);
break;
@@ -187,7 +197,15 @@
IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing(
bool read, mlir::MLIRContext* ctx) const {
auto thread_offsets =
- Permute(GetThreadOffsets(ctx), read ? Vector3{0, 1, 2} : permutation_);
+ Permute(GetThreadOffsets(ctx),
+ read ? absl::InlinedVector<int64_t, 3>{0, 1, 2} : permutation_);
+ if (MostMinorDimensionUnchanged()) {
+ return {mlir::AffineMap::get(6, 3, thread_offsets, ctx),
+ DimVarsFromTensorSizes({kNumThreadsPerBlock, 1, 1, 1, 1, 1}),
+ RangeVarsFromTensorSizes(
+ {block_size_ / kNumRows, vector_size_, input_shape_[2]}),
+ {}};
+ }
return {mlir::AffineMap::get(6, 2, thread_offsets, ctx),
DimVarsFromTensorSizes({kNumThreadsPerBlock, 1, 1, 1, 1, 1}),
RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}),
@@ -203,7 +221,11 @@
MLIRContext* ctx = builder.getContext();
auto shmem_tensor_size = block_sizes_;
// Avoid bank conflicts.
- ++shmem_tensor_size.back();
+ if (MostMinorDimensionUnchanged()) {
+ ++shmem_tensor_size[1];
+ } else {
+ ++shmem_tensor_size.back();
+ }
// Allocate shared memory.
SmallVector<Value> inits;
@@ -237,8 +259,8 @@
root_computation, transpose,
/*operand_index=*/0, input_indices(transpose->operand(0)),
call_target_provider, entry_function, builder)[0];
- result_tensors.push_back(
- builder.create<InsertOp>(result_scalar, output, shmem_indices));
+ result_tensors.push_back(builder.create<mlir::tensor::InsertOp>(
+ result_scalar, output, shmem_indices));
}
// Produce all side outputs and then write them.
@@ -258,7 +280,7 @@
llvm::zip(side_outputs, side_output_indices,
output_tensors.take_back(side_output_roots_.size()))) {
result_tensors.push_back(
- builder.create<InsertOp>(value, output, indices));
+ builder.create<mlir::tensor::InsertOp>(value, output, indices));
}
return result_tensors;
@@ -306,7 +328,7 @@
for (auto [transpose, shmem] :
llvm::zip(shmem_transposes_, written.shmem_tensors)) {
transpose_values[transpose].push_back(
- builder.create<ExtractOp>(shmem, shmem_indices));
+ builder.create<mlir::tensor::ExtractOp>(shmem, shmem_indices));
}
llvm::SmallVector<Value> epilogue_indices = dim_values;
absl::c_copy(symbol_values, std::back_inserter(epilogue_indices));
@@ -320,7 +342,7 @@
shmem_transpose_root_indices_)) {
llvm::SmallVector<Value> indices =
ApplyIndexing(indexing, dim_values, symbol_values, builder);
- results[root_index] = builder.create<InsertOp>(
+ results[root_index] = builder.create<mlir::tensor::InsertOp>(
result_scalars.at(root).front(), results[root_index], indices);
}
return results;
@@ -372,6 +394,10 @@
auto vector = mlir::getAffineSymbolExpr(1, ctx);
int loop_stride = block_size_ * kNumRows;
auto linear_index = loop * loop_stride + thread * vector_size_ + vector;
+ if (MostMinorDimensionUnchanged()) {
+ auto minor_dim = mlir::getAffineSymbolExpr(2, ctx);
+ linear_index = linear_index * input_shape_[2] + minor_dim;
+ }
return DelinearizeInBoundsIndex(linear_index, block_sizes_);
}
@@ -380,19 +406,25 @@
mlir::MLIRContext* ctx) const {
auto raw_id = mlir::getAffineDimExpr(
KernelFusionInterface::kIndexingMapBlockIdxDims[0], ctx);
- auto block_ids = Permute(DelinearizeInBoundsIndex(raw_id, block_counts_),
- input ? Vector3{0, 1, 2} : permutation_);
+ auto block_ids =
+ Permute(DelinearizeInBoundsIndex(raw_id, block_counts_),
+ input ? absl::InlinedVector<int64_t, 3>{0, 1, 2} : permutation_);
auto thread_offsets = GetThreadOffsets(ctx);
llvm::SmallVector<AffineExpr, 3> offsets;
for (auto [block_id, block_size, thread] :
llvm::zip(block_ids, block_sizes_, thread_offsets)) {
offsets.push_back(block_id * block_size + thread);
}
+ auto range_var_sizes =
+ std::vector<int64_t>{block_size_ / kNumRows, vector_size_};
+ if (MostMinorDimensionUnchanged()) {
+ range_var_sizes.push_back(input_shape_[2]);
+ }
IndexingMap result{
- mlir::AffineMap::get(6, 2, offsets, ctx),
+ mlir::AffineMap::get(6, range_var_sizes.size(), offsets, ctx),
DimVarsFromTensorSizes(
{kNumThreadsPerBlock, 1, 1, Product(block_counts_), 1, 1}),
- RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}),
+ RangeVarsFromTensorSizes(range_var_sizes),
{}};
auto normalized_shape =
input ? ShapeUtil::MakeShape(shape.element_type(), input_shape_)
@@ -407,5 +439,9 @@
return result;
}
+bool MlirTransposeFusion::MostMinorDimensionUnchanged() const {
+ return permutation_.back() == permutation_.size() - 1;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h
index 07d1e99..538ad2f 100644
--- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h
@@ -19,6 +19,7 @@
#include <optional>
#include <vector>
+#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -35,7 +36,6 @@
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/util.h"
namespace xla {
namespace gpu {
@@ -98,9 +98,10 @@
IndexingMap GetSharedMemoryIndexing(bool read, mlir::MLIRContext* ctx) const;
llvm::SmallVector<mlir::AffineExpr, 4> GetThreadOffsets(
mlir::MLIRContext* ctx) const;
+ bool MostMinorDimensionUnchanged() const;
TransposeDescription transpose_;
- Vector3 permutation_;
+ absl::InlinedVector<int64_t, 3> permutation_;
std::vector<int64_t> input_shape_;
std::vector<int64_t> block_sizes_; // In input elements.
std::vector<int64_t> block_counts_;
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
index eb71bb7..eec531d 100644
--- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
@@ -44,7 +44,7 @@
)"));
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirTransposeFusion fusion(analysis);
EXPECT_THAT(
@@ -101,7 +101,7 @@
})"));
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirTransposeFusion fusion(analysis);
EXPECT_THAT(
@@ -158,7 +158,7 @@
)"));
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirTransposeFusion fusion(analysis);
EXPECT_THAT(
@@ -212,7 +212,7 @@
})"));
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirTransposeFusion fusion(analysis);
EXPECT_THAT(
@@ -295,6 +295,55 @@
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}
+TEST_F(MlirTransposeFusionTest, FusedTranspose102) {
+ auto kHloString = R"(
+ HloModule Transpose
+
+ %fused_computation {
+ %p0 = s8[160,170,3] parameter(0)
+ ROOT %transpose = s8[170,160,3] transpose(%p0), dimensions={1,0,2}
+ }
+ ENTRY main {
+ %param = s8[160,170,3] parameter(0)
+ ROOT %fusion = s8[170,160,3] fusion(%param), kind=kInput,
+ calls=%fused_computation
+ }
+ )";
+ TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"(
+ // CHECK-LABEL: func.func @fused_computation(
+ // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x3xi8>
+ //
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+
+ // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x33x3xi8>
+ // CHECK: %[[SHMEM_WITH_VALS:.*]] = scf.for
+ // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]]
+ // CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]])
+ // CHECK: %[[SHMEM_WITH_VALS2:.*]] = scf.for
+ // CHECK-SAME: %[[C0]] to %[[C3]] step %[[C1]]
+ // CHECK-SAME: iter_args(%[[SHMEM2_:.*]] = %[[SHMEM_]])
+ // CHECK: %[[P0:.*]] = xla_gpu.pure_call @fused_computation_p0
+ // CHECK: tensor.insert %[[P0]] into %[[SHMEM2_]]
+
+ // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]]
+
+ // CHECK: scf.for
+ // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]]
+ // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]])
+ // CHECK: scf.for
+ // CHECK-SAME: %[[C0]] to %[[C3]] step %[[C1]]
+ // CHECK-SAME: iter_args(%[[OUT2_:.*]] = %[[OUT_]])
+ // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[SYNC]]
+ // CHECK: %[[RES:.*]] = xla_gpu.pure_call @fused_computation__epilogue__transpose
+ // CHECK: tensor.insert %[[RES]] into %[[OUT2_]]
+ )"));
+
+ EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
+}
+
TEST_F(MlirTransposeFusionTest, FusedTranspose210) {
auto kHloString = R"(
HloModule Transpose
@@ -578,7 +627,7 @@
.value();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirTransposeFusion fusion(analysis);
mlir::MLIRContext mlir_context;
@@ -608,7 +657,7 @@
.value();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
MlirTransposeFusion fusion(analysis);
mlir::MLIRContext mlir_context;
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc
deleted file mode 100644
index f942469..0000000
--- a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc
+++ /dev/null
@@ -1,346 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/transpose.h"
-
-#include <memory>
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/statusor.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class TransposeTest : public HloTestBase {
- protected:
- DebugOptions GetDebugOptionsForTest() override {
- auto opts = HloTestBase::GetDebugOptionsForTest();
- opts.set_xla_gpu_mlir_emitter_level(0);
- return opts;
- }
- stream_executor::DeviceDescription device_info_ =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
-};
-
-absl::StatusOr<std::unique_ptr<TransposeFusion>> GetTransposeFusion(
- const HloFusionAnalysis& analysis) {
- auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
- auto fusion = dynamic_cast<TransposeFusion*>(emitter.get());
- TF_RET_CHECK(fusion != nullptr);
-
- emitter.release();
- return std::unique_ptr<TransposeFusion>{fusion};
-}
-
-TEST_F(TransposeTest, ThreadIndexing021) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fusion {
- %input = f32[100,32,64] parameter(0)
- ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
- }
-
- ENTRY entry {
- %input = f32[100,32,64] parameter(0)
- ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
- mlir::MLIRContext mlir_context;
-
- EXPECT_THAT(
- fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- d3 floordiv 2,
- d0 floordiv 32 + s1 * 4,
- (d3 mod 2) * 32 + d0 mod 32
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 199]
- d4 in [0, 0]
- d5 in [0, 0]
-
- s0 in [0, 0]
- s1 in [0, 7]
- s2 in [0, 0]
- )"));
- EXPECT_THAT(
- fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- d3 floordiv 2,
- (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
- d0 mod 32
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 199]
- d4 in [0, 0]
- d5 in [0, 0]
-
- s0 in [0, 0]
- s1 in [0, 7]
- s2 in [0, 0]
- )"));
-}
-
-TEST_F(TransposeTest, ThreadIndexing201) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fusion {
- %input = f32[100,64,32] parameter(0)
- ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1}
- }
-
- ENTRY entry {
- %input = f32[100,64,32] parameter(0)
- ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
- mlir::MLIRContext mlir_context;
- EXPECT_THAT(
- fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- d3 floordiv 2,
- (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
- d0 mod 32
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 199]
- d4 in [0, 0]
- d5 in [0, 0]
-
- s0 in [0, 0]
- s1 in [0, 7]
- s2 in [0, 0]
- )"));
- EXPECT_THAT(
- fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- d0 floordiv 32 + s1 * 4,
- d3 floordiv 2,
- (d3 mod 2) * 32 + d0 mod 32
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 199]
- d4 in [0, 0]
- d5 in [0, 0]
-
- s0 in [0, 0]
- s1 in [0, 7]
- s2 in [0, 0]
- )"));
-}
-
-TEST_F(TransposeTest, ThreadIndexingPartialBlock) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- fused_computation {
- %p0 = f64[24,2,6,4] parameter(0)
- ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0}
- }
-
- ENTRY main {
- %p0 = f64[24,2,6,4] parameter(0)
- ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput,
- calls=%fused_computation
- }
- )")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
- mlir::MLIRContext mlir_context;
- EXPECT_THAT(
- fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- d0 floordiv 32 + s0 * 4,
- d3,
- (d0 floordiv 4) mod 8,
- d0 mod 4
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 1]
- d4 in [0, 0]
- d5 in [0, 0]
- s0 in [0, 5]
- s1 in [0, 0]
- s2 in [0, 0]
- d0 mod 32 in [0, 23]
- )"));
- EXPECT_THAT(
- fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- s0,
- d0 floordiv 32,
- d3,
- d0 mod 32
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 1]
- d4 in [0, 0]
- d5 in [0, 0]
- s0 in [0, 5]
- s1 in [0, 0]
- s2 in [0, 0]
- d0 mod 32 in [0, 23]
- )"));
-}
-
-TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fusion {
- %input = f32[100,32,64] parameter(0)
- %transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
- %bitcast = f32[100,2048] bitcast(%input)
- ROOT %tuple = (f32[100,64,32], f32[100,2048]) tuple(%transpose, %bitcast)
- }
-
- ENTRY entry {
- %input = f32[100,32,64] parameter(0)
- ROOT %fusion = (f32[100,64,32], f32[100,2048]) fusion(%input), kind=kInput, calls=fusion
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
- mlir::MLIRContext mlir_context;
-
- EXPECT_THAT(
- fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
- fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString());
-}
-
-TEST_F(TransposeTest, ThreadIndexingSideOutput) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fusion {
- %input0 = f32[100,32,64] parameter(0)
- %input1 = f32[100,32] parameter(1)
- %transpose = f32[100,64,32] transpose(%input0), dimensions={0,2,1}
- %broadcast = f32[100,32,64] broadcast(%input1), dimensions={0,1}
- ROOT %tuple = (f32[100,64,32], f32[100,32,64]) tuple(%transpose, %broadcast)
- }
-
- ENTRY entry {
- %input0 = f32[100,32,64] parameter(0)
- %input1 = f32[100,32] parameter(1)
- ROOT %fusion = (f32[100,64,32], f32[100,32,64]) fusion(%input0, %input1), kind=kInput, calls=fusion
- })")
- .value();
-
- auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
-
- TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
- mlir::MLIRContext mlir_context;
- // Check if side output `%broadcast` get the correct input indexing, which
- // should corresponds to `%input1` with shape [100,32].
- EXPECT_THAT(
- fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- d3 floordiv 2,
- d0 floordiv 32 + s1 * 4
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 199]
- d4 in [0, 0]
- d5 in [0, 0]
-
- s0 in [0, 0]
- s1 in [0, 7]
- s2 in [0, 0]
- )"));
- EXPECT_THAT(
- fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(),
- MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- d3 floordiv 2,
- d0 floordiv 32 + s1 * 4,
- (d3 mod 2) * 32 + d0 mod 32
- )
- domain:
- d0 in [0, 127]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 199]
- d4 in [0, 0]
- d5 in [0, 0]
-
- s0 in [0, 0]
- s1 in [0, 7]
- s2 in [0, 0]
- )"));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD
index 8c9eadf..3bedbcf 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD
@@ -1,3 +1,4 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
load("//xla:xla.bzl", "xla_cc_test")
@@ -32,8 +33,7 @@
]),
hdrs = ["triton_fusion_emitter.h"],
deps = [
- ":prevent_mmav3_loop_unrolling",
- ":sparse_extensions",
+ ":passes",
"//xla:autotuning_proto_cc",
"//xla:comparison_util",
"//xla:debug_options_flags",
@@ -58,9 +58,9 @@
"//xla/service/gpu:target_util",
"//xla/service/gpu:triton_fusion_analysis",
"//xla/service/gpu:triton_tiling_propagation",
+ "//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
- "//xla/service/gpu/fusions/mlir:passes",
- "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+ "//xla/service/gpu/fusions/transforms:passes",
"//xla/service/gpu/llvm_gpu_backend",
"//xla/service/gpu/model:affine_map_printer",
"//xla/service/gpu/model:indexing_analysis",
@@ -135,11 +135,33 @@
]),
)
+gentbl_cc_library(
+ name = "passes_inc_gen",
+ tbl_outs = [
+ (
+ [
+ "-gen-pass-decls",
+ "-name=TritonFusionTransforms",
+ ],
+ "passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "passes.td",
+ visibility = ["//visibility:private"],
+ deps = ["@llvm-project//mlir:PassBaseTdFiles"],
+)
+
cc_library(
- name = "sparse_extensions",
- srcs = ["sparse_extensions.cc"],
- hdrs = ["sparse_extensions.h"],
+ name = "passes",
+ srcs = [
+ "passes.cc",
+ "prevent_mmav3_loop_unrolling.cc",
+ "sparse_extensions.cc",
+ ],
+ hdrs = ["passes.h"],
deps = [
+ ":passes_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:GPUCommonTransforms",
@@ -151,6 +173,7 @@
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
@@ -165,20 +188,6 @@
],
)
-cc_library(
- name = "prevent_mmav3_loop_unrolling",
- srcs = ["prevent_mmav3_loop_unrolling.cc"],
- hdrs = ["prevent_mmav3_loop_unrolling.h"],
- deps = [
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:SCFDialect",
- "@llvm-project//mlir:Support",
- "@triton//:TritonDialects",
- ],
-)
-
xla_test(
name = "triton_fusion_emitter_device_legacy_test",
srcs = if_gpu_is_configured(["triton_fusion_emitter_device_legacy_test.cc"]),
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc
index 471a91c..46a569d 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc
@@ -24,8 +24,7 @@
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
-#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h"
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/hlo_module_config.h"
@@ -65,9 +64,9 @@
pm.addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps,
threadsPerWarp, block_level_parameters.num_ctas));
- pm.addPass(CreateAddSparseDotEncodingPass(block_level_parameters.num_warps,
- threadsPerWarp,
- block_level_parameters.num_ctas));
+ pm.addPass(CreateSparseAddEncodingPass(block_level_parameters.num_warps,
+ threadsPerWarp,
+ block_level_parameters.num_ctas));
pm.addPass(mt::gpu::createTritonGPUCoalesce());
if (ccCuda.IsAtLeastAmpere()) {
pm.addPass(mt::gpu::createTritonGPUF32DotTC());
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
index e31e29b..2a95ea8 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
@@ -22,7 +22,6 @@
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
@@ -57,8 +56,6 @@
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const BlockLevelParameters& block_level_parameters,
mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
- // TODO(ROCm): Check whether value different than 0 can be used.
- const int ccAsInt = 0;
// TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
const int threadsPerWarp = 32;
auto ccRocm = std::get<se::RocmComputeCapability>(cc);
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.cc b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc
new file mode 100644
index 0000000..fefb0f7
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc
@@ -0,0 +1,38 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/triton/passes.h"
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Visitors.h"
+
+namespace xla::gpu {
+namespace {
+
+using ::mlir::WalkResult;
+
+} // namespace
+
+bool ContainsOp(mlir::Operation* op,
+ llvm::function_ref<bool(mlir::Operation*)> fn) {
+ return op
+ ->walk([&](mlir::Operation* nested_op) {
+ return fn(nested_op) ? WalkResult::interrupt() : WalkResult::advance();
+ })
+ .wasInterrupted();
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.h b/third_party/xla/xla/service/gpu/fusions/triton/passes.h
new file mode 100644
index 0000000..39066ba
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.h
@@ -0,0 +1,50 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_
+#define XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_
+
+#include <cstdint>
+#include <memory>
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace xla::gpu {
+
+#define GEN_PASS_DECL
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
+std::unique_ptr<mlir::Pass> CreateSparseAddEncodingPass(
+ int32_t num_warps = 4, int32_t threads_per_warp = 32, int32_t num_ctas = 1);
+std::unique_ptr<mlir::Pass> CreateSparseBlockedToMMAPass();
+std::unique_ptr<mlir::Pass> CreateSparseRemoveLayoutConversionPass();
+std::unique_ptr<mlir::Pass> CreateSparseLocalLoadToLLVMPass();
+std::unique_ptr<mlir::Pass> CreateSparseDotOpToLLVMPass();
+std::unique_ptr<mlir::Pass> CreateSparseWGMMAOpToLLVMPass();
+std::unique_ptr<mlir::Pass> CreatePreventMmaV3LoopUnrollingPass();
+
+// Returns true if the `op` contains an operation in it's regions that satisfies
+// the `fn`.
+bool ContainsOp(mlir::Operation* op,
+ llvm::function_ref<bool(mlir::Operation*)> fn);
+
+#define GEN_PASS_REGISTRATION
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.td b/third_party/xla/xla/service/gpu/fusions/triton/passes.td
new file mode 100644
index 0000000..b1366d0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.td
@@ -0,0 +1,94 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_
+#define XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_
+
+include "mlir/Pass/PassBase.td"
+
+def SparseAddEncodingPass : Pass<"sparse-add-encoding", "mlir::ModuleOp"> {
+ let summary = "Adds sparse dot encoding.";
+ let options = [
+ Option<"num_warps_", "num-warps", "int32_t", /*default=*/"4",
+ "Number of warps">,
+ Option<"threads_per_warp_", "threads-per-warp", "int32_t", /*default=*/"32",
+ "Number of threads per warp">,
+ Option<"num_ctas_", "num-ctas", "int32_t", /*default=*/"1",
+ "Number of CTAs in a CGA">,
+ ];
+ let dependentDialects = [
+ "triton::gpu::TritonGPUDialect",
+ ];
+ let constructor = "CreateSparseAddEncodingPass()";
+}
+
+def SparseBlockedToMMAPass : Pass<"sparse-blocked-to-mma", "mlir::ModuleOp"> {
+ let summary = "Add convert layouts to/from MMA before and after SparseDotOp.";
+ let dependentDialects = [
+ "triton::gpu::TritonGPUDialect",
+ ];
+ let constructor = "CreateSparseBlockedToMMAPass()";
+}
+
+def SparseRemoveLayoutConversionPass
+ : Pass<"sparse-remove-layout-conversion", "mlir::ModuleOp"> {
+ let summary = "Replaces ConvertLayoutOp with sparse dot encoding";
+ let dependentDialects = [
+ "triton::gpu::TritonGPUDialect",
+ ];
+ let constructor = "CreateSparseRemoveLayoutConversionPass()";
+}
+
+def SparseLocalLoadToLLVMPass
+ : Pass<"sparse-local-load-to-llvm", "mlir::ModuleOp"> {
+ let summary = "Lowers sparse local load to LLVM";
+ let dependentDialects = [
+ "triton::gpu::TritonGPUDialect",
+ "mlir::LLVM::LLVMDialect"
+ ];
+ let constructor = "CreateSparseLocalLoadToLLVMPass()";
+}
+
+def SparseDotOpToLLVMPass : Pass<"sparse-dot-to-llvm", "mlir::ModuleOp"> {
+ let summary = "Lowers sparse dot to LLVM";
+ let constructor = "CreateSparseDotOpToLLVMPass()";
+ let dependentDialects = [
+ "triton::gpu::TritonGPUDialect",
+ "mlir::triton::nvgpu::NVGPUDialect",
+ ];
+}
+
+def SparseWGMMAOpToLLVMPass : Pass<"sparse-wgmma-to-llvm", "mlir::ModuleOp"> {
+ let summary = "Lowers sparse WGMMA to LLVM";
+ let dependentDialects = [
+ "triton::gpu::TritonGPUDialect",
+ "mlir::triton::nvgpu::NVGPUDialect",
+ ];
+ let constructor = "CreateSparseWGMMAOpToLLVMPass()";
+}
+
+def PreventMmaV3LoopUnrollingPass
+ : Pass<"prevent-mmav3-loop-unrolling", "mlir::ModuleOp"> {
+ let summary = "Prevent MMAv3 loop unrolling.";
+ let description = [{
+ This pass is a result of b/344841434:
+ PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in
+ compilation time. Most unrolling has already been done before PTX,
+ this pragma prevents ptxas from doing more.
+ }];
+ let constructor = "CreatePreventMmaV3LoopUnrollingPass()";
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc
index 7cb0a55..e5b3d4e 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc
@@ -13,29 +13,27 @@
limitations under the License.
==============================================================================*/
-#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h"
-
#include <memory>
-#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
-class PreventMmaV3LoopUnrollingPass
- : public mlir::PassWrapper<PreventMmaV3LoopUnrollingPass,
- mlir::OperationPass<mlir::ModuleOp>> {
- public:
- llvm::StringRef getArgument() const override {
- return "prevent-mmav3-loop-unrolling";
- }
+namespace xla::gpu {
+namespace {
+#define GEN_PASS_DEF_PREVENTMMAV3LOOPUNROLLINGPASS
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
+struct PreventMmaV3LoopUnrollingPass
+ : public impl::PreventMmaV3LoopUnrollingPassBase<
+ PreventMmaV3LoopUnrollingPass> {
// TODO(b/344841434): Remove this if NVIDIA fixes compile-time issue.
// PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in
// compilation time. Most unrolling has already been done before PTX;
@@ -60,10 +58,10 @@
}
};
-std::unique_ptr<mlir::Pass> xla::gpu::CreatePreventMmaV3LoopUnrollingPass() {
+} // namespace
+
+std::unique_ptr<mlir::Pass> CreatePreventMmaV3LoopUnrollingPass() {
return std::make_unique<PreventMmaV3LoopUnrollingPass>();
}
-void xla::gpu::RegisterPreventMmaV3LoopUnrollingPass() {
- registerPass(CreatePreventMmaV3LoopUnrollingPass);
-}
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h
deleted file mode 100644
index f8e1af0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_
-#define XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_
-
-#include <memory>
-
-#include "mlir/Pass/Pass.h"
-
-namespace xla::gpu {
-
-// This pass is a result of b/344841434:
-// PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in
-// compilation time. Most unrolling has already been done before PTX,
-// this pragma prevents ptxas from doing more.
-std::unique_ptr<mlir::Pass> CreatePreventMmaV3LoopUnrollingPass();
-
-void RegisterPreventMmaV3LoopUnrollingPass();
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc
index bfc9d7a..5b2d331 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc
@@ -13,8 +13,6 @@
limitations under the License.
==============================================================================*/
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
-
#include <algorithm>
#include <cassert>
#include <cstdint>
@@ -48,9 +46,9 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Membar.h"
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
@@ -65,6 +63,13 @@
using namespace mlir; // NOLINT(build/namespaces)
+namespace ttn = triton::nvgpu;
+using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
+using ::mlir::triton::gpu::getShapePerCTA;
+using ::mlir::triton::gpu::getShapePerCTATile;
+using ::mlir::triton::gpu::SharedEncodingAttr;
+using ttn::OperandsAndConstraints;
+
// The functions below are defined in AccelerateMatmul.cpp.
namespace mlir::triton::gpu {
SmallVector<unsigned, 3> getWarpsPerTile(
@@ -80,13 +85,22 @@
int64_t swizzling, uint32_t stride);
int64_t getSwizzlingFromLayout(const triton::gpu::SharedEncodingAttr &layout,
uint32_t widthInByte);
-triton::nvgpu::WGMMAEltType getMmaRetType(Value);
-triton::nvgpu::WGMMAEltType getMmaOperandType(Value, bool);
+ttn::WGMMAEltType getMmaRetType(Value);
+ttn::WGMMAEltType getMmaOperandType(Value, bool);
+namespace xla::gpu {
namespace {
+#define GEN_PASS_DEF_SPARSEADDENCODINGPASS
+#define GEN_PASS_DEF_SPARSEBLOCKEDTOMMAPASS
+#define GEN_PASS_DEF_SPARSEDOTOPTOLLVMPASS
+#define GEN_PASS_DEF_SPARSELOCALLOADTOLLVMPASS
+#define GEN_PASS_DEF_SPARSEREMOVELAYOUTCONVERSIONPASS
+#define GEN_PASS_DEF_SPARSEWGMMAOPTOLLVMPASS
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
// Add sparse encoding for all the arguments of a SparseDotOp.
-struct AddSparseEncoding
+struct SparseAddEncoding
: public OpConversionPattern<triton::gpu::SparseDotOp> {
using OpConversionPattern<triton::gpu::SparseDotOp>::OpConversionPattern;
@@ -179,29 +193,16 @@
}
};
-class AddSparseEncodingPass
- : public PassWrapper<AddSparseEncodingPass, OperationPass<ModuleOp>> {
- public:
- AddSparseEncodingPass() = default;
- AddSparseEncodingPass(int32_t num_warps, int32_t threads_per_warp,
- int32_t num_ctas) {
- num_warps_ = num_warps;
- threads_per_warp_ = threads_per_warp;
- num_ctas_ = num_ctas;
- }
- AddSparseEncodingPass(const AddSparseEncodingPass &other) {
- num_warps_ = other.num_warps_;
- threads_per_warp_ = other.threads_per_warp_;
- num_ctas_ = other.num_ctas_;
- };
-
- StringRef getArgument() const override { return "add-sparse-encoding"; }
+struct SparseAddEncodingPass
+ : public impl::SparseAddEncodingPassBase<SparseAddEncodingPass> {
+ using impl::SparseAddEncodingPassBase<
+ SparseAddEncodingPass>::SparseAddEncodingPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
TritonGPUTypeConverter type_converter(context, num_warps_,
threads_per_warp_, num_ctas_);
- auto pattern = std::make_unique<AddSparseEncoding>(type_converter, context);
+ auto pattern = std::make_unique<SparseAddEncoding>(type_converter, context);
RewritePatternSet patterns(context, std::move(pattern));
TritonGPUConversionTarget target(*context, type_converter);
target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
@@ -212,18 +213,6 @@
std::move(patterns))))
return signalPassFailure();
}
-
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSparseEncodingPass)
-
- private:
- Option<int32_t> num_warps_{
- *this, "num-warps", llvm::cl::desc("number of warps"), llvm::cl::init(4)};
- Option<int32_t> threads_per_warp_{
- *this, "threads-per-warp", llvm::cl::desc("number of threads per warp"),
- llvm::cl::init(32)};
- Option<int32_t> num_ctas_{*this, "num-ctas",
- llvm::cl::desc("number of ctas in a cga"),
- llvm::cl::init(1)};
};
// Add convert layouts to and from MMA before and after SparseDotOp. In MMAV3,
@@ -332,13 +321,8 @@
int compute_capability_;
};
-class SparseBlockedToMMAPass
- : public PassWrapper<SparseBlockedToMMAPass, OperationPass<ModuleOp>> {
- public:
- SparseBlockedToMMAPass() = default;
-
- StringRef getArgument() const override { return "sparse-blocked-to-mma"; }
-
+struct SparseBlockedToMMAPass
+ : public impl::SparseBlockedToMMAPassBase<SparseBlockedToMMAPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
@@ -350,8 +334,6 @@
return signalPassFailure();
}
}
-
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseBlockedToMMAPass)
};
class SparseLocalLoadToLLVM
@@ -469,16 +451,9 @@
}
};
-class SparseRemoveLayoutConversionPass
- : public PassWrapper<SparseRemoveLayoutConversionPass,
- OperationPass<ModuleOp>> {
- public:
- SparseRemoveLayoutConversionPass() = default;
-
- StringRef getArgument() const override {
- return "sparse-remove-layout-conversion";
- }
-
+struct SparseRemoveLayoutConversionPass
+ : public impl::SparseRemoveLayoutConversionPassBase<
+ SparseRemoveLayoutConversionPass> {
void runOnOperation() override {
getOperation().walk([&](triton::gpu::ConvertLayoutOp op) {
ImplicitLocOpBuilder builder(op.getLoc(), op);
@@ -507,35 +482,22 @@
op.erase();
});
}
-
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseRemoveLayoutConversionPass)
};
-class SparseLocalLoadToLLVMPass
- : public PassWrapper<SparseLocalLoadToLLVMPass, OperationPass<ModuleOp>> {
- public:
- SparseLocalLoadToLLVMPass() = default;
+bool IsLocalLoadWithSparseEncoding(Operation *op) {
+ auto local_load = mlir::dyn_cast<triton::gpu::LocalLoadOp>(op);
+ if (!local_load) return false;
+ return isa<triton::gpu::SparseDotMetaEncodingAttr>(
+ local_load.getType().getEncoding());
+}
- StringRef getArgument() const override { return "sparse-local-load-to-llvm"; }
-
- void getDependentDialects(mlir::DialectRegistry ®istry) const override {
- registry.insert<LLVM::LLVMDialect, mlir::gpu::GPUDialect,
- arith::ArithDialect>();
- }
-
+struct SparseLocalLoadToLLVMPass
+ : public impl::SparseLocalLoadToLLVMPassBase<SparseLocalLoadToLLVMPass> {
void runOnOperation() override {
// Exit early if there are no sparse ops.
- mlir::ModuleOp mod = getOperation();
- if (!mod.walk([](triton::gpu::LocalLoadOp op) {
- if (isa<triton::gpu::SparseDotMetaEncodingAttr>(
- op.getType().getEncoding())) {
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- })
- .wasInterrupted()) {
- return;
- }
+ ModuleOp mod = getOperation();
+ if (!ContainsOp(mod, IsLocalLoadWithSparseEncoding)) return;
+
// Allocate shared memory and set barrier
// This is also done in the TritonGPUToLLVMPass but we need to do it before
// we write the local load op to LLVM to have barriers in the right place.
@@ -553,7 +515,7 @@
return !isa<triton::gpu::SparseDotMetaEncodingAttr>(
op.getType().getEncoding());
});
- mlir::LowerToLLVMOptions option(context);
+ LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
auto pattern = std::make_unique<SparseLocalLoadToLLVM>(typeConverter);
RewritePatternSet patterns(context, std::move(pattern));
@@ -562,15 +524,8 @@
return signalPassFailure();
}
}
-
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass)
};
-using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
-using ::mlir::triton::gpu::getShapePerCTA;
-using ::mlir::triton::gpu::getShapePerCTATile;
-using ::mlir::triton::gpu::SharedEncodingAttr;
-
using ValueTableV2 = std::map<std::pair<unsigned, unsigned>, Value>;
constexpr int kContractingFactor = 2; // implied by N:M (2:4)
@@ -818,17 +773,17 @@
assert(hMetaPacked.size() == repM * repK);
// Generate prologue.
- triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(op.getA(), false);
- triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(op.getB(), false);
- triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(op.getD());
+ ttn::WGMMAEltType eltTypeA = getMmaOperandType(op.getA(), false);
+ ttn::WGMMAEltType eltTypeB = getMmaOperandType(op.getB(), false);
+ ttn::WGMMAEltType eltTypeC = getMmaRetType(op.getD());
- triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col
- : triton::nvgpu::WGMMALayout::row;
- triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row
- : triton::nvgpu::WGMMALayout::col;
+ ttn::WGMMALayout layoutA =
+ transA ? ttn::WGMMALayout::col : ttn::WGMMALayout::row;
+ ttn::WGMMALayout layoutB =
+ transB ? ttn::WGMMALayout::row : ttn::WGMMALayout::col;
- rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
- rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
+ rewriter.create<ttn::FenceAsyncSharedOp>(loc, 0);
+ rewriter.create<ttn::WGMMAFenceOp>(loc);
// Generate main loop.
for (int m = 0; m < repM; ++m) {
@@ -841,7 +796,7 @@
Value a = loadA(m, k);
Value b = loadB(n, k);
Value meta = hMetaPacked[k * repM + m];
- d = rewriter.create<triton::nvgpu::SparseWGMMAOp>(
+ d = rewriter.create<ttn::SparseWGMMAOp>(
loc, accTy, a, meta, b, d, kWarpsInGroup * instrShape[0],
instrShape[1], kContractingFactor * instrShape[2], eltTypeC,
eltTypeA, eltTypeB, layoutA, layoutB);
@@ -858,8 +813,8 @@
op.getContext(), SmallVector<Type>(fc.size(), f32_ty));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);
- rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
- res = rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, res, 0);
+ rewriter.create<ttn::WGMMACommitGroupOp>(loc);
+ res = rewriter.create<ttn::WGMMAWaitGroupOp>(loc, res, 0);
rewriter.replaceOp(op, res);
return success();
@@ -899,26 +854,16 @@
}
};
-class SparseDotOpToLLVMPass
- : public PassWrapper<SparseDotOpToLLVMPass, OperationPass<ModuleOp>> {
- public:
- SparseDotOpToLLVMPass() = default;
-
- StringRef getArgument() const override { return "sparse-dot-to-llvm"; }
-
- void getDependentDialects(mlir::DialectRegistry ®istry) const override {
- registry.insert<LLVM::LLVMDialect, mlir::gpu::GPUDialect,
- arith::ArithDialect>();
- }
-
+struct SparseDotOpToLLVMPass
+ : public impl::SparseDotOpToLLVMPassBase<SparseDotOpToLLVMPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect, NVVM::NVVMDialect,
- arith::ArithDialect, triton::nvgpu::NVGPUDialect>();
+ arith::ArithDialect, ttn::NVGPUDialect>();
target.addIllegalOp<triton::gpu::SparseDotOp>();
target.addIllegalDialect<mlir::gpu::GPUDialect>();
- mlir::LowerToLLVMOptions option(context);
+ LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
RewritePatternSet patterns(context);
patterns.add<SparseDotOpConversion>(typeConverter);
@@ -928,13 +873,8 @@
return signalPassFailure();
}
}
-
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass)
};
-namespace ttn = mlir::triton::nvgpu;
-using ttn::OperandsAndConstraints;
-
class SparseWGMMAOpPattern : public OpRewritePattern<ttn::SparseWGMMAOp> {
public:
using OpRewritePattern<ttn::SparseWGMMAOp>::OpRewritePattern;
@@ -1021,13 +961,8 @@
}
};
-class SparseWGMMAOpToLLVMPass
- : public PassWrapper<SparseWGMMAOpToLLVMPass, OperationPass<ModuleOp>> {
- public:
- SparseWGMMAOpToLLVMPass() = default;
-
- StringRef getArgument() const override { return "sparse-wgmma-to-llvm"; }
-
+struct SparseWGMMAOpToLLVMPass
+ : public impl::SparseWGMMAOpToLLVMPassBase<SparseWGMMAOpToLLVMPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto pattern = std::make_unique<SparseWGMMAOpPattern>(context);
@@ -1037,43 +972,38 @@
return signalPassFailure();
}
}
-
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass)
};
} // namespace
-std::unique_ptr<Pass> xla::gpu::CreateAddSparseDotEncodingPass(
- int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas) {
- return std::make_unique<AddSparseEncodingPass>(num_warps, threads_per_warp,
- num_ctas);
+std::unique_ptr<Pass> CreateSparseAddEncodingPass(int32_t num_warps,
+ int32_t threads_per_warp,
+ int32_t num_ctas) {
+ SparseAddEncodingPassOptions options;
+ options.num_warps_ = num_warps;
+ options.threads_per_warp_ = threads_per_warp;
+ options.num_ctas_ = num_ctas;
+ return std::make_unique<SparseAddEncodingPass>(options);
}
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseBlockedToMMAPass() {
+std::unique_ptr<Pass> CreateSparseBlockedToMMAPass() {
return std::make_unique<SparseBlockedToMMAPass>();
}
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseRemoveLayoutConversionPass() {
+std::unique_ptr<Pass> CreateSparseRemoveLayoutConversionPass() {
return std::make_unique<SparseRemoveLayoutConversionPass>();
}
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseLocalLoadToLLVMPass() {
+std::unique_ptr<Pass> CreateSparseLocalLoadToLLVMPass() {
return std::make_unique<SparseLocalLoadToLLVMPass>();
}
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseDotOpToLLVMPass() {
+std::unique_ptr<Pass> CreateSparseDotOpToLLVMPass() {
return std::make_unique<SparseDotOpToLLVMPass>();
}
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseWGMMAOpToLLVMPass() {
+std::unique_ptr<Pass> CreateSparseWGMMAOpToLLVMPass() {
return std::make_unique<SparseWGMMAOpToLLVMPass>();
}
-void xla::gpu::RegisterSparsePasses() {
- registerPass([] { return std::make_unique<AddSparseEncodingPass>(); });
- registerPass(CreateSparseBlockedToMMAPass);
- registerPass(CreateSparseRemoveLayoutConversionPass);
- registerPass(CreateSparseLocalLoadToLLVMPass);
- registerPass(CreateSparseDotOpToLLVMPass);
- registerPass(CreateSparseWGMMAOpToLLVMPass);
-}
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h
deleted file mode 100644
index 988a63d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_
-#define XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_
-
-#include <cstdint>
-#include <memory>
-
-#include "mlir/Pass/Pass.h"
-
-namespace xla::gpu {
-
-std::unique_ptr<mlir::Pass> CreateAddSparseDotEncodingPass(
- int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas);
-std::unique_ptr<mlir::Pass> CreateSparseBlockedToMMAPass();
-std::unique_ptr<mlir::Pass> CreateSparseRemoveLayoutConversionPass();
-std::unique_ptr<mlir::Pass> CreateSparseLocalLoadToLLVMPass();
-std::unique_ptr<mlir::Pass> CreateSparseDotOpToLLVMPass();
-std::unique_ptr<mlir::Pass> CreateSparseWGMMAOpToLLVMPass();
-
-void RegisterSparsePasses();
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
index 20fa19c..e6c0374 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
@@ -110,9 +110,9 @@
#include "xla/service/algorithm_util.h"
#include "xla/service/dump.h"
#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
@@ -1507,7 +1507,7 @@
}
}
}
- CHECK(to_order.insert(current).second);
+ to_order.insert(current);
to_add.pop();
}
}
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h
index fe133d8..3f7c3bc 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h
@@ -17,7 +17,6 @@
#define XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_FUSION_EMITTER_H_
#include <cstdint>
-#include <functional>
#include <optional>
#include <string>
@@ -87,19 +86,6 @@
mlir::triton::FuncOp fn,
const BlockLevelParameters& block_level_parameters);
-// Generate Softmax in Triton IR inside 'fn'.
-// Use execution parameters from 'block_level_parameters'.
-absl::Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path,
- const se::DeviceDescription& device_info,
- const HloFusionInstruction* fusion,
- mlir::triton::FuncOp fn,
- const BlockLevelParameters& block_level_parameters);
-
-using TritonIrEmitter = std::function<absl::Status(
- mlir::OpBuilder, absl::string_view, const se::DeviceDescription&,
- const HloFusionInstruction*, mlir::triton::FuncOp,
- const BlockLevelParameters&)>;
-
// Load the MLIR dialects required for Triton IR generation.
void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context);
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc
index c162c34..720a8b4 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc
@@ -1549,6 +1549,35 @@
kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
}
+TEST_F(TritonGemmTest, MultiplePathsToSameOperandWorks) {
+ const std::string kHloText = R"(
+triton_computation {
+ p0 = bf16[8192,512]{1,0} parameter(0)
+ p1 = bf16[512,512]{1,0} parameter(1)
+ dot = bf16[8192,512]{1,0} dot(bf16[8192,512]{1,0} p0, bf16[512,512]{1,0} p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ p2 = bf16[8192,512]{1,0} parameter(2)
+ multiply.1 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} dot, bf16[8192,512]{1,0} p2)
+ ROOT multiply.2 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} multiply.1, bf16[8192,512]{1,0} p2)
+}
+
+ENTRY e {
+ p0 = bf16[8192,512]{1,0} parameter(0)
+ p1 = bf16[512,512]{1,0} parameter(1)
+ p2 = bf16[8192,512]{1,0} parameter(2)
+ ROOT fusion = bf16[8192,512]{1,0} fusion(p0,p1,p2), kind=kCustom, calls=triton_computation,
+ backend_config={"fusion_backend_config":
+ {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"256","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}}
+})";
+
+ TF_ASSERT_OK(
+ CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_computation", R"(
+ CHECK: tt.dot
+ CHECK-SAME: tensor<64x32xbf16> * tensor<32x256xbf16> -> tensor<64x256xf32>
+ CHECK: arith.mulf
+ CHECK: arith.mulf
+ )"));
+}
+
class TritonGemmDynamicSliceClampingTest
: public TritonTest,
public ::testing::WithParamInterface<int> {};
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc
index 8f07bba..d59b6cc 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc
@@ -214,7 +214,7 @@
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
FromOutputTileSizes({1, 127}),
"triton_softmax_computation", R"(
-CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
CHECK: %[[PID:.*]] = tt.get_program_id x : i32
CHECK: arith.index_castui %[[PID]] : i32 to index
@@ -272,7 +272,7 @@
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
FromOutputTileSizes({1, 127}),
"triton_softmax_computation", R"(
-CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
CHECK-LABEL: tt.func @triton_fn(
CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr<f32>
CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr<f32>
@@ -339,7 +339,7 @@
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
FromOutputTileSizes({1, 1, 127}),
"triton_softmax_computation", R"(
-CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32
CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index
@@ -517,7 +517,7 @@
TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText,
FromOutputTileSizes({1, 1, 16}),
"triton_softmax_computation", R"(
-// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 16)
// CHECK-LABEL: tt.func @triton_fn(
// CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr<f32>
// CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr<f32>
@@ -674,7 +674,7 @@
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
FromOutputTileSizes({1, 127}),
"triton_softmax_computation", R"(
-CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
CHECK: %[[PID:.*]] = tt.get_program_id x : i32
CHECK: arith.index_castui %[[PID]] : i32 to index
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc
index f946cc4..7ae8f5e 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc
@@ -52,13 +52,6 @@
class MixedTypeTest : public GpuCodegenTest,
public ::testing::WithParamInterface<MixTypeParams> {
public:
- se::CudaComputeCapability GetCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
-
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
// We are testing Triton, remove cuBLAS fallback for these tests.
@@ -803,13 +796,6 @@
debug_options.clear_xla_disable_hlo_passes();
return debug_options;
}
-
- se::CudaComputeCapability GetCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
};
TEST_P(TritonSoftmaxTest, CanFuseAndEmitExactSoftmax) {
diff --git a/third_party/xla/xla/service/gpu/fusions/triton_test.cc b/third_party/xla/xla/service/gpu/fusions/triton_test.cc
index c2cfabf..1738d6f 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton_test.cc
@@ -64,7 +64,7 @@
TestGpuDeviceInfo::RTXA6000DeviceInfo();
auto* root = module->entry_computation()->root_instruction();
- HloFusionAnalysis analysis = AnalyzeFusion(*root, device_info);
+ HloFusionAnalysis analysis = HloFusionAnalysis::Create(*root, device_info);
std::unique_ptr<FusionInterface> emitter =
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
@@ -100,7 +100,7 @@
TestGpuDeviceInfo::RTXA6000DeviceInfo();
auto* root = module->entry_computation()->root_instruction();
- HloFusionAnalysis analysis = AnalyzeFusion(*root, device_info);
+ HloFusionAnalysis analysis = HloFusionAnalysis::Create(*root, device_info);
std::unique_ptr<FusionInterface> emitter =
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
deleted file mode 100644
index a2de14c..0000000
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
+++ /dev/null
@@ -1,500 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/types/span.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/gpu/variant_visitor.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/tsl/util/proto/proto_utils.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/profiler/lib/scoped_annotation.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using se::gpu::BlasLt;
-
-absl::StatusOr<BlasLt::Epilogue> AsBlasLtEpilogue(
- GemmBackendConfig_Epilogue epilogue) {
- switch (epilogue) {
- case GemmBackendConfig::DEFAULT:
- return BlasLt::Epilogue::kDefault;
- case GemmBackendConfig::RELU:
- return BlasLt::Epilogue::kReLU;
- case GemmBackendConfig::GELU:
- return BlasLt::Epilogue::kGELU;
- case GemmBackendConfig::GELU_AUX:
- return BlasLt::Epilogue::kGELUWithAux;
- case GemmBackendConfig::BIAS:
- return BlasLt::Epilogue::kBias;
- case GemmBackendConfig::BIAS_RELU:
- return BlasLt::Epilogue::kBiasThenReLU;
- case GemmBackendConfig::BIAS_GELU:
- return BlasLt::Epilogue::kBiasThenGELU;
- case GemmBackendConfig::BIAS_GELU_AUX:
- return BlasLt::Epilogue::kBiasThenGELUWithAux;
- default:
- return Internal("Unsupported Epilogue.");
- }
-}
-
-class GemmAutotuner {
- const AutotuneConfig& autotune_config_;
- RedzoneBuffers rz_buffers_;
- se::Stream* stream_ = nullptr;
- bool deterministic_ops_ = false;
- size_t solutions_limit_ = 0;
- size_t num_algorithms_left_ = 0;
-
- public:
- explicit GemmAutotuner(const AutotuneConfig& autotune_config)
- : autotune_config_(autotune_config) {}
-
- size_t num_algorithms_left() const { return num_algorithms_left_; }
-
- absl::StatusOr<AutotuneResult> operator()(const HloInstruction* gemm,
- const AutotuneCacheKey& key) {
- num_algorithms_left_ = 0;
- if (autotune_config_.IsDeviceless()) {
- // Return empty result, will tune at runtime.
- return AutotuneResult{};
- }
- VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
-
- TF_ASSIGN_OR_RETURN(stream_, autotune_config_.GetStream());
- const DebugOptions& debug_options =
- gemm->GetModule()->config().debug_options();
- deterministic_ops_ = RequireDeterminism(gemm->GetModule()->config());
- solutions_limit_ = debug_options.xla_gpu_autotune_max_solutions();
-
- TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm));
-
- // Don't run autotuning concurrently on the same GPU.
- absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent()));
-
- TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromInstruction(
- *gemm, autotune_config_, debug_options,
- RedzoneBuffers::kAllInputsAllOutputs));
-
- return IsCublasLtMatmul(*gemm) || IsCublasLtMatmulF8(*gemm)
- ? TuneGpuBlasLt(gemm, gemm_config)
- : TuneGpuBlas(gemm, gemm_config);
- }
-
- private:
- se::DeviceMemoryBase LhsBuffer() { return rz_buffers_.input_buffers().at(0); }
- se::DeviceMemoryBase RhsBuffer() { return rz_buffers_.input_buffers().at(1); }
- se::DeviceMemoryBase OutputBuffer() {
- return rz_buffers_.output_buffers().at(0);
- }
-
- const Shape& GetOutputShape(const HloInstruction* gemm) {
- return gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0)
- : gemm->shape();
- }
-
- absl::StatusOr<AutotuneResult> TuneGpuBlasLt(const HloInstruction* gemm,
- const GemmConfig& gemm_config) {
- auto workspace_buffer =
- rz_buffers_.output_buffers().at(gemm->shape().tuple_shapes_size() - 1);
-
- GpuBackendConfig gpu_config =
- gemm->backend_config<GpuBackendConfig>().value();
- const GemmBackendConfig& backend_config = gpu_config.gemm_backend_config();
-
- bool has_matrix_bias = gemm_config.beta != 0.;
-
- TF_ASSIGN_OR_RETURN(
- bool has_vector_bias,
- gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue()));
-
- TF_ASSIGN_OR_RETURN(
- bool has_aux_output,
- gpublas_lt::EpilogueHasAuxiliaryOutput(backend_config.epilogue()));
-
- TF_ASSIGN_OR_RETURN(auto epilogue,
- AsBlasLtEpilogue(backend_config.epilogue()));
-
- se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer,
- d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer;
-
- if (has_vector_bias) {
- bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2);
- }
- if (has_aux_output) {
- aux_buffer = rz_buffers_.output_buffers().at(1);
- }
-
- TF_ASSIGN_OR_RETURN(auto plan,
- BlasLt::GetMatmulPlan(stream_, gemm_config, epilogue));
-
- TF_ASSIGN_OR_RETURN(
- auto algorithms,
- plan->GetAlgorithms(/*max_algorithm_count*/ 128,
- /*max_workspace_size*/ workspace_buffer.size()));
-
- auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm)
- -> absl::StatusOr<se::blas::ProfileResult> {
- // Run a warmup iteration without the profiler active.
- TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
- stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
- bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
- c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
- workspace_buffer));
- se::blas::ProfileResult profile_result;
- profile_result.set_warmup_run_executed(true);
- TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
- stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
- bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
- c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
- workspace_buffer, &profile_result));
- return std::move(profile_result);
- };
-
- return GetBestAlgorithm<BlasLt::MatmulAlgorithm>(
- gemm, algorithms, gemm_config.beta, /*return_algo_index*/ true,
- tuned_func);
- }
-
- absl::StatusOr<AutotuneResult> TuneGpuBlas(const HloInstruction* gemm,
- const GemmConfig& gemm_config) {
- auto workspace_buffer = rz_buffers_.output_buffers().at(1);
-
- std::vector<se::blas::AlgorithmType> algorithms;
- TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc,
- gemm_config.GetMatrixDescriptors(
- LhsBuffer(), RhsBuffer(), OutputBuffer()));
-
- auto blas = stream_->parent()->AsBlas();
- if (blas == nullptr) {
- return absl::InternalError("No BLAS support for stream");
- }
- blas->GetBlasGemmAlgorithms(stream_, desc.lhs, desc.rhs, &desc.output,
- &gemm_config.alpha, &gemm_config.beta,
- &algorithms);
-
- auto tuned_func = [&](const se::blas::AlgorithmType& algorithm)
- -> absl::StatusOr<se::blas::ProfileResult> {
- // Do a warm-up run first, without a profile result. RunGemm swallows
- // error codes when profile_result is passed, as it is in the measurement
- // below, but not otherwise. It is, therefore, consistent to ignore the
- // error code here.
- static_cast<void>(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
- OutputBuffer(), workspace_buffer,
- deterministic_ops_, stream_, algorithm));
- se::blas::ProfileResult profile_result;
- // Allow GpuTimer to use its delay kernel implementation to improve
- // accuracy.
- profile_result.set_warmup_run_executed(true);
- // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
- // for all algorithms if we're targeting < sm_50. But because we pass a
- // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
- // and the actual success-ness is returned in ProfileResult::is_valid.
- TF_RETURN_IF_ERROR(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
- OutputBuffer(), workspace_buffer,
- deterministic_ops_, stream_, algorithm,
- &profile_result));
- return std::move(profile_result);
- };
-
- return GetBestAlgorithm<se::blas::AlgorithmType>(
- gemm, algorithms, gemm_config.beta, /*return_algo_index*/ false,
- tuned_func);
- }
-
- // Returns the index (into `algorithms`) of the fastest algorithm.
- template <typename AlgoT, typename TunedFunc>
- absl::StatusOr<AutotuneResult> GetBestAlgorithm(
- const HloInstruction* gemm, absl::Span<const AlgoT> algorithms,
- double beta, bool return_algo_index, TunedFunc&& run_benchmark) {
- static_assert(std::is_invocable_r_v<absl::StatusOr<se::blas::ProfileResult>,
- TunedFunc, const AlgoT&>,
- "Tuned function has incorrect prototype!");
-
- if (!stream_->parent()->SynchronizeAllActivity()) {
- return Internal("Failed to synchronize GPU for autotuning.");
- }
- tsl::profiler::ScopedAnnotation annotation([&] {
- return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
- gemm->name());
- });
-
- auto& hlo_module_config = gemm->GetModule()->mutable_config();
- const auto& output_shape = GetOutputShape(gemm);
-
- se::DeviceMemoryBase reference_buffer;
- if (autotune_config_.should_check_correctness()) {
- TF_ASSIGN_OR_RETURN(reference_buffer,
- rz_buffers_.RedzoneAllocator().AllocateBytes(
- ShapeUtil::ByteSizeOf(output_shape)));
- }
-
- // Do not print error messages if should_skip_wrong_results() is ON.
- BufferComparator comparator(
- output_shape,
- hlo_module_config.debug_options().xla_gpu_autotune_gemm_rtol(),
- /* verbose */ !autotune_config_.should_skip_wrong_results());
- std::vector<AutotuneResult> results;
- results.reserve(algorithms.size());
- std::optional<int64_t> reference_algorithm;
-
- auto num = algorithms.size();
- if (solutions_limit_ > 0) num = std::min(num, solutions_limit_);
- for (size_t i = 0; i < num; i++) {
- const AlgoT& algorithm = algorithms[i];
- // Make sure the output buffer always has the same value if we use
- // the bias parameter.
- if (autotune_config_.should_reinit_output_buffer() && beta != 0) {
- int64_t rng_state = 0;
- InitializeBuffer(stream_, output_shape.element_type(), &rng_state,
- OutputBuffer());
- }
- TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm));
-
- AutotuneResult& result = results.emplace_back();
- result.mutable_gemm()->set_algorithm(profile_result.algorithm());
-
- if (!profile_result.is_valid()) { // Unsupported algorithm.
- result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
- continue;
- }
-
- VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took "
- << profile_result.elapsed_time_in_ms() << "ms";
-
- *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
- absl::Milliseconds(profile_result.elapsed_time_in_ms()));
-
- if (!autotune_config_.should_check_correctness()) {
- num_algorithms_left_++;
- continue;
- }
- TF_ASSIGN_OR_RETURN(
- se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
- rz_buffers_.RedzoneAllocator().CheckRedzones());
-
- if (!rz_check_status.ok()) {
- result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
- *result.mutable_failure()->mutable_msg() =
- rz_check_status.RedzoneFailureMsg();
- LOG(ERROR) << "Detected out-of-bounds write in gemm buffer";
- CHECK(!autotune_config_.should_crash_on_check_failure());
- continue;
- }
-
- num_algorithms_left_++;
- if (!reference_algorithm) {
- TF_RETURN_IF_ERROR(stream_->Memcpy(&reference_buffer, OutputBuffer(),
- OutputBuffer().size()));
- reference_algorithm = profile_result.algorithm();
- continue;
- }
- // Perform the comparison versus the reference algorithm.
- TF_ASSIGN_OR_RETURN(
- bool outputs_match,
- comparator.CompareEqual(stream_, /*current=*/OutputBuffer(),
- /*expected=*/reference_buffer));
- if (!outputs_match) {
- LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
- << "This is likely a bug/unexpected loss of precision.";
- CHECK(!autotune_config_.should_crash_on_check_failure());
-
- // By default, autotuner does NOT really skip wrong results, but
- // merely prints out the above error message: this may lead to a
- // great confusion. When should_skip_wrong_results() is set to true,
- // solutions with accuracy problems will be disqualified.
- auto kind = AutotuneResult::WRONG_RESULT;
- if (autotune_config_.should_skip_wrong_results()) {
- kind = AutotuneResult::DISQUALIFIED;
- num_algorithms_left_--; // Decrement again since we disqualified it.
- }
- result.mutable_failure()->set_kind(kind);
- result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
- *reference_algorithm);
- }
- } // for algorithms
-
- absl::StatusOr<AutotuneResult> best =
- PickBestResult(results, gemm->ToString(), hlo_module_config);
- if (best.ok()) {
- // Note that, cublas-lt returns an opaque object as an algorithm ID,
- // therefore we need to convert it to the index from the algorithms list
- // (otherwise, we cannot store this ID inside a gemm_backend_config).
- // In contrast, legacy cublas returns a 32-bit integer algorithm ID which
- // can be readily stored inside an HLO (hence return_algo_index is false
- // for cublas case).
- if (!return_algo_index) return best;
- // Otherwise, map a real algorithm ID to its index among the results.
- for (size_t i = 0; i < results.size(); ++i) {
- if (best->gemm().algorithm() == results[i].gemm().algorithm()) {
- best->mutable_gemm()->set_algorithm(i);
- return best;
- }
- }
- return Internal("unknown best algorithm");
- }
- LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance "
- "might be suboptimal: "
- << best.status();
- return AutotuneResult{};
- } // GetBestAlgorithm
-}; // GemmAutotuner
-
-// Do Gemm Autotune without stream executor. Use results from autotune cache
-// only.
-absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,
- const AutotuneConfig& config,
- size_t* num_algorithms_left) {
- VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString();
-
- GpuBackendConfig gpu_config =
- gemm->backend_config<GpuBackendConfig>().value();
- GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config();
-
- *num_algorithms_left = 0;
- // Degenerate gemms replaced with memzero operation, no need to auto tune it.
- if (backend_config.alpha_real() == 0.0 &&
- backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) {
- VLOG(3) << "Skip degenerate gemm instruction auto tuning";
- return false;
- }
-
- AutotuneCacheKey key(config.GetModelStr(), *gemm);
- GemmAutotuner autotuner(config);
- TF_ASSIGN_OR_RETURN(AutotuneResult algorithm,
- AutotunerUtil::Autotune(
- gemm, config, [&] { return autotuner(gemm, key); }));
-
- *num_algorithms_left = autotuner.num_algorithms_left();
- auto old_algorithm = backend_config.selected_algorithm();
- bool update_algorithm =
- IsCublasLtMatmulF8(*gemm) ||
- std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) {
- // We only set the 'algorithm' field on
- // non-Ampere architectures, as for Ampere
- // it's ignored in any case.
- return !cc.IsAtLeast(
- se::CudaComputeCapability::AMPERE);
- },
- [](const se::RocmComputeCapability&) {
- return true; // TODO: not decided yet
- }},
- config.GetGpuComputeCapability());
-
- if (update_algorithm) {
- int64_t new_algorithm{};
- if (algorithm.has_gemm()) {
- new_algorithm = algorithm.gemm().algorithm();
- } else {
- // NOTE: runtime autotuning is no longer available => set to default
- new_algorithm = se::blas::kDefaultAlgorithm;
- }
-
- if (new_algorithm == old_algorithm &&
- backend_config.has_selected_algorithm()) {
- // We don't need to update the backend config if
- // the algorithm hasn't changed unless previously
- // the algorithm wasn't set explicitly.
- return false;
- }
-
- backend_config.set_selected_algorithm(new_algorithm);
- TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config));
- return true; // We changed `gemm`
- }
-
- return false; // No change to `gemm`
-}
-
-absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
- AutotuneConfig config,
- size_t* num_algorithms_left) {
- bool changed = false;
-
- for (HloInstruction* instr : computation->instructions()) {
- if (IsCublasGemm(*instr)) {
- size_t num_left;
- TF_ASSIGN_OR_RETURN(bool result,
- RunOnInstruction(instr, config, &num_left));
- // Gathering statistics on the algorithms left after tuning (for testing)
- *num_algorithms_left = std::max(*num_algorithms_left, num_left);
- changed |= result;
- }
- }
- return changed;
-}
-
-} // namespace
-
-absl::StatusOr<bool> GemmAlgorithmPicker::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_SCOPED_LOGGING_TIMER(
- absl::StrCat("GemmAlgorithmPicker for ", module->name()));
-
- num_algorithms_left_ = 0;
- if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
- VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
- return false;
- }
-
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_,
- &num_algorithms_left_));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h
deleted file mode 100644
index be2686d..0000000
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_
-#define XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_
-
-#include <functional>
-#include <optional>
-#include <string_view>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream_executor.h"
-
-namespace xla {
-namespace gpu {
-
-// GemmAlgorithmPicker supports two modes: device and deviceless.
-// In device mode, we run autotuning on the device and store autotune results.
-// In deviceless mode, we pass in some information related to the device and
-// use stored autotune results to rewrite Gemm instructions. If the required
-// autotune result is not stored, then algorithm is set to kRuntimeAutotuning.
-class GemmAlgorithmPicker : public HloModulePass {
- public:
- explicit GemmAlgorithmPicker(AutotuneConfig config) : config_(config) {}
-
- absl::string_view name() const override { return "gemm-algorithm-picker"; }
-
- size_t num_algorithms_left() const { return num_algorithms_left_; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- AutotuneConfig config_;
- // The number of valid algorithms used for autotuning (from the last call),
- // to be used for testing purposes.
- size_t num_algorithms_left_ = 0;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc
deleted file mode 100644
index 8049aa1..0000000
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc
+++ /dev/null
@@ -1,301 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-
-#include <cstdint>
-#include <variant>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/variant_visitor.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/platform_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-#include "tsl/protobuf/dnn.pb.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GemmAlgorithmPickerTest : public HloTestBase,
- public ::testing::WithParamInterface<bool> {
- public:
- GemmAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
-
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_cublaslt(GetParam());
- debug_options.set_xla_gpu_enable_triton_gemm(false);
- return debug_options;
- }
-
- const se::DeviceDescription& device_desc() {
- return backend().default_stream_executor()->GetDeviceDescription();
- }
-
- se::StreamExecutor* stream_exec() {
- return backend().default_stream_executor();
- }
- const se::DeviceDescription& gpu_device_desc() {
- return stream_exec()->GetDeviceDescription();
- }
- const se::GpuComputeCapability& gpu_comp() {
- return gpu_device_desc().gpu_compute_capability();
- }
-
- void SetUp() override {
- std::string_view name =
- ::testing::UnitTest::GetInstance()->current_test_info()->name();
- // We need special handling for BlasGetVersion test.
- bool blas_get_version = name.rfind("BlasGetVersion") == 0;
-
- std::visit(
- VariantVisitor{
- [&](const se::CudaComputeCapability& cc) {
- if (!blas_get_version && cc.IsAtLeastAmpere()) {
- GTEST_SKIP()
- << "Skipping this test for Ampere+ as it is supported "
- "and recommended with the Nvidia Volta+ GPUs.";
- }
- },
- [&](const se::RocmComputeCapability& cc) {
- if (blas_get_version) {
- auto version = std::stol(device_desc().runtime_version());
- if (version < 60200) {
- GTEST_SKIP()
- << "This API is not available on ROCM 6.1 and below.";
- }
- } else if (GetDebugOptionsForTest().xla_gpu_enable_cublaslt() &&
- !cc.has_hipblaslt()) {
- GTEST_SKIP() << "No gpublas-lt support on this architecture!";
- }
- }},
- gpu_comp());
- }
-};
-
-TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) {
- auto* blas = backend().default_stream_executor()->AsBlas();
- ASSERT_TRUE(blas != nullptr);
- std::string version;
- ASSERT_TRUE(blas->GetVersion(&version).ok());
- VLOG(0) << "Blas version: " << version;
- ASSERT_TRUE(!version.empty());
-}
-
-TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) {
- constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
- %arg0 = f32[100,100]{1,0} parameter(0)
- %arg1 = f32[100,100]{1,0} parameter(1)
- ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
- auto module_cfg = GetModuleConfigForTest();
- auto debug_opts = module_cfg.debug_options();
- size_t num_left1 = 0, num_left2 = 0;
-
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kHlo, module_cfg));
-
- {
- // Run first with default settings (autotune level = 4), keep the number of
- // algorithms left after autotuning
- TF_ASSERT_OK_AND_ASSIGN(
- bool changed,
- RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
- module.get()));
-
- AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
- GemmAlgorithmPicker gpicker(cfg);
- // Note that, we do not care if the algorithm index has been changed:
- // the thing matters is the # of algorithms left after sorting out.
- TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
- num_left1 = gpicker.num_algorithms_left();
- if (num_left1 < 2) {
- GTEST_SKIP() << "Too few algorithms left after the first step";
- }
- }
-
- // Clear cache before the second run!
- AutotunerUtil::ClearAutotuneResults();
- {
- // Run once again but now with autotune level 5 and embarassingly tight
- // rtol which shall disqualify most of the algorithms.
-
- // Note that, we have "two sources of truth" for GemmAlgorithmPicker: i.e.,
- // debug_options are used to initialize both 'HloModuleConfig' and also
- // 'AutotuneConfig'.
- debug_opts.set_xla_gpu_autotune_gemm_rtol(1e-12);
- debug_opts.set_xla_gpu_autotune_level(5);
- module->mutable_config().set_debug_options(debug_opts);
- TF_ASSERT_OK_AND_ASSIGN(
- bool changed,
- RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
- module.get()));
-
- AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
- GemmAlgorithmPicker gpicker(cfg);
- TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
- num_left2 = gpicker.num_algorithms_left();
- }
- // Assert that we have fewer algorithms left after the second run.
- ASSERT_TRUE(num_left1 > num_left2);
-}
-
-TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) {
- constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
- %arg0 = f32[100,100]{1,0} parameter(0)
- %arg1 = f32[100,100]{1,0} parameter(1)
- ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
- auto module_cfg = GetModuleConfigForTest();
- TF_ASSERT_OK_AND_ASSIGN(auto m,
- ParseAndReturnVerifiedModule(kHlo, module_cfg));
-
- bool changed = false;
- TF_ASSERT_OK_AND_ASSIGN(
- changed,
- RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
- changed = false;
- DebugOptions opts;
- AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
- TF_ASSERT_OK_AND_ASSIGN(changed,
- RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
- ASSERT_TRUE(changed);
-
- AutotuneResults results;
- TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
- ASSERT_EQ(results.results_size(), 1);
- auto& result = *results.mutable_results(0)->mutable_result();
- int64_t old_algo_id = result.algorithm().algo_id();
- int64_t new_algo_id = old_algo_id + 1;
- result.mutable_gemm()->set_algorithm(new_algo_id);
-
- AutotunerUtil::ClearAutotuneResults();
- TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
-
- // Now send the same module through GemmAlgorithmPicker again. The dot should
- // have the new algorithm.
- TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
- changed = false;
- TF_ASSERT_OK_AND_ASSIGN(
- changed,
- RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
- changed = false;
- TF_ASSERT_OK_AND_ASSIGN(changed,
- RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
- ASSERT_TRUE(changed);
-
- SCOPED_TRACE(m->ToString());
- HloInstruction* dot;
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
-
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- dot->backend_config<GpuBackendConfig>());
- const GemmBackendConfig& config = gpu_config.gemm_backend_config();
- EXPECT_EQ(config.selected_algorithm(), new_algo_id);
-}
-
-TEST_P(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) {
- constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
- %arg0 = f32[100,100]{1,0} parameter(0)
- %arg1 = f32[100,100]{1,0} parameter(1)
- ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
- TF_ASSERT_OK_AND_ASSIGN(
- auto m, ParseAndReturnVerifiedModule(kHlo, GetModuleConfigForTest()));
-
- bool changed = false;
- TF_ASSERT_OK_AND_ASSIGN(
- changed,
- RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
- changed = false;
-
- DebugOptions opts;
- AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
-
- TF_ASSERT_OK_AND_ASSIGN(changed,
- RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
- ASSERT_TRUE(changed);
-
- AutotuneResults results;
- TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
- ASSERT_EQ(results.results_size(), 1);
- auto& result = *results.mutable_results(0)->mutable_result();
- int64_t old_algo_id = result.algorithm().algo_id();
- int64_t new_algo_id = old_algo_id + 1;
- result.mutable_gemm()->set_algorithm(new_algo_id);
-
- AutotunerUtil::ClearAutotuneResults();
- TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
-
- auto module_cfg = GetModuleConfigForTest();
- // Now send the same module through GemmAlgorithmPicker again. The dot should
- // have the new algorithm.
- TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
- changed = false;
-
- DevicelessConfig deviceless_config{gpu_device_desc()};
- AutotuneConfig deviceless_cfg{deviceless_config, opts};
- TF_ASSERT_OK_AND_ASSIGN(changed,
- RunHloPass(GemmRewriter(gpu_comp(),
- /*toolkit_version=*/12040),
- m.get()));
- changed = false;
- TF_ASSERT_OK_AND_ASSIGN(
- changed, RunHloPass(GemmAlgorithmPicker(deviceless_cfg), m.get()))
- ASSERT_TRUE(changed);
-
- SCOPED_TRACE(m->ToString());
- HloInstruction* dot;
-
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
-
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- dot->backend_config<GpuBackendConfig>());
- const GemmBackendConfig& config = gpu_config.gemm_backend_config();
-
- EXPECT_EQ(config.selected_algorithm(), new_algo_id);
-}
-
-INSTANTIATE_TEST_SUITE_P(GemmAlgorithmPickerTestSuite, GemmAlgorithmPickerTest,
- ::testing::Bool());
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc
deleted file mode 100644
index a6cbbf11..0000000
--- a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc
+++ /dev/null
@@ -1,124 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h"
-
-#include <cstdint>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace m = match;
-
-class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleCustomCall(HloInstruction *instr) override {
- HloInstruction *existing_gemm;
- HloInstruction *bcast;
- if (Match(instr, m::CustomCall(&existing_gemm,
- {kGemmCallTarget, kCublasLtMatmulCallTarget})
- .WithOperand(0, m::Broadcast(&bcast, m::Op()))) ||
- (Match(instr, m::CustomCall(&existing_gemm, {kGemmCallTarget,
- kCublasLtMatmulCallTarget})
- .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) {
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- existing_gemm->backend_config<GpuBackendConfig>());
- GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
- DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers();
- int bcast_operand_index = instr->operand_index(bcast);
- int num_bcast_dims = (bcast->shape().dimensions_size() -
- bcast->operand(0)->shape().dimensions_size());
- int num_batch_dims = dim_nums->lhs_batch_dimensions_size();
-
- const tsl::protobuf::RepeatedField<int64_t> &batch_dimensions =
- (bcast_operand_index == 1) ? dim_nums->rhs_batch_dimensions()
- : dim_nums->lhs_batch_dimensions();
- // This optimization is only valid if the set of broadcasted dimensions
- // is exactly the set of batch dimensions. First, check that all newly
- // broadcast dimensions have been inserted on the left i.e. all new
- // dimensions must be in [0, num_bcast_dims) or equivalently all original
- // dimensions are >= num_bcast_dims.
- for (int64_t bcast_dim : bcast->dimensions()) {
- if (bcast_dim < num_bcast_dims) {
- return absl::OkStatus();
- }
- // bcast_dim should not be in batch_dimensions.
- if (absl::c_linear_search(batch_dimensions, bcast_dim)) {
- return absl::OkStatus();
- }
- }
-
- // Then check that all batch dimensions are being broadcast, and that
- // there is at least one batch dimension.
- CHECK_GT(num_bcast_dims, 0);
- if (num_bcast_dims != num_batch_dims) {
- return absl::OkStatus();
- }
-
- if (bcast_operand_index == 1) {
- CHECK_EQ(dim_nums->rhs_contracting_dimensions_size(), 1);
- dim_nums->set_rhs_contracting_dimensions(
- 0, dim_nums->rhs_contracting_dimensions(0) - num_batch_dims);
- dim_nums->clear_rhs_batch_dimensions();
- } else {
- CHECK_EQ(dim_nums->lhs_contracting_dimensions_size(), 1);
- dim_nums->set_lhs_contracting_dimensions(
- 0, dim_nums->lhs_contracting_dimensions(0) - num_batch_dims);
- dim_nums->clear_lhs_batch_dimensions();
- }
- TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape(
- bcast_operand_index, bcast->mutable_operand(0)));
- TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
- MarkAsChanged();
- }
- return absl::OkStatus();
- }
-};
-
-static absl::StatusOr<bool> RunOnComputation(HloComputation *computation) {
- GemmBroadcastFoldingVisitor visitor;
- TF_RETURN_IF_ERROR(computation->Accept(&visitor));
- return visitor.changed();
-}
-
-absl::StatusOr<bool> GemmBroadcastFoldingRewriter::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- bool changed = false;
- for (HloComputation *computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h b/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h
deleted file mode 100644
index bac14bc..0000000
--- a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_
-#define XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// cuBLAS GEMM has support for strided batched calls, where the stride is used
-// to determine the offset between the batches.
-//
-// This allows (kCustomCall:gemm A kBroadcast(B)) or
-// (kCustomCall:gemm kBroadcast(A) B)
-// to be rewritten as (kCustomCall:gemm A B) with a zero stride for the
-// broadcasted operand if the broadcast operates on all the batch dimensions.
-//
-// This pattern matches the above case and removes the unnecessary broadcast.
-class GemmBroadcastFoldingRewriter : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "cublas-gemm-broadcast-folding-rewriter";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc
deleted file mode 100644
index 4e37a4b..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion.cc
+++ /dev/null
@@ -1,815 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_fusion.h"
-
-#include <array>
-#include <cstddef>
-#include <cstdint>
-#include <optional>
-#include <queue>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/triton_fusion_analysis.h"
-#include "xla/service/gpu/triton_tiling_propagation.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-using triton_fusion::CombineDotRequirements;
-using triton_fusion::DimensionOrder;
-using triton_fusion::DimOrderMap;
-using triton_fusion::DimOrdersAndReqs;
-using triton_fusion::DimOrdersAndReqsOrError;
-using triton_fusion::DotProperties;
-using triton_fusion::DotRequirements;
-using triton_fusion::DotRequirementsOrError;
-using triton_fusion::FusionContext;
-using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible;
-using triton_fusion::TransformDirection;
-
-// This represents a directed graph.
-class AdjacencyList {
- public:
- using NodeId = int64_t;
-
- NodeId AddNode() {
- adj_.emplace_back();
- return adj_.size() - 1;
- }
-
- const std::vector<NodeId>& GetOutNeighbors(NodeId node_id) const {
- return adj_.at(node_id);
- }
-
- void ReserveSpaceForOutNeighbors(NodeId node_id, size_t count) {
- adj_.at(node_id).reserve(count);
- }
-
- void AddArc(NodeId from, NodeId to) { adj_.at(from).push_back(to); }
-
- // Currently the Root node is the node which was added first.
- NodeId GetRoot() const {
- CHECK(!adj_.empty());
- return 0;
- }
-
- private:
- // Adjacency list: A vector of out-neighbors for each node.
- std::vector<std::vector<NodeId>> adj_;
-};
-
-struct HloAndDimOrder {
- const HloInstruction* original_hlo = nullptr;
- DimensionOrder dim_order;
-};
-
-struct HloAndIterSpec {
- const HloInstruction* original_hlo;
- TensorIterationSpec iter_spec;
-
- auto ToTuple() const { return std::make_tuple(original_hlo, iter_spec); }
- bool operator==(const HloAndIterSpec& other) const {
- return ToTuple() == other.ToTuple();
- }
- template <typename H>
- friend H AbslHashValue(H h, const HloAndIterSpec& key) {
- return H::combine(std::move(h), key.ToTuple());
- }
-};
-
-struct NodeFusionPlan {
- const HloInstruction* original_hlo = nullptr;
- bool should_fuse = false;
-};
-
-struct FusionPlan {
- // The graph describing the structure of the fusion that we build - nodes
- // corresponding to the instructions and arcs pointing from users to operands.
- AdjacencyList graph;
- // The fusion plan for each node.
- absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> map;
-};
-
-struct FusionPlanAndRequirements {
- FusionPlan fusion_plan;
- DotRequirements requirements;
-};
-
-struct HlosAndRequirements {
- // The original HLO (which is outside the fusion computation).
- const HloInstruction* original_hlo = nullptr;
- // The fused HLO inside the new fusion computation, built by the builder.
- //
- // This can have the same opcode as `original_hlo` or it can be a parameter if
- // the original HLO can't be fused.
- const HloInstruction* fused_hlo = nullptr;
- // The requirements imposed by the fused operations.
- //
- // If we fuse further operations they may have to conform to these
- // requirements.
- DotRequirements requirements;
-};
-
-// Clones the hero kDot operation into the fusion.
-HloInstruction& FuseDot(const HloDotInstruction& dot,
- const HloInstruction& fused_lhs,
- const HloInstruction& fused_rhs,
- std::optional<const HloInstruction*> fused_meta,
- HloComputation::Builder& builder // append
-) {
- VLOG(3) << "Fusing " << dot.ToString();
-
- std::vector<HloInstruction*> hlo_new_operands = {
- const_cast<HloInstruction*>(&fused_lhs),
- const_cast<HloInstruction*>(&fused_rhs)};
- if (fused_meta.has_value()) {
- hlo_new_operands.push_back(const_cast<HloInstruction*>(fused_meta.value()));
- }
- return *builder.AddInstruction(
- dot.CloneWithNewOperands(dot.shape(), hlo_new_operands));
-}
-
-// Tells how many new parameters does a fusion gain by fusing the operation as
-// an input.
-int64_t NumAddedParameters(const HloInstruction& hlo) {
- // Non-scalar constant is equivalent to a parameter: one input, one output.
- if (hlo.opcode() == HloOpcode::kParameter ||
- (hlo.opcode() == HloOpcode::kConstant &&
- !ShapeUtil::IsScalar(hlo.shape()))) {
- return 0;
- }
- // All other instructions add all own inputs and remove own single output.
- return hlo.operand_count() - 1;
-}
-
-// Just a helper to reduce "unwrapping" code where we use this.
-std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqs(
- const HloInstruction& hlo, const DimensionOrder& dim_order,
- const DotProperties& properties,
- const se::GpuComputeCapability& gpu_version,
- const DotRequirements& requirements) {
- DimOrdersAndReqsOrError dim_orders_and_new_reqs =
- GetPropagatedDimOrdersAndRequirements(
- hlo, dim_order, TransformDirection::kOutputToInput, properties);
- if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
- return std::nullopt;
- }
- DotRequirementsOrError combined_reqs = CombineDotRequirements(
- requirements,
- std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
- if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
- return std::nullopt;
- }
- return DimOrdersAndReqs{
- std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
- std::get<DotRequirements>(combined_reqs)};
-}
-
-// Just a helper to reduce "unwrapping" code where we use this.
-std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqsIfProfitable(
- const HloInstruction& hlo, const DimensionOrder& dim_order,
- const DotProperties& properties,
- const se::GpuComputeCapability& gpu_version,
- const DotRequirements& requirements) {
- DimOrdersAndReqsOrError dim_orders_and_new_reqs =
- GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
- hlo, TransformDirection::kOutputToInput,
- /*src_operand_index=*/std::nullopt, dim_order, gpu_version,
- properties);
- if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
- return std::nullopt;
- }
- DotRequirementsOrError combined_reqs = CombineDotRequirements(
- requirements,
- std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
- if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
- return std::nullopt;
- }
- return DimOrdersAndReqs{
- std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
- std::get<DotRequirements>(combined_reqs)};
-}
-
-// Just a helper to reduce "unwrapping" code where we use this.
-std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
- const HloInstruction& hlo, const DimensionOrder& hlo_dim_order,
- const HloInstruction& user, const DotProperties& properties,
- const se::GpuComputeCapability& gpu_version,
- const DotRequirements& requirements) {
- DimOrdersAndReqsOrError dim_orders_and_new_reqs =
- GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
- user, TransformDirection::kInputToOutput, user.operand_index(&hlo),
- hlo_dim_order, gpu_version, properties);
- if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
- return std::nullopt;
- }
- DotRequirementsOrError combined_reqs = CombineDotRequirements(
- requirements,
- std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
- if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
- return std::nullopt;
- }
- return DimOrdersAndReqs{
- std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
- std::get<DotRequirements>(combined_reqs)};
-}
-
-// Builds the fusion map and the requirements which can later be used to
-// actually fuse that subgraph.
-FusionPlanAndRequirements BuildFusionPlanTowardOperands(
- const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
- const std::optional<int>& max_params,
- const se::GpuComputeCapability& gpu_version,
- const DotProperties& properties,
- const DotRequirements& requirements_so_far) {
- CHECK(!max_params.has_value() || max_params.value() >= 1);
-
- // The graph describing the structure of the fusion that we build - nodes
- // corresponding to the instructions and arcs pointing from users to operands.
- // We can build and modify this graph easily without the need to create
- // HloInstructions at this point.
- AdjacencyList graph;
- // Stores the original HLO and the dimension order for each node. This is a
- // temporary map which is used when processing the nodes in this function.
- absl::flat_hash_map<AdjacencyList::NodeId, HloAndDimOrder>
- hlo_and_dim_order_map;
- // Stores the information needed to build the fused HLO for each node (what
- // was the original HLO and whether we should fuse it or create a parameter).
- // This is one of the outputs of this function.
- absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> fusion_plan_map;
- // Allows reusing nodes when multiple instructions iterate over the same HLO
- // using the same iteration spec. In that case we don't duplicate the
- // instruction in the fusion.
- absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map;
- // The requirements imposed by the fusion choices made in this function,
- // combined with the existing requirements. This is one of the outputs of this
- // function.
- DotRequirements combined_reqs = requirements_so_far;
-
- auto get_or_create_fusion_node =
- [&](const HloInstruction& hlo, const DimensionOrder& dim_order,
- bool* is_new_node = nullptr) -> AdjacencyList::NodeId {
- HloAndIterSpec reuse_key = {&hlo, dim_order.ToTensorIterationSpec()};
- if (auto it = node_reuse_map.find(reuse_key); it != node_reuse_map.end()) {
- if (is_new_node != nullptr) {
- *is_new_node = false;
- }
- return it->second;
- }
- AdjacencyList::NodeId node_id = graph.AddNode();
- CHECK(hlo_and_dim_order_map.insert({node_id, {&hlo, dim_order}}).second);
- CHECK(node_reuse_map.insert({reuse_key, node_id}).second);
- if (is_new_node != nullptr) {
- *is_new_node = true;
- }
- return node_id;
- };
- AdjacencyList::NodeId root =
- get_or_create_fusion_node(root_hlo, root_dim_order);
-
- // Nodes at the fusion edge that can either get fused too or become parameters
- // of the fusion. Used to track the number of parameters.
- absl::flat_hash_set<AdjacencyList::NodeId> inputs({root});
- std::queue<AdjacencyList::NodeId> queue({root});
- int64_t num_requeued = 0;
- // BFS
- while (queue.size() > num_requeued) {
- AdjacencyList::NodeId node_id = queue.front();
- queue.pop();
- const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
- const HloInstruction& original_hlo = *hlo_and_dim_order.original_hlo;
- const DimensionOrder& dim_order = hlo_and_dim_order.dim_order;
-
- // Watch the total number of fusion parameters.
- if (max_params.has_value() &&
- inputs.size() + NumAddedParameters(original_hlo) > max_params.value()) {
- // Re-queue: the number of parameters may go down when other instructions
- // are processed.
- queue.push(node_id);
- // Prevent infinite loops.
- ++num_requeued;
- continue;
- }
- num_requeued = 0;
- if (original_hlo.opcode() == HloOpcode::kParameter) {
- CHECK(fusion_plan_map
- .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
- .second);
- continue;
- }
- auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(
- original_hlo, dim_order, properties, gpu_version, combined_reqs);
- if (!opt_result.has_value()) {
- CHECK(fusion_plan_map
- .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
- .second);
- continue;
- }
- const DimOrderMap operand_dim_orders = std::move(opt_result->dim_orders);
- combined_reqs = std::move(opt_result->requirements);
- inputs.erase(node_id);
- graph.ReserveSpaceForOutNeighbors(node_id, original_hlo.operand_count());
- for (int64_t i = 0; i < original_hlo.operand_count(); ++i) {
- const HloInstruction& operand = *original_hlo.operand(i);
- const DimensionOrder& operand_dim_order = operand_dim_orders.at(&operand);
- bool is_new_node = false;
- AdjacencyList::NodeId operand_node_id =
- get_or_create_fusion_node(operand, operand_dim_order, &is_new_node);
- graph.AddArc(node_id, operand_node_id);
- if (is_new_node) {
- VLOG(6) << "Enqueueing " << operand.ToString() << ":"
- << operand_dim_order.ToString();
- inputs.insert(operand_node_id);
- queue.push(operand_node_id);
- }
- }
- CHECK(
- fusion_plan_map.insert({node_id, {&original_hlo, /*should_fuse=*/true}})
- .second);
- }
- // Handle the remaining requeued items.
- while (!queue.empty()) {
- AdjacencyList::NodeId node_id = queue.front();
- queue.pop();
-
- const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
- CHECK(fusion_plan_map
- .insert({node_id,
- {hlo_and_dim_order.original_hlo, /*should_fuse=*/false}})
- .second);
- }
- return {{std::move(graph), std::move(fusion_plan_map)},
- std::move(combined_reqs)};
-}
-
-// Builds the HLO instructions for the fusion represented by `fusion_plan`,
-// starting from `node_id`.
-HloInstruction& BuildFusionTowardOperandsImpl(
- AdjacencyList::NodeId node_id, const FusionPlan& fusion_plan,
- absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*>&
- fused_hlo_map, // read/append
- HloComputation::Builder& builder, // append
- std::vector<HloInstruction*>& fusion_params // append
-) {
- if (auto it = fused_hlo_map.find(node_id); it != fused_hlo_map.end()) {
- return *it->second;
- }
-
- const NodeFusionPlan& node_fusion_plan = fusion_plan.map.at(node_id);
- const bool should_fuse = node_fusion_plan.should_fuse;
- const HloInstruction& original_hlo = *node_fusion_plan.original_hlo;
-
- HloInstruction* fused_hlo = nullptr;
- if (should_fuse) {
- HloInstruction::InstructionVector new_operands;
- for (AdjacencyList::NodeId operand_id :
- fusion_plan.graph.GetOutNeighbors(node_id)) {
- new_operands.push_back(&BuildFusionTowardOperandsImpl(
- operand_id, fusion_plan, fused_hlo_map, builder, fusion_params));
- }
- fused_hlo = builder.AddInstruction(
- original_hlo.CloneWithNewOperands(original_hlo.shape(), new_operands));
- } else {
- fusion_params.push_back(const_cast<HloInstruction*>(&original_hlo));
- fused_hlo = builder.AddInstruction(HloInstruction::CreateParameter(
- fusion_params.size() - 1, original_hlo.shape(),
- absl::StrCat("parameter_", fusion_params.size() - 1)));
- }
-
- CHECK(fused_hlo_map.insert({node_id, fused_hlo}).second);
- return *fused_hlo;
-}
-
-// Builds the HLO instructions for the fusion represented by `fusion_plan`.
-HloInstruction& BuildFusionTowardOperands(
- const FusionPlan& fusion_plan,
- HloComputation::Builder& builder, // append
- std::vector<HloInstruction*>& fusion_params // append
-) {
- absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*> fused_hlo_map;
- return BuildFusionTowardOperandsImpl(fusion_plan.graph.GetRoot(), fusion_plan,
- fused_hlo_map, builder, fusion_params);
-}
-
-// Grows the fusion toward the operands.
-//
-// This always succeeds.
-//
-// If it's not possible to fuse something, it fuses a parameter instead.
-//
-// The fusion can grow until it has `max_params` params and it can only grow
-// with operations for which the DimOrder propagation works and they don't
-// impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to `root_hlo` and the
-// requirements corresponding to the whole fusion so far.
-HlosAndRequirements FuseTowardOperands(
- const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
- const std::optional<int>& max_params,
- const se::GpuComputeCapability& gpu_version,
- const DotProperties& properties, const DotRequirements& requirements_so_far,
- HloComputation::Builder& builder, // append
- std::vector<HloInstruction*>& fusion_params // append
-) {
- FusionPlanAndRequirements fusion_plan_and_reqs =
- BuildFusionPlanTowardOperands(root_hlo, root_dim_order, max_params,
- gpu_version, properties,
- requirements_so_far);
- HloInstruction& fused_hlo_or_param = BuildFusionTowardOperands(
- fusion_plan_and_reqs.fusion_plan, builder, fusion_params);
- return HlosAndRequirements{&root_hlo, &fused_hlo_or_param,
- fusion_plan_and_reqs.requirements};
-}
-
-// Grows the fusion toward the given dot operand.
-//
-// This always succeeds.
-//
-// If it's not possible to fuse something, it fuses a parameter instead.
-//
-// The fusion can grow until it has `max_params` params and it can only grow
-// with operations for which the DimOrder propagation works and they don't
-// impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to the given dot operand and
-// the requirements corresponding to the whole fusion so far.
-absl::StatusOr<HlosAndRequirements> FuseDotOperand(
- const HloInstruction& dot, int operand_index,
- const se::GpuComputeCapability& gpu_version,
- HloComputation::Builder& builder, // append
- std::vector<HloInstruction*>& fusion_params // append
-) {
- // Direct dot inputs have well defined dimension orders.
- TF_ASSIGN_OR_RETURN(const FusionContext context,
- FusionContext::FromDotOperand(dot, operand_index));
- const HloInstruction& operand = *dot.operand(operand_index);
- return FuseTowardOperands(operand, context.dim_orders().at(&operand),
- TritonFusionAnalysis::kMaxParameterPerDotOperand,
- gpu_version, context.dot_properties(),
- context.requirements(), builder, fusion_params);
-}
-
-// Grows the fusion toward the users.
-//
-// This always succeeds.
-//
-// The fusion can grow as long as the DimOrder propagation works and the users
-// don't impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to the "lowest" fused user
-// or `hlo` if no users can be fused.
-//
-// It also grows the fusion upward, toward the "other" operands of the users,
-// but currently only in special cases, such as binary elementwise operation
-// with broadcast of scalar constant.
-HlosAndRequirements FuseTowardUsers(
- const HloInstruction& hlo, const HloInstruction& fused_hlo,
- const DimensionOrder& hlo_dim_order,
- const se::GpuComputeCapability& gpu_version,
- const DotProperties& properties, const DotRequirements& requirements,
- HloComputation::Builder& builder, // append
- std::vector<HloInstruction*>& fusion_params // append
-) {
- const HlosAndRequirements existing_hlos_and_requirements = {&hlo, &fused_hlo,
- requirements};
- if (hlo.user_count() != 1) {
- return existing_hlos_and_requirements;
- }
- const HloInstruction& user = *hlo.users()[0];
- if (!legacy_triton::IsDistributiveOverAddition(user)) {
- return existing_hlos_and_requirements;
- }
-
- // Get the dim orders for the user.
- auto opt_user_result = GetUserDimOrdersAndCombinedReqsIfProfitable(
- hlo, hlo_dim_order, user, properties, gpu_version, requirements);
- if (!opt_user_result.has_value()) {
- return existing_hlos_and_requirements;
- }
- DimensionOrder user_dim_order = opt_user_result->dim_orders.at(&user);
- DotRequirements combined_requirements = opt_user_result->requirements;
-
- HloInstruction::InstructionVector new_operands;
- if (user.operand_count() == 1) {
- new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
- } else {
- // Get the dim orders for the operands of the user.
- // We shouldn't do a profitability check here, we made that decision in
- // GetUserDimOrdersAndCombinedReqsIfProfitable.
- auto opt_operand_result = GetOperandDimOrdersAndCombinedReqs(
- user, user_dim_order, properties, gpu_version, combined_requirements);
- // This shouldn't fail, because currently we only encounter this when we
- // have just propagated down the DimOrders on a binary elementwise
- // operation (user). In that case propagating up the DimOrders should always
- // work.
- if (!opt_operand_result.has_value()) {
- return existing_hlos_and_requirements;
- }
- DimOrderMap operand_dim_orders = opt_operand_result->dim_orders;
- combined_requirements = opt_operand_result->requirements;
-
- // Fuse the other operands of the user.
- for (int i = 0; i < user.operand_count(); ++i) {
- const HloInstruction& operand = *user.operand(i);
- if (&operand == &hlo) {
- new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
- } else {
- HlosAndRequirements hlos_and_requirements = FuseTowardOperands(
- operand, operand_dim_orders.at(&operand),
- /*max_params=*/std::nullopt, gpu_version, properties,
- combined_requirements, builder, fusion_params);
- new_operands.push_back(
- const_cast<HloInstruction*>(hlos_and_requirements.fused_hlo));
- combined_requirements = hlos_and_requirements.requirements;
- }
- }
- }
-
- const HloInstruction& fused_user = *builder.AddInstruction(
- user.CloneWithNewOperands(user.shape(), new_operands));
- return FuseTowardUsers(user, fused_user, user_dim_order, gpu_version,
- properties, combined_requirements, builder,
- fusion_params);
-}
-
-// Grows the fusion toward the users of the dot.
-//
-// This always succeeds.
-//
-// The fusion can grow as long as the DimOrder propagation works and the users
-// don't impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to the "lowest" fused user
-// or `dot` if no users can be fused.
-//
-// It also grows the fusion towards the "other" operands of the users, but
-// currently only in special cases, such as binary elementwise operation with
-// broadcast of scalar constant.
-HlosAndRequirements FuseDotOutput(
- const HloInstruction& dot, const HloInstruction& fused_dot,
- const se::GpuComputeCapability& gpu_version,
- const DotRequirements& requirements,
- HloComputation::Builder& builder, // append
- std::vector<HloInstruction*>& fusion_params // append
-) {
- const auto context =
- FusionContext::FromDotOutput(dot, /*split_k=*/1, requirements);
- return FuseTowardUsers(dot, fused_dot, context.dim_orders().at(&dot),
- gpu_version, context.dot_properties(),
- context.requirements(), builder, fusion_params);
-}
-
-// Fuses dot and the compatible and profitable to fuse operations around it
-// into a new fusion computation constructed using the builder. fusion_inputs
-// get populated with the non-fused instructions that become operands of the
-// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the
-// original instruction that has to be replaced by the call to the fusion.
-absl::StatusOr<FusionDecision> CreateDotFusion(
- const HloDotInstruction& dot, const se::GpuComputeCapability gpu_version,
- HloComputation::Builder& builder,
- std::vector<HloInstruction*>& fusion_inputs,
- HloInstruction** fusion_output_ptr) {
- VLOG(5) << dot.ToString();
- if (CodegenDecision is_supported =
- legacy_triton::IsTritonSupportedInstruction(dot, gpu_version);
- !is_supported) {
- VLOG(3) << is_supported.Explain();
- return is_supported;
- }
-
- // Verify sparse dot constraints.
- if (dot.sparse_operands()) {
- const SparsityDescriptor& descriptor = dot.sparsity().front();
- if (dot.sparse_operands() != 1 || descriptor.index() != 0) {
- return InvalidArgument("Sparsity is only supported on left operand");
- }
- if (descriptor.type() != SparsityType::SPARSITY_STRUCTURED_N_M ||
- descriptor.n() != 2 || descriptor.m() != 4) {
- return InvalidArgument("Only 2:4 structured sparsity is supported");
- }
- // DotDimensionSorter pass makes sure the sparse dimension is minor.
- CHECK_EQ(descriptor.dimension(), dot.operand(0)->shape().rank() - 1);
- }
-
- TF_ASSIGN_OR_RETURN(HlosAndRequirements lhs_hlos_and_reqs,
- FuseDotOperand(dot, /*operand_index=*/0, gpu_version,
- builder, fusion_inputs));
- TF_ASSIGN_OR_RETURN(HlosAndRequirements rhs_hlos_and_reqs,
- FuseDotOperand(dot, /*operand_index=*/1, gpu_version,
- builder, fusion_inputs));
- std::optional<const HloInstruction*> meta_hlo;
- if (dot.sparse_operands()) {
- TF_ASSIGN_OR_RETURN(HlosAndRequirements meta_hlos_and_reqs,
- FuseDotOperand(dot, /*operand_index=*/2, gpu_version,
- builder, fusion_inputs));
- meta_hlo.emplace(meta_hlos_and_reqs.fused_hlo);
- }
- HloInstruction& fused_dot =
- FuseDot(dot, *lhs_hlos_and_reqs.fused_hlo, *rhs_hlos_and_reqs.fused_hlo,
- meta_hlo, builder);
- // For now the RHS doesn't support splits, so it also doesn't impose any
- // requirements.
- HlosAndRequirements fused_output_and_reqs =
- FuseDotOutput(dot, fused_dot, gpu_version, lhs_hlos_and_reqs.requirements,
- builder, fusion_inputs);
-
- if (fusion_output_ptr != nullptr) {
- *fusion_output_ptr =
- const_cast<HloInstruction*>(fused_output_and_reqs.original_hlo);
- }
-
- const PrecisionConfig::Algorithm algorithm =
- dot.precision_config().algorithm();
- if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
- algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
- dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
- dot.sparse_operands()) {
- return FusionDecision{};
- }
-
- bool is_pure_matmul = true;
- (void)builder.ForEachInstruction([&](const HloInstruction* fused_hlo) {
- static constexpr std::array<HloOpcode, 4> kPureOpcodes = {
- HloOpcode::kBitcast, HloOpcode::kDot, HloOpcode::kParameter,
- HloOpcode::kReshape};
- if (absl::c_find(kPureOpcodes, fused_hlo->opcode()) == kPureOpcodes.end()) {
- is_pure_matmul = false;
- // Stop iterating.
- return absl::CancelledError();
- }
- return absl::OkStatus();
- });
- if (!is_pure_matmul) {
- return FusionDecision{};
- }
-
- return "No profitable operations to fuse.";
-}
-
-// Extracts into fused computations parts of HLO graph including dot()
-// operations that can target the triton GEMM emitter.
-class GemmFusionVisitor : public DfsHloRewriteVisitor {
- public:
- explicit GemmFusionVisitor(const se::GpuComputeCapability& gpu_version)
- : gpu_version_(gpu_version) {}
- // Checks that a dot() should be targeting the triton GEMM emitter;
- // if so - fuses all its compatible inputs and outputs as a new computation
- // and replaces the original dot() with a call to the computation.
- absl::Status HandleDot(HloInstruction* dot) override {
- CHECK_EQ(dot->opcode(), HloOpcode::kDot);
-
- int64_t gemm_rewrite_size_threshold =
- dot->GetModule()
- ->config()
- .debug_options()
- .xla_gpu_gemm_rewrite_size_threshold();
- TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
- IsMatrixMultiplicationTooSmallForRewriting(
- *dot, gemm_rewrite_size_threshold));
- if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*dot)) {
- return absl::OkStatus();
- }
-
- std::string fusion_name = absl::StrCat("gemm_fusion_", dot->name());
- HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation"));
- std::vector<HloInstruction*> fusion_inputs;
- HloInstruction* fusion_output = nullptr;
- TF_ASSIGN_OR_RETURN(
- const FusionDecision should_fuse,
- CreateDotFusion(*Cast<HloDotInstruction>(dot), gpu_version_, builder,
- fusion_inputs, &fusion_output));
- if (builder.last_added_instruction() == nullptr) {
- return absl::OkStatus();
- }
- // If a GEMM requiring padding for cuBLAS is encountered here this
- // happened because earlier ShouldTritonHandleGEMM() accepted it and padding
- // was skipped. Accept it ignoring profitability checks.
- // TODO(rocm): check ROCM padding requirements.
- if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
- if (!CublasRequiresPadding(
- *Cast<HloDotInstruction>(dot),
- std::get<se::CudaComputeCapability>(gpu_version_)) &&
- !should_fuse) {
- return absl::OkStatus();
- }
- }
-
- HloComputation* computation =
- dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
- /*is_entry=*/false);
- HloInstruction* dot_fusion =
- dot->parent()->AddInstruction(HloInstruction::CreateFusion(
- computation->root_instruction()->shape(),
- HloInstruction::FusionKind::kCustom, fusion_inputs, computation));
- // Copy the metadata of the `dot` to the newly created `fusion` op. This
- // is convenient for handling metadata in split-k rewriting subsequently.
- dot_fusion->set_metadata(dot->metadata());
- dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name);
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- dot_fusion->backend_config<GpuBackendConfig>());
- FusionBackendConfig& backend_config =
- *gpu_config.mutable_fusion_backend_config();
- backend_config.set_kind(std::string(kTritonGemmFusionKind));
- TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(gpu_config));
-
- if (fusion_output->IsRoot()) {
- fusion_output->parent()->set_root_instruction(dot_fusion);
- TF_RETURN_IF_ERROR(
- fusion_output->parent()->RemoveInstructionAndUnusedOperands(
- fusion_output));
- MarkAsChanged();
- } else {
- TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion));
- }
- XLA_VLOG_LINES(5, computation->ToString(HloPrintOptions::ShortParsable()));
- return absl::OkStatus();
- }
-
- private:
- se::GpuComputeCapability gpu_version_;
-};
-
-absl::StatusOr<bool> RunOnComputation(
- HloComputation* computation, const se::GpuComputeCapability& gpu_version) {
- GemmFusionVisitor visitor(gpu_version);
- TF_RETURN_IF_ERROR(computation->Accept(&visitor));
- return visitor.changed();
-}
-
-
-} // namespace
-
-bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
- const se::GpuComputeCapability& gpu_version) {
- std::vector<HloInstruction*> fusion_inputs;
- HloComputation::Builder builder("disposable");
- return CreateDotFusion(dot, gpu_version, builder, fusion_inputs,
- /*fusion_output_ptr=*/nullptr)
- ->CanFuse();
-}
-
-absl::StatusOr<bool> GemmFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- TF_RETURN_IF_ERROR(
- EnsureTritonSupportsComputeCapability(compute_capability_));
-
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result,
- RunOnComputation(computation, compute_capability_));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.h b/third_party/xla/xla/service/gpu/gemm_fusion.h
deleted file mode 100644
index c858b43..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion.h
+++ /dev/null
@@ -1,57 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_FUSION_H_
-#define XLA_SERVICE_GPU_GEMM_FUSION_H_
-
-// This file contains the code for fusing dots and other operations into Triton
-// GEMM fusions.
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Filters GEMMs which are better to handle using Triton.
-bool ShouldTritonHandleGEMM(HloDotInstruction&,
- const se::GpuComputeCapability&);
-
-// Rewrite compatible dot() calls into custom calls with fused computations
-// that target Triton-based matmul emitter.
-class GemmFusion : public HloModulePass {
- public:
- explicit GemmFusion(const se::GpuComputeCapability& compute_capability)
- : compute_capability_(compute_capability) {}
- absl::string_view name() const override { return "triton-gemm-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::GpuComputeCapability compute_capability_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GEMM_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc
deleted file mode 100644
index 5a5a0f3..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc
+++ /dev/null
@@ -1,1280 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_fusion_autotuner.h"
-
-#include <algorithm>
-#include <array>
-#include <atomic>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "third_party/gpus/cuda/include/cublas_v2.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/primitive_util.h"
-#include "xla/service/algorithm_util.h"
-#include "xla/service/dump.h"
-#include "xla/service/executable.h"
-#include "xla/service/float_normalization.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
-#include "xla/service/gpu/fusion_wrapper.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/gpu_float_support.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/instruction_fusion.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/priority_fusion.h"
-#include "xla/service/gpu/split_k_gemm_rewriter.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/tsl/util/proto/proto_utils.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/lib/core/bits.h"
-#include "tsl/platform/blocking_counter.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/protobuf.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/threadpool.h"
-#include "tsl/profiler/lib/scoped_annotation.h"
-
-// Log levels used in this file:
-// VLOG(1): Overview
-// VLOG(2): Autotuning progress
-// VLOG(3): Autotuning progress - more frequent
-// VLOG(4): Print all fusions
-// VLOG(5): Profiling information for every tiling
-// VLOG(10): Print fusion computations and each configuration
-
-// TODO(b/317016172): Update usages of TritonGemmConfig to use newly exposed
-// parameters.
-
-namespace xla {
-namespace gpu {
-
-using Config = GemmFusionAutotunerImpl::Config;
-using TilingConfigs = GemmFusionAutotunerImpl::TilingConfigs;
-using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
-
-namespace {
-
-// Minimum tile size.
-constexpr int kMinTileSize = 16;
-
-// Default tiling when autotuning is disabled.
-constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4};
-
-// Split-K is enabled when the estimate number of waves is lower than the limit.
-constexpr int kMaxWavesForSplitK = 5;
-
-// Search space for exhaustive matmul autotuning.
-constexpr std::array<int, 6> kBlockSizes = {16, 32, 64, 128, 256, 512};
-constexpr std::array<int, 4> kNumStages = {1, 2, 3, 4};
-constexpr std::array<int, 4> kNumWarps = {2, 4, 8, 16};
-constexpr std::array<int, 5> kSplitK = {1, 2, 4, 8, 16};
-constexpr std::array<int, 5> kNumCtas = {1, 2, 4, 8, 16};
-
-using AutoTuneCacheKeyCount = absl::flat_hash_map<AutotuneCacheKey, uint64_t>;
-
-class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor {
- public:
- explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config)
- : config_(config) {}
-
- absl::Status HandleFusion(HloInstruction* hlo) override {
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- hlo->backend_config<GpuBackendConfig>());
- FusionBackendConfig& backend_config =
- *gpu_config.mutable_fusion_backend_config();
- if (backend_config.kind() != kTritonGemmFusionKind &&
- backend_config.kind() != kCuDnnFusionKind) {
- return absl::OkStatus();
- }
-
- VLOG(4) << "Processing " << hlo->ToString();
- if (!backend_config.has_triton_gemm_config() &&
- !backend_config.has_cudnn_fusion_config()) {
- TF_ASSIGN_OR_RETURN(
- AutotuneResult autotune_result,
- AutotunerUtil::Autotune(
- hlo, config_, [&]() -> absl::StatusOr<AutotuneResult> {
- if (config_.IsDeviceless()) {
- return absl::InternalError(absl::StrCat(
- "Expect autotune result cache hit for deviceless "
- "compilation (HLO: ",
- hlo->ToString(), ")"));
- }
- return absl::InternalError("Expect autotune result cache hit.");
- }));
- VLOG(4) << "Result: " << autotune_result.ShortDebugString();
-
- if (autotune_result.has_triton()) {
- *backend_config.mutable_triton_gemm_config() = autotune_result.triton();
- TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
- } else if (autotune_result.has_gemm()) {
- // Falling back to cuBLAS: Converting the fusion to a Call, so that it
- // can be inlined back again.
- HloComputation* const computation = hlo->parent();
- HloInstruction* const call = computation->AddInstruction(
- HloInstruction::CreateCall(hlo->shape(), hlo->operands(),
- hlo->fused_instructions_computation()));
- TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call));
- hlo = call;
- } else {
- CHECK(autotune_result.has_algorithm());
- backend_config.set_kind(std::string(kCuDnnFusionKind));
- backend_config.mutable_cudnn_fusion_config()->set_plan_id(
- autotune_result.algorithm().algo_id());
- TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
- }
- }
-
- if (backend_config.has_triton_gemm_config()) {
- TF_ASSIGN_OR_RETURN(
- const TritonGemmConfig config,
- TritonGemmConfig::FromProto(backend_config.triton_gemm_config()));
- if (config.split_k > 1) {
- TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config));
- }
- }
-
- MarkAsChanged();
- return absl::OkStatus();
- }
-
- private:
- AutotuneConfig config_;
-};
-
-class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {
- public:
- explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl)
- : impl_(impl) {}
-
- // Find configurations to tune.
- absl::StatusOr<TilingConfigs> CollectGemmConfigSets(
- const HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
- error_out_on_cache_miss_ =
- module->config()
- .debug_options()
- .xla_gpu_require_complete_aot_autotune_results();
- gemm_config_sets_.clear();
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_RETURN_IF_ERROR(computation->Accept(this));
- }
- return std::move(gemm_config_sets_);
- }
-
- AutoTuneCacheKeyCount GetFusionsCount() {
- return std::move(fusion_count_map_);
- }
-
- absl::Status HandleFusion(const HloInstruction* hlo) override {
- const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- hlo->backend_config<GpuBackendConfig>());
- const FusionBackendConfig& backend_config =
- gpu_config.fusion_backend_config();
-
- AutotuneCacheKey key = AutotunerUtil::GetKey(hlo, impl_->GetConfig());
-
- auto [iterator, inserted] = fusion_count_map_.insert({key, 1});
- if (!inserted) {
- ++(iterator->second);
- }
-
- TF_ASSIGN_OR_RETURN(bool is_in_cache,
- AutotunerUtil::IsInCache(key, impl_->GetConfig()));
- if (is_in_cache || handled_fusions_.contains(key)) {
- return absl::OkStatus();
- }
-
- bool missing_config = (backend_config.kind() == kTritonGemmFusionKind &&
- !backend_config.has_triton_gemm_config()) ||
- (backend_config.kind() == kCuDnnFusionKind &&
- !backend_config.has_cudnn_fusion_config());
- if (missing_config) {
- if (error_out_on_cache_miss_) {
- return absl::NotFoundError(absl::StrCat(
- "Complete autotuning results are required, but no cache result "
- "found for key: ",
- key.ToString()));
- }
-
- TF_ASSIGN_OR_RETURN(std::vector<Config> configs,
- impl_->GenerateConfigs(*fusion));
- gemm_config_sets_.push_back({fusion, std::move(configs)});
- }
-
- handled_fusions_.insert(key);
- return absl::OkStatus();
- }
-
- absl::Status DefaultAction(const HloInstruction* hlo) override {
- return absl::OkStatus();
- }
-
- private:
- bool error_out_on_cache_miss_;
- GemmFusionAutotunerImpl* impl_;
- TilingConfigs gemm_config_sets_;
- AutoTuneCacheKeyCount fusion_count_map_;
- absl::flat_hash_set<AutotuneCacheKey> handled_fusions_;
-};
-
-struct TileSizeLimit {
- int block_m = 0;
- int block_n = 0;
- int block_k = 0;
-};
-
-absl::StatusOr<TileSizeLimit> GetLimits(const HloDotInstruction& dot) {
- TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_lhs,
- NonContractingDimensionIndex(dot, /*operand_number=*/0));
- TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_rhs,
- NonContractingDimensionIndex(dot, /*operand_number=*/1));
- TF_ASSIGN_OR_RETURN(int64_t contracting_index,
- ContractingDimensionIndex(dot, /*operand_number=*/1));
- // This is not a sharp upper limit, the actual m value can be much smaller
- // based on how much of the m dimension is physically contiguous.
- // TODO(tdanyluk): Get the exact m value by running a TritonFusionAnalysis.
- const int max_m = tsl::NextPowerOfTwoS64(
- dot.operand(0)->shape().dimensions(non_contracting_index_lhs));
- // Theoretically the same is true as for m, but that is not possible in
- // practice with the current implementation.
- const int max_n = tsl::NextPowerOfTwoS64(
- dot.operand(1)->shape().dimensions(non_contracting_index_rhs));
- // This is before doing the split-k transform.
- const int max_k = tsl::NextPowerOfTwoS64(
- dot.operand(1)->shape().dimensions(contracting_index));
-
- return TileSizeLimit{
- /*block_m=*/std::max(max_m, kMinTileSize),
- /*block_n=*/std::max(max_n, kMinTileSize),
- /*block_k=*/std::max(max_k, kMinTileSize),
- };
-}
-
-int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; }
-
-absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
- const TritonGemmConfig& config,
- const se::DeviceDescription& gpu_device_info,
- const HloFusionInstruction* fusion, DebugOptions debug_opts,
- bool allow_filtering_kernels_spilling_registers) {
- std::unique_ptr<HloModule> new_module =
- ExtractInstructionIntoNewModule(*fusion);
- if (!allow_filtering_kernels_spilling_registers) {
- debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
- false);
- }
- new_module->mutable_config().set_debug_options(debug_opts);
-
- HloComputation* entry_computation = new_module->entry_computation();
- HloInstruction* cloned_dot_fusion = entry_computation->root_instruction();
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- cloned_dot_fusion->backend_config<GpuBackendConfig>());
- FusionBackendConfig& backend_config =
- *gpu_config.mutable_fusion_backend_config();
-
- *backend_config.mutable_triton_gemm_config() = config.ToProto();
- TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(gpu_config));
-
- if (config.split_k > 1) {
- TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config));
- GpuFloatSupport bf16_support(gpu_device_info.cuda_compute_capability(),
- BF16);
- FloatNormalization float_normalization(&bf16_support);
- TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status());
-
- auto shape_size_function = [&](const Shape& shape) {
- // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the
- // pointer size is used only to determine the size of tuple types. We
- // shouldn't have any tuples in the autotuned module, so it's safe to use
- // a constant here, instead of piping the real value.
- constexpr int64_t kPointerSize = 8;
- return ShapeUtil::ByteSizeOf(shape, kPointerSize);
- };
- GpuPriorityFusion priority_fusion(
- /*thread_pool=*/nullptr, gpu_device_info,
- GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function,
- /*per_second_rates=*/{},
- /*count_multiple_input_accesses=*/true});
- TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status());
-
- // If the priority fusion pass above skipped some instructions, turn them
- // into fusions.
- FusionWrapper fusion_wrapper;
- TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status());
- }
- return new_module;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> CublasGemmAutotuneExtractor(
- const AutotuneConfig& config, const int32_t toolkit_version,
- const HloFusionInstruction* fusion, const DebugOptions& debug_opts) {
- const HloComputation* fusion_computation =
- fusion->called_computations().at(0);
- std::unique_ptr<HloModule> new_module =
- ExtractComputationIntoNewModule(*fusion_computation);
- new_module->mutable_config().set_debug_options(debug_opts);
-
- auto* dot = hlo_query::GetFirstInstructionWithOpcode(
- *new_module->entry_computation(), HloOpcode::kDot);
- // Substitute algorithms, which are not supported by cuBLAS for the check, but
- // don't use cuBlas in the end. This assumes that the substituting algorithm
- // has result which are close enough for the check in this file.
- if (dot->precision_config().algorithm() ==
- PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
- dot->precision_config().algorithm() ==
- PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) {
- dot->mutable_precision_config()->set_algorithm(
- PrecisionConfig::ALG_DOT_F32_F32_F32);
- }
-
- for (bool fp8 : {true, false}) {
- GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version,
- fp8);
- GpuInstructionFusion fusion_pass(
- /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription());
- TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status());
- TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status());
- }
- // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS
- // performance. It is probably not needed on Ampere and later because cuBLAS
- // ignores the algorithm parameter for those targets. If we run
- // GemmAlgorithmPicker, we probably should not run this in parallel with other
- // compilations.
- return new_module;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> FusionExtractor(
- const HloFusionInstruction& fusion, const DebugOptions& debug_opts) {
- std::unique_ptr<HloModule> module = ExtractInstructionIntoNewModule(fusion);
- module->mutable_config().set_debug_options(debug_opts);
- return module;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> CuDnnFusionExtractor(
- const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
- const int plan_id) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- FusionExtractor(fusion, debug_opts));
-
- GpuBackendConfig gpu_config;
- FusionBackendConfig& backend_config =
- *gpu_config.mutable_fusion_backend_config();
- backend_config.set_kind(std::string(kCuDnnFusionKind));
- // Provided a plan ID the autotuner just compiles one plan.
- backend_config.mutable_cudnn_fusion_config()->set_plan_id(plan_id);
- TF_RETURN_IF_ERROR(
- module->entry_computation()->root_instruction()->set_backend_config(
- gpu_config));
- return module;
-}
-
-bool IsFusionKind(const HloInstruction& hlo, absl::string_view kind) {
- auto gpu_config = hlo.backend_config<GpuBackendConfig>();
- if (!gpu_config.ok()) {
- return false;
- }
- return gpu_config->fusion_backend_config().kind() == kind;
-}
-
-int GetCuDnnPlanCount(const HloInstruction& hlo,
- const AutotuneConfig& autotune_config) {
- if (auto gpu_config = hlo.backend_config<GpuBackendConfig>();
- !gpu_config.ok() ||
- gpu_config->fusion_backend_config().has_cudnn_fusion_config()) {
- return {};
- }
- return CuDnnFusionCompiler::GetAvailablePlanCount(
- *autotune_config.GetExecutor(), *DynCast<HloFusionInstruction>(&hlo));
-}
-
-AutotuneResult FromConfig(const Config& config) {
- AutotuneResult res;
- if (std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(config)) {
- res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT);
- } else if (std::holds_alternative<GemmFusionAutotunerImpl::CuDnnConfig>(
- config)) {
- res.mutable_algorithm()->set_algo_id(
- std::get<GemmFusionAutotunerImpl::CuDnnConfig>(config).plan_id);
- } else if (std::holds_alternative<TritonGemmConfig>(config)) {
- *res.mutable_triton() = std::get<TritonGemmConfig>(config).ToProto();
- } else {
- LOG(FATAL) << "Unsupported config type: " << config.index();
- }
- return res;
-}
-
-absl::Status DumpOriginalFusion(AutotunerCompileUtil& util,
- const HloFusionInstruction& fusion,
- int fusion_id) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- util.ExtractModule([&](const DebugOptions& debug_opts) {
- return FusionExtractor(fusion, debug_opts);
- }));
- module->set_name(std::string(fusion.name()));
- // Using the original module for its debug info and name in the first
- // parameter. It's better to include the name of both the original module
- // and the extracted module, to avoid name clashes.
- DumpToFileInDirOrStdout(
- /*module=*/*fusion.GetModule(),
- /*file_prefix=*/"",
- /*file_suffix=*/
- absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".txt"),
- /*contents=*/module->ToString());
- return absl::OkStatus();
-}
-
-absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config,
- const int32_t toolkit_version,
- AutotunerCompileUtil& util,
- const AutotuneResult result,
- const HloFusionInstruction* fusion,
- int fusion_id) {
- TritonGemmConfig triton_gemm_config;
- if (result.has_triton()) {
- TF_ASSIGN_OR_RETURN(triton_gemm_config,
- TritonGemmConfig::FromProto(result.triton()));
- }
- const se::DeviceDescription& device_desc =
- autotune_config.GetExecutor()->GetDeviceDescription();
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModule> module,
- util.ExtractModule([&](const DebugOptions& debug_opts) {
- if (result.has_algorithm()) {
- return CuDnnFusionExtractor(*fusion, debug_opts,
- result.algorithm().algo_id());
- } else if (result.has_triton()) {
- return TritonGemmAutotuneExtractor(
- triton_gemm_config, device_desc, fusion, debug_opts,
- /*allow_filtering_kernels_spilling_registers=*/true);
- } else if (result.has_gemm()) {
- return CublasGemmAutotuneExtractor(autotune_config, toolkit_version,
- fusion, debug_opts);
- } else {
- LOG(FATAL) << "Unknown result type: " << result.DebugString();
- }
- }));
- module->set_name(std::string(fusion->name()));
- // Using the original module for its debug info and name in the first
- // parameter. It's better to include the name of both the original module
- // and the extracted module, to avoid name clashes.
- DumpToFileInDirOrStdout(
- /*module=*/*fusion->GetModule(),
- /*file_prefix=*/"",
- /*file_suffix=*/
- absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(),
- ".optimized.txt"),
- /*contents=*/module->ToString());
- return absl::OkStatus();
-}
-
-std::string Serialize(const Config& config) {
- if (auto triton_config = std::get_if<TritonGemmConfig>(&config)) {
- tsl::protobuf::TextFormat::Printer printer;
- printer.SetSingleLineMode(true);
- std::string result;
- printer.PrintToString(triton_config->ToProto(), &result);
- return result;
- }
- return GemmFusionAutotunerImpl::ToString(config);
-}
-
-} // anonymous namespace
-
-// Methods required for sorting the configs.
-bool GemmFusionAutotunerImpl::CuBlasConfig::operator<(
- const CuBlasConfig& other) const {
- return false;
-}
-bool GemmFusionAutotunerImpl::CuDnnConfig::operator<(
- const CuDnnConfig& other) const {
- return plan_id < other.plan_id;
-}
-
-bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const {
- return debug_options_.xla_gpu_autotune_level() > 0 &&
- !debug_options_.xla_gpu_deterministic_ops();
-}
-
-/*static*/ std::string GemmFusionAutotunerImpl::ToString(const Config& config) {
- if (std::holds_alternative<TritonGemmConfig>(config)) {
- return std::get<TritonGemmConfig>(config).ToString();
- } else if (std::holds_alternative<CuDnnConfig>(config)) {
- return absl::StrFormat("cuDNN plan %d",
- std::get<CuDnnConfig>(config).plan_id);
- } else if (std::holds_alternative<CuBlasConfig>(config)) {
- return "reference (cublas)";
- } else {
- LOG(FATAL) << "Unsupported config type: " << config.index();
- }
-}
-
-absl::StatusOr<std::vector<Config>> GemmFusionAutotunerImpl::GenerateConfigs(
- const HloFusionInstruction& fusion) {
- const HloDotInstruction* dot =
- Cast<HloDotInstruction>(hlo_query::GetFirstInstructionWithOpcode(
- *fusion.called_computations().at(0), HloOpcode::kDot));
-
- // Add cuBLAS reference config, if available.
- std::vector<Config> configs;
- if (algorithm_util::IsSupportedByCublasOrCublasLt(
- dot->precision_config().algorithm()) &&
- !dot->sparse_operands() && IsAutotuningEnabled()) {
- configs.push_back(CuBlasConfig{});
- }
-
- // Add cuDNN plans, if available.
- bool is_hopper =
- !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
- bool is_cudnn_enabled =
- debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
- GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
- if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
- (IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
- algorithm_util::IsSupportedByCudnn(
- dot->precision_config().algorithm()) &&
- !dot->sparse_operands() && IsAutotuningEnabled())) {
- const int plan_count = GetCuDnnPlanCount(fusion, config_);
- for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
- configs.push_back(CuDnnConfig{plan_id});
- }
- }
- if (IsFusionKind(fusion, kCuDnnFusionKind)) {
- if (!IsAutotuningEnabled()) {
- configs.push_back(CuDnnConfig{-1});
- }
- return configs;
- }
-
- // Add triton configs.
- TF_ASSIGN_OR_RETURN(std::vector<TritonGemmConfig> triton_configs,
- GenerateTritonConfigs(*dot));
- for (TritonGemmConfig& config : triton_configs) {
- configs.push_back(std::move(config));
- }
- return configs;
-}
-
-absl::StatusOr<std::vector<TritonGemmConfig>>
-GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
- // Retrieve the minimum bit-width participating in the dot. This is needed
- // to avoid autotuning configurations that are not supported by Triton. This
- // is used to restrict the values for tile_k.
- std::vector<const HloInstruction*> converts =
- HloBfsFindAll({&dot}, [&](const HloInstruction* node) {
- return node->opcode() == HloOpcode::kConvert;
- });
- int minBitWidth = primitive_util::BitWidth(dot.shape().element_type());
- for (auto convert : converts) {
- auto in_type = convert->operand(0)->shape().element_type();
- auto out_type = convert->shape().element_type();
- minBitWidth = std::min({minBitWidth, primitive_util::BitWidth(in_type),
- primitive_util::BitWidth(out_type)});
- }
-
- std::vector<TritonGemmConfig> result_configs;
- TF_ASSIGN_OR_RETURN(TileSizeLimit limits, GetLimits(dot));
-
- // Generate the list of configurations (once).
- if (triton_configs_.empty()) {
- triton_configs_ = !IsAutotuningEnabled()
- ? std::vector(1, kDefaultGemmTiling)
- : debug_options_.xla_gpu_exhaustive_tiling_search()
- ? GetExhaustiveTritonConfigs()
- : GetDefaultTritonConfigs();
- }
-
- // Avoid autotuning tiny fusions.
- constexpr int kMinGemmElements = 32 * 32;
- bool small_dot =
- ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements &&
- ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements;
- std::vector<TritonGemmConfig> triton_configs =
- small_dot ? std::vector(1, kDefaultGemmTiling) : triton_configs_;
-
- // Split-K optimization enables more even utilization of a GPU in cases
- // where tiling just the non-contracting dimensions of a GEMM does not create
- // a sufficient number of thread block programs to occupy all available cores.
- // Around 5 full waves completely avoid the need for split-K.
- // n_tiles = split_k * (M * N) / (block_m * block_n)
- const int kCoreCount =
- !config_.IsDeviceless()
- ? config_.GetExecutor()->GetDeviceDescription().core_count()
- : 100; // some sensible default
- const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount;
- const int64_t result_size = ShapeUtil::ElementsIn(dot.shape());
-
- // Triton configurations are adjusted and deduplicated.
- absl::flat_hash_set<TritonGemmConfig> added;
- bool is_hopper =
- !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
- for (TritonGemmConfig& config : triton_configs) {
- config.block_m = std::min(config.block_m, limits.block_m);
- config.block_n = std::min(config.block_n, limits.block_n);
- config.block_k = std::min(config.block_k, limits.block_k);
- int max_split_k = 1;
- if (debug_options_.xla_gpu_enable_split_k_autotuning()) {
- int64_t ratio = kSufficientNumberOfTiles * config.block_m *
- config.block_n / result_size;
- max_split_k = 1 << std::max<int>(tsl::Log2Floor64(ratio), 0);
- }
- config.split_k = std::min(config.split_k, max_split_k);
-
- // TODO(b/337839570): Triton currently has a limitation where it crashes
- // on small block_k values depending on the bit-width of the inputs to the
- // dot. The logic below accounts for this limitation.
- constexpr int kLdmatrixGranularity = 256;
- config.block_k =
- std::max(config.block_k, kLdmatrixGranularity / minBitWidth);
-
- // Sparse meta should have at least one element per thread.
- // Note: only 2:4 structured sparsity is currently supported.
- if (dot.sparse_operands()) {
- if (is_hopper) {
- config.block_m = std::max(config.block_m, 64);
- config.num_warps = std::max(config.num_warps, 4);
- }
- config.block_k = std::max(
- config.block_k,
- 2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
- int meta_elements = config.block_m * config.block_k / 16;
- config.num_warps =
- std::min<int>(config.num_warps, meta_elements / WarpSize());
- }
-
- if (added.insert(config).second) {
- result_configs.push_back(config);
- }
- }
- return result_configs;
-}
-
-absl::StatusOr<absl::flat_hash_map<
- const HloFusionInstruction*,
- std::vector<GemmFusionAutotunerImpl::ExecutableCandidate>>>
-GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
- const TilingConfigs& task) {
- tsl::profiler::ScopedAnnotation annotation("XlaAutotunerCompilation");
- absl::Mutex results_mu;
- absl::flat_hash_map<const HloFusionInstruction*,
- std::vector<ExecutableCandidate>>
- results;
- if (task.empty()) {
- return results;
- }
-
- const int log_every_n = GetLogEveryN();
- int64_t config_count = 0;
- for (const auto& [unused, configs] : task) {
- config_count += configs.size();
- }
-
- std::atomic<int> done_count = 0;
- std::atomic<int> good_count = 0;
- auto log = [&](bool success) {
- const int done_so_far = done_count.fetch_add(1) + 1;
- const int good_so_far =
- success ? good_count.fetch_add(1) + 1 : good_count.load();
- if (done_so_far % log_every_n == 0) {
- VLOG(2) << "Compiled " << done_so_far << " of " << config_count
- << " configs (successful: " << good_so_far << ")";
- }
- };
-
- auto compile = [&](const HloFusionInstruction* fusion, const Config& config,
- bool allow_filtering_kernels_spilling_registers)
- -> absl::StatusOr<bool> {
- std::unique_ptr<Executable> executable;
- if (std::holds_alternative<TritonGemmConfig>(config)) {
- TF_ASSIGN_OR_RETURN(
- executable, compile_util.Compile([&](const DebugOptions& opts) {
- return TritonGemmAutotuneExtractor(
- std::get<TritonGemmConfig>(config),
- config_.GetExecutor()->GetDeviceDescription(), fusion, opts,
- allow_filtering_kernels_spilling_registers);
- }));
- } else if (std::holds_alternative<CuDnnConfig>(config)) {
- executable =
- compile_util
- .Compile([&](const DebugOptions& opts) {
- return CuDnnFusionExtractor(
- *fusion, opts, std::get<CuDnnConfig>(config).plan_id);
- })
- .value_or(nullptr);
- } else if (std::holds_alternative<CuBlasConfig>(config)) {
- TF_ASSIGN_OR_RETURN(executable,
- compile_util.Compile([&](const DebugOptions& opts) {
- return CublasGemmAutotuneExtractor(
- config_, toolkit_version_, fusion, opts);
- }));
- } else {
- LOG(FATAL) << "Unsupported config type: " << config.index();
- }
- if (executable != nullptr) {
- absl::MutexLock lock(&results_mu);
- results[fusion].push_back({config, std::move(executable)});
- return true;
- }
- return false;
- };
-
- // If the thread pool has only one thread, then it is actually slower to
- // offload the tasks there.
- if (thread_pool_ && thread_pool_->NumThreads() > 1 &&
- debug_options_.xla_gpu_force_compilation_parallelism() != 1) {
- if (task.size() == 1) {
- absl::string_view fusion_name = task.begin()->first->name();
- VLOG(1) << "Compiling " << config_count << " configs for " << fusion_name
- << " on " << thread_pool_->NumThreads() << " threads.";
- } else {
- VLOG(1) << "Compiling " << config_count << " configs for " << task.size()
- << " fusions on " << thread_pool_->NumThreads() << " threads.";
- }
-
- tsl::BlockingCounter counter(config_count);
- for (const auto& key_value : task) {
- const HloFusionInstruction* fusion = key_value.first;
- const std::vector<Config>& gemm_config_set = key_value.second;
-
- VLOG(10) << "Compiling fusion: " << fusion->name();
- VLOG(10) << "Dumping fusion computation: "
- << fusion->called_computation()->ToString();
- for (const Config& config : gemm_config_set) {
- thread_pool_->Schedule([&, fusion] {
- VLOG(10) << "Trying configuration forceable through: "
- "--xla_gpu_override_gemm_autotuner='"
- << Serialize(config) << "'";
- VLOG(10) << "WARNING: you are running in multithreaded-mode, the "
- "last configuration printed out might not be the one "
- "causing issues! Use "
- "--xla_gpu_force_compilation_parallelism=1 to fix.";
- absl::StatusOr<bool> has_executable =
- compile(fusion, config, gemm_config_set.size() > 1);
- TF_CHECK_OK(has_executable.status())
- << "Failure occured when compiling fusion " << fusion->name()
- << " with config '" << ToString(config)
- << "'\nFused HLO computation:\n"
- << fusion->fused_instructions_computation()->ToString();
- log(has_executable.value());
- counter.DecrementCount();
- });
- }
- }
- counter.Wait();
- } else {
- if (task.size() == 1) {
- absl::string_view fusion_name = task.begin()->first->name();
- LOG(WARNING) << "Compiling " << config_count << " configs for "
- << fusion_name << " on a single thread.";
- } else {
- LOG(WARNING) << "Compiling " << config_count << " configs for "
- << task.size() << " fusions on a single thread.";
- }
-
- for (const auto& [fusion, gemm_config_set] : task) {
- VLOG(10) << "Compiling fusion: " << fusion->name();
- VLOG(10) << "Dumping fusion computation: "
- << fusion->called_computation()->ToString();
- for (const Config& config : gemm_config_set) {
- VLOG(10) << "Trying configuration forceable through: "
- "--xla_gpu_override_gemm_autotuner='"
- << Serialize(config) << "'";
- TF_ASSIGN_OR_RETURN(
- bool has_executable,
- compile(fusion, config, gemm_config_set.size() > 1));
- log(has_executable);
- }
- }
- }
-
- VLOG(1) << "Done compiling (successful: " << good_count.load() << ").";
- return results;
-}
-
-absl::StatusOr<std::vector<AutotuneResult>> GemmFusionAutotunerImpl::Profile(
- AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
- absl::Span<const ExecutableCandidate> candidates) {
- const HloComputation* fusion_computation = fusion.called_computations().at(0);
-
- se::StreamExecutor* stream_exec = config_.GetExecutor();
- if (!stream_exec->SynchronizeAllActivity()) {
- return Internal("Failed to synchronize GPU for autotuning.");
- }
- tsl::profiler::ScopedAnnotation annotation([&] {
- return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
- fusion.name());
- });
- se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
- std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
- if (allocator == nullptr) {
- owned_allocator =
- std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
- allocator = owned_allocator.get();
- }
- TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
-
- const HloInstruction& root = *fusion_computation->root_instruction();
- BufferComparator comparator(root.shape(),
- debug_options_.xla_gpu_autotune_gemm_rtol());
-
- TF_ASSIGN_OR_RETURN(auto rz_buffers,
- RedzoneBuffers::FromInstruction(
- *fusion_computation->FusionInstruction(), config_,
- debug_options_, RedzoneBuffers::kAllInputs));
-
- const int log_every_n = GetLogEveryN();
- std::vector<AutotuneResult> results;
- std::optional<ScopedShapedBuffer> reference_buffer;
- for (const ExecutableCandidate& candidate : candidates) {
- VLOG(5) << "Trying : " << ToString(candidate.config);
- AutotuneResult res = FromConfig(candidate.config);
-
- std::optional<ProfilingOutput> profiling_output;
- if (IsAutotuningEnabled()) {
- TF_ASSIGN_OR_RETURN(
- profiling_output,
- compile_util.ProfileExecutable(candidate.executable.get(), stream,
- rz_buffers.input_buffers(),
- rz_buffers.input_shapes()));
- if (std::holds_alternative<CuBlasConfig>(candidate.config) &&
- config_.should_check_correctness()) {
- reference_buffer = std::move(profiling_output->output);
- }
-
- int ran_so_far = results.size() + 1;
- if (ran_so_far % log_every_n == 0) {
- VLOG(2) << "Ran " << ran_so_far << " configs of " << candidates.size()
- << ".";
- }
- if (!profiling_output) {
- VLOG(5) << "Skipping this tiling.";
- continue;
- }
-
- VLOG(5) << "Running the kernel took: " << profiling_output->duration;
- if (profiling_output->duration >= absl::Seconds(1)) {
- LOG(WARNING) << "Slow kernel for "
- << fusion.called_computations()[0]->ToString()
- << " took: " << profiling_output->duration << ". "
- << ToString(candidate.config);
- }
- *res.mutable_run_time() =
- tsl::proto_utils::ToDurationProto(profiling_output->duration);
- }
-
- // Reference buffer is available when `config.should_check_correctness()`
- // is set and reference executable was compiled.
- if (reference_buffer.has_value() &&
- !std::holds_alternative<CuBlasConfig>(candidate.config)) {
- TF_ASSIGN_OR_RETURN(
- se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
- rz_buffers.RedzoneAllocator().CheckRedzones());
- if (!rz_check_status.ok()) {
- LOG(ERROR) << "Red zone modified";
- res.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
- res.mutable_failure()->set_msg(rz_check_status.RedzoneFailureMsg());
- CHECK(!config_.should_crash_on_check_failure());
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(
- bool outputs_match,
- comparator.CompareEqual(
- stream, /*current=*/profiling_output->output.root_buffer(),
- /*expected=*/reference_buffer->root_buffer()));
- if (!outputs_match) {
- const char kMessage[] =
- "Results do not match the reference. This is likely a "
- "bug/unexpected loss of precision.";
- LOG(ERROR) << kMessage;
- CHECK(!config_.should_crash_on_check_failure());
- // WRONG_RESULT is not taken seriously by PickBestResult(), so
- // use DISQUALIFIED.
- res.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
- res.mutable_failure()->set_msg(kMessage);
- }
- }
- results.push_back(std::move(res));
- }
- VLOG(2) << "Done running.";
- return results;
-}
-
-std::vector<TritonGemmConfig>
-GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
- std::vector<TritonGemmConfig> configs;
- se::CudaComputeCapability cc = GetComputeCapability();
- bool tune_ctas =
- debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
-
- for (int num_stages : kNumStages) {
- // Volta doesn't support num_stages > 2.
- if (!cc.IsAtLeastAmpere() && num_stages > 2) {
- break;
- }
- for (int tile_m : kBlockSizes) {
- for (int tile_n : kBlockSizes) {
- for (int tile_k : kBlockSizes) {
- const int tile_lhs = tile_m * tile_k;
- const int tile_rhs = tile_k * tile_n;
- for (int num_warps : kNumWarps) {
- // Each thread should read at least one input element.
- if (num_warps * WarpSize() > std::min(tile_lhs, tile_rhs)) {
- break;
- }
- for (int split_k : kSplitK) {
- // Split-K autotuning may be disabled by a flag.
- if (!debug_options_.xla_gpu_enable_split_k_autotuning() &&
- split_k > 1) {
- break;
- }
- for (int num_ctas : kNumCtas) {
- // Clusters are only supported on Hopper.
- // Autotuning this parameter is enabled by a flag.
- if (!tune_ctas && num_ctas > 1) {
- break;
- }
- if (num_ctas > num_warps) {
- break;
- }
- configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k,
- split_k, num_stages,
- num_warps, num_ctas));
- }
- }
- }
- }
- }
- }
- }
- return configs;
-}
-
-std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
- const {
- using Config = TritonGemmConfig;
- std::vector<Config> configs = {
- Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
- Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4),
- Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
- Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
- Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
- Config(64, 32, 64, 1, 2, 8)};
- if (GetComputeCapability().IsAtLeastAmpere()) {
- absl::c_copy(
- std::vector<Config>{
- Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8),
- Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4),
- Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4),
- Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4),
- Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4),
- Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4),
- Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4),
- Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4),
- Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8),
- Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4),
- Config(256, 256, 128, 1, 3, 8)},
- std::back_inserter(configs));
- }
- if (GetComputeCapability().IsAtLeastHopper()) {
- absl::c_copy(
- std::vector<Config>{
- Config(16, 32, 32, 8, 1, 2),
- Config(16, 64, 128, 8, 1, 4),
- Config(16, 64, 128, 16, 3, 4),
- },
- std::back_inserter(configs));
- }
- return configs;
-}
-
-absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts,
- const AutotuningLogs& autotuning_logs) {
- if (absl::string_view file_path = debug_opts.xla_gpu_dump_autotune_logs_to();
- !file_path.empty()) {
- std::string resolved_path;
- if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
- return FailedPrecondition("File path can not be resolved: %s", file_path);
- }
-
- std::string textproto;
- tsl::protobuf::TextFormat::PrintToString(autotuning_logs, &textproto);
-
- TF_RETURN_IF_ERROR(
- tsl::WriteStringToFile(tsl::Env::Default(), resolved_path, textproto));
- LOG(INFO) << "Autotune logs serialized to file: " << resolved_path;
- }
- return absl::OkStatus();
-}
-
-absl::Status GemmFusionAutotunerImpl::Autotune(
- AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
- AutoTuneCacheKeyCount fusion_count_map) {
- TF_ASSIGN_OR_RETURN(auto executable_sets,
- CompileAll(compile_util, gemm_config_sets));
-
- // Sort the candidates to make their execution order well-defined for each
- // fusion.
- for (auto& [unused, candidates] : executable_sets) {
- absl::c_sort(candidates, [](const auto& a, const auto& b) {
- return a.config < b.config;
- });
- }
-
- AutotuningLogs autotuning_logs;
- int fusion_id = 0;
- for (const auto& [fusion, candidates] : executable_sets) {
- TF_ASSIGN_OR_RETURN(std::vector<AutotuneResult> results,
- Profile(compile_util, *fusion, candidates));
-
- // The reference config (if it exists) will be the first in the results,
- // due to how sorting the variants work.
- if (!debug_options_.xla_gpu_cublas_fallback() &&
- results.front().has_gemm()) {
- results.erase(results.begin());
- }
-
- const HloInstruction* root =
- fusion->called_computations().at(0)->root_instruction();
- TF_ASSIGN_OR_RETURN(
- AutotuneResult best,
- PickBestResult(results, root->ToString(), root->GetModule()->config()));
- VLOG(2) << "Best time: "
- << tsl::proto_utils::FromDurationProto(best.run_time());
-
- if (debug_options_.xla_gpu_dump_autotuned_gemm_fusions()) {
- TF_RETURN_IF_ERROR(DumpOriginalFusion(compile_util, *fusion, fusion_id));
- TF_RETURN_IF_ERROR(DumpAutotunedFusion(
- config_, toolkit_version_, compile_util, best, fusion, fusion_id++));
- }
-
- const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
- TF_ASSIGN_OR_RETURN(
- bool added, AutotunerUtil::AddResult(key, std::move(best), config_));
- if (!added) {
- // In the context of model server, concurrent autotuning is expected and
- // insertion of identical autotuning keys is accepted.
- LOG(WARNING) << "AutotunerUtil::AddResult already existed: "
- << key.ToString();
- }
-
- if (!debug_options_.xla_gpu_dump_autotune_logs_to().empty()) {
- auto autotuning_log = autotuning_logs.add_logs();
- autotuning_log->set_fusion_name(std::string(fusion->name()));
-
- for (const auto& autotune_result : results) {
- auto log_result = autotuning_log->add_results();
- log_result->CopyFrom(autotune_result);
- }
-
- if (auto fusion_key_count = fusion_count_map.find(key);
- fusion_key_count != fusion_count_map.end()) {
- auto fusion_key = fusion_key_count->first;
- auto fusion_count = fusion_key_count->second;
- autotuning_log->set_fusion_count(fusion_count);
- }
- }
- }
-
- TF_RETURN_IF_ERROR(DumpAutotuningLogs(debug_options_, autotuning_logs));
-
- return absl::OkStatus();
-}
-
-// Trim the set of configs to what one rank has to run.
-static TilingConfigs TrimConfigs(const TilingConfigs& gemm_config_sets,
- const int shard_index, const int shard_count) {
- const uint64_t bucket_size =
- (gemm_config_sets.size() + shard_count - 1) / shard_count;
- const uint64_t start = bucket_size * shard_index;
- const uint64_t end = std::min(start + bucket_size, gemm_config_sets.size());
- if (start >= end) {
- return {};
- }
- return TilingConfigs(gemm_config_sets.cbegin() + start,
- gemm_config_sets.cbegin() + end);
-}
-
-// Exchange the results with the other ranks.
-absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store,
- const int module_id, const int shard_index,
- const int shard_count) {
- AutotuneResults results;
- TF_RETURN_IF_ERROR(AutotunerUtil::SerializeAutotuneResults(&results));
- TF_ASSIGN_OR_RETURN(std::string results_str,
- AutotuneResultsToString(results, true));
- constexpr absl::string_view kKeyPrefix = "gemm_fusion_autotuning_results";
- TF_RETURN_IF_ERROR(key_value_store.Set(
- absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, shard_index),
- results_str));
- VLOG(2) << "Rank " << shard_index << ": published results";
- for (int i = 0; i < shard_count; ++i) {
- if (i == shard_index) {
- continue;
- }
- VLOG(2) << "Rank " << shard_index << ": waiting for results from rank " << i
- << " / " << shard_count;
- TF_ASSIGN_OR_RETURN(
- std::string autotune_results_str,
- key_value_store.Get(
- absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, i),
- absl::InfiniteDuration()));
- TF_RETURN_IF_ERROR(
- AutotunerUtil::LoadAutotuneResults(autotune_results_str, true));
- }
- return absl::OkStatus();
-}
-
-absl::StatusOr<bool> GemmFusionAutotuner::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_SCOPED_LOGGING_TIMER("GEMM fusion autotuner");
-
- const DebugOptions& debug_options = module->config().debug_options();
- GemmFusionAutotunerImpl autotuner(config_, toolkit_version_, debug_options,
- thread_pool_);
- GemmConfigSetCollector gemm_config_set_collector(&autotuner);
- TF_ASSIGN_OR_RETURN(TilingConfigs gemm_config_sets,
- gemm_config_set_collector.CollectGemmConfigSets(
- module, execution_threads));
- const int total_fusion_count = gemm_config_sets.size();
-
- AutoTuneCacheKeyCount fusion_count_map =
- gemm_config_set_collector.GetFusionsCount();
-
- if (!autotuner.IsAutotuningEnabled()) {
- // Pick the first option for each gemm instead of autotuning.
- for (const auto& [fusion, tilings] : gemm_config_sets) {
- const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
- AutotuneResult res = FromConfig(tilings[0]);
- *res.mutable_run_time() =
- tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
- TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
- }
- } else if (!debug_options.xla_gpu_override_gemm_autotuner().empty()) {
- // TODO(gflegar): support overriding with non-Triton configs (cuBLAS, cuDNN)
- AutotuneResult::TritonGemmKey gemm_key;
- CHECK(tsl::protobuf::TextFormat::ParseFromString(
- debug_options.xla_gpu_override_gemm_autotuner(), &gemm_key));
- VLOG(1) << "Overriding GEMM autotuner with the following config: "
- << gemm_key.DebugString();
- for (const auto& [fusion, unused] : gemm_config_sets) {
- const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
- AutotuneResult res;
- *res.mutable_triton() = gemm_key;
- *res.mutable_run_time() =
- tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
- TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
- }
- } else if (!config_.IsDeviceless()) {
- TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
- AutotunerCompileUtil::Create(config_, debug_options));
- TF_RET_CHECK(opt_compile_util.has_value());
- std::string correctness_check_str = config_.should_check_correctness()
- ? "(with correctness check)"
- : "(without correctness check)";
-
- const bool shard_autotuning = debug_options.xla_gpu_shard_autotuning() &&
- key_value_store_.process_count > 1 &&
- total_fusion_count > 0;
- if (shard_autotuning) {
- if (key_value_store_.key_value_store == nullptr) {
- return absl::FailedPreconditionError(
- "Sharded autotuning requested but key-value store is missing.");
- }
- gemm_config_sets =
- TrimConfigs(gemm_config_sets, key_value_store_.process_index,
- key_value_store_.process_count);
- }
-
- VLOG(1) << absl::StrFormat(
- "Shard %d / %d: autotuning %d / %d fusions for %s %s.",
- key_value_store_.process_index + 1, key_value_store_.process_count,
- gemm_config_sets.size(), total_fusion_count, module->name(),
- correctness_check_str);
- TF_RETURN_IF_ERROR(autotuner.Autotune(*opt_compile_util, gemm_config_sets,
- std::move(fusion_count_map)));
- VLOG(1) << "Done autotuning.";
-
- if (shard_autotuning) {
- TF_RETURN_IF_ERROR(ExchangeResults(
- *key_value_store_.key_value_store, module->unique_id(),
- key_value_store_.process_index, key_value_store_.process_count));
- }
- }
-
- return GemmFusionAutotunerVisitor(config_).RunOnModule(module,
- execution_threads);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h
deleted file mode 100644
index 2815792..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h
+++ /dev/null
@@ -1,147 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_
-#define XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-
-// Find best tiling configuration for each triton fusion outlined.
-class GemmFusionAutotuner : public HloModulePass {
- public:
- explicit GemmFusionAutotuner(const AutotuneConfig& config,
- const int32_t toolkit_version,
- tsl::thread::ThreadPool* thread_pool,
- const MultiProcessKeyValueStore& key_value_store)
- : config_(config),
- toolkit_version_(toolkit_version),
- thread_pool_(thread_pool),
- key_value_store_(key_value_store) {}
-
- absl::string_view name() const override { return "triton-autotuner"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const AutotuneConfig config_;
- const int32_t toolkit_version_;
- tsl::thread::ThreadPool* thread_pool_;
- MultiProcessKeyValueStore key_value_store_;
-};
-
-// Autotuner implementation.
-class GemmFusionAutotunerImpl {
- public:
- GemmFusionAutotunerImpl(const AutotuneConfig config,
- const int32_t toolkit_version,
- const DebugOptions debug_options,
- tsl::thread::ThreadPool* thread_pool)
- : config_(std::move(config)),
- toolkit_version_(toolkit_version),
- debug_options_(std::move(debug_options)),
- thread_pool_(thread_pool) {}
-
- struct CuBlasConfig {
- bool operator<(const CuBlasConfig& other) const;
- };
- struct CuDnnConfig {
- int64_t plan_id;
- bool operator<(const CuDnnConfig& other) const;
- };
- using Config = std::variant<CuBlasConfig, CuDnnConfig, TritonGemmConfig>;
- using TilingConfigs =
- std::vector<std::pair<const HloFusionInstruction*, std::vector<Config>>>;
-
- struct ExecutableCandidate {
- Config config;
- std::unique_ptr<Executable> executable;
- };
-
- // Generate all possible configs for a dot operation.
- absl::StatusOr<std::vector<Config>> GenerateConfigs(
- const HloFusionInstruction& fusion);
- absl::StatusOr<std::vector<TritonGemmConfig>> GenerateTritonConfigs(
- const HloDotInstruction& dot);
-
- // Compile all executables for all fusions.
- absl::StatusOr<absl::flat_hash_map<const HloFusionInstruction*,
- std::vector<ExecutableCandidate>>>
- CompileAll(AutotunerCompileUtil& compile_util, const TilingConfigs& task);
-
- // Profile all executables for a fusion.
- absl::StatusOr<std::vector<AutotuneResult>> Profile(
- AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
- absl::Span<const ExecutableCandidate> candidates);
-
- // Autotune and save the results to the autotuning cache.
- absl::Status Autotune(
- AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
- absl::flat_hash_map<AutotuneCacheKey, uint64_t> fusion_count_map);
-
- // Helper methods.
- const AutotuneConfig& GetConfig() const { return config_; }
- bool IsAutotuningEnabled() const;
- static std::string ToString(const Config& config);
-
- private:
- se::CudaComputeCapability GetComputeCapability() const {
- return std::get<se::CudaComputeCapability>(
- config_.GetGpuComputeCapability());
- }
-
- std::vector<TritonGemmConfig> GetDefaultTritonConfigs() const;
- std::vector<TritonGemmConfig> GetExhaustiveTritonConfigs() const;
-
- const AutotuneConfig config_;
- const int32_t toolkit_version_;
- const DebugOptions debug_options_;
- tsl::thread::ThreadPool* thread_pool_;
- std::vector<TritonGemmConfig> triton_configs_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc
deleted file mode 100644
index b0d8ba6..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc
+++ /dev/null
@@ -1,946 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/gemm_fusion_autotuner.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "xla/autotuning.pb.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/service/call_inliner.h"
-#include "xla/service/dump.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gemm_fusion.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_description.pb.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_utils.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/cpu_info.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using HloExtractionTest = HloTestBase;
-
-TEST_F(HloExtractionTest, InstructionExtractionIsCorrect) {
- std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-triton_gemm_dot {
- p0 = s8[10,10] parameter(0)
- p1 = f32[10,10] parameter(1)
- c0 = f32[10,10] convert(p0)
- ROOT dot.0 = f32[10,10] dot(c0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY entry {
- p0 = s8[10,10] parameter(0)
- p1 = f32[10,10] parameter(1)
- s = f32[10,10] sqrt(p1)
- d = f32[10,10] fusion(p0, p1),
- kind=kCustom, calls=triton_gemm_dot
- ROOT r = f32[10,10] add(d, s)
-})")
- .value();
-
- std::unique_ptr<HloModule> extracted_module = ExtractInstructionIntoNewModule(
- *module->entry_computation()->root_instruction()->operand(0));
-
- // Destroy the original module to be sure that the extracted one has no
- // dependency on it.
- module.release();
-
- EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
- EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 3);
- TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
- /*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false));
-}
-
-TEST_F(HloExtractionTest, ComputationExtractionIsCorrect) {
- std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-triton_gemm_dot {
- p0 = s8[10,10] parameter(0)
- p1 = f32[10,10] parameter(1)
- c0 = f32[10,10] convert(p0)
- ROOT dot.0 = f32[10,10] dot(c0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY entry {
- p0 = s8[10,10] parameter(0)
- p1 = f32[10,10] parameter(1)
- s = f32[10,10] sqrt(p1)
- d = f32[10,10] fusion(p0, p1),
- kind=kCustom, calls=triton_gemm_dot
- ROOT r = f32[10,10] add(d, s)
-})")
- .value();
-
- std::unique_ptr<HloModule> extracted_module =
- ExtractComputationIntoNewModule(*module->entry_computation()
- ->root_instruction()
- ->operand(0)
- ->fused_instructions_computation());
-
- // Destroy the original module to be sure that the extracted one has no
- // dependency on it.
- module.release();
-
- EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter())));
- EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 4);
- TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
- /*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false));
-}
-
-class StatelessAutotunerTest : public HloTestBase {
- public:
- StatelessAutotunerTest()
- : HloTestBase(/*verifier_layout_sensitive=*/true,
- /*allow_mixed_precision_in_hlo_verifier=*/false) {}
-
- int32_t GetToolkitVersion() const { return CUDA_VERSION; }
-
- void SetUp() override {
- AutotunerUtil::ClearAutotuneResults();
- HloTestBase::SetUp();
- }
-
- void TearDown() override {
- AutotunerUtil::ClearAutotuneResults();
- HloTestBase::TearDown();
- }
-};
-
-class GemmFusionAutotunerTest : public StatelessAutotunerTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options =
- StatelessAutotunerTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_triton_gemm(true);
- debug_options.set_xla_gpu_cublas_fallback(false);
- debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0);
- return debug_options;
- }
-
- se::CudaComputeCapability GetCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
-
- void CheckTritonAutotuning(absl::string_view hlo,
- absl::string_view expected) {
- HloPassPipeline pipeline("gemm_rewrite");
- pipeline.AddPass<GemmFusion>(backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability());
- tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
- tsl::port::MaxParallelism());
- DebugOptions opts;
- MultiProcessKeyValueStore key_value_store;
- pipeline.AddPass<GemmFusionAutotuner>(
- AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
- backend().memory_allocator()},
- opts},
- GetToolkitVersion(), &thread_pool, key_value_store);
-
- RunAndFilecheckHloRewrite(
- hlo, std::move(pipeline), expected, [](const HloModule* m) {
- VLOG(5) << m->ToString();
- const HloInstruction* dot_fusion =
- m->entry_computation()->root_instruction();
- if (dot_fusion->opcode() == HloOpcode::kReduce) {
- dot_fusion = dot_fusion->operand(0);
- }
- CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion);
- if (!dot_fusion->backend_config<GpuBackendConfig>()
- ->fusion_backend_config()
- .has_cudnn_fusion_config()) {
- CHECK_GT(dot_fusion->backend_config<GpuBackendConfig>()
- .value()
- .fusion_backend_config()
- .triton_gemm_config()
- .block_m(),
- 0);
- }
- });
- }
-};
-
-class GemmFusionAutotunerTestWithMorePreciseReduction
- : public GemmFusionAutotunerTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options =
- GemmFusionAutotunerTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(
- true);
- return debug_options;
- }
-};
-
-absl::StatusOr<std::vector<TritonGemmConfig>> GetPossibleMatmulAutotuneConfigs(
- const HloDotInstruction& dot,
- const se::CudaComputeCapability& compute_capability,
- const int32_t toolkit_version, const DebugOptions& debug_options) {
- se::GpuDeviceInfoProto deviceless_proto;
- auto ccc = deviceless_proto.mutable_cuda_compute_capability();
- ccc->set_major(compute_capability.major);
- ccc->set_minor(compute_capability.minor);
- DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}};
- AutotuneConfig autotune_config{test_config, debug_options};
- GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version,
- debug_options, nullptr);
- return autotuner.GenerateTritonConfigs(dot);
-}
-
-TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) {
- std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[1024,1024] parameter(0)
- p1 = f32[1024,1024] parameter(1)
- ROOT r = f32[1024,1024] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
- .value();
- const se::CudaComputeCapability compute_capability{
- se::CudaComputeCapability::AMPERE, /*minor=*/0};
- TF_ASSERT_OK_AND_ASSIGN(
- const std::vector<TritonGemmConfig> configs,
- GetPossibleMatmulAutotuneConfigs(
- *Cast<HloDotInstruction>(
- module->entry_computation()->root_instruction()),
- compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
- EXPECT_TRUE(std::any_of(
- configs.begin(), configs.end(),
- [](const TritonGemmConfig& config) { return config.num_stages > 2; }));
-}
-
-TEST_F(GemmFusionAutotunerTest, SmallOutputCanUseLargeSplitK) {
- std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[1024,1024] parameter(0)
- p1 = f32[1024,1024] parameter(1)
- ROOT r = f32[1024,1024] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
- .value();
- const se::CudaComputeCapability compute_capability{
- se::CudaComputeCapability::AMPERE, /*minor=*/0};
- TF_ASSERT_OK_AND_ASSIGN(
- const std::vector<TritonGemmConfig> configs,
- GetPossibleMatmulAutotuneConfigs(
- *Cast<HloDotInstruction>(
- module->entry_computation()->root_instruction()),
- compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
- EXPECT_TRUE(std::any_of(
- configs.begin(), configs.end(),
- [](const TritonGemmConfig& config) { return config.split_k >= 4; }));
-}
-
-TEST_F(GemmFusionAutotunerTest, LargeOutputDoesNotUseLargeSplitK) {
- std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[20480,20480] parameter(0)
- p1 = f32[20480,20480] parameter(1)
- ROOT r = f32[20480,20480] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
- .value();
- const se::CudaComputeCapability compute_capability{
- se::CudaComputeCapability::AMPERE, /*minor=*/0};
- TF_ASSERT_OK_AND_ASSIGN(
- const std::vector<TritonGemmConfig> configs,
- GetPossibleMatmulAutotuneConfigs(
- *Cast<HloDotInstruction>(
- module->entry_computation()->root_instruction()),
- compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
- EXPECT_FALSE(std::any_of(
- configs.begin(), configs.end(),
- [](const TritonGemmConfig& config) { return config.split_k > 1; }));
-}
-
-TEST_F(GemmFusionAutotunerTest, Int8FusedGemm) {
- const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
- x = s8[128,64] parameter(0)
- c = f16[128,64] convert(x)
-
- y = f16[64,6144] parameter(1)
-
- ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
- CheckTritonAutotuning(hlo, R"(
-// CHECK: ENTRY
-// CHECK: ROOT
-// CHECK-SAME: kCustom
-// CHECK-SAME: block_m
-)");
-
- EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3}));
-}
-
-TEST_F(GemmFusionAutotunerTest, Int8FusedGemm256) {
- const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
- x = s8[128,256] parameter(0)
- c = f16[128,256] convert(x)
-
- y = f16[256,6144] parameter(1)
-
- ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- CheckTritonAutotuning(hlo, R"(
-// CHECK: ENTRY
-// CHECK: ROOT
-// CHECK-SAME: kCustom
-// CHECK-SAME: block_m
-)");
-
- EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}));
-}
-
-TEST_F(GemmFusionAutotunerTest, SelectsSplitK) {
- // Shapes with K >> M, N have to force split-K configurations.
- const std::string kHloText = R"(
-HloModule t
-
-ENTRY e {
- p0 = s8[7,8192] parameter(0)
- p0c = f16[7,8192] convert(p0)
- p1 = f16[8192,18] parameter(1)
- ROOT dot.0 = f16[7,18] dot(p0c, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
- MatchOptimizedHlo(kHloText, R"(
-; CHECK: reduce
-; CHECK: ENTRY
-; CHECK-NEXT: parameter
-; CHECK-NEXT: parameter
-; CHECK-NEXT: kCustom
-; CHECK-NEXT: kLoop
-)");
-
- EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1, /*arel=*/0.5}));
-}
-
-TEST_F(GemmFusionAutotunerTestWithMorePreciseReduction, SelectsSplitK) {
- // Shapes with K >> M, N have to force split-K configurations.
- constexpr absl::string_view kHloText = R"(
-HloModule t
-
-ENTRY e {
- p0 = s8[7,8192] parameter(0)
- p0c = f16[7,8192] convert(p0)
- p1 = f16[8192,18] parameter(1)
- ROOT dot.0 = f16[7,18] dot(p0c, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
- MatchOptimizedHlo(kHloText, R"(
-; CHECK: reduce
-; CHECK: ENTRY
-; CHECK-NEXT: parameter
-; CHECK-NEXT: parameter
-; CHECK-NEXT: kCustom
-; CHECK-NEXT: kLoop
-)");
-
- EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3}));
-}
-
-TEST_F(GemmFusionAutotunerTest, ApplySplitKWithoutAlteringTiling) {
- const std::string kHloText = R"(
-triton_dot {
- p0 = f16[55,120] parameter(0)
- p1 = f16[120,20] parameter(1)
- ROOT dot = f16[55,20] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY e {
- p0 = f16[55,120]{1,0} parameter(0)
- p1 = f16[120,20]{1,0} parameter(1)
- ROOT _ = f16[55,20] fusion(p0, p1), kind=kCustom, calls=triton_dot,
- backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}}}
-})";
-
- MatchOptimizedHlo(kHloText, R"(
-; CHECK: f16[3,55,20]
-; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}
-; CHECK: f16[55,20]{1,0} {{(reduce|fusion)}}
-)");
-
- EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
-}
-
-// Modify block_k back to 16 once b/337839570 is fixed.
-// TODO(b/344770374): Make this test not fragile.
-TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
- const std::string kHloText = R"(
-HloModule m
-
-%triton_gemm_dot {
- %p1 = s8[4,12288]{1,0} parameter(1)
- %p0 = s8[12288,1536]{1,0} parameter(0)
- %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
- %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
- %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
-}
-
-ENTRY %e {
- %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
- %convert = s8[4,12288]{1,0} parameter(1)
- ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
- backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
-})";
-
- auto module = ParseAndReturnVerifiedModule(kHloText).value();
- EXPECT_THAT(
- backend().compiler()->RunBackend(std::move(module),
- backend().default_stream_executor(),
- {/*device_allocator=*/nullptr,
- /*thread_pool=*/nullptr,
- /*layout_canonicalization_callback=*/{},
- /*is_autotuning_compilation=*/true}),
- ::testing::AnyOf(
- tsl::testing::StatusIs(
- tsl::error::CANCELLED,
- absl::StrFormat(
- "Compilation result discarded due to register spilling")),
- // Hopper can't spill registers since wgmma instructions are
- // asynchronous, instead it just runs out of them.
- tsl::testing::StatusIs(
- tsl::error::RESOURCE_EXHAUSTED,
- absl::StrFormat("Register allocation failed"))));
-}
-
-// Modify block_k back to 16 once b/337839570 is fixed.
-// TODO(b/344770374): Make this test not fragile.
-TEST_F(GemmFusionAutotunerTest,
- DoNotFilterOutAutotuningKernelSpillingRegisters) {
- if (GetCudaComputeCapability().IsAtLeastHopper()) {
- GTEST_SKIP() << "Hopper and newer runs out of registers for such HLOs";
- }
- const std::string kHloText = R"(
-HloModule m
-
-%triton_gemm_dot {
- %p1 = s8[4,12288]{1,0} parameter(1)
- %p0 = s8[12288,1536]{1,0} parameter(0)
- %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
- %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
- %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
-}
-
-ENTRY %e {
- %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
- %convert = s8[4,12288]{1,0} parameter(1)
- ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
- backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
-})";
-
- auto module = ParseAndReturnVerifiedModule(kHloText).value();
- HloModuleConfig config = module->config();
- DebugOptions debug_options = config.debug_options();
- debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
- false);
- config.set_debug_options(debug_options);
- module->set_config(config);
-
- std::unique_ptr<Executable> executable =
- backend()
- .compiler()
- ->RunBackend(std::move(module), backend().default_stream_executor(),
- {/*device_allocator=*/nullptr,
- /*thread_pool=*/nullptr,
- /*layout_canonicalization_callback=*/{},
- /*is_autotuning_compilation=*/true})
- .value();
- EXPECT_NE(executable, nullptr);
-}
-
-// Modify block_k back to 16 once b/337839570 is fixed.
-TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) {
- const std::string kHloText = R"(
-HloModule m
-
-%triton_gemm_dot {
- %p1 = f16[4,12288]{1,0} parameter(1)
- %p0 = s8[12288,1536]{1,0} parameter(0)
- %convert.10406 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
- ROOT %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %p1, f16[12288,1536]{1,0} %convert.10406), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY %e {
- %p0 = s8[12288,1536]{1,0} parameter(0)
- %p1 = f16[4,12288]{1,0} parameter(1)
- ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot,
- backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}}
-})";
-
- auto module = ParseAndReturnVerifiedModule(kHloText).value();
- std::unique_ptr<Executable> executable =
- backend()
- .compiler()
- ->RunBackend(std::move(module), backend().default_stream_executor(),
- {/*device_allocator=*/nullptr,
- /*thread_pool=*/nullptr,
- /*layout_canonicalization_callback=*/{},
- /*is_autotuning_compilation=*/true})
- .value();
- EXPECT_NE(executable, nullptr);
-}
-
-using GemmFusionAutotunerDumpTest = GemmFusionAutotunerTest;
-
-TEST_F(GemmFusionAutotunerDumpTest, Fp8CublasltFallbackSupport) {
- const std::string kHloText = R"(
-HloModule o
-
-gemm_fusion {
- p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
- p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
- ROOT %dot.0 = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-}
-
-ENTRY main {
- p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
- p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
- ROOT %dot.0 = f32[64,64]{1,0} fusion(p0, p1), kind=kCustom, calls=gemm_fusion, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(kHloText));
-
- DebugOptions opts;
- AutotuneConfig autotune_config{
- DeviceConfig{backend().default_stream_executor(),
- backend().memory_allocator()},
- opts};
- AutotuneCacheKey cache_key(autotune_config.GetModelStr(),
- *module->entry_computation()->root_instruction());
-
- TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override,
- ParseTextProto<AutotuneResults>(R"pb(
- version: 3
- results {
- device: "..."
- hlo: "..."
- result {
- gemm { algorithm: -1 }
- run_time { nanos: 14 }
- }
- })pb"));
- autotune_results_override.mutable_results(0)->set_device(
- std::string(cache_key.GetModelStr()));
- autotune_results_override.mutable_results(0)->set_hlo(
- std::string(cache_key.GetHlo()));
- CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override));
-
- HloPassPipeline pipeline("gemm_autotune");
- tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
- tsl::port::MaxParallelism());
- MultiProcessKeyValueStore key_value_store;
- pipeline.AddPass<GemmFusionAutotuner>(autotune_config, GetToolkitVersion(),
- &thread_pool, key_value_store);
- pipeline.AddPass<CallInliner>();
- for (bool fp8_rewrite : {true, false}) {
- pipeline.AddPass<GemmRewriter>(autotune_config.GetGpuComputeCapability(),
- GetToolkitVersion(), fp8_rewrite);
- }
-
- TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get()));
- const bool is_at_least_hopper =
- std::holds_alternative<se::CudaComputeCapability>(
- autotune_config.GetGpuComputeCapability()) &&
- std::get<se::CudaComputeCapability>(
- autotune_config.GetGpuComputeCapability())
- .IsAtLeastHopper();
- TF_ASSERT_OK_AND_ASSIGN(
- bool filecheck_matches,
- RunFileCheck(module->ToString(), is_at_least_hopper
- ? "// CHECK: __cublas$lt"
- : "// CHECK: __cublas$gemm"));
- EXPECT_TRUE(filecheck_matches);
-}
-
-TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) {
- HloModuleConfig config;
- DebugOptions options = GetDebugOptionsForTest();
- options.set_xla_gpu_cublas_fallback(true);
- options.set_xla_gpu_dump_autotuned_gemm_fusions(true);
- std::string output_directory;
- if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) {
- output_directory = tsl::testing::TmpDir();
- }
- options.set_xla_dump_to(output_directory);
- config.set_debug_options(options);
- // Computation is chosen such that relatively heavy math operations before the
- // GEMM are not worth fusing because they would get duplicated many times and
- // slow down execution. Therefore autotuning picks cuBLAS here.
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-fusion1 {
- p0 = f32[333,333] parameter(0)
- s = f32[333,333] sine(p0)
- p1 = f32[333,333] parameter(1)
- c = f32[333,333] cosine(p1)
- ROOT dot = f32[333,333] dot(s, c),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY e {
- p0 = f32[333,333] parameter(0)
- p1 = f32[333,333] parameter(1)
- ROOT rr = f32[333,333] fusion(p0, p1), kind=kCustom, calls=fusion1,
- backend_config={"fusion_backend_config": {kind: "__triton_gemm"}}
-})",
- config));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(std::move(module)));
-
- std::string dump;
- TF_EXPECT_OK(tsl::ReadFileToString(
- tsl::Env::Default(),
- tsl::io::JoinPath(output_directory,
- FilenameFor(*optimized_module, /*prefix=*/"",
- /*suffix=*/"gemm_fusion_0.rr.txt")),
- &dump));
- EXPECT_TRUE(*RunFileCheck(dump, R"(
-CHECK: HloModule rr
-CHECK-NOT: cublas
-CHECK: __triton_gemm
-CHECK-NOT: block_m
-)"));
-
- dump.clear();
-
- TF_EXPECT_OK(tsl::ReadFileToString(
- tsl::Env::Default(),
- tsl::io::JoinPath(
- output_directory,
- FilenameFor(*optimized_module, /*prefix=*/"",
- /*suffix=*/"gemm_fusion_0.rr.optimized.txt")),
- &dump));
- EXPECT_TRUE(*RunFileCheck(dump, R"(
-CHECK: HloModule rr
-CHECK-NOT: triton
-CHECK: cublas
-)"));
-}
-
-TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) {
- const std::string kHlo = R"(
-fusion1 {
- p0 = f32[3,28,32] parameter(0)
- p1 = f32[3,28,32] parameter(1)
- ROOT d = f32[3,32,32] dot(p0, p1),
- lhs_batch_dims={0}, rhs_batch_dims={0},
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-}
-
-ENTRY e {
- p0 = f32[3,28,32] parameter(0)
- p1 = f32[3,28,32] parameter(1)
- ROOT _ = f32[3,32,32] fusion(p0, p1), kind=kCustom, calls=fusion1,
- backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
-})";
-
- CheckTritonAutotuning(kHlo, R"(
-// CHECK: "plan_id":
-)");
-}
-
-// TODO(b/281489442): Write a testcase called
-// `SkipConfigsProducingDeviantResults` or similar.
-
-class GemmFusionAutotunerLevelTest : public StatelessAutotunerTest,
- public ::testing::WithParamInterface<int> {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options =
- StatelessAutotunerTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_autotune_level(GetParam());
- debug_options.set_xla_gpu_cublas_fallback(false);
- return debug_options;
- }
-};
-
-TEST_P(GemmFusionAutotunerLevelTest, AllAutotuningLevelsWorkCorrectly) {
- const std::string kHloText = R"(
-HloModule m
-
-ENTRY e {
- p0 = pred[64,10] parameter(0)
- p0c = f32[64,10] convert(p0)
- p1 = f32[10,128] parameter(1)
- ROOT r = f32[64,128] dot(p0c, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
- MatchOptimizedHlo(kHloText, R"(
-; CHECK: kind=kCustom
-; CHECK-SAME: block_m
- )");
-
- EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
-}
-
-TEST_P(GemmFusionAutotunerLevelTest, Deviceless) {
- const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
- x = s8[16,16] parameter(0)
- c = f16[16,16] convert(x)
- y = f16[16,16] parameter(1)
- ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- HloPassPipeline pipeline("gemm_rewrite_deviceless");
- pipeline.AddPass<GemmFusion>(backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability());
- tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
- tsl::port::MaxParallelism());
- DebugOptions opts;
- MultiProcessKeyValueStore key_value_store;
- pipeline.AddPass<GemmFusionAutotuner>(
- AutotuneConfig{
- DevicelessConfig{
- backend().default_stream_executor()->GetDeviceDescription()},
- opts},
- GetToolkitVersion(), &thread_pool, key_value_store);
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(hlo));
- if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) {
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloTestBase::RunHloPass(&pipeline, module.get()));
- EXPECT_TRUE(changed);
-
- // Check default configuration.
- TF_ASSERT_OK_AND_ASSIGN(
- bool filecheck_matches,
- RunFileCheck(
- module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
- R"(
-// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"16","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}},"force_earliest_schedule":false}
- )"));
- EXPECT_TRUE(filecheck_matches);
- } else {
- EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()),
- tsl::testing::StatusIs(
- tsl::error::INTERNAL,
- ::testing::HasSubstr(
- "Expect autotune result cache hit for deviceless")));
- }
-}
-
-INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerLevelSweep,
- GemmFusionAutotunerLevelTest, ::testing::Range(0, 5));
-
-class GemmFusionAutotunerExhaustiveTest : public GemmFusionAutotunerTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options =
- GemmFusionAutotunerTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_exhaustive_tiling_search(true);
- return debug_options;
- }
-};
-
-TEST_F(GemmFusionAutotunerExhaustiveTest, DISABLED_CompileOnly) {
- const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
- x = s8[16,16] parameter(0)
- c = f16[16,16] convert(x)
- y = f16[16,16] parameter(1)
- ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- CheckTritonAutotuning(hlo, R"(
-// CHECK: %triton_gemm_out_computation (
-// CHECK: ROOT %out.1 = f16[16,16]{1,0} dot(%c.1, %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-// CHECK: ROOT %triton_gemm_out = f16[16,16]{1,0} fusion(%x, %y), kind=kCustom, calls=%triton_gemm_out_computation
-// CHECK-SAME: "block_m":
-)");
-}
-
-// TODO(b/337839570): Triton currently has a limitation where it crashes
-// on small block_k values depending on the bit-width of the inputs to the
-// dot. For this test case, it should skip any block_k values that are <= 16
-// since the smallest type has a bit-width of 8.
-TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingTileKConfig) {
- std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-ENTRY e {
- x = s8[33,33]{1,0} parameter(0)
- c = f16[33,33]{1,0} convert(x)
- y = f16[33,33]{1,0} parameter(1)
- ROOT out = f16[33,33]{1,0} dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)")
- .value();
- const se::CudaComputeCapability compute_capability{
- se::CudaComputeCapability::AMPERE, /*minor=*/0};
- TF_ASSERT_OK_AND_ASSIGN(
- const std::vector<TritonGemmConfig> configs,
- GetPossibleMatmulAutotuneConfigs(
- *Cast<HloDotInstruction>(
- module->entry_computation()->root_instruction()),
- compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
- EXPECT_TRUE(std::all_of(
- configs.begin(), configs.end(),
- [](const TritonGemmConfig& config) { return config.block_k > 16; }));
-}
-
-class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options =
- GemmFusionAutotunerTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_split_k_autotuning(false);
- return debug_options;
- }
-};
-
-TEST_F(GemmFusionAutotunerDisableSplitK, SplitKIsDisabled) {
- std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[1024,1024] parameter(0)
- p1 = f32[1024,1024] parameter(1)
- ROOT r = f32[1024,1024] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
- .value();
- const se::CudaComputeCapability compute_capability{
- se::CudaComputeCapability::AMPERE, /*minor=*/0};
- TF_ASSERT_OK_AND_ASSIGN(
- const std::vector<TritonGemmConfig> configs,
- GetPossibleMatmulAutotuneConfigs(
- *Cast<HloDotInstruction>(
- module->entry_computation()->root_instruction()),
- compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
- EXPECT_TRUE(std::all_of(
- configs.begin(), configs.end(),
- [](const TritonGemmConfig& config) { return config.split_k == 1; }));
-}
-
-class GemmFusionAutotunerConfigTest
- : public StatelessAutotunerTest,
- public ::testing::WithParamInterface<bool> {};
-
-TEST_P(GemmFusionAutotunerConfigTest, SparseDotDiscardsUnsupportedTiles) {
- const std::string kHloText = R"(
-HloModule test
-ENTRY wais {
- lhs = f16[5,1600] parameter(0)
- rhs = f16[3200,10] parameter(1)
- meta = u16[5,200] parameter(2)
- ROOT dot = f32[5,10] dot(lhs, rhs, meta),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
- const se::CudaComputeCapability compute_capability{
- se::CudaComputeCapability::AMPERE, /*minor=*/0};
- DebugOptions debug_options = GetDebugOptionsForTest();
- debug_options.set_xla_gpu_exhaustive_tiling_search(GetParam());
-
- TF_ASSERT_OK_AND_ASSIGN(
- const std::vector<TritonGemmConfig> configs,
- GetPossibleMatmulAutotuneConfigs(
- *Cast<HloDotInstruction>(
- module->entry_computation()->root_instruction()),
- compute_capability, GetToolkitVersion(), debug_options));
- for (const auto& config : configs) {
- int metadata_size = config.block_m * config.block_k / 16;
- EXPECT_LE(config.num_warps * WarpSize(), metadata_size);
- EXPECT_GT(config.block_k, 16); // kMinTileSize
- }
-}
-
-INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep,
- GemmFusionAutotunerConfigTest, ::testing::Bool());
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc
deleted file mode 100644
index 44430f5..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc
+++ /dev/null
@@ -1,1334 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_fusion.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/triton_fusion_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-using ::testing::FieldsAre;
-
-namespace m = ::xla::match;
-
-class GemmFusionTest : public HloTestBase {
- public:
- GemmFusionTest()
- : HloTestBase(/*verifier_layout_sensitive=*/true,
- /*allow_mixed_precision_in_hlo_verifier=*/false) {}
-
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_triton_gemm_any(false);
- debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
- return debug_options;
- }
-
- se::GpuComputeCapability gpu_version_{
- se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}};
-
- void MatchHloModule(HloModule& module, absl::string_view pattern) {
- TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result,
- RunFileCheck(module.ToString(), pattern));
- EXPECT_TRUE(filecheck_result);
- }
-};
-
-TEST_F(GemmFusionTest, TransposeSubdimensionGroup) {
- // This HLO is artificial because unnecessary reshapes get optimized
- // out during compilation. It tests the ability of GemmFusion
- // to handle transposes of groups of subdimensions.
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
- p0 = f32[32,3] parameter(0)
- t1 = f32[3,32] transpose(p0), dimensions={1,0}
- r1 = f32[3,8,4] reshape(t1)
- r0 = f32[3,32] reshape(r1)
- p1 = f16[32,7] parameter(1)
- c1 = f32[32,7] convert(p1)
- ROOT d = f32[3,7] dot(r0, c1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
- .value();
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, UnsupportedTransposeIsNotFused) {
- auto module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f16[1,512,8,1024]{3,1,0,2} parameter(0)
- c = f16[1,512,8,1024]{3,2,1,0} copy(p0)
- b = f16[4096,1024]{1,0} bitcast(c)
- p1 = f16[128,1024]{1,0} parameter(1)
- ROOT d = f16[4096,128]{1,0} dot(b, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})")
- .value();
- EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, BitcastChain) {
- // This HLO is artificial because unnecessary reshapes get optimized
- // out during compilation. It tests the ability of GemmFusion
- // to handle various kinds of bitcasts.
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
- p0 = s8[60,5] parameter(0)
- r0 = s8[3,20,5] reshape(p0)
- c0 = f16[3,20,5] convert(r0)
- p1 = f16[3,200] parameter(1)
- r12 = f16[600] reshape(p1)
- r11 = f16[30,20] reshape(r12)
- r1 = f16[3,10,20] reshape(r11)
- ROOT d = f16[3,5,10] dot(c0, r1),
- lhs_contracting_dims={1}, rhs_contracting_dims={2},
- lhs_batch_dims={0}, rhs_batch_dims={0}
-})")
- .value();
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, SplitDimensionTwice) {
- auto module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = s8[4,2,32,4,2] parameter(0)
- r1 = s8[8,32,8] reshape(p0)
- t1 = s8[32,8,8] transpose(r1), dimensions={1,0,2}
- r0 = s8[32,64] reshape(t1)
- p1 = s8[32,32] parameter(1)
- c0 = f16[32,32] convert(p1)
- ROOT d = f16[64,32] dot(r0, c0),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})")
- .value();
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, DoNotTriggerOnUnsupportedOutputConversions) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f16[128,256] parameter(0)
- p1 = f16[256,512] parameter(1)
- r = f16[128,512] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT c = u8[128,512] convert(r)
-})"));
- EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, FuseDotWithTrivialNoncontractingDim) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
- p0 = s8[60,5] parameter(0)
- r0 = s8[3,20,5] reshape(p0)
- c0 = f16[3,20,5] convert(r0)
- p1 = f16[3,1,20] parameter(1)
- ROOT d = f16[3,5,1] dot(c0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={2},
- lhs_batch_dims={0}, rhs_batch_dims={0}
-})")
- .value();
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, HandleDotIfCublasRequiresPadding) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[5,3] parameter(0)
- p1 = f16[5,7] parameter(1)
- ROOT d = f16[3,7] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_TRUE(CublasRequiresPadding(
- *xla::Cast<HloDotInstruction>(
- module->entry_computation()->root_instruction()),
- cc));
- EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, FuseSliceOfParameterWithOtherUsers) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[97,121] parameter(0)
- s0 = f32[7,101] slice(p0), slice={[3:10], [10:111]}
- p1 = f32[101,16] parameter(1)
- d = f32[16,7] dot(p1, s0),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
- s1 = f32[3,33] slice(p0), slice={[10:13], [20:53]}
- ROOT t = tuple(d, s1)
-})"));
-
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, DoNotFuseSliceOfMixedDimensions) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = bf16[768,64] parameter(0)
- s0 = bf16[768,32] slice(p0), slice={[0:768], [0:32]}
- b0 = bf16[256,3,32] reshape(s0)
- b1 = bf16[256,96] reshape(b0)
- p1 = bf16[256,96] parameter(1)
- ROOT d = bf16[96,96] dot(b1, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, DoNotFuseSlicesOfNonMajorFragments) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[2,2,256,256] parameter(0)
- s0 = f32[1,1,256,256] slice(p0),
- slice={[0:1], [0:1], [0:256], [0:256]}
- r0 = f32[256,256] reshape(s0)
- p1 = f16[2,2,256,256] parameter(1)
- s1 = f16[1,1,256,256] slice(p1),
- slice={[0:1], [0:1], [0:256], [0:256]}
- r1 = f16[256,256] reshape(s1)
- ROOT d = f32[256,256] dot(r0, r1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, DynamicSliceIsFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- dot_lhs = f32[2,18] parameter(0)
- dynamic_slice_input = f32[2,64,2] parameter(1)
- start_index0 = s32[] parameter(2)
- start_index1_2 = s32[] constant(0)
- dynamic_slice = f32[1,64,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1_2, start_index1_2),
- dynamic_slice_sizes={1,64,2}
- reshape = f32[64,2] reshape(dynamic_slice)
- ROOT dot = f16[18,64] dot(dot_lhs, reshape),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
- m::Parameter(), m::Constant()))));
-}
-
-TEST_F(GemmFusionTest, DynamicSlicesAreFusedEvenIfTheyShareIndices) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[2,64,2] parameter(0)
- p1 = s32[] parameter(1)
- p2 = s32[] parameter(2)
- p3 = s32[] parameter(3)
- ds0 = f32[1,64,2] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,64,2}
- a = f32[64,2] reshape(ds0)
- ds1 = f32[1,64,2] dynamic-slice(p0, p3, p2, p1), dynamic_slice_sizes={1,64,2}
- b = f32[64,2] reshape(ds1)
- ROOT d = f16[64,64] dot(a, b),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- // TODO(b/339810582): Don't duplicate scalar parameters to dot fusions,
- // because they are never tiled differently.
- // TODO(b/339814210): Don't count scalar parameters towards dot fusion
- // parameter limit.
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
- m::Parameter(), m::Parameter(), m::Parameter(),
- m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionTest, DoNotFuseDynamicSliceOfNonMajorFragments) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- dot_lhs = f32[2,4]{1,0} parameter(0)
- dynamic_slice_input = f32[4,5,2]{2,1,0} parameter(1)
- c0 = s32[] constant(0)
- c2 = s32[] constant(2)
- dynamic_slice = f32[4,1,2]{2,1,0} dynamic-slice(dynamic_slice_input, c0, c2, c0),
- dynamic_slice_sizes={4,1,2}
- reshape = f32[4,2]{1,0} reshape(dynamic_slice)
- ROOT dot = f32[4,4]{1,0} dot(dot_lhs, reshape),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- // FusionDecision "Unsupported dynamic slice on non-major-most dimension."
- EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- dot_lhs = f32[2,4]{1,0} parameter(0)
- dynamic_slice_input = f32[5,5]{1,0} parameter(1)
- start_index0 = s32[] constant(2)
- start_index1 = s32[] constant(0)
- dynamic_slice = f32[2,5]{1,0} dynamic-slice(dynamic_slice_input, start_index0, start_index1),
- dynamic_slice_sizes={2,5}
- ROOT d = f32[4,5]{1,0} dot(dot_lhs, dynamic_slice),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
- m::Constant(), m::Constant()))));
-}
-
-TEST_F(GemmFusionTest, SliceToDegenerateIsSkipped) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p = f32[3] parameter(0)
- s = f32[1] slice(p), slice={[2:3]}
- r = f32[] reshape(s)
- b = f32[3,3] broadcast(r), dimensions={}
- ROOT d = f32[3,3] dot(b, b),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)"));
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-
- ASSERT_TRUE(GemmFusion(cc).Run(module.get()).value());
-
- // Slice is not fused.
- MatchHloModule(*module, R"(
-; CHECK-NOT: slice
-; CHECK: ENTRY
-; CHECK: slice
-)");
-}
-
-TEST_F(GemmFusionTest, MultipleUsesAreHandled) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- c = f32[] constant(1)
- b = f32[6,8] broadcast(c), dimensions={}
- p0 = f32[6,8] parameter(0)
- a1 = f32[6,8] add(p0, b)
- e = f32[6,8] exponential(a1)
- a2 = f32[6,8] add(e, b)
- d = f32[6,8] divide(b, a2)
- p2 = f16[8,6] parameter(1)
- cv = f32[8,6] convert(p2)
- ROOT r = f32[6,6] dot(d, cv),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, BinaryElementwiseOfBroadcastIsFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p2 = f32[3072] parameter(2)
- b = f32[8192,3072] broadcast(p2), dimensions={1}
- p0 = f16[8192,3072] parameter(0)
- p0c = f32[8192,3072] convert(p0)
- a = f32[8192,3072] add(p0c, b)
- p1 = f32[3072,768] parameter(1)
- ROOT r = f32[8192,768] dot(a, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, BinaryElementwiseOfUnsupportedBroadcastIsNotFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p2 = f32[768] parameter(2)
- b = f32[8192,768,4] broadcast(p2), dimensions={1}
- s = f32[8192,3072] bitcast(b)
- p0 = f16[8192,3072] parameter(0)
- p0c = f32[8192,3072] convert(p0)
- a = f32[8192,3072] add(p0c, s)
- p1 = f32[3072,768] parameter(1)
- ROOT r = f32[8192,768] dot(a, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-class GemmFusionLevel2Test : public GemmFusionTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_triton_fusion_level(2);
- return debug_options;
- }
-};
-
-TEST_F(GemmFusionTest, ConcatenationDivisibleBy64IsFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = bf16[8192,1]{1,0} parameter(0)
- p1 = bf16[2752,8192]{1,0} parameter(1)
- p2 = bf16[2752,8192]{1,0} parameter(2)
- concat = bf16[5504,8192]{1,0} concatenate(p1, p2), dimensions={0}
- bitcast = bf16[8192,5504]{0,1} bitcast(concat)
- ROOT r = f32[1,5504]{1,0} dot(p0, bitcast),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
- const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
- EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionLevel2Test, ReshapeToScalarIsHandled) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = s8[5,3] parameter(0)
- c = f16[5,3] convert(p0)
- p1 = f16[1] parameter(1)
- r = f16[] reshape(p1)
- b = f16[5,7] broadcast(r)
- ROOT d = f16[3,7] dot(c, b),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionLevel2Test, DoNotFuseIncompatibleDimensionSplits) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p1 = s8[5,7,2,3]{3,2,1,0} parameter(1)
- t1 = s8[7,5,2,3]{3,2,1,0} transpose(p1), dimensions={1,0,2,3}
- r1 = s8[7,30]{1,0} reshape(t1)
- cvt = f16[7,30]{1,0} convert(r1)
- p2 = f16[2,7,5,3]{3,2,1,0} parameter(2)
- t2 = f16[7,2,5,3]{3,2,1,0} transpose(p2), dimensions={1,0,2,3}
- r2 = f16[7,30]{1,0} reshape(t2)
- a = f16[7,30]{1,0} add(cvt, r2)
- p0 = f16[7,79]{1,0} parameter(0)
- ROOT dot = f16[30,79]{1,0} dot(a, p0),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Transpose(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParameters) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- tmp_0 = f32[] constant(1)
- tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={}
- tmp_2 = f32[3,49]{1,0} parameter(6)
- tmp_3 = f32[] constant(0)
- tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={}
- tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT
- tmp_6 = f32[3,49]{1,0} convert(tmp_5)
- tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6)
- tmp_8 = s32[] parameter(13)
- tmp_9 = f32[] convert(tmp_8)
- tmp_10 = f32[] maximum(tmp_9, tmp_0)
- tmp_11 = f32[] divide(tmp_3, tmp_10)
- tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={}
- tmp_13 = pred[3,49]{1,0} parameter(7)
- tmp_14 = pred[3,49]{1,0} parameter(10)
- tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14)
- tmp_16 = f32[3,49]{1,0} convert(tmp_15)
- tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16)
- tmp_18 = f32[3,49]{1,0} negate(tmp_17)
- tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18)
- tmp_20 = f32[3,49]{1,0} parameter(19)
- tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20)
- tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21)
- tmp_23 = f32[3,49]{1,0} negate(tmp_22)
- tmp_24 = f32[3,49]{1,0} negate(tmp_6)
- tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17)
- tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20)
- tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26)
- tmp_28 = f32[3,49]{1,0} parameter(18)
- tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28)
- tmp_30 = f32[3,49]{1,0} parameter(17)
- tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30)
- tmp_32 = f32[3,49]{1,0} parameter(16)
- tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32)
- tmp_34 = f32[3,49]{1,0} parameter(15)
- tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34)
- tmp_36 = f32[3,49]{1,0} parameter(14)
- tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36)
- tmp_38 = f32[1,1]{1,0} constant({ {0} })
- tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1}
- tmp_40 = f32[] reshape(tmp_39)
- tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={}
- tmp_42 = u32[48]{0} parameter(11)
- tmp_43 = u32[48]{0} parameter(5)
- tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0}
- tmp_45 = u32[3,32]{1,0} reshape(tmp_44)
- tmp_46 = u32[96]{0} reshape(tmp_45)
- tmp_47 = u32[] constant(1)
- tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={}
- tmp_49 = u32[96]{0} reshape(tmp_48)
- tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49)
- tmp_51 = u32[3,32]{1,0} reshape(tmp_50)
- tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48)
- tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52)
- tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={}
- tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54)
- tmp_56 = f32[1,1]{1,0} constant({ {1} })
- tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1}
- tmp_58 = f32[] reshape(tmp_57)
- tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={}
- tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59)
- tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41)
- tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61)
- tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={}
- tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT
- tmp_65 = f32[3,32]{1,0} convert(tmp_64)
- tmp_66 = f32[3,49]{1,0} parameter(9)
- tmp_67 = f32[49]{0} parameter(4)
- tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1}
- tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68)
- tmp_70 = f32[1,49]{1,0} parameter(12)
- tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={}
- tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71)
- tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1}
- tmp_74 = f32[49]{0} reshape(tmp_73)
- tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1}
- tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75)
- tmp_77 = f32[1,49]{1,0} parameter(3)
- tmp_78 = f32[1,49]{1,0} parameter(8)
- tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71)
- tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72)
- tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80)
- tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71)
- tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82)
- tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83)
- tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1}
- tmp_86 = f32[49]{0} reshape(tmp_85)
- tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1}
- tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87)
- tmp_89 = f32[1,49]{1,0} parameter(2)
- tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1}
- tmp_91 = f32[49]{0} reshape(tmp_90)
- tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1}
- tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92)
- tmp_94 = f32[49,32]{1,0} parameter(1)
- tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- tmp_96 = f32[32]{0} parameter(0)
- tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1}
- tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97)
- tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98)
- tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63)
- tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63)
- ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
- HloOpcode::kFusion);
- EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
- HloInstruction::FusionKind::kCustom);
- EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
- TritonFusionAnalysis::kMaxParameterPerDotOperand * 2);
-}
-
-TEST_F(GemmFusionLevel2Test,
- DoNotFuseTooManyParametersWhenAnInstructionWouldAddMultipleParameters) {
- static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
- "We have to update this test.");
- // If we fuse the select, it adds 2 additional parameters at once (not 3,
- // because the select instruction itself is removed from the parameters).
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- a = f32[3,49]{1,0} parameter(0)
- b = f32[3,49]{1,0} parameter(1)
- c = pred[3,49]{1,0} parameter(2)
- d = f32[3,49]{1,0} parameter(3)
- e = f32[3,49]{1,0} parameter(4)
- add0 = f32[3,49]{1,0} add(a, b)
- select = f32[3,49]{1,0} select(c, d, e)
- add1 = f32[3,49]{1,0} add(add0, select)
- f = f32[3,32]{1,0} parameter(5)
- ROOT tmp_102 = f32[49,32]{1,0} dot(add1, f), lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
- HloOpcode::kFusion);
- EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
- HloInstruction::FusionKind::kCustom);
- EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
- TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
-}
-
-TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParametersForConcat) {
- static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
- "We have to update this test.");
- // The concat shouldn't overgo the allowed parameter limit.
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- a = f32[3,3]{1,0} parameter(0)
- b = f32[3,3]{1,0} parameter(1)
- c = f32[3,3]{1,0} parameter(2)
- d = f32[3,3]{1,0} parameter(3)
- e = f32[3,3]{1,0} parameter(4)
- f = f16[3,3]{1,0} parameter(5)
- concat = f32[15,3]{1,0} concatenate(a, b, c, d, e), dimensions={0}
- convert = f32[3,3]{1,0} convert(f)
- ROOT dot = f32[15,3]{1,0} dot(concat, convert), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
- HloOpcode::kFusion);
- EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
- HloInstruction::FusionKind::kCustom);
- EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
- TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
-}
-
-TEST_F(GemmFusionLevel2Test,
- InstructionsReachableFromMultipleOperandsAreHandledCorrectly) {
- static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
- "We have to update this test.");
- // There was a bug that some dead code was generated into some fusions in a
- // specific edge case. When some instructions were reachable both through the
- // LHS and the RHS operands, the BFS (Breadth-first search) through the LHS1
- // operand "marked" one operation as non-fusible because it would exceed the
- // limit on fusion parameters per operand. But the BFS through the RHS operand
- // went through that node and fused some more operands. So the resulting
- // fusion was not connected and caused errors. This test case checks that such
- // configurations generate a correct HLO now.
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- a = f32[2,4]{1,0} parameter(0)
- b = f32[2,4]{1,0} parameter(1)
- c = f32[2,4]{1,0} parameter(2)
- d = f32[2,4]{1,0} parameter(3)
- e = f32[2,4]{1,0} parameter(4)
- add0 = f32[2,4]{1,0} add(a, b)
- add1 = f32[2,4]{1,0} add(add0, c)
- add2 = f32[2,4]{1,0} add(add1, d)
- add3 = f32[2,4]{1,0} add(add2, e)
- ROOT r = f32[2,2]{1,0} dot(add3, add0),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- // ~VerifiedHloModule() will verify the module.
-}
-
-TEST_F(GemmFusionLevel2Test, EachScopeIsFusedToASeparateSubgraph) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- a = f32[2,4]{1,0} parameter(0)
- b = f32[2,4]{1,0} parameter(1)
- add = f32[2,4]{1,0} add(a, b)
- ROOT r = f32[2,2]{1,0} dot(add, add),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
- MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
-CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]])
-CHECK-DAG: %[[P2:.*]] = f32[2,4]{1,0} parameter(2)
-CHECK-DAG: %[[P3:.*]] = f32[2,4]{1,0} parameter(3)
-CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P2]], f32[2,4]{1,0} %[[P3]])
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
-CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]),
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-// The 2 inputs of the add operation are the same and they are iterated the same
-// way, so the same parameter node is reused for them.
-// The reuse happens per "operand fusion", so the add of the LHS and RHS still
-// use different nodes.
-TEST_F(GemmFusionLevel2Test, ParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- a = f32[2,4]{1,0} parameter(0)
- add = f32[2,4]{1,0} add(a, a)
- ROOT r = f32[2,2]{1,0} dot(add, add),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
- MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
-CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
-CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P1]])
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
-CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-// NEGATE has the same iteration spec at both usages, so the node is reused
-// (implying that P0 is also reused).
-TEST_F(GemmFusionLevel2Test, NonParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- a = f32[4,4]{1,0} parameter(0)
- b = f32[4,4]{1,0} parameter(1)
- negate = f32[4,4]{1,0} negate(a)
- sine = f32[4,4]{1,0} sine(negate)
- add = f32[4,4]{1,0} add(negate, sine)
- ROOT r = f32[4,4]{1,0} dot(add, b),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
- MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: %[[NEGATE:.*]] = f32[4,4]{1,0} negate(f32[4,4]{1,0} %[[P0]])
-CHECK-DAG: %[[SINE:.*]] = f32[4,4]{1,0} sine(f32[4,4]{1,0} %[[NEGATE]])
-CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[NEGATE]], f32[4,4]{1,0} %[[SINE]])
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P1]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
-CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-// The direct read of the input and the transposed read of the input have
-// different iteration specs, so we don't reuse the node.
-TEST_F(GemmFusionLevel2Test, NodesAreNotReusedIfTheyHaveDifferentIterSpecs) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- a = f32[4,4]{1,0} parameter(0)
- b = f32[4,4]{1,0} parameter(1)
- tr_a = f32[4,4]{1,0} transpose(a), dimensions={1,0}
- add = f32[4,4]{1,0} add(a, tr_a)
- ROOT r = f32[4,4]{1,0} dot(add, b),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
- MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: %[[P2:.*]] = f32[4,4]{1,0} parameter(2)
-CHECK-DAG: %[[TRANSPOSE:.*]] = f32[4,4]{1,0} transpose(f32[4,4]{1,0} %[[P1]])
-CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[TRANSPOSE]])
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P2]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
-CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-TEST_F(GemmFusionLevel2Test, OperationsAddingMoreParametersGetMultipleTries) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-e {
- p0 = f32[2,2] parameter(0)
- c0 = f32[] constant(12345)
- b0 = f32[2,2] broadcast(c0), dimensions={}
- m0 = f32[2,2] multiply(p0, b0)
- c1 = f32[] constant(34567)
- b1 = f32[2,2] broadcast(c1), dimensions={}
- a0 = f32[2,2] add(m0, b1)
- b3 = f32[2,2,2] broadcast(a0), dimensions={0,1}
- p2 = f32[2,2,2] parameter(2)
- m2 = f32[2,2,2] multiply(p2, b3)
- p1 = f32[2]{0} parameter(1)
- c2 = f32[] constant(5678)
- b2 = f32[2] broadcast(c2), dimensions={}
- a1 = f32[2]{0} add(p1, b2)
- b4 = f32[2,2,2] broadcast(a1), dimensions={2}
- m1 = f32[2,2,2] multiply(m2, b4)
- b = f32[4,2] bitcast(m1)
- p3 = f16[2,2] parameter(3)
- p3c = f32[2,2] convert(p3)
- ROOT r = f32[4,2] dot(b, p3c),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
- m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutPreAmpere) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[2,53] parameter(0)
- p0e = f32[2,53] exponential(p0)
- p1 = s16[53,2] parameter(1)
- p1c = f32[53,2] convert(p1)
- ROOT dot = f32[2,2] dot(p0e, p1c),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- EXPECT_THAT(
- GemmFusion(se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0})
- .Run(module.get()),
- tsl::testing::StatusIs(
- absl::StatusCode::kFailedPrecondition,
- ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
- "(compute capability 8.0) and up, but got")));
-}
-
-TEST_F(GemmFusionLevel2Test, GemmFusionSucceedsOnNonCudaGpu) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[2,53] parameter(0)
- p0e = f32[2,53] exponential(p0)
- p1 = s16[53,2] parameter(1)
- p1c = f32[53,2] convert(p1)
- ROOT dot = f32[2,2] dot(p0e, p1c),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- EXPECT_TRUE(GemmFusion(se::RocmComputeCapability{}).Run(module.get()).ok());
-}
-
-TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-HloModule t
-
-ENTRY e {
- p0 = f32[2,35] parameter(0)
- p0n = f32[2,35] negate(p0)
- p0e = f32[2,35] exponential(p0)
- a = f32[2,35] add(p0e, p0n)
- p1 = f16[35,2] parameter(1)
- p1c = f32[35,2] convert(p1)
- ROOT dot = f32[2,2] dot(a, p1c),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter()))));
- TF_ASSERT_OK_AND_ASSIGN(
- const auto analysis,
- TritonFusionAnalysis::Execute(*module->entry_computation()
- ->root_instruction()
- ->called_computations()[0]));
- EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(),
- 1);
- EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size(),
- 1);
-}
-
-TEST_F(GemmFusionLevel2Test,
- ParameterUsedNonElementwiseTwiceIsFusedOnBothPaths) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-HloModule t
-
-ENTRY e {
- p0 = f32[4,4] parameter(0)
- p0t = f32[4,4] transpose(p0), dimensions={1,0}
- a = f32[4,4] add(p0, p0t)
- p1 = f16[4,5] parameter(1)
- p1c = f32[4,5] convert(p1)
- ROOT dot = f32[4,5] dot(a, p1c),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test,
- ComputationParameterWithMultipleUsersIsNotTrivialToFuse) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = f32[400,400] parameter(0)
-
- c0 = f16[400,400] convert(p0)
- p1 = f16[400,400] parameter(1)
- dot0 = f16[400,400] dot(c0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
- c1 = f16[400,400] convert(p0)
- p2 = f16[400,400] parameter(2)
- dot1 = f16[400,400] dot(c1, p2),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
- ROOT a = f16[400,400] add(dot0, dot1)
-})"));
- EXPECT_FALSE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
-}
-
-TEST_F(GemmFusionLevel2Test, NarrowingConversionIsAlwaysBetterToFuse) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-ENTRY e {
- p0 = s8[512,512] parameter(0)
- c0 = f16[512,512] convert(p0)
- p1 = f16[512,512] parameter(1)
- dot0 = f16[512,512] dot(c0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
- n = f16[512,512] negate(c0)
- ROOT a = f16[512,512] add(dot0, n)
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Add(m::Fusion(m::Parameter(), m::Parameter()),
- m::Negate()))));
-}
-
-TEST_F(GemmFusionLevel2Test, NestedSlicingIsAnalyzedCorrectly) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-triton_gemm_d_computation {
- p0 = f32[6,24]{1,0} parameter(0)
- slice1 = f32[5,20]{1,0} slice(p0), slice={[1:6], [3:23]}
- n1 = f32[5,20]{1,0} negate(slice1)
- slice2 = f32[3,7]{1,0} slice(n1), slice={[1:4], [13:20]}
- p1 = f32[7,37]{1,0} parameter(1)
- ROOT d = f32[3,37]{1,0} dot(slice2, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY e {
- p0 = f32[7,37]{1,0} parameter(0)
- p1 = f32[6,24]{1,0} parameter(1)
- ROOT triton_gemm_d = f32[3,37]{1,0} fusion(p1, p0), kind=kCustom,
- calls=triton_gemm_d_computation
-})"));
- const HloComputation* computation =
- module->entry_computation()->root_instruction()->called_computations()[0];
- TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
- TritonFusionAnalysis::Execute(*computation));
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
- computation->parameter_instruction(0), 0),
- ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6,
- /*slice_start=*/2, /*sliced_count=*/3,
- /*subfragments=*/ElementsAre(3))));
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
- computation->parameter_instruction(0), 1),
- ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24,
- /*slice_start=*/16, /*sliced_count=*/7,
- /*subfragments=*/ElementsAre(7))));
-}
-
-TEST_F(GemmFusionLevel2Test, FusedConcatenationIsAnalyzedCorrectly) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-e {
- p0 = s8[153,1536] parameter(0)
- p1 = s8[153,128] parameter(1)
- p2 = s8[153,256] parameter(2)
- cat = s8[153,1920] concatenate(p0, p1, p2), dimensions={1}
- cvt = bf16[153,1920] convert(cat)
- p3 = bf16[16,153] parameter(3)
- ROOT d = bf16[16,1920] dot(p3, cvt),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
- m::Parameter(), m::Parameter()))));
- const HloComputation* computation =
- module->entry_computation()->root_instruction()->called_computations()[0];
- TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
- TritonFusionAnalysis::Execute(*computation));
-
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
- computation->parameter_instruction(1), 0),
- ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153,
- /*slice_start=*/0, /*sliced_count=*/153,
- /*subfragments=*/ElementsAre(153))));
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
- computation->parameter_instruction(1), 1),
- ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536,
- /*slice_start=*/0, /*sliced_count=*/1536,
- /*subfragments=*/ElementsAre(1536))));
-
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
- computation->parameter_instruction(2), 0),
- ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153,
- /*slice_start=*/0, /*sliced_count=*/153,
- /*subfragments=*/ElementsAre(153))));
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
- computation->parameter_instruction(2), 1),
- ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128,
- /*slice_start=*/-1536, /*sliced_count=*/128,
- /*subfragments=*/ElementsAre(128))));
-
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
- computation->parameter_instruction(3), 0),
- ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153,
- /*slice_start=*/0, /*sliced_count=*/153,
- /*subfragments=*/ElementsAre(153))));
- EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
- computation->parameter_instruction(3), 1),
- ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256,
- /*slice_start=*/-1536 - 128,
- /*sliced_count=*/256,
- /*subfragments=*/ElementsAre(256))));
-}
-
-TEST_F(GemmFusionLevel2Test, IndivisibleConcatenationIsNotFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-e {
- p0 = s8[124,1024] parameter(0)
- p1 = s8[124,1001] parameter(1)
- cat = s8[124,2025] concatenate(p0, p1), dimensions={1}
- cvt = f16[124,2025] convert(cat)
- p2 = f16[123,124] parameter(2)
- ROOT d = f16[2025,123] dot(cvt, p2),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test, ConcatenationOfContractingIsNotFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-e {
- p0 = s8[124,1024] parameter(0)
- p1 = s8[124,1024] parameter(1)
- cat = s8[124,2048] concatenate(p0, p1), dimensions={1}
- cvt = f16[124,2048] convert(cat)
- p2 = f16[123,2048] parameter(2)
- ROOT d = f16[124,123] dot(cvt, p2),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test, ConcatenationOfBatchIsNotFused) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-e {
- p0 = s8[124,1024,50] parameter(0)
- p1 = s8[124,1024,50] parameter(1)
- cat = s8[124,2048,50] concatenate(p0, p1), dimensions={1}
- cvt = f16[124,2048,50] convert(cat)
- p2 = f16[123,2048,50] parameter(2)
- ROOT d = f16[2048,124,123] dot(cvt, p2),
- lhs_batch_dims={1}, rhs_batch_dims={1},
- lhs_contracting_dims={2}, rhs_contracting_dims={2}
-})"));
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test,
- DifferentConcatenationOfSameParametersIsFusedViaNodeDuplication) {
- // It means that the same input is passed to the fusion multiple times and
- // it's read differently for each.
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-e {
- p0 = s8[128,2] parameter(0)
- p1 = s8[128,2] parameter(1)
- cat0 = s8[256,2] concatenate(p0, p1), dimensions={0}
- cvt0 = f16[256,2] convert(cat0)
- cat1 = s8[256,2] concatenate(p1, p0), dimensions={0}
- n1 = s8[256,2] negate(cat1)
- cvt1 = f16[256,2] convert(n1)
- a = f16[256,2] add(cvt1, cvt0)
- p2 = f16[2,18] parameter(2)
- ROOT d = f16[18,256] dot(p2, a),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
-
- EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
- se::CudaComputeCapability::AMPERE, 0})
- .Run(module.get())
- .value());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
- m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionTest, CopiesDotMetadataToFusionOp) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[2,18] parameter(0)
- p1 = f16[256,2] parameter(1)
- ROOT d = f16[18,256] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="foo"}
-})")
- .value();
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
- EXPECT_EQ(
- module->entry_computation()->root_instruction()->metadata().op_name(),
- "foo");
-}
-
-// A test fixture class for testing the threshold for small matrices.
-class SmallDotGemmFusionTest : public GemmFusionTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
- return debug_options;
- }
-};
-
-TEST_F(SmallDotGemmFusionTest, SkipSmallMatrixMultiplicationRewrite) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[2,10] parameter(0)
- p1 = f16[10,2] parameter(1)
- ROOT d = f16[10,10] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})")
- .value();
-
- EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
-
- MatchHloModule(*module, R"(
-; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,10], {{.*}}: f16[10,2]) -> f16[10,10] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,10]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[10,2]{1,0} parameter(1)
-; CHECK: ROOT {{.*}} = f16[10,10]{1,0} dot(f16[2,10]{1,0} [[P0]], f16[10,2]{1,0} [[P1]])
-})");
-}
-
-TEST_F(SmallDotGemmFusionTest, LargeMatrixMultiplicationIsRewritten) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
- p0 = f16[2,18] parameter(0)
- p1 = f16[50,2] parameter(1)
- ROOT d = f16[18,50] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})")
- .value();
-
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
- MatchHloModule(*module, R"(
-; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,18], {{.*}}: f16[50,2]) -> f16[18,50] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,18]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[50,2]{1,0} parameter(1)
-; CHECK: ROOT {{.*}} = f16[18,50]{1,0}
-; CHECK: fusion(f16[2,18]{1,0} [[P0]], f16[50,2]{1,0} [[P1]]),
-; CHECK: kind=kCustom
-; CHECK: __triton_gemm
-})");
-}
-
-class SparseDotTest : public GemmFusionTest {};
-
-TEST_F(SparseDotTest, DotWithSparseLhsOperandIsRewritten) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-ENTRY main {
- lhs = f16[2,16] parameter(0)
- rhs = f16[32,2] parameter(1)
- meta = u16[2,2] parameter(2)
- ROOT dot = f32[2,2] dot(lhs, rhs, meta),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
-})")
- .value();
- EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
- MatchHloModule(*module, R"(
-; CHECK-LABEL: ENTRY %main ({{.*}}: f16[2,16], {{.*}}: f16[32,2], {{.*}}: u16[2,2]) -> f32[2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,16]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[32,2]{1,0} parameter(1)
-; CHECK-NEXT: [[META:%[^ ]+]] = u16[2,2]{1,0} parameter(2)
-; CHECK: ROOT {{.*}} = f32[2,2]{1,0}
-; CHECK-SAME: fusion(f16[2,16]{1,0} [[P0]], f16[32,2]{1,0} [[P1]], u16[2,2]{1,0} [[META]]),
-; CHECK-SAME: kind=kCustom
-; CHECK-SAME: __triton_gemm
-})");
-}
-
-TEST_F(SparseDotTest, DotWithSparseRhsOperandIsNotSupported) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-ENTRY main {
- lhs = f16[2,32] parameter(0)
- rhs = f16[16,2] parameter(1)
- meta = u16[2,2] parameter(2)
- ROOT dot = f32[2,2] dot(lhs, rhs, meta),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=R.0@2:4
-})")
- .value();
- auto result = GemmFusion(gpu_version_).Run(module.get());
- EXPECT_FALSE(result.ok());
-}
-
-TEST_F(SparseDotTest, UnsupportedSparsityType) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-ENTRY main {
- lhs = f16[2,8] parameter(0)
- rhs = f16[32,2] parameter(1)
- meta = u16[2,1] parameter(2)
- ROOT dot = f32[2,2] dot(lhs, rhs, meta),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@1:4
-})")
- .value();
- auto result = GemmFusion(gpu_version_).Run(module.get());
- EXPECT_FALSE(result.ok());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc
deleted file mode 100644
index e3dd0cf..0000000
--- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc
+++ /dev/null
@@ -1,2429 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-=
-=============================================================================*/
-
-#include "xla/service/gpu/gemm_rewriter.h"
-
-#include <algorithm>
-#include <array>
-#include <cmath>
-#include <cstddef>
-#include <cstdint>
-#include <initializer_list>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/evaluator/hlo_evaluator.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal.h"
-#include "xla/literal_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/algorithm_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/gpu/gpu_blas_lt.h"
-#include "xla/types.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/ml_dtypes.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/protobuf/dnn.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = match;
-
-// Give this instruction a more useful name than "custom-call.42".
-absl::Status SetName(HloModule *module, HloInstruction *gemm) {
- if (IsCublasLtMatmul(*gemm)) {
- module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul");
- return absl::OkStatus();
- }
-
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- gemm->backend_config<GpuBackendConfig>());
- const GemmBackendConfig &config = gpu_config.gemm_backend_config();
- const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
- bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
- !dot_dims.rhs_batch_dimensions().empty();
-
- module->SetAndUniquifyInstrName(
- gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm");
- return absl::OkStatus();
-}
-
-// Returns whether a given PrimitiveType is supported by cuBLASLt Epilogue
-// Fusion. A table of supported data types can be found in the cuBLASLt
-// documentation: https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul.
-// Note that `Ctype` also describes the output type of the GEMM. Rows with
-// `Non-default epilogue not supported` entries in the last column indicate data
-// types not compatible with Epilogue Fusion.
-bool SupportsEpilogueFusion(PrimitiveType type) {
- switch (type) {
- case F8E4M3FN:
- case F8E5M2:
- case F16:
- case BF16:
- case F32:
- case F64:
- return true;
- default:
- return false;
- }
-}
-
-bool IsF8Type(const HloInstruction *instr) {
- return primitive_util::IsF8Type(instr->shape().element_type());
-}
-
-// Returns a new shape with non-batch dimensions padded to multiples of 16, as
-// required by cuBLASLt FP8 gemms.
-Shape PadShapeToMultipleOf16(const Shape old_shape,
- const absl::Span<const int64_t> batch_dims) {
- Shape padded_shape = old_shape;
- for (int i = 0; i < old_shape.rank(); ++i) {
- if (!absl::c_linear_search(batch_dims, i)) {
- int64_t padded_dimension =
- RoundUpTo<int64_t>(old_shape.dimensions(i), 16);
- padded_shape.set_dimensions(i, padded_dimension);
- }
- }
- return padded_shape;
-}
-
-// Pad the dimensions of the operands to the target shape.
-HloInstruction *PadOperandToTargetShape(const Shape &target,
- HloInstruction *x) {
- if (ShapeUtil::Equal(target, x->shape()) ||
- !ShapeUtil::SameElementType(x->shape(), target)) {
- return x;
- }
-
- PaddingConfig padding_config;
- for (int i = 0; i < x->shape().rank(); ++i) {
- auto dimension = padding_config.add_dimensions();
- dimension->set_edge_padding_low(0);
- dimension->set_edge_padding_high(target.dimensions(i) -
- x->shape().dimensions(i));
- dimension->set_interior_padding(0);
- }
-
- HloInstruction *zero = x->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(x->shape().element_type())));
- return x->AddInstruction(
- HloInstruction::CreatePad(target, x, zero, padding_config));
-}
-
-// Pad the non-batch dimensions of the operands to multiples of 16 as required
-// by cuBLASLt FP8 gemms.
-HloInstruction *PadOperandToMultipleOf16(absl::Span<const int64_t> batch_dims,
- HloInstruction *x) {
- Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims);
- return PadOperandToTargetShape(padded_shape, x);
-}
-
-// Calculates the reciprocal of scalar when invert is true and converts to FP32.
-absl::StatusOr<HloInstruction *> InvertAndConvertScalar(HloInstruction *scalar,
- bool invert) {
- DCHECK(ShapeUtil::IsScalar(scalar->shape()));
-
- if (invert) {
- Literal one_literal = LiteralUtil::One(scalar->shape().element_type());
- HloInstruction *one = scalar->parent()->AddInstruction(
- HloInstruction::CreateConstant(one_literal.Clone()));
- TF_ASSIGN_OR_RETURN(scalar, MakeBinaryHlo(HloOpcode::kDivide, one, scalar,
- &scalar->metadata()));
- }
- if (scalar->shape().element_type() != F32) {
- scalar = MakeConvertToHlo(scalar, F32, &scalar->metadata());
- }
-
- return scalar;
-}
-
-// A path of instructions by traversing downwards through users, as (op,
-// operand_index) pairs. operand_index is the index to get to the previous
-// element in the path. I.e.,
-// path[i].first->operand(path[i].second) == path[i-1].first
-using InstrPath = std::vector<std::pair<HloInstruction *, int>>;
-
-// From 'instr', recursively traverses operands until an FP8 instruction is
-// encountered. Only unary ops and a few types of non-unary ops are traversed.
-// If an FP8 instruction is found, returns the path from the FP8 instruction to
-// 'instr'. Returns nullopt when no FP8 instruction is reached.
-//
-// The intent is, given 'instr' is the operand of a dot, to find a sequence of
-// instruction that can potentially be fused into a cuBLAS LT FP8 gemm.
-std::optional<InstrPath> FindF8SubgraphRecursive(
- HloInstruction *instr, absl::flat_hash_set<int> &visited_instrs) {
- // Avoid visiting the same instruction more than once.
- if (!visited_instrs.emplace(instr->unique_id()).second) {
- return std::nullopt;
- }
- if (IsF8Type(instr)) {
- // The initial operand index is meaningless. Arbitrarily use -1.
- return InstrPath{{instr, -1}};
- }
- if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide ||
- instr->opcode() == HloOpcode::kDynamicSlice ||
- instr->opcode() == HloOpcode::kPad) {
- std::optional<InstrPath> subgraph =
- FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs);
- if (subgraph) {
- subgraph->emplace_back(std::make_pair(instr, 0));
- }
- return subgraph;
- } else if (instr->opcode() == HloOpcode::kMultiply ||
- instr->opcode() == HloOpcode::kSelect) {
- for (int k = 0; k < 2; ++k) {
- // Iterate over operands 0 and 1 for multiply and operands 1 and 2 for
- // select.
- int operand_idx = k + (instr->opcode() == HloOpcode::kSelect);
- std::optional<InstrPath> subgraph = FindF8SubgraphRecursive(
- instr->mutable_operand(operand_idx), visited_instrs);
- if (subgraph) {
- subgraph->emplace_back(std::make_pair(instr, operand_idx));
- return subgraph;
- }
- }
- }
- return std::nullopt;
-}
-
-// Contains information on a parameter (either the LHS or RHS) for a
-// gemm that can be potentially pattern-matched into an FP8 cublasLT gemm.
-struct MatchedFp8Param {
- // The FP8 input to the gemm.
- HloInstruction *fp8_input = nullptr;
- // If nonnull, the scale for the 'x'
- HloInstruction *scale = nullptr;
- // Whether the scale, if present, multiplies or divides 'x'
- bool mult_scale = false;
- // A list of instructions from x to the dot instruction commutative with
- // dequantization. Such instructions can be moved before the FP8 gemm.
- InstrPath commutative_ops;
-};
-
-// Given an operand of a dot, `instr`, returns a MatchedFp8Param if this operand
-// allows rewriting the dot in an FP8 cublasLT custom call, optionally with
-// scaling. In particular, returns an MatchedFp8Param if either 'instr' is FP8
-// or there is a there is a path from an FP8 instruction 'fp8_input' to 'instr'
-// consisting of the following.
-// 1. A convert to a wider type.
-// 2. Optionally, a multiplication/division by a scalar, representing the scale.
-// If present, the scalar scale is returned as 'scale' and 'mult_scale'
-// is set to true or false depending on whether there is a multiplication or
-// a division.
-// 3. A possibly-empty set of ops communative with steps (1) and (2), meaning
-// they can be safely moved before step (1). Such ops are returned in
-// 'commutative_ops'.
-// Steps (1) and (2) together are a dequantization, and can be fused into a
-// cublas LT matmul. Step (3) can be moved before the cublas LT matmul.
-std::optional<MatchedFp8Param> MatchFp8Param(HloInstruction *instr) {
- absl::flat_hash_set<int> visited_instrs;
- std::optional<InstrPath> maybe_subgraph =
- FindF8SubgraphRecursive(instr, visited_instrs);
- if (!maybe_subgraph) {
- return std::nullopt;
- }
- InstrPath &subgraph = maybe_subgraph.value();
-
- MatchedFp8Param param;
-
- // Directly operating on an FP8 operand.
- if (subgraph.size() == 1) {
- CHECK(IsF8Type(subgraph[0].first));
- param.fp8_input = subgraph[0].first;
- return param;
- }
-
- int num_dequant_ops;
- // When not operating directly on an FP8 operand, the second and
- // third instructions in the subgraph can describe a dequantization, i.e. a
- // convert instruction followed by a multiply/divide instruction.
- if (subgraph.size() > 2 &&
- Match(subgraph[2].first,
- m::MultiplyAnyOrder(m::Convert(m::Op(¶m.fp8_input)),
- m::Broadcast(m::Op(¶m.scale))))) {
- param.mult_scale = true;
- num_dequant_ops = 2;
- } else if (subgraph.size() > 2 &&
- Match(subgraph[2].first,
- m::Divide(m::Convert(m::Op(¶m.fp8_input)),
- m::Broadcast(m::Op(¶m.scale))))) {
- param.mult_scale = false;
- num_dequant_ops = 2;
- } else if (subgraph.size() > 1 &&
- Match(subgraph[1].first, m::Convert(m::Op(¶m.fp8_input)))) {
- // We have a convert from FP8 without a scale in this case.
- param.scale = nullptr;
- num_dequant_ops = 1;
- } else {
- VLOG(1) << "Possible intended FP8 GEMM operating on "
- << instr->ToShortString() << " not rewritten into FP8 Custom Call.";
- return std::nullopt;
- }
-
- auto preserves_element_type = [](const HloInstruction *instr) -> bool {
- return ShapeUtil::SameElementType(instr->shape(),
- instr->operand(0)->shape());
- };
- auto use_spmd_partitioning = [](const HloInstruction *instr) -> bool {
- return instr->GetModule()->config().use_spmd_partitioning();
- };
-
- // Skip the initial FP8 instruction and the dequantization instructions.
- int start = 1 + num_dequant_ops;
- for (int i = start; i < subgraph.size(); ++i) {
- // The remaining instructions must be commutative with dequantization.
- // Bitcast, broadcast, copy, dynamic-slice, pad, reshape, select, slice,
- // transpose, all-gather, all-to-all and collective-permute instructions are
- // supported. Specifically, the all-gather, all-to-all and
- // collective-permute operations are permitted only in SPMD cases since the
- // optimization cannot be guaranteed to be applied to all replicas in the
- // MPMD scenario.
- if (!Match(
- subgraph[i].first,
- m::AnyOf<HloInstruction>(
- m::Bitcast().WithPredicate(preserves_element_type),
- m::Broadcast(), m::Copy(), m::DynamicSlice(), m::Pad(),
- m::Reshape(), m::Select(), m::Slice(), m::Transpose(),
- m::AllGather().WithPredicate(use_spmd_partitioning),
- m::AllToAll().WithPredicate(use_spmd_partitioning),
- m::CollectivePermute().WithPredicate(use_spmd_partitioning)))) {
- VLOG(1) << "Possible intended FP8 GEMM operating on "
- << instr->ToShortString()
- << " not rewritten into FP8 Custom Call.";
- return std::nullopt;
- }
- // One of the operands of select must be zero for the op to be commutative
- // with dequantization.
- if (Match(subgraph[i].first, m::Select()) &&
- !Match(subgraph[i].first->operand(subgraph[i].second == 2 ? 1 : 2),
- m::Broadcast(m::ConstantScalar(0)))) {
- VLOG(1) << "Possible intended FP8 GEMM operating on "
- << instr->ToShortString()
- << " not rewritten into FP8 Custom Call. Select requires a zero "
- "operand to be exchanged with dequantization.";
- return std::nullopt;
- }
- }
-
- param.commutative_ops = {subgraph.begin() + start, subgraph.end()};
- return param;
-}
-
-// Transposes a matrix by swapping the contracting and non-contracting
-// dimension. There must be only one contracting and only one non-contracting
-// dimension. Keeps the layout the same.
-HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
- absl::Span<const int64_t> batch_dims) {
- // Identify the dimensional order which describes a transpose of the
- // contracting and non-contracting dimensions of the GEMM.
- std::vector<int64_t> permutation(instr->shape().dimensions_size(), -1);
- // Discard the batch dimensions.
- for (int64_t batch_dim : batch_dims) {
- permutation[batch_dim] = batch_dim;
- }
- // Identify the non-contracting dimension.
- int non_contracting_dim;
- for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
- if (permutation[i] == -1 && contracting_dim != i) {
- non_contracting_dim = i;
- }
- }
- permutation[non_contracting_dim] = contracting_dim;
- permutation[contracting_dim] = non_contracting_dim;
-
- Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
- *new_shape.mutable_layout() = instr->shape().layout();
- return instr->AddInstruction(
- HloInstruction::CreateTranspose(new_shape, instr, permutation));
-}
-
-// If the bias is a sequence of ops that depend only on broadcasts of
-// constants, materialize the bias if it's small.
-//
-// Normally the constant-folding pass would materialize the bias if it is
-// calculated entirely from constants. But if the bias is a broadcast of a
-// constant, constant-folding won't expand the broadcast, on the theory that
-// folding broadcasts of constants causes us to consume more memory and can
-// actually make things slower (because any op which reads the constant has
-// to read more memory).
-//
-// OTOH in our case, we don't want to run an op that just broadcasts a
-// constant so we can fuse it into this gemm. That would defeat the whole
-// purpose of this fusion, which is to launch fewer kernels. So if we can,
-// we expand out this constant ourselves.
-//
-// TODO(b/192499646): Even better would be to use cublasLT to fuse the
-// broadcasted bias, if it supports that fusion efficiently.
-HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
- // This limit was not chosen carefully.
- constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
-
- // Don't fold broadcasts of scalars -- algsimp will just collapse it again.
- auto is_nonscalar = [](const HloInstruction *instr) {
- return !ShapeUtil::IsEffectiveScalar(instr->shape());
- };
-
- // For now, only fold broadcast(constant) or
- // reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
- // complexity in the constant-folding pass about what is and isn't legal to
- // fold.
- auto broadcast_of_nonscalar =
- m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
-
- if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
- (Match(bias, broadcast_of_nonscalar) ||
- Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
- Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
- Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
- HloEvaluator evaluator(/*max_loop_iterations=*/0);
- Literal result;
- if (evaluator.TryEvaluate(
- bias, &result,
- /*recursively_evaluate_nonconstant_operands=*/true)) {
- return bias->parent()->AddInstruction(
- HloInstruction::CreateConstant(std::move(result)));
- }
- }
-
- return bias;
-}
-
-auto Gemm(HloInstruction **instr) {
- return m::CustomCall(instr, {kGemmCallTarget});
-}
-
-auto CublasLtMatmul(HloInstruction **instr) {
- return m::CustomCall(instr, {kCublasLtMatmulCallTarget});
-}
-
-auto CublasLtMatmulF8(HloInstruction **instr) {
- return m::CustomCall(instr, {kCublasLtMatmulF8CallTarget});
-}
-
-auto CublasLtMatmulMaybeF8(HloInstruction **instr) {
- return m::CustomCall(
- instr, {kCublasLtMatmulCallTarget, kCublasLtMatmulF8CallTarget});
-}
-
-auto GemmOrCublasLtMatmul(HloInstruction **instr) {
- return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget});
-}
-
-auto GemmOrCublasLtMatmulMaybeF8(HloInstruction **instr) {
- return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget,
- kCublasLtMatmulF8CallTarget});
-}
-
-auto BcastConstScalar(HloInstruction **instr, double value) {
- return m::Broadcast(instr, m::ConstantScalar(value));
-}
-
-auto BcastConstScalar(double value) { return BcastConstScalar(nullptr, value); }
-
-auto BcastConstScalarNear(double value) {
- return m::Broadcast(m::ConstantScalar().WithPredicate(
- [expected = value](const HloInstruction *instr) {
- // Not a very robust floating-point comparison, but good enough for our
- // purposes.
- std::optional<double> actual =
- xla::Cast<const HloConstantInstruction>(instr)
- ->literal()
- .GetAsDouble({});
- if (!actual.has_value()) return false;
- double epsilon;
- switch (instr->shape().element_type()) {
- case F16:
- epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
- break;
- case BF16:
- epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
- break;
- case F32:
- epsilon = 128 * std::numeric_limits<float>::epsilon();
- break;
- case F64:
- epsilon = 128 * std::numeric_limits<double>::epsilon();
- break;
- default:
- return false;
- }
- return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
- }));
-}
-
-template <typename Pattern>
-auto OptionalSlice(HloInstruction **optional_slice, Pattern pattern) {
- return m::AnyOf<HloInstruction>(m::Slice(optional_slice, pattern),
- std::move(pattern));
-}
-
-template <typename Pattern>
-auto OptionalConvert(HloInstruction **optional_convert, Pattern pattern) {
- return m::AnyOf<HloInstruction>(m::Convert(optional_convert, pattern),
- std::move(pattern));
-}
-
-template <typename Pattern>
-auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) {
- return m::AnyOf<HloInstruction>(m::Bitcast(optional_bitcast, pattern),
- std::move(pattern));
-}
-
-// The rewriting proceeds in a bottom-up way:
-//
-// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
-//
-// (kMultiply (kCustomCall:gemm A B) C) is folding C (provided it's a constant)
-// into an alpha parameter of the custom call.
-//
-// (kAdd (kCustomCall:gemm A B) C) is rewritten into (kCustomCall:gemm A B C),
-// where the "beta" parameter is set to 1 (provided it was zero before,
-// and provided C has no other users).
-// We then guide the buffer assignment to alias the buffer of the custom call
-// and C.
-//
-// For scaled FP8 GEMMs on Hopper systems, the following steps mentioned in
-// RFC #22 (https://github.com/openxla/xla/discussions/22) are elided and
-// rewritten into a Custom Call:
-//
-// 1. Cast each input from FP8 to a wider type such as FP16 or FP32.
-// 2. Unscale each input by multiplying each input by the corresponding input
-// scale.
-// 3. Evaluate the matrix multiplication on the scaled inputs.
-// 4. Compute the maximum of the absolute values in the result of the GEMM
-// (DAmax).
-// 5. Scale the output by dividing the output by the output scale.
-// 6. Cast the output back to FP8. Since saturation should be done on
-// overflow, this is represented by a Clamp instruction followed by a Convert
-// instruction.
-
-// Steps 1 through 3 can be elided independently of the remainder. Steps 5 and
-// 6 are elided only if steps 1 through 3 were successfully transformed. Step
-// 4 requires steps 5 and 6, i.e. the computation of DAmax can be elided only
-// when the output of the GEMM is requested in FP8 format.
-class GemmRewriterVisitor : public DfsHloRewriteVisitor {
- public:
- explicit GemmRewriterVisitor(const se::GpuComputeCapability &gpu_version,
- const int32_t toolkit_version,
- const bool f8_rewrite)
- : gpu_version_(gpu_version),
- toolkit_version_(toolkit_version),
- f8_rewrite_(f8_rewrite) {}
-
- absl::Status HandleDot(HloInstruction *instr) override {
- if (!IsMatrixMultiplication(*instr) &&
- !IsMatrixVectorMultiplication(*instr)) {
- return absl::OkStatus();
- }
- // Sparse dot is not supported.
- if (Cast<HloDotInstruction>(instr)->sparse_operands()) {
- return absl::OkStatus();
- }
-
- int64_t gemm_rewrite_size_threshold =
- instr->GetModule()
- ->config()
- .debug_options()
- .xla_gpu_gemm_rewrite_size_threshold();
- TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
- IsMatrixMultiplicationTooSmallForRewriting(
- *instr, gemm_rewrite_size_threshold));
- if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*instr)) {
- return absl::OkStatus();
- }
-
- CHECK(!instr->IsRank2Transpose());
- if (instr->operand(0)->IsRank2Transpose() ||
- instr->operand(1)->IsRank2Transpose()) {
- return absl::OkStatus();
- }
- // Create a GemmBackendConfig based on the instruction.
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
- instr->backend_config<GpuBackendConfig>());
- GemmBackendConfig &gemm_backend_config =
- *gpu_backend_config.mutable_gemm_backend_config();
- gemm_backend_config.set_alpha_real(1.0);
- gemm_backend_config.set_alpha_imag(0.0);
- gemm_backend_config.set_beta(0.0);
- *gemm_backend_config.mutable_dot_dimension_numbers() =
- instr->dot_dimension_numbers();
- *gemm_backend_config.mutable_precision_config() = instr->precision_config();
-
- HloInstruction *lhs = instr->mutable_operand(0);
- HloInstruction *rhs = instr->mutable_operand(1);
- auto attributes = instr->frontend_attributes().map();
- gemm_backend_config.set_grad_x(attributes["grad_x"] == "true");
- gemm_backend_config.set_grad_y(attributes["grad_y"] == "true");
-
- int64_t lhs_batch_dims_size =
- instr->dot_dimension_numbers().lhs_batch_dimensions_size();
- bool is_lhs_vector =
- lhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
- bool is_rhs_vector =
- rhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
- int64_t lhs_stride =
- is_lhs_vector ? lhs->shape().dimensions(lhs_batch_dims_size)
- : lhs->shape().dimensions(lhs_batch_dims_size) *
- lhs->shape().dimensions(lhs_batch_dims_size + 1);
- int64_t rhs_stride =
- is_rhs_vector ? rhs->shape().dimensions(lhs_batch_dims_size)
- : rhs->shape().dimensions(lhs_batch_dims_size) *
- rhs->shape().dimensions(lhs_batch_dims_size + 1);
-
- gemm_backend_config.set_lhs_stride(lhs_stride);
- gemm_backend_config.set_rhs_stride(rhs_stride);
-
- if (f8_rewrite_) {
- // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call.
- TF_ASSIGN_OR_RETURN(
- bool supported_by_cublaslt,
- GemmIsSupportedByCublasLt(*instr, gemm_backend_config));
- std::optional<MatchedFp8Param> a, b;
- if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot &&
- (a = MatchFp8Param(
- const_cast<HloInstruction *>(instr->operand(0)))) &&
- (b = MatchFp8Param(
- const_cast<HloInstruction *>(instr->operand(1))))) {
- if (IsRocm(gpu_version_) && toolkit_version_ < 60200 &&
- instr->shape().element_type() != F16 &&
- instr->shape().element_type() != F32) {
- TF_ASSIGN_OR_RETURN(instr,
- TurnF8DotWithUnsupportedOutputTypeIntoF32(instr));
- }
- TF_ASSIGN_OR_RETURN(bool created_call,
- CreateF8CustomCall(instr, gpu_backend_config,
- a.value(), b.value()));
- if (created_call) {
- return absl::OkStatus();
- }
- }
- if (IsF8Type(instr->operand(0))) {
- // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt
- // custom call, so turn into an FP16 dot which may be rewritten as an
- // FP16 Triton, cublas or cublasLt call.
- TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr));
- }
- } else {
- // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call.
- TF_ASSIGN_OR_RETURN(
- absl::string_view gemm_custom_call_target,
- GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config));
- const Shape &output_shape = instr->shape();
- HloInstruction *gemm_call =
- instr->AddInstruction(HloInstruction::CreateCustomCall(
- output_shape,
- {instr->mutable_operand(0), instr->mutable_operand(1)},
- gemm_custom_call_target));
- TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config));
- TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call));
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleMultiply(HloInstruction *instr) override {
- HloInstruction *alpha, *existing_gemm;
- if (Match(instr,
- m::MultiplyAnyOrder(
- GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(),
- m::Broadcast(m::ConstantScalar(&alpha)).WithOneUser()))) {
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- existing_gemm->backend_config<GpuBackendConfig>());
- GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
- // Do not fuse alpha into S32 GEMM, as they only support fixed values for
- // alpha/beta.
- if (existing_gemm->shape().element_type() == S32) {
- return absl::OkStatus();
- }
-
- if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
- complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
- complex128 new_alpha =
- *alpha->literal().GetAsComplex128({}) * prev_alpha;
- config.set_alpha_real(new_alpha.real());
- config.set_alpha_imag(new_alpha.imag());
- TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
- return ReplaceInstruction(instr, existing_gemm);
- }
- }
-
- HloInstruction *d_scale;
- if (Match(instr, m::MultiplyAnyOrder(
- CublasLtMatmulF8(&existing_gemm).WithOneUser(),
- m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
- return F8ScaleD(instr, existing_gemm, d_scale);
- }
-
- // Attempt to match approximate GELU activation
- // (https://arxiv.org/abs/1606.08415), where:
- // approx_gelu(x) = x * cdf(x)
- // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
- HloInstruction *cdf, *slice_or_bitcast = nullptr;
- if (Match(instr, m::MultiplyAnyOrder(
- m::AnyOf<HloInstruction>(
- m::Slice(&slice_or_bitcast,
- CublasLtMatmulMaybeF8(&existing_gemm)),
- m::Bitcast(&slice_or_bitcast,
- CublasLtMatmulMaybeF8(&existing_gemm)),
- CublasLtMatmulMaybeF8(&existing_gemm)),
- m::Op(&cdf).WithOneUser())) &&
- Match(cdf,
- m::MultiplyAnyOrder(
- BcastConstScalar(0.5),
- m::AddAnyOrder(
- BcastConstScalar(1.0),
- m::Tanh(
- m::MultiplyAnyOrder(
- BcastConstScalarNear(sqrt(M_2_PI)),
- m::AddAnyOrder(
- m::Op().Is(slice_or_bitcast ? slice_or_bitcast
- : existing_gemm),
- m::MultiplyAnyOrder(
- BcastConstScalarNear(0.044715),
- m::MultiplyAnyOrder(
- m::Op().Is(slice_or_bitcast
- ? slice_or_bitcast
- : existing_gemm),
- m::MultiplyAnyOrder(
- m::Op().Is(slice_or_bitcast
- ? slice_or_bitcast
- : existing_gemm),
- m::Op().Is(slice_or_bitcast
- ? slice_or_bitcast
- : existing_gemm))
- .WithOneUser())
- .WithOneUser())
- .WithOneUser())
- .WithOneUser())
- .WithOneUser())
- .WithOneUser())))) {
- return FuseGeluActivation(instr, existing_gemm, slice_or_bitcast);
- }
- return absl::OkStatus();
- }
-
- // Fuse the scaling of an FP8 GEMM into the Custom Call.
- absl::Status HandleDivide(HloInstruction *instr) override {
- HloInstruction *existing_gemm, *d_scale;
- if (Match(instr, m::Divide(CublasLtMatmulF8(&existing_gemm).WithOneUser(),
- m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
- return F8ScaleD(instr, existing_gemm, d_scale);
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleAdd(HloInstruction *instr) override {
- HloInstruction *bias, *existing_gemm = nullptr;
- HloInstruction *optional_slice = nullptr;
- HloInstruction *optional_convert = nullptr;
- HloInstruction *optional_bitcast = nullptr;
- // Attempt to elide broadcast and fuse addition of a vector bias into
- // GEMM, including when slicing is applied to the result.
- if (Match(instr,
- m::AddAnyOrder(
- OptionalBitcast(
- &optional_bitcast,
- OptionalSlice(
- &optional_slice,
- CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
- .WithOneUser())
- .WithOneUser(),
- m::Broadcast(&bias,
- OptionalConvert(&optional_convert, m::Op()))))) {
- TF_ASSIGN_OR_RETURN(
- bool was_fused,
- FuseVectorBiasAdd(instr, bias, existing_gemm, optional_slice,
- optional_convert, optional_bitcast));
-
- if (was_fused) {
- return absl::OkStatus();
- }
- }
- // Attempt to elide broadcast and fuse addition of a vector bias into
- // *batched* GEMM as a matrix bias addition using FuseMatrixBiasAdd.
- // add(bitcast(gemm(a, b)), broadcast(bias)) ->
- // bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) ->
- // bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd)
- //
- if (Match(
- instr,
- m::AddAnyOrder(
- m::Bitcast(CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
- .WithOneUser(),
- m::Broadcast(&bias, m::Op()).WithOneUser()))) {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_add,
- MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
- MakeBitcastHlo(bias, existing_gemm->shape())));
- TF_RETURN_IF_ERROR(
- ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
-
- // Continue below.
- instr = new_add;
- }
-
- // Do not fuse broadcast unless we can fuse its input, as it will cause
- // broadcast materialization.
- auto is_not_broadcast = [](const HloInstruction *instr) {
- return instr->opcode() != HloOpcode::kBroadcast;
- };
-
- // add(bitcast(gemm(a, b)), bias) ->
- // bitcast(add(gemm(a, b), bitcast(bias))) ->
- // bitcast(gemm(a, b, bitcast(bias))) (later down in this function).
- //
- // We see this idiom in models that contain batch-dots, where we cast
- // between a rank-2 shape for non-batch dots and a higher-rank shape for
- // batch-dots.
- //
- // The last stage of the transform may fail (because of any of the checks in
- // FuseMatrixBiasAdd), but if so that's okay -- we'll have done a useless
- // transformation, but it doesn't hurt anything.
- if (Match(instr,
- m::AddAnyOrder(
- m::Bitcast(
- GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
- .WithOneUser(),
- m::Op(&bias).WithPredicate(is_not_broadcast)))) {
- HloInstruction *new_bitcast =
- MakeBitcastHlo(bias, existing_gemm->shape(), &bias->metadata());
- TF_ASSIGN_OR_RETURN(HloInstruction * new_add,
- MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
- new_bitcast, &bias->metadata()));
- TF_RETURN_IF_ERROR(
- ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
-
- // Continue below transforming new_add.
- instr = new_add;
- }
-
- // Attempt to fuse matrix bias into gemm with optional convert
- // add(convert(gemm(a, b)), c) -> gemm(a, b, c)
- // add(gemm(a, b), c) -> gemm(a, b, c)
- if (Match(instr,
- m::AddAnyOrder(
- m::AnyOf<HloInstruction>(
- GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(),
- m::Convert(
- GemmOrCublasLtMatmul(&existing_gemm).WithOneUser())
- .WithOneUser()),
- m::Op(&bias).WithPredicate(is_not_broadcast)))) {
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
- existing_gemm->backend_config<GpuBackendConfig>());
- const GemmBackendConfig &gemm_backend_config =
- gpu_backend_config.gemm_backend_config();
- // check if type combination is supported here
- TF_ASSIGN_OR_RETURN(
- bool types_are_supported,
- IsLegacyCublasMatmul(*existing_gemm)
- ? TypesAreSupportedByLegacyCublas(*existing_gemm,
- gemm_backend_config, instr)
- : TypesAreSupportedByCublasLt(*existing_gemm, gemm_backend_config,
- instr));
-
- // for mix type gemm, only fuse add if there is no consumers
- // ROOT add
- // ROOT tuple(add)
- bool has_no_consumer =
- instr->shape().element_type() ==
- existing_gemm->shape().element_type() ||
- instr->user_count() == 0 ||
- (instr->user_count() == 1 &&
- instr->users()[0]->opcode() == HloOpcode::kTuple &&
- instr->users()[0]->user_count() == 0);
-
- if (types_are_supported && has_no_consumer) {
- return FuseMatrixBiasAdd(instr, bias, existing_gemm);
- }
- }
-
- HloInstruction *optional_bitcast_matrix = nullptr;
- HloInstruction *optional_slice_matrix = nullptr;
- if (Match(instr,
- m::AddAnyOrder(
- OptionalBitcast(
- &optional_bitcast_matrix,
- OptionalSlice(&optional_slice_matrix,
- GemmOrCublasLtMatmulMaybeF8(&existing_gemm)
- .WithOneUser()))
- .WithOneUser(),
- m::Op(&bias).WithPredicate(is_not_broadcast)))) {
- // The matrix bias must not be FP8, see
- // https://docs.nvidia.com/cuda/cublas/index.html.
- if (!IsF8Type(bias)) {
- return FuseMatrixBiasAdd(instr, bias, existing_gemm,
- optional_bitcast_matrix,
- optional_slice_matrix);
- }
- }
-
- return absl::OkStatus();
- }
-
- absl::Status HandleMaximum(HloInstruction *instr) override {
- HloInstruction *existing_gemm, *zeros;
- HloInstruction *optional_slice_or_bitcast = nullptr;
- // Attempt to elide maximum and fuse ReLU activation into GEMM, including
- // when slicing or bitcasting is applied to the result.
- if (Match(instr,
- m::MaximumAnyOrder(
- m::AnyOf<HloInstruction>(
- m::Slice(
- &optional_slice_or_bitcast,
- CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
- m::Bitcast(
- &optional_slice_or_bitcast,
- CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
- CublasLtMatmulMaybeF8(&existing_gemm))
- .WithOneUser(),
- m::Broadcast(&zeros, m::ConstantScalar(0))))) {
- TF_RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm,
- optional_slice_or_bitcast));
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleConvert(HloInstruction *instr) override {
- HloInstruction *clamp_lower, *clamp_upper, *existing_gemm,
- *d_scale = nullptr, *binary = nullptr;
- // Attempt to elide the scaling and conversion of the result of an FP8
- // GEMM, including the optional calculation of the maximum of the absolute
- // values before scaling, and adapt the Custom Call.
- if (Match(instr,
- m::Convert(
- m::Clamp(
- m::Broadcast(m::ConstantScalar(&clamp_lower)),
- m::AnyOf<HloInstruction>(
- CublasLtMatmulF8(&existing_gemm),
- m::Divide(&binary, CublasLtMatmulF8(&existing_gemm),
- m::Broadcast(m::Op(&d_scale))),
- m::MultiplyAnyOrder(&binary,
- CublasLtMatmulF8(&existing_gemm),
- m::Broadcast(m::Op(&d_scale)))),
- m::Broadcast(m::ConstantScalar(&clamp_upper)))
- .WithOneUser()))) {
- return F8ConvertD(
- instr, existing_gemm, d_scale, clamp_lower, clamp_upper,
- /*mult_scale=*/(binary && binary->opcode() == HloOpcode::kMultiply));
- }
- return absl::OkStatus();
- }
-
- static bool IsCuda(const se::GpuComputeCapability &gpu_version) {
- return std::holds_alternative<se::CudaComputeCapability>(gpu_version);
- }
-
- static absl::StatusOr<se::CudaComputeCapability> GetCudaComputeCapability(
- const se::GpuComputeCapability &gpu_version) {
- auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version);
- if (cuda_cc == nullptr) {
- return absl::InvalidArgumentError("Compute Capability is not CUDA.");
- }
- return *cuda_cc;
- }
-
- static bool IsRocm(const se::GpuComputeCapability &gpu_version) {
- return std::holds_alternative<se::RocmComputeCapability>(gpu_version);
- }
-
- static absl::StatusOr<se::RocmComputeCapability> GetRocmComputeCapability(
- const se::GpuComputeCapability &gpu_version) {
- auto rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);
- if (rocm_cc == nullptr) {
- return absl::InvalidArgumentError("Compute Capability is not ROCm.");
- }
- return *rocm_cc;
- }
-
- absl::StatusOr<bool> CreateF8CustomCall(HloInstruction *instr,
- GpuBackendConfig &gpu_backend_config,
- MatchedFp8Param a,
- MatchedFp8Param b) {
- GemmBackendConfig &gemm_backend_config =
- *gpu_backend_config.mutable_gemm_backend_config();
- if (IsCuda(gpu_version_)) {
- TF_ASSIGN_OR_RETURN(auto cuda_compute_capability,
- GetCudaComputeCapability(gpu_version_));
- // FP8 GEMM kernels are only available on Ada, Hopper, and later
- // architectures.
- if (!cuda_compute_capability.IsAtLeast(8, 9)) {
- VLOG(1) << "FP8 Custom Calls require Ada, Hopper, or later "
- "architectures. Got: "
- << cuda_compute_capability.ToString()
- << " and toolkit version: " << toolkit_version_;
- return false;
- }
- // FP8 GEMM kernels are only available with CUDA 12.0 and above
- if (toolkit_version_ < 12000) {
- VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer.";
- return false;
- }
- }
-
- if (IsRocm(gpu_version_)) {
- TF_ASSIGN_OR_RETURN(auto rocm_compute_capability,
- GetRocmComputeCapability(gpu_version_));
- if (!rocm_compute_capability.has_fp8_support()) {
- VLOG(1) << "FP8 Custom Calls require MI300, or later architectures.";
- return false;
- }
- if (toolkit_version_ < 60000) {
- // FP8 GEMM kernels are only available with ROCm 6.0 and above
- VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer.";
- return false;
- }
- }
-
- PrimitiveType a_type = a.fp8_input->shape().element_type();
- PrimitiveType b_type = b.fp8_input->shape().element_type();
-
- // cuBLASLt FP8 GEMM kernels require one of the two operands to be in
- // F8E4M3FN format.
- if (IsCuda(gpu_version_)) {
- if (a_type == F8E5M2 && b_type == F8E5M2) {
- VLOG(1)
- << "Failed to rewrite " << instr->ToShortString()
- << " into FP8 Custom Call. The element type of one of the operands "
- "must be F8E4M3FN.";
- return false;
- }
- if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
- (b_type != F8E5M2 && b_type != F8E4M3FN)) {
- VLOG(1) << "Failed to rewrite " << instr->ToShortString()
- << " into FP8 Custom Call. The input types must be F8E5M2 or "
- "F8E4M3FN, but got "
- << PrimitiveType_Name(a_type) << " and "
- << PrimitiveType_Name(b_type);
- return false;
- }
- }
-
- if (IsRocm(gpu_version_)) {
- if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
- VLOG(1)
- << "Failed to rewrite " << instr->ToShortString()
- << " into FP8 Custom Call. The element type of one of the operands "
- "must be F8E4M3FNUZ.";
- return false;
- }
- if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
- (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
- VLOG(1)
- << "Failed to rewrite " << instr->ToShortString()
- << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
- "F8E4M3FNUZ, but got "
- << PrimitiveType_Name(a_type) << " and "
- << PrimitiveType_Name(b_type);
- return false;
- }
- }
-
- absl::Span<const int64_t> batch_dims =
- gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions();
-
- // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32
- // format. Set the factors to one when no scaling factors were captured.
- Literal one_literal = LiteralUtil::One(F32);
- HloInstruction *one = instr->AddInstruction(
- HloInstruction::CreateConstant(one_literal.Clone()));
- std::array<bool, 2> mult_scale{a.mult_scale, b.mult_scale};
- std::array<HloInstruction *, 2> scales{a.scale, b.scale}, inv_scales,
- scales_f32;
- for (int i = 0; i < scales.size(); ++i) {
- if (scales[i]) {
- if (!ShapeUtil::IsScalar(scales[i]->shape())) {
- VLOG(1) << "Failed to rewrite " << instr->ToShortString()
- << " into FP8 Custom Call. The scaling factors must be "
- "scalars.";
- return false;
- }
- if (!mult_scale[i]) {
- inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary(
- scales[i]->shape(), HloOpcode::kDivide, one, scales[i]));
- }
- scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i];
- if (scales_f32[i]->shape().element_type() != F32) {
- scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::MakeScalarShape(F32), scales_f32[i]));
- }
- } else {
- scales_f32[i] = one;
- }
- }
-
- PrimitiveType d_type = instr->shape().element_type();
- bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32);
- if (IsCuda(gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) {
- supported_d_type = true;
- }
- if (IsRocm(gpu_version_) && toolkit_version_ >= 60200 &&
- (d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) {
- supported_d_type = true;
- }
- if (!supported_d_type) {
- VLOG(1) << "Failed to rewrite " << instr->ToShortString()
- << " into FP8 Custom Call. Output element type must be "
- << (IsCuda(gpu_version_) ? "F8E4M3FN, F8E5M2, BF16, F16 or F32. "
- : toolkit_version_ >= 60200
- ? "F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. "
- : "BF16, F16 or F32. ")
- << "Actual element type is " << PrimitiveType_Name(d_type);
- return false;
- }
-
- // Each operand must have exactly one contracting and one non-contracting
- // dimension.
- absl::Span<const int64_t> a_contracting_dims =
- gemm_backend_config.dot_dimension_numbers()
- .lhs_contracting_dimensions();
- absl::Span<const int64_t> b_contracting_dims =
- gemm_backend_config.dot_dimension_numbers()
- .rhs_contracting_dimensions();
- if (a_contracting_dims.size() != 1 || b_contracting_dims.size() != 1) {
- VLOG(1) << "Failed to rewrite " << instr->ToShortString()
- << " into FP8 Custom Call. A and B must have one contracting "
- "dimension.";
- return false;
- }
- if ((a.commutative_ops.empty() ? a.fp8_input
- : a.commutative_ops.back().first)
- ->shape()
- .dimensions_size() -
- batch_dims.size() !=
- 2 ||
- (b.commutative_ops.empty() ? b.fp8_input
- : b.commutative_ops.back().first)
- ->shape()
- .dimensions_size() -
- batch_dims.size() !=
- 2) {
- VLOG(1) << "Failed to rewrite " << instr->ToShortString()
- << "into FP8 Custom Call. A and B must have one non-contracting "
- "dimension.";
- return false;
- }
-
- // Sequentially apply the collected unary, dynamic-slice, pad and select ops
- // to the unconverted and unscaled operands.
- auto shift_ops = [&instr](HloInstruction *&x, InstrPath &x_ops) -> void {
- for (std::pair<HloInstruction *, int> op : x_ops) {
- std::vector<HloInstruction *> operands = {x};
- // Insert the additional operands of dynamic-slice ops.
- if (op.first->opcode() == HloOpcode::kDynamicSlice) {
- for (int i = 1; i < op.first->operand_count(); ++i) {
- operands.emplace_back(op.first->mutable_operand(i));
- }
- }
- // Convert the second operand of pad ops.
- if (op.first->opcode() == HloOpcode::kPad) {
- HloInstruction *convert =
- instr->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::ChangeElementType(op.first->operand(1)->shape(),
- x->shape().element_type()),
- op.first->mutable_operand(1)));
- operands.emplace_back(convert);
- }
- // Convert and insert the additional operands of select ops.
- if (op.first->opcode() == HloOpcode::kSelect) {
- // The first operand is the predicate.
- operands.emplace(operands.begin(), op.first->mutable_operand(0));
- // Convert the remaining operand.
- int operand_idx = op.second == 2 ? 1 : 2;
- HloInstruction *convert =
- instr->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::ChangeElementType(
- op.first->operand(operand_idx)->shape(),
- x->shape().element_type()),
- op.first->mutable_operand(operand_idx)));
- operands.emplace(operands.begin() + operand_idx, convert);
- }
- x = instr->AddInstruction(op.first->CloneWithNewOperands(
- ShapeUtil::MakeShapeWithDenseLayout(
- x->shape().element_type(), op.first->shape().dimensions(),
- op.first->shape().layout().minor_to_major()),
- operands));
- }
- return;
- };
- shift_ops(a.fp8_input, a.commutative_ops);
- shift_ops(b.fp8_input, b.commutative_ops);
-
- TF_ASSIGN_OR_RETURN(bool a_is_col_major,
- MatrixIsColumnMajor(*instr, gemm_backend_config, "a"));
- TF_ASSIGN_OR_RETURN(bool b_is_col_major,
- MatrixIsColumnMajor(*instr, gemm_backend_config, "b"));
-
- DotDimensionNumbers *dim_nums =
- gemm_backend_config.mutable_dot_dimension_numbers();
- int batch_dim_offset = batch_dims.size();
-
- // cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
- // be row-major. If A is column-major, swap the contracting and
- // non-contracting dimension and transpose the matrix to effectively make it
- // column-major.
- // TODO(philipphack): Remove once cuBLASLt supports A being column-major
- if (a_is_col_major) {
- CHECK(a_contracting_dims[0] == batch_dim_offset ||
- a_contracting_dims[0] == batch_dim_offset + 1);
- if (a_contracting_dims[0] == batch_dim_offset) {
- dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset + 1);
- } else {
- dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset);
- }
- a.fp8_input =
- TransposeMatrix(a.fp8_input, a_contracting_dims[0], batch_dims);
- }
-
- // Similarly, cuBLASLt requires the second operand to be column-major, so
- // make it column-major if it is currently row-major.
- if (!b_is_col_major) {
- CHECK(b_contracting_dims[0] == batch_dim_offset ||
- b_contracting_dims[0] == batch_dim_offset + 1);
- if (b_contracting_dims[0] == batch_dim_offset) {
- dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset + 1);
- } else {
- dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset);
- }
- b.fp8_input =
- TransposeMatrix(b.fp8_input, b_contracting_dims[0], batch_dims);
- }
-
- a.fp8_input = PadOperandToMultipleOf16(batch_dims, a.fp8_input);
- b.fp8_input = PadOperandToMultipleOf16(batch_dims, b.fp8_input);
- Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims);
-
- std::vector<HloInstruction *> operands_list = {
- a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one};
-
- HloInstruction *new_custom_call =
- instr->AddInstruction(HloInstruction::CreateCustomCall(
- ShapeUtil::MakeShapeWithDenseLayout(
- instr->shape().element_type(), new_output_shape.dimensions(),
- instr->shape().layout().minor_to_major()),
- operands_list, kCublasLtMatmulF8CallTarget));
- TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config));
- TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call));
-
- // Slice the result of the GEMM if the operands were padded.
- HloInstruction *slice = nullptr;
- if (new_output_shape.dimensions() != instr->shape().dimensions()) {
- std::vector<int64_t> start_indices(instr->shape().rank(), 0);
- std::vector<int64_t> strides(instr->shape().rank(), 1);
- slice = instr->AddInstruction(HloInstruction::CreateSlice(
- instr->shape(), new_custom_call, start_indices,
- instr->shape().dimensions(), strides));
- }
-
- TF_RETURN_IF_ERROR(
- ReplaceInstruction(instr, slice ? slice : new_custom_call));
- VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call.";
- return true;
- }
-
- absl::Status F8ScaleD(HloInstruction *instr, HloInstruction *existing_gemm,
- HloInstruction *d_scale) {
- if (!ShapeUtil::IsScalar(d_scale->shape())) {
- return absl::OkStatus();
- }
-
- // When the output of an FP8 GEMM is scaled but not type converted to FP8,
- // cublasLT requires the scaling factor to be forwarded to the Custom Call
- // as a_scale (chosen here) or b_scale. The scaling factor is fused here
- // when no input scaling factors were fused during the creation of the
- // Custom Call. When the maximum of the absolute value of the output of an
- // FP8 GEMM is calculated and the output is scaled and type converted to
- // FP8, the scaling of the output is fused in F8ConvertD.
- if (!existing_gemm->operand(2)->IsConstant() ||
- existing_gemm->operand(2)->literal().GetAsDouble({}) != 1.) {
- return absl::OkStatus();
- }
-
- // The application of the scaling of the output to the input (see previous
- // comment) is not valid for epilogues other than ReLU or when a matrix bias
- // has been fused.
- TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
- existing_gemm->backend_config<GpuBackendConfig>());
- const GemmBackendConfig &config = gpu_backend_config.gemm_backend_config();
- if ((config.epilogue() != GemmBackendConfig::DEFAULT &&
- config.epilogue() != GemmBackendConfig::RELU) ||
- config.beta() != 0.) {
- return absl::OkStatus();
- }
-
- // If necessary, invert the scaling factor of D and convert to F32.
- TF_ASSIGN_OR_RETURN(
- d_scale,
- InvertAndConvertScalar(d_scale, instr->opcode() == HloOpcode::kDivide));
-
- TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, d_scale));
- TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
-
- VLOG(1) << "Scaling of FP8 GEMM fused into Custom Call.";
- return absl::OkStatus();
- }
-
- absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm,
- HloInstruction *d_scale, HloInstruction *clamp_lower,
- HloInstruction *clamp_upper,
- bool mult_scale = false) {
- // Verify the data types and the operands of clamp.
- if (instr->shape().element_type() == F8E4M3FN) {
- if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e4m3fn>::lowest())) ||
- !clamp_upper->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e4m3fn>::max()))) {
- return absl::OkStatus();
- }
- } else if (instr->shape().element_type() == F8E5M2) {
- if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e5m2>::lowest())) ||
- !clamp_upper->literal().IsAllFloat(static_cast<float>(
- std::numeric_limits<tsl::float8_e5m2>::max()))) {
- return absl::OkStatus();
- }
- } else {
- return absl::OkStatus();
- }
-
- if (d_scale && !ShapeUtil::IsScalar(d_scale->shape())) {
- return absl::OkStatus();
- }
-
- // The possible second user of the GEMM must be the calculation of the
- // maximum of the absolute value of the result of the GEMM. Since it is
- // unknown in what form this operation will be used, it is identified in a
- // top-down approach by inspecting the users of the GEMM.
- const std::vector<HloInstruction *> gemm_users = existing_gemm->users();
- HloInstruction *reduce_damax = nullptr;
- if (gemm_users.size() == 2) {
- // In the presence of a ReLU activation, the abs instruction is elided
- // since abs(ReLU(x)) = ReLU(x).
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- existing_gemm->backend_config<GpuBackendConfig>());
- const GemmBackendConfig &config = gpu_config.gemm_backend_config();
- for (int i = 0; i < gemm_users.size(); ++i) {
- HloInstruction *maybe_reduce = nullptr;
- if (gemm_users[i]->opcode() == HloOpcode::kAbs) {
- if (gemm_users[i]->users().size() != 1) continue;
- maybe_reduce = gemm_users[i]->users()[0];
- } else {
- // If there is no Abs instruction, relu is required as epilogue to
- // ensure all values are nonnegative.
- if (config.epilogue() != GemmBackendConfig::BIAS_RELU &&
- config.epilogue() != GemmBackendConfig::RELU)
- continue;
- maybe_reduce = gemm_users[i];
- }
-
- if (maybe_reduce->opcode() == HloOpcode::kReduce &&
- maybe_reduce->operands().size() == 2 &&
- maybe_reduce->operand(1)->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsScalar(maybe_reduce->operand(1)->shape())) {
- HloInstruction *reduce = maybe_reduce;
- HloComputation *reduce_comp = reduce->to_apply();
- HloInstruction *reduce_comp_root = reduce_comp->root_instruction();
- if (reduce->operand(1)->literal().GetAsDouble({}) <= 0. &&
- reduce_comp_root->opcode() == HloOpcode::kMaximum &&
- reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
- reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter) {
- reduce_damax = reduce;
- }
- }
- }
- if (!reduce_damax) {
- return absl::OkStatus();
- }
- } else if (gemm_users.size() > 2) {
- return absl::OkStatus();
- }
-
- TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
- existing_gemm->backend_config<GpuBackendConfig>());
- const GemmBackendConfig &gemm_backend_config =
- gpu_backend_config.gemm_backend_config();
-
- if (gemm_backend_config.beta() != 0.0) {
- if (existing_gemm->operand(2)->shape().element_type() != BF16 &&
- existing_gemm->operand(2)->shape().element_type() != F16) {
- VLOG(1) << "The scaling and conversion of the result of "
- << existing_gemm->ToShortString()
- << " is not fused into the FP8 Custom Call because it "
- "conflicts with the existing fusion of the addition of a "
- "matrix bias with element type other than BF16 or F16.";
- return absl::OkStatus();
- } else {
- // Turn off the output to operand aliasing, since the fp8 output and
- // bf16/fp16 bias have different sizes.
- xla::Cast<HloCustomCallInstruction>(existing_gemm)
- ->set_output_to_operand_aliasing({});
- }
- }
-
- // If necessary, invert the scaling factor of D and convert to F32.
- if (d_scale) {
- TF_ASSIGN_OR_RETURN(d_scale,
- InvertAndConvertScalar(d_scale, !mult_scale));
- TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(
- gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale));
- }
-
- // If present, elide the calculation of the maximum of the absolute values
- // of the result of the GEMM.
- if (reduce_damax) {
- return F8AddDAmax(instr, existing_gemm, reduce_damax);
- }
-
- std::unique_ptr<HloInstruction> new_gemm =
- existing_gemm->CloneWithNewShape(instr->shape());
-
- TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm)));
-
- VLOG(1) << "Conversion" << (reduce_damax ? " and amax calculation" : "")
- << " fused into FP8 GEMM.";
- return absl::OkStatus();
- }
-
- // Adds a scalar DAmax return value to an FP8 GEMM.
- absl::Status F8AddDAmax(HloInstruction *instr, HloInstruction *existing_gemm,
- HloInstruction *reduce_damax) {
- // Change the output shape of the Custom Call to tuple(D, DAmax).
- Shape damax_shape = ShapeUtil::MakeScalarShape(F32);
- Shape tuple_shape =
- ShapeUtil::MakeTupleShape({instr->shape(), damax_shape});
- HloInstruction *gemm_and_damax =
- instr->AddInstruction(existing_gemm->CloneWithNewShape(tuple_shape));
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- gemm_and_damax->backend_config<GpuBackendConfig>());
- GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
- config.set_damax_output(true);
- TF_RETURN_IF_ERROR(gemm_and_damax->set_backend_config(gpu_config));
-
- // Obtain D and DAmax separately from the output tuple.
- HloInstruction *d =
- instr->AddInstruction(HloInstruction::CreateGetTupleElement(
- instr->shape(), gemm_and_damax, 0));
- HloInstruction *damax = instr->AddInstruction(
- HloInstruction::CreateGetTupleElement(damax_shape, gemm_and_damax, 1));
-
- // Convert DAmax from FP32 to the requested type and elide reduce.
- HloInstruction *damax_converted = instr->AddInstruction(
- HloInstruction::CreateConvert(reduce_damax->shape(), damax));
- TF_RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted));
- TF_RETURN_IF_ERROR(ReplaceInstruction(instr, d));
-
- return absl::OkStatus();
- }
-
- // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add
- // instruction in the following form:
- // Add(OptionalBitcast(OptionalSlice(gemm)), bias)
- // where 'gemm' is expected to be a cuBLAS custom_call. Slice is introduced
- // when the inputs of the gemm are possibly padded. Bitcast is introduced to
- // handle high rank input.
- absl::Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias,
- const HloInstruction *gemm,
- HloInstruction *bitcast = nullptr,
- HloInstruction *slice = nullptr) {
- TF_RET_CHECK(Shape::Equal().IgnoreElementType()(bias->shape(),
- bitcast ? bitcast->shape()
- : slice ? slice->shape()
- : gemm->shape()));
-
- // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
- // supports fixed values for alpha/beta.
- if (gemm->shape().element_type() == S32) {
- return absl::OkStatus();
- }
-
- // To ensure correctness, only slices that chop off the ends of dimensions
- // are supported.
- if (slice) {
- int slice_op_dim = slice->operand(0)->shape().rank();
- if (slice->slice_starts() != std::vector<int64_t>(slice_op_dim, 0) ||
- slice->slice_strides() != std::vector<int64_t>(slice_op_dim, 1)) {
- return absl::OkStatus();
- }
- }
- // Cublas gemm overwrites the bias matrix, so fusion is only possible if the
- // gemm is the only user. CublasLt gemm can operate out-of-place.
- bool can_overwrite_bias = [bias]() {
- if (bias->user_count() > 1) {
- // There is another user of the data, do not overwrite it.
- return false;
- }
-
- if (bias->opcode() != HloOpcode::kParameter) {
- // Not a parameter; can overwrite.
- return true;
- }
-
- // The bias is a parameter of the computation; check if it is aliased.
- if (!bias->parent()->IsEntryComputation()) {
- // Only the HloModule has input/output aliasing, since this is not the
- // entry computation, there are no guarantees about aliasing; do not
- // overwrite.
- return false;
- }
- const auto &in_out_alias_config =
- bias->GetModule()->input_output_alias_config();
- // If the parameter is aliased, we can overwrite it.
- // TODO(victorstone): The assumption when calling ParameterHasAlias is
- // that bias is not a tuple. This is why we pass {} as the argument for
- // param_index.
- return in_out_alias_config.ParameterHasAlias(bias->parameter_number(),
- /*param_index=*/{});
- }();
- bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) ||
- IsCublasLtMatmul(*gemm) || can_overwrite_bias;
-
- auto gpu_config = gemm->backend_config<GpuBackendConfig>().value();
- GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
- // It is possible to fuse into a cublasLt matmul that already has a vector
- // bias, but no other epilogue will commute with the matrix bias add.
- bool supported_epilogue =
- ((config.epilogue() == GemmBackendConfig::DEFAULT) ||
- (config.epilogue() == GemmBackendConfig::BIAS));
-
- if ((config.beta() != 0) || !want_to_fuse_bias ||
- (gemm->user_count() != 1) || !supported_epilogue) {
- return absl::OkStatus();
- }
-
- config.set_beta(1.0);
-
- std::vector<HloInstruction *> operands(gemm->operands().begin(),
- gemm->operands().end());
- HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias);
- if (bitcast) {
- maybe_constant_folded_bias =
- instr->AddInstruction(HloInstruction::CreateBitcast(
- slice->shape(), maybe_constant_folded_bias));
- }
-
- maybe_constant_folded_bias =
- PadOperandToTargetShape(gemm->shape(), maybe_constant_folded_bias);
-
- operands.insert(operands.begin() + 2, maybe_constant_folded_bias);
-
- std::unique_ptr<HloInstruction> fused_op =
- gemm->CloneWithNewOperands(gemm->shape(), operands);
- // set output shape to bias shape if mix type
- fused_op->mutable_shape()->set_element_type(bias->shape().element_type());
- TF_RETURN_IF_ERROR(fused_op->set_backend_config(gpu_config));
-
- // Choose whether the bias must alias the output. Legacy cublas GEMMs must
- // operate in place and alias the bias with the output, whereas with
- // cublasLt we can choose.
- //
- // Operating in place is always safe; copy-insertion will insert copies if
- // necessary. But (we assume) copying is slower than operating
- // out-of-place, so for cublasLt (where we have the choice), we try to
- // operate in place if we think it a copy won't be necessary.
- //
- // We assume that parameters are always read-only and therefore we'd need to
- // copy if we were going to operate in place. (This is not quite true; the
- // param could have input/output aliasing.) We also assume that if there
- // are other uses of the bias, we might need to copy. (Again, not quite
- // true if those uses all come before this operation. But copy-insertion
- // runs before scheduling, so it can't know and has to conservatively insert
- // copies.)
- if (IsLegacyCublasMatmul(*fused_op) || can_overwrite_bias) {
- xla::Cast<HloCustomCallInstruction>(fused_op.get())
- ->set_output_to_operand_aliasing({{{}, {2, {}}}});
- }
- TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
- if (slice) {
- fused_op = slice->CloneWithNewOperands(
- slice->shape(),
- {slice->parent()->AddInstruction(std::move(fused_op))});
- }
-
- if (bitcast) {
- fused_op = bitcast->CloneWithNewOperands(
- bitcast->shape(),
- {bitcast->parent()->AddInstruction(std::move(fused_op))});
- }
-
- return ReplaceWithNewInstruction(instr, std::move(fused_op));
- }
-
- // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add
- // instruction in the following form:
- // Add(OptionalBitcast(OptionalSlice(gemm)), Broadcast(OptionalConvert()))
- // where 'gemm' is expected to be a cuBLAS custom_call. The optional
- // convert is only used for F8 matmuls as cublasLt has specific constraints
- // on the vector bias type for such matmuls. The optional bitcast is
- // necessary to handle high rank input cases.
- absl::StatusOr<bool> FuseVectorBiasAdd(HloInstruction *instr,
- HloInstruction *broadcast,
- HloInstruction *gemm,
- HloInstruction *slice = nullptr,
- HloInstruction *convert = nullptr,
- HloInstruction *bitcast = nullptr) {
- if (!bitcast) {
- TF_RET_CHECK(ShapeUtil::Compatible(
- broadcast->shape(), (slice ? slice->shape() : gemm->shape())));
- }
- // Verify that the data type is supported by Epilogue Fusion.
- if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
- return false;
- }
-
- HloInstruction *bias = broadcast->mutable_operand(0);
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- gemm->backend_config<GpuBackendConfig>());
- GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
- // # output column dims == # non-contracting rhs operand dims.
- const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
- size_t num_col_dims = gemm->operand(1)->shape().rank() -
- dot_dims.rhs_batch_dimensions_size() -
- dot_dims.rhs_contracting_dimensions_size();
-
- if ((gemm->user_count() != 1) ||
- (config.epilogue() != GemmBackendConfig::DEFAULT) ||
- (bias->shape().rank() != num_col_dims)) {
- return false;
- }
- // We require the bias vector to have been broadcast in the most major
- // dimensions; i.e. its most minor physical dimensions align with most minor
- // physical dimensions of the gemm output.
- absl::Span<const int64_t> broadcast_dims = broadcast->dimensions();
- for (size_t i = 0; i < num_col_dims; ++i) {
- int64_t dim =
- (bitcast ? bitcast : gemm)->shape().layout().minor_to_major(i);
-
- // Find the corresponding dimension from the bias vector.
- auto it = absl::c_find(broadcast_dims, dim);
-
- if (it == broadcast_dims.end()) {
- return false;
- }
-
- int64_t vector_dim = it - broadcast_dims.begin();
- if (bias->shape().layout().minor_to_major(i) != vector_dim) {
- return false;
- }
- }
-
- std::vector<HloInstruction *> operands(gemm->operands().begin(),
- gemm->operands().end());
- // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just
- // fuse matrix bias.
- if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
- config.beta() != 0.0) {
- return true;
- }
-
- if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
- bias->shape().element_type() == F32) {
- if (convert == nullptr) {
- return false;
- }
-
- HloInstruction *bias_f16_or_bf16 = convert->mutable_operand(0);
- auto compatible_bias_type = [](const PrimitiveType bias_type,
- const PrimitiveType output_type) {
- if (bias_type == BF16) {
- return output_type == F8E4M3FN || output_type == F8E5M2 ||
- output_type == F32 || output_type == BF16;
- } else if (bias_type == F16) {
- return output_type == F16 || output_type == F8E4M3FN ||
- output_type == F8E5M2;
- }
- return false;
- };
-
- // cuBLAS LT does not support FP32 biases on matmuls with FP8 inputs,
- // even if the matmul output is FP32. We do not unconditionally convert
- // the bias to a supported precision (F16 or BF16) because this lowers
- // precision. Instead, we only fuse the bias if the bias itself is a
- // convert from F16 or BF16, fusing the input of the convert instruction
- // to the matmul.
- if (compatible_bias_type(bias_f16_or_bf16->shape().element_type(),
- gemm->shape().element_type())) {
- bias = bias_f16_or_bf16;
- } else {
- VLOG(1) << "Epilogue fusion of FP32 vector bias into FP8 GEMM is "
- "currently not supported. See the cublasLT support matrix.";
- return false;
- }
- }
-
- // In the case of high rank input for FP8, it is necessary to consider
- // potential padding for the bias.
- if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && bitcast) {
- bias = PadOperandToMultipleOf16(
- config.dot_dimension_numbers().rhs_batch_dimensions(), bias);
- }
- // Replace add(gemm, broadcast) with fused new_gemm.
- operands.push_back(bias);
- config.set_epilogue(GemmBackendConfig::BIAS);
- std::unique_ptr<HloInstruction> result =
- gemm->CloneWithNewOperands(gemm->shape(), operands);
- TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
- if (slice) {
- result = slice->CloneWithNewOperands(
- slice->shape(), {slice->parent()->AddInstruction(std::move(result))});
- }
-
- if (bitcast) {
- result = bitcast->CloneWithNewOperands(
- bitcast->shape(),
- {bitcast->parent()->AddInstruction(std::move(result))});
- }
- TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(result)));
- return true;
- }
-
- absl::Status FuseReluActivation(HloInstruction *instr,
- HloInstruction *broadcast,
- HloInstruction *gemm,
- HloInstruction *slice_or_bitcast = nullptr) {
- TF_RET_CHECK(ShapeUtil::Compatible(
- broadcast->shape(),
- (slice_or_bitcast ? slice_or_bitcast->shape() : gemm->shape())));
-
- if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
- return absl::OkStatus();
- }
-
- if (gemm->user_count() != 1) {
- return absl::OkStatus();
- }
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- gemm->backend_config<GpuBackendConfig>());
- GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
- if (config.epilogue() == GemmBackendConfig::DEFAULT) {
- config.set_epilogue(GemmBackendConfig::RELU);
- } else if (config.epilogue() == GemmBackendConfig::BIAS) {
- config.set_epilogue(GemmBackendConfig::BIAS_RELU);
- } else {
- return absl::OkStatus();
- }
-
- std::unique_ptr<HloInstruction> result = gemm->Clone();
- TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
-
- if (slice_or_bitcast) {
- result = slice_or_bitcast->CloneWithNewOperands(
- slice_or_bitcast->shape(),
- {slice_or_bitcast->parent()->AddInstruction(std::move(result))});
- }
-
- return ReplaceWithNewInstruction(instr, std::move(result));
- }
-
- absl::Status FuseGeluActivation(HloInstruction *multiply,
- HloInstruction *gemm,
- HloInstruction *slice_or_bitcast = nullptr) {
- if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
- return absl::OkStatus();
- }
- // For CUDA versions less than 12.3.2, cuBLAS LT returns
- // CUBLAS_STATUS_NOT_SUPPORTED in some cases when fusing gelu into an FP8
- // matmul. We cannot check the patch version, so disable this fusion with
- // CUDA versions less than 12.4.
- if (IsCuda(gpu_version_) && toolkit_version_ < 12040 &&
- IsCublasLtMatmulF8(*gemm)) {
- return absl::OkStatus();
- }
-
- // There are four users of the gemm output within the GELU calculation.
- bool has_aux = gemm->user_count() > 4;
-
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- gemm->backend_config<GpuBackendConfig>());
- GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-
- if (config.epilogue() == GemmBackendConfig::DEFAULT) {
- config.set_epilogue(has_aux ? GemmBackendConfig::GELU_AUX
- : GemmBackendConfig::GELU);
- } else if (config.epilogue() == GemmBackendConfig::BIAS) {
- config.set_epilogue(has_aux ? GemmBackendConfig::BIAS_GELU_AUX
- : GemmBackendConfig::BIAS_GELU);
- } else {
- return absl::OkStatus();
- }
-
- std::unique_ptr<HloInstruction> output = gemm->CloneWithNewShape(
- has_aux ? ShapeUtil::MakeTupleShape({gemm->shape(), gemm->shape()})
- : gemm->shape());
- TF_RETURN_IF_ERROR(output->set_backend_config(gpu_config));
- TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get()));
-
- if (slice_or_bitcast) {
- output = slice_or_bitcast->CloneWithNewOperands(
- slice_or_bitcast->shape(),
- {gemm->parent()->AddInstruction(std::move(output))});
- }
-
- if (has_aux) {
- HloInstruction *tuple_output =
- gemm->parent()->AddInstruction(std::move(output));
- TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
- gemm, HloInstruction::CreateGetTupleElement(tuple_output, 1)));
- output = HloInstruction::CreateGetTupleElement(tuple_output, 0);
- }
-
- return ReplaceWithNewInstruction(multiply, std::move(output));
- }
-
- private:
- se::GpuComputeCapability gpu_version_;
- int32_t toolkit_version_;
- bool f8_rewrite_;
-
- // Choose cublas or cublasLt for the target of the custom call that instr will
- // be rewritten into.
- absl::StatusOr<absl::string_view> GetNonFp8GemmCustomCallTarget(
- const HloInstruction &instr,
- const GemmBackendConfig &gemm_backend_config) const {
- if (!instr.GetModule()
- ->config()
- .debug_options()
- .xla_gpu_enable_cublaslt()) {
- // cublasLt is not enabled.
- return absl::string_view(kGemmCallTarget);
- }
-
- // cublasLt is enabled, check if other internal conditions are met.
- const HloInstruction *lhs = instr.operand(0);
- const HloInstruction *rhs = instr.operand(1);
- if (lhs->shape().element_type() == S8 ||
- rhs->shape().element_type() == S8) {
- // TODO(b/241446501) The XLA usage of cublasLt does not yet handle
- // int8 matmuls. Fallback to legacy cublas.
- return absl::string_view(kGemmCallTarget);
- }
-
- // All internal conditions are met, check if we meet the requirements of
- // cublasLt.
- TF_ASSIGN_OR_RETURN(bool gemm_is_supported_by_cublas_lt,
- GemmIsSupportedByCublasLt(instr, gemm_backend_config));
- if (gemm_is_supported_by_cublas_lt) {
- return absl::string_view(kCublasLtMatmulCallTarget);
- }
-
- // This case is not supported by cublasLt, fallback to legacy cublas.
- return absl::string_view(kGemmCallTarget);
- }
-
- absl::StatusOr<bool> TypesAreSupportedByLegacyCublas(
- const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config,
- const HloInstruction *bias = nullptr) const {
- // Figure out the Atype/Btype.
- const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
- const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
- const PrimitiveType output_type =
- bias ? bias->shape().element_type() : instr.shape().element_type();
- const std::array<PrimitiveType, 12> supported_type = {
- PrimitiveType::S8, PrimitiveType::F16, PrimitiveType::BF16,
- PrimitiveType::F32, PrimitiveType::S32, PrimitiveType::F64,
- PrimitiveType::C64, PrimitiveType::C128};
- // legacy cublas has a defined set of combinations of types that it
- // supports. Figure out the computeType and scaleType.
- if (!absl::c_linear_search(supported_type, output_type)) return false;
- TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
- se::gpu::AsBlasDataType(output_type));
- // TODO(tdanyluk): Investigate why don't we use the actual precision (and
- // algorithm) here? Why do we use the default?
- TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
- se::gpu::GetBlasComputationType(
- PrecisionConfig::ALG_UNSET, a_dtype, output_type,
- stream_executor::blas::kDefaultComputePrecision));
- se::blas::DataType scale_type =
- se::gpu::GetScaleType(output_dtype, compute_type);
-
- using se::blas::ComputationType;
- using se::blas::DataType;
- // This matrix of supported types is taken directly from cublas
- // documentation.
- // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex
- const std::array<
- std::tuple<ComputationType, DataType /*scale_type*/,
- PrimitiveType /*a_dtype*/, PrimitiveType /*b_dtype*/,
- DataType /*output_dtype*/>,
- 32>
- supported_type_combinations = {{
- {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
- PrimitiveType::F16, DataType::kHalf},
-
- {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
- PrimitiveType::S8, DataType::kInt32},
-
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
- PrimitiveType::BF16, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
- PrimitiveType::F16, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
- PrimitiveType::S8, DataType::kFloat},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
- PrimitiveType::BF16, DataType::kFloat},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
- PrimitiveType::F16, DataType::kFloat},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
-
- // There would be an entry here for A/BType complex int8, but we do
- // not support that type.
- {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
- PrimitiveType::C64, DataType::kComplexFloat},
-
- {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
- {ComputationType::kF16AsF32, DataType::kComplexFloat,
- PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
- {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
- {ComputationType::kBF16AsF32, DataType::kComplexFloat,
- PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
- {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
- {ComputationType::kTF32AsF32, DataType::kComplexFloat,
- PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
- {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
- PrimitiveType::F64, DataType::kDouble},
- {ComputationType::kF64, DataType::kComplexDouble,
- PrimitiveType::C128, PrimitiveType::C128,
- DataType::kComplexDouble},
- }};
-
- return absl::c_linear_search(
- supported_type_combinations,
- std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
- output_dtype));
- }
-
- absl::StatusOr<bool> TypesAreSupportedByCublasLt(
- const HloInstruction &instr, const GemmBackendConfig &backend_config,
- const HloInstruction *bias = nullptr) const {
- // Figure out the Atype/Btype.
- const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
- const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
- const PrimitiveType output_type =
- bias ? bias->shape().element_type() : instr.shape().element_type();
- const std::array<PrimitiveType, 12> supported_type = {
- PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN,
- PrimitiveType::S8, PrimitiveType::F16,
- PrimitiveType::BF16, PrimitiveType::F32,
- PrimitiveType::S32, PrimitiveType::F64,
- PrimitiveType::C64, PrimitiveType::C128};
- if (!absl::c_linear_search(supported_type, output_type)) return false;
- // cublasLt has a defined set of combinations of types that it supports.
- // Figure out the computeType and scaleType.
- TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
- se::gpu::AsBlasDataType(output_type));
- const int max_precision = *absl::c_max_element(
- backend_config.precision_config().operand_precision());
- const PrecisionConfig::Algorithm algorithm =
- backend_config.precision_config().algorithm();
- if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm)) return false;
-
- TF_ASSIGN_OR_RETURN(
- const se::blas::ComputationType compute_type,
- se::gpu::GetBlasComputationType(
- algorithm, a_dtype, instr.shape().element_type(), max_precision));
- se::blas::DataType scale_type =
- se::gpu::GetScaleType(output_dtype, compute_type);
-
- using se::blas::ComputationType;
- using se::blas::DataType;
- using TypeCombinations = std::initializer_list<std::tuple<
- ComputationType, DataType /*scale_type*/, PrimitiveType /*a_dtype*/,
- PrimitiveType /*b_dtype*/, DataType /*output_dtype*/>>;
- // This matrix of supported types is taken directly from cublasLt
- // documentation.
- // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
- const TypeCombinations supported_cublas_type_combinations = {
- // FP8 types:
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E4M3FN, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E4M3FN, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E4M3FN, DataType::kFloat},
-
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E5M2, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E5M2, DataType::kF8E4M3FN},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E5M2, DataType::kF8E5M2},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E5M2, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
- PrimitiveType::F8E5M2, DataType::kFloat},
-
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
- PrimitiveType::F8E4M3FN, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
- PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
- PrimitiveType::F8E4M3FN, DataType::kF8E5M2},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
- PrimitiveType::F8E4M3FN, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
- PrimitiveType::F8E4M3FN, DataType::kFloat},
- // There would be an entry here for A/BType complex int8, but we do
- // not support that type.
- {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
- PrimitiveType::C64, DataType::kComplexFloat},
-
- {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
- {ComputationType::kF16AsF32, DataType::kComplexFloat,
- PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
- // The next 4 may be supported by hipblaslt, but they are not
- // covered by any unit tests
- {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
- {ComputationType::kBF16AsF32, DataType::kComplexFloat,
- PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
- {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
- {ComputationType::kTF32AsF32, DataType::kComplexFloat,
- PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
- {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
- PrimitiveType::F64, DataType::kDouble},
- {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128,
- PrimitiveType::C128, DataType::kComplexDouble},
- };
- if (IsCuda(gpu_version_) &&
- absl::c_linear_search(supported_cublas_type_combinations,
- std::tuple{compute_type, scale_type, a_dtype,
- b_dtype, output_dtype})) {
- return true;
- }
- const TypeCombinations supported_hipblas_type_combinations = {
- // FP8 types:
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
-
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E5M2FNUZ, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E5M2FNUZ, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
- PrimitiveType::F8E5M2FNUZ, DataType::kFloat},
-
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
- PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
- };
- if (IsRocm(gpu_version_) &&
- absl::c_linear_search(supported_hipblas_type_combinations,
- std::tuple{compute_type, scale_type, a_dtype,
- b_dtype, output_dtype})) {
- return true;
- }
- const TypeCombinations supported_type_combinations = {
- // Other data types:
- {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
- PrimitiveType::F16, DataType::kHalf},
-
- {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
- PrimitiveType::S8, DataType::kInt32},
- {ComputationType::kI32, DataType::kFloat, PrimitiveType::S8,
- PrimitiveType::S8, DataType::kInt8},
-
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
- PrimitiveType::BF16, DataType::kBF16},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
- PrimitiveType::F16, DataType::kHalf},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
- PrimitiveType::S8, DataType::kFloat},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
- PrimitiveType::BF16, DataType::kFloat},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
- PrimitiveType::F16, DataType::kFloat},
- {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
- PrimitiveType::F32, DataType::kFloat},
- };
-
- return absl::c_linear_search(
- supported_type_combinations,
- std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
- output_dtype));
- }
-
- absl::StatusOr<bool> MatrixIsColumnMajor(
- const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config,
- const std::string matrix_name = "output") const {
- const HloInstruction *lhs = instr.operand(0);
- const HloInstruction *rhs = instr.operand(1);
-
- const DotDimensionNumbers &dot_dims =
- gemm_backend_config.dot_dimension_numbers();
- // We use ALG_UNSET and kDefaultComputePrecision because we don't care about
- // the precision, just the layout, since we're just checking if the matrix
- // is column-major.
- TF_ASSIGN_OR_RETURN(
- GemmConfig gemm_config,
- GemmConfig::For(
- lhs->shape(), dot_dims.lhs_batch_dimensions(),
- dot_dims.lhs_contracting_dimensions(), rhs->shape(),
- dot_dims.rhs_batch_dimensions(),
- dot_dims.rhs_contracting_dimensions(),
- /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(),
- gemm_backend_config.alpha_imag(), gemm_backend_config.beta(),
- /*precision_algorithm=*/PrecisionConfig::ALG_UNSET,
- /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision,
- gemm_backend_config.grad_x(), gemm_backend_config.grad_y()));
-
- if (matrix_name == "lhs" || matrix_name == "a") {
- return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor;
- } else if (matrix_name == "rhs" || matrix_name == "b") {
- return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor;
- } else if (matrix_name == "output" || matrix_name == "d") {
- return gemm_config.output_layout.order ==
- MatrixLayout::Order::kColumnMajor;
- } else {
- return Internal("Invalid matrix name.");
- }
- }
-
- absl::StatusOr<bool> GemmIsSupportedByCublasLt(
- const HloInstruction &instr,
- const GemmBackendConfig &gemm_backend_config) const {
- const HloInstruction *lhs = instr.operand(0);
- const HloInstruction *rhs = instr.operand(1);
- const Shape &output_shape = instr.shape();
-
- TF_ASSIGN_OR_RETURN(
- bool types_are_supported_by_cublas_lt,
- TypesAreSupportedByCublasLt(instr, gemm_backend_config));
- if (!types_are_supported_by_cublas_lt) {
- return false;
- }
-
- // The cublasLt API has two currently known limitations:
- // 1. Batch count must be <2^16.
- constexpr int64_t kMaxBatchCount = 65535;
- // We get the batch dimension size from lhs here, but we could just as well
- // use rhs; they are guaranteed to be the same (TODO:Verify).
- const auto &batch_dimensions =
- gemm_backend_config.dot_dimension_numbers().lhs_batch_dimensions();
- int batch_count = (batch_dimensions.empty() ? 0 : 1);
- // All batch dimensions get flattened into a single batch dimension.
- for (auto batch_dimension : batch_dimensions) {
- batch_count *= lhs->shape().dimensions(batch_dimension);
- }
- if (batch_count > kMaxBatchCount) {
- // This is not supported by cublasLt.
- return false;
- }
-
- TF_ASSIGN_OR_RETURN(bool output_is_column_major,
- MatrixIsColumnMajor(instr, gemm_backend_config));
-
- if (auto isrocm = std::get_if<se::RocmComputeCapability>(&gpu_version_);
- isrocm) {
- if (!isrocm->has_hipblaslt()) {
- return false;
- }
- }
-
- // 2. cublasLt does not support rhs col dimension size > 4194240 for
- // C64.
- constexpr int kMaxDimensionSize{4194240};
- if (output_shape.element_type() != C64) {
- // Does not match type in unsupported case.
- return true;
- }
-
- if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
- auto cuda_compute_capability_ =
- std::get<se::CudaComputeCapability>(gpu_version_);
- if (cuda_compute_capability_.IsAtLeast(
- se::CudaComputeCapability::AMPERE)) {
- // cuBlasLt has an implementation for complex data with compute type
- // 32F_FAST_32TF that uses tensor cores and that is free from the
- // restriction. This implementation only works on Ampere
- // architecture though (where TF32 was introduced).
- return true;
- }
- }
- // Get the rhs non-contracting dimensions as they will eventually be at the
- // cublasLt level.
- std::vector<int64_t> rhs_non_contracting_dims;
- const DotDimensionNumbers &dot_dims =
- gemm_backend_config.dot_dimension_numbers();
-
- if (!output_is_column_major) {
- // cublasLt's matmul output is column major by default. This gemm requires
- // the output to be in row major. Later we will swap lhs & rhs (and
- // transpose each operand) of this gemm. Since we care about the rhs at
- // the cublasLt level, this swap means that we care about the lhs right
- // here.
- TF_ASSIGN_OR_RETURN(
- rhs_non_contracting_dims,
- GetNonContractingDims(lhs->shape(), dot_dims.lhs_batch_dimensions(),
- dot_dims.lhs_contracting_dimensions()));
- } else {
- TF_ASSIGN_OR_RETURN(
- rhs_non_contracting_dims,
- GetNonContractingDims(rhs->shape(), dot_dims.rhs_batch_dimensions(),
- dot_dims.rhs_contracting_dimensions()));
- }
-
- const auto lhs_non_contracting_dimension_size = absl::c_accumulate(
- rhs_non_contracting_dims, 1, [&](int64_t size, int64_t dim) {
- return size * lhs->shape().dimensions(dim);
- });
-
- // Check that the size of the non-contracting dimension is not too large.
- return lhs_non_contracting_dimension_size <= kMaxDimensionSize;
- }
-
- // Turns an F8 dot with unsupported output type into an F8 dot with F32
- // output, and converting the F32 output to unsupported output types.
- absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32(
- HloInstruction *instr) {
- Shape output_f32_shape = instr->shape();
- output_f32_shape.set_element_type(F32);
- HloInstruction *f32_dot =
- instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape));
- HloInstruction *convert = instr->AddInstruction(
- HloInstruction::CreateConvert(instr->shape(), f32_dot));
- TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert));
- return f32_dot;
- }
-
- // Turns an F8 dot into an F16 dot, converting operands to F16 and
- // converting the output back to F8.
- absl::StatusOr<HloInstruction *> TurnF8DotIntoF16Dot(HloInstruction *instr) {
- DCHECK(IsF8Type(instr->operand(0)));
- DCHECK(IsF8Type(instr->operand(1)));
-
- // Convert operands to F16
- for (int i = 0; i < 2; ++i) {
- Shape operand_f16_shape = instr->operand(i)->shape();
- operand_f16_shape.set_element_type(F16);
- HloInstruction *convert =
- instr->AddInstruction(HloInstruction::CreateConvert(
- operand_f16_shape, instr->mutable_operand(i)));
- TF_RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert));
- }
-
- // If output is F8, change output to F16 and then convert it back to F8
- if (IsF8Type(instr)) {
- Shape output_f16_shape = instr->shape();
- output_f16_shape.set_element_type(F16);
- HloInstruction *f16_dot =
- instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape));
- HloInstruction *convert_to_f8 = instr->AddInstruction(
- HloInstruction::CreateConvert(instr->shape(), f16_dot));
- TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8));
- return f16_dot;
- } else {
- return instr;
- }
- }
-};
-
-// Rewriter that adds a workspace to legacy cuBLAS custom calls. We run it
-// separately after gemm rewriter, so that we can do pattern matching without
-// having to match output tuples.
-class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor {
- public:
- explicit GemmWorkspaceRewriteVisitor(
- const se::GpuComputeCapability &gpu_version)
- : gpu_version_(gpu_version) {}
-
- absl::Status HandleCustomCall(HloInstruction *instr) override {
- bool has_aux_output = false;
- if (instr->custom_call_target() == kCublasLtMatmulCallTarget ||
- instr->custom_call_target() == kCublasLtMatmulF8CallTarget) {
- TF_ASSIGN_OR_RETURN(const auto gpu_config,
- instr->backend_config<xla::gpu::GpuBackendConfig>());
- const xla::gpu::GemmBackendConfig &config =
- gpu_config.gemm_backend_config();
- xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
- TF_ASSIGN_OR_RETURN(
- has_aux_output,
- xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(epilogue));
-
- if (!((instr->shape().IsTuple() &&
- instr->shape().tuple_shapes_size() ==
- has_aux_output + config.damax_output() + 1) ||
- instr->shape().IsArray())) {
- return absl::OkStatus();
- }
- } else if (instr->custom_call_target() != kGemmCallTarget ||
- !instr->shape().IsArray()) {
- return absl::OkStatus();
- }
-
- auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version_);
-
- // Pass a user-managed workspace to legacy cuBLAS operations, as
- // otherwise cuBLAS will use its own internal pool which will be competing
- // with XLA allocator for device memory.
- int64_t workspace = cuda_cc == nullptr ? GemmConfig::kDefaultWorkspace
- : cuda_cc->IsAtLeastHopper()
- ? GemmConfig::kHopperWorkspace
- : GemmConfig::kDefaultWorkspace;
-
- // We do not know the workspace size required by cuBLAS, but we can guess
- // that in a worst case cuBLAS will transpose all operands into tiled
- // layout optimal for the tensor cores. It doesn't make sense to allocate a
- // larger workspace.
- //
- // TODO(ezhulenev): This is not based on any measurement, just a common
- // sense, we should tweak it to find the minimal workspace size.
- if (instr->custom_call_target() == kGemmCallTarget) {
- int64_t operands_byte_size = 0;
- for (auto &operand : instr->operands()) {
- operands_byte_size += ShapeUtil::ByteSizeOf(operand->shape());
- }
- workspace = std::min(workspace, operands_byte_size);
- }
-
- // Append workspace buffer to instruction outputs.
- std::vector<Shape> output_shapes = instr->shape().IsArray()
- ? std::vector<Shape>{instr->shape()}
- : instr->shape().tuple_shapes();
- output_shapes.emplace_back(ShapeUtil::MakeShape(S8, {workspace}));
- Shape output_shape = ShapeUtil::MakeTupleShape(output_shapes);
-
- // Clone custom call with a new shape.
- HloInstruction *new_call = instr->AddInstruction(
- instr->CloneWithNewOperands(output_shape, instr->operands()));
-
- // Update operand aliasing if it was a fused gemm with aliased output.
- auto *custom_call = xla::Cast<HloCustomCallInstruction>(new_call);
- if (!custom_call->output_to_operand_aliasing().empty()) {
- custom_call->set_output_to_operand_aliasing({{{0}, {2, {}}}});
- }
-
- if (instr->shape().IsTuple()) {
- for (auto user : instr->users()) {
- auto user_get_tuple =
- dynamic_cast<HloGetTupleElementInstruction *>(user);
- TF_RET_CHECK(user_get_tuple);
- HloInstruction *get_output =
- instr->AddInstruction(HloInstruction::CreateGetTupleElement(
- new_call, user_get_tuple->tuple_index()));
- TF_RETURN_IF_ERROR(ReplaceInstruction(user_get_tuple, get_output));
- }
- return absl::OkStatus();
- } else {
- HloInstruction *get_output = instr->AddInstruction(
- HloInstruction::CreateGetTupleElement(new_call, 0));
- return ReplaceInstruction(instr, get_output);
- }
- }
-
- private:
- se::GpuComputeCapability gpu_version_;
-};
-
-absl::StatusOr<bool> RunOnComputation(HloComputation *computation,
- se::GpuComputeCapability gpu_version,
- int32_t toolkit_version,
- bool f8_rewrite) {
- GemmRewriterVisitor visitor(gpu_version, toolkit_version, f8_rewrite);
- TF_RETURN_IF_ERROR(computation->Accept(&visitor));
- GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version);
- TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor));
- return visitor.changed();
-}
-
-} // anonymous namespace
-
-GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version,
- int32_t toolkit_version, bool f8_rewrite)
- : gpu_version_(gpu_version),
- toolkit_version_(toolkit_version),
- f8_rewrite_(f8_rewrite) {}
-
-absl::StatusOr<bool> GemmRewriter::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- bool changed = false;
- for (HloComputation *computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result,
- RunOnComputation(computation, gpu_version_,
- toolkit_version_, f8_rewrite_));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.h b/third_party/xla/xla/service/gpu/gemm_rewriter.h
deleted file mode 100644
index 161a29a..0000000
--- a/third_party/xla/xla/service/gpu/gemm_rewriter.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_REWRITER_H_
-#define XLA_SERVICE_GPU_GEMM_REWRITER_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// cuBLAS GEMM in the most general form can run the following operation:
-//
-// (kAdd
-// (kMultiply (kDot A B) alpha)
-// (kMultiply C beta))
-//
-// where A, B, C are matrices or vectors and `alpha` and `beta` are host
-// constants. In matrix-vector multiplication, one operand must be a matrix and
-// the other must be a vector. The additional requirement is that C has no other
-// users (otherwise, it does not make sense to fuse it inside the custom call).
-//
-// Both multiplication and addition can be avoided (equivalent to setting
-// `alpha` to one and `beta` to zero).
-//
-// This pass pattern-matches the most general form of this instruction
-// (we assume transposes are already folded), and rewrites it into a custom call
-// where (A, B, C) are three operands respectively, and `alpha` and `beta` are
-// stored in the backend config.
-class GemmRewriter : public HloModulePass {
- public:
- // When f8_rewrite is true, only FP8 GEMMs are rewritten. Otherwise, non-FP8
- // GEMMs are rewritten.
- GemmRewriter(se::GpuComputeCapability gpu_version, int32_t toolkit_version,
- bool f8_rewrite = false);
- absl::string_view name() const override { return "cublas-gemm-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::GpuComputeCapability gpu_version_;
- int32_t toolkit_version_;
- bool f8_rewrite_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GEMM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/gemv_rewriter.cc
deleted file mode 100644
index 21e5f47..0000000
--- a/third_party/xla/xla/service/gpu/gemv_rewriter.cc
+++ /dev/null
@@ -1,183 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemv_rewriter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/shape.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// Construct a new layout by adding a new minor-most dimension to the input
-// layout. For example, {3, 2, 1, 0} is extended to {4, 3, 2, 1, 0}.
-// We expect that the input layout is normalized by LayoutNormalizer, so that
-// the input layout has a descending ordering.
-absl::StatusOr<Layout> GetLayoutWithNewMinorMostDimension(
- const Layout& layout) {
- // Check that the layout is normalized.
- if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) {
- return absl::InvalidArgumentError("Layout is not normalized.");
- }
- return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() + 1);
-}
-
-class GemvRewriterVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleDot(HloInstruction* instr) override {
- HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
- const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers();
- HloInstruction* lhs = dot->mutable_operand(0);
- HloInstruction* rhs = dot->mutable_operand(1);
-
- // This pass relies on dot decomposer which ensures that all non-batch
- // dimensions are merged into one.
- bool lhs_has_non_contracting_dim =
- lhs->shape().rank() ==
- dim_numbers.lhs_batch_dimensions_size() +
- dim_numbers.lhs_contracting_dimensions_size() + 1;
- bool rhs_has_non_contracting_dim =
- rhs->shape().rank() ==
- dim_numbers.rhs_batch_dimensions_size() +
- dim_numbers.rhs_contracting_dimensions_size() + 1;
-
- // Skip matrix-matrix multiplication.
- if (lhs_has_non_contracting_dim && rhs_has_non_contracting_dim) {
- return absl::OkStatus();
- }
-
- // Skip vector-vector multiplication.
- if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) {
- return absl::OkStatus();
- }
-
- if (dot->shape().is_dynamic()) {
- return absl::OkStatus();
- }
-
- changed_ = true;
-
- HloComputation* computation = dot->parent();
- HloInstruction* new_lhs = lhs;
- if (!lhs_has_non_contracting_dim) {
- const Shape& lhs_shape = lhs->shape();
- absl::Span<const int64_t> lhs_dimensions = lhs_shape.dimensions();
- std::vector<int64_t> new_lhs_dimensions(lhs_dimensions.begin(),
- lhs_dimensions.end());
- new_lhs_dimensions.push_back(1);
- Shape new_lhs_shape(
- lhs_shape.element_type(), new_lhs_dimensions,
- absl::InlinedVector<bool, 4>(new_lhs_dimensions.size(), false),
- /*tuple_shapes=*/{});
- TF_ASSIGN_OR_RETURN(
- *new_lhs_shape.mutable_layout(),
- GetLayoutWithNewMinorMostDimension(lhs_shape.layout()));
- new_lhs = computation->AddInstruction(
- HloInstruction::CreateBitcast(new_lhs_shape, lhs));
- }
-
- HloInstruction* new_rhs = rhs;
- if (!rhs_has_non_contracting_dim) {
- const Shape& rhs_shape = rhs->shape();
- absl::Span<const int64_t> rhs_dimensions = rhs_shape.dimensions();
- std::vector<int64_t> new_rhs_dimensions(rhs_dimensions.begin(),
- rhs_dimensions.end());
- new_rhs_dimensions.push_back(1);
- Shape new_rhs_shape(
- rhs_shape.element_type(), new_rhs_dimensions,
- absl::InlinedVector<bool, 4>(new_rhs_dimensions.size(), false),
- /*tuple_shapes=*/{});
- TF_ASSIGN_OR_RETURN(
- *new_rhs_shape.mutable_layout(),
- GetLayoutWithNewMinorMostDimension(rhs_shape.layout()));
- new_rhs = computation->AddInstruction(
- HloInstruction::CreateBitcast(new_rhs_shape, rhs));
- }
-
- std::vector<int64_t> new_out_dimensions;
- new_out_dimensions.reserve(dot->shape().dimensions().size() + 1);
- for (int64_t dim_size : dot->shape().dimensions()) {
- new_out_dimensions.push_back(dim_size);
- }
- if (!lhs_has_non_contracting_dim) {
- // Insert the trivial dimension before the non-contracting dimension from
- // rhs.
- int non_contracting_dim_size = new_out_dimensions.back();
- new_out_dimensions[new_out_dimensions.size() - 1] = 1;
- new_out_dimensions.push_back(non_contracting_dim_size);
- } else {
- new_out_dimensions.push_back(1);
- }
-
- Shape new_out_shape(
- dot->shape().element_type(), new_out_dimensions,
- absl::InlinedVector<bool, 4>(new_out_dimensions.size(), false),
- /*tuple_shapes=*/{});
- TF_ASSIGN_OR_RETURN(
- *new_out_shape.mutable_layout(),
- GetLayoutWithNewMinorMostDimension(dot->shape().layout()));
-
- HloInstruction* new_dot =
- computation->AddInstruction(HloInstruction::CreateDot(
- new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(),
- dot->precision_config()));
- HloInstruction* bitcast = computation->AddInstruction(
- HloInstruction::CreateBitcast(dot->shape(), new_dot));
- return computation->ReplaceInstruction(dot, bitcast);
- }
-
- bool changed() const { return changed_; }
-
- private:
- bool changed_ = false;
-};
-
-} // namespace
-
-absl::StatusOr<bool> GemvRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- GemvRewriterVisitor gemv_rewriter;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_RETURN_IF_ERROR(computation->Accept(&gemv_rewriter));
- }
- return gemv_rewriter.changed();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.h b/third_party/xla/xla/service/gpu/gemv_rewriter.h
deleted file mode 100644
index a041138..0000000
--- a/third_party/xla/xla/service/gpu/gemv_rewriter.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMV_REWRITER_H_
-#define XLA_SERVICE_GPU_GEMV_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrite a matrix-vector or a vector-matrix multiplication into a
-// matrix-matrix multiplication with a trivial dimension. For example,
-// [m x n] @ [n] is rewritten to [m x n] @ [n x 1], and [n] @ [m x n] is
-// rewritten to [n x 1] @ [m x n].
-class GemvRewriter : public HloModulePass {
- public:
- absl::string_view name() const override { return "gemv-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GEMV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc
deleted file mode 100644
index 2a8b810..0000000
--- a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc
+++ /dev/null
@@ -1,149 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemv_rewriter.h"
-
-#include <memory>
-#include <optional>
-
-#include <gtest/gtest.h>
-#include "absl/status/statusor.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-class GemvRewriterTest : public HloTestBase {};
-
-TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationToGemm) {
- const char* hlo = R"(
- HloModule m
-
- ENTRY e {
- p0 = f32[32,7] parameter(0)
- p1 = f32[7] parameter(1)
- ROOT d = f32[32] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })";
-
- const char* expected = R"()
-// CHECK: %[[P0:.*]] = f32[32,7]{1,0} parameter(0)
-// CHECK: %[[P1:.*]] = f32[7]{0} parameter(1)
-// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P1]])
-// CHECK: %[[DOT:.*]] = f32[32,1]{1,0} dot(%[[P0]], %[[BITCAST]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-// CHECK: ROOT %[[ROOT:.*]] = f32[32]{0} bitcast(%[[DOT]])
-})";
-
- RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
-}
-
-TEST_F(GemvRewriterTest, RewriteVectorMatrixMultiplicationToGemm) {
- const char* hlo = R"(
- HloModule m
-
- ENTRY e {
- p0 = f32[7] parameter(0)
- p1 = f32[7,32] parameter(1)
- ROOT d = f32[32] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
- })";
-
- const char* expected = R"()
-// CHECK: %[[P0:.*]] = f32[7]{0} parameter(0)
-// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P0]])
-// CHECK: %[[P1:.*]] = f32[7,32]{1,0} parameter(1)
-// CHECK: %[[DOT:.*]] = f32[1,32]{1,0} dot(%[[BITCAST]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0}
-// CHECK: ROOT %[[ROOT:.*]].1 = f32[32]{0} bitcast(%[[DOT]])
-})";
-
- RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
-}
-
-TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationWithBatch) {
- const char* hlo = R"(
- HloModule m
-
- ENTRY e {
- p0 = f32[2,5,32,7] parameter(0)
- p1 = f32[2,5,7] parameter(1)
- ROOT d = f32[2,5,32] dot(p0, p1),
- lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
- lhs_contracting_dims={3}, rhs_contracting_dims={2}
- })";
-
- const char* expected = R"()
-// CHECK: %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0)
-// CHECK: %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1)
-// CHECK: %[[BITCAST:.*]] = f32[2,5,7,1]{3,2,1,0} bitcast(%[[P1]])
-// CHECK: %[[DOT:.*]] = f32[2,5,32,1]{3,2,1,0} dot(%[[P0]], %[[BITCAST]]),
-// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-// CHECK: ROOT %[[ROOT:.*]] = f32[2,5,32]{2,1,0} bitcast(%[[DOT]])
-})";
-
- RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
-}
-
-TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) {
- const char* hlo = R"(
- HloModule m
-
- ENTRY e {
- p0 = f32[7] parameter(0)
- p1 = f32[7] parameter(1)
- ROOT d = f32[] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
- })";
-
- RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
-}
-
-TEST_F(GemvRewriterTest, DotNotRewriteMatrixMatrixMultiplication) {
- const char* hlo = R"(
- HloModule m
-
- ENTRY e {
- p0 = f32[5,7] parameter(0)
- p1 = f32[7,32] parameter(1)
- ROOT d = f32[5,32] dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })";
-
- RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
-}
-
-TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) {
- const char* hlo = R"(
- HloModule m
-
- ENTRY e {
- p0 = f32[5,32,7]{2,1,0} parameter(0)
- p1 = f32[5,7]{0,1} parameter(1)
- ROOT d = f32[5,32]{0,1} dot(p0, p1),
- lhs_batch_dims={0}, rhs_batch_dims={0},
- lhs_contracting_dims={2}, rhs_contracting_dims={1}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo));
- GemvRewriter rewriter;
- absl::StatusOr<bool> result = this->RunHloPass(&rewriter, module.get());
- EXPECT_FALSE(result.ok());
- EXPECT_EQ(result.status().message(), "Layout is not normalized.");
-}
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc
deleted file mode 100644
index 21e8d6c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
-
-#include "absl/log/check.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla::gpu {
-
-bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
- const HloInstruction* hlo) {
- if (!options_.enable_dot_strength_reduction()) {
- return false;
- }
-
- const HloDotInstruction* dot = DynCast<HloDotInstruction>(hlo);
- if (dot == nullptr) {
- return false;
- }
-
- const HloInstruction* lhs = dot->operand(0);
- const HloInstruction* rhs = dot->operand(1);
- DotDimensionNumbers dnums = dot->dot_dimension_numbers();
- bool lhs_is_vector = (dnums.lhs_batch_dimensions_size() +
- dnums.lhs_contracting_dimensions_size() ==
- lhs->shape().rank());
- bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() +
- dnums.rhs_contracting_dimensions_size() ==
- rhs->shape().rank());
- // Strength-reduce vector-vector dots since they are not supported by
- // GemmFusion.
- if (lhs_is_vector && rhs_is_vector) {
- return true;
- }
-
- absl::StatusOr<bool> is_too_small =
- IsMatrixMultiplicationTooSmallForRewriting(*hlo, /*threshold=*/10000000);
- CHECK_OK(is_too_small.status());
- if (is_too_small.value()) {
- return true;
- }
-
- // If GemmFusion cannot handle this dot, we should strength-reduce it so that
- // it can be handled by the fusion pipeline.
- return !legacy_triton::CanTritonHandleGEMM(*dot, compute_capability_);
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h
deleted file mode 100644
index 8553596..0000000
--- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_
-#define XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-
-namespace xla::gpu {
-
-class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor {
- public:
- explicit GpuAlgebraicSimplifierVisitor(
- const AlgebraicSimplifierOptions& options,
- se::GpuComputeCapability compute_capability,
- AlgebraicSimplifier* simplifier)
- : AlgebraicSimplifierVisitor(options, simplifier),
- compute_capability_(std::move(compute_capability)) {}
-
- bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override;
-
- private:
- se::GpuComputeCapability compute_capability_;
-};
-
-class GpuAlgebraicSimplifier : public AlgebraicSimplifier {
- public:
- explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options,
- se::GpuComputeCapability compute_capability)
- : AlgebraicSimplifier(options),
- compute_capability_(std::move(compute_capability)) {}
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(HloModule* module,
- const absl::flat_hash_set<absl::string_view>&
- execution_threads) override {
- XLA_VLOG_LINES(
- 2, "GpuAlgebraicSimplifier::Run(), before:\n" + module->ToString());
- bool changed = false;
- GpuAlgebraicSimplifierVisitor visitor(options_, compute_capability_, this);
- for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
- if (visitor.Run(comp, options_, this)) {
- changed = true;
- }
- }
- XLA_VLOG_LINES(
- 2, "GpuAlgebraicSimplifier::Run(), after:\n" + module->ToString());
- return changed;
- }
-
- private:
- se::GpuComputeCapability compute_capability_;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc
deleted file mode 100644
index 135ddb12..0000000
--- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc
+++ /dev/null
@@ -1,141 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
-
-#include <string>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-class GpuAlgebraicSimplifierTest : public HloTestBase {};
-
-TEST_F(GpuAlgebraicSimplifierTest, VectorVectorDotShouldBeStrengthReduced) {
- const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
- p0 = f32[32, 500] parameter(0)
- p1 = f32[32, 500] parameter(1)
- ROOT dot = f32[32] dot(p0, p1), lhs_batch_dims={0},
- lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- const HloInstruction* dot = module->entry_computation()->root_instruction();
- AlgebraicSimplifierOptions options;
- options.set_enable_dot_strength_reduction(true);
- se::CudaComputeCapability ampere(8, 0);
- GpuAlgebraicSimplifier simplifier(options, ampere);
- GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
- EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest, MatrixVectorDotShouldNotBeStrengthReduced) {
- const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
- p0 = f32[32, 5000, 7000] parameter(0)
- p1 = f32[32, 5000] parameter(1)
- ROOT dot = f32[32,7000] dot(p0, p1), lhs_batch_dims={0},
- lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
- algorithm=dot_bf16_bf16_f32_x6
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- const HloInstruction* dot = module->entry_computation()->root_instruction();
- AlgebraicSimplifierOptions options;
- options.set_enable_dot_strength_reduction(true);
- se::CudaComputeCapability ampere(8, 0);
- GpuAlgebraicSimplifier simplifier(options, ampere);
- GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
- EXPECT_FALSE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest,
- DotWithTypeUnsupportedByGemmFusionShouldBeStrengthReduced) {
- const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
- p0 = c64[32, 5000, 7000] parameter(0)
- p1 = c64[32, 5000] parameter(1)
- ROOT dot = c64[32,7000] dot(p0, p1), lhs_batch_dims={0},
- lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- const HloInstruction* dot = module->entry_computation()->root_instruction();
- AlgebraicSimplifierOptions options;
- options.set_enable_dot_strength_reduction(true);
- se::CudaComputeCapability ampere(8, 0);
- GpuAlgebraicSimplifier simplifier(options, ampere);
- GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
- EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced) {
- const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
- p0 = f32[32, 50, 70] parameter(0)
- p1 = f32[32, 50] parameter(1)
- ROOT dot = f32[32,70] dot(p0, p1), lhs_batch_dims={0},
- lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
- algorithm=dot_bf16_bf16_f32_x6
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- const HloInstruction* dot = module->entry_computation()->root_instruction();
- AlgebraicSimplifierOptions options;
- options.set_enable_dot_strength_reduction(true);
- se::CudaComputeCapability ampere(8, 0);
- GpuAlgebraicSimplifier simplifier(options, ampere);
- GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
- EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced2) {
- const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
- p0 = f32[2000, 3000] parameter(0)
- p1 = f32[2000] parameter(1)
- ROOT dot = f32[3000] dot(p0, p1), lhs_contracting_dims={0},
- rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32_x6
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- const HloInstruction* dot = module->entry_computation()->root_instruction();
- AlgebraicSimplifierOptions options;
- options.set_enable_dot_strength_reduction(true);
- se::CudaComputeCapability ampere(8, 0);
- GpuAlgebraicSimplifier simplifier(options, ampere);
- GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
- EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc b/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc
deleted file mode 100644
index fe2d2d1..0000000
--- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_all_gather_optimizer.h"
-
-#include <cstdint>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> AllGatherOptimizer::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* instruction :
- computation->MakeInstructionPostOrder()) {
- if (!HloOpcodeIsBinaryCommutative(instruction->opcode())) {
- continue;
- }
-
- HloInstruction* left_op = instruction->mutable_operand(0);
- HloInstruction* right_op = instruction->mutable_operand(1);
-
- if (right_op->opcode() != HloOpcode::kAllGather ||
- left_op->opcode() != HloOpcode::kAllGather) {
- VLOG(2) << "Binary op's operands are not all-gather deduced types.";
- continue;
- }
-
- auto* left_all_gather = Cast<HloAllGatherInstruction>(left_op);
- auto* right_all_gather = Cast<HloAllGatherInstruction>(right_op);
-
- if (right_all_gather->constrain_layout() !=
- left_all_gather->constrain_layout() ||
- right_all_gather->use_global_device_ids() !=
- left_all_gather->use_global_device_ids() ||
- !ReplicaGroupsEqual(right_all_gather->replica_groups(),
- left_all_gather->replica_groups())) {
- VLOG(2) << "The right and left all-gather ops are not compatible "
- "to merge. ";
- continue;
- }
-
- if (!ShapeUtil::Equal(left_all_gather->operand(0)->shape(),
- right_all_gather->operand(0)->shape())) {
- VLOG(2) << "all-gather operands have different shapes";
- continue;
- }
-
- if (right_all_gather->user_count() != 1 ||
- left_all_gather->user_count() != 1) {
- VLOG(2) << "all-gather user_count > 1 ";
- continue;
- }
- auto index_in_full_shape =
- computation->AddInstruction(HloInstruction::CreateBinary(
- right_all_gather->operand(0)->shape(), instruction->opcode(),
- left_all_gather->mutable_operand(0),
- right_all_gather->mutable_operand(0)));
-
- int64_t all_gather_dimension =
- Cast<HloAllGatherInstruction>(right_all_gather)
- ->all_gather_dimension();
-
- auto combined = HloInstruction::CreateAllGather(
- left_all_gather->shape(), {index_in_full_shape}, all_gather_dimension,
- left_all_gather->device_list(),
- /*constrain_layout=*/false, left_all_gather->channel_id(),
- Cast<HloAllGatherInstruction>(left_all_gather)
- ->use_global_device_ids());
-
- TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
- instruction, std::move(combined)));
- changed = true;
- }
- }
-
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h b/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h
deleted file mode 100644
index e28e422..0000000
--- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_
-#define XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Transforms binary_op(all-gather(reduce_scatter(a)),
-// all-gather(reduce_scatter(b))) to allgather(binary_op(reduce_scatter(a),
-// reduce_scatter(b)))
-
-class AllGatherOptimizer : public HloModulePass {
- public:
- AllGatherOptimizer() = default;
- absl::string_view name() const override { return "all-gather-optimizer"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc b/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc
deleted file mode 100644
index c2f6c04..0000000
--- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_async_collective_annotator.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> GpuAsyncCollectiveAnnotator::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (!hlo_query::IsAsyncCollectiveStartOp(instruction)) {
- continue;
- }
- CollectiveBackendConfig config;
- config.set_is_sync(!is_collective_async_(instruction));
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- instruction->backend_config<GpuBackendConfig>());
- *gpu_config.mutable_collective_backend_config() = config;
- TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
- changed = true;
- }
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h b/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h
deleted file mode 100644
index 4000fbc..0000000
--- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Annotate async collectives with CollectiveBackendConfig.
-class GpuAsyncCollectiveAnnotator : public HloModulePass {
- public:
- explicit GpuAsyncCollectiveAnnotator(HloPredicate is_collective_async)
- : is_collective_async_(std::move(is_collective_async)) {}
- absl::string_view name() const override {
- return "gpu-async-collective-annotator";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- HloPredicate is_collective_async_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc b/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc
deleted file mode 100644
index f874a7e..0000000
--- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc
+++ /dev/null
@@ -1,183 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_async_collective_annotator.h"
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_macros.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithAsync
-
- addf32 {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
-
- addf16 {
- p0 = f16[] parameter(0)
- p1 = f16[] parameter(1)
- ROOT add = f16[] add(p0, p1)
- }
-
- reduce_scatterf32 {
- p0 = f32[2] parameter(0)
- ROOT result = f32[1] reduce-scatter(p0), replica_groups={},
- dimensions={0}, to_apply=addf32
- }
-
- ENTRY entry {
- pf32 = f32[1] parameter(0)
- pf16 = f16[1] parameter(1)
-
- arf32-start = f32[1] all-reduce-start(pf32), to_apply=addf32
- arf32-done = f32[1] all-reduce-done(arf32-start)
-
- arf16-start = f16[1] all-reduce-start(pf16), to_apply=addf16
- arf16-done = f16[1] all-reduce-done(arf16-start)
-
- agf32-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}
- agf32-done = f32[2] all-gather-done(agf32-start)
-
- agf16-start = (f16[1], f16[2]) all-gather-start(pf16), dimensions={0}
- agf16-done = f16[2] all-gather-done(agf16-start)
-
- cpf32-start = (f32[1], f32[1], u32[], u32[]) collective-permute-start(pf32),
- source_target_pairs={{0,1}, {1,0}}
- cpf32-done = f32[1] collective-permute-done(cpf32-start)
-
- cpf16-start = (f16[1], f16[1], u32[], u32[]) collective-permute-start(pf16),
- source_target_pairs={{0,1}, {1,0}}
- cpf16-done = f16[1] collective-permute-done(cpf16-start)
-
- rsf32-start = ((f32[2]), f32[1]) async-start(agf32-done), calls=reduce_scatterf32
- rsf32-done = f32[1] async-done(rsf32-start), calls=reduce_scatterf32
-
- ROOT tuple = (f32[1], f16[1], f32[2], f16[2], f32[1], f16[1], f32[1])
- tuple(arf32-done, arf16-done, agf32-done, agf16-done, cpf32-done,
- cpf16-done, rsf32-done)
- }
-)";
-
-struct TestCase {
- std::string test_name;
- HloPredicate is_async_predicate;
- absl::flat_hash_set<absl::string_view> expected_async;
- absl::flat_hash_set<absl::string_view> expected_sync;
-};
-
-class GpuAsyncCollectiveAnnotatorTest
- : public HloTestBase,
- public ::testing::WithParamInterface<TestCase> {};
-
-XLA_TEST_P(GpuAsyncCollectiveAnnotatorTest, Test) {
- const TestCase& test_case = GetParam();
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString, /*replica_count=*/2));
- TF_ASSERT_OK_AND_ASSIGN(
- bool changed, GpuAsyncCollectiveAnnotator(test_case.is_async_predicate)
- .Run(module.get()));
- EXPECT_TRUE(changed);
-
- // Assert that all async collectives are annotated with the backend config.
- for (const HloInstruction* hlo :
- module->entry_computation()->instructions()) {
- if (!hlo_query::IsAsyncCollectiveStartOp(hlo)) {
- continue;
- }
- auto gpu_config = hlo->backend_config<GpuBackendConfig>();
- ASSERT_TRUE(gpu_config.ok());
-
- const CollectiveBackendConfig& backend_config =
- gpu_config.value().collective_backend_config();
- if (test_case.expected_async.contains(hlo->name())) {
- EXPECT_FALSE(backend_config.is_sync());
- }
-
- if (test_case.expected_sync.contains(hlo->name())) {
- EXPECT_TRUE(backend_config.is_sync());
- }
- }
-}
-
-std::vector<TestCase> TestCases() {
- HloPredicate is_f16 = [](const HloInstruction* hlo) {
- return hlo->operand(0)->shape().element_type() == PrimitiveType::F16;
- };
-
- return {
- {"all_async",
- HloPredicateTrue, /*expected_async=*/
- {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
- "cpf32-start", "cpf16-start", "rsf32-start"},
- /*expected_sync=*/{}},
- {"all_sync",
- HloPredicateFalse,
- /*expected_async=*/{},
- /*expected_sync=*/
- {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
- "cpf32-start", "cpf16-start", "rsf32-start"}},
- {"ar_async",
- HloPredicateIsOp<HloOpcode::kAllReduceStart>,
- /*expected_async=*/
- {"arf32-start", "arf16-start"},
- /*expected_sync=*/
- {"agf32-start", "agf16-start", "cpf32-start", "cpf16-start",
- "rsf32-start"}},
- {"cp_async",
- HloPredicateIsOp<HloOpcode::kCollectivePermuteStart>,
- /*expected_async=*/
- {"cpf32-start", "cpf16-start"},
- /*expected_sync=*/
- {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
- "rsf32-start"}},
- {"f16_async",
- is_f16,
- /*expected_async=*/{"arf16-start", "agf16-start", "cpf16-start"},
- /*expected_sync=*/
- {"arf32-start", "agf32-start", "cpf32-start", "rsf32-start"}},
- };
-}
-
-std::string TestCaseName(const ::testing::TestParamInfo<TestCase>& test_case) {
- return test_case.param.test_name;
-}
-
-INSTANTIATE_TEST_SUITE_P(GpuAsyncCollectiveAnnotatorTest,
- GpuAsyncCollectiveAnnotatorTest,
- ::testing::ValuesIn(TestCases()), TestCaseName);
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc
index 96b537e..b477a88 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler.cc
+++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc
@@ -74,7 +74,6 @@
#include "xla/service/all_reduce_folder.h"
#include "xla/service/all_reduce_promotion.h"
#include "xla/service/all_reduce_reassociate.h"
-#include "xla/service/all_reduce_splitter.h"
#include "xla/service/async_collective_creator.h"
#include "xla/service/batchnorm_expander.h"
#include "xla/service/bitcast_dtypes_expander.h"
@@ -110,43 +109,18 @@
#include "xla/service/float_support.h"
#include "xla/service/gather_expander.h"
#include "xla/service/gather_simplifier.h"
-#include "xla/service/gpu/algorithm_checker.h"
-#include "xla/service/gpu/all_reduce_blueconnect.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/collective_permute_cycle_decomposer.h"
-#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h"
-#include "xla/service/gpu/command_buffer_scheduling.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h"
#include "xla/service/gpu/compile_module_to_llvm_ir.h"
#include "xla/service/gpu/conv_layout_normalization.h"
-#include "xla/service/gpu/custom_kernel_fusion_autotuner.h"
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
-#include "xla/service/gpu/dot_dimension_sorter.h"
-#include "xla/service/gpu/dot_operand_converter.h"
-#include "xla/service/gpu/double_buffer_loop_unrolling.h"
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
#include "xla/service/gpu/execution_stream_assignment.h"
#include "xla/service/gpu/fusion_pipeline.h"
-#include "xla/service/gpu/fusion_wrapper.h"
-#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h"
-#include "xla/service/gpu/gemm_fusion.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/gemv_rewriter.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
-#include "xla/service/gpu/gpu_all_gather_optimizer.h"
-#include "xla/service/gpu/gpu_async_collective_annotator.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/gpu/gpu_hlo_schedule.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
-#include "xla/service/gpu/gpu_layout_assignment.h"
#include "xla/service/gpu/gpu_p2p_pipeliner.h"
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-#include "xla/service/gpu/gpu_scatter_expander.h"
#include "xla/service/gpu/gpu_spmd_pipeline.h"
-#include "xla/service/gpu/gpu_windowed_einsum_handler.h"
#include "xla/service/gpu/hlo_fusion_stats.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_context.h"
@@ -156,26 +130,53 @@
#include "xla/service/gpu/metrics.h"
#include "xla/service/gpu/model/gpu_cost_model_stats_collection.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/move_copy_to_users.h"
-#include "xla/service/gpu/pipelined_p2p_rewriter.h"
#include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h"
-#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
-#include "xla/service/gpu/reduction_dimension_grouper.h"
-#include "xla/service/gpu/reduction_layout_normalizer.h"
-#include "xla/service/gpu/reduction_splitter.h"
#include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/gpu/rename_fusions.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/runtime_intrinsics.h"
-#include "xla/service/gpu/scatter_slice_simplifier.h"
-#include "xla/service/gpu/softmax_rewriter_triton.h"
-#include "xla/service/gpu/stream_attribute_annotator.h"
-#include "xla/service/gpu/stream_attribute_async_wrapper.h"
#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/gpu/topk_specializer.h"
-#include "xla/service/gpu/topk_splitter.h"
-#include "xla/service/gpu/tree_reduction_rewriter.h"
-#include "xla/service/gpu/triton_fusion_numerics_verifier.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/algorithm_checker.h"
+#include "xla/service/gpu/transforms/all_gather_optimizer.h"
+#include "xla/service/gpu/transforms/all_reduce_blueconnect.h"
+#include "xla/service/gpu/transforms/all_reduce_splitter.h"
+#include "xla/service/gpu/transforms/async_collective_annotator.h"
+#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h"
+#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h"
+#include "xla/service/gpu/transforms/command_buffer_scheduling.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h"
+#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h"
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
+#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
+#include "xla/service/gpu/transforms/dot_operand_converter.h"
+#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/gpu/transforms/gemv_rewriter.h"
+#include "xla/service/gpu/transforms/layout_assignment.h"
+#include "xla/service/gpu/transforms/move_copy_to_users.h"
+#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h"
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h"
+#include "xla/service/gpu/transforms/reduction_dimension_grouper.h"
+#include "xla/service/gpu/transforms/reduction_layout_normalizer.h"
+#include "xla/service/gpu/transforms/reduction_splitter.h"
+#include "xla/service/gpu/transforms/rename_fusions.h"
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
+#include "xla/service/gpu/transforms/scatter_expander.h"
+#include "xla/service/gpu/transforms/scatter_slice_simplifier.h"
+#include "xla/service/gpu/transforms/softmax_rewriter_triton.h"
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
+#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h"
+#include "xla/service/gpu/transforms/topk_specializer.h"
+#include "xla/service/gpu/transforms/topk_splitter.h"
+#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
+#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
+#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_computation_deduplicator.h"
#include "xla/service/hlo_constant_folding.h"
@@ -494,7 +495,7 @@
AlgebraicSimplifierOptions layout_insensitive_algsimp_opts =
opts_from_compiler;
layout_insensitive_algsimp_opts.set_conv_is_lowerable_callback(
- GpuConvRewriter::ConvIsLowerable);
+ ConvRewriter::ConvIsLowerable);
layout_insensitive_algsimp_opts.set_enable_dot_strength_reduction(
hlo_module_config.debug_options()
.xla_gpu_enable_dot_strength_reduction());
@@ -526,6 +527,7 @@
HloPassPipeline pre_spmd_pipeline("pre-spmd-partitioner");
// Run some IR cleanup passes before running the SPMD partitioning
// passes.
+ pre_spmd_pipeline.AddPass<CuDnnCustomCallConverter>();
pre_spmd_pipeline.AddPass<ConvertMemoryPlacementToInternalAnnotations>();
pre_spmd_pipeline.AddPass<CallInliner>();
pre_spmd_pipeline.AddPass<ZeroSizedHloElimination>();
@@ -627,7 +629,7 @@
HloPassPipeline pipeline("optimization");
AddHloVerifier(&pipeline);
if (debug_options.xla_gpu_multi_streamed_windowed_einsum()) {
- pipeline.AddPass<GpuWindowedEinsumHandler>();
+ pipeline.AddPass<WindowedEinsumHandler>();
}
pipeline.AddPass<TopKSplitter>();
pipeline.AddPass<TopkSpecializer>();
@@ -1121,7 +1123,7 @@
return false;
}
};
- pipeline.AddPass<GpuAsyncCollectiveAnnotator>(convert_to_async);
+ pipeline.AddPass<AsyncCollectiveAnnotator>(convert_to_async);
return pipeline.Run(hlo_module).status();
}
@@ -1376,6 +1378,7 @@
// heuristic, so we can mix and match various Gemm implementations based
// on projected (measured) performance.
if (debug_options.xla_gpu_enable_custom_fusions()) {
+ pipeline.AddPass<SimplifyFPConversions>();
pipeline.AddPass<CustomKernelFusionRewriter>(
&gpu_target_config.device_description);
pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
@@ -1396,10 +1399,12 @@
pipeline.AddPass<GemmFusion>(gpu_version);
}
- pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
- /*f8_rewrite=*/true);
- pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
- /*f8_rewrite=*/false);
+ pipeline.AddPass<GemmRewriter>(
+ gpu_version, GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ pipeline.AddPass<GemmRewriter>(
+ gpu_version, GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only});
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();
@@ -1442,7 +1447,7 @@
bool ignore_small_reduce_dims =
!debug_options.xla_gpu_enable_priority_fusion();
pipeline.AddPass<HloPassFix<ReductionSplitter>>(ignore_small_reduce_dims);
- pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>(gpu_version);
+ pipeline.AddPass<HloPassFix<TreeReductionRewriter>>(gpu_version);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
@@ -1467,10 +1472,12 @@
pipeline.AddPass<CallInliner>();
// TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated
// here for possibly better cuBLAS performance.
- pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
- /*f8_rewrite=*/true);
- pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
- /*f8_rewrite=*/false);
+ pipeline.AddPass<GemmRewriter>(
+ gpu_version, GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ pipeline.AddPass<GemmRewriter>(
+ gpu_version, GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only});
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();
@@ -2147,8 +2154,8 @@
}};
BinaryMap dnn_compiled_graphs;
if (stream_exec) {
- TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec,
- &dnn_compiled_graphs));
+ TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec,
+ &dnn_compiled_graphs));
}
const DebugOptions& debug_opts = module->config().debug_options();
@@ -2481,7 +2488,7 @@
pipeline.AddPass<CommandBufferScheduling>(
gpu_device_info, toolkit_version,
driver_version.value_or(toolkit_version));
- pipeline.AddPass<GpuSanitizeConstantNames>();
+ pipeline.AddPass<SanitizeConstantNames>();
}
AddHloVerifier(&main_pipeline,
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h
index 27a434f..456e675 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler.h
+++ b/third_party/xla/xla/service/gpu/gpu_compiler.h
@@ -31,7 +31,7 @@
#include "xla/service/buffer_assignment.h"
#include "xla/service/compiler.h"
#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/buffer_sharing.h"
#include "xla/service/gpu/compile_module_to_llvm_ir.h"
#include "xla/service/gpu/executable.pb.h"
@@ -171,10 +171,10 @@
return absl::OkStatus();
}
- // Runs cuDNN fusion compiler pass.
- virtual absl::Status RunCudnnFusionCompilerPass(
- HloModule* module, se::StreamExecutor* stream_exec,
- BinaryMap* dnn_compiled_graphs) {
+ // Runs cuDNN fusion and custom call compiler passes.
+ virtual absl::Status RunCudnnCompilerPasses(HloModule* module,
+ se::StreamExecutor* stream_exec,
+ BinaryMap* dnn_compiled_graphs) {
return absl::OkStatus();
}
@@ -235,7 +235,8 @@
absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
virtual absl::StatusOr<std::vector<uint8_t>> LinkModules(
- se::GpuComputeCapability cc, se::StreamExecutor* stream_exec,
+ se::GpuComputeCapability gpu_compute_capability,
+ se::StreamExecutor* stream_exec,
std::vector<std::vector<uint8_t>> modules,
const DebugOptions& debug_options) {
return Unimplemented("LinkModules is not implemented.");
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc
index b74d77c..93057c5 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc
@@ -36,7 +36,7 @@
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/primitive_util.h"
#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/gpu_hlo_schedule.h"
#include "xla/service/gpu/metrics.h"
#include "xla/service/hlo_module_config.h"
@@ -391,7 +391,6 @@
TEST_F(GpuCompilerTest,
GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) {
- GTEST_SKIP() << "TODO(b/354864068): Test fails in OSS stack on A100-80.";
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
@@ -645,12 +644,12 @@
CHECK: %[[RESULT_RECV:.*]] = recv(%[[AFTER_ALL]])
CHECK-SAME: channel_id=[[CHANNEL_ID]]
CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"},
+CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}},
CHECK-SAME: control-predecessors={%[[CUSTOM_CALL]]}
CHECK: %[[RESULT_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[AFTER_ALL]])
CHECK-SAME: channel_id=1
CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"},
+CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}},
CHECK-SAME: control-predecessors={%[[RESULT_RECV]]}
CHECK: ROOT
// We actually expect both RESULT_RECV and RESULT_SEND to match on this line.
@@ -664,11 +663,11 @@
CHECK: %[[ENTRY_RECV:.*]] = recv(%[[ENTRY_AFTER_ALL]])
CHECK-SAME: channel_id=[[CHANNEL_ID]]
CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}
+CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}
CHECK: %[[ENTRY_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[ENTRY_AFTER_ALL]])
CHECK-SAME: channel_id=1
CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"},
+CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}},
CHECK-SAME: control-predecessors={%[[ENTRY_RECV]]}
CHECK: %[[WHILE_INIT:.*]] = tuple
// Check here that the send argument is likewise passed to the while loop, as
@@ -861,10 +860,10 @@
})";
const char* kExpected = R"(
- // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{3,0}}",_xla_send_recv_validation="{{[{]}}{3,9}}"}
- // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{0,1},{1,2},{2,3}}",_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8}}"}
- // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{3,0}}",_xla_send_recv_validation="{{[{]}}{3,9}}"}
- // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{0,1},{1,2},{2,3}}",_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8}}"}
+ // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{3,0}},_xla_send_recv_validation={{[{]}}{3,9}}}
+ // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{0,1},{1,2},{2,3}},_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8}}}
+ // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{3,0}},_xla_send_recv_validation={{[{]}}{3,9}}}
+ // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{0,1},{1,2},{2,3}},_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8}}}
)";
DebugOptions debug_options;
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto
index 3549c95..51caadb 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto
+++ b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto
@@ -13,7 +13,7 @@
}
results {
device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
- hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}"
+ hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
run_time {
nanos: 1
@@ -37,7 +37,7 @@
}
results {
device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 40 MB"
- hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}"
+ hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
run_time {
nanos: 1
@@ -61,7 +61,7 @@
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
- hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}"
+ hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
gemm {
algorithm: -1
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc b/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc
deleted file mode 100644
index 0b55f7d..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc
+++ /dev/null
@@ -1,461 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <cstdlib>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/shape_inference.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/window_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
- CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
- conv.custom_call_target() ==
- kCudnnConvBiasActivationForwardCallTarget ||
- conv.custom_call_target() == kCudnnConvForwardGraphCallTarget);
- return window_util::HasSymmetricPadding(conv.window()) &&
- !window_util::HasNegativePadding(conv.window()) &&
- !window_util::HasDilation(conv.window());
-}
-
-// If the (positive and negative) padding on the input operand of a convolution
-// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
-// dilation), returns kPad and/or kSlice instructions that explicitly apply the
-// padding; otherwise returns the original input operand. When there is both
-// positive padding (including dilation) and negative padding, we insert both
-// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
-// into a kPad or kSlice op.
-HloInstruction* MaybePaddedAndSlicedInput(
- Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
- HloInstruction* input) {
- HloComputation* computation = input->parent();
- if (!window_util::HasSymmetricPadding(*conv_window) ||
- window_util::HasBaseDilation(*conv_window)) {
- // If padding is uneven or has dilation, we insert a kPad instruction that
- // applies positive padding and dilation.
- //
- // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
- // moving all the padding into an explicit pad op, we should keep as much
- // padding inside of cudnn as possible, on the assumption that padding
- // within cudnn is basically free, whereas a kPad's cost increases as the
- // amount of padding increases.
- PaddingConfig padding_config =
- MakeNoPaddingConfig(input->shape().dimensions_size());
- for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
- int64_t dim = conv_dnums.input_spatial_dimensions(i);
- if (conv_window->dimensions(i).padding_low() > 0) {
- padding_config.mutable_dimensions(dim)->set_edge_padding_low(
- conv_window->dimensions(i).padding_low());
- conv_window->mutable_dimensions(i)->set_padding_low(0);
- }
- if (conv_window->dimensions(i).padding_high() > 0) {
- padding_config.mutable_dimensions(dim)->set_edge_padding_high(
- conv_window->dimensions(i).padding_high());
- conv_window->mutable_dimensions(i)->set_padding_high(0);
- }
- if (conv_window->dimensions(i).base_dilation() != 1) {
- padding_config.mutable_dimensions(dim)->set_interior_padding(
- conv_window->dimensions(i).base_dilation() - 1);
- conv_window->mutable_dimensions(i)->set_base_dilation(1);
- }
- }
- PrimitiveType element_type = input->shape().element_type();
- HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
- input =
- MakePadHlo(input, padding, padding_config, &input->metadata()).value();
- }
-
- if (window_util::HasNegativePadding(*conv_window)) {
- // If the window has negative padding, insert a kSlice that explicitly
- // applies negative padding.
- //
- // For each dimension, initialize the start index to 0 and the limit index
- // to the size of that dimension.
- std::vector<int64_t> start_indices(input->shape().dimensions_size(), 0);
- std::vector<int64_t> limit_indices(input->shape().dimensions().begin(),
- input->shape().dimensions().end());
- std::vector<int64_t> strides(input->shape().dimensions_size(), 1);
- for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
- int64_t dim = conv_dnums.input_spatial_dimensions(i);
- // If dimension "dim" has negative padding, increase the start index or
- // decrement the limit index by the amount of negative padding.
- if (conv_window->dimensions(i).padding_low() < 0) {
- start_indices[dim] += -conv_window->dimensions(i).padding_low();
- conv_window->mutable_dimensions(i)->set_padding_low(0);
- }
- if (conv_window->dimensions(i).padding_high() < 0) {
- limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
- conv_window->mutable_dimensions(i)->set_padding_high(0);
- }
- }
-
- input = MakeSliceHlo(input, start_indices, limit_indices, strides).value();
- }
-
- return input;
-}
-
-// If the padding on the kernel operand of a convolution can't be folded into a
-// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
-// explicitly applies the padding; otherwise returns the original kernel
-// operand.
-HloInstruction* MaybePaddedKernel(const Window& conv_window,
- const ConvolutionDimensionNumbers& conv_dnums,
- HloInstruction* kernel) {
- if (!window_util::HasWindowDilation(conv_window)) {
- return kernel;
- }
-
- // Compute the shape and padding config of the pad to be inserted.
- PaddingConfig padding_config;
- padding_config.mutable_dimensions()->Reserve(
- kernel->shape().dimensions_size());
- for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
- padding_config.add_dimensions();
- }
- for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
- int64_t dim = conv_dnums.kernel_spatial_dimensions(i);
- padding_config.mutable_dimensions(dim)->set_interior_padding(
- conv_window.dimensions(i).window_dilation() - 1);
- }
-
- HloComputation* computation = kernel->parent();
- PrimitiveType element_type = kernel->shape().element_type();
- HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
- return MakePadHlo(kernel, padding, padding_config, &kernel->metadata())
- .value();
-}
-} // namespace
-
-bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution(
- HloInstruction* conv) {
- if (IsForwardConvolutionCanonical(*conv)) {
- return false;
- }
-
- // Insert slices and/or pads between the convolution and its input and/or
- // kernel operand.
- Window new_conv_window = conv->window();
- HloInstruction* new_input = MaybePaddedAndSlicedInput(
- &new_conv_window, conv->convolution_dimension_numbers(),
- conv->mutable_operand(0));
- HloInstruction* new_kernel =
- MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
- conv->mutable_operand(1));
-
- // Remove the window dilation from convolution's window field. These paddings
- // are made explicit with the pads inserted by MaybePaddedKernel().
- for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
- WindowDimension* dim = new_conv_window.mutable_dimensions(i);
-
- // The size of the kernel may have changed so update the Window to match.
- dim->set_size(new_kernel->shape().dimensions(
- conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
- dim->set_window_dilation(1);
- }
-
- // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
- // out the shape of conv_result.
- VLOG(1) << "Canonicalizing forward conv";
- std::vector<HloInstruction*> operands(conv->operands().begin(),
- conv->operands().end());
- operands[0] = new_input;
- operands[1] = new_kernel;
- auto new_conv = conv->parent()->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), operands));
- new_conv->set_window(new_conv_window);
- VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
- << new_conv->ToString();
- TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
- return true;
-}
-
-namespace {
-void IncreasePaddingLowBy(int64_t delta, WindowDimension* window_dim) {
- window_dim->set_padding_low(window_dim->padding_low() + delta);
-}
-
-void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) {
- window_dim->set_padding_high(window_dim->padding_high() + delta);
-}
-} // namespace
-
-bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
- HloInstruction* backward_conv) {
- CHECK_EQ(backward_conv->custom_call_target(),
- kCudnnConvBackwardFilterCallTarget);
- if (window_util::HasSymmetricPadding(backward_conv->window())) {
- return false;
- }
-
- // A backward filter convolution with uneven padding can be canonicalized to
- // one with even padding by padding the activations (input) beforehand. For
- // example,
- // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
- // is equivalent to
- // ABCD0 = Pad(ABCD, padding_high=1)
- // BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
- // We choose the lesser of padding_low and padding_high as the new padding.
- HloInstruction* input = backward_conv->mutable_operand(0);
- Window new_backward_conv_window = backward_conv->window();
- // input_padding_config is the config of the kPad to be inserted.
- PaddingConfig input_padding_config =
- MakeNoPaddingConfig(input->shape().rank());
- ConvolutionDimensionNumbers backward_conv_dnums =
- backward_conv->convolution_dimension_numbers();
- for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
- int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
- int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
- if (padding_low < 0 || padding_high < 0) {
- // TODO(b/32744257): The following canonicalization wouldn't remove
- // negative padding in a backward convolution, and would therefore cause
- // cuDNN convolution (which doesn't support negative padding) to fail.
- return false;
- }
- // Compute the new, even padding for the backward conv operation.
- int64_t new_conv_padding = std::min(padding_low, padding_high);
- int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
- input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
- padding_low - new_conv_padding);
- input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
- padding_high - new_conv_padding);
-
- // Since we move some padding from the backward convolution to the kPad, we
- // need to accordingly reduce the padding amount of the backward convolution
- // and its inner forward convolution.
- auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
- new_dim->set_padding_low(new_conv_padding);
- new_dim->set_padding_high(new_conv_padding);
- }
-
- // Create a new backward convolution replacing the old one.
- HloComputation* computation = backward_conv->parent();
- HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(input->shape().element_type())));
- HloInstruction* padded_input =
- MakePadHlo(input, padding, input_padding_config).value();
-
- // The shape of the backward_conv CustomCall is a tuple (conv_result,
- // scratch_buffer). Extract out the shape of conv_result.
- HloInstruction* new_backward_conv =
- computation->AddInstruction(backward_conv->CloneWithNewOperands(
- backward_conv->shape(), {padded_input, output}));
- new_backward_conv->set_window(new_backward_conv_window);
-
- VLOG(1) << "Canonicalizing backward filter conv";
- VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
- << new_backward_conv->ToString();
-
- TF_CHECK_OK(
- computation->ReplaceInstruction(backward_conv, new_backward_conv));
- return true;
-}
-
-bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
- HloInstruction* backward_conv) {
- if (window_util::HasSymmetricPadding(backward_conv->window())) {
- return false;
- }
-
- Window new_backward_conv_window = backward_conv->window();
- ConvolutionDimensionNumbers backward_conv_dnums =
- backward_conv->convolution_dimension_numbers();
-
- // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
- // Get the shape of conv_result.
- Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
-
- Shape new_backward_conv_shape = backward_conv_shape;
- for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
- int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
- int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
- if (padding_low < 0 || padding_high < 0) {
- // TODO(b/32744257): The following canonicalization wouldn't remove
- // negative padding in a backward convolution, and would therefore cause
- // cuDNN convolution (which doesn't support negative padding) to fail.
- return false;
- }
- // If the backward convolution has uneven padding on the activations, we
- // move some padding on the larger end to "internal" padding, so that the
- // backward convolution produces larger activations which get sliced later.
- //
- // For example, suppose we have a non-canonical HLO
- // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
- // where the amount of padding low is larger, we can canonicalize it to
- // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
- // [A] = Slice([B A])
- if (padding_low > padding_high) {
- IncreasePaddingLowBy(padding_high - padding_low,
- new_backward_conv_window.mutable_dimensions(i));
- } else if (padding_low < padding_high) {
- IncreasePaddingHighBy(padding_low - padding_high,
- new_backward_conv_window.mutable_dimensions(i));
- }
- // Decreasing the padding by X *increases* the size of our output by X.
- // Note that we have swapped input spatial dimensions with output spatial
- // dimensions to be compatible with the cuDNN API, so
- // input_spatial_dimensions(i) gives the i-th spatial dimension of the
- // output.
- int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
- new_backward_conv_shape.set_dimensions(
- dim, new_backward_conv_shape.dimensions(dim) +
- std::abs(padding_low - padding_high));
- }
-
- // Create a new backward convolution replacing the old one.
- HloComputation* computation = backward_conv->parent();
- HloInstruction* output = backward_conv->mutable_operand(0);
- HloInstruction* filter = backward_conv->mutable_operand(1);
-
- HloInstruction* new_backward_conv_call =
- computation->AddInstruction(backward_conv->CloneWithNewOperands(
- ShapeUtil::MakeTupleShape(
- {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
- {output, filter}));
- new_backward_conv_call->set_window(new_backward_conv_window);
-
- // The CustomCall created above returns a tuple (conv_result, scratch_memory).
- // Extract out the two elements.
- HloInstruction* new_backward_conv =
- computation->AddInstruction(HloInstruction::CreateGetTupleElement(
- new_backward_conv_shape, new_backward_conv_call, 0));
- HloInstruction* new_backward_conv_scratch =
- computation->AddInstruction(HloInstruction::CreateGetTupleElement(
- new_backward_conv_call->shape().tuple_shapes(1),
- new_backward_conv_call, 1));
-
- // Slice the new backward convolution.
- //
- // Initialize start_indices and limit_indices as no slicing.
- std::vector<int64_t> start_indices(
- new_backward_conv->shape().dimensions_size(), 0LL);
- std::vector<int64_t> limit_indices(
- new_backward_conv->shape().dimensions().begin(),
- new_backward_conv->shape().dimensions().end());
- std::vector<int64_t> strides(new_backward_conv->shape().dimensions_size(),
- 1LL);
- for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
- int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
- int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
- // Note that we have swapped input spatial dimensions with output spatial
- // dimensions to be compatible with the cuDNN API, so
- // input_spatial_dimensions(i) gives the i-th spatial dimension of the
- // output.
- int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
- if (padding_low > padding_high) {
- // If the amount of low padding (of the old backward convolution) is
- // larger, we internally pad the low end of the activations and slice
- // internal padding out here.
- start_indices[dim] += padding_low - padding_high;
- } else if (padding_low < padding_high) {
- // If the amount of high padding is larger, we slice out the internal
- // padding on the high end.
- limit_indices[dim] -= padding_high - padding_low;
- }
- }
-
- // Replace the old backward convolution with the slice.
- Shape slice_shape =
- ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
- limit_indices, strides)
- .value();
- CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
- << ShapeUtil::HumanString(slice_shape) << " vs "
- << ShapeUtil::HumanString(backward_conv_shape);
-
- HloInstruction* slice = computation->AddInstruction(
- HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
- start_indices, limit_indices, strides));
- HloInstruction* new_tuple = computation->AddInstruction(
- HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
-
- VLOG(1) << "Canonicalizing backward input conv";
- VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
- << new_tuple->ToString();
-
- TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
- return true;
-}
-
-absl::StatusOr<bool> GpuConvPaddingLegalization::RunOnComputation(
- HloComputation* computation) {
- bool changed = false;
- std::vector<HloCustomCallInstruction*> convs;
- for (auto* instr : computation->instructions()) {
- if (IsCustomCallToDnnConvolution(*instr)) {
- convs.push_back(Cast<HloCustomCallInstruction>(instr));
- }
- }
- for (HloCustomCallInstruction* instruction : convs) {
- TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
- changed |= [&] {
- switch (kind) {
- case CudnnConvKind::kForward:
- case CudnnConvKind::kForwardActivation:
- case CudnnConvKind::kForwardGraph:
- return CanonicalizeForwardConvolution(instruction);
- case CudnnConvKind::kBackwardInput:
- return CanonicalizeBackwardInputConvolution(instruction);
- case CudnnConvKind::kBackwardFilter:
- return CanonicalizeBackwardFilterConvolution(instruction);
- }
- }();
- }
- return changed;
-}
-
-absl::StatusOr<bool> GpuConvPaddingLegalization::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
- changed |= result;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h b/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h
deleted file mode 100644
index 32e0238..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_
-#define XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
-// inserts Pad instructions before Convolution instructions with uncanonicalized
-// padding, so that they can be lowered to Cudnn/Miopen convolution.
-class GpuConvPaddingLegalization : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "gpu-conv-padding-legalization";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
- // Returns if any changes are made to the parent computation.
- bool CanonicalizeForwardConvolution(HloInstruction* conv);
- bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv);
- bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv);
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc b/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc
deleted file mode 100644
index edaf9d0..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc
+++ /dev/null
@@ -1,96 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/test.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using GpuConvPaddingLegalizationTest = HloTestBase;
-
-TEST_F(GpuConvPaddingLegalizationTest, BackwardInputConvolve) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule convolution_module
-ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) {
- %operand = f64[2,2,2,3]{3,2,1,0} parameter(0)
- %kernel = f64[2,3,2,3]{3,2,1,0} constant(
- {
- { /*i0=0*/
- { /*i1=0*/
- { 0.29629629629629628, 0.30246913580246915, 0.30864197530864196 },
- { 0.31481481481481483, 0.32098765432098764, 0.3271604938271605 }
- },
- { /*i1=1*/
- { 0.25925925925925924, 0.26543209876543211, 0.27160493827160492 },
- { 0.27777777777777779, 0.2839506172839506, 0.29012345679012347 }
- },
- { /*i1=2*/
- { 0.22222222222222221, 0.22839506172839505, 0.23456790123456789 },
- { 0.24074074074074073, 0.24691358024691357, 0.25308641975308643 }
- }
- },
- { /*i0=1*/
- { /*i1=0*/
- { 0.18518518518518517, 0.19135802469135801, 0.19753086419753085 },
- { 0.20370370370370369, 0.20987654320987653, 0.21604938271604937 }
- },
- { /*i1=1*/
- { 0.14814814814814814, 0.15432098765432098, 0.16049382716049382 },
- { 0.16666666666666666, 0.1728395061728395, 0.17901234567901234 }
- },
- { /*i2=2*/
- { 0.1111111111111111, 0.11728395061728394, 0.12345679012345678 },
- { 0.12962962962962962, 0.13580246913580246, 0.1419753086419753 }
- }
- }
- })
- %reverse = f64[2,3,2,3]{3,2,1,0} reverse(%kernel), dimensions={0,1}
- ROOT %custom-call = (f64[2,2,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f64[2,2,2,3]{3,2,1,0} %operand, f64[2,3,2,3]{3,2,1,0} %reverse), window={size=2x3 stride=2x2 pad=0_0x0_1}, dim_labels=bf01_01io->b01f, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
-}
- )")
- .value();
- ASSERT_TRUE(GpuConvPaddingLegalization().Run(module.get()).value());
- auto root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::Tuple(
- m::Slice(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardInputCallTarget},
- m::Op(), m::Reverse(m::Constant())),
- 0)),
- m::GetTupleElement())));
- auto slice = root->operand(0);
- Shape expected_slice_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 4});
- EXPECT_TRUE(ShapeUtil::Equal(slice->shape(), expected_slice_shape));
- auto conv = slice->operand(0);
- Shape expected_conv_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 5});
- EXPECT_TRUE(ShapeUtil::Equal(conv->shape(), expected_conv_shape));
-}
-
-} // anonymous namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc
deleted file mode 100644
index cb5b186..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc
+++ /dev/null
@@ -1,869 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-
-#include <cstdint>
-#include <cstdlib>
-#include <memory>
-#include <numeric>
-#include <optional>
-#include <string>
-#include <string_view>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/window_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-absl::Status CheckTypes(HloInstruction* conv,
- const se::GpuComputeCapability cc) {
- auto valid_shape = [conv, &cc](const Shape& shape) -> absl::Status {
- PrimitiveType type = shape.element_type();
- if (!primitive_util::IsFloatingPointType(type) &&
- !primitive_util::IsIntegralType(type)) {
- // Among integral types, only S8 is supported. But CudnnFusedConvRewriter
- // may rewrite convolutions of wider types into S8 convolutions, so allow
- // all integral convolutions here.
- return Unimplemented(
- "Convolutions must have floating-point or integral operands/outputs, "
- "but got convolution with type %s: %s",
- primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
- }
- if (primitive_util::IsF8Type(type)) {
- if (type != F8E4M3FN && type != F8E5M2) {
- return Unimplemented(
- "The only FP8 types supported in convolutions are f8e5m2 and "
- "f8e4m3, "
- "but got convolution with FP8 type %s: %s",
- primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
- }
- if (!std::holds_alternative<se::CudaComputeCapability>(cc)) {
- return Unimplemented(
- "FP8 convolutions are only supported on CUDA GPUs, but got "
- "FP8 convolution on ROCm GPU: %s",
- conv->ToString());
- } else if (!std::get<se::CudaComputeCapability>(cc).IsAtLeastHopper()) {
- return Unimplemented(
- "FP8 convolutions are only supported on CUDA GPUs with compute "
- "capability at least 9.0, but got "
- "FP8 convolution on GPU with compute capability %s: %s",
- std::get<se::CudaComputeCapability>(cc).ToString(),
- conv->ToString());
- }
- }
- return absl::OkStatus();
- };
-
- TF_RETURN_IF_ERROR(valid_shape(conv->shape()));
- TF_RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape()));
- TF_RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape()));
- return absl::OkStatus();
-}
-
-using ConvolutionMatch = std::optional<
- std::tuple<Window, ConvolutionDimensionNumbers, HloInstruction*>>;
-
-// Determine whether conv2d is equal to conv1d.
-bool MaybeConv1dToConv2d(HloInstruction* conv) {
- if (conv->window().dimensions().size() != 2) {
- return false;
- }
- if (conv->operand(1)->opcode() != HloOpcode::kReshape) {
- return false;
- }
- auto filter = conv->operand(1);
- std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
- filter->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
- if (reshape_degenerate.has_value() &&
- reshape_degenerate->deleted_dimensions.empty() &&
- reshape_degenerate->inserted_dimensions.size() == 1) {
- const auto& dnums = conv->convolution_dimension_numbers();
- for (auto dim : dnums.kernel_spatial_dimensions()) {
- if (dim == reshape_degenerate->inserted_dimensions[0]) {
- return true;
- }
- }
- }
- return false;
-}
-
-bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
- const ConvolutionDimensionNumbers& dnums =
- conv->convolution_dimension_numbers();
- if (dnums.input_spatial_dimensions_size() > 3) {
- return false;
- }
-
- // CuDNN does not accept zero-element arguments
- if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
- ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
- return false;
- }
-
- // CuDNN can perform either cross correlation (no reversal),
- // or convolution (all dimensions reversed).
- if (dnums.input_spatial_dimensions_size() == 2
- ? !window_util::AllOrNoneReversed(conv->window())
- : window_util::HasWindowReversal(conv->window())) {
- return false;
- }
- return true;
-}
-
-// Try to match a backward filter pattern that contains "conv".
-// Precondition: "conv" is a kConvolution.
-ConvolutionMatch MatchBackwardFilter(HloInstruction* conv) {
- VLOG(2) << "Trying to match convolution backward filter.";
-
- if (conv->feature_group_count() > 1) {
- VLOG(1) << conv->ToString()
- << " is a forward convolution. All grouped backward filters are "
- "mapped to batch grouped convolutions in tf2xla bridge. Hence "
- "backward filter "
- "convolutions cannot have feature groups greater than 1 at this "
- "point. No need to fold to backward filter.";
- return std::nullopt;
- }
-
- // Step 1: match the instruction pattern without considering the paddings and
- // dimension numbers just yet. We may need some generic pattern matcher
- // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
- //
- // Backward filter convolution is implemented in XLA as the forward
- // convolution of padded activations and dilated gradients. Padding on
- // activations and dilation on gradients are specified in the "window" field
- // of the forward convolution.
- //
- // activations gradients
- // \ /
- // v v
- // Convolution
- // conv
- CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
-
- // Step 2: match paddings and dimension numbers of the forward convolution.
- const ConvolutionDimensionNumbers& conv_dnums =
- conv->convolution_dimension_numbers();
- auto input_batch_dim = conv_dnums.input_batch_dimension();
- auto input_feature_dim = conv_dnums.input_feature_dimension();
- auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
- auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
- auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
- auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
- auto output_batch_dim = conv_dnums.output_batch_dimension();
- auto output_feature_dim = conv_dnums.output_feature_dimension();
- auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
- for (const WindowDimension& window_dim : conv->window().dimensions()) {
- if (window_dim.stride() != 1) {
- VLOG(1) << "Forward convolution's window "
- << conv->window().ShortDebugString()
- << " should have stride of 1.";
- return std::nullopt;
- }
- if (window_dim.base_dilation() != 1) {
- VLOG(1) << "Forward convolution's window "
- << conv->window().ShortDebugString()
- << " should have no base (LHS) dilation.";
- return std::nullopt;
- }
- if (window_dim.padding_low() < 0) {
- VLOG(1) << "Padding low should be non-negative.";
- return std::nullopt;
- }
- if (window_dim.window_reversal()) {
- VLOG(1) << "Window reversal field not supported";
- return std::nullopt;
- }
- // Padding high will be checked in Step 3.
- }
- // Mathematically, there is no difference between convolution forward vs
- // backward filter. A backward filter:
- // [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
- // Can be treated as a forward convolution with `N` treated as the new
- // contracting (feature) dimension, `O` treated as the new batch dimension,
- // and `C` treated as the new output feature dimension. The only difference is
- // layouts and performance.
- //
- // Since there is no way to precisely tell whether we want a foward conv or
- // backward filter conv, we have to rely on heuristics. Empirically forward
- // convolutions have very small kernel dimensions, while in the backward pass
- // "kernel dimensions" are large. If kernel dimensions are smaller than the
- // output dimensions, return foward conv; otherwise proceed with backward
- // filter conv. But for conv1d, it is not same. Due to conv1d always reshape
- // 1D-filter to 2D-filter, even backward or forward will exist one small
- // kernel dimension. We should handle this special case.
- int small_kernel_dimension_num = 0;
- for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
- if (conv->operand(1)->shape().dimensions(kernel_spatial_dims[i]) <=
- conv->shape().dimensions(output_spatial_dims[i])) {
- small_kernel_dimension_num += 1;
- }
- }
- if ((kernel_spatial_dims.empty() || small_kernel_dimension_num > 1 ||
- (!MaybeConv1dToConv2d(conv) && small_kernel_dimension_num == 1)) &&
- !window_util::HasWindowDilation(conv->window())) {
- VLOG(1) << conv->ToString()
- << " is a regular forward convolution. No need "
- "to fold it to a backward filter convolution....";
- return std::nullopt;
- }
-
- // Step 3: fuse the matched HLOs into a backward convolution instruction.
- //
- // Compute the window of the backward convolution.
- Window backward_conv_window;
- for (int i = 0; i < input_spatial_dims.size(); ++i) {
- WindowDimension* dim = backward_conv_window.add_dimensions();
- // The window size of the backward convolution equals the output size of the
- // forward convolution.
- int64_t filter_size = conv->shape().dimensions(output_spatial_dims[i]);
- dim->set_size(filter_size);
- // The window stride equals the window dilation of the forward convolution.
- dim->set_stride(conv->window().dimensions(i).window_dilation());
- // The window's low padding is the same as the low padding of the
- // activations.
- dim->set_padding_low(conv->window().dimensions(i).padding_low());
- dim->set_base_dilation(1);
- dim->set_window_dilation(1);
-
- int64_t input_size =
- conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
- int64_t output_size = conv->window().dimensions(i).size();
- // Compute the range of the amount of valid high padding. We first compute
- // min_padding_high, the amount of padding on the right/bottom to ensure the
- // last patch ends at the border, i.e.,
- //
- // input_size + dim->padding_low() + min_padding_high
- // = (output_size - 1) * stride + filter_size
- //
- // Because convolution ignores trailing incomplete windows, any amount of
- // padding high from min_padding_high to min_padding_high+stride-1
- // (max_padding_high) has the same effect.
- int64_t padded_input_size = filter_size + (output_size - 1) * dim->stride();
- int64_t min_padding_high =
- padded_input_size - input_size - dim->padding_low();
- int64_t max_padding_high = min_padding_high + dim->stride() - 1;
- CHECK_GE(dim->padding_low(), 0);
- // In practice, since cuDNN convolution only supports even padding, we make
- // the amount of high padding the same as the amount of low padding as long
- // as it is between min_padding_high and max_padding_high. If it is not in
- // that range, we pick the one that's closest to dim->padding_low() and let
- // GpuConvPaddingLegalization canonicalize the resultant backward
- // convolution later. Picking the closest one minimizes the cost of the kPad
- // instruction to be inserted by GpuConvPaddingLegalization.
- if (dim->padding_low() >= min_padding_high &&
- dim->padding_low() <= max_padding_high) {
- dim->set_padding_high(dim->padding_low());
- } else {
- if (dim->padding_low() < min_padding_high) {
- dim->set_padding_high(min_padding_high);
- } else {
- dim->set_padding_high(max_padding_high);
- }
- }
- if (dim->padding_high() < 0) {
- LOG(WARNING)
- << "Fusing this pattern to backward filter convolution would cause "
- "negative padding ("
- << dim->padding_high()
- << ") on right/bottom of the weight gradients, which is not "
- "supported by GpuConvPaddingLegalization (b/32744257). "
- "Falling back to "
- "unfused convolution for instruction: "
- << conv->ToString();
- return std::nullopt;
- }
- }
-
- // Restore the dimension numbers of the backward convolution from the forward
- // convolution. The two activation dimensions are reversed (batch and
- // feature).
- ConvolutionDimensionNumbers backward_conv_dnums;
- backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
- backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
- for (int i = 0; i < input_spatial_dims.size(); ++i) {
- backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
- }
- backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
- backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
- for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
- backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
- }
- // The dimension numbering of the output of the forward convolution (before
- // transposition) is the same as that of the activations (according to the
- // semantics of kConvolution). The batch dimension of the activations should
- // be treated as the input feature dimension, and the feature dimension should
- // be treated as the output feature.
- backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
- backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
- for (int i = 0; i < output_spatial_dims.size(); ++i) {
- backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
- }
-
- HloInstruction* lhs = conv->mutable_operand(0);
- return std::make_tuple(backward_conv_window, backward_conv_dnums, lhs);
-}
-
-// Try to match a backward input pattern that contains "conv".
-// Precondition: "conv" is a kConvolution.
-ConvolutionMatch MatchBackwardInput(HloInstruction* conv) {
- VLOG(2) << "Trying to match convolution backward input.";
-
- // TODO(timshen) Theoretically cuDNN supports grouped convolutions also
- // for the backward input convolution, but based on the cudnn's current state
- // there is not much performance improvement when using the
- // cudnn backward input API for grouped conv.
- // This needs to be re-evaluated for future cuDNN versions.
- // Note that we already have the necessary code down below, the only thing to
- // enable it is to remove the following early return.
- if (conv->feature_group_count() > 1) {
- return std::nullopt;
- }
-
- // Match instruction pattern.
- CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
- HloInstruction* reverse_filter = conv->mutable_operand(1);
- ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
-
- // Match BackwardInput for a depthwise convolution and thunk it to forward
- // convolution Output feature dimension and input feature dimension has been
- // swapped in the bridge. Hence to get the actual input features we need to
- // query the output feature dimension
- auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
- auto kernel_out_features =
- reverse_filter->shape().dimensions(kernel_out_feature_dim);
-
- // For a depthwise convolution, the input features must be equal to the
- // feature_group_count. We can leverage this property to match a depthwise
- // convolution and thunk it to forward conv
- if (conv->feature_group_count() > 1 &&
- kernel_out_features == conv->feature_group_count()) {
- return std::nullopt;
- }
-
- // We pattern-match to a backwards input conv if:
- //
- // - all spatial dims of the filter are reversed
- //
- // OR
- //
- // - filter is 1x1 or a constant AND
- // - conv has base dilation (otherwise this is just a regular forward conv).
- //
- // The final criterion above is just for canonicalization; cudnn seems to run
- // just as fast if we canonicalize 1x1/constant filters without base dilation
- // to forward or backward convs. We canonicalize to forward conv because (a)
- // it's more natural (constant filters usually show up when doing inference,
- // and having backwards convolutions in inference graphs would be weird), and
- // (b) cudnn has special fusions for forward conv plus bias and activation,
- // and we want to pattern-match to that after running this pass.
- bool is_reversed_filter =
- reverse_filter->opcode() == HloOpcode::kReverse &&
- absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
- reverse_filter->dimensions());
- // For conv1d which reshape to conv2d, filter reverse pattern is
- // reshape(reverse(filter)). It seems we can reuse conv2d backward input
- // pattern matcher, but after algsimp pass, this pattern will change to
- // reverse(reshape(filter)) and fail to match. So matching conv1d backward
- // input need different processing logic.
- bool is_reversed_conv1d_filter =
- MaybeConv1dToConv2d(conv) &&
- reverse_filter->operand(0)->opcode() == HloOpcode::kReverse;
- bool is_1x1_filter =
- absl::c_all_of(conv->window().dimensions(),
- [](const WindowDimension& d) { return d.size() == 1; });
- if (!is_reversed_filter && !is_reversed_conv1d_filter &&
- !(window_util::HasBaseDilation(conv->window()) &&
- (reverse_filter->IsConstant() || is_1x1_filter))) {
- VLOG(1) << "Can't match to backwards convolution. Either filter is not "
- "kReverse, or it's not a base-dilated conv with a 1x1 or "
- "constant filter.";
- return std::nullopt;
- }
-
- // Match padding and dilation of the forward convolution.
- for (const WindowDimension& window_dim : conv->window().dimensions()) {
- if (window_dim.stride() != 1) {
- VLOG(1) << "Forward convolution's window "
- << conv->window().ShortDebugString()
- << " should have stride of 1.";
- return std::nullopt;
- }
- if (window_dim.window_dilation() != 1) {
- VLOG(1) << "Forward convolution's window "
- << conv->window().ShortDebugString()
- << " should have no window dilation.";
- return std::nullopt;
- }
- if (window_dim.window_reversal()) {
- VLOG(1) << "Window reversal field not supported";
- return std::nullopt;
- }
- }
-
- const auto& input_spatial_dims = dnums.input_spatial_dimensions();
- const auto& output_spatial_dims = dnums.output_spatial_dimensions();
- CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
- CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
-
- const Window& old_window = conv->window();
- Window new_window = old_window;
- for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
- // Restore backward convolution's padding config from the matched pattern.
- // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
- // convert backward input convolution to a variant of forward convolution.
- //
- // The stride of the backward convolution
- // = the base dilation factor of the forward convolution
- auto dim = new_window.mutable_dimensions(i);
- dim->set_stride(old_window.dimensions(i).base_dilation());
- dim->set_base_dilation(1);
-
- // The low padding = kernel_size - 1 - low padding on the gradients
- // Make sure the low padding is not negative.
- auto kernel_size = old_window.dimensions(i).size();
- auto backward_padding_low =
- kernel_size - 1 - old_window.dimensions(i).padding_low();
- if (backward_padding_low < 0) {
- LOG(WARNING)
- << "The low padding of the backward convolution would be negative ("
- << backward_padding_low
- << "), which isn't supported by GpuConvPaddingLegalization "
- "for now (b/32744257).";
- return std::nullopt;
- }
- dim->set_padding_low(backward_padding_low);
-
- // Compute the range of the amount of padding on the right/bottom of the
- // activations. XLA's convolution requires all patches to be within the
- // padded base. This gives us flexiblity to choose the amount of high
- // padding from a set of values without changing the result of the backward
- // convolution. The minimum amount (min_padding_high) makes the last patch
- // end at the border. The maximum amount (max_padding_high) equals
- // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
- // size to change.
- auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
- auto output_size =
- conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
- auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
- auto total_pad_size = padded_input_size - unpadded_input_size;
- auto min_padding_high = total_pad_size - backward_padding_low;
- auto max_padding_high = min_padding_high + dim->stride() - 1;
-
- if (backward_padding_low >= min_padding_high &&
- backward_padding_low <= max_padding_high) {
- // In the best case (most likely), if backward_padding_low is in the range
- // of the amounts of valid high padding, we choose backward_padding_low
- // because cuDNN supports even padding only.
- dim->set_padding_high(backward_padding_low);
- } else {
- // Otherwise, we choose the amount that's closest to backward_padding_low,
- // and GpuConvPaddingLegalization will later insert kSlice
- // instructions to enforce even padding.
- //
- // For example, consider the backward convolution pattern
- //
- // ab xy
- // | pad | reverse
- // .a.b yx
- // \ /
- // ABC
- //
- // The amount of low padding on activations (in backward convolution) is
- // backward_padding_low = kernel_size - 1 - forward_padding_low
- // = 2 - 1 - 1 = 0
- //
- // The amount of padding high must be between 1 and 2, in order to make
- // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
- // the range of [1,2], so we pick the closest valid amount of padding
- // high, which is 1 in this case. Therefore, we fuse the above pattern to
- //
- // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
- if (backward_padding_low < min_padding_high) {
- dim->set_padding_high(min_padding_high);
- } else {
- dim->set_padding_high(max_padding_high);
- }
- }
- // GpuConvPaddingLegalization doesn't handle backward input
- // convolution with negative padding for now. So fall back to unfused
- // convolution in case of negative padding. For example,
- // ABCD = Conv(abc, reverse(xy), padding_high=2)
- // could be fused to
- // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
- // with positive padding low but negative padding high.
- if (dim->padding_high() < 0) {
- LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
- "negative padding ("
- << dim->padding_high()
- << ") on right/bottom of the activations, which is not "
- "supported by GpuConvPaddingLegalization (b/32744257). "
- "Falling back to unfused convolution for instruction: "
- << conv->ToString();
- return std::nullopt;
- }
- }
-
- // OK, it's a match! Switch the input feature dimension with the output
- // feature dimension. Also switch the output with the input. This is the way
- // cuDNN expects it to be.
- auto conv_dnums = conv->convolution_dimension_numbers();
- dnums.set_kernel_input_feature_dimension(
- conv_dnums.kernel_output_feature_dimension());
- dnums.set_kernel_output_feature_dimension(
- conv_dnums.kernel_input_feature_dimension());
- for (int i = 0; i < input_spatial_dims.size(); ++i) {
- dnums.set_input_spatial_dimensions(i,
- conv_dnums.output_spatial_dimensions(i));
- dnums.set_output_spatial_dimensions(i,
- conv_dnums.input_spatial_dimensions(i));
- }
- dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
- dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
- dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
- dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
-
- // If we matched against a constant, we need to add a reverse op that can be
- // subsumed by the cuDNN call. algebraic-simplifier will later remove any
- // unnecessary reverses.
- if (reverse_filter->opcode() != HloOpcode::kReverse &&
- reverse_filter->IsConstant()) {
- // Create a double-reverse, which is a nop.
- HloComputation* c = conv->parent();
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- dnums.kernel_spatial_dimensions()));
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- dnums.kernel_spatial_dimensions()));
- TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
- }
-
- // Calculate the 'rhs' that goes into the backward input convolution.
- HloInstruction* rhs = reverse_filter;
- // One reverse is subsumed by the cuDNN call.
- if (rhs->opcode() == HloOpcode::kReverse) {
- rhs = rhs->mutable_operand(0);
- } else if (is_reversed_conv1d_filter) {
- auto src = rhs->mutable_operand(0)->mutable_operand(0);
- rhs = conv->parent()->AddInstruction(
- HloInstruction::CreateReshape(rhs->shape(), src));
- }
- if (conv->feature_group_count() == 1) {
- return std::make_tuple(new_window, dnums, rhs);
- }
-
- // Handle grouped convolutions. Because we swapped the input feature dimension
- // with the output feature dimension, we need to also reshape the kernel so
- // that the 'feature_group_count' parameter still makes sense. The
- // 'feature_group_count' parameter essentially specifies how often the
- // 'kernel_input_feature_dimension' is repeated. So when we swap these
- // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
- // 'feature_group_count' and multiply the new
- // 'kernel_output_feature_dimension' by 'feature_group_count'.
- int64_t input_feature_dimension = dnums.kernel_input_feature_dimension();
- int64_t output_feature_dimension = dnums.kernel_output_feature_dimension();
- // The following code assumes that input_feature_dimension and
- // output_feature_dimension are adjacent.
- if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
- return std::nullopt;
- }
-
- int64_t input_features = rhs->shape().dimensions(input_feature_dimension);
- int64_t output_features = rhs->shape().dimensions(output_feature_dimension);
-
- // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
- // out_depth / G]
- std::vector<int64_t> reshape_dims = SpanToVector(rhs->shape().dimensions());
- auto num_groups = conv->feature_group_count();
- CHECK_EQ(input_features % num_groups, 0)
- << "Input feature count should be an exact multiple of feature group "
- "count";
- reshape_dims[input_feature_dimension] =
- reshape_dims[input_feature_dimension] / num_groups;
- reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
- num_groups);
-
- HloComputation* c = conv->parent();
- rhs = c->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
-
- // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
- // in_depth/G, G, out_depth / G]
- std::vector<int64_t> transpose_dims(rhs->shape().dimensions_size());
- std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
- transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
- transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
- input_feature_dimension);
- std::vector<int64_t> transpose_reshape_dims =
- SpanToVector(rhs->shape().dimensions());
- transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
- input_feature_dimension);
- transpose_reshape_dims.insert(
- transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
- rhs = c->AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
- rhs, transpose_dims));
-
- // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
- // in_depth/G, out_depth]
- Shape new_shape = rhs->shape();
- new_shape.DeleteDimension(output_feature_dimension);
- new_shape.set_dimensions(output_feature_dimension,
- output_features * num_groups);
- rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
- return std::make_tuple(new_window, dnums, rhs);
-}
-
-HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
- HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64_t feature_group_count,
- const PrecisionConfig& precision_config,
- const OpMetadata& metadata) {
- HloComputation* computation = lhs->parent();
-
- // This call returns a tuple of (conv_result, scratch_memory), where
- // conv_result is the actual result of the convolution, and scratch_memory is
- // temporary memory used by cudnn.
- //
- // At the moment, we don't know how much scratch memory this conv is going to
- // use, so we put u8[0] in this place. Later on another pass will choose
- // which conv algorithm to use, and at that point we'll modify the shape of
- // this second tuple element.
- Shape call_shape =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
-
- HloInstruction* custom_call = computation->AddInstruction(
- HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
- custom_call->set_window(window);
- custom_call->set_convolution_dimension_numbers(dnums);
- custom_call->set_feature_group_count(feature_group_count);
- *custom_call->mutable_precision_config() = precision_config;
- custom_call->set_metadata(metadata);
-
- // Give the customcall a user-friendly name.
- std::optional<std::string> name;
- if (call_target == kCudnnConvForwardCallTarget) {
- name = "cudnn-conv";
- } else if (call_target == kCudnnConvBackwardInputCallTarget) {
- name = "cudnn-conv-bw-input";
- } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
- name = "cudnn-conv-bw-filter";
- } else if (call_target == kCudnnConvBiasActivationForwardCallTarget) {
- name = "cudnn-conv-bias-activation";
- }
- if (name.has_value()) {
- computation->parent()->SetAndUniquifyInstrName(custom_call, *name);
- }
-
- return custom_call;
-}
-
-HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
- HloInstruction* conv) {
- CHECK_EQ(conv->feature_group_count(), 1);
- int64_t num_groups = conv->batch_group_count();
- auto dim_numbers = conv->convolution_dimension_numbers();
- auto lhs = conv->mutable_operand(0);
- auto rhs = conv->mutable_operand(1);
-
- int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
-
- Shape output_shape = conv->shape();
- int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
- int64_t input_feature = lhs->shape().dimensions(input_feature_dimension);
-
- HloComputation* computation = lhs->parent();
- auto add = [&](std::unique_ptr<HloInstruction> inst) {
- return computation->AddInstruction(std::move(inst));
- };
- // Reshape batch_dim N -> [G, N/G]
- std::vector<int64_t> reshape_dims = SpanToVector(lhs->shape().dimensions());
- reshape_dims[input_batch_dimension] =
- reshape_dims[input_batch_dimension] / num_groups;
- reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
- lhs = add(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
-
- // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
- // W, G, C]
- std::vector<int64_t> transpose_dims(lhs->shape().dimensions_size());
- std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
- transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
- transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
- input_batch_dimension);
- std::vector<int64_t> transpose_reshape_dims =
- ComposePermutations(lhs->shape().dimensions(), transpose_dims);
- lhs = add(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
- lhs, transpose_dims));
-
- // Merge [G,C] -> [C*G]
- Shape new_shape = lhs->shape();
- new_shape.DeleteDimension(input_feature_dimension);
- new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
- lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
-
- std::vector<HloInstruction*> new_operands = {lhs, rhs};
- auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
- new_conv->set_feature_group_count(num_groups);
- new_conv->set_batch_group_count(1);
- new_conv->set_convolution_dimension_numbers(dim_numbers);
- return computation->AddInstruction(std::move(new_conv));
-}
-
-CudnnConvBackendConfig GetDefaultBackendConfig() {
- CudnnConvBackendConfig config;
- config.set_conv_result_scale(1);
- return config;
-}
-
-// Helper function to create a custom_call instruction to replace the given
-// conv instruction
-static absl::StatusOr<HloInstruction*> CreateCustomCallHelper(
- HloInstruction* conv, const se::GpuComputeCapability& cc) {
- TF_RETURN_IF_ERROR(CheckTypes(conv, cc));
- if (ConvolutionMatch m = MatchBackwardInput(conv)) {
- auto& [window, dnums, rhs] = *m;
- return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
- conv->mutable_operand(0), rhs, window, dnums,
- conv->feature_group_count(), conv->precision_config(),
- conv->metadata());
- }
-
- if (ConvolutionMatch m = MatchBackwardFilter(conv)) {
- auto& [window, dnums, lhs] = *m;
- return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
- conv->mutable_operand(1), window, dnums,
- conv->batch_group_count(), conv->precision_config(),
- conv->metadata());
- }
-
- // If all else fails, try a forward convolution.
- if (CanImplementAsGpuForwardConv(conv)) {
- if (conv->batch_group_count() > 1) {
- conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
- }
-
- return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
- conv->mutable_operand(0), conv->mutable_operand(1),
- conv->window(), conv->convolution_dimension_numbers(),
- conv->feature_group_count(), conv->precision_config(),
- conv->metadata());
- }
-
- return nullptr;
-}
-
-// Tries to rewrite a single convolution into a call to cudnn/miopen.
-absl::StatusOr<bool> RunOnInstruction(HloInstruction* conv,
- const se::GpuComputeCapability& cc) {
- CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
-
- TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
- CreateCustomCallHelper(conv, cc));
- if (custom_call == nullptr) {
- return false;
- }
-
- GpuBackendConfig gpu_backend_config;
- *gpu_backend_config.mutable_cudnn_conv_backend_config() =
- GetDefaultBackendConfig();
- TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
-
- VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
- << custom_call->ToString();
-
- // The CustomCall returns a tuple (conv_result, scratch_memory). Extract
- // out the conv result and replace `conv` with it.
- TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
- conv,
- HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
- return true;
-}
-
-// Rewrites the convolutions in the given computation into calls to
-// cudnn/miopen.
-// Returns true if it made any changes.
-absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
- const se::GpuComputeCapability& cc) {
- std::vector<HloInstruction*> convs;
- for (auto* hlo : computation->instructions()) {
- if (hlo->opcode() == HloOpcode::kConvolution) {
- convs.push_back(hlo);
- }
- }
-
- bool changed = false;
- for (HloInstruction* conv : convs) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc));
- changed |= result;
- }
- return changed;
-}
-} // namespace
-
-absl::StatusOr<bool> GpuConvRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result,
- RunOnComputation(computation, compute_capability_));
- changed |= result;
- }
- XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
- return changed;
-}
-
-/*static*/ bool GpuConvRewriter::ConvIsLowerable(HloInstruction* conv) {
- return CanImplementAsGpuForwardConv(conv) || MatchBackwardFilter(conv) ||
- MatchBackwardInput(conv);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h
deleted file mode 100644
index 74b860f..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_
-#define XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites plain convolutions, backwards-filter convolutions, and
-// backwards-input convolutions into CustomCall HLOs that call into
-// Cudnn/Miopen.
-//
-// This pass does not fuse other ops into the convolution. Instead, specific
-// patterns of ops will be matched and fused into the custom call in
-// CudnnFusedConvRewriter.
-
-class GpuConvRewriter : public HloModulePass {
- public:
- explicit GpuConvRewriter(const se::GpuComputeCapability& compute_capability)
- : compute_capability_(compute_capability) {};
-
- absl::string_view name() const override { return "gpu-conv-rewriter"; }
-
- static bool ConvIsLowerable(HloInstruction* conv);
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::GpuComputeCapability compute_capability_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc
deleted file mode 100644
index f83bae8..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc
+++ /dev/null
@@ -1,812 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-
-#include <optional>
-#include <string>
-
-#include "absl/log/check.h"
-#include "absl/strings/str_format.h"
-#include "xla/array4d.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/literal_util.h"
-#include "xla/protobuf_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/shape_inference.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/test.h"
-#include "xla/test_helpers.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuConvRewriterTest : public HloTestBase {
- public:
- GpuConvRewriterTest()
- : HloTestBase(/*verifier_layout_sensitive=*/true,
- /*allow_mixed_precision_in_hlo_verifier=*/false) {
- for (int i = 0; i < 2; ++i) {
- WindowDimension* window_dim = default_conv_window_.add_dimensions();
- window_dim->set_size(1);
- window_dim->set_stride(1);
- window_dim->set_padding_low(0);
- window_dim->set_padding_high(0);
- window_dim->set_window_dilation(1);
- window_dim->set_base_dilation(1);
- }
- // TF data shapes are by default in the NHWC order, and filter shape is by
- // default in HWIO order. For backward filter convolution, we need to swap
- // the batch and feature dimension in the activations, and treat the batch
- // dimension in gradients as the input feature dimension in the filter.
- //
- // TODO(jingyue): Add more tests on NCHW input order, which TF also
- // supports.
- tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
- tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
- tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
- tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
- tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
- tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
- 3);
- tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
- tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
- tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
- tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
- tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
- tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
-
- tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
- tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
- tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
- tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
- tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
- tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
- tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
- tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
- tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
- tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
- tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
- tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
- }
-
- protected:
- const se::GpuComputeCapability& GetComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .gpu_compute_capability();
- }
-
- bool RunPass(HloModule* module) {
- return GpuConvRewriter(GetComputeCapability()).Run(module).value();
- }
-
- // A convolution window with stride 1 and zero padding. The size fields are
- // not set.
- Window default_conv_window_;
- ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
- ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
-};
-
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) {
- HloComputation::Builder builder(TestName());
- HloInstruction* activations =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
- HloInstruction* gradients =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
- Window conv_window = default_conv_window_;
- conv_window.mutable_dimensions(1)->set_size(2);
- conv_window.mutable_dimensions(1)->set_window_dilation(2);
- auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(
- activations->shape(), gradients->shape(), /*feature_group_count=*/1,
- /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_filter_,
- /*preferred_element_type=*/std::nullopt)
- .value(),
- activations, gradients, /*feature_group_count=*/1,
- /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
- OpMetadata metadata;
- metadata.set_op_name("foo");
- conv->set_metadata(metadata);
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- ASSERT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-
- // Check that metadata was preserved.
- const auto& md_after_opt =
- entry_computation->root_instruction()->operand(0)->metadata();
- EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
- << md_after_opt.DebugString() << " vs " << metadata.DebugString();
-}
-
-TEST_F(GpuConvRewriterTest,
- BackwardFilterConvolveEquivalentToForwardConvolution) {
- HloComputation::Builder builder(TestName());
- HloInstruction* activations =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
- HloInstruction* gradients =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
- Window conv_window = default_conv_window_;
- conv_window.mutable_dimensions(1)->set_size(3);
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(
- activations->shape(), gradients->shape(), /*feature_group_count=*/1,
- /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_filter_,
- /*preferred_element_type=*/std::nullopt)
- .value(),
- activations, gradients, /*feature_group_count=*/1,
- /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Extracted from block35 training.
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) {
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* activations =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
- HloInstruction* gradients =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
-
- Window conv_window = default_conv_window_;
- for (int i = 0; i < 2; ++i) {
- conv_window.mutable_dimensions(i)->set_size(35);
- conv_window.mutable_dimensions(i)->set_padding_low(1);
- conv_window.mutable_dimensions(i)->set_padding_high(1);
- }
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-}
-
-// Extracted from inception v3 training.
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) {
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* activations =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
- HloInstruction* gradients =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
-
- Window conv_window = default_conv_window_;
- for (int i = 0; i < 2; ++i) {
- conv_window.mutable_dimensions(i)->set_size(4);
- conv_window.mutable_dimensions(i)->set_padding_high(-1);
- conv_window.mutable_dimensions(i)->set_window_dilation(2);
- }
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-}
-
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* activations =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
- HloInstruction* gradients =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
-
- Window conv_window = default_conv_window_;
- for (int i = 0; i < 2; ++i) {
- conv_window.mutable_dimensions(i)->set_size(35);
- // Uneven padding: padding_low=0, padding_high=1
- conv_window.mutable_dimensions(i)->set_padding_high(1);
- }
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-}
-
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) {
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* output =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
- HloInstruction* kernel =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
- HloInstruction* reverse_kernel = builder.AddInstruction(
- HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
-
- Window conv_window = default_conv_window_;
- for (int i = 0; i < 2; ++i) {
- conv_window.mutable_dimensions(i)->set_size(7);
- conv_window.mutable_dimensions(i)->set_padding_low(3);
- conv_window.mutable_dimensions(i)->set_padding_high(3);
- }
- ConvolutionDimensionNumbers conv_dnums;
- conv_dnums.set_input_batch_dimension(0);
- conv_dnums.set_output_batch_dimension(0);
- conv_dnums.set_input_feature_dimension(1);
- conv_dnums.set_output_feature_dimension(1);
- conv_dnums.add_input_spatial_dimensions(2);
- conv_dnums.add_output_spatial_dimensions(2);
- conv_dnums.add_input_spatial_dimensions(3);
- conv_dnums.add_output_spatial_dimensions(3);
- conv_dnums.set_kernel_input_feature_dimension(0);
- conv_dnums.set_kernel_output_feature_dimension(1);
- conv_dnums.add_kernel_spatial_dimensions(2);
- conv_dnums.add_kernel_spatial_dimensions(3);
-
- HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
- /*rhs=*/reverse_kernel, /*feature_group_count=*/1,
- /*batch_group_count=*/1, conv_window, conv_dnums,
- DefaultPrecisionConfig(2)));
- // Verify the convolution's shape is consistent with ShapeInference.
- CHECK(ShapeUtil::Compatible(
- conv->shape(),
- ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(),
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- conv_dnums, /*preferred_element_type=*/std::nullopt)
- .value()));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
-
- ASSERT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
- const HloInstruction* custom_call =
- entry_computation->root_instruction()->operand(0);
- for (int i = 0; i < 2; ++i) {
- const WindowDimension& window_dim = custom_call->window().dimensions(i);
- // Low padding of the backward input convolution
- // = kernel_size - 1 - low padding on gradients.
- EXPECT_EQ(3, window_dim.padding_low());
- EXPECT_EQ(3, window_dim.padding_high());
- EXPECT_EQ(1, window_dim.stride());
- EXPECT_EQ(1, window_dim.base_dilation());
- }
-}
-
-// Convolve([abc], [x], base_dilation=2)
-// = Convolve([abc], Reverse([x]), base_dilation=2)
-// = BackwardInputConvolve([abc], [x], stride=2)
-TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) {
- auto builder = HloComputation::Builder(TestName());
- // NHWC dimension order.
- HloInstruction* output =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
- // HWOI dimension order.
- HloInstruction* kernel =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
-
- Window conv_window = default_conv_window_;
- conv_window.mutable_dimensions(1)->set_base_dilation(2);
-
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(
- output->shape(), kernel->shape(),
- /*feature_group_count=*/1,
- /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_input_,
- /*preferred_element_type=*/std::nullopt)
- .value(),
- /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
- /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
-}
-
-// BackwardInputConvolve([abc], [x], stride=1) is equivalent to
-// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
-// convolution.
-TEST_F(GpuConvRewriterTest,
- BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
- auto builder = HloComputation::Builder(TestName());
- // NHWC dimension order.
- HloInstruction* output =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
- // HWOI dimension order.
- HloInstruction* kernel =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
-
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(
- output->shape(), kernel->shape(), /*feature_group_count=*/1,
- /*batch_group_count=*/1, default_conv_window_,
- tf_default_dnums_for_backward_input_,
- /*preferred_element_type=*/std::nullopt)
- .value(),
- /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
- /*batch_group_count=*/1, default_conv_window_,
- tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Extracted from Inception V3 training.
-//
-// filter(HWIO)
-// 3x3x192x320
-// |
-// v
-// gradients(NHWC) reverse
-// 20x4x4x320 3x3x192x320
-// \ /
-// \ /
-// conv (NHWC) with padding (low=2,high=3,interior=1)
-// 20x10x10x192
-//
-// Gradients are padded unevenly.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* output =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
- HloInstruction* kernel =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
- HloInstruction* reverse_kernel = builder.AddInstruction(
- HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
- Window conv_window = default_conv_window_;
- for (int i = 0; i < 2; ++i) {
- conv_window.mutable_dimensions(i)->set_size(3);
- conv_window.mutable_dimensions(i)->set_padding_low(2);
- conv_window.mutable_dimensions(i)->set_padding_high(3);
- // Interior padding = 1.
- conv_window.mutable_dimensions(i)->set_base_dilation(2);
- }
- HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
- // Verify the convolution's shape is consistent with ShapeInference.
- CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(),
- /*feature_group_count=*/1, /*batch_group_count=*/1,
- conv_window, tf_default_dnums_for_backward_input_,
- /*preferred_element_type=*/std::nullopt)
- .value()));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- ASSERT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
- const HloInstruction* custom_call =
- entry_computation->root_instruction()->operand(0);
- for (int i = 0; i < 2; ++i) {
- const WindowDimension& window_dim = custom_call->window().dimensions(i);
- EXPECT_EQ(0, window_dim.padding_low());
- EXPECT_EQ(0, window_dim.padding_high());
- EXPECT_EQ(2, window_dim.stride());
- EXPECT_EQ(1, window_dim.base_dilation());
- }
-}
-
-// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
-// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* output =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
- HloInstruction* kernel =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
- HloInstruction* reverse_kernel = builder.AddInstruction(
- HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
- Window conv_window = default_conv_window_;
- for (int i = 0; i < 2; ++i) {
- conv_window.mutable_dimensions(i)->set_size(3);
- conv_window.mutable_dimensions(i)->set_padding_low(3);
- conv_window.mutable_dimensions(i)->set_padding_high(2);
- conv_window.mutable_dimensions(i)->set_base_dilation(2);
- }
- HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
- // Verify the convolution's shape is consistent with ShapeInference.
- CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(),
- /*feature_group_count=*/1, /*batch_group_count=*/1,
- conv_window, tf_default_dnums_for_backward_input_,
- /*preferred_element_type=*/std::nullopt)
- .value()));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Extracted from Resnet-50.
-//
-// For simplicity, we focus on the column dimension and ignore other dimensions.
-// We use [?] to represent the shape instead of the content.
-//
-// Suppose operator FC does
-// [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame
-//
-// BC = BackwardInput(FC) does:
-// [14] = conv([7], reverse([3]),
-// padding_low=2, padding_high=1, base_dilation=2)
-//
-// We should fuse BC even though padding on activations is uneven, because
-// GpuConvPaddingLegalization will canonicalize the fusion HLO.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
- auto builder = HloComputation::Builder(TestName());
- // The gradients are in NCHW layout.
- HloInstruction* output =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
- // The kernel is in HWIO layout.
- HloInstruction* kernel =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
- HloInstruction* reverse_kernel = builder.AddInstruction(
- HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
- Window conv_window = default_conv_window_;
- WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
- forward_conv_col_dim->set_size(3);
- forward_conv_col_dim->set_padding_low(2);
- forward_conv_col_dim->set_padding_high(1);
- forward_conv_col_dim->set_base_dilation(2);
- HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
- // Verify the convolution's shape is consistent with ShapeInference.
- CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(),
- /*feature_group_count=*/1, /*batch_group_count=*/1,
- conv_window, tf_default_dnums_for_backward_input_,
- /*preferred_element_type=*/std::nullopt)
- .value()));
-
- auto module = CreateNewVerifiedModule();
- const HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- ASSERT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
- const WindowDimension& backward_conv_col_dim =
- entry_computation->root_instruction()->operand(0)->window().dimensions(1);
- EXPECT_EQ(0, backward_conv_col_dim.padding_low());
- EXPECT_EQ(1, backward_conv_col_dim.padding_high());
-}
-
-// For simplicity, we focus on the column dimension and ignore other dimensions.
-// We use [?] to represent the shape instead of the content.
-//
-// Suppose operator FC does
-// [3] = conv([4], [2], padding_low=1, padding_high=-1)
-//
-// BC = BackwardInput(FC) does:
-// [4] = conv([3], reverse([2]), padding_high=2)
-//
-// We currently don't fuse BC because GpuConvPaddingLegalization
-// doesn't support negative padding on the gradients of backward convolution
-// (b/32744257).
-TEST_F(GpuConvRewriterTest,
- BackwardInputConvolveNegativePaddingHighOnActivations) {
- auto builder = HloComputation::Builder(TestName());
- // The gradients are in NCHW layout.
- HloInstruction* output =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
- // The kernel is in HWIO layout.
- HloInstruction* kernel =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
- HloInstruction* reverse_kernel = builder.AddInstruction(
- HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
- Window conv_window = default_conv_window_;
- WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
- forward_conv_col_dim->set_size(2);
- forward_conv_col_dim->set_padding_high(2);
- HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
- /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
- tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
- // Verify the convolution's shape is consistent with ShapeInference.
- CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(),
- /*feature_group_count=*/1, /*batch_group_count=*/1,
- conv_window, tf_default_dnums_for_backward_input_,
- /*preferred_element_type=*/std::nullopt)
- .value()));
-
- auto module = CreateNewVerifiedModule();
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
- EXPECT_THAT(entry_computation->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Check that we will materialize a reversed version of a constant in order to
-// pattern-match a backwards input convolution.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) {
- Array4D<float> constant_arr(4, 4, 2, 2);
- constant_arr.FillIota(0);
- std::string constant_str =
- LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape();
-
- const std::string module_str = absl::StrFormat(R"(
- HloModule test
-
- ENTRY entry_computation {
- param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
- constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
- ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
- window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
- dim_labels=bf01_01oi->bf01, feature_group_count=1
- })",
- constant_str);
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- EXPECT_TRUE(RunPass(m.get()));
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardInputCallTarget},
- m::Parameter(), m::Reverse(m::Constant())),
- 0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternMatch) {
- // All filter dimensions are larger than the corresponding output dimensions.
- // This must be a backward filter convolution.
- const std::string module_str = absl::StrFormat(R"(
- HloModule Test
-
- ENTRY Test {
- input = f32[8,120,256,256] parameter(0)
- filter = f32[8,120,256,256] parameter(1)
-
- ROOT conv = f32[120,120,3,3] convolution(input, filter), window={size=256x256 pad=1_1x1_1}, dim_labels=fb01_io01->fb01
- })");
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- EXPECT_TRUE(RunPass(m.get()));
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardFilterCallTarget},
- m::Parameter(0), m::Parameter(1)),
- 0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) {
- // At least one filter dimension is smaller than the corresponding output
- // dimension. This must be a forward convolution.
- const std::string module_str = absl::StrFormat(R"(
- HloModule Test
-
- ENTRY Test {
- input = f32[8,128,2,32] parameter(0)
- filter = f32[3,3,128,128] parameter(1)
-
- ROOT conv = f32[8,128,2,32] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
- })");
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- EXPECT_TRUE(RunPass(m.get()));
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvForwardCallTarget}, m::Parameter(0),
- m::Parameter(1)),
- 0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) {
- // There exist one kernel dimension equal to output dimension, regard
- // it as backward filter if conv is 1d.
- const std::string module_str = absl::StrFormat(R"(
- HloModule Test
-
- ENTRY Test {
- input = f32[8,256,128] parameter(0)
- filter = f32[8,254,128] parameter(1)
- reshape.1 = f32[8,1,256,128] reshape(input)
- reshape.2 = f32[8,1,254,128] reshape(filter)
- ROOT conv = f32[1,3,128,128] convolution(reshape.1, reshape.2), window={size=1x254}, dim_labels=f01b_i01o->01bf
- })");
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- EXPECT_TRUE(RunPass(m.get()));
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardFilterCallTarget},
- m::Reshape(), m::Reshape()),
- 0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) {
- // For conv1d backward input, filter may reverse first and then reshape.
- const std::string module_str = absl::StrFormat(R"(
- HloModule Test
-
- ENTRY Test {
- input = f32[8,254,128] parameter(0)
- filter = f32[3,128,128] parameter(1)
- reverse = f32[3,128,128] reverse(filter), dimensions={0}
- reshape.1 = f32[8,1,254,128] reshape(input)
- reshape.2 = f32[1,3,128,128] reshape(reverse)
- ROOT conv = f32[8,1,256,128] convolution(reshape.1, reshape.2), window={size=1x3 pad=0_0x2_2}, dim_labels=b01f_01oi->b01f
- })");
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- EXPECT_TRUE(RunPass(m.get()));
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCudnnConvBackwardInputCallTarget},
- m::Reshape(), m::Reshape()),
- 0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestInvalidTypes) {
- const std::string module_str = absl::StrFormat(R"(
- HloModule Test
-
- ENTRY Test {
- input = TYPE[1,17,9,9] parameter(0)
- filter = TYPE[3,3,17,32] parameter(1)
- ROOT conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
- })");
-
- // Test complex types
- for (std::string_view type : {"c64", "c128"}) {
- const std::string module_with_type =
- absl::StrReplaceAll(module_str, {{"TYPE", type}});
- TF_ASSERT_OK_AND_ASSIGN(auto m,
- ParseAndReturnVerifiedModule(module_with_type));
-
- absl::Status s =
- GpuConvRewriter(GetComputeCapability()).Run(m.get()).status();
- EXPECT_THAT(
- s, tsl::testing::StatusIs(
- absl::StatusCode::kUnimplemented,
- ::testing::HasSubstr("Convolutions must have floating-point or "
- "integral operands/outputs")));
- }
-
- // Test FP8 type on unsupported GPUs
- std::string module_with_type =
- absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fn"}});
- TF_ASSERT_OK_AND_ASSIGN(auto m,
- ParseAndReturnVerifiedModule(module_with_type));
- absl::Status s = GpuConvRewriter(se::CudaComputeCapability::Ampere())
- .Run(m.get())
- .status();
- EXPECT_THAT(s, tsl::testing::StatusIs(
- absl::StatusCode::kUnimplemented,
- ::testing::HasSubstr(
- "FP8 convolutions are only supported on CUDA "
- "GPUs with compute capability at least 9.0")));
- s = GpuConvRewriter(se::RocmComputeCapability{"gfx942"})
- .Run(m.get())
- .status();
- EXPECT_THAT(s, tsl::testing::StatusIs(
- absl::StatusCode::kUnimplemented,
- ::testing::HasSubstr(
- "FP8 convolutions are only supported on CUDA GPUs")));
-
- // Test unsupported FP8 type
- module_with_type = absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fnuz"}});
- TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(module_with_type));
- s = GpuConvRewriter(GetComputeCapability()).Run(m.get()).status();
- EXPECT_THAT(s,
- tsl::testing::StatusIs(
- absl::StatusCode::kUnimplemented,
- ::testing::HasSubstr("The only FP8 types supported in "
- "convolutions are f8e5m2 and f8e4m3")));
-}
-
-} // anonymous namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc b/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc
deleted file mode 100644
index b8c87e2..0000000
--- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h"
-
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync(
- HloComputation* computation,
- absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
- const {
- absl::flat_hash_map<HloInstruction*, HloInstruction*> replaced_ops;
- CollectiveBackendConfig sync_config;
- sync_config.set_is_sync(true);
- for (auto& [async_start, async_done] : async_pairs) {
- // Tag the async start with is_sync = true.
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- async_start->backend_config<GpuBackendConfig>());
- *gpu_config.mutable_collective_backend_config() = sync_config;
- TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
- replaced_ops[async_start] = nullptr;
- replaced_ops[async_done] = async_start;
- }
-
- // Update schedule.
- HloModule* module = computation->parent();
- const HloInstructionSequence& sequence =
- module->schedule().sequence(computation);
- std::vector<HloInstruction*> new_sequence;
- new_sequence.reserve(sequence.size());
- for (HloInstruction* instr : sequence.instructions()) {
- auto it = replaced_ops.find(instr);
- // If its not a start or done, add it to new schedule.
- if (it == replaced_ops.end()) {
- new_sequence.push_back(instr);
- continue;
- }
-
- // If its a start op, do not add it to the schedule yet.
- if (it->second == nullptr) {
- continue;
- }
-
- // Its a done op. First add the start and then the done.
- new_sequence.push_back(it->second);
- new_sequence.push_back(instr);
- }
- module->schedule().set_sequence(computation, new_sequence);
- return absl::OkStatus();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h b/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h
deleted file mode 100644
index ea56f7a..0000000
--- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
-#define XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
-
-#include <utility>
-
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/convert_async_collectives_to_sync.h"
-
-namespace xla {
-namespace gpu {
-
-class GpuConvertAsyncCollectivesToSync : public ConvertAsyncCollectivesToSync {
- public:
- using ConvertAsyncCollectivesToSync::ConvertAsyncCollectivesToSync;
- absl::string_view name() const override {
- return "gpu-convert-async-collectives-to-sync";
- }
-
- absl::Status ConvertAsyncInstructionsToSync(
- HloComputation* computation,
- absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
- const override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc
deleted file mode 100644
index 03f18bd..0000000
--- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc
+++ /dev/null
@@ -1,347 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h"
-
-#include <string_view>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::IsFalse;
-using ::testing::IsTrue;
-
-// Note: The pass only processes modules that are already scheduled. If the test
-// does not work as expected, make sure to check if "is_scheduled=true" is added
-// to the HLO module string.
-class GpuConvertAsyncCollectivesToSyncTest : public HloTestBase {
- public:
- absl::Status RunPass(HloModule *module, bool expect_change,
- HloPredicate is_nop = {}) {
- TF_ASSIGN_OR_RETURN(bool changed,
- GpuConvertAsyncCollectivesToSync{is_nop}.Run(module));
- EXPECT_EQ(changed, expect_change);
- return absl::OkStatus();
- }
-
- // Returns true if the instruction with the given name is synchronous.
- bool IsSync(HloModule *module, std::string_view name) {
- const HloInstruction *inst = FindInstruction(module, name);
- if (inst == nullptr) {
- return false;
- }
- auto backend_config = inst->backend_config<GpuBackendConfig>()
- .value()
- .collective_backend_config();
- return backend_config.is_sync();
- }
-
- HloPredicate is_nop_simple_ =
- HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kGetTupleElement,
- HloOpcode::kParameter>;
-};
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduce) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- apply_op {
- x = u32[] parameter(0)
- y = u32[] parameter(1)
- ROOT apply_op = u32[] add(x, y)
- }
-
- ENTRY test_computation {
- id = u32[] replica-id()
- start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
- ROOT done = u32[] all-reduce-done(start)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNop) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- apply_op {
- x = u32[] parameter(0)
- y = u32[] parameter(1)
- ROOT apply_op = u32[] add(x, y)
- }
-
- ENTRY test_computation {
- id = u32[] replica-id()
- start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3, replica_groups={{0,1}, {2,3}}
- id2 = f32[] bitcast(id)
- ROOT done = u32[] all-reduce-done(start)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true, is_nop_simple_));
- EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
-}
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectiveBroadcast) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- collective_broadcast {
- p0 = u32[8] parameter(0)
- ROOT result = u32[8] collective-broadcast(p0), replica_groups={{0,1}, {2,3}}
- }
-
- ENTRY main {
- data = u32[8] parameter(0)
- cb-start = ((u32[8]{0}), u32[8]{0}) async-start(u32[8]{0} %data), calls=collective_broadcast
- ROOT %ars = u32[8]{0} async-done(((u32[8]{0}), u32[8]{0}) %cb-start), calls=collective_broadcast
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "cb-start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNonNop) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- apply_op {
- x = u32[] parameter(0)
- y = u32[] parameter(1)
- ROOT apply_op = u32[] add(x, y)
- }
-
- ENTRY test_computation {
- id = u32[] replica-id()
- start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
- id2 = u32[] add(id, id)
- ROOT done = u32[] all-reduce-done(start)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/false));
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllGather) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
- ENTRY test_computation {
- a1 = u32[1, 2] parameter(0)
- ags = (u32[1, 2], u32[2, 2]) all-gather-start(a1), dimensions={0}, channel_id=3
- ROOT allgather = u32[2,2] all-gather-done(ags)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "ags"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectivePermute) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- ENTRY test_computation {
- p = u32[2] parameter(0)
- start = (u32[2], u32[2], u32[], u32[]) collective-permute-start(p), source_target_pairs={{0,1}, {1,0}}
- ROOT done = u32[2] collective-permute-done(start)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleReduceScatter) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- add {
- lhs = u32[] parameter(0)
- rhs = u32[] parameter(1)
- ROOT add = u32[] add(lhs, rhs)
- }
-
- reduce_scatter {
- p0 = u32[8] parameter(0)
- ROOT result = u32[4] reduce-scatter(p0), replica_groups={{0,3}, {1,2}},
- dimensions={0}, to_apply=add
- }
-
- ENTRY main {
- data = u32[8] parameter(0)
- rs-start = ((u32[8]{0}), u32[4]{0}) async-start(u32[8]{0} %data), calls=reduce_scatter
- ROOT %ars = u32[4]{0} async-done(((u32[8]{0}), u32[4]{0}) %rs-start), calls=reduce_scatter
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "rs-start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllToAll) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- all_to_all {
- p0 = u32[2] parameter(0)
- ROOT result = u32[2] all-to-all(p0), dimensions={0}, replica_groups={{0,1},{2,3}}
- }
-
- ENTRY test_computation {
- a1 = u32[2] parameter(0)
- a2a-start = ((u32[2]), u32[2]) async-start(u32[2] a1), calls=all_to_all
- ROOT a2s = u32[2] async-done(a2a-start), calls=all_to_all
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "a2a-start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, ControlDeps) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- apply_op {
- x = u32[] parameter(0)
- y = u32[] parameter(1)
- ROOT apply_op = u32[] add(x, y)
- }
-
- ENTRY test_computation {
- id = u32[] replica-id()
- start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
- done1 = u32[] all-reduce-done(start1)
- start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4, control-predecessors={done1}
- done2 = u32[] all-reduce-done(start2)
- ROOT x = u32[] add(done1, done2)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
- EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-// Test multiple in-flight collectives that are ordered in a streaming fashion:
-// i.e., ends are in start order (FIFO).
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightStreaming) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- apply_op {
- x = u32[] parameter(0)
- y = u32[] parameter(1)
- ROOT apply_op = u32[] add(x, y)
- }
-
- ENTRY test_computation {
- id = u32[] replica-id()
- start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
- start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
- done1 = u32[] all-reduce-done(start1)
- done2 = u32[] all-reduce-done(start2)
- ROOT x = u32[] add(done1, done2)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
- EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0}
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNested) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- apply_op {
- x = u32[] parameter(0)
- y = u32[] parameter(1)
- ROOT apply_op = u32[] add(x, y)
- }
-
- ENTRY test_computation {
- id = u32[] replica-id()
- start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
- start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
- done2 = u32[] all-reduce-done(start2)
- done1 = u32[] all-reduce-done(start1)
- ROOT x = u32[] add(done1, done2)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
- EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0} where
-// inner pair can be converted but not outer.
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNestedPartial) {
- const absl::string_view hlo_string = R"(
- HloModule test, is_scheduled=true
-
- apply_op {
- x = u32[] parameter(0)
- y = u32[] parameter(1)
- ROOT apply_op = u32[] add(x, y)
- }
-
- ENTRY test_computation {
- id = u32[] replica-id()
- start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
- start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
- done2 = u32[] all-reduce-done(start2)
- id2 = u32[] add(done2, done2)
- done1 = u32[] all-reduce-done(start1)
- ROOT x = u32[] add(done1, done2)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
- EXPECT_THAT(IsSync(module.get(), "start1"), IsFalse());
- EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc
index bf97747..25fe651 100644
--- a/third_party/xla/xla/service/gpu/gpu_executable.cc
+++ b/third_party/xla/xla/service/gpu/gpu_executable.cc
@@ -841,9 +841,14 @@
TF_ASSIGN_OR_RETURN(globals, ResolveConstantGlobals(run_options->stream()));
}
- auto device_ordinal = executor->device_ordinal();
+ // Use the `device_ordinal` from the `run_options` if it is provided. This is
+ // the ordinal of the logical devices (e.g., virtual GPUs). If it is not
+ // provided, the ordinals of the logical and physical devices are the same.
+ const int device_ordinal = run_options->device_ordinal() != -1
+ ? run_options->device_ordinal()
+ : executor->device_ordinal();
ExecutionOutput result(/*on_device_shape=*/output_shape_, memory_allocator,
- device_ordinal);
+ device_ordinal, executor->device_ordinal());
TF_ASSIGN_OR_RETURN(
BufferAllocations buffer_allocations,
diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc
deleted file mode 100644
index 566c006..0000000
--- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc
+++ /dev/null
@@ -1,719 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-
-#include <cstdint>
-#include <optional>
-#include <string>
-
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "Eigen/Core"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-using se::DeviceMemory;
-using se::DeviceMemoryBase;
-using se::dnn::DataType;
-using se::dnn::MatmulTensorDescriptor;
-using se::dnn::TensorDescriptor;
-
-template <typename ElementType, typename BiasType, typename OutputType>
-absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream,
- RunFusedMHAOptions options,
- DeviceMemory<ElementType> lhs_bmm1_buffer,
- DeviceMemory<ElementType> rhs_bmm1_buffer,
- DeviceMemory<ElementType> rhs_bmm2_buffer,
- DeviceMemory<OutputType> output_buffer,
- DeviceMemoryBase bias_buffer,
- DeviceMemoryBase scratch_memory,
- DeviceMemoryBase activation_output,
- DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHAOp> *lazy_runner =
- options.runner_cache->AsFusedMHARunner();
- std::optional<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>> local_runner;
- if (!lazy_runner) {
- local_runner.emplace(params.config->algorithm);
- lazy_runner = &*local_runner;
- }
- std::optional<double> dropout_rate;
- if (params.config->dropout_rate) {
- dropout_rate = *params.config->dropout_rate;
- }
-
- std::optional<int64_t> seed;
- if (params.config->seed) {
- seed = *params.config->seed;
- }
-
- TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config,
- params.config->AsDnnFusedMHAOpConfig());
- TF_ASSIGN_OR_RETURN(auto *runner,
- lazy_runner->GetOrCreateRunner(config, stream));
- return (*runner)(stream, options.profile_result, scratch_memory,
- lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer,
- output_buffer, bias_buffer, activation_output, seqlen_q,
- seqlen_k);
-}
-
-template <typename ElementType, typename BiasType, typename OutputType>
-absl::Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream,
- se::DeviceMemoryBase scratch_memory,
- RunFusedMHAOptions options) {
- auto lhs_bmm1_buffer = se::DeviceMemory<ElementType>(params.lhs_bmm1_buffer);
- auto rhs_bmm1_buffer = se::DeviceMemory<ElementType>(params.rhs_bmm1_buffer);
- auto rhs_bmm2_buffer = se::DeviceMemory<ElementType>(params.rhs_bmm2_buffer);
- auto output_buffer = se::DeviceMemory<OutputType>(params.output_buffer);
- auto activation_buffer =
- params.activation_buffer.has_value()
- ? se::DeviceMemory<OutputType>(*params.activation_buffer)
- : se::DeviceMemoryBase();
- auto bias_buffer = params.bias_buffer.has_value()
- ? se::DeviceMemory<BiasType>(*params.bias_buffer)
- : se::DeviceMemoryBase();
- auto seqlen_q_buffer =
- params.seqlen_q_buffer.has_value()
- ? se::DeviceMemory<BiasType>(*params.seqlen_q_buffer)
- : se::DeviceMemoryBase();
- auto seqlen_k_buffer =
- params.seqlen_k_buffer.has_value()
- ? se::DeviceMemory<BiasType>(*params.seqlen_k_buffer)
- : se::DeviceMemoryBase();
- se::dnn::AlgorithmDesc algorithm = params.config->algorithm;
- if (options.runner_cache) {
- algorithm = options.runner_cache->ToAlgorithmDesc();
- }
-
- absl::Status run_status = absl::OkStatus();
- switch (params.config->kind) {
- case CudnnfMHAKind::kSoftmaxDropout:
- case CudnnfMHAKind::kSoftmax:
- case CudnnfMHAKind::kScaleBiasSoftmax:
- case CudnnfMHAKind::kScaleBiasSoftmaxDropout:
- run_status = RunFusedMHA<ElementType, BiasType, OutputType>(
- params, stream, options, lhs_bmm1_buffer, rhs_bmm1_buffer,
- rhs_bmm2_buffer, output_buffer, bias_buffer, scratch_memory,
- activation_buffer, seqlen_q_buffer, seqlen_k_buffer);
- break;
- default:
- return Internal("Invalid cuDNN fMHA kind");
- }
-
- if (!run_status.ok()) {
- return run_status;
- }
-
- if (!stream->ok()) {
- return Internal("Unable to launch FMHA with type %s and algorithm %s",
- CudnnfMHAKindToString(params.config->kind),
- algorithm.ToString());
- }
-
- return absl::OkStatus();
-}
-
-template <typename ElementType, typename OutputType>
-absl::Status RunFusedMHABackward(
- GpufMHABackwardParams params, se::Stream *stream,
- RunFusedMHABackwardOptions options,
- DeviceMemory<ElementType> bmm1_grad_gemm1_rhs_buffer,
- DeviceMemory<ElementType> bmm1_grad_gemm2_rhs_buffer,
- DeviceMemory<ElementType> bmm2_grad_gemm1_lhs_buffer,
- DeviceMemory<ElementType> bmm2_grad_gemm2_rhs_buffer,
- DeviceMemory<ElementType> d_output_buffer,
- DeviceMemory<OutputType> d_bmm1_lhs_buffer,
- DeviceMemory<OutputType> d_bmm1_rhs_buffer,
- DeviceMemory<OutputType> d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer,
- DeviceMemoryBase d_bias_buffer, DeviceMemoryBase fwd_output_buffer,
- DeviceMemoryBase bias_buffer, DeviceMemoryBase scratch_memory,
- DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp> *lazy_runner =
- options.runner_cache->AsFusedMHABackwardRunner();
- std::optional<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>
- local_runner;
- if (!lazy_runner) {
- local_runner.emplace(params.config->algorithm);
- lazy_runner = &*local_runner;
- }
- std::optional<double> dropout_rate;
- if (params.config->dropout_rate) {
- dropout_rate = *params.config->dropout_rate;
- }
-
- std::optional<int64_t> seed;
- if (params.config->seed) {
- seed = *params.config->seed;
- }
-
- TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config,
- params.config->AsDnnFusedMHABackwardOpConfig());
- TF_ASSIGN_OR_RETURN(auto *runner,
- lazy_runner->GetOrCreateRunner(config, stream));
- // TODO: pass in real softmax_sum, dQ_accum, fwd_output
- return (*runner)(stream, options.profile_result, scratch_memory,
- bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
- bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer,
- d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer,
- d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer,
- fwd_output_buffer, bias_buffer, seqlen_q, seqlen_k);
- return absl::OkStatus();
-}
-
-template <typename ElementType, typename BiasType, typename OutputType>
-absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms,
- se::Stream *stream,
- se::DeviceMemoryBase scratch_memory,
- RunFusedMHABackwardOptions options) {
- auto bmm1_grad_gemm1_rhs_buffer =
- se::DeviceMemory<ElementType>(params.bmm1_grad_gemm1_rhs_buffer);
- auto bmm1_grad_gemm2_rhs_buffer =
- se::DeviceMemory<ElementType>(params.bmm1_grad_gemm2_rhs_buffer);
- auto bmm2_grad_gemm1_lhs_buffer =
- se::DeviceMemory<ElementType>(params.bmm2_grad_gemm1_lhs_buffer);
- auto bmm2_grad_gemm2_rhs_buffer =
- se::DeviceMemory<ElementType>(params.bmm2_grad_gemm2_rhs_buffer);
- auto d_output_buffer = se::DeviceMemory<ElementType>(params.d_output_buffer);
- auto d_bmm1_lhs_buffer =
- se::DeviceMemory<OutputType>(params.d_bmm1_lhs_buffer);
- auto d_bmm1_rhs_buffer =
- se::DeviceMemory<OutputType>(params.d_bmm1_rhs_buffer);
- auto d_bmm2_rhs_buffer =
- se::DeviceMemory<OutputType>(params.d_bmm2_rhs_buffer);
-
- // optional buffers
- auto d_s_buffer = params.d_s_buffer.has_value()
- ? se::DeviceMemory<OutputType>(*params.d_s_buffer)
- : se::DeviceMemoryBase();
-
- auto d_bias_buffer = params.d_bias_buffer.has_value()
- ? se::DeviceMemory<OutputType>(*params.d_bias_buffer)
- : se::DeviceMemoryBase();
-
- auto fwd_output_buffer =
- params.fwd_output_buffer.has_value()
- ? se::DeviceMemory<ElementType>(*params.fwd_output_buffer)
- : se::DeviceMemoryBase();
-
- auto bias_buffer = params.bias_buffer.has_value()
- ? se::DeviceMemory<BiasType>(*params.bias_buffer)
- : se::DeviceMemoryBase();
-
- auto seqlen_q_buffer =
- params.seqlen_q_buffer.has_value()
- ? se::DeviceMemory<BiasType>(*params.seqlen_q_buffer)
- : se::DeviceMemoryBase();
-
- auto seqlen_k_buffer =
- params.seqlen_k_buffer.has_value()
- ? se::DeviceMemory<BiasType>(*params.seqlen_k_buffer)
- : se::DeviceMemoryBase();
-
- se::dnn::AlgorithmDesc algorithm = params.config->algorithm;
- if (options.runner_cache) {
- algorithm = options.runner_cache->ToAlgorithmDesc();
- }
-
- absl::Status run_status = absl::OkStatus();
- switch (params.config->kind) {
- case CudnnfMHAKind::kBackwardSoftmaxDropout:
- case CudnnfMHAKind::kBackwardSoftmax:
- case CudnnfMHAKind::kBackwardScaleBiasSoftmax:
- case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout:
- run_status = RunFusedMHABackward<ElementType, OutputType>(
- params, stream, options, bmm1_grad_gemm1_rhs_buffer,
- bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer,
- bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer,
- d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer,
- fwd_output_buffer, bias_buffer, scratch_memory, seqlen_q_buffer,
- seqlen_k_buffer);
- break;
- default:
- return Internal("Invalid cuDNN fMHA kind");
- }
-
- if (!run_status.ok()) {
- return run_status;
- }
-
- if (!stream->ok()) {
- return Internal("Unable to launch FMHA with type %s and algorithm %s",
- CudnnfMHAKindToString(params.config->kind),
- algorithm.ToString());
- }
-
- return run_status;
-}
-} // namespace
-
-/*static*/ absl::StatusOr<GpufMHAConfig> GpufMHAConfig::For(
- const GpufMHADescriptor &desc) {
- // Get shapes from desc.
- const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape;
- const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape;
- const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape;
- const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape;
- const Shape &output_shape = desc.output_shapes[0];
-
- // Get DNN dtype from primtive types
- TF_ASSIGN_OR_RETURN(
- DataType lhs_bmm1_type,
- GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type()));
- TF_ASSIGN_OR_RETURN(
- DataType rhs_bmm1_type,
- GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(
- DataType rhs_bmm2_type,
- GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type()));
- TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type,
- GetDNNDataTypeFromPrimitiveType(
- intermediate_lhs_bmm2_shape.element_type()));
- TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType(
- output_shape.element_type()));
- GpufMHAConfig config;
- config.input_type = lhs_bmm1_shape.element_type();
- config.output_type = output_shape.element_type();
-
- // Get MatmulTensorDescriptors for BMM1
- config.lhs_bmm1 =
- MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(),
- desc.lhs_bmm1_shape.layout().minor_to_major(),
- desc.bmm1_dnums.lhs_batch_dimensions(),
- desc.bmm1_dnums.lhs_contracting_dimensions());
- config.rhs_bmm1 =
- MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(),
- desc.rhs_bmm1_shape.layout().minor_to_major(),
- desc.bmm1_dnums.rhs_batch_dimensions(),
- desc.bmm1_dnums.rhs_contracting_dimensions());
-
- // Get MatmulTensorDescriptors for BMM2
- config.rhs_bmm2 =
- MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(),
- desc.rhs_bmm2_shape.layout().minor_to_major(),
- desc.bmm2_dnums.rhs_batch_dimensions(),
- desc.bmm2_dnums.rhs_contracting_dimensions());
-
- config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For(
- lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(),
- desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(),
- desc.bmm2_dnums.lhs_batch_dimensions(),
- desc.bmm2_dnums.lhs_contracting_dimensions());
-
- config.output = TensorDescriptor::For(output_type, output_shape.dimensions(),
- output_shape.layout().minor_to_major());
-
- if (desc.output_shapes.size() > 1) {
- const Shape &activation_shape = desc.output_shapes.back();
- // Generally, activation should have same type as output, but set it
- // explicityly just to be safe.
- TF_ASSIGN_OR_RETURN(
- DataType activation_type,
- GetDNNDataTypeFromPrimitiveType(activation_shape.element_type()));
- config.activation =
- TensorDescriptor::For(activation_type, activation_shape.dimensions(),
- activation_shape.layout().minor_to_major());
- }
-
- if (desc.mask_shape) {
- const Shape &mask_shape = *desc.mask_shape;
- TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
- mask_shape.element_type()));
- config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
- mask_shape.layout().minor_to_major());
- }
-
- if (desc.bias_shape) {
- const Shape &bias_shape = *desc.bias_shape;
- TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
- bias_shape.element_type()));
- config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
- bias_shape.layout().minor_to_major());
- }
- config.kind = desc.kind;
- config.mask_type = desc.mask_type;
- const CudnnfMHABackendConfig &backend_config = desc.backend_config;
- config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
- config.fmha_scale.emplace(backend_config.fmha_scale());
- config.dropout_rate.emplace(backend_config.dropout_rate());
- config.seed.emplace(backend_config.seed());
- return config;
-}
-
-absl::StatusOr<se::dnn::FusedMHAOp::Config>
-GpufMHAConfig::AsDnnFusedMHAOpConfig() const {
- double scale = 1.0;
- if (fmha_scale.has_value()) {
- scale = *fmha_scale;
- }
- TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type,
- GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type));
-
- return se::dnn::FusedMHAOp::Config{
- scale, lhs_bmm1, rhs_bmm1, rhs_bmm2, intermediate_lhs_bmm2,
- output, bias, activation, dropout_rate, seed,
- mask_type};
-}
-
-/*static*/ absl::StatusOr<GpufMHABackwardConfig> GpufMHABackwardConfig::For(
- const GpufMHABackwardDescriptor &desc) {
- // Get shapes from desc.
-
- const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape;
- const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape;
- const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape;
- const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape;
- const Shape &d_output_shape = desc.d_output_shape;
- const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape;
- const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape;
- const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape;
- // Get DNN dtype from primtive types
- TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type,
- GetDNNDataTypeFromPrimitiveType(
- bmm1_grad_gemm1_rhs_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type,
- GetDNNDataTypeFromPrimitiveType(
- bmm1_grad_gemm2_rhs_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type,
- GetDNNDataTypeFromPrimitiveType(
- bmm2_grad_gemm1_lhs_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type,
- GetDNNDataTypeFromPrimitiveType(
- bmm2_grad_gemm2_rhs_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(
- DataType d_output_type,
- GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(
- DataType d_bmm1_lhs_type,
- GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(
- DataType d_bmm1_rhs_type,
- GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type()));
-
- TF_ASSIGN_OR_RETURN(
- DataType d_bmm2_rhs_type,
- GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type()));
-
- GpufMHABackwardConfig config;
- config.input_type = bmm1_grad_gemm1_rhs_shape.element_type();
- config.output_type = d_bmm1_lhs_shape.element_type();
-
- // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1
- config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For(
- bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(),
- desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(),
- desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(),
- desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions());
-
- // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2
- config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For(
- bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(),
- desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(),
- desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(),
- desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions());
-
- // Get MatmulTensorDescriptors for BMM2 grad GEMM 1
- config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For(
- bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
- desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(),
- desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(),
- desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions());
-
- config.d_output = MatmulTensorDescriptor::For(
- d_output_type, d_output_shape.dimensions(),
- desc.d_output_shape.layout().minor_to_major(),
- desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(),
- desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions());
-
- // Get MatmulTensorDescriptors for BMM2 grad GEMM 2
- config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For(
- bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(),
- desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(),
- desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(),
- desc.bmm2_grad_gemm2_dnums
- .rhs_contracting_dimensions()); // FMHA TODO: transpose here?
-
- config.d_bmm1_lhs =
- TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(),
- d_bmm1_lhs_shape.layout().minor_to_major());
- config.d_bmm1_rhs =
- TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(),
- d_bmm1_rhs_shape.layout().minor_to_major());
- config.d_bmm2_rhs =
- TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(),
- d_bmm2_rhs_shape.layout().minor_to_major());
- config.d_s = TensorDescriptor::For(
- bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
- bmm2_grad_gemm1_lhs_shape.layout().minor_to_major());
-
- if (desc.d_bias_shape) {
- const Shape &d_bias_shape = *desc.d_bias_shape;
- // Get DNN dtype from primtive types
- TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType(
- d_bias_shape.element_type()));
- config.d_bias =
- TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(),
- d_bias_shape.layout().minor_to_major());
- }
-
- if (desc.mask_shape) {
- const Shape &mask_shape = *desc.mask_shape;
- TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
- mask_shape.element_type()));
- config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
- mask_shape.layout().minor_to_major());
- }
- if (desc.fwd_output_shape) {
- const Shape &fwd_output_shape = *desc.fwd_output_shape;
- TF_ASSIGN_OR_RETURN(
- DataType fwd_output_type,
- GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type()));
- config.fwd_output =
- TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(),
- fwd_output_shape.layout().minor_to_major());
- }
-
- if (desc.bias_shape) {
- const Shape &bias_shape = *desc.bias_shape;
- TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
- bias_shape.element_type()));
- config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
- bias_shape.layout().minor_to_major());
- }
-
- config.kind = desc.kind;
- config.mask_type = desc.mask_type;
- config.force_deterministic = desc.force_deterministic;
- const CudnnfMHABackendConfig &backend_config = desc.backend_config;
- config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
- config.fmha_scale.emplace(backend_config.fmha_scale());
- config.dropout_rate.emplace(backend_config.dropout_rate());
- config.seed.emplace(backend_config.seed());
- return config;
-}
-
-absl::StatusOr<se::dnn::FusedMHABackwardOp::Config>
-GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const {
- double scale = 1.0;
- if (fmha_scale.has_value()) {
- scale = *fmha_scale;
- }
- TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type,
- GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type));
-
- return se::dnn::FusedMHABackwardOp::Config{scale,
- bmm1_grad_gemm1_rhs,
- bmm1_grad_gemm2_rhs,
- bmm2_grad_gemm1_lhs,
- bmm2_grad_gemm2_rhs,
- d_output,
- d_bmm1_lhs,
- d_bmm1_rhs,
- d_bmm2_rhs,
- d_s,
- d_bias,
- fwd_output,
- bias,
- dropout_rate,
- seed,
- mask_type,
- force_deterministic};
-}
-
-/*static*/ absl::StatusOr<GpufMHAParams> GpufMHAParams::For(
- const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer,
- se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer,
- se::DeviceMemoryBase output_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> activation_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer) {
- GpufMHAParams params;
- params.config = &config;
- params.lhs_bmm1_buffer = lhs_bmm1_buffer;
- params.rhs_bmm1_buffer = rhs_bmm1_buffer;
- params.rhs_bmm2_buffer = rhs_bmm2_buffer;
- params.output_buffer = output_buffer;
- params.activation_buffer = activation_buffer;
- params.bias_buffer = bias_buffer;
- params.seqlen_q_buffer = seqlen_q_buffer;
- params.seqlen_k_buffer = seqlen_k_buffer;
- return params;
-}
-
-/*static*/ absl::StatusOr<GpufMHABackwardParams> GpufMHABackwardParams::For(
- const GpufMHABackwardConfig &config,
- se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
- se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase d_output_buffer,
- se::DeviceMemoryBase d_bmm1_lhs_buffer,
- se::DeviceMemoryBase d_bmm1_rhs_buffer,
- se::DeviceMemoryBase d_bmm2_rhs_buffer,
- std::optional<se::DeviceMemoryBase> d_s_buffer,
- std::optional<se::DeviceMemoryBase> d_bias_buffer,
- std::optional<se::DeviceMemoryBase> fwd_output_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer) {
- GpufMHABackwardParams params;
- params.config = &config;
- params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer;
- params.bmm1_grad_gemm2_rhs_buffer = bmm1_grad_gemm2_rhs_buffer;
- params.bmm2_grad_gemm1_lhs_buffer = bmm2_grad_gemm1_lhs_buffer;
- params.bmm2_grad_gemm2_rhs_buffer = bmm2_grad_gemm2_rhs_buffer;
- params.d_output_buffer = d_output_buffer;
- params.d_bmm1_lhs_buffer = d_bmm1_lhs_buffer;
- params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer;
- params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer;
- params.d_s_buffer = d_s_buffer;
- params.d_bias_buffer = d_bias_buffer;
- params.fwd_output_buffer = fwd_output_buffer;
- params.bias_buffer = bias_buffer;
- params.seqlen_q_buffer = seqlen_q_buffer;
- params.seqlen_k_buffer = seqlen_k_buffer;
- return params;
-}
-
-absl::Status RunGpuFMHA(const GpufMHAConfig &fmha_config,
- se::DeviceMemoryBase lhs_bmm1_buffer,
- se::DeviceMemoryBase rhs_bmm1_buffer,
- se::DeviceMemoryBase rhs_bmm2_buffer,
- se::DeviceMemoryBase output_buffer,
- se::DeviceMemoryBase scratch_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> activation_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer,
- se::Stream *stream, RunFusedMHAOptions options) {
- TF_ASSIGN_OR_RETURN(
- GpufMHAParams params,
- GpufMHAParams::For(fmha_config, lhs_bmm1_buffer, rhs_bmm1_buffer,
- rhs_bmm2_buffer, output_buffer, bias_buffer,
- activation_buffer, seqlen_q_buffer, seqlen_k_buffer));
- PrimitiveType input_primitive_type = fmha_config.input_type;
- switch (input_primitive_type) {
- case F16:
- return RunGpuFMHAImpl<Eigen::half, Eigen::half, Eigen::half>(
- params, stream, scratch_buffer, options);
- case BF16:
- return RunGpuFMHAImpl<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16>(
- params, stream, scratch_buffer, options);
- default:
- return absl::UnimplementedError(absl::StrFormat(
- "Unimplemented fused MHA with %s", ToString(fmha_config)));
- }
- return absl::OkStatus();
-}
-
-absl::Status RunGpuFMHABackward(
- const GpufMHABackwardConfig &fmha_config,
- se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
- se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer,
- se::DeviceMemoryBase d_bmm1_lhs_buffer,
- se::DeviceMemoryBase d_bmm1_rhs_buffer,
- se::DeviceMemoryBase d_bmm2_rhs_buffer,
- std::optional<se::DeviceMemoryBase> d_s_buffer,
- std::optional<se::DeviceMemoryBase> d_bias_buffer,
- std::optional<se::DeviceMemoryBase> fwd_output_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer, se::Stream *stream,
- RunFusedMHABackwardOptions options) {
- TF_ASSIGN_OR_RETURN(
- GpufMHABackwardParams params,
- GpufMHABackwardParams::For(
- fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
- bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer,
- d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer,
- d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, fwd_output_buffer,
- bias_buffer, seqlen_q_buffer, seqlen_k_buffer));
- PrimitiveType input_primitive_type = fmha_config.input_type;
- switch (input_primitive_type) {
- case F16:
- return RunGpuFMHABackwardImpl<Eigen::half, Eigen::half, Eigen::half>(
- params, stream, scratch_buffer, options);
- case BF16:
- return RunGpuFMHABackwardImpl<Eigen::bfloat16, Eigen::bfloat16,
- Eigen::bfloat16>(params, stream,
- scratch_buffer, options);
- default:
- return Unimplemented("Unimplemented fused MHA backward");
- }
- return absl::OkStatus();
-}
-
-std::string ToString(const GpufMHAConfig &config) {
- std::string result = "GpufMHAConfig:\n";
- absl::StrAppend(&result,
- "input_type: ", PrimitiveType_Name(config.input_type), ", ");
- absl::StrAppend(
- &result, "output_type: ", PrimitiveType_Name(config.output_type), ", ");
- absl::StrAppend(&result, "Kind: ", CudnnfMHAKindToString(config.kind), ", ");
- if (config.fmha_scale) {
- absl::StrAppend(&result, "fmha_scale: ", *config.fmha_scale, ", ");
- }
- if (config.dropout_rate) {
- absl::StrAppend(&result, "dropout_rate: ", *config.dropout_rate, ", ");
- }
- if (config.seed) {
- absl::StrAppend(&result, "seed: ", *config.seed, ", ");
- }
- absl::StrAppend(&result, "Algorithm Desc: ", config.algorithm.ToString(),
- "\n");
- absl::StrAppend(&result, "lhs_bmm1: ", config.lhs_bmm1.ToString(), "\n");
- absl::StrAppend(&result, "rhs_bmm1: ", config.rhs_bmm1.ToString(), "\n");
- absl::StrAppend(&result, "rhs_bmm2: ", config.rhs_bmm2.ToString(), "\n");
- absl::StrAppend(&result, "intermediate_lhs_bmm2: ",
- config.intermediate_lhs_bmm2.ToString(), "\n");
- absl::StrAppend(&result, "output: ", config.output.ToString(), "\n");
-
- if (config.mask) {
- absl::StrAppend(&result, "mask: ", (*config.mask).ToString(), "\n");
- }
-
- if (config.bias) {
- absl::StrAppend(&result, "bias: ", (*config.bias).ToString(), "\n");
- }
-
- return result;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h
deleted file mode 100644
index d0621cb..0000000
--- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h
+++ /dev/null
@@ -1,431 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_
-#define XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-inline absl::StatusOr<xla::gpu::CudnnfMHAMaskKind> AsCudnnFmhaMaskKind(
- xla::gpu::CudnnfMHABackendConfig_MaskType mask_type) {
- switch (mask_type) {
- case xla::gpu::CudnnfMHABackendConfig::NO_MASK:
- return xla::gpu::CudnnfMHAMaskKind::kNoMask;
- case xla::gpu::CudnnfMHABackendConfig::PADDING:
- return xla::gpu::CudnnfMHAMaskKind::kPadding;
- case xla::gpu::CudnnfMHABackendConfig::CAUSAL:
- return xla::gpu::CudnnfMHAMaskKind::kCausal;
- case xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL:
- return xla::gpu::CudnnfMHAMaskKind::kPaddingCausal;
- case xla::gpu::CudnnfMHABackendConfig::ALIBI:
- return xla::gpu::CudnnfMHAMaskKind::kAlibi;
- default:
- return xla::Internal("Unknown fmha mask kind.");
- }
-}
-
-// This is an interim structure to hold the parameters to construct a
-// GpufMHAConfig.
-// Struct to describe properties of a FMHA without being tied to specific
-// IR. Will be used to help build FMHA thunks from either XLA HLO or
-// LHLO GPU dialect in MLIR.
-struct GpufMHADescriptor {
- CudnnfMHAKind kind;
- CudnnfMHABackendConfig backend_config;
- CudnnfMHAMaskKind mask_type;
- Shape lhs_bmm1_shape;
- Shape rhs_bmm1_shape;
- Shape rhs_bmm2_shape;
- Shape intermediate_lhs_bmm2_shape;
- // This will contain both output shape and activation shape
- absl::InlinedVector<Shape, 2> output_shapes;
- DotDimensionNumbers bmm1_dnums;
- DotDimensionNumbers bmm2_dnums;
-
- std::optional<Shape> mask_shape;
- std::optional<Shape> bias_shape;
-};
-
-struct GpufMHABackwardDescriptor {
- CudnnfMHAKind kind;
- CudnnfMHABackendConfig backend_config;
- CudnnfMHAMaskKind mask_type;
- Shape bmm1_grad_gemm1_rhs_shape;
- Shape bmm1_grad_gemm2_rhs_shape;
- Shape bmm2_grad_gemm1_lhs_shape;
- Shape bmm2_grad_gemm2_rhs_shape;
- Shape d_output_shape;
- Shape d_bmm1_lhs_shape;
- Shape d_bmm1_rhs_shape;
- Shape d_bmm2_rhs_shape;
- DotDimensionNumbers bmm1_grad_gemm1_dnums;
- DotDimensionNumbers bmm1_grad_gemm2_dnums;
- DotDimensionNumbers bmm2_grad_gemm1_dnums;
- DotDimensionNumbers bmm2_grad_gemm2_dnums;
-
- std::optional<Shape> d_s_shape;
- std::optional<Shape> fwd_output_shape;
- std::optional<Shape> mask_shape;
- std::optional<Shape> d_bias_shape;
- std::optional<Shape> bias_shape;
- bool force_deterministic;
-};
-
-// Structure to describe static properties of a GPU fused Multi-Headed
-// Attention.
-struct GpufMHAConfig {
- static absl::StatusOr<GpufMHAConfig> For(const GpufMHADescriptor& fmha_desc);
-
- absl::StatusOr<se::dnn::FusedMHAOp::Config> AsDnnFusedMHAOpConfig() const;
-
- PrimitiveType
- input_type; // Capture the primitive type of one of the inputs of BMM1
- PrimitiveType output_type;
- CudnnfMHAKind kind;
- std::optional<double> fmha_scale;
- std::optional<double> dropout_rate;
- std::optional<int64_t> seed;
-
- se::dnn::AlgorithmDesc algorithm;
- CudnnfMHAMaskKind mask_type;
- // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len]
- // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
- se::dnn::MatmulTensorDescriptor lhs_bmm1;
- se::dnn::MatmulTensorDescriptor rhs_bmm1;
- se::dnn::MatmulTensorDescriptor rhs_bmm2;
- se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2;
- se::dnn::TensorDescriptor output;
-
- std::optional<se::dnn::TensorDescriptor> activation;
- std::optional<se::dnn::TensorDescriptor> mask;
- std::optional<se::dnn::TensorDescriptor> bias;
-};
-
-// Structure to describe static properties of a GPU fused Multi-Headed
-// Attention backward.
-struct GpufMHABackwardConfig {
- static absl::StatusOr<GpufMHABackwardConfig> For(
- const GpufMHABackwardDescriptor& fmha_desc);
-
- absl::StatusOr<se::dnn::FusedMHABackwardOp::Config>
- AsDnnFusedMHABackwardOpConfig() const;
-
- PrimitiveType
- input_type; // Capture the primitive type of one of the inputs of BMM1
- PrimitiveType output_type;
- CudnnfMHAKind kind;
- std::optional<double> fmha_scale;
- std::optional<double> dropout_rate;
- std::optional<int64_t> seed;
-
- se::dnn::AlgorithmDesc algorithm;
- CudnnfMHAMaskKind mask_type;
- // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
- // d_bias -> [1, num_heads, q_seq_len, kv_seq_len]
- se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs;
- se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs;
- se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs;
- se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs;
- se::dnn::MatmulTensorDescriptor d_output;
- se::dnn::TensorDescriptor d_bmm1_lhs;
- se::dnn::TensorDescriptor d_bmm1_rhs;
- se::dnn::TensorDescriptor d_bmm2_rhs;
- std::optional<se::dnn::TensorDescriptor> d_s;
- std::optional<se::dnn::TensorDescriptor> mask;
- std::optional<se::dnn::TensorDescriptor> d_bias;
- std::optional<se::dnn::TensorDescriptor> fwd_output;
- std::optional<se::dnn::TensorDescriptor> bias;
- bool force_deterministic;
-};
-
-// Implementation struct exposed for debugging and log analysis.
-struct GpufMHAParams {
- static absl::StatusOr<GpufMHAParams> For(
- const GpufMHAConfig& config, se::DeviceMemoryBase lhs_bmm1_buffer,
- se::DeviceMemoryBase rhs_bmm1_buffer,
- se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> activation_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer);
-
- const GpufMHAConfig* config; // Not owned
- se::DeviceMemoryBase lhs_bmm1_buffer;
- se::DeviceMemoryBase rhs_bmm1_buffer;
- se::DeviceMemoryBase rhs_bmm2_buffer;
- se::DeviceMemoryBase output_buffer;
- std::optional<se::DeviceMemoryBase> activation_buffer;
- std::optional<se::DeviceMemoryBase> bias_buffer;
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer;
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer;
-};
-
-struct GpufMHABackwardParams {
- static absl::StatusOr<GpufMHABackwardParams> For(
- const GpufMHABackwardConfig& config,
- se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
- se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase d_output_buffer,
- se::DeviceMemoryBase d_bmm1_lhs_buffer,
- se::DeviceMemoryBase d_bmm1_rhs_buffer,
- se::DeviceMemoryBase d_bmm2_rhs_buffer,
- std::optional<se::DeviceMemoryBase> d_s_buffer,
- std::optional<se::DeviceMemoryBase> d_bias_buffer,
- std::optional<se::DeviceMemoryBase> fwd_output_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer);
-
- const GpufMHABackwardConfig* config; // Not owned
- se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer;
- se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer;
- se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer;
- se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer;
- se::DeviceMemoryBase d_output_buffer;
- se::DeviceMemoryBase d_bmm1_lhs_buffer;
- se::DeviceMemoryBase d_bmm1_rhs_buffer;
- se::DeviceMemoryBase d_bmm2_rhs_buffer;
- std::optional<se::DeviceMemoryBase> d_s_buffer;
- std::optional<se::DeviceMemoryBase> d_bias_buffer;
- std::optional<se::DeviceMemoryBase> fwd_output_buffer;
- std::optional<se::DeviceMemoryBase> bias_buffer;
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer;
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer;
-};
-
-class FusedMultiHeadedAttentionRunner {
- public:
- using Repr =
- std::variant<std::monostate, // To allow XXX default ctor
- std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>>;
-
- FusedMultiHeadedAttentionRunner() = default;
-
- explicit FusedMultiHeadedAttentionRunner(
- std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>> runner)
- : repr_(std::move(runner)) {}
-
- explicit FusedMultiHeadedAttentionRunner(Repr runner)
- : repr_(std::move(runner)) {}
-
- explicit FusedMultiHeadedAttentionRunner(const GpufMHAConfig& config)
- : FusedMultiHeadedAttentionRunner(CreateRunner(config)) {
- if (std::holds_alternative<std::monostate>(repr_)) {
- CHECK(false) << "Cannot construct FusedMultiHeadedAttentionRunner with "
- "std::monostate";
- }
- }
-
- se::dnn::AlgorithmDesc ToAlgorithmDesc() const {
- return std::visit(ToAlgorithmDescVisitor{}, repr_);
- }
-
- se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* AsFusedMHARunner() {
- CHECK(std::holds_alternative<
- std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>>(repr_));
- return std::get<
- std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>>(
- repr_)
- .get();
- }
-
- private:
- // The CreateRunner function is defined as static because it
- // doesn't need access to any non-static member variables of the
- // FusedMultiHeadedAttentionRunner class. Defining it static makes it easy to
- // use and makes it clear that it is a utility function that doesn't rely on
- // the state of any specific instance of the class.
- static Repr CreateRunner(const GpufMHAConfig& config) {
- switch (config.kind) {
- case CudnnfMHAKind::kSoftmaxDropout:
- case CudnnfMHAKind::kSoftmax:
- case CudnnfMHAKind::kScaleBiasSoftmax:
- case CudnnfMHAKind::kScaleBiasSoftmaxDropout:
- return std::make_unique<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>(
- config.algorithm);
- default:
- LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in "
- "FusedMultiHeadedAttentionRunner";
- }
- }
-
- struct ToAlgorithmDescVisitor {
- template <typename RunnerPtr>
- se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) {
- return runner->ToAlgorithmDesc();
- }
-
- se::dnn::AlgorithmDesc operator()(const std::monostate&) {
- CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc";
- }
- };
-
- Repr repr_;
-};
-
-class FusedMultiHeadedAttentionBackwardRunner {
- public:
- using Repr = std::variant<
- std::monostate, // To allow XXX default ctor
- std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>>;
-
- FusedMultiHeadedAttentionBackwardRunner() = default;
-
- explicit FusedMultiHeadedAttentionBackwardRunner(
- std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>
- runner)
- : repr_(std::move(runner)) {}
-
- explicit FusedMultiHeadedAttentionBackwardRunner(Repr runner)
- : repr_(std::move(runner)) {}
-
- explicit FusedMultiHeadedAttentionBackwardRunner(
- const GpufMHABackwardConfig& config)
- : FusedMultiHeadedAttentionBackwardRunner(CreateRunner(config)) {
- if (std::holds_alternative<std::monostate>(repr_)) {
- CHECK(false)
- << "Cannot construct FusedMultiHeadedAttentionBackwardRunner with "
- "std::monostate";
- }
- }
-
- se::dnn::AlgorithmDesc ToAlgorithmDesc() const {
- return std::visit(ToAlgorithmDescVisitor{}, repr_);
- }
-
- se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>*
- AsFusedMHABackwardRunner() {
- CHECK(std::holds_alternative<
- std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>>(
- repr_));
- return std::get<std::unique_ptr<
- se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>>(repr_)
- .get();
- }
-
- private:
- // The CreateRunner function is defined as static because it
- // doesn't need access to any non-static member variables of the
- // FusedMultiHeadedAttentionBackwardRunner class. Defining it static makes it
- // easy to use and makes it clear that it is a utility function that doesn't
- // rely on the state of any specific instance of the class.
- static Repr CreateRunner(const GpufMHABackwardConfig& config) {
- switch (config.kind) {
- case CudnnfMHAKind::kBackwardSoftmaxDropout:
- case CudnnfMHAKind::kBackwardSoftmax:
- case CudnnfMHAKind::kBackwardScaleBiasSoftmax:
- case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout:
- return std::make_unique<
- se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>(
- config.algorithm);
- default:
- LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in "
- "FusedMultiHeadedAttentionBackwardRunner";
- }
- }
-
- struct ToAlgorithmDescVisitor {
- template <typename RunnerPtr>
- se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) {
- return runner->ToAlgorithmDesc();
- }
-
- se::dnn::AlgorithmDesc operator()(const std::monostate&) {
- CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc";
- }
- };
-
- Repr repr_;
-};
-
-struct RunFusedMHAOptions {
- // Nullable output-parameter pointer for profiling results.
- // Profile results remain unused for now since cuDNN FMHA has only one
- // algorithm for now.
- se::dnn::ProfileResult* profile_result = nullptr;
-
- // Use this runner cache (and its configured algorithm), instead of the one
- // from the instruction.
- FusedMultiHeadedAttentionRunner* runner_cache;
-};
-
-struct RunFusedMHABackwardOptions {
- // Nullable output-parameter pointer for profiling results.
- // Profile results remain unused for now since cuDNN FMHA has only one
- // algorithm for now.
- se::dnn::ProfileResult* profile_result = nullptr;
-
- // Use this runner cache (and its configured algorithm), instead of the one
- // from the instruction.
- FusedMultiHeadedAttentionBackwardRunner* runner_cache;
-};
-
-absl::Status RunGpuFMHA(const GpufMHAConfig& fmha_config,
- se::DeviceMemoryBase lhs_bmm1_buffer,
- se::DeviceMemoryBase rhs_bmm1_buffer,
- se::DeviceMemoryBase rhs_bmm2_buffer,
- se::DeviceMemoryBase output_buffer,
- se::DeviceMemoryBase scratch_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> activation_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer,
- se::Stream* stream, RunFusedMHAOptions = {});
-
-absl::Status RunGpuFMHABackward(
- const GpufMHABackwardConfig& fmha_config,
- se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
- se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
- se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
- se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer,
- se::DeviceMemoryBase d_bmm1_lhs_buffer,
- se::DeviceMemoryBase d_bmm1_rhs_buffer,
- se::DeviceMemoryBase d_bmm2_rhs_buffer,
- std::optional<se::DeviceMemoryBase> d_s_buffer,
- std::optional<se::DeviceMemoryBase> d_bias_buffer,
- std::optional<se::DeviceMemoryBase> fwd_output_buffer,
- std::optional<se::DeviceMemoryBase> bias_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer, se::Stream* stream,
- RunFusedMHABackwardOptions = {});
-
-std::string ToString(const GpufMHAConfig& config);
-
-} // namespace gpu
-} // namespace xla
-#endif // XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc
index c40be16..e8710fe 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible.cc
+++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc
@@ -29,11 +29,13 @@
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/synchronization/mutex.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/permutation_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/reduction_utils.h"
@@ -57,6 +59,103 @@
});
}
+const Shape& GetElementShape(const HloFusionAnalysis& analysis) {
+ const Shape* shape = &analysis.fusion_root(0).shape();
+ while (shape->IsTuple()) {
+ shape = &shape->tuple_shapes(0);
+ }
+ return *shape;
+}
+
+// Computes the maximum valid unroll factor for a given instruction.
+int ComputeMaxUnrollFactor(int64_t num_elements) {
+ constexpr int kMaxUnrollFactor = 4;
+ for (int i = kMaxUnrollFactor; i > 1; i /= 2) {
+ if (num_elements % i == 0) {
+ return i;
+ }
+ }
+ return 1;
+}
+
+// Determines if we enable the row optimized codegen. When we have a fusion with
+// only pointwise operations, scalar broadcasting and row broadcasting, we can
+// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in
+// particular on A100. The int is the number of inputs with rank `out_rank`. Its
+// value is only defined if row vectorization is enabled.
+std::pair<bool /*enabled*/, int> RowVectorizationEnabled(
+ const HloFusionAdaptor& fusion, int64_t out_rank) {
+ auto roots = fusion.GetRoots();
+ const auto is_row_major = [](const HloInstruction* instr) {
+ // Only tested when the inputs are row-major. So only enable that case.
+ // Maybe it would work if only the inner dimensions is contiguous.
+ return LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout());
+ };
+ bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() &&
+ is_row_major(&roots[0].instruction());
+ if (!row_vectorized) {
+ return {false, 0};
+ }
+
+ // Check that the operations in the fusion are supported. Each
+ // supported operation (or category) must be manually vetted as XLA
+ // only unrolls and relies on LLVM to vectorize. But this is brittle.
+ // Currently tested and supported operations:
+ // Elementwise, scalar and row broadcasting.
+ //
+ // We also detect at the same time if there is a row broadcasting
+ // operation.
+ int num_big_inputs = 0;
+ bool some_row_broadcasting = false;
+ HloBfsConsumersFirstTraversal(
+ roots, fusion, [&](auto node) -> TraversalResult {
+ if (!row_vectorized) {
+ return TraversalResult::kInterrupt;
+ }
+
+ if (node.instruction().IsElementwise()) {
+ return TraversalResult::kAdvance;
+ }
+
+ switch (node.opcode()) {
+ case HloOpcode::kConstant:
+ return TraversalResult::kSkip;
+ case HloOpcode::kParameter:
+ return TraversalResult::kAdvance;
+ case HloOpcode::kBroadcast: {
+ auto dims = node.instruction().dimensions();
+ if (dims.empty()) {
+ return TraversalResult::kAdvance;
+ }
+
+ if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) {
+ some_row_broadcasting = true;
+ return TraversalResult::kAdvance;
+ }
+ TF_FALLTHROUGH_INTENDED;
+ }
+ default:
+ VLOG(2) << "Row vectorization not enabled due to: "
+ << node.ToString();
+ row_vectorized = false;
+ return TraversalResult::kInterrupt;
+ }
+ });
+ if (row_vectorized) {
+ for (const HloInstruction* argument : fusion.GetParameters()) {
+ if (argument->shape().rank() == out_rank) {
+ ++num_big_inputs;
+ }
+ if (!is_row_major(argument)) {
+ row_vectorized = false;
+ }
+ };
+ }
+ // Trigger only when there is a row broadcasting.
+ return std::make_pair(row_vectorized && some_row_broadcasting,
+ num_big_inputs);
+}
+
} // namespace
bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
@@ -612,11 +711,16 @@
// from potential x-tiling).
return 4 * 32 * 33 * primitive_size * num_variadic;
}
- } else if (GetDescriptionForTiledTransposeEmitter(instr, instr).has_value()) {
+ } else if (auto tr = GetDescriptionForTiledTransposeEmitter(instr, instr)) {
// Tile size for transposition.
int64_t primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type());
- return 32 * 33 * primitive_size;
+ int64_t bytes_required = 32 * 33 * primitive_size;
+ // If the last dimension is not changed, it becomes part of the tile.
+ if (tr->permutation.back() == tr->permutation.size() - 1) {
+ bytes_required *= tr->dimensions.back();
+ }
+ return bytes_required;
}
// Other fused expressions for now don't need the shared memory budget.
return 0;
@@ -1022,5 +1126,78 @@
return result;
}
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+ const HloFusionAnalysis& analysis) {
+ return ComputeLoopFusionConfig(analysis, GetElementShape(analysis));
+}
+
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+ const HloFusionAnalysis& analysis, const Shape& element_shape) {
+ int unroll_factor = 1;
+ // Unrolling is good to read large inputs with small elements
+ // due to vector loads, but increases the register pressure when one
+ // thread has to produce multiple output elements.
+ // Therefore for fusions with small outputs prefer to use one thread
+ // per output element = no unroll.
+ // Call 'small' fusions that use less threads than the GPU has.
+ int64_t num_elements = ShapeUtil::ElementsIn(element_shape);
+ int64_t n_threads_max = analysis.device_info().threads_per_core_limit() *
+ analysis.device_info().core_count();
+ if (num_elements >= n_threads_max &&
+ !MayPreventVectorization(analysis.fusion())) {
+ unroll_factor = ComputeMaxUnrollFactor(num_elements);
+ }
+ // CHECK that unroll_factor is a power-of-2, as needed by the logic below.
+ CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
+ // Ensure a single thread writes to a byte containing multiple values by
+ // setting unroll_factor to an appropriate number. Setting unroll_factor is
+ // safe even if the new unroll_factor doesn't divide the number of elements,
+ // as the parallel loop emitter will insert a bounds check in this case to
+ // ensure the out-of-bounds element is not computed and written. Setting
+ // unroll_factor is safe even if MayPreventVectorization returns false, as
+ // the MayPreventVectorization check is an optimization, not a correctness
+ // requirement.
+ unroll_factor = std::max(
+ unroll_factor,
+ CeilOfRatio(8, analysis.input_output_info().smallest_output_dtype_bits));
+ CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
+ VLOG(2) << "Unroll factor: " << unroll_factor;
+
+ bool row_vectorized;
+ int num_big_inputs;
+ std::tie(row_vectorized, num_big_inputs) =
+ RowVectorizationEnabled(analysis.fusion(), element_shape.rank());
+ bool few_waves = !HloAnyOf(analysis.fusion(), [&](auto instr) {
+ if (instr.opcode() == HloOpcode::kParameter ||
+ instr.opcode() == HloOpcode::kConstant ||
+ HloInstruction::IsOpElementwise(instr.opcode())) {
+ return false;
+ }
+ if (auto broadcast =
+ DynCast<HloBroadcastInstruction>(&instr.instruction())) {
+ if (broadcast->dimensions().empty() ||
+ // More than 3 big inputs cause a speed regression.
+ (row_vectorized && num_big_inputs <= 3)) {
+ return false;
+ }
+ }
+ VLOG(2) << "few_waves not enabled due to: "
+ << instr.instruction().ToString();
+ return true;
+ });
+
+ LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
+ row_vectorized};
+ // Check that the shapes is supported.
+ if (launch_config.row_vectorized &&
+ ThreadsPerBlockRowVectorized(element_shape, analysis.device_info(),
+ launch_config) <= 0) {
+ VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
+ launch_config.row_vectorized = false;
+ launch_config.few_waves = false;
+ }
+ return launch_config;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h
index 185c440..0dadbfa 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible.h
+++ b/third_party/xla/xla/service/gpu/gpu_fusible.h
@@ -27,12 +27,14 @@
#include "absl/synchronization/mutex.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"
// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from
-// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion.
+// GpuInstructionFusion, FusionMerger, and MultiOutputFusion.
namespace xla {
namespace gpu {
@@ -226,6 +228,12 @@
// instructions it contains.
bool MayPreventVectorization(const HloFusionAdaptor& fusion);
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+ const HloFusionAnalysis& analysis);
+
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+ const HloFusionAnalysis& analysis, const Shape& shape);
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
index 874b9da..aa3713d 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
@@ -1731,5 +1731,23 @@
module->entry_computation()));
}
+TEST_F(GpuFusibleTest, GetSharedMemoryUsage) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ wrapped_transpose {
+ p0 = f32[128,1024,2]{2,1,0} parameter(0)
+ ROOT transpose = f32[1024,128,2]{2,1,0} transpose(p0), dimensions={1,0,2}
+ }
+ ENTRY main {
+ p = f32[128,1024,2] parameter(0)
+ ROOT res = f32[1024,128,2]{2,1,0} fusion(p), kind=kInput, calls=wrapped_transpose
+ })"))
+ .value();
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+ FusionInfoCache cache;
+ auto fusion = module->entry_computation()->root_instruction();
+ EXPECT_EQ(cache.GetSharedMemoryUsage(*fusion), 32 * 33 * 2 * 4);
+}
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
index 2504b43..a6a68b4 100644
--- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
@@ -46,9 +46,9 @@
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
-#include "xla/service/gpu/gpu_schedule_postprocessing.h"
#include "xla/service/gpu/model/analytical_latency_estimator.h"
-#include "xla/service/gpu/scheduling_instruction_annotator.h"
+#include "xla/service/gpu/transforms/schedule_postprocessing.h"
+#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
#include "xla/service/hlo_memory_scheduler.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/service/latency_hiding_scheduler.h"
@@ -74,6 +74,9 @@
case HloOpcode::kAllReduceStart:
case HloOpcode::kCollectivePermuteStart:
return !IsSyncCollective(&instr);
+ case HloOpcode::kAsyncStart:
+ // Start async ops as early as possible to allow more concurrency.
+ return true;
case HloOpcode::kCustomCall:
return static_cast<const HloCustomCallInstruction&>(instr)
.custom_call_schedule() ==
@@ -95,6 +98,10 @@
case HloOpcode::kAllReduceDone:
case HloOpcode::kCollectivePermuteDone:
return ShouldScheduleAsEarlyAsPossible(*instr.operand(0));
+ case HloOpcode::kAsyncDone:
+ // Schedule as many other ops as possible before blocking on the
+ // completion of async ops.
+ return true;
case HloOpcode::kCustomCall:
return static_cast<const HloCustomCallInstruction&>(instr)
.custom_call_schedule() == CustomCallSchedule::SCHEDULE_LATEST;
@@ -513,8 +520,8 @@
TF_RETURN_IF_ERROR(pipeline.Run(module).status());
- HloPassPipeline postprocessing_pipeline("gpu-schedule-postprocessing");
- postprocessing_pipeline.AddPass<GpuSchedulePostprocessing>();
+ HloPassPipeline postprocessing_pipeline("schedule-postprocessing");
+ postprocessing_pipeline.AddPass<SchedulePostprocessing>();
TF_RETURN_IF_ERROR(postprocessing_pipeline.Run(module).status());
return ScheduleMetadata{memory_limit};
diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
index 0304f35..3582d40 100644
--- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -46,6 +46,7 @@
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_utils.h"
+#include "xla/tests/verified_hlo_module.h"
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
@@ -54,6 +55,7 @@
namespace xla {
namespace gpu {
+using ::testing::ElementsAre;
using ::testing::HasSubstr;
using ::tsl::testing::StatusIs;
@@ -1480,5 +1482,48 @@
}
}
+TEST_F(GpuHloScheduleTest, AsyncOps) {
+ const char* hlo_text = R"(
+ HloModule m
+
+ op1 {
+ p0 = f32[2,2] parameter(0)
+ ROOT add = f32[2,2] add(p0, p0)
+ }
+
+ op2 {
+ p0 = f32[2,2] parameter(0)
+ ROOT add = f32[2,2] add(p0, p0)
+ }
+
+ ENTRY main {
+ p0 = f32[2,2] parameter(0)
+ // The `async-start` blocks should be moved up, and the `async-done` blocks
+ // should be moved down.
+ acc1_start = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0),
+ kind=kLoop, calls=op1
+ acc1_done = f32[2,2] fusion-done(acc1_start)
+ acc2_start = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0),
+ kind=kLoop, calls=op2
+ acc2_done = f32[2,2] fusion-done(acc2_start)
+ ROOT done = f32[2,2] add(acc1_done, acc2_done)
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<xla::VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig{}));
+ SequentialHloOrdering order = BuildHloOrdering(module.get());
+
+ std::vector<HloOpcode> opcodes;
+ for (HloInstruction* instruction :
+ order.SequentialOrder(*module->entry_computation())->instructions()) {
+ opcodes.push_back(instruction->opcode());
+ }
+ EXPECT_THAT(opcodes,
+ ElementsAre(HloOpcode::kParameter, HloOpcode::kAsyncStart,
+ HloOpcode::kAsyncStart, HloOpcode::kAsyncDone,
+ HloOpcode::kAsyncDone, HloOpcode::kAdd));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc b/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc
deleted file mode 100644
index 008dbae..0000000
--- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc
+++ /dev/null
@@ -1,596 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_layout_assignment.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <initializer_list>
-#include <memory>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/host_memory_offload_annotations.h"
-#include "xla/service/logical_buffer.h"
-#include "xla/shape.h"
-#include "xla/shape_layout.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tsl/util/env_var.h"
-#include "xla/util.h"
-#include "xla/window_util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-using se::dnn::DataLayout;
-using se::dnn::FilterLayout;
-
-// Returns (input, filter, output) layouts.
-static std::tuple<DataLayout, FilterLayout, DataLayout>
-HeuristicLayoutAssignment(const HloInstruction* instr,
- const se::GpuComputeCapability& gpu_version,
- const se::dnn::VersionInfo& dnn_version) {
- // DataLayout and FilterLayout uses weird enum names. Translations:
- // N <=> Batch or Output
- // C <=> Depth or Input
- // H <=> Y
- // W <=> X
- //
- // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
- //
- // If you have trouble keeping these straight, consider that all that matters
- // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
-
- constexpr auto kAllNCHW =
- std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
- DataLayout::kBatchDepthYX);
- // kBatchDepthYX4 has the same layout as kBatchDepthYX32; they're both VECT_C
- // layouts as far as cudnn is concerned.
- constexpr auto kAllNCHW_VECT_C =
- std::make_tuple(DataLayout::kBatchDepthYX4, FilterLayout::kOutputInputYX4,
- DataLayout::kBatchDepthYX4);
- constexpr auto kAllNHWC =
- std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
- DataLayout::kBatchYXDepth);
-
- // Integer convolution must use NHWC or NCHW_VECT_C.
- //
- // TODO(jlebar): Do non-VECT_C int8_t convs still require NHWC with new
- // versions of cudnn?
- const ConvolutionDimensionNumbers& dnums =
- instr->convolution_dimension_numbers();
- Shape input_shape = instr->operand(0)->shape();
- PrimitiveType input_ty = instr->operand(0)->shape().element_type();
- if (primitive_util::IsIntegralType(input_ty)) {
- if (input_ty == S8 && dnums.input_spatial_dimensions_size() == 2 &&
- input_shape.dimensions_size() == 5) {
- VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString();
- return kAllNCHW_VECT_C;
- }
- VLOG(2) << "Using NHWC for int8_t conv " << instr->ToString();
- return kAllNHWC;
- }
-
- if (primitive_util::IsF8Type(input_ty)) {
- VLOG(2) << "Using NHWC for FP8 conv " << instr->ToString();
- return kAllNHWC;
- }
-
- const DebugOptions& debug_options =
- instr->GetModule()->config().debug_options();
-
- if (debug_options.xla_gpu_force_conv_nchw()) {
- VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
- return kAllNCHW;
- }
-
- if (debug_options.xla_gpu_force_conv_nhwc()) {
- VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
- return kAllNHWC;
- }
-
- const auto* rocm_compute_capability =
- std::get_if<se::RocmComputeCapability>(&gpu_version);
- if (rocm_compute_capability && input_ty == F16) return kAllNHWC;
-
- // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
- // easy: Use NCHW.
- const bool isFloat16 = (input_ty == F16) || (input_ty == BF16);
- if (std::holds_alternative<se::CudaComputeCapability>(gpu_version)) {
- // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
- // easy: Use NCHW.
- const auto* cuda_compute_capability =
- std::get_if<se::CudaComputeCapability>(&gpu_version);
- bool is_volta =
- cuda_compute_capability &&
- cuda_compute_capability->IsAtLeast(se::CudaComputeCapability::VOLTA);
- if (!isFloat16 || !is_volta ||
- instr->shape().tuple_shapes(0).dimensions_size() != 4) {
- return kAllNCHW;
- }
-
- // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
- // convs with stride are significantly faster with NCHW layouts.
- //
- // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
- // which on paper gives good performance. However, there are two
- // observations:
- // * a mixed layout combination is more cuDNN-bug prone, based on empirical
- // evidence.
- // * we've also observed that for mixed layouts, cuDNN transposes data back
- // and forth from a different layout combination. If we end up with
- // transposes anyway, we prefer to have them in XLA, as they can be fused.
- if (std::make_tuple(dnn_version.major_version(),
- dnn_version.minor_version()) <= std::make_tuple(7, 3) &&
- instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
- window_util::HasStride(instr->window())) {
- return kAllNCHW;
- }
- } else if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
- bool is_enabled = false;
- TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_ROCM_NHWC",
- /*default_val=*/false, &is_enabled));
- auto rocm_compute_capability =
- std::get<se::RocmComputeCapability>(gpu_version);
- if (!isFloat16 || (!rocm_compute_capability.has_nhwc_layout_support()) ||
- instr->shape().tuple_shapes(0).dimensions_size() != 4 || !is_enabled) {
- return kAllNCHW;
- }
- }
-
- VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
-
- // For other Volta f16 convolutions, use NHWC.
- return kAllNHWC;
-}
-
-// Adds layout constraints on the cudnn custom-call instruction. The layout
-// constraints are represented in terms of minor_to_major fields of both
-// operands and the output shape. Depending on the underlying algorithm, one of
-// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
-absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
- HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
- Shape lhs_shape = instr->operand(0)->shape();
- Shape rhs_shape = instr->operand(1)->shape();
- Shape result_shape = instr->shape().tuple_shapes(0);
-
- Shape* input_shape;
- Shape* filter_shape;
- Shape* output_shape;
-
- TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
- switch (kind) {
- case CudnnConvKind::kForward:
- case CudnnConvKind::kForwardActivation:
- case CudnnConvKind::kForwardGraph:
- input_shape = &lhs_shape;
- filter_shape = &rhs_shape;
- output_shape = &result_shape;
- break;
- case CudnnConvKind::kBackwardInput:
- input_shape = &result_shape;
- filter_shape = &rhs_shape;
- output_shape = &lhs_shape;
- break;
- case CudnnConvKind::kBackwardFilter:
- input_shape = &lhs_shape;
- filter_shape = &result_shape;
- output_shape = &rhs_shape;
- break;
- }
-
- {
- DataLayout input;
- FilterLayout filter;
- DataLayout output;
- std::tie(input, filter, output) =
- HeuristicLayoutAssignment(instr, gpu_version_, dnn_version_);
-
- TF_ASSIGN_OR_RETURN(
- std::tie(*input_shape->mutable_layout(),
- *filter_shape->mutable_layout(),
- *output_shape->mutable_layout()),
- StreamExecutorConvLayoutsToXlaLayouts(
- instr->convolution_dimension_numbers(), input, filter, output));
- }
-
- // The custom call returns a tuple of (actual_result, scratch_buffer);
- // call_result_buf is the logical buffer for actual_result, the thing that
- // contains the result of the conv call.
- TF_ASSIGN_OR_RETURN(
- const LogicalBuffer* call_result_buf,
- points_to_analysis_->GetBufferDefinedAt(instr, /*index=*/{0}));
-
- // Set layouts of the instructions' shapes.
- TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0));
- TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1));
- TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf));
- // For fused convolutions, instr->operand(2), if exists, is the bias buffer.
- // There is no need to assign layout to it, as it has only one dimension.
- // instr->operand(3), if exists, is the side input buffer.
- if (kind == CudnnConvKind::kForwardActivation &&
- instr->operand_count() == 4) {
- // The side input layout must match the output layout.
- TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3));
- }
-
- // For graph convolutions, align the layouts of the non-scalar inputs to any
- // pointwise ops with the output layout.
- if (kind == CudnnConvKind::kForwardGraph) {
- for (int k = 2; k < instr->operand_count(); ++k) {
- if (!ShapeUtil::IsScalar(instr->operand(k)->shape())) {
- TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, k));
- }
- }
- }
-
- if (instr->operand_count() > 2 && kind != CudnnConvKind::kForwardActivation &&
- kind != CudnnConvKind::kForwardGraph) {
- return Internal(
- "Invalid convolution. Conv has a side input, but kind is not fused "
- "conv forward or graph conv foward: %s",
- instr->ToString());
- }
-
- return absl::OkStatus();
-}
-
-namespace {
-
-// Imposes the default layout with first two dimensions swapped on input
-// `shape`.
-void SetFortranLayout(Shape* shape) {
- LayoutUtil::SetToDefaultLayout(shape);
- int n = shape->mutable_layout()->minor_to_major_size();
- CHECK_GE(n, 2);
- std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
- shape->mutable_layout()->mutable_minor_to_major()->at(1));
-}
-
-bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
- const Shape& shape) {
- const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers();
- // If we are able to construct a `MatrixLayout` then the dot can support
- // this layout.
- return MatrixLayout::For(shape, dot_dims.lhs_batch_dimensions().size(),
- dot->operand(0)->shape().rank() -
- dot_dims.lhs_contracting_dimensions().size() -
- dot_dims.lhs_batch_dimensions().size(),
- dot_dims.rhs_batch_dimensions().size(),
- dot->operand(1)->shape().rank() -
- dot_dims.rhs_contracting_dimensions().size() -
- dot_dims.rhs_batch_dimensions().size())
- .ok();
-}
-
-} // namespace
-
-absl::Status GpuLayoutAssignment::AddBackendConstraints(
- LayoutConstraints* constraints) {
- // Add convolution constraints in reverse postorder that the earliest
- // convolution layout propagates first. This reduces the likelihood of fusion
- // nodes with copies.
- auto post_order = constraints->computation()->MakeInstructionPostOrder();
- for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
- ++iterator) {
- HloInstruction* instruction = *iterator;
- if (IsCustomCallToDnnConvolution(*instruction)) {
- TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
- Cast<HloCustomCallInstruction>(instruction), constraints));
- }
-
- CHECK(!IsCublasGemm(*instruction))
- << "Gemm rewriting should run after layout assignment";
-
- if (instruction->opcode() == HloOpcode::kDot) {
- const Shape& output_shape = instruction->shape();
- const Shape& lhs_shape = instruction->operand(0)->shape();
- const Shape& rhs_shape = instruction->operand(1)->shape();
- const DotDimensionNumbers& dot_dims =
- instruction->dot_dimension_numbers();
-
- // Matmuls require the batch dimensions to be in consecutive physical
- // dimensions and likewise for the contracting and non-contracting
- // dimensions. Additionally, no batch dimension can be in the most
- // minor physical dimension for inputs or the output.
- absl::Span<const int64_t> lhs_batch_dims =
- dot_dims.lhs_batch_dimensions();
- absl::Span<const int64_t> lhs_contracting_dims =
- dot_dims.lhs_contracting_dimensions();
- TF_ASSIGN_OR_RETURN(std::vector<int64_t> lhs_non_contracting_dims,
- GetNonContractingDims(lhs_shape, lhs_batch_dims,
- lhs_contracting_dims));
-
- absl::Span<const int64_t> rhs_batch_dims =
- dot_dims.rhs_batch_dimensions();
- absl::Span<const int64_t> rhs_contracting_dims =
- dot_dims.rhs_contracting_dimensions();
- TF_ASSIGN_OR_RETURN(std::vector<int64_t> rhs_non_contracting_dims,
- GetNonContractingDims(rhs_shape, rhs_batch_dims,
- rhs_contracting_dims));
-
- const DebugOptions& debug_options =
- instruction->GetModule()->config().debug_options();
-
- bool is_bf16_to_bf16 =
- (output_shape.element_type() == PrimitiveType::BF16 &&
- lhs_shape.element_type() == PrimitiveType::BF16 &&
- rhs_shape.element_type() == PrimitiveType::BF16);
- bool is_s8_to_s32 = (output_shape.element_type() == PrimitiveType::S32 &&
- lhs_shape.element_type() == PrimitiveType::S8 &&
- rhs_shape.element_type() == PrimitiveType::S8 &&
- output_shape.dimensions_size() == 2 &&
- lhs_shape.dimensions_size() == 2 &&
- rhs_shape.dimensions_size() == 2);
-
- if (is_s8_to_s32 ||
- (is_bf16_to_bf16 &&
- debug_options.xla_gpu_ensure_minor_dot_contraction_dims())) {
- TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
- instruction, /*operand=*/0,
- /*dim_groups=*/
- {lhs_batch_dims, lhs_non_contracting_dims, lhs_contracting_dims}));
- TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
- instruction, /*operand=*/1,
- /*dim_groups=*/
- {rhs_batch_dims, rhs_non_contracting_dims, rhs_contracting_dims}));
- TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
- } else {
- if (!lhs_batch_dims.empty() || lhs_contracting_dims.size() > 1 ||
- lhs_non_contracting_dims.size() > 1) {
- TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 0, lhs_batch_dims,
- lhs_contracting_dims,
- lhs_non_contracting_dims));
- }
- if (!rhs_batch_dims.empty() || rhs_non_contracting_dims.size() > 1 ||
- rhs_contracting_dims.size() > 1) {
- TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 1, rhs_batch_dims,
- rhs_contracting_dims,
- rhs_non_contracting_dims));
- }
- // If we have at least one batch dimension or there is more than one
- // non-contracting dimension on lhs or rhs, we need to set a layout for
- // the dot output.
- if (!lhs_batch_dims.empty() || lhs_non_contracting_dims.size() > 1 ||
- rhs_non_contracting_dims.size() > 1) {
- TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
- }
- }
- } else if (instruction->opcode() == HloOpcode::kTranspose) {
- const HloInstruction* operand = instruction->operand(0);
- if ((operand->opcode() != HloOpcode::kDot) ||
- (operand->user_count() > 1)) {
- continue;
- }
-
- // If possible, set layout of the dot operation such that the output of
- // the transpose (as a bitcast) has the default layout.
- Shape shape = operand->shape();
- *shape.mutable_layout() =
- LayoutUtil::MakeLayoutFromMajorToMinor(instruction->dimensions());
-
- if (DotCanSupportShapeWithLayout(operand, shape)) {
- TF_RETURN_IF_ERROR(
- SetOperandLayout(shape, instruction, /*operand_no=*/0));
- }
- } else if (instruction->opcode() == HloOpcode::kFft) {
- // cuFFT requires a dim0 major layout.
- Shape op0_shape = instruction->operand(0)->shape();
- LayoutUtil::SetToDefaultLayout(&op0_shape);
- Shape output_shape = instruction->shape();
- LayoutUtil::SetToDefaultLayout(&output_shape);
- TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
- TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
- } else if (instruction->opcode() == HloOpcode::kSort &&
- instruction->operand(0)->shape().rank() > 1) {
- // Make sure that all the operands and the output(s) have the same layout.
- Shape keys_shape = instruction->operand(0)->shape();
- Layout keys_layout =
- LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
- for (int64_t i = 0; i < instruction->operand_count(); ++i) {
- Shape shape = instruction->operand(i)->shape();
- *shape.mutable_layout() = keys_layout;
- TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i));
- const LogicalBuffer* output_buffer;
- if (instruction->shape().IsArray()) {
- TF_ASSIGN_OR_RETURN(
- output_buffer,
- points_to_analysis_->GetBufferDefinedAt(instruction, {}));
- } else {
- TF_ASSIGN_OR_RETURN(
- output_buffer,
- points_to_analysis_->GetBufferDefinedAt(instruction, {i}));
- }
- TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer));
- }
- } else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
- // TODO(phawkins): Ideally we would relax this constraint. What we
- // actually want is that:
- // a) the batch dimensions are major, in no particular order.
- // b) the two minor dimensions are in fortran (column-major) order,
- // although for the 'a' argument we could potentially accept row-major
- // order and fold the transpose into the operator.
- Shape op0_shape = instruction->operand(0)->shape();
- Shape op1_shape = instruction->operand(1)->shape();
- Shape output_shape = instruction->shape();
- SetFortranLayout(&op0_shape);
- SetFortranLayout(&op1_shape);
- SetFortranLayout(&output_shape);
- TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
- TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1));
- TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
- } else if (instruction->opcode() == HloOpcode::kReduceScatter) {
- // XLA:GPU can only support reduce-scatter where the scatter dimension
- // is the most major dimension in the layout.
- auto ars = Cast<HloReduceScatterInstruction>(instruction);
- TF_RETURN_IF_ERROR(SetInstructionLayout(
- ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()),
- ars));
- } else if (instruction->opcode() == HloOpcode::kAllGather) {
- // XLA:GPU can only support all-gathers where the gather dimension is the
- // most major dimension in the layout.
- auto ag = Cast<HloAllGatherInstruction>(instruction);
- TF_RETURN_IF_ERROR(SetInstructionLayout(
- ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()),
- ag));
- } else if (instruction->opcode() == HloOpcode::kAllToAll &&
- instruction->shape().IsArray()) {
- // XLA:GPU can only support all-to-all with split dimensions where the
- // split dimension is the most major dimension in the layout.
- auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
- TF_RETURN_IF_ERROR(SetInstructionLayout(
- ShapeUtil::MoveDimToMajor(all_to_all->shape(),
- *all_to_all->split_dimension()),
- all_to_all));
- } else if (instruction->opcode() == HloOpcode::kSend) {
- Shape s = instruction->operand(0)->shape();
- LayoutUtil::SetToDefaultLayout(&s);
- TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction->operand(0)));
- TF_RETURN_IF_ERROR(
- SetArrayOperandLayout(s.layout(), instruction->operand(0), 0));
- } else if (instruction->opcode() == HloOpcode::kRecv) {
- Shape s = instruction->shape();
- ShapeUtil::ForEachMutableSubshape(
- &s, [&](Shape* subshape, const ShapeIndex& index) {
- LayoutUtil::SetToDefaultLayout(subshape);
- });
- TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction));
- }
- }
- return absl::OkStatus();
-}
-
-absl::Status GpuLayoutAssignment::SetDotOperandLayout(
- const HloInstruction* instruction, int64_t operand,
- absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
- absl::Span<const int64_t> col_dims) {
- Shape shape = instruction->operand(operand)->shape();
-
- // First, try to use the existing layout, if present.
- if (shape.has_layout() &&
- MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
- // Re-set the operand layout, so it becomes mandatory.
- return SetOperandLayout(shape, instruction, operand);
-
- // Next, try the default layout (for the sake of everybody's sanity).
- LayoutUtil::SetToDefaultLayout(&shape);
- if (MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
- return SetOperandLayout(shape, instruction, operand);
-
- // Otherwise, fallback to forcing (batch, rows, cols) layout.
- return SetOperandMajorToMinorLayout(
- instruction, operand,
- /*dim_groups=*/{batch_dims, row_dims, col_dims});
-}
-
-absl::Status GpuLayoutAssignment::SetOperandMajorToMinorLayout(
- const HloInstruction* instruction, int64_t operand,
- std::initializer_list<absl::Span<const int64_t>> dim_groups) {
- size_t size = 0;
- for (auto group : dim_groups) size += group.size();
- std::vector<int64_t> major_to_minor;
- major_to_minor.reserve(size);
- for (const auto& group : dim_groups) {
- major_to_minor.insert(major_to_minor.end(), group.begin(), group.end());
- }
-
- Shape shape = instruction->operand(operand)->shape();
- *shape.mutable_layout() =
- LayoutUtil::MakeLayoutFromMajorToMinor(major_to_minor);
- return SetOperandLayout(shape, instruction, operand);
-}
-
-absl::Status GpuLayoutAssignment::SetDotLayout(
- const HloInstruction* instruction, LayoutConstraints* constraints) {
- // If a user has requested a layout that we can support, use that.
- for (const HloInstruction* user : instruction->users()) {
- for (int64_t i = 0; i < user->operand_count(); ++i) {
- if (user->operand(i) != instruction) {
- continue;
- }
-
- const ShapeLayout* constraint = constraints->OperandLayout(user, i);
- if ((constraint != nullptr) &&
- DotCanSupportShapeWithLayout(instruction, constraint->shape())) {
- return SetInstructionLayout(constraint->shape(), instruction);
- }
- }
- }
-
- // Otherwise, use the default layout.
- return SetInstructionLayout(
- LayoutUtil::GetWithDefaultLayout(instruction->shape()), instruction);
-}
-
-bool GpuLayoutAssignment::PropagateReductionLayoutToOperand(
- const HloInstruction* user) {
- // We try to propagate a layout to make the reduction a row reduction. But
- // propagating the layout is only beneficial if the reduction emitter would be
- // used for the row reduction.
- int64_t reduction_size = 1;
- for (int64_t reduction_dim : user->dimensions()) {
- reduction_size *= user->operand(0)->shape().dimensions(reduction_dim);
- }
- int64_t kept_dimension_size = ShapeUtil::ElementsIn(user->shape());
- return IsUnnestedReductionFasterThanElemental(
- {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}});
-}
-
-bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance(
- const HloInstruction* instruction) {
- // The host offloading custom calls will be eventually removed
- // by the offloader, so we need to make sure that the calls do not change
- // the layout and thus cause layout mismatches after the removal.
- const HloCustomCallInstruction* custom_call =
- DynCast<HloCustomCallInstruction>(instruction);
- if (custom_call != nullptr &&
- (custom_call->custom_call_target() ==
- host_memory_offload_annotations::kMoveToHostCustomCallTarget ||
- custom_call->custom_call_target() ==
- host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
- return false;
- }
-
- return LayoutAssignment::InstructionCanChangeLayoutInstance(instruction);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h b/third_party/xla/xla/service/gpu/gpu_layout_assignment.h
deleted file mode 100644
index 70741fe..0000000
--- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h
+++ /dev/null
@@ -1,81 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
-#define XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
-
-#include <cstdint>
-#include <initializer_list>
-
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/computation_layout.h"
-#include "xla/service/layout_assignment.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-// GPU-specific layout assignment pass which preassigns layouts to satisfy
-// layout constraints for operands and results of library calls.
-class GpuLayoutAssignment : public LayoutAssignment {
- public:
- explicit GpuLayoutAssignment(
- ComputationLayout* entry_computation_layout,
- const se::GpuComputeCapability& gpu_version,
- const se::dnn::VersionInfo& dnn_version,
- ChannelLayoutConstraints* channel_constraints = nullptr)
- : LayoutAssignment(entry_computation_layout, channel_constraints),
- gpu_version_(gpu_version),
- dnn_version_(dnn_version) {}
- ~GpuLayoutAssignment() override = default;
-
- protected:
- absl::Status AddBackendConstraints(LayoutConstraints* constraints) override;
-
- private:
- absl::Status AddBackendConstraintsToDnnConvCustomCall(
- HloCustomCallInstruction* instr, LayoutConstraints* constraints);
-
- // dim_groups are ordered from major to minor dimensions.
- absl::Status SetOperandMajorToMinorLayout(
- const HloInstruction* instruction, int64_t operand,
- std::initializer_list<absl::Span<const int64_t>> dim_groups);
-
- absl::Status SetDotOperandLayout(const HloInstruction* instruction,
- int64_t operand,
- absl::Span<const int64_t> batch_dims,
- absl::Span<const int64_t> row_dims,
- absl::Span<const int64_t> col_dims);
-
- absl::Status SetDotLayout(const HloInstruction* instruction,
- LayoutConstraints* constraints);
-
- bool PropagateReductionLayoutToOperand(const HloInstruction* user) override;
-
- bool InstructionCanChangeLayoutInstance(
- const HloInstruction* instruction) override;
-
- const se::GpuComputeCapability gpu_version_;
- const se::dnn::VersionInfo dnn_version_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc b/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc
deleted file mode 100644
index 81f9e00..0000000
--- a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc
+++ /dev/null
@@ -1,677 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_layout_assignment.h"
-
-#include <cstdint>
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/service/computation_layout.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_layout.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-using ::tsl::testing::IsOkAndHolds;
-
-class LayoutAssignmentTest : public HloTestBase {
- public:
- se::CudaComputeCapability GetCudaComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .cuda_compute_capability();
- }
-
- se::GpuComputeCapability GetGpuComputeCapability() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .gpu_compute_capability();
- }
-
- se::dnn::VersionInfo GetDnnVersion() {
- // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but
- // none of the tests trigger this heuristic.
- return GetDnnVersionInfoOrDefault(backend().default_stream_executor(),
- se::dnn::VersionInfo{8, 3, 0});
- }
-};
-
-TEST_F(LayoutAssignmentTest, Elementwise) {
- Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
- Shape ashape_in_row_major(ashape);
- Shape ashape_in_col_major(ashape);
- *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
- *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
-
- // Enumerate all possible combinations of layouts.
- for (const Shape& lhs_shape_with_layout :
- {ashape_in_row_major, ashape_in_col_major}) {
- for (const Shape& rhs_shape_with_layout :
- {ashape_in_row_major, ashape_in_col_major}) {
- for (const Shape& result_shape_with_layout :
- {ashape_in_row_major, ashape_in_col_major}) {
- // GpuLayoutAssignment should assign the same layout to "add" and its
- // two operands.
- auto builder = HloComputation::Builder(TestName());
- auto x = builder.AddInstruction(
- HloInstruction::CreateParameter(0, ashape, "x"));
- auto y = builder.AddInstruction(
- HloInstruction::CreateParameter(1, ashape, "y"));
- auto add = builder.AddInstruction(
- HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
- auto module = CreateNewVerifiedModule();
- HloComputation* computation =
- module->AddEntryComputation(builder.Build(add));
-
- ComputationLayout computation_layout(
- computation->ComputeProgramShape());
- *computation_layout.mutable_parameter_layout(0) =
- ShapeLayout(lhs_shape_with_layout);
- *computation_layout.mutable_parameter_layout(1) =
- ShapeLayout(rhs_shape_with_layout);
- *computation_layout.mutable_result_layout() =
- ShapeLayout(result_shape_with_layout);
-
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-
- for (const HloInstruction* operand : add->operands()) {
- EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
- operand->shape().layout()));
- }
- }
- }
- }
-}
-
-TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) {
- const char* hlo_text = R"(
- HloModule DotLayout
- ENTRY dot {
- p0 = f32[5,2,3]{1,2,0} parameter(0)
- p1 = f32[5,3,4]{1,2,0} parameter(1)
- ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
- lhs_batch_dims={0}, lhs_contracting_dims={2},
- rhs_batch_dims={0}, rhs_contracting_dims={1}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {1, 2, 0}),
- m::Op().WithShape(F32, {5, 3, 4}, {1, 2, 0}))
- .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
-}
-
-TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) {
- const char* hlo_text = R"(
- HloModule DotLayout
- ENTRY dot {
- p0 = f32[5,3,2] parameter(0)
- p1 = f32[5,4,3]{0,1,2} parameter(1)
- ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
- lhs_batch_dims={0}, lhs_contracting_dims={1},
- rhs_batch_dims={0}, rhs_contracting_dims={2}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 3, 2}, {2, 1, 0}),
- m::Op().WithShape(F32, {5, 4, 3}, {2, 1, 0}))
- .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
-}
-
-TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) {
- const char* hlo_text = R"(
- HloModule DotLayout
- ENTRY dot {
- p0 = f32[2,3,5]{2,1,0} parameter(0)
- p1 = f32[3,4,5] parameter(1)
- ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
- lhs_batch_dims={2}, lhs_contracting_dims={1},
- rhs_batch_dims={2}, rhs_contracting_dims={0}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(m::Op().WithShape(F32, {2, 3, 5}, {0, 1, 2}),
- m::Op().WithShape(F32, {3, 4, 5}, {1, 0, 2}))));
-}
-
-TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) {
- const char* hlo_text = R"(
- HloModule DotLayout
- ENTRY dot {
- p0 = f32[5,6,2,3] parameter(0)
- p1 = f32[6,5,3,4] parameter(1)
- ROOT dot.1330.10585 = f32[5,6,2,4] dot(p0, p1),
- lhs_batch_dims={0,1}, lhs_contracting_dims={3},
- rhs_batch_dims={1,0}, rhs_contracting_dims={2}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 6, 2, 3}, {3, 2, 1, 0}),
- m::Op().WithShape(F32, {6, 5, 3, 4}, {3, 2, 0, 1}))));
-}
-
-TEST_F(LayoutAssignmentTest, TransposedDotLayout) {
- const char* hlo_text = R"(
- HloModule DotLayout
- ENTRY dot {
- p0 = f32[5,2,3] parameter(0)
- p1 = f32[5,3,4,6] parameter(1)
- dot = f32[5,2,4,6] dot(p0, p1),
- lhs_batch_dims={0}, lhs_contracting_dims={2},
- rhs_batch_dims={0}, rhs_contracting_dims={1}
- ROOT out = f32[2,5,4,6] transpose(dot), dimensions={1,0,2,3}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Transpose(
- m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {2, 1, 0}),
- m::Op().WithShape(F32, {5, 3, 4, 6}, {3, 2, 1, 0}))
- .WithShape(F32, {5, 2, 4, 6}, {3, 2, 0, 1}))
- .WithShape(F32, {2, 5, 4, 6}, {3, 2, 1, 0})));
-}
-
-TEST_F(LayoutAssignmentTest, TransposedDotOfDotLayout) {
- const char* hlo_text = R"(
- HloModule DotLayout
- ENTRY dot {
- p0 = f32[8,50] parameter(0)
- p1 = f32[2,8,4,4] parameter(1)
- p2 = f32[4,38] parameter(2)
- dot.1 = f32[50,2,4,4]{3,2,1,0} dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={1}
- dot.2 = f32[50,2,4,38]{3,2,1,0} dot(dot.1, p2),
- lhs_contracting_dims={2}, rhs_contracting_dims={0}
- ROOT out = f32[2,50,38,4]{2,3,0,1} transpose(dot.2), dimensions={1,0,3,2}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- // The transpose layout is not supported by dot.2. Also, we need a copy
- // between dot.1 and dot.2, because the needed operand layout for the lhs of
- // dot.1 cannot be used as layout for dot.1
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Transpose(
- m::Dot(m::Copy(m::Dot(m::Op().WithShape(F32, {8, 50}, {1, 0}),
- m::Op().WithShape(F32, {2, 8, 4, 4},
- {3, 2, 0, 1}))
- .WithShape(F32, {50, 2, 4, 4}, {3, 2, 1, 0}))
- .WithShape(F32, {50, 2, 4, 4}, {3, 1, 0, 2}),
- m::Op().WithShape(F32, {4, 38}, {1, 0}))
- .WithShape(F32, {50, 2, 4, 38}, {3, 2, 1, 0}))
- .WithShape(F32, {2, 50, 38, 4}, {2, 3, 0, 1})));
-}
-
-TEST_F(LayoutAssignmentTest, DotLayoutS8) {
- const char* hlo_text = R"(
- HloModule DotLayout
- ENTRY int8_t {
- p0 = s8[32,64] parameter(0)
- p1 = s8[64,96] parameter(1)
- ROOT out = s32[32,96] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Dot(m::Op().WithShape(S8, {32, 64}, {1, 0}),
- m::Op().WithShape(S8, {64, 96}, {0, 1}))));
-}
-
-TEST_F(LayoutAssignmentTest, SortLayout) {
- const char* hlo_text = R"(
- HloModule SortLayout
-
- compare {
- p.0.lhs = f32[] parameter(0)
- p.0.rhs = f32[] parameter(1)
- p.1.lhs = f32[] parameter(2)
- p.1.rhs = f32[] parameter(3)
- ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
- }
-
- ENTRY sort {
- keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
- values = f32[2,3]{1,0} parameter(0)
- transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
- ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
- dimensions={1}, to_apply=compare
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Sort(m::Op().WithShape(F32, {3, 2}, {1, 0}),
- m::Op().WithShape(F32, {3, 2}, {1, 0}))));
-}
-
-TEST_F(LayoutAssignmentTest, FftLayout) {
- const char* hlo_text = R"(
- HloModule Fft_module
-
- ENTRY Fft {
- input = c64[8,32]{0,1} parameter(0)
- fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
- ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
-
- ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Copy(
- m::Transpose(m::Fft(m::Op().WithShape(C64, {8, 32}, {1, 0}))
- .WithShape(C64, {8, 32}, {1, 0})))));
-}
-
-TEST_F(LayoutAssignmentTest, CustomCallConstrainedAlias) {
- const char* module_str = R"(
-HloModule TestModule
-
-ENTRY entry {
- Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
- Arg_1 = f32[2,5,5]{2,1,0} parameter(1)
- Arg_2 = f32[2,5,5]{2,1,0} parameter(2)
- dot.0 = f32[2,5,5]{2,1,0} dot(Arg_1, Arg_2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}
- custom-call.0 = (f32[2,5,5]{1,2,0}, s8[16]{0}, s8[16]{0}) custom-call(Arg_0, dot.0), custom_call_target="dummy_call", operand_layout_constraints={f32[2,5,5]{1,2,0}, f32[2,5,5]{1,2,0}}, output_to_operand_aliasing={{0}: (1, {})}
- ROOT get-tuple-element.0 = f32[2,5,5]{1,2,0} get-tuple-element(custom-call.0), index=0
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
- ParseAndReturnVerifiedModule(module_str));
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-
- const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
- auto expect_layout = [](const Shape& shape,
- absl::Span<const int64_t> minor_to_major) {
- const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
- EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
- << "Expected layout " << expected << ", actual " << shape.layout();
- };
- expect_layout(ShapeUtil::GetSubshape(call_0->shape(), {0}), {1, 2, 0});
- expect_layout(call_0->operand(0)->shape(), {1, 2, 0});
- expect_layout(call_0->operand(1)->shape(), {1, 2, 0});
-}
-
-TEST_F(LayoutAssignmentTest, MoveToHostCustomCallConstrained) {
- const char* module_str = R"(
-HloModule TestModule
-
-ENTRY entry {
- Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
- custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToHost"
- ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
- ParseAndReturnVerifiedModule(module_str));
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-
- const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
- const Layout input_layout = call_0->operand(0)->shape().layout();
- const Layout output_layout = call_0->shape().layout();
- EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
- << "Expected the same input/output layouts. Input: " << input_layout
- << ". Output: " << output_layout;
-}
-
-TEST_F(LayoutAssignmentTest, MoveToDeviceCustomCallConstrained) {
- const char* module_str = R"(
-HloModule TestModule
-
-ENTRY entry {
- Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
- custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToDevice"
- ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
- ParseAndReturnVerifiedModule(module_str));
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-
- const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
- const Layout input_layout = call_0->operand(0)->shape().layout();
- const Layout output_layout = call_0->shape().layout();
- EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
- << "Expected the same input/output layouts. Input: " << input_layout
- << ". Output: " << output_layout;
-}
-
-TEST_F(LayoutAssignmentTest, ConvCuDNNF8) {
- if (!GetCudaComputeCapability().IsAtLeast(
- se::CudaComputeCapability::HOPPER)) {
- GTEST_SKIP() << "FP8 convolutions require HOPPER or newer archiecture.";
- }
-
- const char* hlo = R"(
-
- HloModule jit_conv_general_dilated
-
- ENTRY main.4 {
- Arg_0 = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
- Arg_1 = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
- ROOT conv = f8e4m3fn[1,64,64,32]{3,2,1,0} convolution(Arg_0, Arg_1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
- }
-)";
-
- MatchOptimizedHlo(hlo, R"(
- // CHECK: [[P0:%[^ ]+]] = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
- // CHECK: [[P1:%[^ ]+]] = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
- // CHECK-NEXT: [[P2:%[^ ]+]] = f8e4m3fn[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
- // CHECK-NEXT: [[CONV:%[^ ]+]] = (f8e4m3fn[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
- )");
-}
-
-TEST_F(LayoutAssignmentTest, ConvCuDNNBF16) {
- if (!GetCudaComputeCapability().IsAtLeast(
- se::CudaComputeCapability::AMPERE)) {
- GTEST_SKIP() << "Conv with Bfloat16 uses NHWC layout for "
- "architectures with Tensor Cores.";
- }
-
- const char* hlo = R"(
-
- HloModule jit_conv_general_dilated
-
- ENTRY main.4 {
- Arg_0.1 = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
- ROOT convolution.3 = bf16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 64, 64, 16) rhs_shape=(3, 3, 16, 32) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.8/dist-packages/flax/linen/linear.py" source_line=438}
- }
-)";
-
- MatchOptimizedHlo(hlo, R"(
- // CHECK: [[P0:%[^ ]+]] = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
- // CHECK: [[P1:%[^ ]+]] = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
- // CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
- // CHECK-NEXT: %cudnn-conv.1 = (bf16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
- )");
-}
-
-TEST_F(LayoutAssignmentTest, ConvCuDNNFP16) {
- if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
- GTEST_SKIP() << "Conv with FP16 uses NHWC layout for "
- "architectures with Tensor Cores.";
- }
-
- const char* hlo = R"(
-
- HloModule jit_conv_general_dilated
-
- ENTRY main.4 {
- Arg_0.1 = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
- Arg_1.2 = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
- ROOT convolution.3 = f16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
- }
-)";
-
- MatchOptimizedHlo(hlo, R"(
- // CHECK: [[P0:%[^ ]+]] = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
- // CHECK: [[P1:%[^ ]+]] = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
- // CHECK-NEXT: [[P2:%[^ ]+]] = f16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
- // CHECK-NEXT: %cudnn-conv.1 = (f16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
- )");
-}
-
-TEST_F(LayoutAssignmentTest, ReduceOperandLayout) {
- const char* module_str = R"(
-scalar_add_computation {
- scalar_lhs = c64[] parameter(0)
- scalar_rhs = c64[] parameter(1)
- ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
- param_0 = c64[512,64,1024,32,128]{4,3,2,1,0} parameter(0)
- negate = c64[512,64,1024,32,128]{4,3,2,1,0} negate(param_0)
- constant_7 = c64[] constant((0, 0))
- ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1,3}, to_apply=scalar_add_computation
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
- ParseAndReturnVerifiedModule(module_str));
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
- auto reduce = m->entry_computation()->root_instruction();
- EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
- LayoutUtil::MakeLayout({3, 1, 4, 2, 0}).minor_to_major());
-}
-
-TEST_F(LayoutAssignmentTest, ReduceOperandLayoutDivisorOfWarpSize) {
- // Same as ReduceOperandLayout, but with a small reduction dimension that
- // is a divisor of the warp size.
- const char* module_str = R"(
-scalar_add_computation {
- scalar_lhs = c64[] parameter(0)
- scalar_rhs = c64[] parameter(1)
- ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
- param_0 = c64[512,16,1024,128]{3,2,1,0} parameter(0)
- negate = c64[512,16,1024,128]{3,2,1,0} negate(param_0)
- constant_7 = c64[] constant((0, 0))
- ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1}, to_apply=scalar_add_computation
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
- ParseAndReturnVerifiedModule(module_str));
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
- GpuLayoutAssignment layout_assignment(
- &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
- EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
- auto reduce = m->entry_computation()->root_instruction();
- EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
- LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major());
-}
-
-TEST_F(LayoutAssignmentTest, SendRcvLayout) {
- const char* hlo = R"(
-HloModule Module
-
-condition {
- p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
- ROOT lt = pred[] constant(1)
-}
-
-body {
- p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
-
- t1 = f32[100,100] get-tuple-element(p), index=0
- t = (f32[100,100], u32[], token[]) get-tuple-element(p), index=1
- sdone = token[] send-done(t), channel_id=3, frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- tk = token[] after-all()
-
-
- rcvd = (f32[100,100]{0,1}, u32[], token[]) recv(tk), channel_id=2
- zz = (f32[100,100]{0,1}, token[]) recv-done(rcvd), channel_id=2
-
- rcvd_d = get-tuple-element(zz), index=0
-
- snd = (f32[100,100]{0,1}, u32[], token[]) send(t1, tk), channel_id=3, frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- a = add(t1, t1)
-
- b = add(rcvd_d, a)
-
- ROOT tup = tuple(b, snd)
-}
-
-ENTRY %main {
- p0 = f32[100,100] parameter(0)
- tk = token[] after-all()
- snd = (f32[100,100]{0,1}, u32[], token[]) send(p0, tk), channel_id=1, frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- t = tuple(p0, snd)
- ROOT loop = while(t), condition=condition, body=body
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
- ParseAndReturnVerifiedModule(hlo));
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
-
- RunAndFilecheckHloRewrite(
- hlo,
- GpuLayoutAssignment{&computation_layout, GetGpuComputeCapability(),
- GetDnnVersion()},
- R"(
-// CHECK: (f32[100,100]{1,0}, u32[], token[]) recv
-// CHECK: (f32[100,100]{1,0}, token[]) recv-done
-// CHECK: (f32[100,100]{1,0}, u32[], token[]) send
- )");
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc
index 928011c..215609c 100644
--- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc
@@ -31,7 +31,7 @@
#include "xla/layout.h"
#include "xla/service/buffer_value.h"
#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/stream_attribute_annotator.h"
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_memory_scheduler.h"
#include "xla/service/hlo_rematerialization.h"
diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc
index 1dcab0c..de3adb3 100644
--- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc
@@ -146,7 +146,7 @@
EXPECT_EQ(send1->channel_id(), send2->channel_id());
const char* kPeeledAttr = "_xla_send_recv_validation=\"invalid\"";
- const char* kRotatedAttr = "_xla_send_recv_validation=\"{{0,6}}\"";
+ const char* kRotatedAttr = "_xla_send_recv_validation={{0,6}}";
EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kPeeledAttr));
EXPECT_THAT(recv1->ToString(), ::testing::HasSubstr(kPeeledAttr));
EXPECT_THAT(send2->ToString(), ::testing::HasSubstr(kRotatedAttr));
diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc
deleted file mode 100644
index 7f1f800..0000000
--- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc
+++ /dev/null
@@ -1,130 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_opt_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/status_macros.h"
-#include "tsl/platform/errors.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> ReduceScatterCreator::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- const HloModuleConfig &config = module->config();
- int64_t next_channel_id = hlo_query::NextChannelId(*module);
-
- bool changed = false;
- for (HloComputation *computation :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction *instruction :
- computation->MakeInstructionPostOrder()) {
- if (instruction->opcode() != HloOpcode::kAllReduce) {
- continue;
- }
- auto *ar = Cast<HloAllReduceInstruction>(instruction);
- auto ar_spec = MatchReduceScatter(ar, config.num_partitions(),
- config.replica_count(),
- /*allow_multiple_split_dims=*/false,
- /*allow_intervening_reshape=*/true);
- if (!ar_spec) {
- VLOG(2) << "Cannot match reduce-scatter " << ar->ToString();
- continue;
- }
-
- HloInstruction *ds = ar_spec->dynamic_slice;
-
- // Convert to all-reduce scatter. The output shape of the all-reduce
- // scatter will the same as the input shape, except the split dim size is
- // that of the result of the dynamic slice.
- const int64_t split_dim = ar_spec->split_dim;
- Shape scatter_shape = ar->shape();
- const int64_t split_dim_size = scatter_shape.dimensions(split_dim);
- HloInstruction *rs_input = ar->mutable_operand(0);
- const int64_t scatter_dim_size = split_dim_size / ar_spec->group_size;
- TF_RET_CHECK(scatter_dim_size * ar_spec->group_size <= split_dim_size);
- if (split_dim_size % ar_spec->group_size != 0) {
- // The dynamic-slice does not evenly split the scatter dim. In that
- // case, create a reduce-scatter with the relevant slice of the
- // all-reduce input.
- scatter_shape.set_dimensions(split_dim,
- scatter_dim_size * ar_spec->group_size);
- rs_input = computation->AddInstruction(HloInstruction::CreateSlice(
- scatter_shape, rs_input,
- std::vector<int64_t>(scatter_shape.rank(), 0),
- scatter_shape.dimensions(),
- std::vector<int64_t>(scatter_shape.rank(), 1)));
- }
- scatter_shape.set_dimensions(split_dim, scatter_dim_size);
-
- std::optional<int64_t> channel_id;
- if (ar->channel_id()) {
- // We cannot reuse the channel_id on all-reduce for reduce-scatter.
- channel_id = next_channel_id++;
- }
-
- HloInstruction *ars =
- computation->AddInstruction(HloInstruction::CreateReduceScatter(
- scatter_shape, {rs_input}, ar->to_apply(), ar->device_list(),
- ar->constrain_layout(), channel_id, ar->use_global_device_ids(),
- ar_spec->split_dim));
-
- // If there was an intervening reshape, reshape the non-split dimensions
- // to match that existing reshape. Basically we can just reshape the ars
- // result to the dynamic slice shape.
- HloInstruction *result = ars;
- HloInstruction *reshape = nullptr;
- if (ds->operand(0) != ar) {
- reshape = ds->mutable_operand(0);
- result = computation->AddInstruction(
- HloInstruction::CreateReshape(ds->shape(), result));
- }
-
- // Note that RemoveInstructionAndUnusedOperands may not always remove the
- // all-reduce operand of the dynamic-slice, so remove all the dead
- // instructions manually.
- TF_RETURN_IF_ERROR(ds->ReplaceAllUsesWith(result));
- TF_RETURN_IF_ERROR(computation->RemoveInstruction(ds));
- if (reshape) {
- TF_RETURN_IF_ERROR(computation->RemoveInstruction(reshape));
- }
- TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar));
- changed = true;
- }
- }
-
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h b/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h
deleted file mode 100644
index fcecb46..0000000
--- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_
-#define XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Transforms dynamic-slice(all-reduce) to a reduce-scatter.
-class ReduceScatterCreator : public HloModulePass {
- public:
- ReduceScatterCreator() = default;
- absl::string_view name() const override { return "reduce-scatter-creator"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc b/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc
deleted file mode 100644
index 771e8cb..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/llvm_ir/buffer_assignment_util.h"
-#include "xla/service/name_uniquer.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-
-namespace gpu {
-
-absl::StatusOr<bool> GpuSanitizeConstantNames::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
-
- NameUniquer instr_name_uniquer(/*separator=*/"_");
- // Collect the names used for the non-constant HLO instructions.+
- for (HloComputation* computation : module->computations(execution_threads)) {
- for (HloInstruction* instr : computation->instructions()) {
- if (instr->opcode() == HloOpcode::kConstant) {
- continue;
- }
-
- // Record the non-constant HLO instruction name in uniquer, and keep
- // original instruction name unchanged.
- instr_name_uniquer.GetUniqueName(instr->name());
- }
- }
-
- // Sanitize the names for the constant HLO instructions and make them unique.
- // This is not merged into the above loop because we don't want this pass to
- // change the names of non-constant instructions, that is, if a constant HLO
- // conflicts with a non-constant HLO, we change the name of the constant HLO
- // even though the non-constant HLO comes after in the HLO module.
- for (HloComputation* computation : module->computations(execution_threads)) {
- for (HloInstruction* instr : computation->instructions()) {
- if (instr->opcode() != HloOpcode::kConstant) {
- continue;
- }
- std::string sanitized_name = llvm_ir::SanitizeConstantName(*instr);
- instr->SetAndSanitizeName(sanitized_name);
- instr->UniquifyName(&instr_name_uniquer);
- // Register this new name with the module's instruction_name_uniquer to
- // avoid name collision that might happen in future.
- module->instruction_name_uniquer().GetUniqueName(instr->name());
- changed = true;
- }
- }
-
- return changed;
-} // namespace gpu
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h b/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h
deleted file mode 100644
index 08701a4..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_
-#define XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Sanitizes HLO instruction names for the GPU backend. Currently, it only
-// replaces . and - in the HLO constant instruction names with _ to please the
-// LLVM PTX backend.
-class GpuSanitizeConstantNames : public HloModulePass {
- public:
- absl::string_view name() const override { return "sanitize-constant-names"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc
deleted file mode 100644
index 17f45dc..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc
+++ /dev/null
@@ -1,111 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/literal_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-using SanitizeConstantNamesTest = HloTestBase;
-
-TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) {
- const char *const kHloString = R"(
- HloModule HyphenInInstructionName
- ENTRY kernelEntry {
- ROOT equal-to = s32[2]{0} constant({42, 73})
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
- HloInstruction *root = module->entry_computation()->root_instruction();
- EXPECT_EQ(root->name(), "equal_to");
-}
-
-TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) {
- const char *const kHloString = R"(
- HloModule HyphenInInstructionName
- ENTRY kernelEntry {
- ROOT equal.to = s32[2]{0} constant({42, 73})
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
- HloInstruction *root = module->entry_computation()->root_instruction();
- EXPECT_EQ(root->name(), "equal_to");
-}
-
-TEST_F(SanitizeConstantNamesTest, NewInstructionNameRegisteredWithModule) {
- const char *const kHloString = R"(
- HloModule HyphenInInstructionName
- ENTRY kernelEntry {
- ROOT equal.to = s32[2]{0} constant({42, 73})
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
- HloInstruction *root = module->entry_computation()->root_instruction();
- EXPECT_EQ(root->name(), "equal_to");
-
- auto constant_instr =
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1));
- constant_instr->SetAndSanitizeName("equal_to");
- module->entry_computation()->AddInstruction(std::move(constant_instr));
-
- EXPECT_THAT(FindInstruction(module.get(), "equal_to.1"),
- GmockMatch(m::Constant()));
-}
-
-TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) {
- const char *const kHloString = R"(
- HloModule BufferSanitizedName
- ENTRY kernelEntry {
- equal.to = s32[2]{0} constant({42, 73})
- equal-to = s32[2]{0} constant({67, 3})
- ROOT equal_to = s32[2]{0} add(equal.to, equal-to)
- })";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
- EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"),
- GmockMatch(m::Constant()));
- EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"),
- GmockMatch(m::Constant()));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc b/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc
deleted file mode 100644
index b03b340..0000000
--- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_scatter_expander.h"
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/primitive_util.h"
-
-namespace xla {
-
-bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
- // TODO(b/129698548): Scattering elements larger than 64 bits is not
- // supported by XLA:GPU.
- // TODO(b/227486631): Variadic scatter is not yet supported by GPU.
- return inst->opcode() == HloOpcode::kScatter &&
- (inst->shape().IsTuple() ||
- primitive_util::BitWidth(inst->shape().element_type()) > 64);
-}
-
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h b/third_party/xla/xla/service/gpu/gpu_scatter_expander.h
deleted file mode 100644
index 100350c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h
+++ /dev/null
@@ -1,40 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_
-#define XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_
-
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/scatter_expander.h"
-
-namespace xla {
-
-// Legalizes scatters on the GPU.
-class GpuScatterExpander : public ScatterExpander {
- public:
- // Although we pass kEliminateAllScatters, we override this behavior in
- // InstruuctionMatchesPattern and select only some scatters to expand.
- GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {}
-
- absl::string_view name() const override { return "gpu_scatter_expander"; }
-
- protected:
- bool InstructionMatchesPattern(HloInstruction* inst) override;
-};
-
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc
deleted file mode 100644
index a0af798..0000000
--- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc
+++ /dev/null
@@ -1,165 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_schedule_postprocessing.h"
-
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-// Maps a computation to a boolean that indicates whether the computation may
-// invoke custom-calls directly or indirectly, which can eventually trigger gpu
-// synchronization.
-using CustomCallInComputation =
- absl::flat_hash_map<const HloComputation*, bool>;
-
-// Returns whether the hlo may invoke custom-calls which may trigger gpu
-// synchronization. Currently, we only check for custom-calls, because they are
-// the only operations that can be parallel with asynchronous collectives
-// operations in an hlo-schedule and may trigger gpu synchronization.
-bool MayInvokeCustomCall(
- const HloInstruction* hlo,
- const CustomCallInComputation& custom_call_in_computation) {
- if (hlo->opcode() == HloOpcode::kCustomCall) {
- return true;
- }
-
- return absl::c_any_of(
- hlo->called_computations(), [&](const HloComputation* callee) {
- return custom_call_in_computation.find(callee)->second;
- });
-}
-
-// Returns true if this is an asynchronous collective start operation, excluding
-// P2P operations.
-absl::StatusOr<bool> IsRelevantAsynchronousStart(const HloInstruction* hlo) {
- if (!hlo_query::IsAsyncCollectiveStartOp(hlo,
- /*include_send_recv=*/false)) {
- return false;
- }
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- hlo->backend_config<GpuBackendConfig>());
- const CollectiveBackendConfig& collective_backend_config =
- gpu_config.collective_backend_config();
- return !collective_backend_config.is_sync();
-}
-
-// Returns true if this is a collective done operation, excluding P2P
-// operations.
-absl::StatusOr<bool> IsRelevantAsynchronousDone(const HloInstruction* hlo) {
- return hlo_query::IsAsyncCollectiveDoneOp(hlo,
- /*include_send_recv=*/false);
-}
-
-// For a given computation, finds all the asynchronous collective operations
-// that aren't parallel with custom-calls and sets its no_parallel_custom_call
-// attribute to true. Also records whether the given computation may invoke
-// custom-calls.
-absl::StatusOr<bool> ProcessComputation(
- const HloSchedule& schedule, HloComputation* computation,
- CustomCallInComputation& custom_call_in_computation) {
- bool changed = false;
- bool has_custom_call = false;
- absl::flat_hash_set<HloInstruction*> async_starts;
- const HloInstructionSequence& sequence = schedule.sequence(computation);
-
- // Visit instructions in the sequence. Collect relevant asynchronous
- // collective start ops. When we see a relevant asynchronous collective done
- // op, remove the corresponding start op from the collection and set its
- // attribute no_parallel_custom_call to true. When we see a custom-call, clear
- // the start ops from the collection and keep their attribute
- // no_parallel_custom_call as false.
- const std::vector<HloInstruction*>& all_instructions =
- sequence.instructions();
- for (HloInstruction* hlo : all_instructions) {
- if (MayInvokeCustomCall(hlo, custom_call_in_computation)) {
- async_starts.clear();
- has_custom_call = true;
- continue;
- }
- TF_ASSIGN_OR_RETURN(bool is_async_start, IsRelevantAsynchronousStart(hlo));
- if (is_async_start) {
- async_starts.insert(hlo);
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(bool is_async_done, IsRelevantAsynchronousDone(hlo));
- if (is_async_done) {
- HloInstruction* async_start = hlo->mutable_operand(0);
- if (async_starts.contains(async_start)) {
- changed = true;
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- async_start->backend_config<GpuBackendConfig>());
- CollectiveBackendConfig& collective_backend_config =
- *gpu_config.mutable_collective_backend_config();
- collective_backend_config.set_no_parallel_custom_call(true);
- TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
- async_starts.erase(async_start);
- }
- }
- }
-
- custom_call_in_computation[computation] = has_custom_call;
- return changed;
-}
-
-} // anonymous namespace
-
-absl::StatusOr<bool> GpuSchedulePostprocessing::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- if (!module->has_schedule()) return false;
- HloSchedule& schedule = module->schedule();
- bool changed = false;
- CustomCallInComputation custom_call_in_computation;
-
- // We visit computations in the order of callees to callers, as information is
- // propagated from calles to callers.
- std::vector<HloComputation*> all_computations =
- module->MakeComputationPostOrder(execution_threads);
- for (auto iter = all_computations.begin(); iter != all_computations.end();
- ++iter) {
- HloComputation* computation = *iter;
- if (computation->IsFusionComputation()) {
- custom_call_in_computation[computation] = false;
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(
- bool result,
- ProcessComputation(schedule, computation, custom_call_in_computation));
- changed |= result;
- }
-
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h
deleted file mode 100644
index d8eda81..0000000
--- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_
-#define XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Amends a schedule result with the needed information to support a runtime
-// implementation. Currently, this pass refines attribute
-// no_parallel_custom_call for asynchronous collective operations to support
-// runtime optimization, such as skipping rendezvous of all participating
-// threads for NCCL collective operations. In particular, it sets the attribute
-// value for Collective-start operations with is_sync=false; it also keeps the
-// attribute value untouch for the operations with is_sync=true and for P2P
-// operations, assumming the runtime won't use those values.
-//
-class GpuSchedulePostprocessing : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "gpu-schedule-postprocessing";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc
deleted file mode 100644
index 9d4956b..0000000
--- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc
+++ /dev/null
@@ -1,163 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_schedule_postprocessing.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using GpuSchedulePostprocessingTest = HloTestBase;
-
-TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) {
- constexpr absl::string_view kHloString = R"(
- HloModule module, is_scheduled=true
-
- ENTRY entry {
- pf32 = f32[1] parameter(0)
-
- all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
- ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
- }
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kHloString)));
- GpuSchedulePostprocessing pass;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(GpuSchedulePostprocessingTest, P2POpsNotChanged) {
- constexpr absl::string_view kHloString = R"(
- HloModule module, is_scheduled=true
-
- ENTRY main {
- f0 = f32[] constant(0.0)
- init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
- after-all = token[] after-all()
- recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}"
- }
- recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2
- ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0
- }
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kHloString)));
- GpuSchedulePostprocessing pass;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) {
- constexpr absl::string_view kHloString = R"(
- HloModule module, is_scheduled=true
-
- ENTRY entry {
- pf32 = f32[1] parameter(0)
- pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
- all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
- ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
- }
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kHloString)));
- GpuSchedulePostprocessing pass;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- start->backend_config<GpuBackendConfig>());
- const CollectiveBackendConfig& collective_backend_config =
- gpu_config.collective_backend_config();
- EXPECT_TRUE(collective_backend_config.no_parallel_custom_call());
-}
-
-TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) {
- constexpr absl::string_view kHloString = R"(
- HloModule module, is_scheduled=true
-
- ENTRY entry {
- pf32 = f32[1] parameter(0)
- all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
- pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
- all-gather-done = f32[2] all-gather-done(all-gather-start)
- ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
- }
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kHloString)));
- GpuSchedulePostprocessing pass;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
- EXPECT_FALSE(changed);
-
- HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- start->backend_config<GpuBackendConfig>());
- const CollectiveBackendConfig& collective_backend_config =
- gpu_config.collective_backend_config();
- EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
-}
-
-TEST_F(GpuSchedulePostprocessingTest,
- AsynchronousOpsWithParallelNestedCustomcall) {
- constexpr absl::string_view kHloString = R"(
- HloModule module, is_scheduled=true
- foo {
- v = f32[1] parameter(0)
- ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call"
- }
-
- ENTRY entry {
- pf32 = f32[1] parameter(0)
- all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
- pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo
- all-gather-done = f32[2] all-gather-done(all-gather-start)
- ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
- }
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnUnverifiedModule((kHloString)));
- GpuSchedulePostprocessing pass;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
- EXPECT_FALSE(changed);
-
- HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- start->backend_config<GpuBackendConfig>());
- const CollectiveBackendConfig& collective_backend_config =
- gpu_config.collective_backend_config();
- EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc
deleted file mode 100644
index 217387c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc
+++ /dev/null
@@ -1,343 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sort_rewriter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/runtime/cub_sort_thunk.h"
-#include "xla/service/stable_sort_expander.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Analyze sort comparer function.
-struct SortComputationAnalysis {
- int key_operand; // 0 or 1
- bool descending;
-};
-
-std::pair<int64_t, int64_t> ParametersFromCmpOperands(
- const HloCompareInstruction* cmp_op) {
- if (cmp_op == nullptr) {
- return std::pair<int64_t, int64_t>(-1, -1);
- }
- const HloParameterInstruction* param0 =
- DynCast<HloParameterInstruction>(cmp_op->operand(0));
- const HloParameterInstruction* param1 =
- DynCast<HloParameterInstruction>(cmp_op->operand(1));
- return (param0 && param1) ? std::make_pair(param0->parameter_number(),
- param1->parameter_number())
- : std::pair<int64_t, int64_t>(-1, -1);
-}
-
-// Returns sort info on compatible compare instructions. The instruction may
-// belong to a computation that has 2 or 4 operands. If this is the root
-// instruction of a computation with 4 parameters only succeeds in cases where
-// 2 of the parameters are ignored.
-std::optional<SortComputationAnalysis> AnalyzeCompareOp(
- const HloInstruction* maybe_compare_op) {
- // Root instruction must be a comparison with a valid direction.
- const HloCompareInstruction* compare =
- DynCast<HloCompareInstruction>(maybe_compare_op);
- if (compare == nullptr || compare->direction() == ComparisonDirection::kEq ||
- compare->direction() == ComparisonDirection::kNe) {
- return std::nullopt;
- }
-
- // Compare should operate on the function parameters for a single tensor.
- auto [index0, index1] = ParametersFromCmpOperands(compare);
- if (index0 == -1 || index1 == -1) {
- return std::nullopt;
- }
-
- // When sorting a pair of tensors, the parameters should be adjacent.
- int first_index = std::min(index0, index1);
- if (first_index % 2 != 0 || std::max(index0, index1) != first_index + 1) {
- return std::nullopt;
- }
-
- // Return the tensor index and the sort direction.
- bool descending = compare->direction() == ComparisonDirection::kGt ||
- compare->direction() == ComparisonDirection::kGe;
- bool reverse = first_index != index0;
- return SortComputationAnalysis{first_index / 2, descending != reverse};
-}
-
-// Detects a sort with these properties:
-// - Has two operands -- one is an iota op
-// - Has a comparison computation that takes 4 inputs and compares them
-// hierarchically, so that the iota inputs are the final tie-breaker.
-//
-// The above is equivalent to a stable sort where the iota operand is completely
-// ignored. That simpler comparator is the one detected in AnalyzeCompareOp, but
-// that's insufficient, because the StableSortExpander pass expands it into the
-// more complex version detected below.
-std::optional<SortComputationAnalysis> AnalyzeComplexSortComputation(
- const HloSortInstruction& sort_op) {
- auto computation = sort_op.called_computations().front();
- if (computation->num_parameters() != 4) {
- return std::nullopt;
- }
-
- int64_t iota_operand_index =
- StableSortExpander::IotaOperandIndexForStableSort(sort_op);
- if (iota_operand_index < 0) {
- return std::nullopt;
- }
-
- auto root = computation->root_instruction();
- if (root->opcode() != HloOpcode::kSelect) {
- return std::nullopt;
- }
-
- // Check that the middle operand of the select compares the iota input.
- auto iota_cmp = DynCast<HloCompareInstruction>(root->operand(1));
- auto [iotap0, iotap1] = ParametersFromCmpOperands(iota_cmp);
- if (iota_cmp == nullptr ||
- iota_cmp->direction() != ComparisonDirection::kLt ||
- iotap0 != iota_operand_index * 2 ||
- iotap1 != iota_operand_index * 2 + 1) {
- return std::nullopt;
- }
-
- // Check that the first operand of the select is an EQ comparison of the
- // values (non-iota) input.
- auto eq_cmp = DynCast<HloCompareInstruction>(root->operand(0));
- if (eq_cmp == nullptr || eq_cmp->direction() != ComparisonDirection::kEq) {
- return std::nullopt;
- }
-
- // EQ comparison case 1: direct comparison of parameters
- auto [p0, p1] = ParametersFromCmpOperands(eq_cmp);
- if (p0 < 0 || p1 < 0) {
- // EQ comparison case 2: comparison of comparisons. This is what
- // the StableSortExpander pass currently generates.
- auto cmp = DynCast<HloCompareInstruction>(eq_cmp->operand(0));
- auto cmp_reverse = DynCast<HloCompareInstruction>(eq_cmp->operand(1));
- auto [a, b] = ParametersFromCmpOperands(cmp);
- auto [p, q] = ParametersFromCmpOperands(cmp_reverse);
- if (cmp == nullptr || cmp_reverse == nullptr || a < 0 || b < 0 || a != q ||
- b != p || cmp->direction() != cmp_reverse->direction() ||
- cmp->direction() == Comparison::Direction::kEq ||
- cmp->direction() == Comparison::Direction::kNe) {
- return std::nullopt;
- }
- }
-
- // At this point only the last operand of the select needs to be verified.
- return AnalyzeCompareOp(root->operand(2));
-}
-
-std::optional<SortComputationAnalysis> AnalyzeSortOp(
- const HloSortInstruction& sort_op) {
- auto computation = sort_op.called_computations().front();
-
- // First, check if the computation is a simple compare op on the operands.
- auto result = AnalyzeCompareOp(computation->root_instruction());
- if (!result.has_value()) {
- // If the above fails, check if the sort instruction and comparer are more
- // complex, like what is produced by the StableSortExpander pass.
- result = AnalyzeComplexSortComputation(sort_op);
- }
- return result;
-}
-
-// Create runner for CUB sort operation.
-absl::StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner(
- HloSortInstruction* sort_op, const SortComputationAnalysis& sort_config) {
- int value_index = 1 - sort_config.key_operand;
- return CubSortRunnerInterface::Create(
- sort_op->operand(sort_config.key_operand)->shape().element_type(),
- sort_op->operand_count() == 2
- ? std::optional(sort_op->operand(value_index)->shape().element_type())
- : std::nullopt);
-}
-
-// Verify that the sort tensor shape is supported by CUB.
-bool IsCubCompatibleSort(HloSortInstruction* sort_op) {
- VLOG(1) << "Sort instruction: " << sort_op->name();
- if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) {
- VLOG(2) << "Unsupported operand count: " << sort_op->operand_count();
- return false;
- }
-
- const Shape& operand_shape = sort_op->operand(0)->shape();
- if (sort_op->sort_dimension() != operand_shape.rank() - 1) {
- VLOG(2) << "Sort dimension should be the minor one";
- return false;
- }
- if (Product(operand_shape.dimensions()) <
- GpuSortRewriter::SortSizeThreshold()) {
- VLOG(2) << "Tensor shape size is too small to see an improvement";
- return false;
- }
-
- auto sort_config = AnalyzeSortOp(*sort_op);
- if (!sort_config.has_value()) {
- VLOG(2) << "Only simple compare computations are supported";
- return false;
- }
- if (!CreateRunner(sort_op, *sort_config).ok()) {
- VLOG(2) << "Unsupported operand types (no compiled CUB kernels)";
- return false;
- }
- VLOG(2) << "Sort operation is compatible";
- return true;
-}
-
-// Restore the result shape after sorting a pair of tensors.
-// The trailing argument is the scratch buffer which should be discarded.
-HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
- HloInstruction* custom_call, bool swap) {
- HloComputation* parent = sort_op->parent();
- HloInstruction* gte0 =
- parent->AddInstruction(HloInstruction::CreateGetTupleElement(
- sort_op->operand(0)->shape(), custom_call, swap ? 1 : 0));
- HloInstruction* gte1 =
- parent->AddInstruction(HloInstruction::CreateGetTupleElement(
- sort_op->operand(1)->shape(), custom_call, swap ? 0 : 1));
- return parent->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
-}
-
-} // namespace
-
-// Rewrites a single sort instruction with a custom call.
-absl::StatusOr<bool> GpuSortRewriter::RunOnInstruction(
- HloSortInstruction* sort_op) {
- // Get the sort tensor index and direction.
- SortComputationAnalysis sort_config = AnalyzeSortOp(*sort_op).value();
-
- // Get scratch size requirements from CUB.
- const Shape& operand_shape = sort_op->operand(0)->shape();
- int64_t batch_size = Product(operand_shape.dimensions()) /
- operand_shape.dimensions(sort_op->sort_dimension());
-
- TF_ASSIGN_OR_RETURN(auto runner, CreateRunner(sort_op, sort_config));
- TF_ASSIGN_OR_RETURN(
- int64_t scratch_size,
- runner->GetScratchSize(Product(operand_shape.dimensions()), batch_size));
-
- // Align and increase scratch size to fit the offsets.
- if (batch_size > 1) {
- scratch_size += sizeof(int) - scratch_size % sizeof(int);
- scratch_size += (batch_size + 1) * sizeof(int);
- }
-
- // Values are only present if sorting a pair of tensors.
- HloInstruction* keys = sort_op->mutable_operand(0);
- HloInstruction* values = nullptr;
- if (sort_op->operand_count() == 2) {
- values = sort_op->mutable_operand(1);
- if (sort_config.key_operand == 1) {
- std::swap(keys, values);
- }
- }
-
- // Build the resulting shape for the custom call.
- std::vector<Shape> shapes{keys->shape()};
- std::vector<HloInstruction*> operands{keys};
- if (values != nullptr) {
- shapes.push_back(values->shape());
- operands.push_back(values);
- }
- shapes.push_back(ShapeUtil::MakeShape(U8, {scratch_size}));
- Shape call_shape = ShapeUtil::MakeTupleShape(absl::MakeSpan(shapes));
-
- // Build the custom call instruction.
- HloInstruction* custom_call =
- sort_op->parent()->AddInstruction(HloInstruction::CreateCustomCall(
- call_shape, absl::MakeSpan(operands), kCubDeviceRadixSortTarget));
-
- xla::SortOptions backend_config;
- backend_config.set_descending(sort_config.descending);
- TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
-
- // Build the replacement instruction.
- HloInstruction* replacement;
- if (sort_op->operand_count() == 1) {
- replacement =
- sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
- sort_op->shape(), custom_call, 0));
- } else {
- replacement = UnpackResultPair(sort_op, custom_call,
- /*swap=*/sort_config.key_operand == 1);
- }
-
- // Replace sort operation with custom call followed by GTE.
- TF_RETURN_IF_ERROR(
- sort_op->parent()->ReplaceInstruction(sort_op, replacement));
- return true;
-}
-
-// Rewrites the sorts in the given computation into calls to CUB.
-absl::StatusOr<bool> GpuSortRewriter::RunOnComputation(
- HloComputation* computation) {
- std::vector<HloSortInstruction*> sort_ops;
- for (auto* inst : computation->instructions()) {
- HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
- if (sort != nullptr && IsCubCompatibleSort(sort)) {
- sort_ops.push_back(sort);
- }
- }
- bool changed = false;
- for (auto* sort : sort_ops) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(sort));
- changed |= result;
- }
- return changed;
-}
-
-// Replace compatible sort operations with custom calls.
-absl::StatusOr<bool> GpuSortRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), before:\n" + module->ToString());
- bool changed = false;
- for (HloComputation* computation :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
- changed |= result;
- }
- XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), after:\n" + module->ToString());
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h
deleted file mode 100644
index 51dba3c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
-#define XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites sort operations into CustomCall HLOs that call into CUB.
-// Only a subset of shapes is supported - either a single tensor with a simple
-// compare function or a pair of tensors where keys are unsigned integers.
-
-class GpuSortRewriter : public HloModulePass {
- public:
- absl::string_view name() const override { return "gpu-sort-rewriter"; }
-
- // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite
- // tensors with sizes below this limit.
- static int SortSizeThreshold() { return sort_size_threshold_; }
- static void SetSortSizeThresholdForTestingOnly(int threshold) {
- // We need to be able to reduce the threshold for testing, so that the tests
- // can run and compare against the reference interpreter, which is quite
- // slow.
- sort_size_threshold_ = threshold;
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- absl::StatusOr<bool> RunOnInstruction(HloSortInstruction* sort_op);
- absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
-
- static inline int sort_size_threshold_ = 16385;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc
deleted file mode 100644
index abacbc1..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sort_rewriter.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> GpuSortRewriter::RunOnInstruction(
- HloSortInstruction* sort_op) {
- return false;
-}
-
-absl::StatusOr<bool> GpuSortRewriter::RunOnComputation(
- HloComputation* computation) {
- return false;
-}
-
-absl::StatusOr<bool> GpuSortRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- return false;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc
deleted file mode 100644
index 69cdb92..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc
+++ /dev/null
@@ -1,453 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sort_rewriter.h"
-
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuSortRewriterTest : public HloTestBase {
- public:
- void SetUp() override {
- HloTestBase::SetUp();
- GpuSortRewriter::SetSortSizeThresholdForTestingOnly(1000);
- }
-
- bool RunModuleAndPass(HloModule* module) {
- auto cloned = module->Clone();
- bool changed = GpuSortRewriter().Run(module).value();
- if (changed) {
- // Here we run an end to end test to make sure that GpuSortRewriter does
- // not introduce an incorrect rewrite. To do this, we need to clone the
- // original module because the interpreter cannot process the already
- // optimized module.
- EXPECT_TRUE(RunAndCompare(std::move(cloned), ErrorSpec{0, 0}));
- }
- return changed;
- }
-
- void ExpectDirection(const HloInstruction* instruction, bool descending) {
- auto config = instruction->backend_config<xla::SortOptions>();
- EXPECT_EQ(config->descending(), descending);
- }
-};
-
-// Basic sort: ascending.
-TEST_F(GpuSortRewriterTest, SortKeysLessThan) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input = f32[1000] parameter(0)
- ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
- ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
- /*descending=*/false);
-}
-
-// Basic sort: descending.
-TEST_F(GpuSortRewriterTest, SortKeysGreaterThan) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
-}
-
-ENTRY %main {
- %input = f32[1000] parameter(0)
- ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
- ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
- /*descending=*/true);
-}
-
-// Comparer swaps the parameter order -> direction is reversed.
-TEST_F(GpuSortRewriterTest, SortKeysGreaterThanSwapped) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(1)
- %rhs = f32[] parameter(0)
- ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
-}
-
-ENTRY %main {
- %input = f32[1000] parameter(0)
- ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
- ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
- /*descending=*/false);
-}
-
-// Sort a pair of tensors, keys go first.
-TEST_F(GpuSortRewriterTest, SortPairs) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs_key = u32[] parameter(0)
- %rhs_key = u32[] parameter(1)
- %lhs_value = f32[] parameter(2)
- %rhs_value = f32[] parameter(3)
- ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
-}
-
-ENTRY %main {
- %input_keys = u32[1000] parameter(0)
- %input_values = f32[1000] parameter(1)
- ROOT %sort = (u32[1000], f32[1000]) sort(%input_keys, %input_values),
- dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
- m::GetTupleElement(m::CustomCall(), 1))));
-}
-
-// Sort a pair of tensors, keys go last.
-TEST_F(GpuSortRewriterTest, SortPairsSwapped) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs_value = f32[] parameter(0)
- %rhs_value = f32[] parameter(1)
- %lhs_key = u32[] parameter(2)
- %rhs_key = u32[] parameter(3)
- ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
-}
-
-ENTRY %main {
- %input_values = f32[1000] parameter(0)
- %input_keys = u32[1000] parameter(1)
- ROOT %sort = (f32[1000], u32[1000]) sort(%input_values, %input_keys),
- dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 1),
- m::GetTupleElement(m::CustomCall(), 0))));
-}
-
-// CUB sort doesn't support more than two tensors.
-TEST_F(GpuSortRewriterTest, NoRewriteManyTensors) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- %unused1 = f64[] parameter(2)
- %unused2 = f64[] parameter(3)
- %unused3 = u64[] parameter(4)
- %unused4 = u64[] parameter(5)
- ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input1 = f32[1000] parameter(0)
- %input2 = f64[1000] parameter(1)
- %input3 = u64[1000] parameter(2)
- ROOT %sort = (f32[1000], f64[1000], u64[1000]) sort(%input1, %input2, %input3),
- dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Only 1D shapes are supported.
-TEST_F(GpuSortRewriterTest, NoRewriteNonMinorSortDimension) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input = f32[1000,4] parameter(0)
- ROOT %sort = f32[1000,4] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Kernels are compiled for a subset of types.
-TEST_F(GpuSortRewriterTest, NoRewriteUnsupportedType) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = pred[] parameter(0)
- %rhs = pred[] parameter(1)
- ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input = pred[1000] parameter(0)
- ROOT %sort = pred[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Comparer must be a simple function.
-TEST_F(GpuSortRewriterTest, NoRewriteComplexComparer) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %lhs_scaled = f32[] multiply(%lhs, f32[] constant(2))
- %rhs = f32[] parameter(1)
- ROOT %lt = pred[] compare(%lhs_scaled, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input = f32[1000] parameter(0)
- ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Comparer must use adjacent input values.
-TEST_F(GpuSortRewriterTest, NoRewriteMixedKeysValues) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs_key = u32[] parameter(0)
- %rhs_key = u32[] parameter(1)
- %lhs_value = u32[] parameter(2)
- %rhs_value = u32[] parameter(3)
- ROOT %mixed = pred[] compare(%rhs_key, %lhs_value), direction=LT
-}
-
-ENTRY %main {
- %input_keys = u32[1000] parameter(0)
- %input_values = u32[1000] parameter(1)
- ROOT %sort = (u32[1000], u32[1000]) sort(%input_keys, %input_values),
- dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Small shapes do not see improvement from CUB sort.
-TEST_F(GpuSortRewriterTest, NoRewriteSmallSize) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input = f32[100] parameter(0)
- ROOT %sort = f32[100] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Basic sort: with batch dimension.
-TEST_F(GpuSortRewriterTest, SortWithBatchDim) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input = f32[10,100] parameter(0)
- ROOT %sort = f32[10,100] sort(%input), dimensions={1}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
- ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
- /*descending=*/false);
-}
-
-// Basic sort: with multiple batch dimensions.
-TEST_F(GpuSortRewriterTest, SortWithMultipleBatchDims) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
- %input = f32[10,10,10] parameter(0)
- ROOT %sort = f32[10,10,10] sort(%input), dimensions={2}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
- ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
- /*descending=*/false);
-}
-
-// Sort a pair of tensors (values, indices generated by iota) with a complex
-// compare.
-TEST_F(GpuSortRewriterTest, SortPairsIotaComparerSimple) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = u16[] parameter(0)
- %rhs = u16[] parameter(1)
- %lhs_index = s32[] parameter(2)
- %rhs_index = s32[] parameter(3)
-
- cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
- cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
- cmp_eq = pred[] compare(%lhs, %rhs), direction=EQ
-
- ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
-}
-
-ENTRY %main {
- %inputs = u16[1000] parameter(0)
- %iota = s32[1000] iota(), iota_dimension=0
- ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
- dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
- m::GetTupleElement(m::CustomCall(), 1))));
-}
-
-// Sort a pair of tensors (values, indices generated by iota) with a complex
-// compare computation that matches the output of the StableSortExpander pass.
-TEST_F(GpuSortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) {
- constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
- %lhs = u16[] parameter(0)
- %rhs = u16[] parameter(1)
- %lhs_index = s32[] parameter(2)
- %rhs_index = s32[] parameter(3)
-
- cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
- cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
- cmp_rl = pred[] compare(%rhs, %lhs), direction=GT
- cmp_eq = pred[] compare(cmp_lr, cmp_rl), direction=EQ
-
- ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
-}
-
-ENTRY %main {
- %inputs = u16[1000] parameter(0)
- %iota = s32[1000] iota(), iota_dimension=0
- ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
- dimensions={0}, to_apply=%compare
-})";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
- EXPECT_TRUE(RunModuleAndPass(module.get()));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
- m::GetTupleElement(m::CustomCall(), 1))));
-}
-
-TEST_F(GpuSortRewriterTest, SortSizeThresholdIsSet) {
- EXPECT_EQ(GpuSortRewriter::SortSizeThreshold(), 1000);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc
index d84797d..06e1e6f 100644
--- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc
+++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc
@@ -27,7 +27,7 @@
#include "xla/service/algebraic_simplifier.h"
#include "xla/service/conditional_simplifier.h"
#include "xla/service/gather_expander.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
#include "xla/service/hlo_constant_folding.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/hlo_module_config.h"
diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc
deleted file mode 100644
index 8f5e261..0000000
--- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc
+++ /dev/null
@@ -1,1150 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_windowed_einsum_handler.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/shape_inference.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = match;
-
-// Enables the creation of FP8 GEMM Custom Calls for all-gather and
-// reduce-scatter windowed einsums in gemm_rewriter.cc by moving the scalings
-// and type conversions of FP8 operands into the bodies of their while loops,
-// i.e. rewrites
-//
-// inputs --> dequant --> while loop {dynamic-slice/collective-permute/dot}
-//
-// into
-//
-// inputs --> while loop {dequant --> dynamic-slice/collective-permute/dot}.
-absl::Status ShiftDequantizationF8(const HloComputation* comp,
- const std::array<HloInstruction*, 2>& gte) {
- HloInstruction* while_instr = comp->WhileCallInstruction();
- if (!while_instr) {
- return absl::OkStatus();
- }
-
- // Identify the scalings and type conversions applied to the inputs of the
- // while loop.
- HloInstruction* param_tuple = while_instr->mutable_operand(0);
- std::array<HloInstruction*, 2> binaries, operands, scales;
- for (int k = 0; k < 2; ++k) {
- if (!Match(param_tuple->mutable_operand(k),
- m::AnyOf<HloInstruction>(
- m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])),
- m::Broadcast(m::Op(&scales[k]))),
- m::MultiplyAnyOrder(&binaries[k],
- m::Convert(m::Op(&operands[k])),
- m::Broadcast(m::Op(&scales[k])))))) {
- VLOG(5) << "Unable to identify FP8 dequantization pattern.";
- return absl::OkStatus();
- }
- }
-
- // For the dot to be rewritten by gemm_rewriter.cc into an FP8 GEMM, at most
- // one of the inputs can be F8E5M2.
- std::array<PrimitiveType, 2> operand_types{
- operands[0]->shape().element_type(), operands[1]->shape().element_type()};
- if (!((operand_types[0] == F8E4M3FN && operand_types[1] == F8E4M3FN) ||
- (operand_types[0] == F8E4M3FN && operand_types[1] == F8E5M2) ||
- (operand_types[0] == F8E5M2 && operand_types[1] == F8E4M3FN))) {
- VLOG(5) << "Unsupported types.";
- return absl::OkStatus();
- }
-
- // The dequantized types must be BF16, FP16 or FP32.
- for (int k = 0; k < 2; ++k) {
- if (binaries[k]->shape().element_type() != BF16 &&
- binaries[k]->shape().element_type() != F16 &&
- binaries[k]->shape().element_type() != F32) {
- VLOG(5) << "Unsupported types.";
- return absl::OkStatus();
- }
- }
-
- // The FP8 scaling operands must be scalars.
- if (!ShapeUtil::IsScalar(scales[0]->shape()) ||
- !ShapeUtil::IsScalar(scales[1]->shape())) {
- VLOG(5) << "Scaling factors must be scalars.";
- return absl::OkStatus();
- }
-
- // Identify the dot and collective-permute or dynamic-slice instructions in
- // the all-gather or reduce-scatter patterns in while's body.
- HloComputation* while_body = while_instr->while_body();
- HloComputation* while_condition = while_instr->while_condition();
- HloInstruction* while_root = while_body->root_instruction();
- std::array<HloInstruction*, 2> dots, dyn_slices{nullptr, nullptr},
- coll_perms{nullptr, nullptr};
- if (Match(
- while_root,
- m::Tuple(m::CollectivePermute(
- &coll_perms[1], m::CollectivePermute(
- &coll_perms[0], m::Op().Is(gte[0]))),
- m::Op().Is(gte[1]),
- m::DynamicUpdateSlice(
- m::DynamicUpdateSlice().WithOperand(
- 1, m::Dot(&dots[0], m::Op().Is(gte[0]),
- m::Op().Is(gte[1]))),
- m::Dot(&dots[1], m::Op(), m::Op().Is(gte[1])), m::Op(),
- m::Op(), m::Op()),
- m::Op(), m::Op()))) {
- VLOG(5) << "Identified all-gather windowed einsum pattern.";
- } else if (Match(
- while_root,
- m::Tuple(m::Op().Is(gte[0]), m::Op().Is(gte[1]),
- m::AddAnyOrder(
- m::Dot(&dots[0], m::DynamicSlice(&dyn_slices[0]),
- m::Op().Is(gte[1])),
- m::Op()),
- m::CollectivePermute(m::AddAnyOrder(
- m::Dot(&dots[1], m::DynamicSlice(&dyn_slices[1]),
- m::Op().Is(gte[1])),
- m::Op())),
- m::Op()))) {
- VLOG(5) << "Identified reduce-scatter windowed einsum pattern.";
- } else {
- VLOG(5) << "Unable to identify valid windowed einsum pattern.";
- return absl::OkStatus();
- }
-
- // Replace the dequantized dot operands in the parameter tuple used by while
- // with FP8 operands.
- for (int k = 0; k < 2; ++k) {
- TF_RETURN_IF_ERROR(
- param_tuple->ReplaceOperandWithDifferentShape(k, operands[k]));
- ShapeUtil::UpdateTupleShape(operands[k]->shape(), k,
- param_tuple->mutable_shape());
- param_tuple->AppendOperand(scales[k]);
- ShapeUtil::AppendShapeToTuple(scales[k]->shape(),
- param_tuple->mutable_shape());
- }
-
- // Update the parameter tuples of while's body and condition computations.
- for (HloComputation* while_comp : {while_body, while_condition}) {
- while_comp->ReplaceParameter(
- 0, HloInstruction::CreateParameter(
- 0, param_tuple->shape(),
- while_comp->parameter_instruction(0)->name()));
- }
-
- // In the while body, replace the existing get-tuple-element instructions
- // retrieving BF16/FP16/FP32 dot operands with dequantized get-tuple-element
- // instructions retrieving FP8 dot operands from the input tuple.
- HloInstruction* body_param = while_body->parameter_instruction(0);
- for (int k = 0; k < 2; ++k) {
- TF_ASSIGN_OR_RETURN(HloInstruction * operand_f8,
- MakeGetTupleElementHlo(body_param, k));
-
- if (while_root->operand(k) == gte[k]) {
- TF_RETURN_IF_ERROR(
- while_root->ReplaceOperandWithDifferentShape(k, operand_f8));
- ShapeUtil::UpdateTupleShape(operand_f8->shape(), k,
- while_root->mutable_shape());
- }
-
- TF_ASSIGN_OR_RETURN(
- HloInstruction * operand_scale,
- MakeGetTupleElementHlo(
- body_param, body_param->shape().tuple_shapes_size() - 2 + k));
-
- // Also add the scaling factor to the output tuple of the while body.
- while_root->AppendOperand(operand_scale);
- ShapeUtil::AppendShapeToTuple(operand_scale->shape(),
- while_root->mutable_shape());
-
- // Dequantize the operands of the dots and dynamic-slices.
- HloInstruction* operand_f32 =
- MakeConvertToHlo(operand_f8, gte[k]->shape().element_type());
- HloInstruction* broadcast_scale =
- MakeBroadcastHlo(operand_scale, {}, operand_f32->shape());
- TF_ASSIGN_OR_RETURN(
- HloInstruction * operand_scaled,
- MakeBinaryHlo(binaries[k]->opcode(), operand_f32, broadcast_scale));
-
- // Replace the original get-tuple-element instructions accessing the
- // operands of the dots and dynamic-slices with the dequantized FP8
- // operands. The order of dequantization and dynamic-slices will be
- // exchanged in gemm_rewriter.cc.
- for (int l = 0; l < 2; ++l) {
- if (dots[l]->operand(k) == gte[k]) {
- TF_RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled));
- }
- if (dyn_slices[l] && dyn_slices[l]->operand(0) == gte[k]) {
- TF_RETURN_IF_ERROR(
- dyn_slices[l]->ReplaceOperandWith(0, operand_scaled));
- }
- }
-
- // In the all-gather case, coll_perms[0] has two users, coll_perms[1] and
- // dots[1], which prevents it from being exchanged with dequantization in
- // gemm_rewriter.cc. Instead, directly insert the dequantization before
- // dots[1] here.
- if (coll_perms[0] && coll_perms[0]->operand(0) == gte[k]) {
- std::array<HloInstruction*, 2> coll_perms_f8{nullptr, nullptr};
- // Change the type of both collective-permutes to FP8.
- coll_perms_f8[0] =
- while_body->AddInstruction(coll_perms[0]->CloneWithNewOperands(
- operand_f8->shape(), {operand_f8}));
- coll_perms_f8[1] =
- while_body->AddInstruction(coll_perms[1]->CloneWithNewOperands(
- coll_perms_f8[0]->shape(), {coll_perms_f8[0]}));
-
- // Insert the dequantization between coll_perms[0] and dots[1].
- HloInstruction* coll_perm0_f32 =
- MakeConvertToHlo(coll_perms_f8[0], gte[k]->shape().element_type());
- TF_ASSIGN_OR_RETURN(HloInstruction * x_scaled,
- MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32,
- broadcast_scale));
- TF_RETURN_IF_ERROR(dots[1]->ReplaceOperandWith(0, x_scaled));
-
- // Update the output tuple.
- TF_RETURN_IF_ERROR(
- while_root->ReplaceOperandWithDifferentShape(0, coll_perms_f8[1]));
- ShapeUtil::UpdateTupleShape(coll_perms_f8[1]->shape(), 0,
- while_root->mutable_shape());
- }
- }
-
- // Update the shape of the while call in the parent computation.
- TF_RETURN_IF_ERROR(
- while_instr->ReplaceAllUsesWithDifferentShape(while_instr->AddInstruction(
- while_instr->CloneWithNewShape(while_root->shape()))));
- TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr));
-
- if (coll_perms[0]) {
- TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1]));
- TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0]));
- }
- TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[0]));
- TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[1]));
-
- VLOG(5) << "FP8 dequantization moved into while loop.";
- return absl::OkStatus();
-}
-
-int64_t NumberOfInstructionsInComp(const HloComputation* comp, HloOpcode op) {
- int64_t total_count = 0;
- for (const HloInstruction* inst : comp->instructions()) {
- if (inst->opcode() == op) {
- ++total_count;
- }
- }
- return total_count;
-}
-
-absl::Status UpdateDotAndConsumerConfig(HloInstruction* dot,
- int64_t stream_id) {
- auto dot_gpu_config = dot->backend_config<gpu::GpuBackendConfig>();
- HloInstruction* updater = dot->users()[0];
- auto updater_gpu_config = updater->backend_config<gpu::GpuBackendConfig>();
- dot_gpu_config->set_operation_queue_id(stream_id);
- updater_gpu_config->mutable_wait_on_operation_queues()->Add(stream_id);
-
- TF_RETURN_IF_ERROR(dot->set_backend_config(dot_gpu_config.value()));
- TF_RETURN_IF_ERROR(updater->set_backend_config(updater_gpu_config.value()));
- return absl::OkStatus();
-}
-
-absl::Status SetForceDelayForInstruction(HloInstruction* instr,
- bool force_delay) {
- auto gpu_config = instr->backend_config<gpu::GpuBackendConfig>();
-
- gpu_config->set_force_earliest_schedule(force_delay);
-
- TF_RETURN_IF_ERROR(instr->set_backend_config(gpu_config.value()));
- return absl::OkStatus();
-}
-
-absl::StatusOr<bool> HandleRsWindowedEinsumLoop(HloComputation* comp,
- int64_t stream_id) {
- bool changed = false;
- // If we have a einsum loop with only 1 dot, this means either
- // the loop is not unrolled or only 1 partition is available.
- // It's a no-op in either case.
- if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
- return changed;
- }
- for (auto inst : comp->MakeInstructionPostOrder()) {
- HloInstruction* matched_dot;
- std::array<HloInstruction*, 2> gte;
- // The dot we'd like to parallelize is consuming the second loop input
- // as RHS.
- if (Match(inst,
- m::Dot(&matched_dot,
- m::DynamicSlice().WithOperand(
- 0, m::GetTupleElement(>e[0], m::Parameter(), 0)),
- m::GetTupleElement(>e[1], m::Parameter(), 1)))) {
- // If present, move the dequantization of FP8 operands of the dot into the
- // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
- // Call.
- TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
-
- // Dispatch the dot to additional compute stream.
- TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
- ++stream_id;
- changed = true;
- }
-
- // We need to enforce the first collective-permute to be always scheduled
- // at the beginning of the loop.
- HloInstruction* matched_cp;
- if (Match(inst, m::CollectivePermute(
- &matched_cp, m::GetTupleElement(m::Parameter(), 2)))) {
- TF_RETURN_IF_ERROR(
- SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
- changed = true;
- }
- }
- return changed;
-}
-
-absl::StatusOr<bool> HandleAgWindowedEinsumLoop(HloComputation* comp,
- int64_t stream_id) {
- bool changed = false;
- // If we have a einsum loop with only 1 dot, this means either
- // the loop is not unrolled or only 1 partition is available.
- // It's a no-op in either case.
- if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
- return changed;
- }
- for (auto inst : comp->MakeInstructionPostOrder()) {
- HloInstruction* matched_dot;
- std::array<HloInstruction*, 2> gte;
- // The dot we'd like to parallelize is consuming the second loop input
- // as RHS and first loop input as LHS.
- if (Match(inst, m::Dot(&matched_dot,
- m::GetTupleElement(>e[0], m::Parameter(), 0),
- m::GetTupleElement(>e[1], m::Parameter(), 1)))) {
- // If present, move the dequantization of FP8 operands of the dot into the
- // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
- // Call.
- TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
-
- // Dispatch the dot to additional compute stream.
- TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
- ++stream_id;
- TF_RETURN_IF_ERROR(
- SetForceDelayForInstruction(matched_dot, /*force_delay=*/true));
- changed = true;
- }
-
- // We need to enforce the first collective-permute to be always scheduled
- // at the beginning of the loop.
- HloInstruction* matched_cp;
- if (Match(inst, m::CollectivePermute(
- &matched_cp, m::GetTupleElement(m::Parameter(), 0)))) {
- TF_RETURN_IF_ERROR(
- SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
- changed = true;
- }
- }
- return changed;
-}
-
-static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) {
- const HloInstruction* loop_tuple = while_loop->operand(0);
- const Shape& tuple_shape = loop_tuple->shape();
- CHECK(tuple_shape.IsTuple());
- return tuple_shape.tuple_shapes_size();
-}
-
-absl::Status ProcessWindowedEinsumLoopForActivationCaching(
- GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop,
- HloInstruction* ag_with_shared_operand) {
- HloInstruction* loop = ag_loop.loop;
- // Transform the while body to cache the allgathered result in the
- // output buffer to be consumed by the dot
- HloComputation* while_body = loop->while_body();
- HloInstruction* input_gte;
- for (HloInstruction* gte : while_body->parameter_instruction(0)->users()) {
- if (gte->tuple_index() == 0) {
- input_gte = gte;
- }
- }
- // Get the output operand of the full buffer.
- HloInstruction* root = while_body->root_instruction();
- // Change loop body to include the new input and output element.
- HloInstruction* input_tuple = while_body->parameter_instruction(0);
- const Shape& input_shape = input_tuple->shape();
- // The full buffer that we will use to cache the accumulated activation
- // is the last operand in the output tuple.
- int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop);
- std::vector<Shape> new_input_shapes(input_shape.tuple_shapes().begin(),
- input_shape.tuple_shapes().end());
- new_input_shapes.push_back(ag_with_shared_operand->shape());
- // Update body input shape
- Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes);
- *input_tuple->mutable_shape() = new_input_shape;
- HloInstruction* full_buffer_output_gte =
- while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
- ag_with_shared_operand->shape(), input_tuple,
- full_cache_buffer_index));
-
- // Update condition input shape
- HloComputation* cond_comp = loop->while_condition();
- HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0);
- *cond_input_tuple->mutable_shape() = new_input_shape;
-
- // Update input to the while instruction in parent computation
- HloInstruction* original_while_input = loop->mutable_operand(0);
- HloComputation* parent_comp = loop->parent();
- std::vector<HloInstruction*> new_operands(
- original_while_input->operands().begin(),
- original_while_input->operands().end());
- new_operands.push_back(
- parent_comp->AddInstruction(HloInstruction::CreateBroadcast(
- ag_with_shared_operand->shape(),
- parent_comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(new_input_shapes[0].element_type()))),
- {})));
- HloInstruction* new_while_input =
- parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands));
- TF_RETURN_IF_ERROR(
- loop->ReplaceOperandWithDifferentShape(0, new_while_input));
- TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape(
- original_while_input, new_while_input));
- *loop->mutable_shape() = new_input_shape;
-
- HloInstruction* new_full_buffer_output = nullptr;
- // Find the DUS in the loop body and re-use the slice indices
- // This should just be a constant(0)
- HloInstruction* dus_boundary_constant;
- // The slice we need this time is the output of the first
- // collective-permute
- HloInstruction* first_cp_output;
- for (HloInstruction* gte_user : input_gte->users()) {
- if (gte_user->opcode() == HloOpcode::kCollectivePermute) {
- first_cp_output = gte_user;
- break;
- }
- }
- for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) {
- HloInstruction* slice_indices;
- // If we have a DUS(PARAM,DS) pattern, we need to update the output
- // buffer with the first slice.
- if (Match(inst,
- m::DynamicUpdateSlice(
- m::GetTupleElement(m::Parameter()), m::Op(),
- m::Constant(&dus_boundary_constant),
- m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
- m::Op()))) {
- slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
- dus_boundary_constant->shape(), slice_indices));
- VLOG(5) << "Created slice op for first slice: "
- << slice_indices->ToString();
- full_buffer_output_gte =
- while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
- full_buffer_output_gte->shape(), full_buffer_output_gte,
- input_gte,
- {dus_boundary_constant, slice_indices, dus_boundary_constant}));
- }
- // If we have a DUS(DUS,DS) pattern, then the einsum loop is
- // unrolled, we need to update the output buffer again with the
- // second slice. Since the second slice will have different indices,
- // we need to re-capture slice_indices.
- if (Match(inst,
- m::DynamicUpdateSlice(
- m::DynamicUpdateSlice(), m::Op(), m::Constant(),
- m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
- m::Op()))) {
- slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
- dus_boundary_constant->shape(), slice_indices));
- VLOG(5) << "Created slice op for second slice: "
- << slice_indices->ToString();
- new_full_buffer_output =
- while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
- full_buffer_output_gte->shape(), full_buffer_output_gte,
- first_cp_output,
- {dus_boundary_constant, slice_indices, dus_boundary_constant}));
- }
-
- // If we have a Dot(DS(parameter_index1)), then operands are sharded along
- // the contracting dim. Slice indices will be the contracting dim's slices.
- HloInstruction* slice_index;
- HloInstruction* ds_index_constant;
- HloInstruction* remainder;
- HloInstruction* ds_param;
- // There will be 2 dynamic-slices for unrolled loops, match for each one to
- // get the slice index which will be used to write the corresponding
- // received shard into cached activation buffer. For unrolled loops, we need
- // to write to the final buffer twice per iteration, so we need to match for
- // the correct slice index based on each DS.
- if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) &&
- Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) {
- for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size();
- ds_op_i++) {
- if (!Match(
- ds_param->mutable_operand(ds_op_i),
- m::Reshape(&slice_index, m::DynamicSlice(m::Constant(),
- m::Op(&remainder)))) &&
- !Match(ds_param->mutable_operand(ds_op_i),
- m::Constant(&ds_index_constant))) {
- return absl::OkStatus();
- }
- }
- // First DS has slice index calculated based on loop iterator
- // Remainder(add(gte, partition_id))
- if (Match(remainder,
- m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) {
- full_buffer_output_gte =
- while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
- full_buffer_output_gte->shape(), full_buffer_output_gte,
- input_gte,
- {ds_index_constant, ds_index_constant, slice_index}));
- }
- // Second DS has slice index calculated based on loop iterator+1 hence
- // Remainder(add(add(gte, 1), partition_id))
- if (Match(remainder,
- m::Remainder(
- m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()),
- m::Op()))) {
- new_full_buffer_output =
- while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
- full_buffer_output_gte->shape(), full_buffer_output_gte,
- first_cp_output,
- {ds_index_constant, ds_index_constant, slice_index}));
- }
- }
- }
- std::vector<HloInstruction*> original_operands(root->operands().begin(),
- root->operands().end());
- original_operands.push_back(new_full_buffer_output);
- HloInstruction* new_output_tuple = while_body->AddInstruction(
- HloInstruction::CreateTuple(original_operands));
- TF_RETURN_IF_ERROR(
- while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple));
- return absl::OkStatus();
-}
-
-bool HasReplicaGroups(const HloInstruction* inst) {
- return inst->replica_groups().size() > 0;
-}
-
-bool ShouldAddToChain(const HloInstruction* inst) {
- switch (inst->opcode()) {
- case HloOpcode::kTranspose:
- case HloOpcode::kReshape:
- case HloOpcode::kCopy:
- return inst->user_count() == 1;
- default:
- return false;
- }
-}
-
-struct MatchedGemmA2aResult {
- HloInstruction* producer_gemm;
- HloInstruction* lhs;
- HloInstruction* rhs;
- HloInstruction* a2a_replacement = nullptr;
- bool matched = false;
-};
-
-class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
- public:
- explicit WindowedEinsumVisitor(
- std::vector<GpuWindowedEinsumHandler::WindowedEinsumAgLoops>&
- all_ag_loops)
- : all_ag_loops_(all_ag_loops) {}
- absl::StatusOr<bool> MatchA2aGemmWithIntermediateReshapes(
- HloInstruction* dot, HloInstruction** lhs, HloInstruction** rhs) {
- if (Match(dot, m::Dot(m::AllToAll(lhs).WithOneUse().WithPredicate(
- HasReplicaGroups),
- m::Op(rhs))) &&
- !DynCast<HloAllToAllInstruction>((*lhs))->constrain_layout() &&
- !(*lhs)->shape().IsTuple()) {
- return true;
- }
- std::vector<HloInstruction*> allowed_intermediate_ops(
- {dot->mutable_operand(0)});
-
- HloAllToAllInstruction* matched_a2a = nullptr;
- // We keep pushing until an unmet condition or we have found the a2a.
- while (true) {
- HloInstruction* curr = allowed_intermediate_ops.back();
- if (ShouldAddToChain(curr)) {
- allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
- std::begin(curr->operands()),
- std::end(curr->operands()));
- } else if (curr->opcode() == HloOpcode::kAllToAll &&
- curr->user_count() == 1) {
- matched_a2a = DynCast<HloAllToAllInstruction>(curr);
- allowed_intermediate_ops.pop_back();
- break;
- } else {
- return false;
- }
- }
- CHECK(matched_a2a != nullptr);
- if (matched_a2a->constrain_layout() || matched_a2a->shape().IsTuple() ||
- !HasReplicaGroups(matched_a2a) || !matched_a2a->split_dimension()) {
- return false;
- }
- // We need to create a new a2a that's a direct producer of the dot and
- // replace it with the original a2a. A new reshape will be added to the
- // orginal a2a's input. We first need to determine the new split dimension
- // after all the reshape ops.
- int64_t split_dimension = *matched_a2a->split_dimension();
- for (int64_t i = allowed_intermediate_ops.size() - 1; i >= 0; i--) {
- HloInstruction* current_op = allowed_intermediate_ops[i];
- if (current_op->opcode() == HloOpcode::kReshape) {
- std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
- ShapeUtil::DimensionsUnmodifiedByReshape(
- current_op->operand(0)->shape(), current_op->shape());
- auto it = absl::c_find_if(
- unmodified_dims,
- [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
- return dim_pair.first == split_dimension;
- });
- // Split dimension of a2a has been modified, we cannot deduce the new
- // split dim easily, so skip decomposition.
- if (it == unmodified_dims.end()) {
- VLOG(5) << "Split dimension of: " << matched_a2a->ToShortString()
- << " has been modified by reshapes. Skip process it for "
- "decomposition.";
- return false;
- }
- // Assign the new split dim.
- split_dimension = it->second;
- } else if (current_op->opcode() == HloOpcode::kTranspose) {
- const auto& transpose_dims = current_op->dimensions();
- for (int64_t j = 0; j < transpose_dims.size(); j++) {
- if ((int64_t)transpose_dims[j] == split_dimension) {
- split_dimension = j;
- break;
- }
- }
- }
- }
- TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
- 0, matched_a2a->mutable_operand(0)));
- HloInstruction* new_a2a =
- matched_a2a->parent()->AddInstruction(HloInstruction::CreateAllToAll(
- allowed_intermediate_ops.front()->shape(),
- {allowed_intermediate_ops.front()}, matched_a2a->replica_groups(),
- false, hlo_query::NextChannelId(*matched_a2a->GetModule()),
- split_dimension));
-
- TF_RETURN_IF_ERROR(dot->ReplaceOperandWith(0, new_a2a));
- TF_RETURN_IF_ERROR(
- matched_a2a->parent()->RemoveInstructionAndUnusedOperands(matched_a2a));
- MarkAsChanged();
- *lhs = new_a2a;
- *rhs = dot->mutable_operand(1);
- return true;
- }
-
- absl::Status HandleDot(HloInstruction* dot) override {
- CHECK_EQ(dot->opcode(), HloOpcode::kDot);
- HloComputation* comp = dot->parent();
- // Rewrites a allgather-dot pattern that shares the same operand
- // with a windowed einsum loop to consume the output of the loop
- // and remove the all-gather.
- // Now that we have processed all loops, we can check if there are any
- // allgather-dot pattern that we can optimize. We'd want to transform:
- // input
- // / |
- // / |
- // AG windowed loop
- // /
- // /
- // dot
- // to:
- // input
- // |
- // |
- // windowed loop
- // |
- // |
- // dot
- // The windowed einsum loop will also be rewritten to output the full input
- // to be consumed by the dot. This is advantageous since the chained dot can
- // fully utilize all the resources on the GPU while comm is hidden by the
- // first collective matmul loop.
- for (GpuWindowedEinsumHandler::WindowedEinsumAgLoops ag_loop :
- all_ag_loops_) {
- HloInstruction* loop = ag_loop.loop;
- HloInstruction* ag_operand = nullptr;
-
- if (Match(dot, m::Dot(m::AllGather(&ag_operand), m::Op())) ||
- Match(dot, m::Dot(m::Op(), m::AllGather(&ag_operand)))) {
- HloInstruction* windowed_lhs =
- loop->mutable_operand(0)->mutable_operand(0);
- HloInstruction* ag_with_shared_operand = nullptr;
- if (ag_operand && ag_operand->mutable_operand(0) == windowed_lhs) {
- ag_with_shared_operand = ag_operand;
- }
-
- if (!ag_with_shared_operand) {
- continue;
- }
-
- VLOG(5) << "Found all-gather that shares the same operand with a "
- "windowed einsum loop : "
- << loop->ToString();
-
- if (!ag_loop.consumed) {
- TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching(
- ag_loop, ag_with_shared_operand));
- ag_loop.consumed = true;
- }
- int64_t cache_output_index = dot->operand_index(ag_with_shared_operand);
- HloComputation* comp = dot->parent();
- HloInstruction* new_gte =
- comp->AddInstruction(HloInstruction::CreateGetTupleElement(
- loop, GetAgActivationCacheIndex(loop) - 1));
- TF_RETURN_IF_ERROR(
- dot->ReplaceOperandWith(cache_output_index, new_gte));
- TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand));
- }
- }
- // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms
- // to minimize communication overhead. To do this, the original input will
- // be sliced into replica_group size and perform all-to-all+gemm.
- HloInstruction* lhs;
- HloInstruction* rhs;
- std::vector<xla::ReplicaGroup> replica_groups;
- TF_ASSIGN_OR_RETURN(bool matched,
- MatchA2aGemmWithIntermediateReshapes(dot, &lhs, &rhs));
- if (matched) {
- replica_groups = lhs->replica_groups();
- // We split the a2a+gemm along the contracting dimension into multiple
- // a2a+gemms and perform partial dots, partial results are added to the
- // final output buffer.
- int64_t group_size = replica_groups[0].replica_ids_size();
- if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
- return group.replica_ids_size() != group_size;
- }) != replica_groups.end()) {
- VLOG(5) << "All-to-all split groups don't have the same number of "
- "replicas.";
- return absl::OkStatus();
- }
-
- // Get the dimension to slice for lhs and rhs, we slice on the contracting
- // dimensions to calculate partial results
- const DotDimensionNumbers& original_dot_dnums =
- dot->dot_dimension_numbers();
- const PrecisionConfig& original_precision = dot->precision_config();
- const auto& lhs_contracting_dims =
- dot->dot_dimension_numbers().lhs_contracting_dimensions();
- const auto& rhs_contracting_dims =
- dot->dot_dimension_numbers().rhs_contracting_dimensions();
-
- if (lhs_contracting_dims.size() != 1 ||
- rhs_contracting_dims.size() != 1) {
- VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
- "sharding will be skipped.";
- return absl::OkStatus();
- }
- int64_t lhs_contracting_dim = lhs_contracting_dims[0];
- int64_t rhs_contracting_dim = rhs_contracting_dims[0];
- HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(lhs);
- int64_t contracting_dim_value =
- rhs->shape().dimensions()[rhs_contracting_dim];
-
- // Each split is sliced out of the input buffer, we need to determine the
- // slice sizes and increments.
- std::vector<int64_t> lhs_slice_sizes(a2a->shape().rank(), 0);
- std::vector<int64_t> lhs_slice_increments(a2a->shape().rank(), 1);
- std::vector<int64_t> lhs_slice_max_range(
- a2a->shape().dimensions().begin(), a2a->shape().dimensions().end());
-
- std::vector<int64_t> rhs_slice_sizes(rhs->shape().rank(), 0);
- std::vector<int64_t> rhs_slice_increments(rhs->shape().rank(), 1);
- std::vector<int64_t> rhs_slice_max_range(
- rhs->shape().dimensions().begin(), rhs->shape().dimensions().end());
-
- // Create a zero-valued buffer to hold output.
- HloInstruction* output_buffer =
- comp->AddInstruction(HloInstruction::CreateBroadcast(
- dot->shape(),
- comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(dot->shape().element_type()))),
- {}));
- HloInstruction* a2a_operand = a2a->mutable_operand(0);
- if (contracting_dim_value % group_size) {
- VLOG(5) << absl::StrFormat(
- "Contracting dimension %d needs to be divisible by group_size %d",
- contracting_dim_value, group_size);
- return absl::OkStatus();
- }
- int64_t size_per_split = contracting_dim_value / group_size;
-
- // Each split is sliced out of the input buffer, we need to determine the
- // slice sizes and increments.
- lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
- rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
-
- Shape lhs_slice_shape = a2a->shape();
- Shape rhs_slice_shape = rhs->shape();
-
- lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
- rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
-
- HloInstruction* lhs_slice;
- HloInstruction* rhs_slice;
-
- HloInstruction* partial_result = output_buffer;
-
- Shape partial_all_to_all_shape = lhs_slice_shape;
-
- TF_ASSIGN_OR_RETURN(
- Shape partial_dot_shape,
- ShapeInference::InferDotOpShape(
- partial_all_to_all_shape, rhs_slice_shape, original_dot_dnums,
- /*preferred_element_type=*/std::nullopt));
- int64_t stream_id = hlo_query::NextChannelId(*a2a->GetModule());
- for (int64_t i = 0; i < group_size; ++i) {
- lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
- lhs_slice_shape, a2a_operand, lhs_slice_sizes, lhs_slice_max_range,
- lhs_slice_increments));
- a2a->SetupDerivedInstruction(lhs_slice);
- lhs_slice_sizes[lhs_contracting_dim] =
- lhs_slice_max_range[lhs_contracting_dim];
- lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
-
- rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
- rhs_slice_shape, rhs, rhs_slice_sizes, rhs_slice_max_range,
- rhs_slice_increments));
- a2a->SetupDerivedInstruction(rhs_slice);
- rhs_slice_sizes[rhs_contracting_dim] =
- rhs_slice_max_range[rhs_contracting_dim];
- rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
-
- HloInstruction* partial_all_to_all =
- comp->AddInstruction(HloInstruction::CreateAllToAll(
- partial_all_to_all_shape, {lhs_slice}, a2a->device_list(),
- false, hlo_query::NextChannelId(*a2a->GetModule()),
- a2a->split_dimension()));
- a2a->SetupDerivedInstruction(partial_all_to_all);
-
- HloInstruction* partial_dot =
- comp->AddInstruction(HloInstruction::CreateDot(
- partial_dot_shape, partial_all_to_all, rhs_slice,
- original_dot_dnums, original_precision));
- partial_result = comp->AddInstruction(
- HloInstruction::CreateBinary(partial_dot->shape(), HloOpcode::kAdd,
- partial_dot, partial_result));
- a2a->SetupDerivedInstruction(partial_result);
- TF_RETURN_IF_ERROR(
- UpdateDotAndConsumerConfig(partial_dot, stream_id++));
- }
- TF_RETURN_IF_ERROR(ReplaceInstruction(dot, partial_result));
- }
- return absl::OkStatus();
- }
-
- absl::StatusOr<MatchedGemmA2aResult> MatchGemmA2aWithIntermediateReshapes(
- HloInstruction* inst) {
- MatchedGemmA2aResult result;
- HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(inst);
- if (!HasReplicaGroups(a2a) || a2a->constrain_layout() ||
- a2a->shape().IsTuple()) {
- return result;
- }
- if (Match(a2a, m::AllToAll(m::Dot(&result.producer_gemm, m::Op(&result.lhs),
- m::Op(&result.rhs))
- .WithOneUse()))) {
- result.matched = true;
- return result;
- }
- std::vector<HloInstruction*> allowed_intermediate_ops(
- {a2a->mutable_operand(0)});
-
- HloInstruction* matched_dot = nullptr;
- // We keep pushing until an unmet condition or we have found the producer
- // dot.
- while (true) {
- HloInstruction* curr = allowed_intermediate_ops.back();
- if (ShouldAddToChain(curr)) {
- allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
- std::begin(curr->operands()),
- std::end(curr->operands()));
- } else if (curr->opcode() == HloOpcode::kDot && curr->user_count() == 1) {
- matched_dot = curr;
- allowed_intermediate_ops.pop_back();
- break;
- } else {
- return result;
- }
- }
- CHECK(matched_dot != nullptr);
- // We need to create a new a2a that's a direct consumer of the dot and
- // replace it with the original a2a. A new reshape will be added to the
- // orginal a2a's output. We first need to determine the new split dimension
- // after all the reshape ops.
- int64_t split_dimension = *a2a->split_dimension();
- for (int64_t i = 0; i < allowed_intermediate_ops.size(); i++) {
- HloInstruction* current_op = allowed_intermediate_ops[i];
- if (current_op->opcode() == HloOpcode::kReshape) {
- std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
- ShapeUtil::DimensionsUnmodifiedByReshape(
- current_op->operand(0)->shape(), current_op->shape());
- auto it = absl::c_find_if(
- unmodified_dims,
- [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
- return dim_pair.second == split_dimension;
- });
- // Split dimension of a2a has been modified, we cannot deduce the new
- // split dim easily, so skip decomposition.
- if (it == unmodified_dims.end()) {
- VLOG(5) << "Split dimension of: " << a2a->ToShortString()
- << " has been modified by reshapes. Skip process it for "
- "decomposition.";
- return result;
- }
- // Assign the new split dim.
- split_dimension = it->first;
- } else if (current_op->opcode() == HloOpcode::kTranspose) {
- const auto& transpose_dims = current_op->dimensions();
- split_dimension = transpose_dims[split_dimension];
- }
- }
- result.a2a_replacement =
- matched_dot->parent()->AddInstruction(HloInstruction::CreateAllToAll(
- matched_dot->shape(), {matched_dot}, a2a->replica_groups(), false,
- hlo_query::NextChannelId(*matched_dot->GetModule()),
- split_dimension));
- TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
- 0, result.a2a_replacement));
- inst->SetupDerivedInstruction(result.a2a_replacement);
-
- TF_RETURN_IF_ERROR(
- ReplaceInstruction(inst, allowed_intermediate_ops.front()));
- result.lhs = matched_dot->mutable_operand(0);
- result.rhs = matched_dot->mutable_operand(1);
- result.producer_gemm = matched_dot;
- result.matched = true;
- return result;
- }
-
- // Rewrites an gemm+all-to-all into multiple independent partial gemm+a2a's
- // to minimize communication overhead. To do this, the original input will be
- // sliced into replica_group size and perform gemm+all-to-all.
- absl::Status HandleAllToAll(HloInstruction* inst) override {
- CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll);
- HloComputation* comp = inst->parent();
- // Rewrites a gemm+alltoall into multiple independent partial gemm+a2as
- // to minimize communication overhead.
- std::vector<xla::ReplicaGroup> replica_groups;
- TF_ASSIGN_OR_RETURN(MatchedGemmA2aResult matched_result,
- MatchGemmA2aWithIntermediateReshapes(inst));
- if (matched_result.matched) {
- HloInstruction* a2a = inst;
- if (matched_result.a2a_replacement) {
- a2a = matched_result.a2a_replacement;
- }
- replica_groups = a2a->replica_groups();
- // Similar to a2a+gemm, we split along contracting dimensions
- // and aggregate result at each step.
- int64_t group_size = replica_groups[0].replica_ids_size();
-
- if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
- return group.replica_ids_size() != group_size;
- }) != replica_groups.end()) {
- VLOG(5) << "All-to-all split groups don't have the same number of "
- "replicas.";
- return absl::OkStatus();
- }
-
- // Get the dimension to slice for lhs and rhs, we slice on the contracting
- // dimensions to calculate partial results
- const DotDimensionNumbers& original_dot_dnums =
- matched_result.producer_gemm->dot_dimension_numbers();
- const PrecisionConfig& original_precision =
- matched_result.producer_gemm->precision_config();
- const auto& lhs_contracting_dims =
- matched_result.producer_gemm->dot_dimension_numbers()
- .lhs_contracting_dimensions();
- const auto& rhs_contracting_dims =
- matched_result.producer_gemm->dot_dimension_numbers()
- .rhs_contracting_dimensions();
-
- if (lhs_contracting_dims.size() != 1 ||
- rhs_contracting_dims.size() != 1) {
- VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
- "sharding will be skipped.";
- return absl::OkStatus();
- }
- int64_t lhs_contracting_dim = lhs_contracting_dims[0];
- int64_t rhs_contracting_dim = rhs_contracting_dims[0];
- HloAllToAllInstruction* all_to_all = DynCast<HloAllToAllInstruction>(a2a);
- int64_t contracting_dim_value =
- matched_result.rhs->shape().dimensions()[rhs_contracting_dim];
- // Each split is sliced out of the input buffer, we need to determine the
- // slice sizes and increments.
- std::vector<int64_t> lhs_slice_sizes(matched_result.lhs->shape().rank(),
- 0);
- std::vector<int64_t> lhs_slice_increments(
- matched_result.lhs->shape().rank(), 1);
- std::vector<int64_t> lhs_slice_max_range(
- matched_result.lhs->shape().dimensions().begin(),
- matched_result.lhs->shape().dimensions().end());
-
- std::vector<int64_t> rhs_slice_sizes(matched_result.rhs->shape().rank(),
- 0);
- std::vector<int64_t> rhs_slice_increments(
- matched_result.rhs->shape().rank(), 1);
- std::vector<int64_t> rhs_slice_max_range(
- matched_result.rhs->shape().dimensions().begin(),
- matched_result.rhs->shape().dimensions().end());
-
- // Create a zero-valued buffer to hold output.
- HloInstruction* output_buffer =
- comp->AddInstruction(HloInstruction::CreateBroadcast(
- all_to_all->shape(),
- comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(all_to_all->shape().element_type()))),
- {}));
- if (contracting_dim_value % group_size) {
- VLOG(5) << absl::StrFormat(
- "Contracting dimension %d needs to be divisible by group_size %d",
- contracting_dim_value, group_size);
- return absl::OkStatus();
- }
-
- int64_t size_per_split = contracting_dim_value / group_size;
- // Each split is sliced out of the input buffer, we need to determine the
- // slice sizes and increments.
- lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
- rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
-
- Shape lhs_slice_shape = matched_result.lhs->shape();
- Shape rhs_slice_shape = matched_result.rhs->shape();
-
- lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
- rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
-
- HloInstruction* lhs_slice;
- HloInstruction* rhs_slice;
-
- HloInstruction* partial_result = output_buffer;
- Shape partial_all_to_all_shape = all_to_all->shape();
-
- TF_ASSIGN_OR_RETURN(
- Shape partial_dot_shape,
- ShapeInference::InferDotOpShape(
- lhs_slice_shape, rhs_slice_shape, original_dot_dnums,
- /*preferred_element_type=*/std::nullopt));
- int64_t stream_id = hlo_query::NextChannelId(*all_to_all->GetModule());
- for (int64_t i = 0; i < group_size; ++i) {
- lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
- lhs_slice_shape, matched_result.lhs, lhs_slice_sizes,
- lhs_slice_max_range, lhs_slice_increments));
- all_to_all->SetupDerivedInstruction(lhs_slice);
- lhs_slice_sizes[lhs_contracting_dim] =
- lhs_slice_max_range[lhs_contracting_dim];
- lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
-
- rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
- rhs_slice_shape, matched_result.rhs, rhs_slice_sizes,
- rhs_slice_max_range, rhs_slice_increments));
-
- all_to_all->SetupDerivedInstruction(rhs_slice);
- rhs_slice_sizes[rhs_contracting_dim] =
- rhs_slice_max_range[rhs_contracting_dim];
- rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
-
- HloInstruction* partial_dot = comp->AddInstruction(
- HloInstruction::CreateDot(partial_dot_shape, lhs_slice, rhs_slice,
- original_dot_dnums, original_precision));
-
- HloInstruction* partial_all_to_all =
- comp->AddInstruction(HloInstruction::CreateAllToAll(
- partial_all_to_all_shape, {partial_dot},
- all_to_all->device_list(), false,
- hlo_query::NextChannelId(*all_to_all->GetModule()),
- all_to_all->split_dimension()));
- all_to_all->SetupDerivedInstruction(partial_all_to_all);
- partial_result = comp->AddInstruction(HloInstruction::CreateBinary(
- partial_all_to_all_shape, HloOpcode::kAdd, partial_all_to_all,
- partial_result));
- all_to_all->SetupDerivedInstruction(partial_result);
- TF_RETURN_IF_ERROR(
- UpdateDotAndConsumerConfig(partial_dot, stream_id++));
- }
- TF_RETURN_IF_ERROR(ReplaceInstruction(all_to_all, partial_result));
- }
-
- return absl::OkStatus();
- }
-
- private:
- std::vector<GpuWindowedEinsumHandler::WindowedEinsumAgLoops>& all_ag_loops_;
-};
-
-} // namespace
-
-absl::StatusOr<bool> GpuWindowedEinsumHandler::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_VLOG_LINES(
- 5, "GpuWindowedEinsumHandler::Run(), before:\n" + module->ToString());
- bool changed = false;
- int64_t stream_id = hlo_query::NextChannelId(*module);
-
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- if (comp->name().find(kWindowedEinsumRsLoopName) == 0) {
- VLOG(5) << "Processing computation: " << comp->name();
- TF_ASSIGN_OR_RETURN(bool comp_result,
- HandleRsWindowedEinsumLoop(comp, stream_id));
- changed = comp_result;
- } else if (comp->name().find(kWindowedEinsumAgLoopName) == 0) {
- VLOG(5) << "Processing computation: " << comp->name();
- TF_ASSIGN_OR_RETURN(bool comp_result,
- HandleAgWindowedEinsumLoop(comp, stream_id));
- all_ag_loops_.push_back(
- WindowedEinsumAgLoops(comp->WhileCallInstruction()));
- changed = comp_result;
- }
- }
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- WindowedEinsumVisitor visitor(all_ag_loops_);
- TF_RETURN_IF_ERROR(comp->Accept(&visitor));
- changed |= visitor.changed();
- }
-
- XLA_VLOG_LINES(
- 5, "GpuWindowedEinsumHandler::Run(), after:\n" + module->ToString());
- return changed;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h
deleted file mode 100644
index b511920..0000000
--- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_
-#define XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_
-
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass is targeting the windowed einsum optimization
-// in the SPMD pipeline. It rewrites all-gather+gemm or
-// gemm+reduce-scatter into sharded loops to achieve overlap
-// between sharded gemms and communication. This pass will
-// optimize it on GPU by annotating independent gemms with
-// stream ids in the backend config. By running them in different
-// streams, we can practically achieve overlap between gemms too.
-class GpuWindowedEinsumHandler : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "gpu-windowed-einsum-handler";
- }
-
- struct WindowedEinsumAgLoops {
- explicit WindowedEinsumAgLoops(HloInstruction* loop) : loop(loop) {}
- HloInstruction* loop;
- bool consumed = false;
- };
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- constexpr static const char* kWindowedEinsumRsLoopName =
- "windowed_dot_general_body_rs";
- constexpr static const char* kWindowedEinsumAgLoopName =
- "windowed_dot_general_body_ag";
-
- private:
- std::vector<WindowedEinsumAgLoops> all_ag_loops_;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc
deleted file mode 100644
index 6f23319..0000000
--- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc
+++ /dev/null
@@ -1,918 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_windowed_einsum_handler.h"
-
-#include <memory>
-#include <string>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using GpuWindowedEinsumHanlderTest = HloTestBase;
-
-HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) {
- for (auto inst : comp->instructions()) {
- if (inst->name() == name) {
- return inst;
- }
- }
- return nullptr;
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsHaveStreamIds) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,512,24576]{2,1,0}, bf16[24576,24576]{1,0})->bf16[2048,24576]{1,0}}, num_partitions=4
-
-windowed_dot_general_body_ag.1 {
- param = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
- get-tuple-element = bf16[512,24576]{1,0} get-tuple-element(param), index=0
- collective-permute = bf16[512,24576]{1,0} collective-permute(get-tuple-element), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- get-tuple-element.1 = bf16[24576,24576]{1,0} get-tuple-element(param), index=1
- get-tuple-element.2 = bf16[2048,24576]{1,0} get-tuple-element(param), index=2
- dot.2 = bf16[512,24576]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- constant.1 = s32[4]{0} constant({0, 512, 1024, 1536})
- get-tuple-element.4 = u32[] get-tuple-element(param), index=4
- partition-id = u32[] partition-id()
- add = u32[] add(get-tuple-element.4, partition-id)
- constant = u32[] constant(4)
- remainder = u32[] remainder(add, constant)
- dynamic-slice = s32[1]{0} dynamic-slice(constant.1, remainder), dynamic_slice_sizes={1}
- reshape.4 = s32[] reshape(dynamic-slice)
- constant.2 = s32[] constant(0)
- dynamic-update-slice = bf16[2048,24576]{1,0} dynamic-update-slice(get-tuple-element.2, dot.2, reshape.4, constant.2), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- dot.3 = bf16[512,24576]{1,0} dot(collective-permute, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- constant.3 = u32[] constant(1)
- add.1 = u32[] add(get-tuple-element.4, constant.3)
- add.2 = u32[] add(add.1, partition-id)
- remainder.1 = u32[] remainder(add.2, constant)
- dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.1, remainder.1), dynamic_slice_sizes={1}
- reshape.5 = s32[] reshape(dynamic-slice.1)
- dynamic-update-slice.1 = bf16[2048,24576]{1,0} dynamic-update-slice(dynamic-update-slice, dot.3, reshape.5, constant.2)
- get-tuple-element.3 = bf16[2048,24576]{1,0} get-tuple-element(param), index=3
- add.3 = u32[] add(add.1, constant.3)
- ROOT tuple = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(collective-permute, get-tuple-element.1, dynamic-update-slice.1, get-tuple-element.3, add.3)
-} // windowed_dot_general_body_ag.1
-
-windowed_dot_general_cond_ag {
- param.1 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
- get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
- constant.8 = u32[] constant(4)
- ROOT compare = pred[] compare(get-tuple-element.5, constant.8), direction=LT
-}
-
-ENTRY test_main {
- param.4 = bf16[1,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
- reshape.8 = bf16[512,24576]{1,0} reshape(param.4)
- param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
- constant.18 = bf16[] constant(0)
- broadcast = bf16[2048,24576]{1,0} broadcast(constant.18), dimensions={}
- constant.20 = u32[] constant(0)
- tuple.2 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(reshape.8, param.5, broadcast, broadcast, constant.20)
- while = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag.1
- ROOT get-tuple-element.13 = bf16[2048,24576]{1,0} get-tuple-element(while), index=2
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloInstruction* ag_loop =
- module->entry_computation()->root_instruction()->mutable_operand(0);
- HloComputation* ag_loop_body = ag_loop->while_body();
- HloInstruction* inst = FindInstructionByName(ag_loop_body, "dot.2");
- EXPECT_GT(inst->backend_config<GpuBackendConfig>()->operation_queue_id(), 0);
- EXPECT_TRUE(
- inst->backend_config<GpuBackendConfig>()->force_earliest_schedule());
-
- HloInstruction* cp1 =
- FindInstructionByName(ag_loop_body, "collective-permute");
- EXPECT_TRUE(
- cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, RsLoopsHaveStreamIds) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[2048,24576]{1,0})->bf16[512,24576]{1,0}}, num_partitions=4
-
-windowed_dot_general_body_rs_clone.1 {
- param.2 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
- get-tuple-element.6 = bf16[2048,24576]{1,0} get-tuple-element(param.2), index=0
- get-tuple-element.7 = bf16[24576,24576]{1,0} get-tuple-element(param.2), index=1
- get-tuple-element.9 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=2
- collective-permute.1 = bf16[512,24576]{1,0} collective-permute(get-tuple-element.9), channel_id=4, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- constant.10 = s32[4]{0} constant({0, 512, 1024, 1536})
- get-tuple-element.11 = u32[] get-tuple-element(param.2), index=4
- constant.12 = u32[] constant(2)
- add.8 = u32[] add(get-tuple-element.11, constant.12)
- constant.13 = u32[] constant(1)
- add.9 = u32[] add(add.8, constant.13)
- partition-id.3 = u32[] partition-id()
- add.10 = u32[] add(add.9, partition-id.3)
- constant.9 = u32[] constant(4)
- remainder.3 = u32[] remainder(add.10, constant.9)
- dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.10, remainder.3), dynamic_slice_sizes={1}
- reshape.7 = s32[] reshape(dynamic-slice.4)
- constant.11 = s32[] constant(0)
- dynamic-slice.5 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.7, constant.11), dynamic_slice_sizes={512,24576}
- dot.7 = bf16[512,24576]{1,0} dot(dynamic-slice.5, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- add.11 = bf16[512,24576]{1,0} add(collective-permute.1, dot.7), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- get-tuple-element.10 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=3
- add.6 = u32[] add(get-tuple-element.11, partition-id.3)
- remainder.2 = u32[] remainder(add.6, constant.9)
- dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.10, remainder.2), dynamic_slice_sizes={1}
- reshape.6 = s32[] reshape(dynamic-slice.2)
- dynamic-slice.3 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.6, constant.11), dynamic_slice_sizes={512,24576}
- dot.5 = bf16[512,24576]{1,0} dot(dynamic-slice.3, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- add.7 = bf16[512,24576]{1,0} add(get-tuple-element.10, dot.5), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
- collective-permute.2 = bf16[512,24576]{1,0} collective-permute(add.7), channel_id=5, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
- ROOT tuple.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(get-tuple-element.6, get-tuple-element.7, add.11, collective-permute.2, add.8)
-}
-
-windowed_dot_general_cond_rs {
- param.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
- get-tuple-element.12 = u32[] get-tuple-element(param.3), index=4
- constant.17 = u32[] constant(4)
- ROOT compare.1 = pred[] compare(get-tuple-element.12, constant.17), direction=LT
-}
-
-ENTRY main.9_spmd {
- param.6 = bf16[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
- param.7 = bf16[512,24576]{1,0} parameter(1)
- param.8 = bf16[2048,24576]{1,0} parameter(2)
- constant.20 = u32[] constant(0)
- tuple.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(param.8, param.6, param.7, param.7, constant.20)
- while.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs_clone.1
- ROOT get-tuple-element.14 = bf16[512,24576]{1,0} get-tuple-element(while.1), index=2
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloInstruction* rs_loop =
- module->entry_computation()->root_instruction()->mutable_operand(0);
- HloComputation* rs_loop_body = rs_loop->while_body();
- HloInstruction* inst = FindInstructionByName(rs_loop_body, "dot.7");
- EXPECT_TRUE(inst->backend_config<GpuBackendConfig>()->operation_queue_id() >
- 0);
-
- HloInstruction* cp1 =
- FindInstructionByName(rs_loop_body, "collective-permute.1");
- EXPECT_TRUE(
- cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsMultipleConsumersAreChained) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[24576,24576]{1,0})->bf16[2,2048,24576]{2,1,0}}, num_partitions=4
-
-windowed_dot_general_body_ag {
- param.1 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
- get-tuple-element.1 = bf16[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
- collective-permute = bf16[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
- collective-permute.1 = bf16[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=3, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
- get-tuple-element.2 = bf16[24576,24576]{1,0} get-tuple-element(param.1), index=1
- get-tuple-element.3 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
- dot = bf16[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- constant.2 = s32[] constant(0)
- constant.3 = s32[4]{0} constant({0, 512, 1024, 1536})
- get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
- partition-id = u32[] partition-id()
- add = u32[] add(get-tuple-element.5, partition-id)
- constant.1 = u32[] constant(4)
- remainder = u32[] remainder(add, constant.1)
- dynamic-slice = s32[1]{0} dynamic-slice(constant.3, remainder), dynamic_slice_sizes={1}
- reshape = s32[] reshape(dynamic-slice)
- dynamic-update-slice = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.2, reshape, constant.2)
- dot.1 = bf16[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- constant.5 = u32[] constant(1)
- add.1 = u32[] add(get-tuple-element.5, constant.5)
- add.2 = u32[] add(add.1, partition-id)
- remainder.1 = u32[] remainder(add.2, constant.1)
- dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.3, remainder.1), dynamic_slice_sizes={1}
- reshape.1 = s32[] reshape(dynamic-slice.1)
- dynamic-update-slice.1 = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.2, reshape.1, constant.2)
- get-tuple-element.4 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
- add.3 = u32[] add(add.1, constant.5)
- ROOT tuple = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
-} // windowed_dot_general_body_ag
-
-windowed_dot_general_cond_ag {
- param = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
- get-tuple-element = u32[] get-tuple-element(param), index=4
- constant = u32[] constant(4)
- ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
-}
-
-ENTRY main.12_spmd {
- param.4 = bf16[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
- param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
- constant.22 = bf16[] constant(0)
- broadcast = bf16[2,2048,24576]{2,1,0} broadcast(constant.22), dimensions={}
- constant.24 = u32[] constant(0)
- tuple.2 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
- while = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
- get-tuple-element.13 = bf16[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
- copy.1 = bf16[2,2048,24576]{2,1,0} copy(get-tuple-element.13)
- all-gather = bf16[2,2048,24576]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true
- param.6 = bf16[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]}
- ROOT dot.7 = bf16[2,2048,24576]{2,1,0} dot(all-gather, param.6), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- EXPECT_TRUE(changed);
- HloInstruction* ag_loop =
- FindInstructionByName(module->entry_computation(), "while");
- HloInstruction* inst =
- FindInstructionByName(module->entry_computation(), "dot.7");
- // dot.7 should now consume output of the windowed einsum while loop.
- EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
- EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
- EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
-
- // while loop's root should now have a chain of DUS.
- HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction();
- EXPECT_THAT(ag_while_root,
- GmockMatch(m::Tuple(
- m::Op(), m::Op(), m::Op(), m::Op(), m::Op(),
- m::DynamicUpdateSlice(
- m::DynamicUpdateSlice(
- m::GetTupleElement(m::Parameter())
- .WithPredicate([](const HloInstruction* instr) {
- return instr->tuple_index() == 5;
- }),
- m::Op(), m::Op(), m::Op(), m::Op()),
- m::Op(), m::Op(), m::Op(), m::Op()))));
-}
-TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,8192]{3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=8
-
-ENTRY main.9_spmd {
- param0 = bf16[1,8192,32768]{2,1,0} parameter(0)
- param1 = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
- all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(param1), channel_id=4, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={1}
- ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(all-to-all, param0), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-}
-)";
-
- const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["5"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
-
-CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
- RunFileCheck(module->ToString(), kExpected));
- EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aHaveStreamIds) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,32768]{3,2,1,0})->bf16[1,4,2048,8192]{3,2,1,0}}, num_partitions=4
-
-ENTRY main.9_spmd {
- param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
- param.10 = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
- dot.12 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.10, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}
- ROOT all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(dot.12), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={1}
-}
-)";
-
- const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
-CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [24576:32768]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [16384:24576]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [8192:16384]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
-
-CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
- RunFileCheck(module->ToString(), kExpected));
- EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, A2aTransposeLoopsHaveStreamIds) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=4
-
-ENTRY main.9_spmd {
- param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
- param.10 = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
- all-to-all = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} all-to-all(param.10), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={3}
- transpose.15 = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(all-to-all), dimensions={0,3,1,2,4,5}
- reshape.2170 = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(transpose.15)
- reshape.2173 = bf16[4,8192,1,2048]{3,2,1,0} reshape(reshape.2170)
- transpose.16 = bf16[1,4,2048,8192]{2,0,3,1} transpose(reshape.2173), dimensions={2,0,3,1}
- copy.53 = bf16[1,4,2048,8192]{3,2,1,0} copy(transpose.16)
- ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(copy.53, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-}
-)";
-
- const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
-CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} %[[P1:.*]]), dimensions={0,3,1,2,4,5}
-CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} %[[TRANSPOSE0:.*]])
-CHECK-DAG: %[[RESHAPE1:.*]] = bf16[4,8192,1,2048]{3,2,1,0} reshape(bf16[1,4,8192,1,2048]{4,3,2,1,0} %[[RESHAPE0:.*]])
-CHECK-DAG: %[[TRANSPOSE1:.*]] = bf16[1,4,2048,8192]{2,0,3,1} transpose(bf16[4,8192,1,2048]{3,2,1,0} %[[RESHAPE1:.*]]), dimensions={2,0,3,1}
-CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{2,0,3,1} %[[TRANSPOSE1:.*]])
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["8"],"force_earliest_schedule":false}
-
-CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- EXPECT_TRUE(changed);
- TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
- RunFileCheck(module->ToString(), kExpected));
- EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aTransposeLoopsHaveStreamIds) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,4,2048,32768]{3,2,1,0}, bf16[1,32768,8192]{2,1,0})->bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0}}, num_partitions=4
-
-ENTRY main.9_spmd {
- param.9 = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
- param.10 = bf16[1,32768,8192]{2,1,0} parameter(1)
- dot.13 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.9, param.10), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
- copy.55 = bf16[1,4,2048,8192]{3,2,1,0} copy(dot.13)
- transpose.17 = bf16[4,1,2048,8192]{3,2,0,1} transpose(copy.55), dimensions={1,0,2,3}
- copy.56 = bf16[4,1,2048,8192]{3,2,1,0} copy(transpose.17)
- reshape.2216 = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(copy.56)
- reshape.2219 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(reshape.2216)
- ROOT all-to-all.1 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} all-to-all(reshape.2219), channel_id=7, replica_groups={{0,1,2,3}}, dimensions={1}
-}
-)";
-
- const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
-CHECK-DAG: %[[P0:.*]] = bf16[1,32768,8192]{2,1,0} parameter(1)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [24576:32768], [0:8192]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"12","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [16384:24576], [0:8192]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"11","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [8192:16384], [0:8192]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"10","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
-CHECK: replica_groups={
-CHECK: {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
-CHECK-DAG: %[[ADD3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
-
-CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{3,2,1,0} %[[ADD3:.*]])
-CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[4,1,2048,8192]{3,2,0,1} transpose(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), dimensions={1,0,2,3}
-CHECK-DAG: %[[COPY1:.*]] = bf16[4,1,2048,8192]{3,2,1,0} copy(bf16[4,1,2048,8192]{3,2,0,1} %[[TRANSPOSE0:.*]])
-CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(bf16[4,1,2048,8192]{3,2,1,0} %[[COPY1:.*]])
-
-CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,2048,8192]{4,3,2,1,0} %[[RESHAPE0:.*]])
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- EXPECT_TRUE(changed);
- TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
- RunFileCheck(module->ToString(), kExpected));
- EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, AllGatherF8) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4
-
-windowed_dot_general_body_ag {
- param.1 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
- get-tuple-element.1 = f32[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
- collective-permute = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
- collective-permute.1 = f32[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
- get-tuple-element.2 = f32[24576,24576]{1,0} get-tuple-element(param.1), index=1
- get-tuple-element.3 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
- dot = f32[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- constant.12 = s32[] constant(0)
- constant.13 = s32[4]{0} constant({0, 512, 1024, 1536})
- get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
- partition-id = u32[] partition-id()
- add = u32[] add(get-tuple-element.5, partition-id)
- constant.11 = u32[] constant(4)
- remainder = u32[] remainder(add, constant.11)
- dynamic-slice = s32[1]{0} dynamic-slice(constant.13, remainder), dynamic_slice_sizes={1}
- reshape = s32[] reshape(dynamic-slice)
- dynamic-update-slice = f32[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.12, reshape, constant.12)
- dot.1 = f32[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- constant.15 = u32[] constant(1)
- add.1 = u32[] add(get-tuple-element.5, constant.15)
- add.2 = u32[] add(add.1, partition-id)
- remainder.1 = u32[] remainder(add.2, constant.11)
- dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.13, remainder.1), dynamic_slice_sizes={1}
- reshape.1 = s32[] reshape(dynamic-slice.1)
- dynamic-update-slice.1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.12, reshape.1, constant.12)
- get-tuple-element.4 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
- add.3 = u32[] add(add.1, constant.15)
- ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
-} // windowed_dot_general_body_ag
-
-windowed_dot_general_cond_ag {
- param = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
- get-tuple-element = u32[] get-tuple-element(param), index=4
- constant.10 = u32[] constant(4)
- ROOT compare = pred[] compare(get-tuple-element, constant.10), direction=LT
-}
-
-ENTRY test_main {
- param.4 = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
- reshape.8 = f8e4m3fn[2,512,24576]{2,1,0} reshape(param.4)
- param.5 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
- constant.18 = f32[] constant(0)
- broadcast = f32[2,2048,24576]{2,1,0} broadcast(constant.18), dimensions={}
- constant.20 = u32[] constant(0)
- scale_lhs = f32[] parameter(2)
- scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
- lhs_bf32 = f32[2,512,24576]{2,1,0} convert(reshape.8)
- lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_bf32, scale_lhs_bcast)
- scale_rhs = f32[] parameter(3)
- scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
- rhs_bf32 = f32[24576,24576]{1,0} convert(param.5)
- rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf32, scale_rhs_bcast)
- tuple.2 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, broadcast, broadcast, constant.20)
- while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
- ROOT get-tuple-element.13 = f32[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
-}
-)";
-
- RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(),
- R"(
-; CHECK-LABEL: windowed_dot_general_body_ag
-; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
-; CHECK-NEXT: [[GTE0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=0
-; CHECK-NEXT: [[CP0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[GTE0]]), channel_id=4
-; CHECK-NEXT: [[CP1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[CP0]]), channel_id=5
-; CHECK-NEXT: [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
-; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=2
-; CHECK-NEXT: [[CONVERT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[GTE0]])
-; CHECK-NEXT: [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
-; CHECK-NEXT: [[BCAST0:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
-; CHECK-NEXT: [[MUL0:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
-; CHECK-NEXT: [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
-; CHECK-NEXT: [[GTE4:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
-; CHECK-NEXT: [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE4]]), dimensions={}
-; CHECK-NEXT: [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
-; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL0]], [[MUL1]]),
-; CHECK-DAG: lhs_contracting_dims={2},
-; CHECK-DAG: rhs_contracting_dims={0},
-; CHECK-DAG: backend_config={
-; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
-; CHECK-DAG: "wait_on_operation_queues":[],
-; CHECK-DAG: "force_earliest_schedule":true}
-; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0)
-; CHECK-NEXT: [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
-; CHECK-NEXT: [[GTE5:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
-; CHECK-NEXT: [[PID:%[^ ]+]] = u32[] partition-id()
-; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[GTE5]], [[PID]])
-; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(4)
-; CHECK-NEXT: [[REM0:%[^ ]+]] = u32[] remainder([[ADD0]], [[C2]])
-; CHECK-NEXT: [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
-; CHECK-NEXT: [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
-; CHECK-NEXT: [[DUPDATESLICE0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[GTE2]], [[DOT0]], [[C0]], [[RESHAPE0]], [[C0]]),
-; CHECK-DAG: backend_config={
-; CHECK-DAG: "operation_queue_id":"0",
-; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"],
-; CHECK-DAG: "force_earliest_schedule":false}
-; CHECK-NEXT: [[CONVERT2:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[CP0]])
-; CHECK-NEXT: [[MUL2:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT2]], [[BCAST0]])
-; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL2]], [[MUL1]]),
-; CHECK-DAG: lhs_contracting_dims={2},
-; CHECK-DAG: rhs_contracting_dims={0}
-; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(1)
-; CHECK-NEXT: [[ADD1:%[^ ]+]] = u32[] add([[GTE5]], [[C3]])
-; CHECK-NEXT: [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
-; CHECK-NEXT: [[REM1:%[^ ]+]] = u32[] remainder([[ADD2]], [[C2]])
-; CHECK-NEXT: [[DSLICE1:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
-; CHECK-NEXT: [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE1]])
-; CHECK-NEXT: [[DUPDATESLICE1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[DUPDATESLICE0]], [[DOT1]], [[C0]], [[RESHAPE1]], [[C0]])
-; CHECK-NEXT: [[GTE6:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=3
-; CHECK-NEXT: [[ADD3:%[^ ]+]] = u32[] add([[ADD1]], [[C3]])
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[CP1]], [[GTE1]], [[DUPDATESLICE1]], [[GTE6]], [[ADD3]], /*index=5*/[[GTE3]], [[GTE4]])
-)");
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, ReduceScatterF8) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f8e4m3fn[2,2048,24576]{2,1,0}, f32[], f32[])->f32[2,512,24576]{2,1,0}}, num_partitions=4
-
-windowed_dot_general_body_rs {
- param.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
- get-tuple-element.7 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.3), index=0
- get-tuple-element.8 = f32[24576,24576]{1,0} get-tuple-element(param.3), index=1
- get-tuple-element.9 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=2
- collective-permute.2 = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.9), channel_id=9, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
- constant.23 = s32[] constant(0)
- constant.24 = s32[4]{0} constant({0, 512, 1024, 1536})
- get-tuple-element.11 = u32[] get-tuple-element(param.3), index=4
- constant.26 = u32[] constant(2)
- add.8 = u32[] add(get-tuple-element.11, constant.26)
- constant.27 = u32[] constant(1)
- add.9 = u32[] add(add.8, constant.27)
- partition-id.3 = u32[] partition-id()
- add.10 = u32[] add(add.9, partition-id.3)
- constant.22 = u32[] constant(4)
- remainder.3 = u32[] remainder(add.10, constant.22)
- dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.24, remainder.3), dynamic_slice_sizes={1}
- reshape.3 = s32[] reshape(dynamic-slice.4)
- dynamic-slice.5 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.3, constant.23), dynamic_slice_sizes={2,512,24576}
- dot.3 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.5, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- add.11 = f32[2,512,24576]{2,1,0} add(collective-permute.2, dot.3)
- get-tuple-element.10 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=3
- add.6 = u32[] add(get-tuple-element.11, partition-id.3)
- remainder.2 = u32[] remainder(add.6, constant.22)
- dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.24, remainder.2), dynamic_slice_sizes={1}
- reshape.2 = s32[] reshape(dynamic-slice.2)
- dynamic-slice.3 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.2, constant.23), dynamic_slice_sizes={2,512,24576}
- dot.2 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.3, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- add.7 = f32[2,512,24576]{2,1,0} add(get-tuple-element.10, dot.2)
- collective-permute.3 = f32[2,512,24576]{2,1,0} collective-permute(add.7), channel_id=10, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
- ROOT tuple.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(get-tuple-element.7, get-tuple-element.8, add.11, collective-permute.3, add.8)
-} // windowed_dot_general_body_rs
-
-windowed_dot_general_cond_rs {
- param.2 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
- get-tuple-element.6 = u32[] get-tuple-element(param.2), index=4
- constant.21 = u32[] constant(4)
- ROOT compare.1 = pred[] compare(get-tuple-element.6, constant.21), direction=LT
-}
-
-ENTRY main.9_spmd {
- param.6 = f8e4m3fn[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
- param.7 = f32[2,512,24576]{2,1,0} parameter(1)
- param.8 = f8e4m3fn[2,2048,24576]{2,1,0} parameter(2)
- constant.20 = u32[] constant(0)
- scale_lhs = f32[] parameter(3)
- scale_lhs_bcast = f32[2,2048,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
- lhs_bf16 = f32[2,2048,24576]{2,1,0} convert(param.8)
- lhs_scaled = f32[2,2048,24576]{2,1,0} multiply(lhs_bf16, scale_lhs_bcast)
- scale_rhs = f32[] parameter(4)
- scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
- rhs_bf16 = f32[24576,24576]{1,0} convert(param.6)
- rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf16, scale_rhs_bcast)
- tuple.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, param.7, param.7, constant.20)
- while.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs
- ROOT get-tuple-element.14 = f32[2,512,24576]{2,1,0} get-tuple-element(while.1), index=2
-}
-)";
-
- RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(),
- R"(
-; CHECK-LABEL: windowed_dot_general_body_rs
-; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
-; CHECK-NEXT: [[GTE0:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=0
-; CHECK-NEXT: [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
-; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=2
-; CHECK-NEXT: [[CP0:%[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[GTE2]]), channel_id=9
-; CHECK-NEXT: [[CONVERT0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[GTE0]])
-; CHECK-NEXT: [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
-; CHECK-NEXT: [[BCAST0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
-; CHECK-NEXT: [[MUL0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
-; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0)
-; CHECK-NEXT: [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
-; CHECK-NEXT: [[GTE4:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
-; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(2)
-; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[GTE4]], [[C2]])
-; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(1)
-; CHECK-NEXT: [[ADD1:%[^ ]+]] = u32[] add([[ADD0]], [[C3]])
-; CHECK-NEXT: [[PID:%[^ ]+]] = u32[] partition-id()
-; CHECK-NEXT: [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
-; CHECK-NEXT: [[C4:%[^ ]+]] = u32[] constant(4)
-; CHECK-NEXT: [[REM0:%[^ ]+]] = u32[] remainder([[ADD2]], [[C4]])
-; CHECK-NEXT: [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
-; CHECK-NEXT: [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
-; CHECK-NEXT: [[DSLICE1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE0]], [[C0]]), dynamic_slice_sizes={2,512,24576}
-; CHECK-NEXT: [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
-; CHECK-NEXT: [[GTE5:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
-; CHECK-NEXT: [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE5]]), dimensions={}
-; CHECK-NEXT: [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
-; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE1]], [[MUL1]]),
-; CHECK-DAG: lhs_contracting_dims={2},
-; CHECK-DAG: rhs_contracting_dims={0},
-; CHECK-DAG: backend_config={
-; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
-; CHECK-DAG: "wait_on_operation_queues":[],
-; CHECK-DAG: "force_earliest_schedule":false}
-; CHECK-NEXT: [[ADD3:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[CP0]], [[DOT0]]),
-; CHECK-DAG: backend_config={"
-; CHECK-DAG: operation_queue_id":"0",
-; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"],
-; CHECK-DAG: "force_earliest_schedule":false}
-; CHECK-NEXT: [[GTE6:[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=3
-; CHECK-NEXT: [[ADD4:%[^ ]+]] = u32[] add([[GTE4]], [[PID]])
-; CHECK-NEXT: [[REM1:%[^ ]+]] = u32[] remainder([[ADD4]], [[C4]])
-; CHECK-NEXT: [[DSLICE2:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
-; CHECK-NEXT: [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE2]])
-; CHECK-NEXT: [[DSLICE3:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE1]], [[C0]]), dynamic_slice_sizes={2,512,24576}
-; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE3]], [[MUL1]]),
-; CHECK-DAG: lhs_contracting_dims={2},
-; CHECK-DAG: rhs_contracting_dims={0}
-; CHECK-NEXT: [[ADD5:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[GTE6]], [[DOT1]])
-; CHECK-NEXT: [[CP1:[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[ADD5]]), channel_id=10
-; CHECK-NEXT: ROOT [[OUT:[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[GTE0]], [[GTE1]], [[ADD3]], [[CP1]], [[ADD0]], /*index=5*/[[GTE3]], [[GTE5]])
-)");
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest,
- AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) {
- constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8
-
-windowed_dot_general_body_ag {
- param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
- get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0
- collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
- collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
- get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1
- get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2
- constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584})
- get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4
- partition-id.194 = u32[] partition-id()
- add.4309 = u32[] add(get-tuple-element.592, partition-id.194)
- constant.11431 = u32[] constant(8)
- remainder.194 = u32[] remainder(add.4309, constant.11431)
- dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1}
- reshape.12959 = s32[] reshape(dynamic-slice.388)
- constant.11433 = s32[] constant(0)
- dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288}
- dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244)
- constant.11434 = u32[] constant(1)
- add.4312 = u32[] add(get-tuple-element.592, constant.11434)
- add.4313 = u32[] add(add.4312, partition-id.194)
- remainder.195 = u32[] remainder(add.4313, constant.11431)
- dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1}
- reshape.12960 = s32[] reshape(dynamic-slice.390)
- dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288}
- dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245)
- get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3
- add.4315 = u32[] add(add.4312, constant.11434)
- ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315)
-} // windowed_dot_general_body_ag
-
-windowed_dot_general_cond_ag {
- param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
- get-tuple-element = u32[] get-tuple-element(param), index=4
- constant = u32[] constant(4)
- ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
-}
-
-ENTRY main.12_spmd {
- param.4 = bf16[16,2048,512]{2,1,0} parameter(0)
- param.5 = bf16[4096,6288]{1,0} parameter(1)
- constant.22 = bf16[] constant(0)
- broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={}
- constant.24 = u32[] constant(0)
- tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
- while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
- get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2
- all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true
- param.6 = bf16[16,2048,6288]{2,1,0} parameter(2)
- ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- GpuWindowedEinsumHandler gpu_handler;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
- EXPECT_TRUE(changed);
-
- HloInstruction* ag_loop =
- FindInstructionByName(module->entry_computation(), "while");
- HloInstruction* inst =
- FindInstructionByName(module->entry_computation(), "dot.7");
- // dot.7 should now consume output of the windowed einsum while loop.
- EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
- EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
- EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
-}
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc
index 8b1efb2..dcca8f3 100644
--- a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc
+++ b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc
@@ -25,8 +25,8 @@
#include "absl/strings/str_cat.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/backend_config.h"
+#include "xla/service/gpu/autotuning/gpu_autotuning.pb.h"
#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gpu_autotuning.pb.h"
#include "xla/stream_executor/dnn.h"
#include "tsl/platform/env.h"
#include "tsl/platform/protobuf.h"
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
index 345fd8c..e527a09 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
@@ -171,15 +171,40 @@
// static
HloFusionAnalysis HloFusionAnalysis::Create(
- const HloFusionInstruction* fusion,
- const se::DeviceDescription* device_info) {
- CHECK(device_info != nullptr);
- FusionBackendConfig backend_config =
- fusion->has_backend_config()
- ? fusion->backend_config<GpuBackendConfig>()->fusion_backend_config()
- : FusionBackendConfig::default_instance();
- return Create(std::move(backend_config),
- HloFusionAdaptor::ForInstruction(fusion), device_info);
+ const HloInstruction& instruction,
+ const se::DeviceDescription& device_info) {
+ absl::StatusOr<GpuBackendConfig> gpu_backend_config =
+ instruction.backend_config<GpuBackendConfig>();
+
+ FusionBackendConfig fusion_backend_config =
+ gpu_backend_config.ok() ? gpu_backend_config->fusion_backend_config()
+ : FusionBackendConfig::default_instance();
+ return Create(std::move(fusion_backend_config),
+ HloFusionAdaptor::ForInstruction(&instruction), &device_info);
+}
+
+// static
+HloFusionAnalysis HloFusionAnalysis::Create(
+ const HloInstruction& producer, const HloInstruction& consumer,
+ const se::DeviceDescription& device_info) {
+ absl::StatusOr<GpuBackendConfig> gpu_backend_config;
+
+ if (consumer.has_backend_config()) {
+ gpu_backend_config = consumer.backend_config<GpuBackendConfig>();
+ }
+
+ if (!gpu_backend_config.ok() && producer.has_backend_config()) {
+ gpu_backend_config = producer.backend_config<GpuBackendConfig>();
+ }
+
+ FusionBackendConfig fusion_backend_config =
+ gpu_backend_config.ok() ? gpu_backend_config->fusion_backend_config()
+ : FusionBackendConfig::default_instance();
+
+ return HloFusionAnalysis::Create(
+ std::move(fusion_backend_config),
+ HloFusionAdaptor::ForProducerConsumer(&producer, &consumer),
+ &device_info);
}
// Returns true if the fusion has consistent transpose heros.
@@ -264,7 +289,7 @@
}
// We expect that the last dimension is swapped with a different dimension.
- if (HasConsistentTransposeHeros() && tiled_transpose_->permutation[2] != 2) {
+ if (HasConsistentTransposeHeros()) {
return EmitterFusionKind::kTranspose;
}
@@ -305,24 +330,5 @@
LOG(FATAL) << "Did not find a hero reduction";
}
-HloFusionAnalysis AnalyzeProducerConsumerFusion(
- const HloInstruction& producer, const HloInstruction& consumer,
- const se::DeviceDescription& device_info) {
- return HloFusionAnalysis::Create(
- consumer.has_backend_config()
- ? consumer.backend_config<GpuBackendConfig>()->fusion_backend_config()
- : producer.backend_config<GpuBackendConfig>()
- ->fusion_backend_config(),
- HloFusionAdaptor::ForProducerConsumer(&producer, &consumer),
- &device_info);
-}
-
-HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer,
- const se::DeviceDescription& device_info) {
- return HloFusionAnalysis::Create(
- consumer.backend_config<GpuBackendConfig>()->fusion_backend_config(),
- HloFusionAdaptor::ForInstruction(&consumer), &device_info);
-}
-
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
index 146224b..c1b7e5b 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
@@ -58,8 +58,17 @@
static HloFusionAnalysis Create(FusionBackendConfig backend_config,
std::unique_ptr<HloFusionAdaptor> fusion,
const se::DeviceDescription* device_info);
- static HloFusionAnalysis Create(const HloFusionInstruction* fusion,
- const se::DeviceDescription* device_info);
+
+ // Creates a HloFusionAnalysis that analyzes just instruction as a standalone
+ // fusion.
+ static HloFusionAnalysis Create(const HloInstruction& instruction,
+ const se::DeviceDescription& device_info);
+
+ // Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer
+ // into consumer.
+ static HloFusionAnalysis Create(const HloInstruction& producer,
+ const HloInstruction& consumer,
+ const se::DeviceDescription& device_info);
const HloFusionAdaptor& fusion() const { return *fusion_; }
@@ -131,17 +140,6 @@
InputOutputInfo input_output_info_;
};
-// Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer
-// into consumer.
-HloFusionAnalysis AnalyzeProducerConsumerFusion(
- const HloInstruction& producer, const HloInstruction& consumer,
- const se::DeviceDescription& device_info);
-
-// Creates a HloFusionAnalysis that analyzes just consumer as a standalone
-// fusion.
-HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer,
- const se::DeviceDescription& device_info);
-
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc
index 04c5819..7328bc6 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc
@@ -15,9 +15,11 @@
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include <gtest/gtest.h>
+#include "xla/protobuf_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_description.pb.h"
#include "xla/tests/hlo_test_base.h"
@@ -48,12 +50,12 @@
auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kLoop);
auto analysis_fused =
- AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+ HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
EXPECT_EQ(analysis_fused.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
@@ -155,7 +157,7 @@
auto* root = module->entry_computation()->root_instruction();
auto analysis =
- AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+ HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
@@ -186,7 +188,7 @@
auto* root = module->entry_computation()->root_instruction();
auto analysis =
- AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+ HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
@@ -223,7 +225,7 @@
auto* root = module->entry_computation()->root_instruction();
auto analysis =
- AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+ HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
@@ -255,7 +257,7 @@
auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
@@ -287,7 +289,7 @@
auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
auto* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info);
// We expect to fallback to the loop emitter, because the two reductions are
// not compatible as they reduce over different dimensions.
EXPECT_EQ(analysis.GetEmitterFusionKind(),
@@ -319,7 +321,7 @@
auto* root = module->entry_computation()->root_instruction();
auto analysis_fused =
- AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+ HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
EXPECT_EQ(analysis_fused.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
@@ -352,5 +354,90 @@
HloFusionAnalysis::EmitterFusionKind::kConcatenate);
}
+TEST_F(HloFusionAnalysisTest, ExtractValidGpuBackendConfig) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation.1 {
+ %x = s32[64] parameter(0)
+ %y = s32[64] parameter(1)
+ ROOT %root = s32[64] add(%x, %y)
+ }
+
+ fused_computation.2 {
+ %x = s32[64] parameter(0)
+ %y = s32[64] parameter(1)
+ ROOT %root = s32[64] add(%x, %y)
+ }
+
+ ENTRY entry {
+ %x = s32[64] parameter(0)
+ %y = s32[64] parameter(1)
+ %fusion.1 = s32[64] fusion(%x, %y), kind=kLoop, calls=fused_computation.1, backend_config={"fusion_backend_config": {kind: "__triton"}}
+ ROOT %fusion.2 = s32[64] fusion(%fusion.1, %y), kind=kLoop, calls=fused_computation.2
+ })"));
+
+ auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ auto* consumer = module->entry_computation()->root_instruction();
+ auto* producer = consumer->operand(0);
+
+ auto producer_analysis = HloFusionAnalysis::Create(*producer, device_info);
+ EXPECT_EQ(producer_analysis.fusion_backend_config().kind(),
+ kTritonFusionKind);
+
+ auto producer_consumer_analysis =
+ HloFusionAnalysis::Create(*producer, *consumer, device_info);
+ EXPECT_EQ(producer_consumer_analysis.fusion_backend_config().kind(),
+ kTritonFusionKind);
+}
+
+TEST_F(HloFusionAnalysisTest,
+ InvalidGpuBackendConfig_SingleInstruction_Ignored) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ ENTRY entry {
+ %x = s32[64,64,64] parameter(0)
+ %y = s32[64,64,64] parameter(1)
+ ROOT %root = s32[64,128,64] concatenate(x, y), dimensions={1}, backend_config={"outer_dimension_partitions": ["1"]}
+ })"));
+
+ auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ auto* root = module->entry_computation()->root_instruction();
+ auto analysis = HloFusionAnalysis::Create(*root, device_info);
+
+ EXPECT_TRUE(
+ protobuf_util::ProtobufEquals(analysis.fusion_backend_config(),
+ FusionBackendConfig::default_instance()));
+}
+
+TEST_F(HloFusionAnalysisTest,
+ InvalidGpuBackendConfig_ProducerConsumer_Ignored) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation {
+ %x = s32[64] parameter(0)
+ %y = s32[64] parameter(1)
+ ROOT %root = s32[64] add(%x, %y)
+ }
+
+ ENTRY entry {
+ %x = s32[64] parameter(0)
+ %y = s32[64] parameter(1)
+ %fusion = s32[64] fusion(%x, %y), kind=kLoop, calls=fused_computation, backend_config={"invalid_field": "some_value"}
+ ROOT %root = s32[128] concatenate(fusion, y), dimensions={0}, backend_config={"invalid_field": "some_value"}
+ })"));
+
+ auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ auto* consumer = module->entry_computation()->root_instruction();
+ auto* producer = consumer->operand(0);
+ auto analysis = HloFusionAnalysis::Create(*producer, *consumer, device_info);
+
+ EXPECT_TRUE(
+ protobuf_util::ProtobufEquals(analysis.fusion_backend_config(),
+ FusionBackendConfig::default_instance()));
+}
+
} // namespace
} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc
deleted file mode 100644
index c693856..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc
+++ /dev/null
@@ -1,192 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_input_fusion.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// Gets the representative input shape of the multi-output fusion.
-Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
- // Get the HLO that determines the emitter used for lowering.
- const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
- if (real_hero->operands().empty()) {
- // Simply return an empty shape if the representative node has no input
- // operands.
- return Shape();
- } else {
- return real_hero->operand(0)->shape();
- }
-}
-
-class HorizontalInputFusionImpl {
- public:
- explicit HorizontalInputFusionImpl(HloComputation* computation,
- const se::DeviceDescription& d)
- : computation_(computation), device_info_(d) {}
-
- ~HorizontalInputFusionImpl() = default;
-
- absl::StatusOr<bool> Run();
-
- private:
- HloComputation* computation_;
- const se::DeviceDescription& device_info_;
-}; // HorizontalInputFusionImpl
-
-// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
-// right.
-bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
- const Shape& shape_b) {
- if (shape_a.rank() != shape_b.rank()) {
- return shape_a.rank() < shape_b.rank();
- }
- auto dims_a = shape_a.dimensions();
- auto dims_b = shape_b.dimensions();
- for (size_t i = 0; i < dims_a.size(); ++i) {
- if (dims_a[i] != dims_b[i]) {
- return dims_a[i] < dims_b[i];
- }
- }
- return true;
-}
-
-std::vector<HloInstruction*> FindAndSortFusionCandidates(
- HloInstruction* consumer) {
- absl::flat_hash_set<HloInstruction*> fusion_instr_set;
- std::vector<HloInstruction*> fusion_instrs;
- for (HloInstruction* opnd : consumer->operands()) {
- HloInstruction* predecessor = opnd->LatestNonGteAncestor();
- // Find out the input fusion instructions whose only consumer is `consumer`.
- // This guarantees that fusing these candidates will never create cycles, as
- // there is no back edge.
- if (IsInputFusibleReduction(*predecessor) &&
- IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
- if (fusion_instr_set.insert(predecessor).second) {
- fusion_instrs.push_back(predecessor);
- }
- }
- }
-
- std::sort(fusion_instrs.begin(), fusion_instrs.end(),
- [&](const HloInstruction* a, const HloInstruction* b) {
- Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
- Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
- if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
- // Sort shapes according to dimensions, so that the same input
- // shapes will be placed adjacent each other.
- return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
- }
- // Sort `fusion_instrs` according to instruction counts, because
- // we'd like to fuse together computations of similar sizes.
- return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
- });
-
- return fusion_instrs;
-}
-
-absl::StatusOr<bool> HorizontalInputFusionImpl::Run() {
- bool changed = false;
- XLA_VLOG_LINES(3, computation_->ToString());
-
- // Using def-to-use order is sound since we do not modify users.
- std::vector<HloInstruction*> def_to_use_order =
- computation_->MakeInstructionPostOrder();
- for (HloInstruction* consumer : def_to_use_order) {
- auto candidates = FindAndSortFusionCandidates(consumer);
- if (candidates.size() <= 1) {
- continue;
- }
-
- // Convert candidates into fusions if needed.
- for (size_t j = 0; j < candidates.size(); ++j) {
- if (candidates[j]->opcode() != HloOpcode::kFusion) {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * fusion_instr,
- MakeFusionInstruction(candidates[j],
- HloInstruction::FusionKind::kInput));
- candidates[j] = fusion_instr;
- changed = true;
- }
- }
-
- size_t fusion_anchor_id = 0;
- for (size_t j = 1; j < candidates.size(); ++j) {
- HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
- HloInstruction* fused = candidates[j];
- if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
- FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) {
- VLOG(3) << "Fuse " << fused->ToString() << " into "
- << fusion_anchor->ToString();
- fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
- changed = true;
- } else {
- // Update the `fusion_anchor_id` since `fused` is either not
- // compatible or not beneficial to be fused with current fusion anchor.
- VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
- fusion_anchor_id = j;
- }
- }
- }
-
- return changed;
-}
-
-} // namespace
-
-absl::StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
- HloComputation* computation) {
- HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_);
- return horizontal_fusion_impl.Run();
-}
-
-absl::StatusOr<bool> GpuHorizontalInputFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- VLOG(2) << "Run horizontal input fusion.";
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
- }
-
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h b/third_party/xla/xla/service/gpu/horizontal_input_fusion.h
deleted file mode 100644
index 370ce7b..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
-#define XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// This optimization pass horizontally fuses kInput fusions to both reduce the
-// kernel launch overhead and increase parallelism degree. See
-// GpuHorizontalFusion for general description and motivation about horizontal
-// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals
-// with kInput fusions.
-//
-// Following GpuHorizontalFusion, a simple yet effective heuristic is used
-// to search the fusion candidates while avoiding creating cycles. That is,
-// we simply search for fusion candidates by looking for instructions whose
-// outputs are all consumed by the same instruction. This catches the typical
-// target cases; often, the candidate instructions are just consumed by the
-// ROOT tuple of the entry computation.
-class GpuHorizontalInputFusion : public HloModulePass {
- public:
- explicit GpuHorizontalInputFusion(const se::DeviceDescription& d)
- : device_info_(d) {}
-
- absl::string_view name() const override {
- return "gpu_horizontal_input_fusion";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- absl::StatusOr<bool> RunOnComputation(HloComputation*);
-
- const se::DeviceDescription& device_info_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc
deleted file mode 100644
index 2d458f9..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc
+++ /dev/null
@@ -1,270 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_input_fusion.h"
-
-#include <cstdint>
-#include <utility>
-#include <vector>
-
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class HorizontalInputFusionTest : public GpuCodegenTest {
- public:
- se::DeviceDescription device_description_{
- TestGpuDeviceInfo::RTXA6000DeviceInfo()};
- GpuHorizontalInputFusion horizontal_input_fusion_{device_description_};
-};
-
-TEST_F(HorizontalInputFusionTest, BasicTest) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule BasicTest
-
- %add_f16 {
- %x = f16[] parameter(0)
- %y = f16[] parameter(1)
- ROOT %add = f16[] add(%x, %y)
- }
-
- fused_computation.1 {
- arg.1 = f16[1024]{0} parameter(0)
- constant0 = f16[] constant(0)
- ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
- }
-
- fused_computation.2 {
- arg.1 = f16[1024]{0} parameter(0)
- constant0 = f16[] constant(0)
- ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
- }
-
- ENTRY entry_computation {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
- fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
- ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
- }
-)")
- .value();
-
- EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
-
- const HloInstruction* entry_root =
- module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(entry_root,
- GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
- (m::GetTupleElement(m::Fusion())))));
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
- auto module = CreateNewVerifiedModule();
-
- HloComputation* reduce_computation;
- {
- auto embedded_builder = HloComputation::Builder("add");
- auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
- auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
- embedded_builder.AddInstruction(
- HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
- reduce_computation =
- module->AddEmbeddedComputation(embedded_builder.Build());
- }
-
- HloComputation::Builder builder(TestName());
- std::vector<HloInstruction*> var_outs;
- auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
- auto output_shape = ShapeUtil::MakeShape(F32, {1024});
- for (int64_t i = 0; i < 130; ++i) {
- // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
- // f32[1024] {
- // %param_0 = f32[1024,1024]{1,0} parameter(0)
- // %param_1 = f32[] parameter(1)
- // %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
- // dimensions={}
- // %multiply = f32[1024,1024]{1,0}
- // multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
- // %broadcast)
- // %constant0 = f32[] constant(0)
- // ROOT %reduce = f32[1024]{0}
- // reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
- // dimensions={1}, to_apply=%add
- // }
- HloInstruction* param_var_in = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
- HloInstruction* param_alpha =
- builder.AddInstruction(HloInstruction::CreateParameter(
- i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
- auto alpha_broadcasted = builder.AddInstruction(
- HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
- auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
- HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
- auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
- output_shape, mul, const0, {1}, reduce_computation));
- var_outs.push_back(reduce);
- }
- builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
- module->AddEntryComputation(builder.Build());
-
- // Verify that horizontal fusion is kicked in. Check that there are multiple
- // `reduce` instructions fused into the same fusion.
- if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() < 4) {
- // 6 is just a randomly picked number as we don't exactly know how large the
- // fusion will be created due to the `FusionFitsInBudget` constraint.
- CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
- /*match_optimized_ir=*/false);
- } else {
- // Verify that we produced a multi-output reduction with independent groups.
- CompileAndVerifyIr(module->Clone(), R"(CHECK: switch {{.*}} label {{.*}} [
- CHECK-NEXT: label)",
- /*match_optimized_ir=*/false);
- }
-
- // Testing with the entire gpu optimization pipeline.
- EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
-}
-
-TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
- // This tests the below pattern. One known issue is that gtes (to fusions) can
- // be removed after their producer fusions are merged. In the below case, gte2
- // and gte6 will be gone if Fusion2 is fused into Fusion1.
- //
- // Fusion1 Fusion2
- // | | | |
- // | gte1 gte2 |
- // | | | |
- // | Fusion3 |
- // | | | |
- // gte3 gte4 gte5 gte6
- // \ | | /
- // =====ROOT=====
- //
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule MultiOutputFusionTest
-
- %add_f16 {
- %x = f16[] parameter(0)
- %y = f16[] parameter(1)
- ROOT %add = f16[] add(%x, %y)
- }
-
- fused_computation.1 {
- arg.1 = f16[1024]{0} parameter(0)
- constant0 = f16[] constant(0)
- reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
- add.0 = f16[1024] add(arg.1, arg.1)
- ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
- }
-
- fused_computation.2 {
- arg.1 = f16[1024]{0} parameter(0)
- constant0 = f16[] constant(0)
- reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
- add.0 = f16[1024] add(arg.1, arg.1)
- ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
- }
-
- fused_computation.3 {
- arg.0 = f16[1024]{0} parameter(0)
- arg.1 = f16[1024]{0} parameter(1)
- add.0 = f16[1024] add(arg.0, arg.1)
- mul.0 = f16[1024] multiply(arg.0, arg.1)
- ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
- }
-
- ENTRY entry_computation {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
- fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
- gte.3 = f16[] get-tuple-element(fusion.1), index=0
- gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
- gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
- gte.6 = f16[] get-tuple-element(fusion.2), index=0
- fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
- kind=kLoop, calls=fused_computation.3
- gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
- gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
- ROOT tuple.1 = (f16[], f16[1024], f16[1024]{0}, f16[])
- tuple(gte.3, gte.4, gte.5, gte.6)
- }
-)")
- .value();
-
- EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
-}
-
-TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NonfusionInstrs
-
- %add_f16 {
- %x = f16[] parameter(0)
- %y = f16[] parameter(1)
- ROOT %add = f16[] add(%x, %y)
- }
-
- ENTRY entry_computation {
- arg.0 = f16[1024]{0} parameter(0)
- arg.1 = f16[1024]{0} parameter(1)
- constant0 = f16[] constant(0)
- reduce.0 = f16[] reduce(arg.0, constant0), dimensions={0}, to_apply=%add_f16
- reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
- ROOT tuple.0 = (f16[], f16[]) tuple(reduce.0, reduce.1)
- }
-)")
- .value();
-
- EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
-
- const HloInstruction* entry_root =
- module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(entry_root,
- GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
- (m::GetTupleElement(m::Fusion())))));
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc
deleted file mode 100644
index 80c46cb..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc
+++ /dev/null
@@ -1,744 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_loop_fusion.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/sub_byte_normalization.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
- auto outputs = GetOutputsOfFusible(fusible);
- CHECK(!outputs.empty());
- PrimitiveType first_output_type = outputs[0]->shape().element_type();
- for (size_t i = 1; i < outputs.size(); ++i) {
- PrimitiveType cur_output_type = outputs[i]->shape().element_type();
- CHECK(first_output_type == cur_output_type)
- << "Output types are expected to be unique, but see "
- << PrimitiveType_Name(first_output_type) << " and "
- << PrimitiveType_Name(cur_output_type);
- }
-
- return first_output_type;
-}
-
-class HorizontalLoopFusionImpl {
- public:
- explicit HorizontalLoopFusionImpl(HloComputation* computation,
- absl::string_view prefix)
- : computation_(computation), prefix_(prefix) {}
-
- ~HorizontalLoopFusionImpl() = default;
-
- absl::StatusOr<bool> Run();
-
- private:
- absl::Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs,
- bool sliced_input_fusion,
- std::vector<HloInstruction*>& to_fuse_candidates);
-
- // If `sliced_input_fusion` is true, Horizontally fuses `fused_fusion_instrs`
- // into kInput computation, else fuses `fused_fusion_instrs` into kLoop
- // computation.
- //
- // It is required that each of `fused_fusion_instrs` is a kLoop fusion. Also,
- // we require their numbers of outputs to be the same, so that each output
- // will be fused/concatenated with the same number of outputs from other fused
- // fusion instrs. Then, all the fused outputs still have the same shapes for
- // kernel generation.
- //
- // Returns the fused computation in `uniq_computation` and the operands that
- // are used by `uniq_computation`.
- absl::Status CreateFusedComputation(
- absl::Span<HloInstruction*> fused_fusion_instrs,
- std::unique_ptr<HloComputation>* uniq_computation,
- std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion);
-
- // Horizontally fuses the operands of consumer instruction,
- // `sliced_input_fusion` controls whether kInput or kLoop type fused
- // instruction want to be created. `to_fuse_candidates` is the instruction
- // stack that we want to try horizontally fuse its operands, when we create a
- // new fusion instruction, we push it to the stack in hope to further fuse its
- // operands.
- absl::StatusOr<bool> FuseConsumerOperands(
- HloInstruction* consumer, bool sliced_input_fusion,
- std::vector<HloInstruction*>& to_fuse_candidates);
-
- // FusionCandidates collects profitable candidates for a given consumer
- // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
- // acquire the next set of fusion candidates based on some heuristics.
- class FusionCandidates {
- public:
- explicit FusionCandidates(HloInstruction* consumer,
- bool sliced_input_fusion)
- : fusible_instrs_(),
- pos_(0),
- sliced_input_fusion_(sliced_input_fusion) {
- Initialize(consumer);
- }
-
- // Gets a span of fusions to be fused.
- absl::Span<HloInstruction*> GetNextSpanOfFusions();
-
- private:
- void Initialize(HloInstruction*);
-
- std::vector<HloInstruction*> fusible_instrs_;
- // `pos_` points to the start position of the next span.
- size_t pos_;
- // `sliced_input_fusion_` flag controls whether we want to fuse
- // into kLoop (false) or kInput (True) type kernel
- bool sliced_input_fusion_;
- };
-
- HloComputation* computation_;
- std::string prefix_;
-}; // HorizontalLoopFusionImpl
-
-bool IsFusibleCandidate(const HloInstruction& instr) {
- // For now, we do not support fusing instruction with control flow.
- if (!instr.control_successors().empty() ||
- !instr.control_predecessors().empty()) {
- return false;
- }
-
- if (IsNestableVariadicReduction(instr)) {
- return false;
- }
-
- // Require no further check for element-wise instructions.
- if (instr.IsElementwise() && instr.operand_count() > 0) {
- return true;
- }
-
- // Exclude fusions other than kLoop.
- if (!instr.IsLoopFusion()) {
- return false;
- }
-
- // Cannot support fusion who has multiple output types, because the
- // concatenate (inserted for horizontal fusion) requires the same type
- // for all of its operands.
- auto outputs = GetOutputsOfFusible(instr);
- CHECK(!outputs.empty());
- const HloInstruction* first_output = outputs[0];
- for (size_t i = 1; i < outputs.size(); ++i) {
- if (first_output->shape().element_type() !=
- outputs[i]->shape().element_type()) {
- return false;
- }
- }
-
- return true;
-}
-
-// Returns whether `instr` is a profitable candidate to be horizontally fused.
-// Since the primary benefit of horizontal fusion comes from reducing the
-// kernel launch overhead, we want to exclude the instructions with
-// insignificant kernel launch overhead. In other words, we exclude instructions
-// if their computation latencies are longer than launch latencies. We estimate
-// the computation latency of a given instruction by its shapes and the
-// instruction count in its fused computation. We roughly observe that if a
-// fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
-// instructions than `kInstrCountThreshold`, it is launch-latency-bound and
-// profitable by horizontal fusion.
-bool IsProfitableFusionCandidate(const HloInstruction& instr,
- bool sliced_input_fusion) {
- // For kLoop fused kernel, each GPU thread will process 1 or more elements
- // from each horizontal fused operands, while for kInput fused kernel, each
- // GPU thread can only process 1 element. From experience, we enable larger
- // tensor size threshold for kLoop fusion.
- const int64_t kShapeThreshold =
- sliced_input_fusion ? 128 * 2048 : 8192 * 8192;
- const int64_t kInstrCountThreshold = sliced_input_fusion ? 30 : 128;
- const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
- ? instr.fused_expression_root()
- : &instr;
-
- // Too large shapes are not easily profitable.
- if (root->opcode() == HloOpcode::kTuple) {
- // Since all output shapes are the same, use the first shape as the
- // representative.
- Shape shape = root->operand(0)->shape();
- if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
- VLOG(2) << "Profitable check failed due to element count with "
- "sliced_input_fusion="
- << sliced_input_fusion;
- return false;
- }
- } else {
- Shape shape = root->shape();
- if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
- VLOG(2) << "Profiltable check failed due to element size with "
- "sliced_input_fusion="
- << sliced_input_fusion;
- return false;
- }
- }
-
- // Having too many instructions is not easily profitable.
- if (instr.opcode() == HloOpcode::kFusion &&
- instr.fused_instruction_count() > kInstrCountThreshold) {
- return false;
- }
-
- return true;
-}
-
-// Returns whether `fusion_instr` has only row-major layouts.
-// The horizontal fusion excludes computations with non-row-major layouts,
-// because fusing computations with different layouts can result in uncoalesced
-// memory accesses and cause great performance overhead.
-bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
- if (instr.opcode() != HloOpcode::kFusion) {
- return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
- }
-
- auto fused_instrs = instr.fused_instructions_computation()->instructions();
- for (HloInstruction* i : fused_instrs) {
- if (!LayoutUtil::IsDenseArray(i->shape())) {
- continue;
- }
- if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
- return false;
- }
- }
- return true;
-}
-
-// Returns whether any operand of `instr` is a parameter instruction that
-// is shared with `fusion_instrs`.
-bool AnyOpndIsParamSharedAmongFusions(
- const HloInstruction* instr,
- const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
- return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
- return opnd->opcode() == HloOpcode::kParameter &&
- absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
- return user != instr && fusion_instrs.contains(user);
- });
- });
-}
-
-void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
- HloInstruction* consumer) {
- // First, find out all potential target candidates. We will filter out
- // unsupported/non-profitable cases below.
- absl::flat_hash_set<HloInstruction*> fusible_candidates;
- std::vector<HloInstruction*> ordered_fusible_candidates;
- for (HloInstruction* opnd : consumer->operands()) {
- HloInstruction* predecessor = opnd->LatestNonGteAncestor();
- // We support kLoop fusion and element-wise HLOs now. We may extend the
- // support list if needs arise.
- if (IsFusibleCandidate(*predecessor)) {
- if (fusible_candidates.insert(predecessor).second) {
- // Add unseen fusion to ordered list.
- ordered_fusible_candidates.push_back(predecessor);
- }
- }
- }
-
- for (HloInstruction* instr : ordered_fusible_candidates) {
- if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
- VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
- << " rejects maybe illegal instr " << instr->ToString()
- << "; including it may create cycles in HLO.";
- continue;
- } else if (!IsProfitableFusionCandidate(*instr, sliced_input_fusion_)) {
- VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
- << " rejects may-not-be profitable fusion instr"
- << instr->ToString();
- continue;
- } else if (!HasOnlyRowMajorLayout(*instr)) {
- VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
- << " rejects non-row-major fusion instr " << instr->ToString();
- continue;
- } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
- // Don't fuse fusions whose operands are parameter instructions that are
- // shared among fusions because we cannot i/o alias the produced
- // horizontal fusion due to the concat insertion.
- VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
- << " rejects the fusion instr because it shares parameter with"
- << " other fusion candidates, instr: " << instr->ToString();
- continue;
- } else {
- VLOG(2) << "Find a fusion candidate " << instr->ToString();
- // Encapsulate it into a fusion computation for unified representation
- // for later processing.
- fusible_instrs_.push_back(instr);
- }
- }
-
- // Sort `fusible_instrs_` according to output types, the number of outputs,
- // instruction counts, output tensor element count. For sliced input fusion,
- // we only fuse instructions with the same number/type of outputs and whose
- // computations have the same instruction count. For kLoop fusion, we requires
- // the fused instructions to have the same number/type of outputs and also the
- // same output shape. We did a sort here so the fusion candidates is
- // populating a continuous span.
- std::stable_sort(
- fusible_instrs_.begin(), fusible_instrs_.end(),
- [&](const HloInstruction* a, const HloInstruction* b) {
- if (GetUniqueOutputTypeOfFusible(*a) !=
- GetUniqueOutputTypeOfFusible(*b)) {
- return GetUniqueOutputTypeOfFusible(*a) <
- GetUniqueOutputTypeOfFusible(*b);
- } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
- return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
- } else if (GetInstrCountOfFusible(*a) != GetInstrCountOfFusible(*b)) {
- return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
- } else {
- return ShapeUtil::ElementsIn(GetOutputsOfFusible(*a)[0]->shape()) <
- ShapeUtil::ElementsIn(GetOutputsOfFusible(*b)[0]->shape());
- }
- });
-}
-
-// Gets a next span of fusion instructions to be fused.
-absl::Span<HloInstruction*>
-HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
- if (pos_ >= fusible_instrs_.size()) {
- return absl::Span<HloInstruction*>();
- }
-
- // Fusing too many computations at a time may not be easily profitable and
- // may increase compile time due to large kernels. Set a limit to it.
- // From profiling results, we found an issue that large fused horizontal
- // kernel could have lower E2E perf, though the pure GPU kernel time is
- // shorter. TODO task for understanding why E2E perf regression for large
- // horiizontal fused kernel. Use the experience max fusion batch size based on
- // the fused instruction count of the operand
- const auto kMaxFusionBatchSize = [&]() -> int64_t {
- if (sliced_input_fusion_) {
- return 32;
- } else {
- if (fusible_instrs_[pos_]->opcode() == HloOpcode::kFusion) {
- return 32;
- } else {
- return 64;
- }
- }
- }();
-
- size_t left = pos_;
- size_t right = pos_ + 1;
- size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
- PrimitiveType first_output_type =
- GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
- // CUDA has a parameter size limit of ~4k bytes.
- constexpr int64_t kMaxCudaParamSize = 4000;
- size_t accum_io_size = 0;
- size_t accum_num_outputs = 0;
- for (; right < fusible_instrs_.size(); ++right) {
- PrimitiveType cur_output_type =
- GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
- if (first_output_type != cur_output_type) {
- // Cannot fuse computations who have multiple output types.
- break;
- }
- if (first_output_size != GetOutputSizeOfFusible(*fusible_instrs_[right])) {
- // Cannot fuse computations who have different numbers of outputs.
- break;
- }
- if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
- GetInstrCountOfFusible(*fusible_instrs_[right])) {
- // Do not fuse computations of different instruction counts as it may
- // introduce control divergence. This is a very simple heuristic to avoid
- // fusing computations with too much discrepancy and we may improve it
- // when the needs arise.
- break;
- }
- if (!sliced_input_fusion_ &&
- !ShapeUtil::EqualIgnoringElementType(
- GetOutputsOfFusible(*fusible_instrs_[left])[0]->shape(),
- GetOutputsOfFusible(*fusible_instrs_[right])[0]->shape())) {
- // This is for fusing into kLoop type kernel, so we requires that each
- // fusion operand have the same shape
- break;
- }
- size_t num_outputs = GetOutputSizeOfFusible(*fusible_instrs_[right]);
- accum_num_outputs += num_outputs;
- if (accum_num_outputs >= kMaxFusionBatchSize) {
- // Hit max fusion batch size.
- break;
- }
- accum_io_size += fusible_instrs_.at(right)->operand_count() + num_outputs;
- if (accum_io_size * 8 >= kMaxCudaParamSize) {
- break;
- }
- }
- VLOG(2) << "horizontal fuse get instruction span with " << (right - left)
- << " instructions for sliced_input_fusion=" << sliced_input_fusion_
- << " fusion";
- pos_ = right;
- return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
-}
-
-absl::StatusOr<bool> HorizontalLoopFusionImpl::FuseConsumerOperands(
- HloInstruction* consumer, bool sliced_input_fusion,
- std::vector<HloInstruction*>& to_fuse_candidates) {
- bool changed = false;
- FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion);
- while (true) {
- auto fusibles = loop_fusion_candidates.GetNextSpanOfFusions();
- if (fusibles.empty()) {
- break;
- } else if (fusibles.size() == 1) {
- // Skip; there is just one fused_instr.
- continue;
- }
-
- changed = true;
- // Convert fusible into fusion_instrs to simplify the implementation of
- // `Fuse()`.
- std::vector<HloInstruction*> fusion_instrs;
- for (HloInstruction* instr : fusibles) {
- if (instr->opcode() == HloOpcode::kFusion) {
- fusion_instrs.push_back(instr);
- } else {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * fusion_instr,
- MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
- fusion_instrs.push_back(fusion_instr);
- }
- }
-
- TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs), sliced_input_fusion,
- to_fuse_candidates));
- }
- return changed;
-}
-
-absl::Status HorizontalLoopFusionImpl::CreateFusedComputation(
- absl::Span<HloInstruction*> fused_fusion_instrs,
- std::unique_ptr<HloComputation>* uniq_computation,
- std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion) {
- // First, build a computation with only params.
- HloComputation::Builder b(prefix_ + "horizontally_fused_computation");
- size_t fused_comp_param_id = 0;
- for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
- auto old_params = fused_fusion_instrs[i]->fused_parameters();
- for (size_t j = 0; j < old_params.size(); ++j) {
- HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
- // in a form of param_i_j
- b.AddInstruction(HloInstruction::CreateParameter(
- fused_comp_param_id++, bound_opnd->shape(),
- absl::StrCat("param_", i, "_", j)));
- bound_operands->push_back(bound_opnd);
- }
- }
- // Always create a dummy tuple instruction to serve as the root of the
- // computation, as the existence of a root instruction is required by the
- // HloComputation. The real root instruction will replace it below.
- HloInstruction* dummy_root = b.AddInstruction(
- HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
- *uniq_computation = b.Build(dummy_root);
- HloComputation* comp = uniq_computation->get();
-
- // Preparing clone_map, which maps old operand to new operand.
- absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
- size_t new_param_id = 0;
- for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
- auto old_params = fused_fusion_instrs[i]->fused_parameters();
- for (size_t j = 0; j < old_params.size(); ++j) {
- HloInstruction* old_param = old_params[j];
- HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
- clone_map.insert({old_param, new_param});
- }
- }
-
- // Clone every fused computation.
- const OpMetadata* metadata = nullptr;
- for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
- auto def_to_use_order = fused_fusion_instrs[i]
- ->fused_instructions_computation()
- ->MakeInstructionPostOrder();
- for (HloInstruction* old_instr : def_to_use_order) {
- if (old_instr->opcode() == HloOpcode::kParameter ||
- (sliced_input_fusion && old_instr->opcode() == HloOpcode::kTuple &&
- old_instr == fused_fusion_instrs[i]->fused_expression_root())) {
- // Parameters have been created, and we don't need tuples from
- // multi-output fusions, as we will directly reference the tuple
- // operands instead by using GetOutputsOfFusible().
- continue;
- }
- std::vector<HloInstruction*> new_opnds;
- const auto& old_opnds = old_instr->operands();
- new_opnds.reserve(old_opnds.size());
- for (HloInstruction* old_opnd : old_opnds) {
- CHECK(clone_map.find(old_opnd) != clone_map.end());
- new_opnds.push_back(clone_map[old_opnd]);
- }
- HloInstruction* new_instr = comp->AddInstruction(
- old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
- clone_map.insert({old_instr, new_instr});
- // Get the metadata from the last fused instruction.
- metadata = &old_instr->metadata();
- }
- }
-
- // Since we require each fusion to have the same number of outputs, we can
- // simply use the first fusion as the representative for output size.
- size_t fused_instr_output_size =
- GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
-
- if (sliced_input_fusion) {
- // Fusing into kInput fusion
- std::vector<HloInstruction*> concated_outputs;
- for (size_t i = 0; i < fused_instr_output_size; ++i) {
- std::vector<HloInstruction*> instr_outputs(fused_fusion_instrs.size());
- for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
- const HloInstruction* old_output =
- GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
- HloInstruction* new_output = clone_map[old_output];
- if (new_output->shape().dimensions_size() == 1) {
- instr_outputs[j] = new_output;
- } else {
- Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout(
- new_output->shape().element_type(),
- {ShapeUtil::ElementsIn(new_output->shape())},
- /*minor_to_major=*/std::vector<int64_t>(1, 0));
- TF_ASSIGN_OR_RETURN(instr_outputs[j],
- MakeReshapeHlo(new_shape, new_output));
- }
- }
- TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
- MakeConcatHlo(instr_outputs, 0));
- concated_outputs.push_back(concated_output);
- }
-
- // Make slices of outputs.
- std::vector<HloInstruction*> output_slices(concated_outputs.size() *
- fused_fusion_instrs.size());
- for (size_t i = 0; i < concated_outputs.size(); ++i) {
- HloInstruction* concated_output = concated_outputs[i];
- int64_t slice_start = 0;
- // Create a slice per fused computation.
- for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
- const HloInstruction* old_output =
- GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
- Shape shape = old_output->shape();
- int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
- TF_ASSIGN_OR_RETURN(
- output_slices[concated_outputs.size() * j + i],
- MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
- /*strides=*/{1}));
- slice_start = slice_limit;
- }
- }
-
- // Make a tuple of output_slices.
- HloInstruction* tuple = comp->AddInstruction(
- HloInstruction::CreateTuple(output_slices), metadata);
- comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
- TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
-
- } else {
- // Fusing into kLoop fusion
- std::vector<HloInstruction*> tuple_operands(fused_instr_output_size *
- fused_fusion_instrs.size());
- // If fusing into kLoop fusion, the new fusion root is tuple of fused
- // fusion computaton's root.
- for (size_t i = 0; i < fused_instr_output_size; ++i) {
- for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
- const HloInstruction* old_output =
- GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
- HloInstruction* new_output = clone_map[old_output];
- tuple_operands[fused_instr_output_size * j + i] = new_output;
- }
- }
- // Make a tuple instruction of fused instruction outputs as
- // the root of fused computation.
- HloInstruction* tuple =
- comp->AddInstruction(HloInstruction::CreateTuple(tuple_operands));
- comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
- TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
- }
-
- return absl::OkStatus();
-}
-
-absl::Status HorizontalLoopFusionImpl::Fuse(
- absl::Span<HloInstruction*> fused_fusion_instrs, bool sliced_input_fusion,
- std::vector<HloInstruction*>& to_fuse_candidates) {
- // Fuse fused_fusion_instrs and replace them with the new fused computation.
- std::unique_ptr<HloComputation> uniq_computation;
- std::vector<HloInstruction*> bound_operands;
-
- TF_RETURN_IF_ERROR(CreateFusedComputation(fused_fusion_instrs,
- &uniq_computation, &bound_operands,
- sliced_input_fusion));
-
- HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
- std::move(uniq_computation));
- HloInstruction* hori_fusion_instr = computation_->AddInstruction(
- HloInstruction::CreateFusion(fused_comp->root_instruction()->shape(),
- sliced_input_fusion
- ? HloInstruction::FusionKind::kInput
- : HloInstruction::FusionKind::kLoop,
- bound_operands, fused_comp, prefix_),
- &fused_comp->root_instruction()->metadata());
- fused_comp->SetFusionInstruction(hori_fusion_instr);
-
- // we push the newly fused instruction into fusion candidate stack, because
- // the operands of the newly fused instruction could now be possible to be
- // horizontally fused.
- to_fuse_candidates.push_back(hori_fusion_instr);
-
- // Insert bitcasts and replace corresponding users. Note that we do not insert
- // the bitcasts in the fused computation as it does not fit into the slice
- // input fusion pattern. However, inserting bitcasts outside the fused
- // computation creates no performance cost.
- size_t total_output_id = 0;
- for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
- std::vector<HloInstruction*> bitcasts_or_gte;
- HloInstruction* fused_instr = fused_fusion_instrs[i];
- size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
- for (size_t j = 0; j < num_outputs; ++j) {
- const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
- TF_ASSIGN_OR_RETURN(
- HloInstruction * gep,
- MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
- // This pass runs late, so useless bitcast won't be cleaned up.
- if (output->shape().dimensions_size() == 1) {
- bitcasts_or_gte.push_back(gep);
- } else {
- bitcasts_or_gte.push_back(computation_->AddInstruction(
- HloInstruction::CreateBitcast(output->shape(), gep)));
- }
- }
- HloInstruction* bitcast_or_tuple =
- (bitcasts_or_gte.size() == 1)
- ? bitcasts_or_gte.at(0)
- : computation_->AddInstruction(
- HloInstruction::CreateTuple(bitcasts_or_gte));
- HloComputation* old_computation =
- fused_instr->fused_instructions_computation();
- HloModule* module = old_computation->parent();
- TF_RETURN_IF_ERROR(
- computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
- TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(old_computation));
- }
-
- TF_RETURN_IF_ERROR(Cast<HloFusionInstruction>(hori_fusion_instr)
- ->DeduplicateFusionOperands());
-
- VLOG(1) << "Fused " << fused_fusion_instrs.size()
- << " instructions into: " << hori_fusion_instr->ToString();
- return absl::OkStatus();
-}
-
-absl::StatusOr<bool> HorizontalLoopFusionImpl::Run() {
- bool changed = false;
- XLA_VLOG_LINES(3, computation_->ToString());
-
- // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
- // shape mismatch but bitcasts could prevent future h-fusion from happening.
- // So, a bottom-up, use-to-def order should be more favorable. It also helps
- // to save compiler iterations to reach the fixed point.
- std::vector<HloInstruction*> to_fuse_candidates =
- computation_->MakeInstructionPostOrder();
-
- while (!to_fuse_candidates.empty()) {
- HloInstruction* consumer = to_fuse_candidates.back();
- to_fuse_candidates.pop_back();
-
- // the consumer may be the operands of previously fused instruction, so
- // it will no longer valid, skip this instruction.
- if (consumer->IsDead()) {
- continue;
- }
-
- // we first try to fuse into kLoop fusion instruction for those operands
- // that have the same shape.
- TF_ASSIGN_OR_RETURN(
- bool loop_fusion_changed,
- FuseConsumerOperands(consumer, false, to_fuse_candidates));
-
- // for the remaining operands with diffent shape, we further try fuse them
- // into kInput fusion instruction.
- TF_ASSIGN_OR_RETURN(
- bool sliced_input_fusion_changed,
- FuseConsumerOperands(consumer, true, to_fuse_candidates));
-
- changed = changed || loop_fusion_changed || sliced_input_fusion_changed;
- }
- return changed;
-}
-
-} // namespace
-
-absl::StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
- HloComputation* computation) {
- HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_);
- return horizontal_fusion_impl.Run();
-}
-
-absl::StatusOr<bool> GpuHorizontalLoopFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- VLOG(2) << "Run horizontal fusion.";
-
- // Run on the entry computation is actually enough.
- TF_ASSIGN_OR_RETURN(bool changed,
- RunOnComputation(module->entry_computation()));
-
- if (changed) {
- // Correctly set element_size_in_bits for any sub-byte added slice and
- // concatenate instructions
- TF_ASSIGN_OR_RETURN(
- [[maybe_unused]] bool unused,
- SubByteNormalization{SubByteNormalization::SET_ELEMENT_SIZE}.Run(
- module));
- }
-
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h
deleted file mode 100644
index 5daed03..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h
+++ /dev/null
@@ -1,148 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
-#define XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
-
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// This optimization pass horizontally fuses computations for reducing kernel
-// launch overhead while increasing kernel launch dims on GPU. The initial
-// motivation of this horizontal fusion is due to the observation that the
-// training optimizer phase (e.g., AdamOptimizer and L2Loss, etc.) typically
-// has many small kernels as a result of applying the same formula on many
-// training parameters (or variables in Tensorflow). Fusing these small
-// kernels, hence, provides performance gain.
-//
-// Theoretically speaking, we may implement a cycle detection algorithm to make
-// sure no cycles are created after fusion. However, cycle detection check is
-// somewhat cumbersome; also, we observe that naive horizontal fusion of
-// arbitrary kernels may not be profitable due to control divergence and
-// possible increase of memory bandwidth pressure due to uncoalesced memory
-// accesses (note that horizontal fusion does not change the amount of memory
-// read+written at all). In practice, a simple yet effective heuristic is used
-// to avoid these issues while addressing the known beneficial cases. That is,
-// we simply search for fusion candidates by looking for instructions whose
-// outputs are all consumed by the same instruction. This catches the cases in
-// the training optimizer phase, as the candidate instructions are typically
-// consumed only by the ROOT tuple of the entry computation.
-//
-// The following illustrates the mechanism of the horizontal fusion. Before
-// fusion, there are two trivial kernels in the illustrating example. One has
-// only a Mul op, while the other consists of only an Add op. Since they are
-// only consumed by the same (ROOT) tuple instruction, horizontal fusion is
-// triggered.
-//
-// i0 i1 i2 i3
-// | | | |
-// v v v v
-// Mul Add
-// | |
-// v v
-// (ROOT) tuple
-//
-// We fuse into one of two possible patterns, depending on whether all the
-// fused operations have the same shape or not.
-//
-// case 1: if Mul and Add's output shape and type are the same, then we fuse
-// them into the below pattern:
-// i0 i1 i2 i3
-// | | | |
-// v v v v
-// Mul Add
-// | |
-// v v
-// (ROOT) tuple
-// the fused kernel will be kLoop type, and GPU code is emitted through
-// the LoopFusion class.
-//
-// case 2: if Mul and Add's output shape are diffent, then we fuse them into
-// the below pattern that adds extra indexing:
-// i0 i1 i2 i3 +++ (Slice) Input Fusion
-// | | | | +
-// v v v v +
-// Mul Add +
-// | | +
-// v v +
-// Reshape0 Reshape1 +
-// | | +
-// v v +
-// Concatenate +
-// | | +
-// v v +
-// Slice0 Slice1 +++
-// | |
-// v v
-// Reshape2 Reshape3
-// | |
-// v v
-// (ROOT) tuple
-//
-// the fused kernel will be kInput type, and, the GPU code is emitted through
-// the InputSlicesFusion class.
-//
-// In theory, the pattern in case 1 could also be fused into the case2 target
-// graph, but we prefer to fuse into kLoop type, because the codegen for it does
-// not have the slicing range check cost introduced by case 2 pattern.
-//
-// Note that the fusion style by case 2 provides an important advantage that
-// kernels of different shapes can be horizontally fused. The first pair of
-// reshapes (i.e., Reshape0 and Reshape1) reshape the dims to 1 dimension, so
-// that the outputs of the fused kernels can (always) be concatenated. The
-// second pair of reshapes (Reshape2 and Reshape3) restore the original shapes
-// to the output tensors.
-//
-// No extra copies are introduced by the horizontal fusion. Besides Reshape2
-// and Reshape3, the other instructions are fused into an input fusion; the
-// output dims of the concatenate will be used as the kernel launch dims.
-// Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the
-// outputs of Mul and Add are row-major.
-//
-// Note, reshapes are added only if the tensors isn't already a vector.
-class GpuHorizontalLoopFusion : public HloModulePass {
- public:
- GpuHorizontalLoopFusion() = default;
- explicit GpuHorizontalLoopFusion(absl::string_view prefix)
- : prefix_(prefix) {}
-
- absl::string_view name() const override {
- return "gpu_horizontal_loop_fusion";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- absl::StatusOr<bool> RunOnComputation(HloComputation*);
- std::string prefix_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc
deleted file mode 100644
index 4045183..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc
+++ /dev/null
@@ -1,851 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_loop_fusion.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/log/log.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/instruction_fusion.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/hlo_pass_fix.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/test.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class HorizontalLoopFusionTest : public HloTestBase {
- public:
- static bool IsFusion(const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kFusion;
- }
-};
-
-TEST_F(HorizontalLoopFusionTest, BasicTest) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule BasicTest
-
- fused_computation.1 {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- arg.3 = f16[123]{0} parameter(2)
- arg.4 = f16[123]{0} parameter(3)
- fusion.1 = f16[1024]{0}
- fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
- fusion.2 = f16[123]{0}
- fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
- ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
- tuple(fusion.1, fusion.2)
- }
-)")
- .value();
-
- EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
- TF_ASSERT_OK(verifier().Run(module.get()).status());
- EXPECT_FALSE(HloDCE().Run(module.get()).value());
-
- const HloInstruction* entry_root =
- module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(entry_root,
- GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement(m::Fusion()))));
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Slice(m::Concatenate()),
- m::Slice(m::Concatenate()))));
-}
-
-// Horizontal fusion should not be triggered as fusion will create cycles.
-TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NegativeTestForCycle
-
- fused_computation.1 {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- arg.3 = f16[123]{0} parameter(2)
- arg.4 = f16[123]{0} parameter(3)
- // fusion.1 and fusion.2 will not be horizontally fused as it will create
- // a cycle through fusion.1 -> add.2 -> fusion.2
- fusion.1 = f16[123]{0}
- fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
- add.2 = f16[123]{0} add(fusion.1, arg.4)
- fusion.2 = f16[123]{0}
- fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
- ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
- tuple(fusion.1, fusion.2, add.2)
- }
-)")
- .value();
-
- EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NegativeTestForIncompatibleTypes
-
- fused_computation.1 {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
- arg.1 = s32[123]{0} parameter(0)
- arg.2 = s32[123]{0} parameter(1)
- ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- arg.3 = s32[123]{0} parameter(2)
- arg.4 = s32[123]{0} parameter(3)
- // fusion.1 and fusion.2 will not be horizontally fused because their output
- // types are different.
- fusion.1 = f16[1024]{0}
- fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
- fusion.2 = s32[123]{0}
- fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
- ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
- tuple(fusion.1, fusion.2)
- }
-)")
- .value();
-
- EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule FusingIntoKLoopAndKInputTogether
-
- fused_computation.1 {
- arg.1 = f16[129, 2048]{1, 0} parameter(0)
- arg.2 = f16[129, 2048]{1, 0} parameter(1)
- ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
- arg.1 = f16[129, 2048]{1, 0} parameter(0)
- arg.2 = f16[129, 2048]{1, 0} parameter(1)
- ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.3 {
- arg.1 = f16[130, 2048]{1, 0} parameter(0)
- arg.2 = f16[130, 2048]{1, 0} parameter(1)
- ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.4 {
- arg.1 = f16[130, 2048]{1, 0} parameter(0)
- arg.2 = f16[130, 2048]{1, 0} parameter(1)
- ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.5 {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- fused_computation.6 {
- arg.1 = f16[128]{0} parameter(0)
- arg.2 = f16[128]{0} parameter(1)
- ROOT add.1 = f16[128]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
- arg.1 = f16[129, 2048]{1, 0} parameter(0)
- arg.2 = f16[129, 2048]{1, 0} parameter(1)
- arg.3 = f16[129, 2048]{1, 0} parameter(2)
- arg.4 = f16[129, 2048]{1, 0} parameter(3)
- arg.5 = f16[130, 2048]{1, 0} parameter(4)
- arg.6 = f16[130, 2048]{1, 0} parameter(5)
- arg.7 = f16[130, 2048]{1, 0} parameter(6)
- arg.8 = f16[130, 2048]{1, 0} parameter(7)
- arg.9 = f16[123]{0} parameter(8)
- arg.10 = f16[123]{0} parameter(9)
- arg.11 = f16[128]{0} parameter(10)
- arg.12 = f16[128]{0} parameter(11)
-
- // fusion.1 and fusion.2 will be fused into kLoop fusion
- // fusion.3 and fusion.4 will be fused into another kLoop fusion
- // fusion.5 and fusion.6 will be fused into kInput fusion
-
- fusion.1 = f16[129,2048]{1, 0}
- fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
-
- fusion.2 = f16[129,2048]{1, 0}
- fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
-
- fusion.3 = f16[130,2048]{1, 0}
- fusion(arg.5, arg.6), kind=kLoop, calls=fused_computation.3
-
- fusion.4 = f16[130,2048]{1, 0}
- fusion(arg.7, arg.8), kind=kLoop, calls=fused_computation.4
-
- fusion.5 = f16[123]{0}
- fusion(arg.9, arg.10), kind=kLoop, calls=fused_computation.5
-
- fusion.6 = f16[128]{0}
- fusion(arg.11, arg.12), kind=kLoop, calls=fused_computation.6
-
- ROOT tuple.1 = (f16[129,2048]{1, 0}, f16[129,2048]{1, 0},
- f16[130,2048]{1, 0}, f16[130,2048]{1, 0},
- f16[123]{0}, f16[128]{0})
- tuple(fusion.1, fusion.2, fusion.3, fusion.4, fusion.5, fusion.6)
- }
-)")
- .value();
-
- EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-
- int input_fusion_count = 0;
- int loop_fusion_count = 0;
- for (auto inst : module->entry_computation()->MakeInstructionPostOrder()) {
- if (inst->opcode() == HloOpcode::kFusion) {
- input_fusion_count +=
- (inst->fusion_kind() == HloInstruction::FusionKind::kInput) ? 1 : 0;
- loop_fusion_count +=
- (inst->fusion_kind() == HloInstruction::FusionKind::kLoop) ? 1 : 0;
- }
- }
- EXPECT_EQ(input_fusion_count, 1);
- EXPECT_EQ(loop_fusion_count, 2);
-}
-
-TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule MergeSharedFusionInstruction
-
- ENTRY MergeSharedFusionInstruction.Computation0 {
- param.1.1 = f32[4,1024]{1,0} parameter(0)
- param.1.2 = f32[4,1024]{1,0} parameter(1)
- param.1.3 = f32[4,1024]{1,0} parameter(2)
- param.2.1 = f32[321,5]{1,0} parameter(3)
- param.2.2 = f32[321,5]{1,0} parameter(4)
- param.2.3 = f32[321,5]{1,0} parameter(5)
- const.1 = f32[] constant(3)
- const.2 = f32[] constant(3)
- broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
- broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
- mul.1.1 = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
- mul.1.2 = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
- add.1 = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
- mul.2.1 = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
- mul.2.2 = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
- add.2 = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
- ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
-})")
- .value();
-
- HloPassPipeline fusion("fusion");
- const se::DeviceDescription device_info =
- TestGpuDeviceInfo::RTXA6000DeviceInfo();
- fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false,
- device_info);
- fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true,
- device_info);
- EXPECT_TRUE(fusion.Run(module.get()).value());
- EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
- TF_ASSERT_OK(verifier().Run(module.get()).status());
-
- VLOG(2) << "Dump after horizontal fusion:";
- VLOG(2) << module->ToString();
-
- const HloInstruction* entry_root =
- module->entry_computation()->root_instruction();
- const HloInstruction* fusion_instr = nullptr;
- // Check that we add bitcast when needed.
- ASSERT_THAT(entry_root,
- GmockMatch(m::Tuple(
- m::Bitcast(m::GetTupleElement(m::Fusion(&fusion_instr))),
- m::Bitcast(m::GetTupleElement(m::Fusion())))));
- ASSERT_TRUE(fusion_instr->IsMultiOutputFusion());
- EXPECT_THAT(fusion_instr->fused_expression_root(),
- GmockMatch(m::Tuple(
- m::Slice(m::Concatenate(m::Reshape(), m::Reshape())),
- m::Slice(m::Concatenate(m::Reshape(), m::Reshape())))));
-
- EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
- HloComputation::Builder builder(TestName());
-
- std::vector<HloInstruction*> var_outs;
- for (int64_t i = 0; i < 128; ++i) {
- // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
- Shape shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
- HloInstruction* param_var_in = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
- HloInstruction* param_alpha =
- builder.AddInstruction(HloInstruction::CreateParameter(
- i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
- HloInstruction* param_delta = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
- HloInstruction* alpha_broadcasted = builder.AddInstruction(
- HloInstruction::CreateBroadcast(shape, param_alpha, {}));
- HloInstruction* alpha_delta =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
- HloInstruction* var_out =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
- var_outs.push_back(var_out);
- }
- builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
-
- auto module = CreateNewVerifiedModule();
- module->AddEntryComputation(builder.Build());
-
- // Testing with the entire gpu optimization pipeline.
- EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule HeterogeneousMultiOutputFusions
-
- fused_computation.1 {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- arg.3 = f16[1024]{0} parameter(2)
- arg.4 = f16[1024]{0} parameter(3)
- mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
- mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
- add.1 = f16[1024]{0} add(mul.1, mul.2)
- ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
- }
-
- fused_computation.2 {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- arg.3 = f16[123]{0} parameter(2)
- arg.4 = f16[123]{0} parameter(3)
- add.1 = f16[123]{0} add(arg.1, arg.2)
- add.2 = f16[123]{0} add(arg.3, arg.4)
- mul.1 = f16[123]{0} multiply(add.1, add.2)
- ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
- }
-
- ENTRY entry_computation {
- arg.1 = f16[1024]{0} parameter(0)
- arg.2 = f16[1024]{0} parameter(1)
- arg.3 = f16[1024]{0} parameter(2)
- arg.4 = f16[1024]{0} parameter(3)
- arg.5 = f16[123]{0} parameter(4)
- arg.6 = f16[123]{0} parameter(5)
- arg.7 = f16[123]{0} parameter(6)
- arg.8 = f16[123]{0} parameter(7)
- fusion.1 = (f16[1024]{0}, f16[1024]{0})
- fusion(arg.1, arg.2, arg.3, arg.4),
- kind=kLoop, calls=fused_computation.1
- fusion.2 = (f16[123]{0}, f16[123]{0})
- fusion(arg.5, arg.6, arg.7, arg.8),
- kind=kLoop, calls=fused_computation.2
- gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
- gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
- gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
- gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
- ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
- tuple(gte.1, gte.2, gte.3, gte.4)
- }
-)")
- .value();
-
- EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
- TF_ASSERT_OK(verifier().Run(module.get()).status());
- EXPECT_FALSE(HloDCE().Run(module.get()).value());
-
- VLOG(2) << "Dump after horizontal fusion:";
- VLOG(2) << module->ToString();
-
- EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
- HloComputation::Builder builder(TestName());
-
- std::vector<HloInstruction*> all_outputs;
- for (int64_t i = 0; i < 48; ++i) {
- Shape shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
- // ms <- grad**2 (1 - rho) + ms * rho
- HloInstruction* grad = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
- HloInstruction* ms = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
- HloInstruction* rho =
- builder.AddInstruction(HloInstruction::CreateParameter(
- i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
- HloInstruction* one_minus_rho =
- builder.AddInstruction(HloInstruction::CreateParameter(
- i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
- HloInstruction* rho_broadcasted =
- builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
- HloInstruction* one_mins_rho_broadcasted = builder.AddInstruction(
- HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
- HloInstruction* grad_squared = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
- HloInstruction* ms_1st_term = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad_squared,
- one_mins_rho_broadcasted));
- HloInstruction* ms_2nd_term =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, ms, rho_broadcasted));
- HloInstruction* ms_out =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
-
- // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
- HloInstruction* momentum = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
- HloInstruction* mom = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
- HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
- i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
- HloInstruction* epsilon =
- builder.AddInstruction(HloInstruction::CreateParameter(
- i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
- HloInstruction* lr_broadcasted =
- builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
- HloInstruction* epsilon_broadcasted = builder.AddInstruction(
- HloInstruction::CreateBroadcast(shape, epsilon, {}));
- HloInstruction* mom_1st_term =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, momentum, mom));
- HloInstruction* ms_eps =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
- HloInstruction* ms_eps_rsq = builder.AddInstruction(
- HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
- HloInstruction* grad_ms_eps_rsq =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
- HloInstruction* mom_2nd_term =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
- HloInstruction* mom_out =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
-
- // var <- var - mom
- HloInstruction* var = builder.AddInstruction(
- HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
- HloInstruction* var_out =
- builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kSubtract, var, mom_out));
-
- all_outputs.push_back(ms_out);
- all_outputs.push_back(mom_out);
- all_outputs.push_back(var_out);
- }
- builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
-
- auto module = CreateNewVerifiedModule();
- module->AddEntryComputation(builder.Build());
-
- EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
-}
-
-TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NegativeTestForDynamicUpdateSlice
-
- fusion.1 {
- p.0 = f16[5,9,10]{2,1,0} parameter(0)
- p.1 = s32[] parameter(1)
- p.2 = f16[1,9,10]{2,1,0} parameter(2)
- c.0 = s32[] constant(0)
- ROOT %dynamic-update-slice =
- f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
- }
-
- fusion.2 {
- p.0 = f16[5,9,10]{2,1,0} parameter(0)
- p.1 = s32[] parameter(1)
- p.2 = f16[1,9,10]{2,1,0} parameter(2)
- c.0 = s32[] constant(0)
- ROOT %dynamic-update-slice =
- f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
- }
-
- ENTRY entry {
- p.00 = f16[5,9,10]{2,1,0} parameter(0)
- p.01 = f16[5,9,10]{2,1,0} parameter(1)
- p.10 = s32[] parameter(2)
- p.11 = s32[] parameter(3)
- p.20 = f16[1,9,10]{2,1,0} parameter(4)
- p.21 = f16[1,9,10]{2,1,0} parameter(5)
-
- f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
- f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
- ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
- })")
- .value();
-
- EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
- TF_ASSERT_OK(verifier().Run(module.get()).status());
- EXPECT_FALSE(HloDCE().Run(module.get()).value());
-
- VLOG(2) << "Dump after horizontal fusion:";
- VLOG(2) << module->ToString();
-
- EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule BasicTest
-
- fused_computation.1 {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
- arg.1 = f16[123]{0} parameter(0)
- arg.2 = f16[123]{0} parameter(1)
- ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
- arg.1 = f16[123]{0} parameter(0)
- // arg.2 is shared by fusion.1 and fusion.2
- arg.2 = f16[123]{0} parameter(1)
- arg.3 = f16[123]{0} parameter(2)
- fusion.1 = f16[123]{0}
- fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
- fusion.2 = f16[123]{0}
- fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
- ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
- tuple(fusion.1, fusion.2)
- }
-)")
- .value();
-
- EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NonfusionInstrs
-
- fused_computation.0 {
- arg.0 = f16[] parameter(0)
- arg.1 = f16[123]{0} parameter(1)
- broadcast.0 = f16[123]{0} broadcast(arg.0), dimensions={}
- ROOT mul.1 = f16[123]{0} multiply(broadcast.0, arg.1)
- }
-
- fused_computation.1 {
- arg.0 = f16[] parameter(0)
- arg.1 = f16[456]{0} parameter(1)
- broadcast.0 = f16[456]{0} broadcast(arg.0), dimensions={}
- ROOT add.1 = f16[456]{0} add(broadcast.0, arg.1)
- }
-
- ENTRY entry_computation {
- arg.0 = f16[] parameter(0)
- arg.1 = f16[] parameter(1)
- arg.2 = f16[123]{0} parameter(2)
- arg.3 = f16[456]{0} parameter(3)
- // Test fusion of non-fusion instructions. sqrt.0 and sqrt.1 are to be
- // fused.
- sqrt.0 = f16[] sqrt(arg.0)
- sqrt.1 = f16[] sqrt(arg.1)
- // fusion.0 and fusion.1 are to be fused.
- fusion.0 = f16[123]{0}
- fusion(sqrt.0, arg.2), kind=kLoop, calls=fused_computation.0
- fusion.1 = f16[456]{0}
- fusion(sqrt.1, arg.3), kind=kLoop, calls=fused_computation.1
- ROOT tuple.1 = (f16[123]{0}, f16[456]{0}) tuple(fusion.0, fusion.1)
- }
-)")
- .value();
-
- HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
- iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
- iterative_h_fusion.AddPass<HloDCE>();
- EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
-
- // Verify that fusion.0 and fusion.1 are fused.
- const HloInstruction* entry_root =
- module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(entry_root,
- GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement(m::Fusion()))));
- EXPECT_TRUE(fusion->IsMultiOutputFusion());
-
- // Verify that the total number of fusion instructions is 2 so that we
- // know sqrt.0 and sqrt.1 are fused.
- EXPECT_EQ(
- absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
- 2);
-}
-
-TEST_F(HorizontalLoopFusionTest, TraversalOrder) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule cluster
-
- %fused_computation (param_0: f32[256,256], param_1: f32[], param_2: f32[])
- -> f32[256,256] {
- %param_0 = f32[256,256]{1,0} parameter(0)
- %param_1 = f32[] parameter(1)
- %param_2 = f32[] parameter(2)
- %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
- %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
- ROOT %multiply.1 = f32[256,256]{1,0}
- multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
- }
-
- %fused_computation.1 (param_0: f32[256,256], param_1: f32[], param_2: f32[])
- -> f32[256,256] {
- %param_0 = f32[256,256]{1,0} parameter(0)
- %param_1 = f32[] parameter(1)
- %param_2 = f32[] parameter(2)
- %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
- %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
- ROOT %multiply.1 = f32[256,256]{1,0}
- multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
- }
-
- ENTRY %entry_computation (arg0: f32[256,256], arg1: f32[256,256], arg2: f32[],
- arg3: f32[], arg4: f32[], arg5: f32[])
- -> (f32[256,256], f32[256,256]) {
- %arg0 = f32[256,256]{1,0} parameter(0), parameter_replication={false}
- %arg1 = f32[256,256]{1,0} parameter(1), parameter_replication={false}
- %arg2 = f32[] parameter(2), parameter_replication={false}
- %arg3 = f32[] parameter(3), parameter_replication={false}
- %arg4 = f32[] parameter(4), parameter_replication={false}
- %arg5 = f32[] parameter(5), parameter_replication={false}
- %sqrt = f32[] sqrt(f32[] %arg2)
- %sqrt.1 = f32[] sqrt(f32[] %arg3)
- %fusion = f32[256,256]{1,0}
- fusion(f32[256,256]{1,0} %arg0, f32[] %sqrt, f32[] %sqrt.1),
- kind=kLoop, calls=%fused_computation
- %sqrt.2 = f32[] sqrt(f32[] %arg4)
- %sqrt.3 = f32[] sqrt(f32[] %arg5)
- %fusion.1 = f32[256,256]{1,0}
- fusion(f32[256,256]{1,0} %arg1, f32[] %sqrt.2, f32[] %sqrt.3),
- kind=kLoop, calls=%fused_computation.1
- ROOT %tuple.163 = (f32[256,256]{1,0}, f32[256,256]{1,0})
- tuple(f32[256,256]{1,0} %fusion.1, f32[256,256]{1,0} %fusion)
- }
-)")
- .value();
-
- HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
- iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
- EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
-
- // Verify that the total number of fusion instructions is 2 so that we
- // know all the sqrt instructions are fused into a kernel. Note that if we
- // traverse from def-to-use (i.e., top-to-down) instead of use-to-def, we
- // will end up having 3 fusions instead of 2.
- EXPECT_EQ(
- absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
- 2);
-}
-
-// Simplified reproducer for Google bug b/242287055.
-// Things that happened:
-// - horizontal loop fusion joined addition a0 and multiplication m0
-// - the resulting fusion had 4 inputs: (gte1, gte0, gte1, gte0)
-// - buffer assignment aliased outputs of this fusion with its inputs
-// - some threads simultaneously did the addition, some - multiplication
-// - as a result some inputs were overwritten before being read
-// Conditional operation is meaningless (branches are equivalent) and
-// is there only to properly confuse the buffer assignment.
-TEST_F(HorizontalLoopFusionTest, NoBufferAliasingOfDuplicateParameter) {
- const char* hlo_text = R"(
-HloModule m
-
-branch_a {
- p0 = s32[] parameter(0)
- c0 = s32[] constant(1)
- c1 = s32[] constant(2)
- b0 = s32[4096] broadcast(c0), dimensions={}
- b1 = s32[4096] broadcast(c1), dimensions={}
- ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
-}
-
-branch_b {
- p0 = s32[] parameter(0)
- c0 = s32[] constant(1)
- c1 = s32[] constant(2)
- b0 = s32[4096] broadcast(c0), dimensions={}
- b1 = s32[4096] broadcast(c1), dimensions={}
- ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
-}
-
-ENTRY e {
- p0 = s32[] parameter(0)
- c0 = s32[] constant(0)
- cond = (s32[4096], s32[4096]) conditional(p0, c0, c0), branch_computations={branch_a, branch_b}
- p1 = s32[4096] parameter(1)
- gte0 = s32[4096] get-tuple-element(cond), index=0
- gte1 = s32[4096] get-tuple-element(cond), index=1
- a0 = s32[4096] add(gte1, gte0)
- m0 = s32[4096] multiply(gte1, gte0)
- ROOT r = (s32[4096], s32[4096]) tuple(m0, a0)
-}
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
-}
-
-TEST_F(HorizontalLoopFusionTest, CopyInsertionFusionControlFlow) {
- const char* hlo_text = R"(
-HloModule cluster
-
-ENTRY main {
- cst = f32[1]{0} constant({0})
- cp1 = f32[1]{0} copy(cst)
- cp2 = f32[1]{0} copy(cst)
- cp3 = f32[1]{0} copy(cst)
- cp4 = f32[1]{0} copy(cst), control-predecessors={cp1}
- ROOT tuple_out = (f32[1]{0}, f32[1]{0}, f32[1]{0}, f32[1]{0}) tuple(cp1, cp2, cp3, cp4)
-}
-)";
-
- auto module = ParseAndReturnUnverifiedModule(hlo_text).value();
- EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-
- VLOG(2) << module->ToString();
-
- // Verify that the total number of fusion instructions is 1.
- EXPECT_EQ(
- absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
- 1);
-
- const HloInstruction* entry_root =
- module->entry_computation()->root_instruction();
- // Check that we fuse when supported.
- EXPECT_THAT(entry_root,
- GmockMatch(m::Tuple(m::Copy(), m::GetTupleElement(m::Fusion()),
- m::GetTupleElement(m::Fusion()), m::Copy())));
-}
-
-TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- fused_computation.94 {
- tmp_0 = f32[] parameter(0)
- tmp_1 = f32[] parameter(1)
- tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
- tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
- tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
- tmp_5 = s32[] parameter(2)
- tmp_6 = s32[] parameter(3)
- tmp_7 = s32[] minimum(tmp_5, tmp_6)
- tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
- tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
- ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
- }
-
- minmax_func.1536 {
- tmp_0 = f32[] parameter(0)
- tmp_1 = f32[] parameter(2)
- tmp_2 = s32[] parameter(1)
- tmp_3 = s32[] parameter(3)
- ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
- }
-
- fused_computation {
- tmp_0 = f32[554112,10]{1,0} parameter(0)
- tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
- tmp_2 = f32[] constant(-inf)
- tmp_3 = s32[] constant(0)
- ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
- }
-
- fused_computation2 {
- tmp_0 = f32[554112,10]{1,0} parameter(0)
- tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
- tmp_2 = f32[] constant(inf)
- tmp_3 = s32[] constant(1)
- ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
- }
-
- ENTRY e {
- tmp_0 = f32[554112,10]{1,0} parameter(0)
- tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
- tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
- tmp_3 = f32[554112,10]{1,0} parameter(1)
- tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_3), kind=kLoop, calls=fused_computation2
- tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
- ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
- })")
- .value();
-
- EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.cc b/third_party/xla/xla/service/gpu/instruction_fusion.cc
deleted file mode 100644
index 8751d44..0000000
--- a/third_party/xla/xla/service/gpu/instruction_fusion.cc
+++ /dev/null
@@ -1,187 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/instruction_fusion.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/meta/type_traits.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/fusion_node_indexing_evaluation.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-bool ElementIsF32OrF16(const Shape& shape) {
- PrimitiveType type = shape.element_type();
- return type == F32 || type == F16;
-}
-
-class EmptyFusionQueue : public FusionQueue {
- public:
- std::pair<HloInstruction*, std::vector<int64_t>>
- DequeueNextInstructionAndOperandsToFuseInOrder() override {
- return {nullptr, {}};
- }
- void RemoveInstruction(HloInstruction* instruction) override {};
- const std::vector<bool>* FusionConfiguration() override { return nullptr; };
-};
-
-} // namespace
-
-absl::StatusOr<bool> GpuInstructionFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- fusion_node_evaluations_.clear();
- auto fusible_computations =
- GetFusibleComputations(*module, execution_threads);
- fusible_computations_ = {fusible_computations.begin(),
- fusible_computations.end()};
- return InstructionFusion::Run(module, execution_threads);
-}
-
-/*static*/ bool GpuInstructionFusion::IsExpensive(
- const HloInstruction& instruction) {
- // Some floating-point math ops are cheap on the GPU.
- switch (instruction.opcode()) {
- case HloOpcode::kDivide:
- case HloOpcode::kSqrt:
- case HloOpcode::kRsqrt:
- case HloOpcode::kExp:
- if (ElementIsF32OrF16(instruction.shape())) {
- return false;
- }
- break;
- default:
- break;
- }
- return InstructionFusion::IsExpensive(instruction);
-}
-
-FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks(
- HloInstruction* consumer, int64_t operand_index) {
- HloInstruction* producer = consumer->mutable_operand(operand_index);
-
- // Output fusions are not currently supported on GPUs.
- if (producer->opcode() == HloOpcode::kFusion) {
- return "the producer is a fusion";
- }
-
- if (consumer->IsCustomFusion()) {
- return "the consumer is a custom fusion";
- }
-
- // Cost condition: not fuse (simple, expensive producers) and (consumers who
- // reuse operand elements).
- if (is_expensive(*producer) &&
- ReusesOperandElements(consumer, operand_index)) {
- return "the producer is expensive, and the consumer reuses inputs";
- }
-
- // Do not fuse into fusions if the resulting kernel would suffer from
- // uncoalesced reads due to a transposed memory access pattern.
- if (IsInputFusibleReduction(*consumer) &&
- IsPhysicallyTransposing(*producer)) {
- return "fusing the producer would break read coalescing";
- }
-
- RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer));
-
- if (CreatesHeavyComputation(*producer, *consumer)) {
- return "the fusion would create a heavy computation";
- }
-
- return InstructionFusion::ShouldFuse(consumer, operand_index);
-}
-
-FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
- int64_t operand_index) {
- RETURN_IF_NOT_FUSIBLE(ShouldFuseInexpensiveChecks(consumer, operand_index));
-
- auto producer = consumer->operand(operand_index);
-
- // The following checks are potentially expensive.
- RETURN_IF_NOT_FUSIBLE(
- FusionFitsInBudget(*consumer, *producer, device_info_,
- /*is_consumer_producer_fusion=*/true));
-
- if (consumer->opcode() != HloOpcode::kFusion) {
- return {};
- }
-
- // Also check that our emitter can handle the fusion node. We currently can
- // have exponential time/memory requirements for emitting certain fusion
- // kernels, in which case we don't want to fuse.
- // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
- if (fusion_node_evaluations_.find(consumer) ==
- fusion_node_evaluations_.end()) {
- // We have no cached results for this fusion node yet. This can happen when
- // we run the InstructionFusion pass more than once. We can only cache the
- // results within one run.
- fusion_node_evaluations_.emplace(consumer,
- FusionNodeIndexingEvaluation(consumer));
- }
- if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
- return "the fusion would result in an overly large code duplication";
- }
- return {};
-}
-
-HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
- const HloInstruction* producer, const HloInstruction* consumer) {
- return ChooseFusionKind(*producer, *consumer);
-}
-
-HloInstruction* GpuInstructionFusion::FuseInstruction(
- HloInstruction* fusion_instruction, HloInstruction* producer) {
- auto evaluation = fusion_node_evaluations_.find(fusion_instruction);
- if (evaluation == fusion_node_evaluations_.end()) {
- evaluation = fusion_node_evaluations_
- .emplace(fusion_instruction,
- FusionNodeIndexingEvaluation(fusion_instruction))
- .first;
- }
- auto indexing_users = evaluation->second.RemoveFusionOperand(producer);
- HloInstruction* new_producer =
- InstructionFusion::FuseInstruction(fusion_instruction, producer);
- evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
- return new_producer;
-}
-
-std::unique_ptr<FusionQueue> GpuInstructionFusion::GetFusionQueue(
- HloComputation* computation) {
- if (fusible_computations_.contains(computation)) {
- return InstructionFusion::GetFusionQueue(computation);
- }
- return std::make_unique<EmptyFusionQueue>();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.h b/third_party/xla/xla/service/gpu/instruction_fusion.h
deleted file mode 100644
index 29eb032..0000000
--- a/third_party/xla/xla/service/gpu/instruction_fusion.h
+++ /dev/null
@@ -1,82 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
-#define XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
-
-#include <stdint.h>
-
-#include <memory>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/fusion_node_indexing_evaluation.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-class GpuInstructionFusion : public InstructionFusion {
- public:
- GpuInstructionFusion(bool may_duplicate, const se::DeviceDescription& d)
- : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate),
- device_info_(d) {}
-
- static bool IsExpensive(const HloInstruction& instruction);
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- protected:
- std::unique_ptr<FusionQueue> GetFusionQueue(
- HloComputation* computation) override;
- FusionDecision ShouldFuse(HloInstruction* consumer,
- int64_t operand_index) override;
-
- HloInstruction::FusionKind ChooseKind(
- const HloInstruction* producer, const HloInstruction* consumer) override;
-
- private:
- // This method is called by ShouldFuse() to do all the computationally
- // inexpensive checks whether we should fuse the operand into 'consumer'.
- FusionDecision ShouldFuseInexpensiveChecks(HloInstruction* consumer,
- int64_t operand_index);
-
- HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
- HloInstruction* producer) override;
-
- // Keep track of the number of times each instruction inside a fusion node is
- // indexed with different index vectors.
- absl::flat_hash_set<const HloComputation*> fusible_computations_;
- absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
- fusion_node_evaluations_;
-
- se::DeviceDescription device_info_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc b/third_party/xla/xla/service/gpu/instruction_fusion_test.cc
deleted file mode 100644
index fa96edf..0000000
--- a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc
+++ /dev/null
@@ -1,1006 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/instruction_fusion.h"
-
-#include <cstdint>
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_utils.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace m = ::xla::match;
-
-namespace xla {
-namespace gpu {
-
-class InstructionFusionTest : public HloTestBase {
- public:
- GpuInstructionFusion duplicating_instruction_fusion_{
- /*may_duplicate=*/true, TestGpuDeviceInfo::RTXA6000DeviceInfo()};
-};
-
-TEST_F(InstructionFusionTest, NoFusionIntoCustomFusionConsumer) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
- ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-c {
- p0 = bf16[3000,53]{1,0} parameter(0)
- p1 = bf16[22,53]{1,0} parameter(1)
- d = bf16[3000,22]{1,0} dot(p0, p1),
- lhs_contracting_dims={1}, rhs_contracting_dims={1}
- r = bf16[1,1,3000,22]{3,2,1,0} reshape(d)
- ROOT c = bf16[1,1,3000,22]{2,1,3,0} copy(r)
-}
-
-ENTRY e {
- p1 = bf16[3000,53]{1,0} parameter(1)
- p0 = bf16[22,53]{1,0} parameter(0)
- cp0 = bf16[22,53]{1,0} convert(p0)
- ROOT f = bf16[1,1,3000,22]{2,1,3,0} fusion(p1, cp0), kind=kCustom, calls=c
-})"));
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest,
- CostlyProducerAndOperandElementReusingConsumerNotFused) {
- HloComputation::Builder builder(TestName());
- HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
- HloInstruction* log1 = builder.AddInstruction(HloInstruction::CreateUnary(
- ShapeUtil::MakeShape(F32, {}), HloOpcode::kLog, const0));
- HloInstruction* broadcast2 =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- ShapeUtil::MakeShape(F32, {1}), log1, {}));
-
- auto module = CreateNewVerifiedModule();
- auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(broadcast2, computation->root_instruction());
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
- EXPECT_EQ(broadcast2, computation->root_instruction());
-}
-
-TEST_F(InstructionFusionTest,
- NonCostlyProducerAndOperandElementReusingConsumerFused) {
- HloComputation::Builder builder(TestName());
- HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
- HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
- ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
- HloInstruction* broadcast2 =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- ShapeUtil::MakeShape(S32, {1}), negate1, {}));
-
- auto module = CreateNewVerifiedModule();
- auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(broadcast2, computation->root_instruction());
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
- EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
-}
-
-TEST_F(InstructionFusionTest,
- CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
- HloComputation::Builder builder(TestName());
- HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
- HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
- ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
- HloInstruction* reshape2 = builder.AddInstruction(
- HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), exp1));
-
- auto module = CreateNewVerifiedModule();
- auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(reshape2, computation->root_instruction());
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
- EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
-}
-
-TEST_F(InstructionFusionTest,
- CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
- HloComputation::Builder builder(TestName());
- HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
- HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
- ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
- HloInstruction* transpose2 = builder.AddInstruction(
- HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {}));
-
- auto module = CreateNewVerifiedModule();
- auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(transpose2, computation->root_instruction());
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
- EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
-}
-
-TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotFused) {
- HloComputation::Builder builder(TestName());
- auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(
- CreateCanonicalDot(ShapeUtil::MakeShape(F32, {1, 1}), param0, param0));
- auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(F32, {1, 1, 1}), dot1));
- auto log = builder.AddInstruction(HloInstruction::CreateUnary(
- reshape2->shape(), xla::HloOpcode::kLog, reshape2));
-
- auto module = CreateNewVerifiedModule();
- auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(log, computation->root_instruction());
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
- HloComputation::Builder builder(TestName());
- auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(
- CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
- auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
-
- auto module = CreateNewVerifiedModule();
- auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(transpose2, computation->root_instruction());
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-// Tests that broadcasts fused into a fusion with a reduce root.
-TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
-
- ENTRY BroadcastIntoReduce {
- constant = f32[] constant(1)
- broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={}
- constant.1 = f32[] constant(0)
- ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
- to_apply=add
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_THAT(
- root->fused_expression_root(),
- GmockMatch(m::Reduce(m::Broadcast(m::Constant()), m::Constant())));
-}
-
-TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
-
- ENTRY entry {
- p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
- copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
- constant.1 = f32[] constant(0)
- ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add
- })")
- .value();
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
-
- fused_reduce {
- p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0)
- mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1)
- c0.1 = f32[] constant(0)
- ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add
- }
-
- ENTRY entry {
- p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
- copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
- fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce
- ROOT root = (f32[]) tuple(fusion)
- })")
- .value();
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
-
- ENTRY entry {
- p0 = s32[512,512,2] parameter(0)
- p1 = f32[1,1,512,512] parameter(1)
- constant_1 = f32[] constant(1)
- reduce-window.1 = reduce-window(p1, constant_1),
- window={size=1x1x9x9}, to_apply=add
- ROOT ret = gather(reduce-window.1, p0), offset_dims={0,1,2,3},
- collapsed_slice_dims={}, start_index_map={1,2},
- index_vector_dim=2, slice_sizes={1,1,1,1}
- })")
- .value();
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- ENTRY entry {
- p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
- copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
- ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_THAT(root->fused_expression_root(),
- GmockMatch(m::Add(m::Copy(), m::Copy())));
-}
-
-TEST_F(InstructionFusionTest, BitcastIntoAdd) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- ENTRY BroadcastIntoAdd {
- p0 = f32[4,1,1]{2,1,0} parameter(0)
- p1 = f32[4,1]{1,0} parameter(1)
- bitcast = f32[4,1]{1,0} bitcast(p0)
- ROOT add = f32[4,1] add(bitcast, p1)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_THAT(root->fused_expression_root(),
- GmockMatch(m::Add(m::Bitcast(m::Parameter()), m::Parameter())));
-}
-
-TEST_F(InstructionFusionTest, AddIntoBitcast) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- ENTRY BroadcastIntoAdd {
- p0 = f32[4,1]{1,0} parameter(0)
- p1 = f32[4,1]{1,0} parameter(1)
- add = f32[4,1] add(p0, p1)
- ROOT bitcast = f32[4,1,1] bitcast(add)
- })")
- .value();
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, ConvertIntoBitcastBothConsumedByTuple) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test
-
- ENTRY main {
- param_0 = f32[2048,16000]{1,0} parameter(0)
- convert = bf16[2048,16000]{1,0} convert(param_0)
- bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
- ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
- })")
- .value();
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, DontFuseGTE) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- ENTRY DontFuseGTE {
- p0 = (f32[10], f32[10]) parameter(0)
- gte0 = f32[10] get-tuple-element(p0), index=0
- gte1 = f32[10] get-tuple-element(p0), index=1
- ROOT add = f32[10] add(gte0, gte1)
- })")
- .value();
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
-// duplicated and fused into both reduces.
-TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- ENTRY TestComputation {
- zero = f32[] constant(0)
- p0 = f32[100] parameter(0)
- p1 = f32[100] parameter(1)
- recip = f32[100] divide(p1, p0)
- sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
- sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
- ROOT root = (f32[], f32[]) tuple(sum1, sum2)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
- << module->ToString();
-}
-
-// Compute sum(100/p0), where p0 has type s32, twice. Check that the division
-// is *not* duplicated and fused into both reduces, because we say that integer
-// division is not cheap.
-TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- Add {
- lhs = s32[] parameter(0)
- rhs = s32[] parameter(1)
- ROOT add = s32[] add(lhs, rhs)
- }
- ENTRY TestComputation {
- zero = s32[] constant(0)
- p0 = s32[100] parameter(0)
- p1 = s32[100] parameter(1)
- recip = s32[100] divide(p1, p0)
- sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
- sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
- ROOT mul = (s32[], s32[]) tuple(sum1, sum2)
- })")
- .value();
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value())
- << module->ToString();
-}
-
-TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- ENTRY NoOutputFusion {
- alpha = f32[] constant(3)
- broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
- p0 = f32[4,3]{1,0} parameter(0)
- p1 = f32[3,4]{1,0} parameter(1)
- dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- d = f32[4,4]{1,0} multiply(dot, dot)
- ROOT mul = f32[4,4] multiply(d, broadcast)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
- EXPECT_THAT(
- root->fused_expression_root(),
- GmockMatch(m::Multiply(m::Multiply(m::Parameter(), m::Parameter()),
- m::Broadcast(m::Constant()))));
-}
-
-// Counts the HLO ops with a given op code in the specified module.
-static int Count(const HloModule& module, HloOpcode op) {
- int count = 0;
- for (const auto* computation : module.computations()) {
- for (const auto* instruction : computation->instructions()) {
- if (instruction->opcode() == op) {
- ++count;
- }
- }
- }
- return count;
-}
-
-TEST_F(InstructionFusionTest, MultiOutputFusion) {
- // sub --> add --> tuple
- // \---------------/
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- ENTRY OutputFusion {
- p0 = f32[4,3]{1,0} parameter(0)
- p1 = f32[4,3]{1,0} parameter(1)
- p2 = f32[4,3]{1,0} parameter(2)
- sub = f32[4,3]{1,0} subtract(p0, p2)
- add = f32[4,3]{1,0} add(sub, p1)
- ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add)
- })")
- .value();
-
- // Multi-output fusion is disabled here and performed in the
- // GpuMultiOutputFusion pass instead.
- ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, FuseScalarConstant) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- ENTRY FuseScalarConstant {
- p0 = f32[] parameter(0)
- c0 = f32[] constant(1)
- add1 = f32[] add(p0, c0)
- b0 = f32[2]{0} broadcast(add1), dimensions={}
- c1 = f32[2]{0} constant({1, 2})
- ROOT add2 = f32[2]{0} add(b0, c1)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_THAT(
- root->fused_expression_root(),
- GmockMatch(m::Add(m::Broadcast(m::Add(m::Parameter(), m::Constant())),
- m::Parameter())));
-}
-
-// Check that we limit the number of operands to fusions we create.
-TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
- constexpr int64_t kNumParams = 200;
- ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
-
- // Compute p0 + p1 + ... + pN.
- HloComputation::Builder b(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
- auto param0 =
- b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
- auto sum = param0;
- for (int64_t i = 1; i < kNumParams; ++i) {
- auto param =
- b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"));
- sum = b.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param));
- }
- auto module = CreateNewVerifiedModule();
- auto computation = module->AddEntryComputation(b.Build());
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- for (const HloInstruction* instr : computation->instructions()) {
- EXPECT_LE(instr->operand_count(), MaxOperandsAndOutputsPerFusion())
- << instr->ToString();
- }
-}
-
-TEST_F(InstructionFusionTest, FuseIntoScatter) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = s32[] parameter(0)
- rhs = s32[] parameter(1)
- ROOT add = s32[] add(lhs, rhs)
- }
-
- ENTRY FuseIntoScatter {
- p0 = s32[3,3] parameter(0)
- p1 = s32[2] parameter(1)
- indices = s32[2] add(p1, p1)
- p2 = s32[2,3] parameter(2)
- updates = s32[2,3] add(p2, p2)
- scatter = s32[3,3] scatter(p0, indices, updates),
- to_apply=add,
- update_window_dims={1},
- inserted_window_dims={0},
- scatter_dims_to_operand_dims={0},
- index_vector_dim=1
- ROOT add = s32[3,3] add(scatter, scatter)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
- EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
-}
-
-TEST_F(InstructionFusionTest, DontFuseIntoFirstOperandOfScatter) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = s32[] parameter(0)
- rhs = s32[] parameter(1)
- ROOT add = s32[] add(lhs, rhs)
- }
-
- ENTRY FuseIntoScatter {
- p0 = s32[3,3] parameter(0)
- operand = s32[3,3] add(p0, p0)
- p1 = s32[2] parameter(1)
- indices = s32[2] add(p1, p1)
- p2 = s32[2,3] parameter(2)
- updates = s32[2,3] add(p2, p2)
- scatter = s32[3,3] scatter(operand, indices, updates),
- to_apply=add,
- update_window_dims={1},
- inserted_window_dims={0},
- scatter_dims_to_operand_dims={0},
- index_vector_dim=1
- ROOT add = s32[3,3] add(scatter, scatter)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
- EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
-}
-
-TEST_F(InstructionFusionTest, ScatterOpShouldNotFuseWithSharedOperand) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
-
- ENTRY Test {
- parameter.0 = f32[8,8] parameter(0)
- parameter.1 = s32[7] parameter(1)
- indices = s32[7] add(parameter.1, parameter.1)
- slice = f32[7,8] slice(parameter.0), slice={[0:7],[0:8]}
- ROOT scatter = f32[8,8] scatter(parameter.0, indices, slice),
- to_apply=add,
- update_window_dims={1},
- inserted_window_dims={0},
- scatter_dims_to_operand_dims={0},
- index_vector_dim=1
- })")
- .value();
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
- // Verify that we don't fuse scatter and slice together since
- // scatter modifies the input buffer in-place, which is also used
- // as slice's input, and we don't know where the scatter indices point to.
- HloInstruction* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(
- root, GmockMatch(m::Fusion(m::Parameter(), m::Slice(), m::Parameter())));
-}
-
-TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
-
- ENTRY BroadcastIntoReduce {
- constant = f32[16] constant({0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15})
- broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={0}
- constant.1 = f32[] constant(0)
- ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
- to_apply=add
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
- // The f32[16] constant should not be fused into the reduce, but the f32[]
- // constant should be.
- auto* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_THAT(
- root->fused_instructions_computation()->root_instruction(),
- GmockMatch(m::Reduce(m::Broadcast(m::Parameter()), m::Constant())));
-}
-
-TEST_F(InstructionFusionTest, FuseReverse) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- ENTRY Reverse {
- p0 = f32[50,96,1024]{2,1,0} parameter(0)
- add = f32[50,96,1024]{2,1,0} add(p0, p0)
- ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0}
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_THAT(root->fused_expression_root(),
- GmockMatch(m::Reverse(m::Add(m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveF32) {
- auto m = CreateNewVerifiedModule();
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
-
- HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
- HloInstruction* div = builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
- HloInstruction* rem = builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kRemainder, param0, one));
- HloInstruction* sqrt = builder.AddInstruction(
- HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
- HloInstruction* rsqrt = builder.AddInstruction(
- HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
- HloInstruction* exp = builder.AddInstruction(
- HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
-
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
- EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*sqrt));
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rsqrt));
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*exp));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveF64) {
- auto m = CreateNewVerifiedModule();
- Shape r0f64 = ShapeUtil::MakeShape(F64, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f64, "param0"));
-
- HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
- HloInstruction* div = builder.AddInstruction(
- HloInstruction::CreateBinary(r0f64, HloOpcode::kDivide, param0, one));
- HloInstruction* rem = builder.AddInstruction(
- HloInstruction::CreateBinary(r0f64, HloOpcode::kRemainder, param0, one));
- HloInstruction* sqrt = builder.AddInstruction(
- HloInstruction::CreateUnary(r0f64, HloOpcode::kSqrt, param0));
- HloInstruction* rsqrt = builder.AddInstruction(
- HloInstruction::CreateUnary(r0f64, HloOpcode::kRsqrt, param0));
- HloInstruction* exp = builder.AddInstruction(
- HloInstruction::CreateUnary(r0f64, HloOpcode::kExp, param0));
-
- EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*div));
- EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
- EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*sqrt));
- EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rsqrt));
- EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*exp));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveS32) {
- auto m = CreateNewVerifiedModule();
- Shape r0s32 = ShapeUtil::MakeShape(S32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0s32, "param0"));
-
- HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
- HloInstruction* div = builder.AddInstruction(
- HloInstruction::CreateBinary(r0s32, HloOpcode::kDivide, param0, one));
- HloInstruction* rem = builder.AddInstruction(
- HloInstruction::CreateBinary(r0s32, HloOpcode::kRemainder, param0, one));
-
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveBroadcastS32) {
- auto m = CreateNewVerifiedModule();
- Shape r1s32 = ShapeUtil::MakeShape(S32, {10});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r1s32, "param0"));
-
- HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
- HloInstruction* one_broad =
- builder.AddInstruction(HloInstruction::CreateBroadcast(r1s32, one, {}));
-
- HloInstruction* div = builder.AddInstruction(HloInstruction::CreateBinary(
- r1s32, HloOpcode::kDivide, param0, one_broad));
- HloInstruction* rem = builder.AddInstruction(HloInstruction::CreateBinary(
- r1s32, HloOpcode::kRemainder, param0, one_broad));
-
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
- EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
-}
-
-TEST_F(InstructionFusionTest, FloatingPointExpIsCheap) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- Add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
- ENTRY TestComputation {
- zero = f32[] constant(0)
- p0 = f32[100] parameter(0)
- recip = f32[100] exponential(p0)
- sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
- sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
- ROOT root = (f32[], f32[]) tuple(sum1, sum2)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
- << module->ToString();
-}
-
-TEST_F(InstructionFusionTest, SmallReducedDimensionIsNotLoweredToLoop) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = s32[] parameter(0)
- rhs = s32[] parameter(1)
- ROOT add = s32[] add(lhs, rhs)
- }
-
- ENTRY FuseSmallReduction {
- p0 = s32[1048576,4] parameter(0)
- p1 = s32[1048576,4] parameter(1)
- sum = s32[1048576,4] add(p0, p1)
- init = s32[] constant(0)
- ROOT reduce = s32[1048576] reduce(sum, init), dimensions={1}, to_apply=add
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
-}
-
-TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule m
-
- f {
- tmp_0 = f32[] parameter(0)
- tmp_1 = f32[] parameter(1)
- tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
- tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
- tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
- tmp_5 = s32[] parameter(2)
- tmp_6 = s32[] parameter(3)
- tmp_7 = s32[] minimum(tmp_5, tmp_6)
- tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
- tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
- ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
- }
-
- minmax {
- tmp_0 = f32[] parameter(0)
- tmp_1 = f32[] parameter(2)
- tmp_2 = s32[] parameter(1)
- tmp_3 = s32[] parameter(3)
- ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=f
- }
-
- ENTRY e {
- tmp_0 = f32[554112,10]{1,0} parameter(0)
- tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
- tmp_2 = f32[] constant(-inf)
- tmp_3 = s32[] constant(0)
- ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax
- })")
- .value();
-
- EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
- TestGpuDeviceInfo::RTXA6000DeviceInfo())
- .Run(module.get())
- .value());
- ASSERT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter())));
- EXPECT_THAT(
- module->entry_computation()->root_instruction()->fused_expression_root(),
- GmockMatch(
- m::Reduce(m::Parameter(), m::Iota(), m::Constant(), m::Constant())));
-}
-
-TEST_F(InstructionFusionTest, InputReductionFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
- add.clone.13 {
- x.27 = f32[] parameter(0)
- y.27 = f32[] parameter(1)
- ROOT add.1036 = f32[] add(x.27, y.27)
- }
- add.clone.14 {
- x.28 = f32[] parameter(0)
- y.28 = f32[] parameter(1)
- ROOT add.1037 = f32[] add(x.28, y.28)
- }
- add {
- x = bf16[] parameter(0)
- convert.448 = f32[] convert(x)
- y = bf16[] parameter(1)
- convert.449 = f32[] convert(y)
- add.597 = f32[] add(convert.448, convert.449)
- ROOT convert.450 = bf16[] convert(add.597)
- }
- ENTRY FuseSmallReduction {
- param_2.7 = bf16[8,16,64,2048]{3,2,1,0} parameter(2)
- convert.1395 = f32[8,16,64,2048]{3,2,1,0} convert(param_2.7)
- param_0.85 = bf16[8,16,64,2048]{3,2,1,0} parameter(0)
- convert.1393 = f32[8,16,64,2048]{3,2,1,0} convert(param_0.85)
- multiply.1652 = f32[8,16,64,2048]{3,2,1,0} multiply(convert.1395, convert.1393)
- convert.1392 = bf16[8,16,64,2048]{3,2,1,0} convert(multiply.1652)
- bitcast.15934 = bf16[128,64,2048]{2,1,0} bitcast(convert.1392)
- convert.1391 = f32[128,64,2048]{2,1,0} convert(bitcast.15934)
- param_1.15 = bf16[] parameter(1)
- convert.1394 = f32[] convert(param_1.15)
- reduce.462 = f32[128,64]{1,0} reduce(convert.1391, convert.1394), dimensions={2}, to_apply=add.clone.13
- reduce.121 = f32[64]{0} reduce(reduce.462, convert.1394), dimensions={0}, to_apply=add.clone.14
- ROOT convert.890 = bf16[64]{0} convert(reduce.121)
- })")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* fused_convert_fusion =
- module->entry_computation()->root_instruction();
-
- ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
- SCOPED_TRACE(module->ToString());
- EXPECT_EQ(fused_convert_fusion->fusion_kind(),
- HloInstruction::FusionKind::kInput);
-}
-
-TEST_F(InstructionFusionTest, DotStrengthReductionFusion) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
-scalar_add_computation {
- scalar_rhs = f32[] parameter(1)
- scalar_lhs = f32[] parameter(0)
- ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
- param_1.3 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} parameter(1)
- param_0.6 = f16[16,64,96,1,2,16]{5,4,3,2,1,0} parameter(0)
- bitcast.26 = f16[16,64,96,2,16]{4,3,2,1,0} bitcast(param_0.6)
- broadcast.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} broadcast(bitcast.26), dimensions={0,1,2,4,5}
- multiply.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} multiply(broadcast.4, param_1.3)
- convert.8 = f32[16,64,96,6,2,16]{5,4,3,2,1,0} convert(multiply.4)
- constant_2 = f32[] constant(0)
- reduce.3 = f32[16,64,96,6,2]{3,4,2,1,0} reduce(convert.8, constant_2), dimensions={5}, to_apply=scalar_add_computation
- bitcast.25 = f32[16,64,96,2,6]{4,3,2,1,0} bitcast(reduce.3)
- convert.7 = f16[16,64,96,2,6]{4,3,2,1,0} convert(bitcast.25)
- ROOT bitcast.24 = f16[16,64,96,2,1,6]{5,4,3,2,1,0} bitcast(convert.7)
-})")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- const HloInstruction* fused_convert_fusion =
- module->entry_computation()->root_instruction()->operand(0);
-
- ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
- SCOPED_TRACE(module->ToString());
- EXPECT_EQ(fused_convert_fusion->fusion_kind(),
- HloInstruction::FusionKind::kInput);
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
-}
-
-TEST_F(InstructionFusionTest, ReductionFusionOtherUnaryElementwiseOpsAreFused) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
-scalar_add_computation {
- scalar_rhs = f32[] parameter(1)
- scalar_lhs = f32[] parameter(0)
- ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
- param_0 = f16[64,96,6,16]{3,2,1,0} parameter(0)
- constant_2 = f32[] constant(0)
- reduce.3 = f32[64,6,16]{2,1,0} reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
- negate = f32[64,6,16]{2,1,0} negate(reduce.3)
- ROOT sine = f16[64,6,16]{2,1,0} sine(negate)
-})")
- .value();
-
- EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
- HloInstruction* fused_convert_fusion =
- module->entry_computation()->root_instruction();
-
- ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
- SCOPED_TRACE(module->ToString());
- EXPECT_EQ(fused_convert_fusion->fusion_kind(),
- HloInstruction::FusionKind::kInput);
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
-}
-
-TEST_F(InstructionFusionTest, DoNotFuseInsideReducer) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
-scalar_add_computation {
- scalar_rhs = f32[] parameter(1)
- scalar_lhs = f32[] parameter(0)
- add.1 = f32[] add(scalar_lhs, scalar_rhs)
- ROOT add.2 = f32[] add(add.1, scalar_rhs)
-}
-
-ENTRY main {
- param_0 = f16[64,96] parameter(0)
- constant_2 = f32[] constant(0)
- ROOT reduce = f32[64] reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
-})")
- .value();
-
- EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc
index e2882a7..ee1140d 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc
@@ -15,6 +15,7 @@
#include "xla/service/gpu/ir_emission_utils.h"
+#include <algorithm>
#include <cstdint>
#include <functional>
#include <optional>
@@ -60,9 +61,9 @@
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/statusor.h"
@@ -81,6 +82,13 @@
return shape.rank() == batch_dimensions_size + 1;
}
+bool IsMlirTransposeEmitterEnabled(const HloInstruction& hlo) {
+ return hlo.GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_mlir_emitter_level() >= 3;
+}
+
} // namespace
bool IsMatrixMultiplication(const HloInstruction& dot) {
@@ -538,32 +546,53 @@
return std::nullopt;
}
- if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedTransposeShape(
- instr.operand(0)->shape(), instr.shape(), Vector3{0, 2, 1})) {
+ if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+ ShapeUtil::GetNormalizedTransposeShape(
+ instr.operand(0)->shape(), instr.shape(),
+ absl::InlinedVector<int64_t, 3>{0, 2, 1})) {
if ((tr->at(1) >= kMinDimensionToTransposeTiled &&
tr->at(2) >= kMinDimensionToTransposeTiled) ||
(tr->at(1) >= kMinDimensionToTransposeTiled2 &&
tr->at(2) >= kMinDimensionToTransposeTiled2 &&
tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
- return TransposeDescription{&instr, *tr,
- /*permutation=*/Vector3{0, 2, 1}};
+ return TransposeDescription{
+ &instr, *tr,
+ /*permutation=*/absl::InlinedVector<int64_t, 3>{0, 2, 1}};
}
}
- if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedTransposeShape(
- instr.operand(0)->shape(), instr.shape(), Vector3{2, 1, 0})) {
+ if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+ ShapeUtil::GetNormalizedTransposeShape(
+ instr.operand(0)->shape(), instr.shape(),
+ absl::InlinedVector<int64_t, 3>{2, 1, 0})) {
if ((tr->at(0) >= kMinDimensionToTransposeTiled &&
tr->at(2) >= kMinDimensionToTransposeTiled) ||
(tr->at(0) >= kMinDimensionToTransposeTiled2 &&
tr->at(2) >= kMinDimensionToTransposeTiled2 &&
tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
- return TransposeDescription{&instr, *tr,
- /*permutation=*/Vector3{2, 1, 0}};
+ return TransposeDescription{
+ &instr, *tr,
+ /*permutation=*/absl::InlinedVector<int64_t, 3>{2, 1, 0}};
+ }
+ }
+ if (IsMlirTransposeEmitterEnabled(instr)) {
+ if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+ ShapeUtil::GetNormalizedTransposeShape(
+ instr.operand(0)->shape(), instr.shape(),
+ absl::InlinedVector<int64_t, 3>{1, 0, 2})) {
+ auto byte_width = primitive_util::ByteWidth(instr.shape().element_type());
+ if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension &&
+ byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >=
+ kMinDimensionToTransposeTiled) {
+ return TransposeDescription{
+ &instr, *tr,
+ /*permutation=*/absl::InlinedVector<int64_t, 3>{1, 0, 2}};
+ }
}
}
return std::nullopt;
}
-// Find 021 or 210 transpose in logical + physical transposition.
+// Find 021, 210 or 102 transpose in logical + physical transposition.
static std::optional<TransposeDescription> FindTiledLogicalTranspose(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kTranspose) {
@@ -571,28 +600,47 @@
}
// TODO(cheshire): avoid code duplication.
- if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedLogicalTransposeShape(
- instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
- Vector3{0, 2, 1})) {
+ if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+ ShapeUtil::GetNormalizedLogicalTransposeShape(
+ instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
+ absl::InlinedVector<int64_t, 3>{0, 2, 1})) {
if ((tr->at(1) >= kMinDimensionToTransposeTiled &&
tr->at(2) >= kMinDimensionToTransposeTiled) ||
(tr->at(1) >= kMinDimensionToTransposeTiled2 &&
tr->at(2) >= kMinDimensionToTransposeTiled2 &&
tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
- return TransposeDescription{&instr, *tr,
- /*permutation=*/Vector3{0, 2, 1}};
+ return TransposeDescription{
+ &instr, *tr,
+ /*permutation=*/absl::InlinedVector<int64_t, 3>{0, 2, 1}};
}
}
- if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedLogicalTransposeShape(
- instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
- Vector3{2, 1, 0})) {
+ if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+ ShapeUtil::GetNormalizedLogicalTransposeShape(
+ instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
+ absl::InlinedVector<int64_t, 3>{2, 1, 0})) {
if ((tr->at(0) >= kMinDimensionToTransposeTiled &&
tr->at(2) >= kMinDimensionToTransposeTiled) ||
(tr->at(0) >= kMinDimensionToTransposeTiled2 &&
tr->at(2) >= kMinDimensionToTransposeTiled2 &&
tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
- return TransposeDescription{&instr, *tr,
- /*permutation=*/Vector3{2, 1, 0}};
+ return TransposeDescription{
+ &instr, *tr,
+ /*permutation=*/absl::InlinedVector<int64_t, 3>{2, 1, 0}};
+ }
+ }
+ if (IsMlirTransposeEmitterEnabled(instr)) {
+ if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+ ShapeUtil::GetNormalizedLogicalTransposeShape(
+ instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
+ absl::InlinedVector<int64_t, 3>{1, 0, 2})) {
+ auto byte_width = primitive_util::ByteWidth(instr.shape().element_type());
+ if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension &&
+ byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >=
+ kMinDimensionToTransposeTiled) {
+ return TransposeDescription{
+ &instr, *tr,
+ /*permutation=*/absl::InlinedVector<int64_t, 3>{1, 0, 2}};
+ }
}
}
return std::nullopt;
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h
index 9316ba9..3dcf0bc 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils.h
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h
@@ -23,6 +23,7 @@
#include <variant>
#include <vector>
+#include "absl/container/inlined_vector.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
@@ -54,6 +55,10 @@
// efficient.
inline constexpr int64_t kMinDimensionToTransposeTiled2 = 8;
inline constexpr int64_t kMinTotalDimensionsToTransposeTiled = 64 * 128;
+// As the amount of shared memory is limited, we need to make sure that we don't
+// detect 102 transposes that would require too much bytes for the most minor
+// dimension.
+inline constexpr int64_t kMaxBytesInMostMinorDimension = 8;
// Matrix multiplication before the rewrite.
bool IsMatrixMultiplication(const HloInstruction& dot);
@@ -160,16 +165,18 @@
const HloInstruction* instr;
// Normalized transpose dimensions.
- Vector3 dimensions;
+ absl::InlinedVector<int64_t, 3> dimensions;
// Permutations of normalized transpose dimensions.
- Vector3 permutation;
+ absl::InlinedVector<int64_t, 3> permutation;
- TransposeDescription(Vector3 dimensions, Vector3 permutation)
+ TransposeDescription(absl::InlinedVector<int64_t, 3> dimensions,
+ absl::InlinedVector<int64_t, 3> permutation)
: TransposeDescription(/*instr=*/nullptr, dimensions, permutation) {}
- TransposeDescription(const HloInstruction* instr, Vector3 dimensions,
- Vector3 permutation)
+ TransposeDescription(const HloInstruction* instr,
+ absl::InlinedVector<int64_t, 3> dimensions,
+ absl::InlinedVector<int64_t, 3> permutation)
: instr(instr), dimensions(dimensions), permutation(permutation) {}
// Transpose instruction input shape.
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
index 0703f86..2d94235 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
@@ -20,6 +20,7 @@
#include <string>
#include <vector>
+#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "xla/hlo/ir/backend_config.h"
#include "xla/literal.h"
@@ -30,7 +31,6 @@
#include "xla/shape_util.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/types.h"
-#include "xla/util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
@@ -41,6 +41,7 @@
using ::tsl::testing::IsOkAndHolds;
using IrEmissionUtilsTest = HloTestBase;
+using InlinedVector = absl::InlinedVector<int64_t, 3>;
TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) {
const char* hlo = R"(
@@ -59,8 +60,94 @@
auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, tr);
- EXPECT_EQ(result->dimensions, Vector3({1, 64, 1536}));
- EXPECT_EQ(result->permutation, Vector3({0, 2, 1}));
+ EXPECT_EQ(result->dimensions, InlinedVector({1, 64, 1536}));
+ EXPECT_EQ(result->permutation, InlinedVector({0, 2, 1}));
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiledLogical102Transpose) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+ p = f32[32,48,2]{2,1,0} parameter(0)
+ ROOT t = f32[48,32,2]{2,1,0} transpose(p), dimensions={1,0,2}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+ HloInstruction* tr = module->entry_computation()->root_instruction();
+
+ auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+ EXPECT_TRUE(result.has_value());
+ EXPECT_EQ(result->instr, tr);
+ EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 2}));
+ EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2}));
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiledLogical102TransposeTooMuchMemoryRequired) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+ p = s8[32,48,9]{2,1,0} parameter(0)
+ ROOT t = s8[48,32,9]{2,1,0} transpose(p), dimensions={1,0,2}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+ HloInstruction* tr = module->entry_computation()->root_instruction();
+
+ auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+ EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiled102Transpose) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+ p = s16[32,48,4]{2,1,0} parameter(0)
+ ROOT t = s16[32,48,4]{2,0,1} copy(p)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+ HloInstruction* tr = module->entry_computation()->root_instruction();
+
+ auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+ EXPECT_TRUE(result.has_value());
+ EXPECT_EQ(result->instr, tr);
+ EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 4}));
+ EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2}));
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiled102TransposeTooMuchMemoryRequired) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+ p = s8[32,48,9]{2,1,0} parameter(0)
+ ROOT t = s8[32,48,9]{2,0,1} copy(p)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+ HloInstruction* tr = module->entry_computation()->root_instruction();
+
+ auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+ EXPECT_FALSE(result.has_value());
}
TEST_F(IrEmissionUtilsTest, FindAnyTiledTranspose) {
@@ -79,8 +166,8 @@
auto result = GetDescriptionForTiledTransposeEmitter(*r, *r);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, r);
- EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOp) {
@@ -100,8 +187,8 @@
auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, r->operand(0));
- EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOpS8) {
@@ -258,8 +345,8 @@
auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, r->operand(0));
- EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithTwoIntermediateBinaryOps) {
@@ -289,8 +376,8 @@
GetDescriptionForTiledTransposeEmitter(*r, FindNonTrivialHero(*r));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, r->operand(0)->operand(0));
- EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest,
@@ -475,8 +562,8 @@
GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, copy);
- EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOneSwapDimIsSmall) {
@@ -502,8 +589,8 @@
GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, tr);
- EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest, FindTiledTransposeOtherSwapDimIsSmall) {
@@ -529,8 +616,8 @@
GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, copy);
- EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOtherSwapDimIsSmall) {
@@ -556,8 +643,8 @@
GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, tr);
- EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8}));
- EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+ EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8}));
+ EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}
TEST_F(IrEmissionUtilsTest, IsContiguousSlice) {
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
index a964f6b..b73225a 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
+++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
@@ -26,7 +26,6 @@
#include <string>
#include <tuple>
#include <utility>
-#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
@@ -99,7 +98,6 @@
#include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h"
#include "xla/service/gpu/gpu_asm_opts_util.h"
#include "xla/service/gpu/gpu_conv_runner.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
#include "xla/service/gpu/gpu_norm_runner.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
@@ -122,7 +120,6 @@
#include "xla/service/gpu/runtime/copy_thunk.h"
#include "xla/service/gpu/runtime/custom_call_thunk.h"
#include "xla/service/gpu/runtime/fft_thunk.h"
-#include "xla/service/gpu/runtime/fused_mha_thunk.h"
#include "xla/service/gpu/runtime/gemm_thunk.h"
#include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h"
#include "xla/service/gpu/runtime/infeed_thunk.h"
@@ -173,6 +170,7 @@
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "xla/service/gpu/runtime/cholesky_thunk.h"
#include "xla/service/gpu/runtime/cub_sort_thunk.h"
+#include "xla/service/gpu/runtime/cudnn_thunk.h"
#include "xla/service/gpu/runtime/triangular_solve_thunk.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -955,221 +953,17 @@
return absl::OkStatus();
}
-absl::Status IrEmitterUnnested::EmitFusedMHAThunk(
+absl::Status IrEmitterUnnested::EmitCuDnnThunk(
const HloCustomCallInstruction* instr) {
- const HloInstruction* lhs_bmm1 = instr->operand(0);
- const HloInstruction* rhs_bmm1 = instr->operand(1);
- const HloInstruction* rhs_bmm2 = instr->operand(2);
-
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice,
- GetAllocationSliceForHlo(lhs_bmm1));
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice,
- GetAllocationSliceForHlo(rhs_bmm1));
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice,
- GetAllocationSliceForHlo(rhs_bmm2));
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
- GetAllocationSliceForHlo(instr, {0}));
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice,
- GetAllocationSliceForHlo(
- instr, {instr->shape().tuple_shapes_size() - 1}));
- BufferAllocation::Slice activation_slice;
- bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3;
- if (has_activation) {
- TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {1}));
- }
-
- TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind,
- xla::gpu::GetCudnnfMHAKind(instr));
- BufferAllocation::Slice mask_slice, bias_slice;
- BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice;
- std::optional<Shape> mask_shape, bias_shape;
- {
- bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax ||
- kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout;
-
- if (has_bias) {
- const HloInstruction* bias = instr->operand(3);
- TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias));
- bias_shape = bias->shape();
- }
- int64_t seqlen_qk_operand_index = 3 + has_bias;
- bool has_seqlen_qk = seqlen_qk_operand_index == instr->operand_count() - 2;
- if (has_seqlen_qk) {
- const HloInstruction* seqlen_q = instr->operand(seqlen_qk_operand_index);
- TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q));
- const HloInstruction* seqlen_k =
- instr->operand(seqlen_qk_operand_index + 1);
- TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k));
- }
- }
-
- TF_ASSIGN_OR_RETURN(const auto gpu_config,
- instr->backend_config<xla::gpu::GpuBackendConfig>());
- const xla::gpu::CudnnfMHABackendConfig& config =
- gpu_config.cudnn_fmha_backend_config();
- Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
- absl::InlinedVector<Shape, 2> output_shapes = {
- ShapeUtil::GetSubshape(instr->shape(), {0})};
- if (has_activation) {
- output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {1}));
- }
- TF_ASSIGN_OR_RETURN(const auto mask_type,
- AsCudnnFmhaMaskKind(config.mask_type()));
- GpufMHADescriptor descriptor = {kind,
- config,
- mask_type,
- lhs_bmm1->shape(),
- rhs_bmm1->shape(),
- rhs_bmm2->shape(),
- intermediate_tensor_shape,
- output_shapes,
- config.bmm1_dot_dimension_numbers(),
- config.bmm2_dot_dimension_numbers(),
- mask_shape,
- bias_shape};
-
- TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config,
- GpufMHAConfig::For(descriptor));
- AddThunkToThunkSequence(std::make_unique<FusedMHAThunk>(
- Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config),
- lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice,
- scratch_slice, mask_slice, bias_slice, activation_slice, seqlen_q_slice,
- seqlen_k_slice));
- return absl::OkStatus();
-}
-
-absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk(
- const HloCustomCallInstruction* instr) {
- TF_ASSIGN_OR_RETURN(const auto gpu_config,
- instr->backend_config<xla::gpu::GpuBackendConfig>());
- const xla::gpu::CudnnfMHABackendConfig& config =
- gpu_config.cudnn_fmha_backend_config();
-
- int input_index = 0;
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice,
- GetAllocationSliceForHlo(instr->operand(input_index)));
- Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape();
-
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice,
- GetAllocationSliceForHlo(instr->operand(input_index)));
- Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape();
-
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice,
- GetAllocationSliceForHlo(instr->operand(input_index)));
- Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape();
-
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice,
- GetAllocationSliceForHlo(instr->operand(input_index)));
- Shape bmm2_grad_gemm1_lhs_shape;
-
- Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
- bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape;
- input_index++;
-
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice,
- GetAllocationSliceForHlo(instr->operand(input_index)));
- Shape d_output_shape = instr->operand(input_index++)->shape();
-
- TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr));
- BufferAllocation::Slice mask_slice;
- std::optional<Shape> mask_shape;
-
- bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax ||
- kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout);
- BufferAllocation::Slice bias_slice;
- std::optional<Shape> bias_shape;
- if (has_bias) {
- TF_ASSIGN_OR_RETURN(bias_slice,
- GetAllocationSliceForHlo(instr->operand(input_index)));
- bias_shape = instr->operand(input_index++)->shape();
- }
-
- BufferAllocation::Slice fwd_output_slice;
- std::optional<Shape> fwd_output_shape;
-
- TF_ASSIGN_OR_RETURN(fwd_output_slice,
- GetAllocationSliceForHlo(instr->operand(input_index)));
- fwd_output_shape = instr->operand(input_index++)->shape();
-
- BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice;
- bool has_seqlen_qk = input_index == instr->operand_count() - 2;
- if (has_seqlen_qk) {
- const HloInstruction* seqlen_q = instr->operand(input_index);
- TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q));
- const HloInstruction* seqlen_k = instr->operand(input_index + 1);
- TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k));
- input_index += 2;
- }
- TF_RET_CHECK(input_index == instr->operand_count());
-
- int output_index = 0;
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice,
- GetAllocationSliceForHlo(instr, {output_index}));
- Shape d_bmm1_lhs_shape =
- ShapeUtil::GetSubshape(instr->shape(), {output_index++});
-
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice,
- GetAllocationSliceForHlo(instr, {output_index}));
- Shape d_bmm1_rhs_shape =
- ShapeUtil::GetSubshape(instr->shape(), {output_index++});
-
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice,
- GetAllocationSliceForHlo(instr, {output_index}));
- Shape d_bmm2_rhs_shape =
- ShapeUtil::GetSubshape(instr->shape(), {output_index++});
-
- BufferAllocation::Slice d_s_slice;
- std::optional<Shape> d_s_shape;
-
- bool has_dbias = instr->shape().tuple_shapes().size() == 5;
- BufferAllocation::Slice d_bias_slice;
- std::optional<Shape> d_bias_shape;
- if (has_dbias) {
- TF_ASSIGN_OR_RETURN(d_bias_slice,
- GetAllocationSliceForHlo(instr, {output_index}));
- d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++});
- }
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice,
- GetAllocationSliceForHlo(instr, {output_index++}));
- TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size());
- TF_ASSIGN_OR_RETURN(const auto mask_type,
- AsCudnnFmhaMaskKind(config.mask_type()));
- bool force_deterministic = config.force_deterministic();
- GpufMHABackwardDescriptor descriptor = {
- kind,
- config,
- mask_type,
- bmm1_grad_gemm1_rhs_shape,
- bmm1_grad_gemm2_rhs_shape,
- bmm2_grad_gemm1_lhs_shape,
- bmm2_grad_gemm2_rhs_shape,
- d_output_shape,
- d_bmm1_lhs_shape,
- d_bmm1_rhs_shape,
- d_bmm2_rhs_shape,
- config.bmm1_grad_gemm1_dot_dimension_numbers(),
- config.bmm1_grad_gemm2_dot_dimension_numbers(),
- config.bmm2_grad_gemm1_dot_dimension_numbers(),
- config.bmm2_grad_gemm2_dot_dimension_numbers(),
- d_s_shape,
- fwd_output_shape,
- mask_shape,
- d_bias_shape,
- bias_shape,
- force_deterministic};
-
- TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config,
- GpufMHABackwardConfig::For(descriptor));
-
- AddThunkToThunkSequence(std::make_unique<FusedMHABackwardThunk>(
- Thunk::ThunkInfo::WithProfileAnnotation(instr),
- std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice,
- bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice,
- bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice,
- d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice,
- mask_slice, d_bias_slice, fwd_output_slice, bias_slice, seqlen_q_slice,
- seqlen_k_slice));
-
+ TF_ASSIGN_OR_RETURN(
+ auto kernel_arguments,
+ KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr,
+ instr->operands()));
+ TF_ASSIGN_OR_RETURN(const std::string fingerprint,
+ FingerprintWithBackendConfig<GpuBackendConfig>(*instr));
+ AddThunkToThunkSequence(std::make_unique<CuDnnThunk>(
+ fingerprint, Thunk::ThunkInfo::WithProfileAnnotation(instr),
+ kernel_arguments.args()));
return absl::OkStatus();
}
@@ -1698,7 +1492,7 @@
const se::DeviceDescription& device_info =
ir_emitter_context_->gpu_device_info();
const HloFusionAnalysis fusion_analysis =
- HloFusionAnalysis::Create(instr, &device_info);
+ HloFusionAnalysis::Create(*instr, device_info);
std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(HloFusionInfo(
fusion_analysis, instr, &ir_emitter_context_->buffer_assignment()));
@@ -2921,11 +2715,8 @@
if (IsCustomCallToDnnNorm(*instr)) {
return EmitNormThunk(custom_call);
}
- if (IsFwdCustomCallTofMHA(*instr)) {
- return EmitFusedMHAThunk(custom_call);
- }
- if (IsBwdCustomCallTofMHA(*instr)) {
- return EmitFusedMHABackwardThunk(custom_call);
+ if (IsCustomCallTofMHA(*instr)) {
+ return EmitCuDnnThunk(custom_call);
}
#endif // GOOGLE_CUDA
if (IsCustomCallToTopK(*instr)) {
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
index f97f106..d19dd5d 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
+++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
@@ -147,8 +147,7 @@
absl::Status EmitConvolutionReorderThunk(
const HloCustomCallInstruction* instr);
absl::Status EmitNormThunk(const HloCustomCallInstruction* instr);
- absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr);
- absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr);
+ absl::Status EmitCuDnnThunk(const HloCustomCallInstruction* instr);
#endif // GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr);
diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD
index 8c9b851..a1ce0cf 100644
--- a/third_party/xla/xla/service/gpu/kernels/BUILD
+++ b/third_party/xla/xla/service/gpu/kernels/BUILD
@@ -95,6 +95,7 @@
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
@@ -110,6 +111,7 @@
disabled_backends = ["gpu_h100"],
tags = ["no_rocm"],
deps = [
+ ":custom_kernel",
":custom_kernel_fusion_pattern",
":cutlass_gemm_custom_kernel",
":cutlass_gemm_fusion",
@@ -119,9 +121,11 @@
"//xla:error_spec",
"//xla:literal_util",
"//xla:types",
- "//xla/service/gpu:custom_kernel_fusion_rewriter",
+ "//xla:xla_data_proto_cc",
"//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
"//xla/tests:hlo_test_base",
+ "@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
],
@@ -355,6 +359,10 @@
deps = [
":cutlass_gemm_kernel_bf16xbf16_to_bf16",
":cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80",
+ ":cutlass_gemm_kernel_bf16xbf16_to_f32",
+ ":cutlass_gemm_kernel_bf16xf32_to_f32",
+ ":cutlass_gemm_kernel_bf16xs8_to_f32",
+ ":cutlass_gemm_kernel_f32xbf16_to_f32",
":cutlass_gemm_kernel_f32xf32_to_f32",
] + if_cuda_newer_than(
"12_0",
@@ -437,6 +445,68 @@
]),
)
+cuda_library(
+ name = "cutlass_gemm_kernel_bf16xbf16_to_f32",
+ srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc"]),
+ copts = [
+ "-mllvm",
+ "-unroll-threshold=100000",
+ ] + if_windows(
+ [],
+ ["-Wno-unknown-attributes"],
+ ),
+ deps = if_cuda_is_configured([
+ ":cutlass_gemm_adaptor",
+ "@cutlass_archive//:cutlass",
+ "@local_config_cuda//cuda:cuda_headers",
+ ]),
+)
+
+cuda_library(
+ name = "cutlass_gemm_kernel_bf16xf32_to_f32",
+ srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc"]),
+ copts = [
+ "-mllvm",
+ "-unroll-threshold=100000",
+ ] + if_windows(
+ [],
+ ["-Wno-unknown-attributes"],
+ ),
+ deps = if_cuda_is_configured([
+ ":cutlass_gemm_adaptor",
+ "@cutlass_archive//:cutlass",
+ "@local_config_cuda//cuda:cuda_headers",
+ ]),
+)
+
+cuda_library(
+ name = "cutlass_gemm_kernel_f32xbf16_to_f32",
+ srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc"]),
+ copts = [
+ "-mllvm",
+ "-unroll-threshold=100000",
+ ] + if_windows(
+ [],
+ ["-Wno-unknown-attributes"],
+ ),
+ deps = if_cuda_is_configured([
+ ":cutlass_gemm_adaptor",
+ "@cutlass_archive//:cutlass",
+ "@local_config_cuda//cuda:cuda_headers",
+ ]),
+)
+
+cuda_library(
+ name = "cutlass_gemm_kernel_bf16xs8_to_f32",
+ srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc"]),
+ copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
+ deps = if_cuda_is_configured([
+ ":cutlass_gemm_adaptor",
+ "@cutlass_archive//:cutlass",
+ "@local_config_cuda//cuda:cuda_headers",
+ ]),
+)
+
#===--------------------------------------------------------------------------------------------===#
# CUTLASS Gemm kernel libraries
#===--------------------------------------------------------------------------------------------===#
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h
index 37fb0ad..963b80c 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h
@@ -46,12 +46,27 @@
enum class Arch { kDefault, kSm80, kSm90 };
+// Keep in sync with cutlass::gemm::GemmUniversalMode.
+enum class GemmMode { kGemm, kGemmSplitKParallel, kBatched, kArray, kInvalid };
+
template <Arch arch>
struct Bf16xBf16ToBf16 {};
template <Arch arch>
struct F32xF32ToF32 {};
+template <Arch arch>
+struct Bf16xBf16ToF32 {};
+
+template <Arch arch>
+struct Bf16xF32ToF32 {};
+
+template <Arch arch>
+struct F32xBf16ToF32 {};
+
+template <Arch arch>
+struct Bf16xS8ToF32 {};
+
// A tag to specialize CUTLASS kernel adaptors for loading kernels from shared
// libraries using dlopen.
struct DlOpenedKernel {};
@@ -132,6 +147,12 @@
// Type-erased CUTLASS gemm arguments structure that has all of the details
// required for packing CUTLASS kernel parameters.
struct Arguments {
+ GemmMode mode;
+
+ // Number of batches when mode is `kBatched`.
+ // Number of k-slices when mode is `kGemmSplitKParallel`.
+ int32_t batch_count;
+
int32_t m;
int32_t n;
int32_t k;
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
index 53a6ac6..1478dc8 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
@@ -19,13 +19,17 @@
#include <cstddef>
#include <cstdint>
#include <memory>
+#include <optional>
+#include "third_party/gpus/cuda/include/vector_types.h"
#include "cute/layout.hpp"
#include "cutlass/cutlass.h"
+#include "cutlass/device_kernel.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm_enumerated_types.h"
#include "cutlass/gemm_coord.h"
+#include "cutlass/kernel_hardware_info.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/util/packed_stride.hpp"
#include "xla/service/gpu/kernels/cutlass_gemm.h"
@@ -137,6 +141,21 @@
cutlass::Status::kSuccess;
}
+inline cutlass::gemm::GemmUniversalMode ToGemmUniversalMode(GemmMode mode) {
+ switch (mode) {
+ case GemmMode::kGemm:
+ return cutlass::gemm::GemmUniversalMode::kGemm;
+ case GemmMode::kGemmSplitKParallel:
+ return cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel;
+ case GemmMode::kBatched:
+ return cutlass::gemm::GemmUniversalMode::kBatched;
+ case GemmMode::kArray:
+ return cutlass::gemm::GemmUniversalMode::kArray;
+ case GemmMode::kInvalid:
+ return cutlass::gemm::GemmUniversalMode::kInvalid;
+ }
+}
+
// Converts type-erased gemm arguments to the underlying CUTLASS operation
// arguments.
template <typename Tag>
@@ -148,7 +167,7 @@
auto ldb = LdB<typename Traits<Tag>::Operation>(problem_size);
auto ldc = LdC<typename Traits<Tag>::Operation>(problem_size);
- auto mode = cutlass::gemm::GemmUniversalMode::kGemm;
+ cutlass::gemm::GemmUniversalMode mode = ToGemmUniversalMode(args.mode);
// TODO(ezhulenev): We hardcode parameters for `LinearCombination`
// epilogue, however `Gemm` template can be compiled with arbitrary
@@ -160,7 +179,7 @@
return typename Traits<Tag>::Arguments( // CUTLASS Operation arguments
mode, problem_size, //
- 1, // batch
+ args.batch_count, // batch or k-split slices
{alpha, beta}, // epilogue
args.lhs, args.rhs, args.out, args.out, // pointers
0, 0, 0, 0, // batch strides
@@ -237,7 +256,9 @@
// TODO(ezhulenev): Pass device id and sm_count in arguments.
cutlass::KernelHardwareInfo hw_info{/*device_id=*/0, /*sm_count=*/128};
- auto mode = cutlass::gemm::GemmUniversalMode::kGemm;
+ cutlass::gemm::GemmUniversalMode mode =
+ static_cast<cutlass::gemm::GemmUniversalMode>(
+ static_cast<int>(args.mode));
typename Kernel::ProblemShape problem_shape = {args.m, args.n, args.k,
/*batch=*/1};
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
index 81b2dbd..a97fe04 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
@@ -90,8 +90,8 @@
}
template <typename Tag>
-KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k,
- const ArgsIndices& indices,
+KernelArgsPacking ArgsPacking(GemmMode mode, int32_t batch_count, int32_t m,
+ int32_t n, int32_t k, const ArgsIndices& indices,
const DynamicSliceIndices& slices,
int32_t device_sms, Adaptor<Tag> adaptor) {
using Packed = absl::StatusOr<std::unique_ptr<se::KernelArgsPackedArrayBase>>;
@@ -111,7 +111,7 @@
return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed {
auto* mem_args = se::Cast<se::KernelArgsDeviceMemoryArray>(&args);
- Arguments arguments = {m, n, k};
+ Arguments arguments = {mode, batch_count, m, n, k};
arguments.lhs = const_cast<void*>(mem_args->device_memory_ptr(indices.lhs));
arguments.rhs = const_cast<void*>(mem_args->device_memory_ptr(indices.rhs));
arguments.out = const_cast<void*>(mem_args->device_memory_ptr(indices.out));
@@ -176,7 +176,8 @@
//===----------------------------------------------------------------------===//
template <typename Tag>
-static CustomKernel Load(std::string name, int32_t m, int32_t n, int32_t k,
+static CustomKernel Load(std::string name, GemmMode mode, int32_t batch_count,
+ int32_t m, int32_t n, int32_t k,
const ArgsIndices& indices,
const DynamicSliceIndices& slices,
const se::DeviceDescription& device,
@@ -188,8 +189,8 @@
auto thread_dim = As<se::ThreadDim>(adaptor.ThreadDim());
auto shared_memory_bytes = adaptor.SharedMemoryBytes();
- auto packing =
- ArgsPacking<Tag>(m, n, k, indices, slices, device.core_count(), adaptor);
+ auto packing = ArgsPacking<Tag>(mode, batch_count, m, n, k, indices, slices,
+ device.core_count(), adaptor);
se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing));
spec.AddInProcessSymbol(kernel.symbol(), name);
@@ -204,33 +205,83 @@
}
absl::StatusOr<std::vector<CustomKernel>> GetCutlassGemmKernels(
- std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k,
+ std::string name, PrimitiveType dot_type, PrimitiveType lhs_type,
+ PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k,
const ArgsIndices& indices, const DynamicSliceIndices& slices,
const se::DeviceDescription& device) {
auto& cuda_cc =
std::get<se::CudaComputeCapability>(device.gpu_compute_capability());
- switch (dtype) {
- case PrimitiveType::F32:
- return {{Load<F32xF32ToF32<Default>>(std::move(name), m, n, k, indices,
- slices, device)}};
- case PrimitiveType::BF16:
+ if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::F32 &&
+ rhs_type == PrimitiveType::F32) {
+ return {{Load<F32xF32ToF32<Default>>(std::move(name), GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k, indices,
+ slices, device)}};
+ }
+
+ if (dot_type == PrimitiveType::BF16 && lhs_type == PrimitiveType::BF16 &&
+ rhs_type == PrimitiveType::BF16) {
#if CUDA_VERSION >= 12000
if (cuda_cc.IsAtLeastHopper()) {
- return {{Load<Bf16xBf16ToBf16<Sm90>>(std::move(name), m, n, k, indices,
- slices, device)}};
+ return {{Load<Bf16xBf16ToBf16<Sm90>>(std::move(name), GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k,
+ indices, slices, device)}};
}
#endif
if (cuda_cc.IsAtLeastAmpere()) {
- return {{Load<Bf16xBf16ToBf16<Sm80>>(std::move(name), m, n, k, indices,
- slices, device)}};
+ return {{Load<Bf16xBf16ToBf16<Default>>(
+ std::move(name), GemmMode::kGemm, /*batch_count=*/1, m, n, k,
+ indices, slices, device)}};
}
- return {{Load<Bf16xBf16ToBf16<Default>>(std::move(name), m, n, k, indices,
- slices, device)}};
-
- default:
- return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type");
+ return {{Load<Bf16xBf16ToBf16<Default>>(std::move(name), GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k,
+ indices, slices, device)}};
}
+
+ if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 &&
+ rhs_type == PrimitiveType::BF16) {
+ return {{Load<Bf16xBf16ToF32<Default>>(std::move(name), GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k, indices,
+ slices, device)}};
+ }
+
+ if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 &&
+ rhs_type == PrimitiveType::F32) {
+ return {{Load<Bf16xF32ToF32<Default>>(name, GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k, indices,
+ slices, device),
+ Load<Bf16xF32ToF32<Default>>(name, GemmMode::kGemmSplitKParallel,
+ /*batch_count=*/16, m, n, k, indices,
+ slices, device)}};
+ }
+
+ if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::F32 &&
+ rhs_type == PrimitiveType::BF16) {
+ return {{Load<F32xBf16ToF32<Default>>(name, GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k, indices,
+ slices, device),
+ Load<F32xBf16ToF32<Default>>(name, GemmMode::kGemmSplitKParallel,
+ /*batch_count=*/16, m, n, k, indices,
+ slices, device)}};
+ }
+
+ if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 &&
+ rhs_type == PrimitiveType::S8) {
+ return {{
+ Load<Bf16xS8ToF32<Default>>(name, GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k, indices, slices,
+ device),
+ Load<Bf16xS8ToF32<Default>>(name, GemmMode::kGemmSplitKParallel,
+ /*batch_count=*/16, m, n, k, indices,
+ slices, device),
+ }};
+ }
+
+ std::string kernel_name = PrimitiveType_Name(lhs_type) + "x" +
+ PrimitiveType_Name(rhs_type) + "To" +
+ PrimitiveType_Name(dot_type);
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Unsupported CUTLASS gemm data type for kernel: ", kernel_name));
}
absl::StatusOr<CustomKernel> LoadCutlassGemmKernel(
@@ -250,8 +301,9 @@
"Failed to load CUTLASS kernel from a shared library: ", library_path));
}
- return Load<DlOpenedKernel>(std::move(name), m, n, k, indices, slices, device,
- *adaptor, *kernel);
+ return Load<DlOpenedKernel>(std::move(name), GemmMode::kGemm,
+ /*batch_count=*/1, m, n, k, indices, slices,
+ device, *adaptor, *kernel);
}
} // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h
index 37531ef..04b0925 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h
@@ -30,7 +30,8 @@
// Returns pre-compiled custom kernels for a given data type and problem size.
absl::StatusOr<std::vector<CustomKernel>> GetCutlassGemmKernels(
- std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k,
+ std::string name, PrimitiveType dot_type, PrimitiveType lhs_type,
+ PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k,
const ArgsIndices& indices, const DynamicSliceIndices& slices,
const se::DeviceDescription& device);
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc
index 22edaea..124569e 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc
@@ -54,7 +54,8 @@
TF_ASSERT_OK_AND_ASSIGN(
auto custom_kernels,
- GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::BF16, m, n, k,
+ GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::BF16,
+ PrimitiveType::BF16, PrimitiveType::BF16, m, n, k,
/*indices=*/{0, 1, 2}, /*slices=*/{}, device));
const auto& custom_kernel = custom_kernels[0];
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc
index 8e231ee..d95241b 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc
@@ -26,7 +26,8 @@
namespace xla::gpu::kernel::gemm_universal {
absl::StatusOr<std::vector<CustomKernel>> GetCutlassGemmKernels(
- std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k,
+ std::string name, PrimitiveType dot_type, PrimitiveType lhs_type,
+ PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k,
const ArgsIndices& indices, const DynamicSliceIndices& slices,
const se::DeviceDescription& device) {
return absl::InternalError("XLA compiled without CUDA support");
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc
index 4c0a586..7cdc950 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc
@@ -44,7 +44,8 @@
// Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS.
TF_ASSERT_OK_AND_ASSIGN(
auto custom_kernels,
- GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::F32, 4, 4, 4,
+ GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::F32,
+ PrimitiveType::F32, PrimitiveType::F32, 4, 4, 4,
/*indices=*/{0, 1, 2}, /*slices=*/{},
executor->GetDeviceDescription()));
auto custom_kernel = custom_kernels[0];
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc
index a392801..946e1f8 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc
@@ -15,6 +15,7 @@
#include "xla/service/gpu/kernels/cutlass_gemm_fusion.h"
+#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <optional>
@@ -23,6 +24,7 @@
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
@@ -136,14 +138,22 @@
return absl::InternalError("unsupported operands type");
}
-// Returns matched GEMM with one of the operands upcasted to the accumulator
-// data type with an HLO convert instruction.
+// Returns matched GEMM with one or both the operands upcasted to the
+// accumulator data type with an HLO convert instruction.
static absl::StatusOr<GemmWithUpcast> MatchGemmWithUpcast(
HloDotInstruction* dot) {
TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot));
GemmWithUpcast match(dot);
+ // C <- convert(A) * convert(B)
+ if (Match(const_cast<HloInstruction*>(dot->operand(0)),
+ m::Convert(&match.lhs_upcast, m::Op())) &&
+ Match(const_cast<HloInstruction*>(dot->operand(1)),
+ m::Convert(&match.rhs_upcast, m::Op()))) {
+ return match;
+ }
+
// C <- convert(A) * B
if (Match(const_cast<HloInstruction*>(dot->operand(0)),
m::Convert(&match.lhs_upcast, m::Op()))) {
@@ -254,16 +264,19 @@
if (!dot) return std::nullopt;
auto matched = MatchGemmWithUpcast(dot);
- if (!matched.ok()) return std::nullopt;
- // Only one operand can be upcasted.
- DCHECK(matched->lhs_upcast == nullptr || matched->rhs_upcast == nullptr);
+ if (!matched.ok()) return std::nullopt;
CustomFusionConfig config;
config.set_name("cutlass_gemm_with_upcast");
- return matched->lhs_upcast ? Match{config, {matched->lhs_upcast, instr}}
- : Match{config, {matched->rhs_upcast, instr}};
+ if (matched->lhs_upcast != nullptr && matched->rhs_upcast == nullptr) {
+ return Match{config, {matched->lhs_upcast, instr}};
+ } else if (matched->rhs_upcast != nullptr && matched->lhs_upcast == nullptr) {
+ return Match{config, {matched->rhs_upcast, instr}};
+ } else {
+ return Match{config, {matched->lhs_upcast, matched->rhs_upcast, instr}};
+ }
}
//===----------------------------------------------------------------------===//
@@ -283,7 +296,7 @@
TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, {PrimitiveType::F32}));
- auto dtype = dot->shape().element_type();
+ PrimitiveType dot_type = dot->shape().element_type();
auto* lhs = Cast<HloParameterInstruction>(dot->operand(0));
auto* rhs = Cast<HloParameterInstruction>(dot->operand(1));
@@ -293,15 +306,19 @@
lhs->parameter_number(), rhs->parameter_number(),
computation->num_parameters()};
- auto& lhs_shape = lhs->shape();
- auto& rhs_shape = rhs->shape();
+ const Shape& lhs_shape = lhs->shape();
+ const Shape& rhs_shape = rhs->shape();
size_t m = lhs_shape.dimensions(0);
size_t k = lhs_shape.dimensions(1);
size_t n = rhs_shape.dimensions(1);
- return kernel::gemm_universal::GetCutlassGemmKernels(
- "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{}, device);
+ PrimitiveType lhs_type = lhs->shape().element_type();
+ PrimitiveType rhs_type = rhs->shape().element_type();
+
+ return GetCutlassGemmKernels("cutlass_gemm", dot_type, lhs_type, rhs_type,
+ m, n, k, indices,
+ /*slices=*/{}, device);
}
};
@@ -313,23 +330,44 @@
auto* dot = DynCast<HloDotInstruction>(computation->root_instruction());
if (dot == nullptr) {
return absl::InternalError(
- "cutlass_gemm requires ROOT operation to be a dot");
+ "cutlass_gemm_with_upcast requires ROOT operation to be a dot");
}
- TF_ASSIGN_OR_RETURN(auto matched, MatchGemmWithUpcast(dot));
+ TF_ASSIGN_OR_RETURN(GemmWithUpcast matched, MatchGemmWithUpcast(dot));
- // We only support upcasting of rhs operand.
- if (matched.lhs_upcast != nullptr)
- return absl::InternalError("only rhs upcasting is implemented");
+ const HloParameterInstruction* lhs;
+ const HloParameterInstruction* rhs;
- auto dot_dtype = dot->shape().element_type();
- auto upcast_dtype = matched.rhs_upcast->shape().element_type();
+ if (matched.lhs_upcast == nullptr && matched.rhs_upcast != nullptr) {
+ lhs = Cast<HloParameterInstruction>(matched.dot->operand(0));
+ rhs = Cast<HloParameterInstruction>(matched.rhs_upcast->operand(0));
+ } else if (matched.lhs_upcast != nullptr && matched.rhs_upcast == nullptr) {
+ lhs = Cast<HloParameterInstruction>(matched.lhs_upcast->operand(0));
+ rhs = Cast<HloParameterInstruction>(matched.dot->operand(1));
+ } else {
+ lhs = Cast<HloParameterInstruction>(matched.lhs_upcast->operand(0));
+ rhs = Cast<HloParameterInstruction>(matched.rhs_upcast->operand(0));
+ }
- // We only support BF16 <- BF16 x S8 upcasted gemm.
- if (dot_dtype != PrimitiveType::BF16 || upcast_dtype != PrimitiveType::S8)
- return absl::InternalError("unsupported upcasting pattern");
+ const Shape& lhs_shape = lhs->shape();
+ const Shape& rhs_shape = rhs->shape();
- return absl::UnimplementedError("requires CUTLASS 3.3.0");
+ size_t m = lhs_shape.dimensions(0);
+ size_t k = lhs_shape.dimensions(1);
+ size_t n = rhs_shape.dimensions(1);
+
+ PrimitiveType dot_type = dot->shape().element_type();
+ PrimitiveType lhs_type = lhs_shape.element_type();
+ PrimitiveType rhs_type = rhs_shape.element_type();
+
+ // Mapping from fusion arguments to gemm kernel arguments.
+ kernel::gemm_universal::ArgsIndices args_indices = {
+ lhs->parameter_number(), rhs->parameter_number(),
+ computation->num_parameters()};
+
+ return GetCutlassGemmKernels("cutlass_gemm_with_upcast", dot_type, lhs_type,
+ rhs_type, m, n, k, args_indices, /*slices=*/{},
+ device);
}
};
@@ -353,7 +391,7 @@
MatchSimpleGemm(Cast<HloDotInstruction>(matched.dot),
{PrimitiveType::F32, PrimitiveType::BF16}));
- auto dtype = matched.dot->shape().element_type();
+ auto dot_type = matched.dot->shape().element_type();
auto* lhs = Cast<HloParameterInstruction>(matched.dot->operand(0));
auto* rhs = Cast<HloParameterInstruction>(matched.dot->operand(1));
@@ -370,21 +408,25 @@
kernel::gemm_universal::DynamicSliceIndices slices;
slices.out = offset->parameter_number();
- auto& lhs_shape = lhs->shape();
- auto& rhs_shape = rhs->shape();
+ const Shape& lhs_shape = lhs->shape();
+ const Shape& rhs_shape = rhs->shape();
size_t m = lhs_shape.dimensions(0);
size_t k = lhs_shape.dimensions(1);
size_t n = rhs_shape.dimensions(1);
- return kernel::gemm_universal::GetCutlassGemmKernels(
- "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, k, args_indices,
- slices, device);
+ PrimitiveType lhs_type = lhs->shape().element_type();
+ PrimitiveType rhs_type = rhs->shape().element_type();
+
+ return GetCutlassGemmKernels("cutlass_gemm_with_dynamic_update_slice",
+ dot_type, lhs_type, rhs_type, m, n, k,
+ args_indices, slices, device);
}
};
} // namespace xla::gpu
+XLA_REGISTER_CUSTOM_FUSION_PATTERN(::xla::gpu::CutlassGemmWithUpcastPattern);
XLA_REGISTER_CUSTOM_FUSION_PATTERN(
::xla::gpu::CutlassGemmWithDynamicUpdateSlicePattern);
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
index a6488e9..768feaf 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
@@ -17,18 +17,21 @@
#include <cstdint>
#include <utility>
+#include <vector>
+#include <gtest/gtest.h>
#include "xla/array.h"
#include "xla/array2d.h"
#include "xla/array3d.h"
#include "xla/error_spec.h"
#include "xla/literal_util.h"
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h"
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/types.h"
+#include "xla/xla_data.pb.h"
#include "tsl/platform/test.h"
namespace xla::gpu {
@@ -41,10 +44,12 @@
->GetDeviceDescription()
.shared_memory_per_block_optin();
}
- int CutlassGemmKernelSharedMemorySize(PrimitiveType dtype, int m, int n,
+ int CutlassGemmKernelSharedMemorySize(PrimitiveType dot_type,
+ PrimitiveType lhs_type,
+ PrimitiveType rhs_type, int m, int n,
int k) {
return kernel::gemm_universal::GetCutlassGemmKernels(
- "cutlass_gemm", dtype, m, n, k,
+ "cutlass_gemm", dot_type, lhs_type, rhs_type, m, n, k,
/*indices=*/{0, 1, 2}, /*slices=*/{},
backend().default_stream_executor()->GetDeviceDescription())
->at(0)
@@ -134,6 +139,48 @@
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}
+TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastOfBothOperands) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main (p0: bf16[15,19], p1: bf16[19,17]) -> f32[15,17] {
+ %p0 = bf16[15,19]{1,0} parameter(0)
+ %c1 = f32[15,19]{1,0} convert(%p0)
+ %p1 = bf16[19,17]{1,0} parameter(1)
+ %c2 = f32[19,17]{1,0} convert(%p1)
+ ROOT %r = f32[15,17]{1,0} dot(%c1, %c2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %cutlass_gemm_with_upcast {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = bf16[15,19]{1,0} parameter
+ ; CHECK: [[C1:%[^ ]+]] = f32[15,19]{1,0} convert([[P0]])
+ ; CHECK-DAG: [[P1:%[^ ]+]] = bf16[19,17]{1,0} parameter
+ ; CHECK: [[C2:%[^ ]+]] = f32[19,17]{1,0} convert([[P1]])
+ ; CHECK: ROOT [[DOT:%[^ ]+]] = f32[15,17]{1,0} dot([[C1]], [[C2]]),
+ ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main {{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[15,17]{1,0} fusion
+ ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_upcast,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm_with_upcast","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ CustomKernelFusionPatternRegistry patterns;
+ patterns.Emplace<CutlassGemmWithUpcastPattern>();
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ CustomKernelFusionRewriter pass(&device, &patterns);
+ RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
+}
+
TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) {
const char* hlo = R"(
HloModule test
@@ -329,9 +376,7 @@
error_spec, /*run_hlo_passes=*/false));
}
-TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) {
- GTEST_SKIP() << "Requires CUTLASS 3.3.0+";
-
+TEST_F(CutlassFusionTest, GemmWithLeftHandSideUpcastKernel) {
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
const char* hlo_text_cublas = R"(
@@ -339,12 +384,12 @@
ENTRY e {
p0 = bf16[16,32]{1,0} parameter(0)
- p1 = s8[32,8]{1,0} parameter(1)
- c1 = bf16[32,8]{1,0} convert(p1)
- gemm = (bf16[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1),
+ c0 = f32[16,32]{1,0} convert(p0)
+ p1 = f32[32,8]{1,0} parameter(1)
+ gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(c0, p1),
custom_call_target="__cublas$gemm",
backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}}
- ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element(gemm), index=0
+ ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0
})";
const char* hlo_text_custom_fusion = R"(
@@ -352,16 +397,97 @@
cutlass_gemm_with_upcast {
p0 = bf16[16,32]{1,0} parameter(0)
+ c0 = f32[16,32]{1,0} convert(p0)
+ p1 = f32[32,8]{1,0} parameter(1)
+ ROOT dot = f32[16,8]{1,0} dot(c0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+ ENTRY e {
+ p0 = bf16[16,32]{1,0} parameter(0)
+ p1 = f32[32,8]{1,0} parameter(1)
+ ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast,
+ backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast", "kernel_index":0}}}
+ })";
+
+ EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion,
+ error_spec, /*run_hlo_passes=*/false));
+}
+
+TEST_F(CutlassFusionTest, GemmWithRightHandSideUpcastKernel) {
+ ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+
+ const char* hlo_text_cublas = R"(
+ HloModule cublas
+
+ ENTRY e {
+ p0 = f32[16,32]{1,0} parameter(0)
+ p1 = bf16[32,8]{1,0} parameter(1)
+ c1 = f32[32,8]{1,0} convert(p1)
+ gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}}
+ ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0
+ })";
+
+ const char* hlo_text_custom_fusion = R"(
+ HloModule cutlass
+
+ cutlass_gemm_with_upcast {
+ p0 = f32[16,32]{1,0} parameter(0)
+ p1 = bf16[32,8]{1,0} parameter(1)
+ c1 = f32[32,8]{1,0} convert(p1)
+ ROOT dot = f32[16,8]{1,0} dot(p0, c1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+ ENTRY e {
+ p0 = f32[16,32]{1,0} parameter(0)
+ p1 = bf16[32,8]{1,0} parameter(1)
+ ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom,
+ calls=cutlass_gemm_with_upcast,
+ backend_config={"fusion_backend_config":{kind: "__custom_fusion",
+ custom_fusion_config: {"name":"cutlass_gemm_with_upcast",
+ "kernel_index":0}}}
+ })";
+
+ EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion,
+ error_spec, /*run_hlo_passes=*/false));
+}
+
+TEST_F(CutlassFusionTest, GemmWithLeftHandAndRightHandSideUpcastKernel) {
+ ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+
+ const char* hlo_text_cublas = R"(
+ HloModule cublas
+
+ ENTRY e {
+ p0 = bf16[16,32]{1,0} parameter(0)
+ c0 = f32[16,32]{1,0} convert(p0)
p1 = s8[32,8]{1,0} parameter(1)
- c1 = bf16[32,8]{1,0} convert(p1)
- ROOT dot = bf16[16,8]{1,0} dot(p0, c1),
+ c1 = f32[32,8]{1,0} convert(p1)
+ gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(c0, c1),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}}
+ ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0
+ })";
+
+ const char* hlo_text_custom_fusion = R"(
+ HloModule cutlass
+
+ cutlass_gemm_with_upcast {
+ p0 = bf16[16,32]{1,0} parameter(0)
+ c0 = f32[16,32]{1,0} convert(p0)
+ p1 = s8[32,8]{1,0} parameter(1)
+ c1 = f32[32,8]{1,0} convert(p1)
+ ROOT dot = f32[16,8]{1,0} dot(c0, c1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = bf16[16,32]{1,0} parameter(0)
p1 = s8[32,8]{1,0} parameter(1)
- ROOT _ = bf16[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast,
+ ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast,
backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast", "kernel_index":0}}}
})";
@@ -371,7 +497,7 @@
TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) {
if (GpuSharedMemorySize() <
- CutlassGemmKernelSharedMemorySize(BF16, 8, 8, 8)) {
+ CutlassGemmKernelSharedMemorySize(BF16, BF16, BF16, 8, 8, 8)) {
GTEST_SKIP_("The GPU does not have sufficient shared memory");
}
@@ -445,7 +571,7 @@
TEST_F(CutlassFusionTest,
RowMajorGemmWithDynamicUpdateSliceKernelWithoutBitcast) {
if (GpuSharedMemorySize() <
- CutlassGemmKernelSharedMemorySize(BF16, 8, 8, 8)) {
+ CutlassGemmKernelSharedMemorySize(BF16, BF16, BF16, 8, 8, 8)) {
GTEST_SKIP_("The GPU does not have sufficient shared memory");
}
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc
new file mode 100644
index 0000000..ec08008
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = cutlass::bfloat16_t;
+using ElementB = cutlass::bfloat16_t;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+} // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+ ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+ cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+ cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>,
+ cutlass::gemm::GemmShape<1, 1, 1>,
+ cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+ 2, // stages
+ 1, // A alignment
+ 1, // B alignment
+ cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xBf16ToF32<Arch::kDefault>,
+ GemmOperation);
+template struct Adaptor<Bf16xBf16ToF32<Arch::kDefault>>;
+template struct DeviceKernel<Bf16xBf16ToF32<Arch::kDefault>>;
+
+} // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc
new file mode 100644
index 0000000..e117b1a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = cutlass::bfloat16_t;
+using ElementB = float;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+} // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+ ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+ cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+ cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>,
+ cutlass::gemm::GemmShape<1, 1, 1>,
+ cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+ 2, // stages
+ 1, // A alignment
+ 1, // B alignment
+ cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xF32ToF32<Arch::kDefault>,
+ GemmOperation);
+template struct Adaptor<Bf16xF32ToF32<Arch::kDefault>>;
+template struct DeviceKernel<Bf16xF32ToF32<Arch::kDefault>>;
+
+} // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc
new file mode 100644
index 0000000..527d369
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc
@@ -0,0 +1,50 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = cutlass::bfloat16_t;
+using ElementB = cutlass::int8_t;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+} // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+ ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+ cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+ cutlass::gemm::GemmShape<64, 128, 8>, cutlass::gemm::GemmShape<32, 64, 8>,
+ cutlass::gemm::GemmShape<1, 1, 1>,
+ cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+ 2, // stages
+ 1, // A alignment
+ 1, // B alignment
+ cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xS8ToF32<Arch::kDefault>, GemmOperation);
+template struct Adaptor<Bf16xS8ToF32<Arch::kDefault>>;
+template struct DeviceKernel<Bf16xS8ToF32<Arch::kDefault>>;
+
+} // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc
new file mode 100644
index 0000000..6ec6963
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = float;
+using ElementB = cutlass::bfloat16_t;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+} // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+ ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+ cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+ cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>,
+ cutlass::gemm::GemmShape<1, 1, 1>,
+ cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+ 2, // stages
+ 1, // A alignment
+ 1, // B alignment
+ cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(F32xBf16ToF32<Arch::kDefault>,
+ GemmOperation);
+template struct Adaptor<F32xBf16ToF32<Arch::kDefault>>;
+template struct DeviceKernel<F32xBf16ToF32<Arch::kDefault>>;
+
+} // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc
index 5aff534..119d724 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc
@@ -51,14 +51,14 @@
extern "C" bool xla_cutlass_kernel_can_implement(int32_t m, int32_t n,
int32_t k) {
Adaptor<CutlassGemm> adaptor;
- Arguments arguments = {m, n, k};
+ Arguments arguments = {GemmMode::kGemm, /*batch_count=*/1, m, n, k};
return adaptor.CanImplement(arguments);
}
extern "C" int64_t xla_cutlass_kernel_workspace_size(int32_t m, int32_t n,
int32_t k) {
Adaptor<CutlassGemm> adaptor;
- Arguments arguments = {m, n, k};
+ Arguments arguments = {GemmMode::kGemm, /*batch_count=*/1, m, n, k};
return adaptor.WorkspaceSize(arguments);
}
@@ -67,7 +67,9 @@
void* out, void* workspace, int32_t* out_offset, int32_t device_sms,
int32_t sm_occupancy) {
Adaptor<CutlassGemm> adaptor;
- Arguments arguments = {m, n, k, lhs, rhs, out, workspace, {out_offset}};
+ Arguments arguments = {
+ GemmMode::kGemm, /*batch_count=*/1, m, n, k, lhs, rhs, out,
+ workspace, {out_offset}};
adaptor.Initialize(params, arguments, device_sms, sm_occupancy);
}
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD
index 8951c27..8fc3db5 100644
--- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -3,6 +3,10 @@
"if_rocm_is_configured",
)
load(
+ "@local_config_sycl//sycl:build_defs.bzl",
+ "if_sycl_is_configured",
+)
+load(
"@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)
@@ -88,6 +92,8 @@
"@local_config_rocm//rocm:rocm_headers",
"@llvm-project//llvm:AMDGPUCodeGen",
"@llvm-project//llvm:AMDGPUAsmParser",
+ ]) + if_sycl_is_configured([
+ "@spirv_llvm_translator//:spirv_llvm_translator",
]),
)
@@ -106,3 +112,16 @@
"@local_tsl//tsl/platform:test",
],
)
+
+xla_cc_test(
+ name = "gpu_backend_lib_test",
+ size = "small",
+ srcs = ["gpu_backend_lib_test.cc"],
+ deps = [
+ ":llvm_gpu_backend",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:xla_internal_test_main",
+ "@llvm-project//llvm:Core",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 13e17bb..696e360 100644
--- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -17,6 +17,7 @@
#include <algorithm>
#include <cstdint>
+#include <cstdlib>
#include <fstream>
#include <functional>
#include <ios>
@@ -105,6 +106,11 @@
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
#endif
+#if TENSORFLOW_USE_SYCL
+#include "LLVMSPIRVLib.h"
+#include "LLVMSPIRVOpts.h"
+#endif // TENSORFLOW_USE_SYCL
+
namespace xla {
namespace gpu {
namespace {
@@ -117,41 +123,6 @@
// Default inline threshold value to use in llvm.
const int kDefaultInlineThreshold = 1100;
-// Gets the GPU name as it's known to LLVM for a given compute
-// capability. If we see an unrecognized compute capability, we
-// return the highest one that is known and below the selected device.
-static std::string GetSmName(se::CudaComputeCapability compute_capability) {
- int compute_capability_version =
- compute_capability.major * 10 + compute_capability.minor;
- int sm_version = 30;
- // If the current compute capability isn't known, fallback to the
- // most recent version before it.
- int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62,
- 61, 60, 53, 52, 50, 37, 35, 32, 30};
- for (int v : supported_versions) {
- if (v <= compute_capability_version) {
- sm_version = v;
- break;
- }
- }
-
- // If the current CC isn't supported by LLVM and it is newer then
- // the max supported LLVM version, do not warn about it. The end
- // user can't do anything about this. E.g., PTX compiled for SM75 will
- // run on SM80 too.
- if (sm_version != compute_capability_version &&
- compute_capability_version < supported_versions[0]) {
- LOG(WARNING) << "Unknown compute capability "
- << compute_capability.ToString()
- << ". Defaulting to telling LLVM that we're compiling for sm_"
- << sm_version;
- }
- // If the target is sm_90, hard code it to sm_90a so that all instructions
- // can be used. We don't need the portability that sm_90 gives.
- std::string_view extension = sm_version == 90 ? "a" : "";
- return absl::StrCat("sm_", sm_version, extension);
-}
-
// NOLINTBEGIN: clang-diagnostic-unused-function
// Convenience function for producing a name of a temporary compilation product
// from the input filename.
@@ -378,7 +349,7 @@
#else
std::string feature_str;
#endif // GOOGLE_CUDA
- return GetTargetMachine(target_triple, GetSmName(compute_capability),
+ return GetTargetMachine(target_triple, nvptx::GetSmName(compute_capability),
debug_options, feature_str);
}
@@ -452,7 +423,9 @@
llvm::CGSCCAnalysisManager cgam;
llvm::ModuleAnalysisManager mam;
- fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); });
+ if (target_machine) {
+ fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); });
+ }
llvm::PipelineTuningOptions pto;
pto.SLPVectorization = true;
@@ -569,6 +542,40 @@
namespace nvptx {
+std::string GetSmName(se::CudaComputeCapability compute_capability) {
+ int compute_capability_version =
+ compute_capability.major * 10 + compute_capability.minor;
+ int sm_version = 30;
+ // If the current compute capability isn't known, fallback to the
+ // most recent version before it.
+ int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62,
+ 61, 60, 53, 52, 50, 37, 35, 32, 30};
+ for (int v : supported_versions) {
+ if (v <= compute_capability_version) {
+ sm_version = v;
+ break;
+ }
+ }
+
+ // If the current CC isn't supported by LLVM and it is newer then
+ // the max supported LLVM version, do not warn about it. The end
+ // user can't do anything about this. E.g., PTX compiled for SM75 will
+ // run on SM80 too.
+ if (sm_version != compute_capability_version &&
+ compute_capability_version < supported_versions[0]) {
+ LOG(WARNING) << "Unknown compute capability "
+ << compute_capability.ToString()
+ << ". Defaulting to telling LLVM that we're compiling for sm_"
+ << sm_version;
+ }
+ // On Hopper, default to sm_90a so that all instructions can be used. But
+ // only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
+ std::string_view extension =
+ (compute_capability.major == 9 && sm_version == 90) ? "a" : "";
+ return absl::StrCat("sm_", sm_version, extension);
+}
+
std::string CantFindCudaMessage(absl::string_view msg,
absl::string_view xla_gpu_cuda_data_dir) {
return absl::StrCat(
@@ -855,7 +862,12 @@
ir_fs->flush();
}
// Locate lld.
- std::string lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin");
+ std::string lld_path;
+ if (std::getenv("LLVM_PATH")) {
+ lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin");
+ } else {
+ lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin");
+ }
auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path});
if (!lld_program) {
return xla::Internal("unable to find ld.lld in PATH: %s",
@@ -1132,5 +1144,95 @@
} // namespace amdgpu
+namespace {
+
+std::unique_ptr<llvm::TargetMachine> SPIRGetTargetMachine(
+ llvm::Triple target_triple, se::GpuComputeCapability gpu_version,
+ const DebugOptions& debug_options) {
+ return nullptr;
+}
+
+absl::Status SPIRTargetModuleLinker(
+ llvm::Module* module, se::GpuComputeCapability gpu_version,
+ const DebugOptions& debug_options,
+ const std::string& device_bitcode_dir_path) {
+ return absl::OkStatus();
+}
+
+absl::StatusOr<std::string> EmitModuleToSpir(
+ llvm::Module* module, se::GpuComputeCapability gpu_version,
+ const DebugOptions& debug_options) {
+#if TENSORFLOW_USE_SYCL
+ SPIRV::TranslatorOpts::ExtensionsStatusMap ExtensionsStatus;
+ SPIRV::TranslatorOpts opts(SPIRV::VersionNumber::MaximumVersion,
+ ExtensionsStatus);
+ opts.enableAllExtensions(); // enable all SPIR-V extension first
+
+ std::ostringstream oss;
+ std::string err;
+ bool success = llvm::writeSpirv(module, opts, oss, err);
+ if (!success) {
+ return xla::Internal("Fails to convert LLVM as SPIR-V: %s", err);
+ }
+ return oss.str();
+#else
+ return absl::UnimplementedError("Not implemented for SYCL");
+#endif
+}
+
+void SPIRBackendInit(const DebugOptions& debug_options) {
+ FeedLLVMWithFlags({
+ "-slp-vectorize-hor=false",
+ "-slp-min-reg-size=64",
+ "-slp-max-reg-size=64",
+ });
+
+ llvm_ir::InitializeLLVMCommandLineOptions(
+ debug_options.xla_backend_extra_options());
+
+ llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
+ InitializePasses(registry);
+}
+
+} // namespace
+
+namespace spir {
+
+absl::StatusOr<std::vector<uint8_t>> CompileToSpir(
+ llvm::Module* module, se::GpuComputeCapability gpu_version,
+ const DebugOptions& debug_options) {
+ std::string libdevice_dir_path;
+ static absl::once_flag backend_init_flag;
+ absl::call_once(backend_init_flag, SPIRBackendInit, debug_options);
+
+ std::string spir;
+ {
+ XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
+
+ // If the module has no functions or globals, there's nothing to compile.
+ if (module->empty() && module->global_empty()) {
+ VLOG(2) << "Module '" << module->getName().str()
+ << "' is empty. Skipping compilation.";
+ return std::vector<uint8_t>();
+ }
+
+ llvm::Triple default_target_triple("spir64-unknown-unknown");
+ std::unique_ptr<llvm::TargetMachine> target_machine =
+ SPIRGetTargetMachine(default_target_triple, gpu_version, debug_options);
+
+ TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
+ module, gpu_version, debug_options, libdevice_dir_path,
+ SPIRTargetModuleLinker, default_target_triple, target_machine.get(),
+ kDefaultInlineThreshold));
+
+ // Lower optimized LLVM module to SPIR.
+ TF_ASSIGN_OR_RETURN(spir,
+ EmitModuleToSpir(module, gpu_version, debug_options));
+ }
+ return std::vector<uint8_t>(spir.begin(), spir.end());
+}
+
+} // namespace spir
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
index 3ab5d6d..1814291 100644
--- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
@@ -37,6 +37,11 @@
namespace nvptx {
+// Gets the GPU name as it's known to LLVM for a given compute
+// capability. If we see an unrecognized compute capability, we
+// return the highest one that is known and below the selected device.
+std::string GetSmName(se::CudaComputeCapability compute_capability);
+
std::string CantFindCudaMessage(absl::string_view msg,
absl::string_view xla_gpu_cuda_data_dir);
@@ -73,6 +78,13 @@
const std::string& module_config_cache_key);
} // namespace amdgpu
+namespace spir {
+// Compiles the argument module and returns it.
+absl::StatusOr<std::vector<uint8_t>> CompileToSpir(
+ llvm::Module* module, se::GpuComputeCapability gpu_version,
+ const DebugOptions& debug_options);
+} // namespace spir
+
} // namespace gpu
} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc
new file mode 100644
index 0000000..9e65f34
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+namespace se = ::stream_executor;
+
+TEST(UtilsTest, TestGetSmName) {
+ se::CudaComputeCapability cc_hopper(9, 0);
+ ASSERT_EQ(nvptx::GetSmName(cc_hopper), "sm_90a");
+ // Do not default to sm90_a after Hopper, because it is not forward
+ // compatible.
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
+ se::CudaComputeCapability cc_next(10, 0);
+ ASSERT_EQ(nvptx::GetSmName(cc_next), "sm_90");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc
index fe4982e..49270de 100644
--- a/third_party/xla/xla/service/gpu/matmul_utils.cc
+++ b/third_party/xla/xla/service/gpu/matmul_utils.cc
@@ -456,7 +456,11 @@
const HloInstruction* gemm) {
TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
gemm->backend_config<GpuBackendConfig>());
- const GemmBackendConfig& config = gpu_config.gemm_backend_config();
+ return For(gemm, gpu_config.gemm_backend_config());
+}
+
+/*static*/ absl::StatusOr<GemmConfig> GemmConfig::For(
+ const HloInstruction* gemm, const GemmBackendConfig& config) {
std::optional<int64_t> algorithm;
if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) {
algorithm = config.selected_algorithm();
diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h
index 22d7f17..5f128e4 100644
--- a/third_party/xla/xla/service/gpu/matmul_utils.h
+++ b/third_party/xla/xla/service/gpu/matmul_utils.h
@@ -108,6 +108,11 @@
static absl::StatusOr<GemmConfig> For(const HloInstruction* gemm);
+ // Gets the GemmConfig of the `gemm` instruction with overridden
+ // GemmBackendConfig.
+ static absl::StatusOr<GemmConfig> For(const HloInstruction* gemm,
+ const GemmBackendConfig& config);
+
static absl::StatusOr<GemmConfig> For(
const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims,
absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape,
diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD
index e4b5048..6173fa5 100644
--- a/third_party/xla/xla/service/gpu/model/BUILD
+++ b/third_party/xla/xla/service/gpu/model/BUILD
@@ -475,7 +475,6 @@
"//xla/service:gather_simplifier",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:matmul_utils",
- "//xla/service/gpu/fusions:tiling_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
@@ -549,7 +548,6 @@
":indexing_test_utils",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:hlo_traversal",
- "//xla/service/gpu/fusions:tiling_util",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings:string_view",
diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc
index 11ebb82..6a18288 100644
--- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc
@@ -52,7 +52,7 @@
std::vector<bool> IsReadCoalescedPerOperand(const HloInstruction* root) {
auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root);
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
auto fusion = dynamic_cast<KernelFusionInterface*>(emitter.get());
EXPECT_NE(fusion, nullptr);
@@ -71,7 +71,7 @@
bool IsReadCoalescedHeuristic(absl::string_view hlo_string) {
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
HloInstruction* root = module->entry_computation()->root_instruction();
- auto analysis = AnalyzeFusion(*root, device_info_);
+ auto analysis = HloFusionAnalysis::Create(*root, device_info_);
return xla::gpu::IsReadCoalescedHeuristic(analysis.GetEmitterFusionKind(),
root->operand(0), root);
}
diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc
index ba033fb..f4369df 100644
--- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc
+++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc
@@ -33,7 +33,8 @@
}
}
- HloFusionAnalysis analysis = AnalyzeFusion(instruction, device_info_);
+ HloFusionAnalysis analysis =
+ HloFusionAnalysis::Create(instruction, device_info_);
absl::MutexLock lock(&mutex_);
// If some other thread created an entry for this key concurrently, return
@@ -59,7 +60,7 @@
}
HloFusionAnalysis analysis =
- AnalyzeProducerConsumerFusion(producer, consumer, device_info_);
+ HloFusionAnalysis::Create(producer, consumer, device_info_);
absl::MutexLock lock(&mutex_);
// If some other thread created an entry for this key concurrently, return
diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc
index 49b914e..f3417c9 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc
@@ -250,7 +250,7 @@
/*exec_time=*/absl::ZeroDuration()};
}
- auto fusion_analysis = AnalyzeFusion(*producer, *device_info_);
+ auto fusion_analysis = HloFusionAnalysis::Create(*producer, *device_info_);
bool is_coalesced = IsReadCoalescedHeuristic(
fusion_analysis.GetEmitterFusionKind(), producer);
@@ -261,7 +261,7 @@
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer(
const HloInstruction* producer, const HloInstruction* consumer) {
auto fusion_analysis =
- AnalyzeProducerConsumerFusion(*producer, *consumer, *device_info_);
+ HloFusionAnalysis::Create(*producer, *consumer, *device_info_);
bool is_coalesced = IsReadCoalescedHeuristic(
fusion_analysis.GetEmitterFusionKind(), producer, consumer);
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
index c2057e0..6bb4071 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
@@ -56,7 +56,7 @@
// TODO(jreiffers): Remove this once all callers use a cache.
std::optional<HloFusionAnalysis> local_analysis;
if (!config.fusion_analysis_cache) {
- local_analysis = AnalyzeFusion(*instr, device_info);
+ local_analysis = HloFusionAnalysis::Create(*instr, device_info);
}
const auto& fusion_analysis = config.fusion_analysis_cache
? config.fusion_analysis_cache->Get(*instr)
@@ -144,7 +144,7 @@
// TODO(jreiffers): Remove this once all callers use a cache.
std::optional<HloFusionAnalysis> local_analysis;
if (!config.fusion_analysis_cache) {
- local_analysis = AnalyzeFusion(*fused_consumer, device_info);
+ local_analysis = HloFusionAnalysis::Create(*fused_consumer, device_info);
}
const auto& analysis_unfused =
config.fusion_analysis_cache
@@ -193,7 +193,7 @@
std::optional<HloFusionAnalysis> local_analysis_fused;
if (!config.fusion_analysis_cache) {
local_analysis_fused =
- AnalyzeProducerConsumerFusion(*producer, *consumer, device_info);
+ HloFusionAnalysis::Create(*producer, *consumer, device_info);
}
const auto& fusion_analysis =
config.fusion_analysis_cache
@@ -296,8 +296,8 @@
std::optional<HloFusionAnalysis> local_analysis_fused;
if (!config.fusion_analysis_cache) {
- local_analysis_fused = AnalyzeProducerConsumerFusion(
- *producer, *fused_consumer, device_info);
+ local_analysis_fused =
+ HloFusionAnalysis::Create(*producer, *fused_consumer, device_info);
}
const auto& analysis_fused =
config.fusion_analysis_cache
@@ -345,8 +345,9 @@
const GpuPerformanceModelOptions& config,
absl::Span<const HloInstruction* const> fused_consumers,
bool multi_output) {
- EstimateRunTimeData producer_runtime = EstimateRunTimeForInstructionCached(
- producer, device_info, cost_analysis, config);
+ auto cache_result = config.gpu_performance_model_cache->Get(*producer);
+ CHECK(cache_result.has_value());
+ EstimateRunTimeData producer_runtime = *cache_result;
absl::Duration time_unfused =
kKernelLaunchOverhead * (fused_consumers.size() + 1) +
@@ -357,8 +358,10 @@
for (auto fused_consumer : fused_consumers) {
VLOG(8) << "Fused consumer: " << fused_consumer->name();
- EstimateRunTimeData consumer_runtime = EstimateRunTimeForInstructionCached(
- fused_consumer, device_info, cost_analysis, config);
+ auto cache_result =
+ config.gpu_performance_model_cache->Get(*fused_consumer);
+ CHECK(cache_result.has_value());
+ EstimateRunTimeData consumer_runtime = *cache_result;
time_unfused += consumer_runtime.exec_time;
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc
index 56e34c3..08ae2e5 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc
@@ -88,8 +88,6 @@
std::optional<EstimateRunTimeData> GpuPerformanceModelCache::Get(
const HloInstruction& instruction) {
- absl::MutexLock lock(&mutex_);
-
auto it = instruction_runtime_data_.find(&instruction);
if (it != instruction_runtime_data_.end()) {
return it->second;
@@ -113,8 +111,6 @@
void GpuPerformanceModelCache::Set(const HloInstruction& instruction,
const EstimateRunTimeData& runtime_data) {
- absl::MutexLock lock(&mutex_);
-
instruction_runtime_data_[&instruction] = runtime_data;
}
@@ -126,8 +122,6 @@
}
void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) {
- absl::MutexLock lock(&mutex_);
-
// Remove runtime data for the instruction.
instruction_runtime_data_.erase(&instruction);
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc
index a3eb170..b645800 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc
@@ -211,7 +211,7 @@
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
- auto fusion_analysis = AnalyzeFusion(
+ auto fusion_analysis = HloFusionAnalysis::Create(
*module->entry_computation()->root_instruction(), device_info_);
auto launch_dimensions =
GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis);
@@ -247,7 +247,7 @@
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
- auto fusion_analysis = AnalyzeFusion(
+ auto fusion_analysis = HloFusionAnalysis::Create(
*module->entry_computation()->root_instruction(), device_info_);
auto launch_dimensions =
GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis);
@@ -276,7 +276,7 @@
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
- auto fusion_analysis = AnalyzeFusion(
+ auto fusion_analysis = HloFusionAnalysis::Create(
*module->entry_computation()->root_instruction(), device_info_);
auto launch_dimensions =
GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis);
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
index 4c0c35e..7c88789 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
@@ -68,9 +68,19 @@
GpuPerformanceModel::RunTimes EstimateRunTimesForPriorityFusion(
const HloInstruction* producer,
std::vector<HloInstruction*> fused_consumers = {}) {
+ auto config = GpuPerformanceModelOptions::PriorityFusion(
+ &fusion_analysis_cache_, &gpu_performance_model_cache_);
+
+ auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction(
+ producer, device_info_, &analysis_, config);
+ gpu_performance_model_cache_.Set(*producer, runtime_data);
+ for (auto consumer : fused_consumers) {
+ auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction(
+ consumer, device_info_, &analysis_, config);
+ gpu_performance_model_cache_.Set(*consumer, runtime_data);
+ }
return GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
- producer, device_info_, &analysis_,
- GpuPerformanceModelOptions::PriorityFusion(), fused_consumers);
+ producer, device_info_, &analysis_, config, fused_consumers);
}
mlir::MLIRContext mlir_context_;
@@ -82,6 +92,7 @@
se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()};
HloFusionAnalysisCache fusion_analysis_cache_{device_info_};
GpuHloCostAnalysis analysis_{options_, device_info_};
+ GpuPerformanceModelCache gpu_performance_model_cache_;
GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{
&device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(),
@@ -674,16 +685,16 @@
}
fused_computation.0 {
- p0 = f32[4,28672,32] parameter(0)
- tanh = f32[4,28672,32] tanh(p0)
+ p0 = f32[4,256,32] parameter(0)
+ tanh = f32[4,256,32] tanh(p0)
c1 = f32[] constant(72)
- broadcast = f32[4,28672,32] broadcast(c1), dimensions={}
- ROOT mul = f32[4,28672,32] multiply(tanh, broadcast)
+ broadcast = f32[4,256, 32] broadcast(c1), dimensions={}
+ ROOT mul = f32[4,256,32] multiply(tanh, broadcast)
}
ENTRY fusion {
- p0 = f32[4,28672,32] parameter(0)
- fusion = f32[4,28672,32] fusion(p0), kind=kLoop, calls=fused_computation.0
+ p0 = f32[4,256,32] parameter(0)
+ fusion = f32[4,256,32] fusion(p0), kind=kLoop, calls=fused_computation.0
c0 = f32[] constant(0)
ROOT reduce = f32[4,32] reduce(fusion, c0), to_apply=add, dimensions={1}
})";
diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc
index 8f81cb4..e8842b3 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc
+++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc
@@ -48,7 +48,6 @@
#include "xla/layout.h"
#include "xla/permutation_util.h"
#include "xla/service/gather_simplifier.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/affine_map_printer.h"
@@ -1142,13 +1141,6 @@
return out;
}
-AffineMap GetTilingAffineMap(llvm::ArrayRef<AffineExpr> exprs,
- int64_t num_symbols) {
- return AffineMap::get(
- /*dimCount=*/6, /*symbolCount=*/num_symbols, exprs,
- exprs[0].getContext());
-}
-
} // namespace
IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) {
@@ -1219,83 +1211,6 @@
shape.dimensions(), {});
}
-AffineMap GetBlockOffsetsForTiling(
- absl::Span<const int64_t> num_blocks,
- absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
- MLIRContext* mlir_context) {
- auto offsets =
- DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), num_blocks);
- for (auto&& [offset, tile_size] : llvm::zip(offsets, tile_sizes_per_block)) {
- offset = offset * tile_size;
- }
- return GetTilingAffineMap(offsets, rank);
-}
-
-AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
- MLIRContext* mlir_context) {
- return GetBlockOffsetsForTiling(tiling.GetBlockCounts(),
- tiling.GetBlockTileSize(),
- tiling.GetShape().size(), mlir_context);
-}
-
-AffineMap GetThreadOffsetsForTiling(
- absl::Span<const int64_t> num_threads,
- absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
- MLIRContext* mlir_context) {
- auto offsets =
- DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), num_threads);
- for (int dim = 0; dim < rank; ++dim) {
- if (tile_sizes_per_thread[dim] > 1) {
- offsets[dim] = offsets[dim] +
- getAffineSymbolExpr(dim, mlir_context) * num_threads[dim];
- }
- }
- return GetTilingAffineMap(offsets, rank);
-}
-
-AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
- MLIRContext* mlir_context) {
- return GetThreadOffsetsForTiling(tiling.GetThreadsPerBlock(),
- tiling.GetThreadTileSize(),
- tiling.GetShape().size(), mlir_context);
-}
-
-IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
- MLIRContext* mlir_context) {
- return GetIndexingMapForTiling(
- GetBlockOffsetsForTiling(tiling, mlir_context),
- GetThreadOffsetsForTiling(tiling, mlir_context),
- tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(),
- tiling.GetThreadTileSize(), tiling.GetShape());
-}
-
-IndexingMap GetIndexingMapForTiling(AffineMap block_offsets,
- AffineMap thread_offsets,
- int64_t threads_per_block,
- int64_t num_blocks,
- absl::Span<const int64_t> thread_tile_sizes,
- absl::Span<const int64_t> tiled_shape) {
- auto* mlir_context = block_offsets.getContext();
- llvm::SmallVector<AffineExpr, 4> offsets;
- offsets.reserve(block_offsets.getNumResults());
- for (auto [block, thread] :
- llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) {
- offsets.push_back(block + thread);
- }
- std::vector<DimVar> dimension_ranges{
- {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {},
- };
- auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(),
- block_offsets.getNumSymbols(), offsets,
- mlir_context);
- IndexingMap map{affine_map, dimension_ranges,
- RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}};
- for (int i = 0; i < tiled_shape.size(); ++i) {
- map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1});
- }
- return map;
-}
-
bool HloInstructionIndexing::Simplify() {
bool any_simplified = false;
for (auto& operand_indexing : indexing_maps) {
diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h
index 201b8f6..e475b5a 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h
+++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h
@@ -18,7 +18,6 @@
#define XLA_SERVICE_GPU_MODEL_INDEXING_ANALYSIS_H_
#include <cstdint>
-#include <functional>
#include <ostream>
#include <string>
#include <vector>
@@ -31,7 +30,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/MLIRContext.h"
#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/model/affine_map_printer.h"
#include "xla/service/gpu/model/indexing_map.h"
@@ -145,35 +143,6 @@
IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(
const Shape& shape, mlir::MLIRContext* mlir_context);
-// Creates an indexing map from thread and block IDs to elements of the tiled
-// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2
-// are thread indices (currently only 0 is used), dimensions 3 to 5 are block
-// indices (currently only 3 is used).
-mlir::AffineMap GetBlockOffsetsForTiling(
- absl::Span<const int64_t> num_blocks,
- absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
- mlir::MLIRContext* mlir_context);
-mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
- mlir::MLIRContext* mlir_context);
-mlir::AffineMap GetThreadOffsetsForTiling(
- absl::Span<const int64_t> num_threads,
- absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
- mlir::MLIRContext* mlir_context);
-mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
- mlir::MLIRContext* mlir_context);
-
-// Convenience functions for the two functions above
-// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up
-// the ranges of dimensions and symbols.
-IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
- mlir::MLIRContext* mlir_context);
-IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets,
- mlir::AffineMap thread_offsets,
- int64_t threads_per_block,
- int64_t num_blocks,
- absl::Span<const int64_t> thread_tile_sizes,
- absl::Span<const int64_t> tiled_shape);
-
// Returns the shape of the output of the instruction.
const Shape& GetOutputShape(const HloInstruction* instr, int64_t output_id);
diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc
index 30fd805..d30c963 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc
@@ -20,7 +20,6 @@
#include "absl/strings/string_view.h"
#include "mlir/IR/MLIRContext.h"
#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/model/indexing_test_utils.h"
#include "xla/tests/hlo_test_base.h"
@@ -2564,32 +2563,6 @@
ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap())));
}
-TEST_F(IndexingAnalysisTest, TilingIndexing) {
- Tiling tiling{/*shape=*/{1022, 256, 16},
- /*tile_sizes=*/{8, 1, 4},
- /*num_threads=*/{1, 4, 4}};
- auto indexing_map = GetIndexingMapForTiling(tiling, &mlir_context_);
- indexing_map.Simplify();
- EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"(
- (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
- (d3 floordiv 64) * 8 + s0,
- (d3 mod 64) * 4 + d0 floordiv 4,
- d0 mod 4 + s2 * 4
- )
- domain:
- d0 in [0, 15]
- d1 in [0, 0]
- d2 in [0, 0]
- d3 in [0, 8191]
- d4 in [0, 0]
- d5 in [0, 0]
- s0 in [0, 7]
- s1 in [0, 0]
- s2 in [0, 3]
- (d3 floordiv 64) * 8 + s0 in [0, 1021]
- )"));
-}
-
TEST_F(IndexingAnalysisTest, EpilogueIndexing) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule m
diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc
index f0d34e8..da21c34 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_map.cc
+++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc
@@ -1113,6 +1113,20 @@
return eval.getConstantResults();
}
+bool IndexingMap::IsSymbolConstrained(int64_t symbol_id) const {
+ for (const auto& [expr, _] : constraints_) {
+ bool result = false;
+ expr.walk([&](mlir::AffineExpr leaf) {
+ auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(leaf);
+ if (sym && sym.getPosition() == symbol_id) {
+ result = true;
+ }
+ });
+ if (result) return true;
+ }
+ return false;
+}
+
RangeEvaluator::RangeEvaluator(const IndexingMap& indexing_map,
MLIRContext* mlir_context, bool use_constraints)
: mlir_context_(mlir_context),
diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h
index 478e0ec..2e6cc13 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_map.h
+++ b/third_party/xla/xla/service/gpu/model/indexing_map.h
@@ -369,6 +369,9 @@
llvm::ArrayRef<mlir::AffineExpr> dim_const_exprs,
llvm::ArrayRef<mlir::AffineExpr> symbol_const_exprs) const;
+ // Returns true if there is a constraint on the given symbol.
+ bool IsSymbolConstrained(int64_t symbol_id) const;
+
// Returns true if the domain is empty. If it returns false, that does not
// mean that the domain is not effectively empty.
// For example, if there are two constraints 0 <= d0 mod 7 <= 0 and
diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc
index dc5c2f2..1ea107c 100644
--- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc
+++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc
@@ -602,23 +602,8 @@
AffineExpr strided_indexing, absl::Span<Interval const> dimension_intervals,
absl::Span<Interval const> symbol_intervals) {
MLIRContext* ctx = strided_indexing.getContext();
- // Deal with the symbol case (capturing a whole untiled dimension).
- // TODO(b/330906085): concatenating across a reduction dimension needs to be
- // handled by this code.
- if (auto symbol = llvm::dyn_cast<AffineSymbolExpr>(strided_indexing)) {
- const Interval& symbol_interval = symbol_intervals[symbol.getPosition()];
- if (symbol_interval.lower != 0) {
- return std::nullopt;
- }
-
- return SizeAndStrideExpression(
- /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx),
- /*stride=*/getAffineConstantExpr(1, ctx));
- }
-
AffineMapPrinter printer;
- // TODO(b/328427138): support multivariate size expressions.
switch (strided_indexing.getKind()) {
case AffineExprKind::DimId:
return SizeAndStrideExpression(/*size=*/strided_indexing,
@@ -626,23 +611,15 @@
case mlir::AffineExprKind::Mul: {
const auto mul = llvm::cast<mlir::AffineBinaryOpExpr>(strided_indexing);
AffineExpr lhs = mul.getLHS();
- // The stride may not be fully collapsed if it is negative; in that case,
- // we need to extract the negative multiplier first.
- if (const auto rhs = llvm::dyn_cast<AffineConstantExpr>(mul.getRHS());
- rhs && rhs.getValue() == -1) {
- std::optional<SizeAndStrideExpression> maybe_size_and_stride =
- ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals);
- if (!maybe_size_and_stride.has_value()) {
- return std::nullopt;
- }
-
- return SizeAndStrideExpression(
- /*size=*/maybe_size_and_stride->size,
- /*stride=*/maybe_size_and_stride->stride * rhs);
+ std::optional<SizeAndStrideExpression> maybe_size_and_stride =
+ ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals);
+ if (!maybe_size_and_stride.has_value()) {
+ return std::nullopt;
}
- CHECK(lhs.getKind() == AffineExprKind::DimId);
- return SizeAndStrideExpression(/*size=*/lhs,
- /*stride=*/mul.getRHS());
+
+ return SizeAndStrideExpression(
+ /*size=*/maybe_size_and_stride->size,
+ /*stride=*/maybe_size_and_stride->stride * mul.getRHS());
}
case mlir::AffineExprKind::Mod: {
auto mod = llvm::cast<mlir::AffineBinaryOpExpr>(strided_indexing);
@@ -656,15 +633,18 @@
case mlir::AffineExprKind::Constant:
return SizeAndStrideExpression(/*size=*/getAffineConstantExpr(1, ctx),
/*stride=*/getAffineConstantExpr(0, ctx));
- case mlir::AffineExprKind::SymbolId:
- VLOG(1) << "Encountered complex size expression involving symbol "
- << printer.ToString(strided_indexing);
- // It's currently not checked separately, but RTVars shouldn't appear in
- // the strided indexing expressions.
- return std::nullopt;
+ case mlir::AffineExprKind::SymbolId: {
+ auto symbol = llvm::cast<AffineSymbolExpr>(strided_indexing);
+ const Interval& symbol_interval = symbol_intervals[symbol.getPosition()];
+ if (symbol_interval.lower != 0) {
+ return std::nullopt;
+ }
+
+ return SizeAndStrideExpression(
+ /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx),
+ /*stride=*/getAffineConstantExpr(1, ctx));
+ }
case mlir::AffineExprKind::Add: {
- // TODO(b/328427138): this should only be necessary in the multivariate
- // case, and will be implemented later.
std::optional<std::vector<SizeAndStrideExpression>>
maybe_sizes_and_strides =
ExtractSizesAndStridesFromMultivariateSummation(
diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
index 1db5537..92c851d 100644
--- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
+++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
@@ -549,6 +549,61 @@
)")));
}
+TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReductionOfSplittedAxis) {
+ // A split reshape of a reverse creates a sum of strided symbols.
+ auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"(
+ HloModule m
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+ }
+
+ computation {
+ p0 = f32[18] parameter(0)
+ bitcast = f32[9,2] bitcast(p0)
+ c0 = f32[] constant(0)
+ reduce_0 = f32[9] reduce(bitcast, c0), dimensions={1}, to_apply=add
+ ROOT reduce_1 = f32[] reduce(reduce_0, c0), dimensions={0}, to_apply=add
+ }
+
+ ENTRY e {
+ p0 = f32[18] parameter(0)
+ ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=computation
+ }
+ )"));
+
+ EXPECT_THAT(
+ SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()),
+ Optional(MatchSymbolicTileString(R"(
+ Symbolic tile with
+ offset_map: () -> (0)
+ size_map: () -> (18)
+ stride_map: () -> (1)
+ )")));
+}
+
+TEST_F(SymbolicTileTest, CanPropagateTileThroughSummationOfSymbols) {
+ // Such an indexing map is representative of a sequence of HLOs containing a
+ // bitcast followed by two sequential reductions of the split axis, i.e.
+ // something like
+ // p0 = f32[18] parameter(0)
+ // bitcast = f32[9,2] bitcast(p0)
+ // reduce_0 = f32[9] reduce(bitcast), dimensions={1}
+ // reduce_1 = f32[] reduce(reduce_0), dimensions={0}
+ IndexingMap indexing_map = IndexingMap::FromTensorSizes(
+ ParseAffineMap("()[s0, s1] -> (s1 * 2 + s0)", &mlir_context_), {},
+ {2, 9});
+
+ EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map),
+ Optional(MatchSymbolicTileString(R"(
+ Symbolic tile with
+ offset_map: () -> (0)
+ size_map: () -> (18)
+ stride_map: () -> (1)
+ )")));
+}
+
TEST_F(SymbolicTileTest,
FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshape) {
// TODO(b/349487906): constraints should allow us to unblock this use case.
diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.cc b/third_party/xla/xla/service/gpu/move_copy_to_users.cc
deleted file mode 100644
index acc10db..0000000
--- a/third_party/xla/xla/service/gpu/move_copy_to_users.cc
+++ /dev/null
@@ -1,240 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/move_copy_to_users.h"
-
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor {
- // Turn copy->pad into pad->copy
- absl::Status HandlePad(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- HloInstruction* c = hlo->mutable_operand(1);
- if (operand->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied = operand->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * earlier_pad,
- MakePadHlo(copied, c, hlo->padding_config(), &hlo->metadata()));
- // MakePadHlo fails to propagate layout.
- *earlier_pad->mutable_shape()->mutable_layout() =
- copied->shape().layout();
- HloInstruction* later_copy = MakeCopyHlo(earlier_pad, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- return absl::OkStatus();
- }
-
- // Turn copy->slice into slice->copy, as slice is layout-preserving.
- absl::Status HandleSlice(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied = operand->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * earlier_slice,
- MakeSliceHlo(copied, hlo->slice_starts(), hlo->slice_limits(),
- hlo->slice_strides(), &hlo->metadata()));
- *earlier_slice->mutable_shape()->mutable_layout() =
- copied->shape().layout();
- HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- return absl::OkStatus();
- }
-
- // Turn copy->dynamic-slice into dynamic-slice->copy, as dynamic-slice is
- // layout-preserving.
- absl::Status HandleDynamicSlice(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied = operand->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * earlier_slice,
- MakeDynamicSliceHlo(
- copied,
- absl::Span<HloInstruction* const>(hlo->operands()).subspan(1),
- hlo->dynamic_slice_sizes(), &hlo->metadata()));
- *earlier_slice->mutable_shape()->mutable_layout() =
- copied->shape().layout();
- HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- return absl::OkStatus();
- }
-
- // Turn copy->reduce_window into reduce_window->copy, as reduce_window is
- // layout-preserving.
- absl::Status HandleReduceWindow(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied = operand->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * earlier_reduce_window,
- MakeReduceWindowHlo(copied, hlo->mutable_operand(1), hlo->window(),
- hlo->called_computations()[0], &hlo->metadata()));
- *earlier_reduce_window->mutable_shape()->mutable_layout() =
- copied->shape().layout();
- HloInstruction* later_copy =
- MakeCopyHlo(earlier_reduce_window, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleReduce(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- // Reductions can handle transposes, e.g. via column reduction.
- if (operand->opcode() == HloOpcode::kCopy && !hlo->shape().IsTuple()) {
- HloInstruction* new_reduce = hlo->AddInstruction(
- hlo->CloneWithNewOperands(hlo->shape(), {operand->mutable_operand(0),
- hlo->mutable_operand(1)}));
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_reduce));
- }
- return absl::OkStatus();
- }
-
- absl::Status HandleBitcastConvert(HloInstruction* hlo) override {
- return absl::OkStatus();
- }
-
- // Sink kCopy across elementwise unary.
- absl::Status HandleElementwiseUnary(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- if (hlo->opcode() == HloOpcode::kReducePrecision) {
- return absl::OkStatus();
- }
- if (operand->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied = operand->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * earlier_elementwise,
- MakeUnaryHlo(hlo->opcode(), copied, &hlo->metadata()));
- HloInstruction* later_copy =
- MakeCopyHlo(earlier_elementwise, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- return absl::OkStatus();
- }
-
- // Sink kCopy across reverse
- absl::Status HandleReverse(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied = operand->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * earlier_reverse,
- MakeReverseHlo(copied, hlo->dimensions(), &hlo->metadata()));
- HloInstruction* later_copy = MakeCopyHlo(earlier_reverse, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- return absl::OkStatus();
- }
-
- // Sink kCopy across convert.
- absl::Status HandleConvert(HloInstruction* hlo) override {
- HloInstruction* operand = hlo->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied = operand->mutable_operand(0);
- HloInstruction* earlier_convert = MakeConvertToHlo(
- copied, hlo->shape().element_type(), &hlo->metadata());
- HloInstruction* later_copy = MakeCopyHlo(earlier_convert, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- return absl::OkStatus();
- }
-
- // Sink kCopy across elementwise binary.
- absl::Status HandleElementwiseBinary(HloInstruction* hlo) override {
- HloInstruction* a = hlo->mutable_operand(0);
- HloInstruction* b = hlo->mutable_operand(1);
- if (a->opcode() == HloOpcode::kCopy && b->opcode() == HloOpcode::kCopy) {
- HloInstruction* copied_a = a->mutable_operand(0);
- HloInstruction* copied_b = b->mutable_operand(0);
- if (copied_a->shape() == copied_b->shape()) {
- HloInstruction* earlier_elementwise;
- if (hlo->opcode() == HloOpcode::kCompare) {
- TF_ASSIGN_OR_RETURN(
- earlier_elementwise,
- MakeCompareHlo(hlo->comparison_direction(), copied_a, copied_b,
- &hlo->metadata()));
- } else {
- TF_ASSIGN_OR_RETURN(earlier_elementwise,
- MakeBinaryHlo(hlo->opcode(), copied_a, copied_b,
- &hlo->metadata()));
- }
- HloInstruction* later_copy =
- MakeCopyHlo(earlier_elementwise, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
- }
- }
- return absl::OkStatus();
- }
-
- // Move copy across kConcat if it occurs on all operands.
- absl::Status HandleConcatenate(HloInstruction* hlo) override {
- const HloInstruction* first = hlo->operand(0);
- if (first->opcode() != HloOpcode::kCopy) {
- return absl::OkStatus();
- }
- const HloInstruction* inner_op = first->operand(0);
- const Layout& inner_op_layout = inner_op->shape().layout();
-
- std::vector<HloInstruction*> new_operands;
- new_operands.reserve(hlo->operand_count());
- for (HloInstruction* op : hlo->mutable_operands()) {
- if (op->opcode() != HloOpcode::kCopy ||
- op->operand(0)->shape().layout() != inner_op_layout) {
- VLOG(3) << "Mismatch between " << op->ToString()
- << " and expected op layout " << inner_op_layout.ToString();
- return absl::OkStatus();
- }
- new_operands.push_back(op->mutable_operand(0));
- }
-
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_concat,
- MakeConcatHlo(new_operands, hlo->concatenate_dimension()));
- *new_concat->mutable_shape()->mutable_layout() = inner_op_layout;
-
- HloInstruction* new_copy = MakeCopyHlo(new_concat, hlo->shape());
- TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_copy));
- return absl::OkStatus();
- }
-};
-
-} // end namespace
-
-absl::StatusOr<bool> MoveCopyToUsers::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- return MoveCopyToUsersVisitor{}.RunOnModule(module, execution_threads);
-}
-
-} // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.h b/third_party/xla/xla/service/gpu/move_copy_to_users.h
deleted file mode 100644
index 4a7dfb4..0000000
--- a/third_party/xla/xla/service/gpu/move_copy_to_users.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_
-#define XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Sink kCopy operations as far down the graph as possible.
-class MoveCopyToUsers : public HloModulePass {
- public:
- absl::string_view name() const override { return "move_copy_to_users"; }
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // end namespace xla
-
-#endif // XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_
diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc b/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc
deleted file mode 100644
index 10179c1..0000000
--- a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc
+++ /dev/null
@@ -1,274 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/move_copy_to_users.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/service/layout_assignment.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace {
-
-class MoveCopyToUsersTest : public HloTestBase {
- public:
- MoveCopyToUsersTest()
- : HloTestBase(/*verifier_layout_sensitive=*/true,
- /*allow_mixed_precision_in_hlo_verifier=*/true,
- LayoutAssignment::InstructionCanChangeLayout) {}
- void CheckMoveCopyToUsers(absl::string_view hlo,
- std::optional<absl::string_view> expected) {
- RunAndFilecheckHloRewrite(hlo, MoveCopyToUsers{}, expected);
- }
-};
-
-TEST_F(MoveCopyToUsersTest, Pad) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = s8[1,17,9,9]{3,1,2,0} parameter(0)
- copy = s8[1,17,9,9]{1,3,2,0} copy(input)
- constant = s8[] constant(0)
- ROOT pad = s8[1,32,9,9]{1,3,2,0} pad(copy, constant), padding=0_0x0_15x0_0x0_0
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[constant_0:%[^ ]+]] = s8[] constant(0)
-// CHECK: [[pad_1_1:%[^ ]+]] = s8[1,32,9,9]{3,1,2,0} pad([[input_2:%[^ ]+]], [[constant_0]]), padding=0_0x0_15x0_0x0_0
-// CHECK: ROOT [[copy_1_3:%[^ ]+]] = s8[1,32,9,9]{1,3,2,0} copy([[pad_1_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Unary) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,1,0} parameter(0)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- ROOT pad = f32[1,17,9,9]{1,3,2,0} sqrt(copy)
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} sqrt([[input_0]])
-// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Reverse) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,1,0} parameter(0)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- ROOT pad = f32[1,17,9,9]{1,3,2,0} reverse(copy), dimensions={1,2}
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} reverse([[input_0]]), dimensions={1,2}
-// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Convert) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,1,0} parameter(0)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- ROOT converted = f16[1,17,9,9]{1,3,2,0} convert(copy)
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[sqrt_1:%[^ ]+]] = f16[1,17,9,9]{3,2,1,0} convert([[input_0]])
-// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f16[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Slice) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,1,0} parameter(0)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- ROOT slice = f32[1,4,6,6]{1,3,2,0} slice(copy), slice={[0:1],[0:4],[0:6],[0:6]}
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[slice_0:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} slice([[input_1:%[^ ]+]]), slice={[0:1], [0:4], [0:6], [0:6]}
-// CHECK-NEXT: ROOT [[copy_1_2:%[^ ]+]] = f32[1,4,6,6]{1,3,2,0} copy([[slice_0]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, DynamicSlice) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,1,0} parameter(0)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- p0 = s32[] parameter(1)
- p1 = s32[] parameter(2)
- p2 = s32[] parameter(3)
- p3 = s32[] parameter(4)
- ROOT ds = f32[1,4,6,6]{1,3,2,0} dynamic-slice(copy, p0, p1, p2, p3), dynamic_slice_sizes={1,4,6,6}
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[ds:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} dynamic-slice({{.*}}), dynamic_slice_sizes={1,4,6,6}
-// CHECK-NEXT: ROOT {{.*}} = f32[1,4,6,6]{1,3,2,0} copy([[ds]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, ReduceWindow) {
- const char* hlo = R"(
-HloModule R2Window
-
-mul {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT mul = f32[] multiply(lhs, rhs)
-}
-
-ENTRY R2Window {
- operand = f32[256,384]{1,0} parameter(0)
- c = f32[256,384]{0,1} copy(operand)
- constant = f32[] constant(1)
- ROOT reduce-window = f32[256,384]{0,1} reduce-window(c, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[reduce_window_1_0:%[^ ]+]] = f32[256,384]{1,0} reduce-window([[operand_1:%[^ ]+]], [[constant_2:%[^ ]+]]), window={size=2x3 pad=0_1x1_1}, to_apply=[[mul_3:%[^ ]+]]
-// CHECK-NEXT: ROOT [[copy_4:%[^ ]+]] = f32[256,384]{0,1} copy([[reduce_window_1_0]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Reduce) {
- const char* hlo = R"(
-HloModule R2
-
-mul {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT mul = f32[] multiply(lhs, rhs)
-}
-
-ENTRY R2 {
- operand = f32[256,384,10]{2,1,0} parameter(0)
- c = f32[256,384,10]{0,1,2} copy(operand)
- constant = f32[] constant(1)
- ROOT reduce = f32[384,10]{0,1} reduce(c, constant), dimensions={0}, to_apply=mul
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[operand:%[^ ]+]] = f32[256,384,10]{2,1,0} parameter(0)
-// CHECK: ROOT [[reduce:%[^ ]+]] = f32[384,10]{0,1} reduce([[operand]], [[constant_2:%[^ ]+]]), dimensions={0}, to_apply=[[mul_3:%[^ ]+]]
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Binary) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,1,0} parameter(0)
- input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
- ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[input2_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(1)
-// CHECK: [[add_1_2:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} add([[input_0]], [[input2_1]])
-// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[add_1_2]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, BinaryDifferentLayoutNoChange) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,0,1} parameter(0)
- input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
- ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
-}
-)";
-
- CheckMoveCopyToUsers(hlo, std::nullopt);
-}
-
-TEST_F(MoveCopyToUsersTest, Concat) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,1,0} parameter(0)
- input2 = f32[5,17,9,9]{3,2,1,0} parameter(1)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- copy2 = f32[5,17,9,9]{1,3,2,0} copy(input2)
- ROOT add = f32[6,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
-}
-)";
-
- CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[input2_1:%[^ ]+]] = f32[5,17,9,9]{3,2,1,0} parameter(1)
-// CHECK: [[concat:%[^ ]+]] = f32[6,17,9,9]{3,2,1,0} concatenate([[input_0]], [[input2_1]])
-// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[6,17,9,9]{1,3,2,0} copy([[concat]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, ConcatDifferentLayoutNoChange) {
- const char* hlo = R"(
-HloModule module
-
-ENTRY main {
- input = f32[1,17,9,9]{3,2,0,1} parameter(0)
- input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
- copy = f32[1,17,9,9]{1,3,2,0} copy(input)
- copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
- ROOT add = f32[2,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
-}
-)";
-
- CheckMoveCopyToUsers(hlo, std::nullopt);
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/multi_output_fusion.cc
deleted file mode 100644
index 6ac1217..0000000
--- a/third_party/xla/xla/service/gpu/multi_output_fusion.cc
+++ /dev/null
@@ -1,521 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/multi_output_fusion.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_dfs_reachability.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/model/gpu_performance_model.h"
-#include "xla/service/gpu/model/gpu_performance_model_base.h"
-#include "xla/service/hlo_graph_dumper.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-bool IsProfitableOperand(HloInstruction* instr) {
- // Effective scalars are not a profitable shared operand. Skip them.
- return !ShapeUtil::IsEffectiveScalar(instr->shape());
-}
-
-// Finds and returns the unique `slice` op where `parent` is used in `instr`.
-// Returns `nullptr` if no such `slice` exists.
-const HloSliceInstruction* FindUniqueSlice(const HloInstruction* parent,
- const HloInstruction* instr) {
- if (const auto* slice = DynCast<HloSliceInstruction>(instr)) {
- return slice;
- } else if (const auto* fusion = DynCast<HloFusionInstruction>(instr)) {
- const HloSliceInstruction* result = nullptr;
- for (size_t i = 0; i < fusion->operand_count(); ++i) {
- if (fusion->operand(i) == parent) {
- // Parameter used more than once -> there's no unique slice.
- if (result) return nullptr;
-
- auto* called_param = fusion->fused_parameter(i);
- if (called_param->user_count() != 1) return nullptr;
-
- result = FindUniqueSlice(called_param, called_param->users()[0]);
- if (!result) return nullptr;
- }
- }
- return result;
- } else {
- return nullptr;
- }
-}
-
-FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1,
- const HloInstruction& instr2,
- const HloInstruction* parent) {
- if (parent->shape().IsTuple()) return {};
- // Allow MOF if the parameter is small, even if there's no overlap. 1024 bytes
- // were arbitrarily chosen as the threshold.
- if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) return {};
-
- const HloSliceInstruction* slice1 = FindUniqueSlice(parent, &instr1);
- const HloSliceInstruction* slice2 = FindUniqueSlice(parent, &instr2);
- if (!slice1 || !slice2) return {};
-
- // TODO(jreiffers): Check strides as well.
- auto& starts1 = slice1->slice_starts();
- auto& starts2 = slice2->slice_starts();
- auto& limits1 = slice1->slice_limits();
- auto& limits2 = slice2->slice_limits();
-
- for (int64_t dim = 0; dim < parent->shape().rank(); ++dim) {
- bool overlap = starts1[dim] < limits2[dim] && starts2[dim] < limits1[dim];
- if (!overlap) {
- return "slices are non-overlapping";
- }
- }
- return {};
-}
-
-FusionDecision LegalToFuse(const HloInstruction& instr1,
- const HloInstruction& instr2,
- const se::DeviceDescription& device_info,
- FusionInfoCache* fusion_info_cache) {
- CHECK(instr1.opcode() == HloOpcode::kFusion);
-
- // The emitter only supports in-place DUS for fusions with a single DUS at the
- // root. Don't sibling fuse DUS for now.
- // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
- // share the input and output buffers and add support to the emitter.
- if (instr1.fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice ||
- (instr2.opcode() == HloOpcode::kFusion &&
- instr2.fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice)) {
- return "can't fuse multiple DUSs";
- }
-
- // Do this check last, as it may be expensive.
- return FusionFitsInBudget(instr1, instr2, device_info,
- /*is_consumer_producer_fusion=*/false,
- fusion_info_cache);
-}
-
-// We prefer multi-output fusions over other fusions over unfused ops, because
-// we want to preserve fusion opportunities if possible.
-int FusionPriority(const HloInstruction* instr) {
- if (instr->IsMultiOutputFusion()) {
- return 2;
- }
- if (instr->opcode() == HloOpcode::kFusion) {
- return 1;
- }
- return 0;
-}
-
-HloInstruction* SelectPreferredFusionCandidate(
- const std::vector<HloInstruction*> candidates) {
- if (candidates.empty()) {
- return nullptr;
- }
- return *std::max_element(
- candidates.begin(), candidates.end(),
- [](const HloInstruction* a, const HloInstruction* b) {
- return FusionPriority(a) < FusionPriority(b);
- });
-}
-
-// Do not fuse a producer if the other operands of the fusion are
-// reachable from the producer, this would create a cycle.
-FusionDecision OperandReachableFromProducer(
- const HloInstruction& producer, const HloInstruction& consumer,
- const HloDfsReachability& reachability) {
- for (const auto* operand : consumer.operands()) {
- // If a get-tuple-element instruction is not in the reachability
- // map, it has been created by fusion in this pass. Simply move
- // on to its operand, which is in the reachability map.
- if (!reachability.IsPresent(operand) &&
- operand->opcode() == HloOpcode::kGetTupleElement) {
- operand = operand->operand(0);
- }
- CHECK(reachability.IsPresent(operand) && reachability.IsPresent(&producer))
- << "Reachability map is incomplete. This should never "
- "happen.";
- if (&producer != operand && reachability.IsReachable(&producer, operand)) {
- return {
- absl::StrCat(producer.name(), " would introduce a cycle when fused")};
- }
- }
- return {};
-}
-
-FusionDecision ProducerCandidateIsFusible(
- const HloInstruction& producer, const HloInstruction& consumer,
- const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache,
- const se::DeviceDescription& device_info,
- GpuHloCostAnalysis* cost_analysis) {
- if (!IsFusibleAsMultiOutputFusionRoot(consumer)) {
- return "consumer not eligible as multi-output fusion root.";
- }
-
- RETURN_IF_NOT_FUSIBLE(
- ShapesCompatibleForMultiOutputFusion(consumer, producer));
-
- RETURN_IF_NOT_FUSIBLE(
- OperandReachableFromProducer(producer, consumer, reachability));
-
- RETURN_IF_NOT_FUSIBLE(FusionFitsInBudget(
- producer, consumer, device_info,
- /*is_consumer_producer_fusion=*/false, fusion_info_cache));
-
- if (cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer)) {
- return "will generate too large IR";
- }
-
- GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
- &producer, device_info, cost_analysis,
- GpuPerformanceModelOptions::Default(),
- /*fused_consumers=*/{&consumer},
- /*multi_output=*/true);
- if (t.time_fused > t.time_unfused) {
- return "will execute slower if fused";
- }
-
- return {};
-}
-
-std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
- const HloInstruction* producer, const HloDfsReachability& reachability,
- FusionInfoCache* fusion_info_cache,
- const se::DeviceDescription& device_info,
- GpuHloCostAnalysis* cost_analysis) {
- std::vector<HloInstruction*> fusion_candidates;
- const HloComputation* computation = producer->parent();
- const HloModule* module = computation->parent();
- bool dump_fusion =
- module->config().debug_options().xla_dump_fusion_visualization();
-
- // If the producer is not a valid candidate for MOF, no need to check any of
- // its users.
- if (!IsProducerMultiOutputFusible(*producer)) {
- return fusion_candidates;
- }
-
- // If there is only one user, and it is not a multi-output fusion node, this
- // fusion possibility was already considered and rejected by the FusionMerger
- // pass. No need to try again!
- if (producer->user_count() == 1 &&
- !producer->users()[0]->IsMultiOutputFusion()) {
- return fusion_candidates;
- }
-
- for (HloInstruction* consumer : producer->users()) {
- VLOG(3) << "Looking at producer " << producer->name()
- << " and its consumer " << consumer->name();
-
- if (auto decision = ProducerCandidateIsFusible(
- *producer, *consumer, reachability, fusion_info_cache, device_info,
- cost_analysis)) {
- fusion_candidates.push_back(consumer);
- } else if (dump_fusion) {
- RegisterFusionState(
- *computation,
- absl::StrCat("Not considering fusion of producer |", producer->name(),
- "| into consumer |", consumer->name(),
- "| due to: ", decision.Explain()),
- *consumer, producer);
- }
- }
- return fusion_candidates;
-}
-
-bool IsSiblingFusionCandidate(const HloInstruction* instr) {
- if (instr->users().empty() || !IsFusibleAsMultiOutputFusionRoot(*instr) ||
- IsNestableVariadicReduction(*instr)) {
- return false;
- }
- // Check if the users of multioutput fusion is not a get-tuple-element.
- // If this is the case, we bail out because the transformation assumes
- // the users are get-tuple-element.
- return (!instr->IsMultiOutputFusion() ||
- absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
- return user->opcode() == HloOpcode::kGetTupleElement;
- }));
-}
-
-FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1,
- const HloInstruction& sibling_consumer_2,
- const HloInstruction& common_producer,
- const HloDfsReachability& reachability,
- FusionInfoCache* fusion_info_cache,
- const se::DeviceDescription& device_info) {
- if (reachability.IsConnected(&sibling_consumer_1, &sibling_consumer_2)) {
- return {absl::StrCat(sibling_consumer_1.name(), " and ",
- sibling_consumer_2.name(), " are connected")};
- }
-
- RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion(
- sibling_consumer_1, sibling_consumer_2));
-
- // Technically, this check is order-dependent (e.g. siblings A, B, C where
- // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is
- // [C, A, B], only {C, B} will be fused, and A will only be fused in the
- // next iteration of the fusion pipeline, potentially requiring several
- // iterations to converge. We assume this case to be very rare in
- // practice.
- RETURN_IF_NOT_FUSIBLE(ParameterSlicesAreNonOverlapping(
- sibling_consumer_1, sibling_consumer_2, &common_producer));
-
- // This check should be last, as it may be expensive.
- RETURN_IF_NOT_FUSIBLE(LegalToFuse(sibling_consumer_1, sibling_consumer_2,
- device_info, fusion_info_cache));
- return {};
-}
-
-} // namespace
-
-void GpuMultiOutputFusion::RecomputeReachability() {
- reachability_ = HloDfsReachability::Build(computation_);
-}
-
-bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent,
- FusionInfoCache* fusion_info_cache,
- GpuHloCostAnalysis* cost_analysis) {
- const HloComputation* computation = parent->parent();
- const HloModule* module = computation->parent();
- bool dump_fusion =
- module->config().debug_options().xla_dump_fusion_visualization();
-
- if (!IsProfitableOperand(parent)) {
- VLOG(3) << "Operand " << parent->ToShortString() << " is not profitable";
- return false;
- }
- bool changed = false;
- std::vector<HloInstruction*> siblings;
- // Only consider siblings that are fusion candidates.
- absl::c_copy_if(parent->users(), std::back_inserter(siblings),
- IsSiblingFusionCandidate);
- // Sort the siblings such that multi-output fusion ops occur first, followed
- // by fusion ops, followed by unfused ops.
- absl::c_stable_sort(siblings,
- [](const HloInstruction* a, const HloInstruction* b) {
- return FusionPriority(a) > FusionPriority(b);
- });
-
- for (auto i = siblings.begin(); i != siblings.end(); ++i) {
- VLOG(3) << "Considering " << (*i)->name();
- if ((*i)->opcode() != HloOpcode::kFusion) {
- continue;
- }
- for (auto j = i + 1; j != siblings.end();) {
- VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
-
- if (auto fusible = CanFuseSiblings(**i, **j, *parent, *reachability_,
- fusion_info_cache, device_info_);
- !fusible) {
- // We pick `j` arbitrarily as a consumer.
- if (dump_fusion) {
- RegisterFusionState(
- *computation,
- absl::StrCat("Not fusing siblings |", (**i).name(), "| and |",
- (**j).name(), "| due to: ", fusible.Explain()),
- // Randomly pick one consumer.
- /*consumer=*/**i,
- /*producer=*/parent);
- }
- ++j;
- continue;
- }
- if (!ConsumeFuel(name(), [&] {
- return absl::StrFormat("Not fusing siblings %s and %s.",
- (*i)->name(), (*j)->name());
- })) {
- ++j;
- continue;
- }
- VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
- fusion_info_cache->Invalidate(*i);
- fusion_info_cache->Invalidate(*j);
- HloInstruction* remaining = *i;
- HloInstruction* fused = *j;
- TF_CHECK_OK(cost_analysis->RemoveInstruction(remaining));
- TF_CHECK_OK(cost_analysis->RemoveInstruction(fused));
-
- DumpFusionState(*remaining,
- absl::StrCat("About to fuse sibling |", fused->name(),
- "| into sibling |", remaining->name(),
- "| inside multi-output fusion"),
- /*producer=*/fused);
-
- if (fused->opcode() == HloOpcode::kFusion) {
- remaining->MergeFusionInstructionIntoMultiOutput(fused);
- if (fused->IsInputFusion()) {
- remaining->set_fusion_kind(HloInstruction::FusionKind::kInput);
- }
- } else {
- remaining->FuseInstructionIntoMultiOutput(fused);
- CHECK_EQ(0, fused->user_count());
- TF_CHECK_OK(computation_->RemoveInstruction(fused));
- }
- DumpFusionState(*remaining,
- absl::StrCat("Fused into |", remaining->name(),
- "| inside multi-output fusion"));
- TF_CHECK_OK(cost_analysis->RevisitInstruction(remaining));
- changed = true;
- siblings.erase(j);
- RecomputeReachability();
- }
- }
- return changed;
-}
-
-absl::StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
- bool changed = false;
- RecomputeReachability();
- GpuHloCostAnalysis cost_analysis({shape_size_function_,
- /*per_second_rates=*/{},
- /*count_multiple_input_accesses=*/true},
- device_info_);
- TF_RETURN_IF_ERROR(computation_->Accept(&cost_analysis));
- std::vector<HloInstruction*> defs_before_uses =
- computation_->MakeInstructionPostOrder();
-
- FusionInfoCache fusion_info_cache;
- // Traverse the HLO in uses-before-defs order.
- for (auto it = defs_before_uses.rbegin(); it != defs_before_uses.rend();
- ++it) {
- auto* producer = *it;
- // Never multi-output fuse constants. To the extent that we want to fuse
- // constants, that should be handled by the regular fusion pass.
- if (producer->opcode() == HloOpcode::kConstant) {
- VLOG(3) << producer->name() << " is a constant.";
- continue;
- }
- if (producer->IsCustomFusion()) {
- continue;
- }
- // First, fuse the consumer ops of the current op, which are siblings.
- if (FuseSiblings(/*parent=*/producer, &fusion_info_cache, &cost_analysis)) {
- changed = true;
- }
- // Second, perform producer-consumer multi-output fusion. This order will
- // ensure that all get-tuple-element ops inserted as a by-product of
- // multi-output fusion will occur before the current op in the order of
- // traversal, and hence, not get into the way of subsequent fusion attempts.
- const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
- producer, *reachability_, &fusion_info_cache, device_info_,
- &cost_analysis);
- auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
- if (consumer_for_fusion == nullptr) {
- continue;
- }
- if (!ConsumeFuel(name(), [&] {
- return absl::StrFormat("Not fusing %s and %s.", producer->name(),
- consumer_for_fusion->name());
- })) {
- continue;
- }
- changed = true;
- fusion_info_cache.Invalidate(producer);
- fusion_info_cache.Invalidate(consumer_for_fusion);
- TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(producer));
- TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion));
-
- HloInstruction* input_fusion;
- if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
- input_fusion = consumer_for_fusion;
- VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
- << consumer_for_fusion->name();
- } else {
- input_fusion = computation_->AddInstruction(HloInstruction::CreateFusion(
- consumer_for_fusion->shape(),
- ChooseFusionKind(*producer, *consumer_for_fusion),
- consumer_for_fusion));
- VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
- << consumer_for_fusion->name() << " into "
- << input_fusion->name();
- TF_CHECK_OK(
- computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
- }
-
- DumpFusionState(*input_fusion,
- absl::StrCat("About to fuse producer |", producer->name(),
- "| into consumer |", input_fusion->name(),
- "| inside multi-output fusion"),
- /*producer=*/producer);
-
- if (producer->opcode() == HloOpcode::kFusion) {
- input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
- } else {
- input_fusion->FuseInstructionIntoMultiOutput(producer);
- CHECK_EQ(0, producer->user_count());
- TF_CHECK_OK(computation_->RemoveInstruction(producer));
- }
- TF_RETURN_IF_ERROR(cost_analysis.RevisitInstruction(input_fusion));
-
- DumpFusionState(*input_fusion,
- absl::StrCat("Fused into |", input_fusion->name(),
- "| inside multi-output fusion"));
- RecomputeReachability();
- }
- return changed;
-}
-
-void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer,
- absl::string_view label,
- const HloInstruction* producer) {
- if (consumer.GetModule()
- ->config()
- .debug_options()
- .xla_dump_fusion_visualization()) {
- RegisterFusionState(*computation_, label, consumer, producer);
- }
-}
-
-absl::StatusOr<bool> GpuMultiOutputFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (auto* computation : GetFusibleComputations(*module, execution_threads)) {
- computation_ = computation;
- TF_ASSIGN_OR_RETURN(bool computation_changed, DoMultiOutputFusion());
- changed |= computation_changed;
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.h b/third_party/xla/xla/service/gpu/multi_output_fusion.h
deleted file mode 100644
index 82789d3..0000000
--- a/third_party/xla/xla/service/gpu/multi_output_fusion.h
+++ /dev/null
@@ -1,134 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
-#define XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
-
-#include <memory>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_dfs_reachability.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Multi-output fusion of sibling and producer-consumer instructions for the
-// GPU backend to reduce memory bandwidth requirements.
-//
-// 0) Before multi- 1) Sibling multi- 2) Producer-consumer
-// output fusion output fusion multi-output fusion
-//
-// p p p
-// | | |
-// v v v
-// A A +-fusion--+
-// / \ | | A |
-// | | +-fusion--+ | / \ |
-// v v | / \ | | B | |
-// B C | B C | | | | |
-// \ / | | | | | v v |
-// v v | v v | | tuple |
-// ROOT | tuple | +---------+
-// +---------+ / \
-// / \ gte_b gte_a
-// gte_b gte_c | |
-// | | | v
-// \ / | C
-// v v \ /
-// ROOT v v
-// ROOT
-//
-// Multi-output fusion ops have a tuple op at their root containing multiple
-// elements as outputs. GetTupleElement ops (depicted as gte_* above) are
-// inserted to extract tuple elements for consumers.
-//
-// The two different flavors of multi-output fusion this pass performs are
-// depicted above.
-// 1) Fusion of sibling ops reduces memory bandwidth requirements, because
-// common input parameters have to be read only once.
-// 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by
-// saving one read from memory. In the example above, B does not need to read
-// the output of A from memory, while C still does (using gte_a).
-// Note that sibling (1) and producer-consumer (2) multi-output fusion can be
-// combined.
-//
-// The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs
-// before uses). First, it attempts to fuse the consumer ops of the current op,
-// which are siblings (1). Hereafter, it attempts to fuse the current op with
-// one of its consumers (2). This order avoids a phase ordering issue (described
-// in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a
-// by-product of multi-output fusion will occur before the current op in the
-// order of traversal, and hence, not get into the way of subsequent fusion
-// attempts.
-//
-// The GpuMultiOutputFusion pass ensures several conditions are met for fusion.
-// Some of them are relevant for correctness. In particular, no cycles must be
-// introduced into the HLO module. Moreover, the code emitters for multi-output
-// fusion must support the combination of ops and their shapes. Other
-// restrictions are rather arbitrary and lifting them could be beneficial.
-// * Sibling fusion (1) requires at least one op to be a kFusion.
-// * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e.
-// the fusion kinds must match.
-
-class GpuMultiOutputFusion : public HloModulePass {
- public:
- explicit GpuMultiOutputFusion(
- const se::DeviceDescription& device_info,
- HloCostAnalysis::ShapeSizeFunction shape_size_function)
- : device_info_(device_info), shape_size_function_(shape_size_function) {}
-
- absl::string_view name() const override { return "multi_output_fusion"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache,
- GpuHloCostAnalysis* cost_analysis);
-
- absl::StatusOr<bool> DoMultiOutputFusion();
-
- // Recompute reachability for the current computation.
- void RecomputeReachability();
-
- void DumpFusionState(const HloInstruction& consumer, absl::string_view label,
- const HloInstruction* producer = nullptr);
-
- // Computation for the pass.
- HloComputation* computation_;
-
- // The reachability map of current computation.
- std::unique_ptr<HloDfsReachability> reachability_;
-
- se::DeviceDescription device_info_;
- HloCostAnalysis::ShapeSizeFunction shape_size_function_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc
deleted file mode 100644
index b333a04..0000000
--- a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc
+++ /dev/null
@@ -1,2236 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/multi_output_fusion.h"
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-
-namespace m = ::xla::match;
-
-class MultiOutputFusionTest : public HloTestBase {
- HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
- return [&](const Shape& shape) {
- constexpr int64_t kPointerSize = 8;
- return ShapeUtil::ByteSizeOf(shape, kPointerSize);
- };
- }
-
- public:
- GpuMultiOutputFusion mof_{
- TestGpuDeviceInfo::RTXA6000DeviceInfo(),
- ShapeSizeBytesFunction()};
-
- void CheckGpuMultiOutputFusion(absl::string_view hlo,
- std::optional<absl::string_view> expected) {
- RunAndFilecheckHloRewrite(
- hlo,
- GpuMultiOutputFusion{
- TestGpuDeviceInfo::RTXA6000DeviceInfo(),
- ShapeSizeBytesFunction()},
- expected);
- }
-};
-
-const char kModulePrefix[] = R"(
- HloModule test_module
-
- scalar_add_computation {
- scalar_lhs.0 = f32[] parameter(0)
- scalar_rhs.0 = f32[] parameter(1)
- ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
- }
- scalar_mul_computation {
- scalar_lhs.1 = f32[] parameter(0)
- scalar_rhs.1 = f32[] parameter(1)
- ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
- })";
-
-static int64_t CountMultiOutputFusions(const HloModule* module) {
- int multi_output_fusion_count = 0;
- for (auto* computation : module->MakeNonfusionComputations()) {
- for (auto* instr : computation->instructions()) {
- if (instr->IsMultiOutputFusion()) {
- multi_output_fusion_count++;
- }
- }
- }
- return multi_output_fusion_count;
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
- // Fusion with reduce instruction root and a sibling reduce instruction
- // sharing the same input param.
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- const.1 = f32[] parameter(0)
- ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
- }
-
- ENTRY entry {
- p0 = f32[] parameter(0)
- p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- const.2 = f32[] constant(1)
- fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
- reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
- ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* fusion =
- module->entry_computation()->root_instruction()->operand(0)->operand(0);
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p1.1 = f32[6400]{0} parameter(1)
- mul = f32[6400]{0} multiply(p1.1, p1.1)
- const.1 = f32[] parameter(0)
- ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation
- }
-
- fused_computation_2 {
- p1.2 = f32[6400]{0} parameter(1)
- r1 = f32[64,100]{0,1} reshape(p1.2)
- const.2 = f32[] parameter(0)
- ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
- }
-
- ENTRY entry {
- p0 = f32[] parameter(0)
- p1 = f32[6400]{0} parameter(1)
- fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
- fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
- ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, ReduceMofDifferentTypes) {
- // Fusion with reduce instruction root and a sibling reduce instruction
- // sharing the same input param.
- const char* hlo = R"(
-HloModule module
-
-scalar_add_computation {
- scalar_lhs.1 = f32[] parameter(0)
- scalar_rhs.1 = f32[] parameter(1)
- ROOT add.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
-}
-
-scalar_add_computation_f16 {
- scalar_lhs.0 = f16[] parameter(0)
- scalar_rhs.0 = f16[] parameter(1)
- ROOT add.0 = f16[] add(scalar_lhs.0, scalar_rhs.0)
-}
-
-fused_computation {
- param_0.2 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- c.1 = f16[128,512,28,28]{3,2,1,0} convert(param_0.2)
- const.0 = f16[] constant(0)
- ROOT reduce.0 = f16[512]{0} reduce(c.1, const.0), dimensions={0,2,3}, to_apply=scalar_add_computation_f16
-}
-
-ENTRY entry {
- p0 = f32[] parameter(0)
- p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- const.2 = f32[] constant(0)
- reduce.1 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
- fusion = f16[512]{0} fusion(p1), kind=kInput, calls=fused_computation
- ROOT root = (f32[512]{0}, f16[512]{0}) tuple(reduce.1, fusion)
-})";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation
-// CHECK-NEXT: [[param_0_2_0:%[^ ]+]] = f32[128,512,28,28]{3,2,1,0} parameter(0)
-// CHECK-NEXT: [[c_1_1:%[^ ]+]] = f16[128,512,28,28]{3,2,1,0} convert([[param_0_2_0]])
-// CHECK-NEXT: [[const_0_2:%[^ ]+]] = f16[] constant(0)
-// CHECK-NEXT: [[reduce_0_3:%[^ ]+]] = f16[512]{0} reduce([[c_1_1]], [[const_0_2]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_f16_4:%[^ ]+]]
-// CHECK-NEXT: [[param_1_5:%[^ ]+]] = f32[] parameter(1)
-// CHECK-NEXT: [[reduce_2_6:%[^ ]+]] = f32[512]{0} reduce([[param_0_2_0]], [[param_1_5]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_7:%[^ ]+]]
-// CHECK-NEXT: ROOT [[tuple_8:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) tuple([[reduce_0_3]], [[reduce_2_6]])
-// CHECK: [[fusion_9:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) fusion([[p1_10:%[^ ]+]], [[const_2_11:%[^ ]+]]), kind=kInput, calls=[[fused_computation_12:%[^ ]+]]
-)");
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p1.1 = f32[10,10]{1,0} parameter(1)
- mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
- const.1 = f32[] parameter(0)
- ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation
- }
-
- fused_computation_2 {
- p1.2 = f32[10,10]{1,0} parameter(1)
- const.2 = f32[] parameter(0)
- ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
- }
-
- ENTRY entry {
- p0 = f32[] parameter(0)
- p1.3 = f32[10,10]{1,0} parameter(1)
- fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1
- p2 = f32[] parameter(2)
- fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2
- ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
- // Two sibling fusions with reduce instruction roots sharing the same input
- // param.
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
- const.1 = f32[] parameter(0)
- ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
- }
-
- fused_computation_2 {
- p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- const.2 = f32[] parameter(0)
- ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
- }
-
- ENTRY entry {
- p0 = f32[] parameter(0)
- p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
- fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
- fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
- ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* fusion =
- module->entry_computation()->root_instruction()->operand(0)->operand(0);
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionNoSiblingFusionForCommonScalar) {
- // Two sibling fusions with bitcast roots sharing the same scalar input param.
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- param_0.87 = bf16[32,4096,16384]{2,1,0} parameter(0)
- param_1.4620 = s32[] parameter(1)
- constant_3949 = s32[] constant(0)
- compare.1026 = pred[] compare(param_1.4620, constant_3949), direction=LT
- constant_5437 = s32[] constant(32)
- add.6859 = s32[] add(param_1.4620, constant_5437)
- select.1599 = s32[] select(compare.1026, add.6859, param_1.4620)
- dynamic-slice.59 = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0.87, select.1599, constant_3949, constant_3949), dynamic_slice_sizes={1,4096,16384}
- ROOT bitcast.41089 = bf16[4096,16384]{1,0} bitcast(dynamic-slice.59)
- }
-
- fused_computation_2 {
- param_0 = bf16[32,4096,16384]{2,1,0} parameter(0)
- param_1 = s32[] parameter(1)
- constant = s32[] constant(0)
- compare = pred[] compare(param_1, constant), direction=LT
- constant.32 = s32[] constant(32)
- add = s32[] add(param_1, constant.32)
- select = s32[] select(compare, add, param_1)
- dynamic-slice = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0, select, constant, constant), dynamic_slice_sizes={1,4096,16384}
- ROOT bitcast.41087 = bf16[4096,16384]{1,0} bitcast(dynamic-slice)
- }
-
- ENTRY entry {
- p0 = s32[] parameter(0)
- p1 = bf16[32,4096,16384]{2,1,0} parameter(1)
- p2 = bf16[32,4096,16384]{2,1,0} parameter(2)
- fusion.1 = bf16[4096,16384]{1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation_1
- fusion.2 = bf16[4096,16384]{1,0} fusion(p2, p0), kind=kLoop, calls=fused_computation_2
- ROOT root = (bf16[4096,16384]{1,0}, bf16[4096,16384]{1,0}) tuple(fusion.1, fusion.2)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest,
- MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
- // Multi-output fusion with two reduce instructions root and a sibling reduce
- // instruction sharing the same input param.
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
- const.1 = f32[] constant(1)
- p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
- reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
- reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
- ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
- }
-
- ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
- p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
- const = f32[] constant(1)
- fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
- get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
- get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
- reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
- ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* fusion =
- module->entry_computation()->root_instruction()->operand(0)->operand(0);
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest,
- MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
- // Verify that if we already have a multi-output fusion that we prefer to pick
- // a reduce op from its operands for checking shape compatibility.
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p1.1 = f32[10,10]{1,0} parameter(1)
- mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
- const.1 = f32[] parameter(0)
- reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation
- ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1)
- }
-
- fused_computation_2 {
- p1.2 = f32[10,10]{1,0} parameter(1)
- const.2 = f32[] parameter(0)
- ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
- }
-
- ENTRY entry {
- p0 = f32[] parameter(0)
- p1 = f32[10,10]{1,0} parameter(1)
- p2 = f32[] parameter(2)
- fusion.1 = (f32[10,10], f32[]) fusion(p0, p1), kind=kInput, calls=fused_computation_1
- get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[]) fusion.1), index=0
- get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[]) fusion.1), index=1
- fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2
- ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, LoopVariadicReductionFusions) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation.94 {
- tmp_0 = f32[] parameter(0)
- tmp_1 = f32[] parameter(1)
- tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
- tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
- tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
- tmp_5 = s32[] parameter(2)
- tmp_6 = s32[] parameter(3)
- tmp_7 = s32[] minimum(tmp_5, tmp_6)
- tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
- tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
- ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
- }
-
- minmax_func.1536 {
- tmp_0 = f32[] parameter(0)
- tmp_1 = f32[] parameter(2)
- tmp_2 = s32[] parameter(1)
- tmp_3 = s32[] parameter(3)
- ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
- }
-
- fused_computation {
- tmp_0 = f32[554112,10]{1,0} parameter(0)
- tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
- tmp_2 = f32[] constant(-inf)
- tmp_3 = s32[] constant(0)
- ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
- }
-
- fused_computation2 {
- tmp_0 = f32[554112,10]{1,0} parameter(0)
- tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
- tmp_2 = f32[] constant(inf)
- tmp_3 = s32[] constant(1)
- ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
- }
-
- ENTRY e {
- tmp_0 = f32[554112,10]{1,0} parameter(0)
- tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
- tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
- tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation2
- tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
- ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
- })"))
- .value();
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, InputVariadicReductionFusions) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation.1117 {
- param_0.2433 = f32[] parameter(0)
- param_1.2571 = f32[] parameter(1)
- compare.1770 = pred[] compare(param_0.2433, param_1.2571), direction=LE
- select.682 = f32[] select(compare.1770, param_0.2433, param_1.2571)
- compare.1303.clone.1 = pred[] compare(param_0.2433, param_1.2571), direction=EQ
- param_2.6460 = s32[] parameter(2)
- param_3.6755 = s32[] parameter(3)
- minimum.633.clone.1 = s32[] minimum(param_2.6460, param_3.6755)
- select.398.clone.1 = s32[] select(compare.1770, param_2.6460, param_3.6755)
- select.397.clone.1 = s32[] select(compare.1303.clone.1, minimum.633.clone.1, select.398.clone.1)
- ROOT tuple.151 = (f32[], s32[]) tuple(select.682, select.397.clone.1)
- }
-
- minmax_func.223 {
- lhs_value.224 = f32[] parameter(0)
- rhs_value.226 = f32[] parameter(2)
- lhs_index.225 = s32[] parameter(1)
- rhs_index.227 = s32[] parameter(3)
- ROOT fusion.1117 = (f32[], s32[]) fusion(lhs_value.224, rhs_value.226, lhs_index.225, rhs_index.227), kind=kLoop, calls=fused_computation.1117
- }
-
- fused_computation.73 {
- bitcast.86661 = f32[3,1024,300]{2,1,0} parameter(0)
- iota.734 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
- bitcast.97555 = s32[3,1024,300]{2,1,0} bitcast(iota.734)
- constant_3917 = f32[] constant(inf)
- constant_3918 = s32[] constant(0)
- ROOT reduce.1069 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86661, bitcast.97555, constant_3917, constant_3918), dimensions={2}, to_apply=minmax_func.223
- }
-
- fused_computation.84 {
- bitcast.86676 = f32[3,1024,300]{2,1,0} parameter(0)
- iota.732 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
- bitcast.97553 = s32[3,1024,300]{2,1,0} bitcast(iota.732)
- constant_3915 = f32[] constant(inf)
- constant_3916 = s32[] constant(0)
- ROOT reduce.1070 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86676, bitcast.97553, constant_3915, constant_3916), dimensions={2}, to_apply=minmax_func.223
- }
-
- ENTRY e {
- p0 = f32[3,1024,300]{2,1,0} parameter(0)
- fusion.84 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.84
- gte.391 = s32[3,1024]{1,0} get-tuple-element(fusion.84), index=1
- fusion.73 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.73
- gte.393 = s32[3,1024]{1,0} get-tuple-element(fusion.73), index=1
- ROOT r = s32[3,1024]{1,0} add(gte.391, gte.393)
- })"))
- .value();
- EXPECT_TRUE(mof_.Run(module.get()).value());
- EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->user_count(),
- 1);
- const HloInstruction* fusion =
- module->entry_computation()->parameter_instruction(0)->users()[0];
- EXPECT_THAT(fusion, GmockMatch(m::Fusion()));
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p0.1 = f32[6400]{0} parameter(0)
- ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
- }
-
- fused_computation_2 {
- p0.2 = f32[6400]{0} parameter(0)
- const.2 = f32[] constant(1)
- broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
- ROOT div = f32[6400]{0} divide(p0.2, broadcast)
- }
-
- ENTRY entry {
- p0 = f32[6400]{0} parameter(0)
- fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
- fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
- ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* fusion =
- module->entry_computation()->root_instruction()->operand(0)->operand(0);
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p0.1 = f32[6400]{0} parameter(0)
- ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
- }
-
- ENTRY entry {
- p0 = f32[6400]{0} parameter(0)
- fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
- const.2 = f32[] constant(1)
- broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
- div = f32[6400]{0} divide(p0, broadcast)
- ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* fusion =
- module->entry_computation()->root_instruction()->operand(0)->operand(0);
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
- ROOT mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
- }
-
- fused_computation_2 {
- p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
- const.2 = f32[] constant(0)
- ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), dimensions={0,3}, to_apply=scalar_add_computation
- }
-
- ENTRY entry {
- p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
- fusion.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
- fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
- ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0}) tuple(fusion.1, fusion.2)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
- mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
- exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
- ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
- f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
- }
-
- fused_computation_2 {
- p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
- const.2 = f32[] constant(0)
- broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2),
- dimensions={}
- ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
- }
-
- ENTRY entry {
- p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
- fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
- f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
- calls=fused_computation_1
- fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop,
- calls=fused_computation_2
- gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
- gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
- ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
- f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0})
- tuple(gte0, gte1, fusion.2)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* fusion =
- module->entry_computation()->root_instruction()->operand(0)->operand(0);
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add())));
-}
-
-TEST_F(MultiOutputFusionTest,
- MultiOutputFusionSiblingMultiOutputLoopAndMultiOutputLoop) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p0.1 = f32[8,16]{1,0} parameter(0)
- mul = f32[8,16]{1,0} multiply(p0.1, p0.1)
- exp = f32[8,16]{1,0} exponential(p0.1)
- ROOT tuple = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(mul, exp)
- }
-
- fused_computation_2 {
- p0.2 = f32[8,16]{1,0} parameter(0)
- const.2 = f32[] constant(0)
- broadcast = f32[8,16]{1,0} broadcast(const.2),
- dimensions={}
- add = f32[8,16]{1,0} add(p0.2, broadcast)
- ROOT tuple.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(add, broadcast)
- }
-
- ENTRY entry {
- p0 = f32[8,16]{1,0} parameter(0)
- fusion.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
- calls=fused_computation_1
- fusion.2 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
- calls=fused_computation_2
- gte0 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=0
- gte1 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=1
- gte2 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=0
- gte3 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=1
- ROOT root = (f32[8,16]{1,0}, f32[8,16]{1,0}, f32[8,16]{1,0},
- f32[8,16]{1,0})
- tuple(gte0, gte1, gte2, gte3)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* fusion =
- module->entry_computation()->root_instruction()->operand(0)->operand(0);
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(
- fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add(), m::Broadcast())));
-}
-
-TEST_F(MultiOutputFusionTest,
- MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation_1 {
- p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
- mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
- exp = f32[8,1,5,16,1,2]{5,4,3,2,1,0} exponential(p0.1)
- ROOT tuple = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
- f32[8,1,5,16,1,2]{5,4,3,2,1,0}) tuple(mul, exp)
- }
-
- fused_computation_2 {
- p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
- const.2 = f32[] constant(0)
- ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2),
- dimensions={0,3}, to_apply=scalar_add_computation
- }
-
- ENTRY entry {
- p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
- fusion.1 = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
- f32[8,1,5,16,1,2]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
- calls=fused_computation_1
- fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop,
- calls=fused_computation_2
- gte0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
- gte1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
- ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
- f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0})
- tuple(gte0, gte1, fusion.2)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, SiblingFusionBitcastAndLoopFusionNotFused) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-
-fused_computation_1 {
- p0.1 = f32[2048,16000]{1,0} parameter(0)
- bitcast = f32[2048,1,16000]{2,1,0} bitcast(p0.1)
- ROOT exp = f32[2048,1,16000]{2,1,0} exponential(bitcast)
-}
-
-ENTRY main {
- param_0 = f32[2048,16000]{1,0} parameter(0)
- fusion = f32[2048,1,16000]{2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation_1
- bitcast = f32[16000,1,2048]{2,1,0} bitcast(param_0)
- ROOT tuple.143 = (f32[16000,1,2048]{2,1,0}, f32[2048,1,16000]{2,1,0}) tuple(bitcast, fusion)
-})")
- .value();
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest,
- ProducerConsumerFusionBitcastAndElementwiseNotFused) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-
-ENTRY main {
- param_0 = f32[2048,16000]{1,0} parameter(0)
- convert = bf16[2048,16000]{1,0} convert(param_0)
- bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
- ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
-})")
- .value();
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- ENTRY reduce {
- p0 = f32[32,32,32]{2,1,0} parameter(0)
- c0 = f32[] constant(0)
- exp = f32[32,32,32]{2,1,0} exponential(p0)
- reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
- to_apply=scalar_add_computation
- ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement())));
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Exp())));
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_add {
- p0.1 = f32[32,32,32]{2,1,0} parameter(0)
- p1.1 = f32[32,32,32]{2,1,0} parameter(1)
- ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
- }
-
- ENTRY reduce {
- p0 = f32[32,32,32]{2,1,0} parameter(0)
- p1 = f32[32,32,32]{2,1,0} parameter(1)
- c0 = f32[] constant(0)
- add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
- reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
- to_apply=scalar_add_computation
- ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement())));
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Add())));
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_select {
- p1.1 = f32[32,32,32]{2,1,0} parameter(1)
- c0 = f32[] constant(0)
- broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
- greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
- f32[32,32,32]{2,1,0} broadcast), direction=GT
- p0.1 = f32[32,32,32]{2,1,0} parameter(0)
- ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
- greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
- }
-
- fused_reduce {
- p0.2 = f32[32,32,32]{2,1,0} parameter(0)
- c1 = f32[] constant(0)
- r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
- to_apply=scalar_add_computation
- mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
- r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
- to_apply=scalar_add_computation
- ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
- }
-
- ENTRY reduce {
- p0 = f32[32,32,32]{2,1,0} parameter(0)
- p1 = f32[32,32,32]{2,1,0} parameter(1)
- select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
- fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
- calls=fused_reduce
- gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
- gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
- ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
- tuple(gte1, gte1, select)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root,
- GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement(), m::GetTupleElement())));
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_element_wise {
- p0.1 = f32[2,2,2]{2,1,0} parameter(0)
- p1.1 = f32[2,2,2]{2,1,0} parameter(1)
- ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
- }
-
- fused_reduce {
- p0.2 = f32[2,2,2]{2,1,0} parameter(0)
- mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
- f32[2,2,2]{2,1,0} p0.2)
- broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
- c1 = f32[] constant(0)
- ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
- f32[] c1), dimensions={1,3}, to_apply=scalar_add_computation
- }
-
- ENTRY reduce {
- p0 = f32[2,2,2]{2,1,0} parameter(0)
- p1 = f32[2,2,2]{2,1,0} parameter(1)
- element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
- fusion = f32[2,2]{1,0} fusion(element_wise), kind=kLoop, calls=fused_reduce
- ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest,
- ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_select {
- p1.1 = f16[32,32,32]{2,1,0} parameter(1)
- c0 = f16[] constant(0)
- broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
- greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
- f16[32,32,32]{2,1,0} broadcast), direction=GT
- p0.1 = f16[32,32,32]{2,1,0} parameter(0)
- ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
- greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
- }
- fused_reduce {
- p0.2 = f16[32,32,32]{2,1,0} parameter(0)
- convert = f32[32,32,32]{2,1,0} convert(p0.2)
- c1 = f32[] constant(0)
- r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
- to_apply=scalar_add_computation
- mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
- r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
- to_apply=scalar_add_computation
- ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
- }
- ENTRY reduce {
- p0 = f16[32,32,32]{2,1,0} parameter(0)
- p1 = f16[32,32,32]{2,1,0} parameter(1)
- select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
- fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
- calls=fused_reduce
- gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
- gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
- ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
- tuple(gte1, gte1, select)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root,
- GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
- m::GetTupleElement(), m::GetTupleElement())));
- ASSERT_TRUE(fusion->IsMultiOutputFusion());
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
-}
-
-TEST_F(MultiOutputFusionTest,
- ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- mixed_input_layouts_computation {
- p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
- p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
- copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
- c0 = f16[] constant(0)
- broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
- greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
- ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
- }
- fused_reduce {
- p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
- convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
- c0.2 = f32[] constant(0)
- ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
- }
- ENTRY reduce {
- p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
- p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
- loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
- reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
- ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
- })"))
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionAvoidsCycles) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_add {
- p0 = f32[32,32,32]{2,1,0} parameter(0)
- p1 = f32[32,32,32]{2,1,0} parameter(1)
- ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
- }
-
- fused_mul {
- p2 = f32[64,64,64]{2,1,0} parameter(0)
- p3 = f32[64,64,64]{2,1,0} parameter(1)
- ROOT multiply = f32[64,64,64]{2,1,0} multiply(p2, p3)
- }
-
- fused_reduce_1 {
- p4 = f32[32,32,32]{2,1,0} parameter(0)
- p5 = f32[64,64,64]{2,1,0} parameter(1)
- slice = f32[32,32,32]{2,1,0} slice(p5), slice={[0:32], [0:32], [0:32]}
- add = f32[32,32,32]{2,1,0} add(p4, slice)
- c0 = f32[] constant(0)
- ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
- to_apply=scalar_add_computation
- }
-
- fused_reduce_2 {
- p6 = f32[32,32,32]{2,1,0} parameter(0)
- p7 = f32[64,64,64]{2,1,0} parameter(1)
- c0 = f32[] constant(0)
- pad = f32[64,64,64]{2,1,0} pad(p6, c0), padding=16_16x16_16x16_16
- mul = f32[64,64,64]{2,1,0} multiply(pad, p7)
- ROOT r1 = f32[64,64]{1,0} reduce(mul, c0), dimensions={2},
- to_apply=scalar_add_computation
- }
-
- ENTRY reduce {
- p8 = f32[32,32,32]{2,1,0} parameter(0)
- p9 = f32[64,64,64]{2,1,0} parameter(1)
- // `add` and `mul` can be multi-output fused with `reduce1` and `reduce2`,
- // respectively. However, both isn't possible, because multi-output fusion
- // will introduce an extra dependency from `neg` to `abs` or vice versa.
- // Hence, the second multi-output fusion would introduce a cycle.
- add = f32[32,32,32]{2,1,0} fusion(p8, p8), kind=kLoop, calls=fused_add
- mul = f32[64,64,64]{2,1,0} fusion(p9, p9), kind=kLoop, calls=fused_mul
-
- reduce1 = f32[32,32]{1,0} fusion(add, mul), kind=kInput,
- calls=fused_reduce_1
- reduce2 = f32[64,64]{1,0} fusion(add, mul), kind=kInput,
- calls=fused_reduce_2
- ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[64,64]{1,0},
- f32[64,64,64]{2,1,0}) tuple(add, reduce1, reduce2, mul)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- EXPECT_EQ(1, CountMultiOutputFusions(module.get()));
-}
-
-TEST_F(MultiOutputFusionTest, PreferFuseProducerIntoFusionConsumer) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_add {
- p0 = f32[32,32,32]{2,1,0} parameter(0)
- p1 = f32[32,32,32]{2,1,0} parameter(1)
- ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
- }
- fused_reduce {
- p0 = f32[32,32,32]{2,1,0} parameter(0)
- p1 = f32[64,64,64]{2,1,0} parameter(1)
- slice = f32[32,32,32]{2,1,0} slice(p1), slice={[0:32], [0:32], [0:32]}
- add = f32[32,32,32]{2,1,0} add(p0, slice)
- c0 = f32[] constant(0)
- ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
- to_apply=scalar_add_computation
- }
- ENTRY reduce {
- p0 = f32[32,32,32]{2,1,0} parameter(0)
- p1 = f32[64,64,64]{2,1,0} parameter(1)
- add = f32[32,32,32]{2,1,0} fusion(p0, p0), kind=kLoop, calls=fused_add
- c0 = f32[] constant(0)
- reduce2 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
- to_apply=scalar_add_computation
- reduce = f32[32,32]{1,0} fusion(add, p1), kind=kInput, calls=fused_reduce
- ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[32,32]{1,0})
- tuple(add, reduce, reduce2)
- })"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- int multi_output_fusion_count = 0;
- for (auto* computation : module->MakeNonfusionComputations()) {
- for (auto* instr : computation->instructions()) {
- if (instr->IsMultiOutputFusion()) {
- multi_output_fusion_count++;
- }
- }
- }
- EXPECT_EQ(1, multi_output_fusion_count);
-}
-
-// Check that we limit the number of operands to fusions we create.
-TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) {
- constexpr int64_t kNumParams = 200;
- ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
-
- // Compute
- // p0 * p1,
- // p0 * p1 + p1 * p2
- // p0 * p1 + p1 * p2 + p2 * p3
- // ...
- // where each of the (pi * pj)'s is represented as a fusion node so that
- // multi-output fusion will pay attention to it.
- auto module = CreateNewVerifiedModule();
- HloComputation::Builder b(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
-
- std::vector<HloInstruction*> params;
- for (int64_t i = 0; i < kNumParams; ++i) {
- params.push_back(
- b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
- }
-
- // Creates a fusion node that calculates x*y.
- auto make_fusion = [&](HloInstruction* x, HloInstruction* y) {
- HloComputation::Builder sub_builder("subcomp");
- auto* p0 = sub_builder.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "p"));
- auto* p1 = sub_builder.AddInstruction(
- HloInstruction::CreateParameter(1, shape, "p"));
- sub_builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
- HloComputation* subcomp =
- module->AddEmbeddedComputation(sub_builder.Build());
- return HloInstruction::CreateFusion(
- shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp);
- };
-
- auto* sum = b.AddInstruction(make_fusion(params[0], params[1]));
- for (int64_t i = 2; i < kNumParams; ++i) {
- sum = b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kAdd, sum,
- b.AddInstruction(make_fusion(params[i - 1], params[i]))));
- }
- auto computation = module->AddEntryComputation(b.Build());
- EXPECT_TRUE(mof_.Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- for (const HloInstruction* instr : computation->instructions()) {
- EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()),
- MaxOperandsAndOutputsPerFusion())
- << instr->ToString();
- }
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
- auto module = ParseAndReturnVerifiedModule(R"(HloModule dus_mof
- fusion.1 {
- p.0 = f16[50,96,1024]{2,1,0} parameter(0)
- p.1 = f16[1,96,1024]{2,1,0} parameter(1)
- c.0 = s32[3]{0} constant({0, 0, 0})
- ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
- }
-
- fusion.2 {
- p.0 = f16[50,96,1024]{2,1,0} parameter(0)
- p.1 = f16[1,96,1024]{2,1,0} parameter(1)
- c.0 = s32[3]{0} constant({0, 0, 0})
- ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
- }
-
- ENTRY entry {
- p.00 = f16[50,96,1024]{2,1,0} parameter(0)
- p.01 = f16[50,96,1024]{2,1,0} parameter(1)
- p.1 = f16[1,96,1024]{2,1,0} parameter(2)
-
- f1 = f16[50,96,1024] fusion(p.00, p.1), kind=kLoop, calls=fusion.1
- f2 = f16[50,96,1024] fusion(p.01, p.1), kind=kLoop, calls=fusion.2
- ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2)
- })")
- .value();
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-// Check that we don't fuse too many reductions together.
-TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation0 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation1 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation2 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation3 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation4 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation5 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation6 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation7 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation8 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- fused_computation9 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
- to_apply=scalar_add_computation
- }
- ENTRY computation {
- zero = f32[] constant(0)
- param0 = f32[64,64] parameter(0)
- param1 = f32[64,64] parameter(1)
- param2 = f32[64,64] parameter(2)
- param3 = f32[64,64] parameter(3)
- param4 = f32[64,64] parameter(4)
- param5 = f32[64,64] parameter(5)
- param6 = f32[64,64] parameter(6)
- param7 = f32[64,64] parameter(7)
- param8 = f32[64,64] parameter(8)
- param9 = f32[64,64] parameter(9)
- out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
- out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
- out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
- out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
- out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
- out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
- out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
- out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
- out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
- out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
- ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
- }
- )"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
-
- EXPECT_EQ(5, CountMultiOutputFusions(module.get()));
-}
-
-TEST_F(MultiOutputFusionTest, DoNotGroupTooManyReductions) {
- auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
- fused_computation0 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation1 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation2 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation3 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation4 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation5 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation6 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation7 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation8 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- fused_computation9 {
- p0 = f32[64,64] parameter(0)
- p1 = f32[64,64] parameter(1)
- p2 = f32[] parameter(2)
- add = f32[64,64] add(p0, p1)
- ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
- to_apply=scalar_add_computation
- }
- ENTRY computation {
- zero = f32[] constant(0)
- param0 = f32[64,64] parameter(0)
- param1 = f32[64,64] parameter(1)
- param2 = f32[64,64] parameter(2)
- param3 = f32[64,64] parameter(3)
- param4 = f32[64,64] parameter(4)
- param5 = f32[64,64] parameter(5)
- param6 = f32[64,64] parameter(6)
- param7 = f32[64,64] parameter(7)
- param8 = f32[64,64] parameter(8)
- param9 = f32[64,64] parameter(9)
- out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
- out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
- out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
- out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
- out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
- out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
- out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
- out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
- out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
- out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
- ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
- }
- )"))
- .value();
- ASSERT_TRUE(mof_.Run(module.get()).value());
-
- EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
-}
-
-TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule xla_computation_update_step.10931
-
-%scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
- %scalar_lhs.1 = f64[] parameter(0)
- %scalar_rhs.1 = f64[] parameter(1)
- ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
-}
-
-%fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
- %param_0.8 = f64[64,64]{1,0} parameter(0)
- %param_1.11 = f64[64,64]{1,0} parameter(1)
- %multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
- %constant_5217.3 = f64[] constant(0)
- %broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
- %multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
- %reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
- %param_2.9 = f64[64,64]{1,0} parameter(2)
- %multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
- %constant_5217.1.clone.1 = f64[] constant(0)
- %broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
- %multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
- %reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
- ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
-}
-
-%primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
- %parameter.6427 = f64[] parameter(0)
- %parameter.6428 = f64[] parameter(1)
- ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
-}
-
-%fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
- %param_0.7 = f64[64,64]{1,0} parameter(0)
- %param_1.9 = f64[64,64]{1,0} parameter(1)
- %multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
- %constant_5217.2 = f64[] constant(0)
- ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
-}
-
-ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
- %param_0.1090 = f64[64,64]{1,0} parameter(0)
- %param_1.1377 = f64[64,64]{1,0} parameter(1)
- %param_2.1948 = f64[64,64]{1,0} parameter(2)
- %fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
- %get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
- %fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
- %get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
- ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
-}
- )")
- .value();
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-and.reduce_sub_computation {
- x = pred[] parameter(0)
- y = pred[] parameter(1)
- ROOT and = pred[] and(x, y)
-}
-
-fused_computation.1 {
- param_4.658 = f32[2,20,256]{2,0,1} parameter(4)
- slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]}
- constant.6847 = s32[] constant(0)
- broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={}
- param_9.415 = s32[3]{0} parameter(9)
- compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE
- constant.6846 = pred[] constant(true)
- reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation
- broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={}
- param_5.528 = f32[2,512]{1,0} parameter(5)
- slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]}
- bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384)
- constant.5418 = f32[] constant(0)
- broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={}
- select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227)
- add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173)
- param_0.299 = s32[] parameter(0)
- constant.5157 = s32[] constant(11)
- dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299)
- slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]}
- constant.6800 = s32[] constant(0)
- broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={}
- param_8.484 = s32[3]{0} parameter(8)
- compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE
- constant.6798 = pred[] constant(true)
- reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation
- broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={}
- param_3.1169 = f32[2,512]{1,0} parameter(3)
- slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]}
- bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382)
- select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227)
- add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172)
- constant.5154 = s32[] constant(10)
- dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299)
- slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]}
- constant.6794 = s32[] constant(0)
- broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={}
- param_7.478 = s32[3]{0} parameter(7)
- compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE
- constant.6793 = pred[] constant(true)
- reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation
- broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={}
- param_2.1685 = f32[2,512]{1,0} parameter(2)
- slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]}
- bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380)
- select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227)
- add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171)
- constant.5153 = s32[] constant(9)
- dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299)
- slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]}
- constant.6788 = s32[] constant(0)
- broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={}
- param_6.495 = s32[3]{0} parameter(6)
- compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE
- constant.6786 = pred[] constant(true)
- reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation
- broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={}
- param_1.1408 = f32[2,512]{1,0} parameter(1)
- slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]}
- bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378)
- select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227)
- add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170)
- constant.5152 = s32[] constant(8)
- ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299)
-}
-
-fused_computation.2 {
- param_4.655 = f32[2,20,256]{2,0,1} parameter(4)
- slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]}
- param_6.483 = pred[] parameter(6)
- broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={}
- param_5.525 = f32[2,512]{1,0} parameter(5)
- slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]}
- bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368)
- constant.5415 = f32[] constant(0)
- broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={}
- select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225)
- add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161)
- param_0.265 = s32[] parameter(0)
- constant.5151 = s32[] constant(7)
- dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265)
- slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]}
- constant.6782 = s32[] constant(0)
- broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={}
- param_9.391 = s32[3]{0} parameter(9)
- compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE
- constant.6781 = pred[] constant(true)
- reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation
- broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={}
- param_3.1167 = f32[2,512]{1,0} parameter(3)
- slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]}
- bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366)
- select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225)
- add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160)
- constant.5150 = s32[] constant(6)
- dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265)
- slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]}
- constant.6776 = s32[] constant(0)
- broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={}
- param_8.464 = s32[3]{0} parameter(8)
- compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE
- constant.6775 = pred[] constant(true)
- reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation
- broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={}
- param_2.1684 = f32[2,512]{1,0} parameter(2)
- slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]}
- bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364)
- select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225)
- add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159)
- constant.5149 = s32[] constant(5)
- dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265)
- slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]}
- constant.6770 = s32[] constant(0)
- broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={}
- param_7.458 = s32[3]{0} parameter(7)
- compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE
- constant.6769 = pred[] constant(true)
- reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation
- broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={}
- param_1.1405 = f32[2,512]{1,0} parameter(1)
- slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]}
- bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362)
- select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225)
- add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158)
- constant.5148 = s32[] constant(4)
- ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265)
-}
-
-ENTRY main {
- param_0.0 = s32[] parameter(0)
- param_1.0 = f32[2,512]{1,0} parameter(1)
- param_2.0 = f32[2,512]{1,0} parameter(2)
- param_3.0 = f32[2,512]{1,0} parameter(3)
- param_4.0 = f32[2,20,256]{2,1,0} parameter(4)
- param_5.0 = f32[2,512]{1,0} parameter(5)
- param_6.0 = s32[3]{0} parameter(6)
- param_7.0 = s32[3]{0} parameter(7)
- param_8.0 = s32[3]{0} parameter(8)
- param_9.0 = s32[3]{0} parameter(9)
- fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1
- param_10 = pred[] parameter(10)
- fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2
- ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2)
-}
- )")
- .value();
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, DoNotFuseRoot) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-no_op {
- arg_empty_tuple = () parameter(0)
- ROOT tuple = () tuple()
-}
-
-fused_computation {
- param_0 = f32[] parameter(0)
- ROOT convert = s32[] convert(param_0)
-}
-
-ENTRY main {
- param_0 = f32[] parameter(0)
- fusion = s32[] fusion(param_0), kind=kLoop, calls=fused_computation
- tuple = () tuple()
- conditional = () conditional(fusion, tuple, tuple), branch_computations={no_op, no_op}
- constant = f32[] constant(1)
- ROOT root = f32[] add(param_0, constant)
-}
- )")
- .value();
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, CostBasedNoMerge) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-region_3.63 {
- Arg_0.64 = f32[] parameter(0)
- Arg_1.65 = f32[] parameter(1)
- ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
-}
-
-fused_computation.29 {
- param_0.161 = f32[5,32,32,1]{3,2,1,0} parameter(0)
- multiply.208 = f32[5,32,32,1]{3,2,1,0} multiply(param_0.161, param_0.161)
- bitcast.67 = f32[5,32,32]{2,1,0} bitcast(multiply.208)
- constant.265 = f32[] constant(0)
- reduce-window.81 = f32[5,30,31]{2,1,0} reduce-window(bitcast.67, constant.265), window={size=1x3x2}, to_apply=region_3.63
- constant.264 = f32[] constant(0.166666672)
- broadcast.204 = f32[5,30,31]{2,1,0} broadcast(constant.264), dimensions={}
- multiply.205 = f32[5,30,31]{2,1,0} multiply(reduce-window.81, broadcast.204)
- constant.263 = f32[] constant(0)
- reduce-window.80 = f32[5,30,31]{2,1,0} reduce-window(multiply.205, constant.263), window={size=1x2x3 pad=0_0x0_1x1_1}, to_apply=region_3.63
- constant.262 = f32[] constant(0.0138888899)
- broadcast.201 = f32[5,30,31]{2,1,0} broadcast(constant.262), dimensions={}
- multiply.204 = f32[5,30,31]{2,1,0} multiply(reduce-window.80, broadcast.201)
- constant.261 = f32[] constant(0)
- reduce-window.78 = f32[5,30,31]{2,1,0} reduce-window(multiply.204, constant.261), window={size=1x1x2 pad=0_0x0_0x0_1}, to_apply=region_3.63
- constant.113 = f32[] constant(0.5)
- broadcast.137 = f32[5,30,31]{2,1,0} broadcast(constant.113), dimensions={}
- multiply.125 = f32[5,30,31]{2,1,0} multiply(reduce-window.78, broadcast.137)
- constant.114 = f32[] constant(0)
- ROOT reduce-window.17 = f32[5,30,31]{2,1,0} reduce-window(multiply.125, constant.114), window={size=1x2x1 pad=0_0x0_1x0_0}, to_apply=region_3.63
-}
-
-fused_computation.15 {
- constant.108 = f32[] constant(0.5)
- broadcast.105 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.108), dimensions={}
- param_3.126 = f32[5,30,31]{2,1,0} parameter(3)
- constant.295 = f32[] constant(0.25)
- broadcast.234 = f32[5,30,31]{2,1,0} broadcast(constant.295), dimensions={}
- multiply.242 = f32[5,30,31]{2,1,0} multiply(param_3.126, broadcast.234)
- broadcast.233 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.242), dimensions={0,2,3}
- param_2.154 = f32[5,30,31]{2,1,0} parameter(2)
- multiply.241 = f32[5,30,31]{2,1,0} multiply(param_2.154, broadcast.234)
- broadcast.232 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.241), dimensions={1,2,3}
- multiply.240 = f32[5,5,30,31]{3,2,1,0} multiply(broadcast.233, broadcast.232)
- param_1.188 = f32[5,5,30,31]{3,2,1,0} parameter(1)
- constant.294 = f32[] constant(0.159154937)
- broadcast.231 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.294), dimensions={}
- multiply.239 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.231)
- param_0.164 = f32[5,5,30,31]{3,2,1,0} parameter(0)
- add.19 = f32[5,5,30,31]{3,2,1,0} add(multiply.239, param_0.164)
- constant.293 = f32[] constant(0)
- reduce-window.90 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.19, constant.293), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
- constant.292 = f32[] constant(0.5)
- broadcast.230 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.292), dimensions={}
- multiply.238 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.90, broadcast.230)
- constant.291 = f32[] constant(0)
- reduce-window.89 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.238, constant.291), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
- constant.290 = f32[] constant(0.25)
- broadcast.229 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.290), dimensions={}
- multiply.237 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.89, broadcast.229)
- multiply.236 = f32[5,5,30,31]{3,2,1,0} multiply(multiply.237, multiply.237)
- subtract.10 = f32[5,5,30,31]{3,2,1,0} subtract(multiply.240, multiply.236)
- constant.289 = f32[] constant(0)
- broadcast.228 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.289), dimensions={}
- maximum.6 = f32[5,5,30,31]{3,2,1,0} maximum(subtract.10, broadcast.228)
- sqrt.6 = f32[5,5,30,31]{3,2,1,0} sqrt(maximum.6)
- constant.110 = f32[] constant(0)
- broadcast.107 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.110), dimensions={}
- compare.4 = pred[5,5,30,31]{3,2,1,0} compare(sqrt.6, broadcast.107), direction=EQ
- constant.243 = f32[] constant(0.159154937)
- broadcast.193 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.243), dimensions={}
- multiply.194 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.193)
- add.15 = f32[5,5,30,31]{3,2,1,0} add(multiply.194, param_0.164)
- constant.242 = f32[] constant(0)
- reduce-window.66 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.15, constant.242), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
- constant.241 = f32[] constant(0.5)
- broadcast.192 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.241), dimensions={}
- multiply.193 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.66, broadcast.192)
- constant.240 = f32[] constant(0)
- reduce-window.65 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.193, constant.240), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
- constant.239 = f32[] constant(0.25)
- broadcast.191 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.239), dimensions={}
- multiply.192 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.65, broadcast.191)
- compare.3 = pred[5,5,30,31]{3,2,1,0} compare(multiply.192, broadcast.107), direction=EQ
- and.1 = pred[5,5,30,31]{3,2,1,0} and(compare.4, compare.3)
- constant.109 = f32[] constant(1.57079637)
- broadcast.104 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.109), dimensions={}
- atan2.1 = f32[5,5,30,31]{3,2,1,0} atan2(sqrt.6, multiply.192)
- select.4 = f32[5,5,30,31]{3,2,1,0} select(and.1, broadcast.104, atan2.1)
- constant.107 = f32[] constant(0.159154937)
- broadcast.106 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.107), dimensions={}
- multiply.100 = f32[5,5,30,31]{3,2,1,0} multiply(select.4, broadcast.106)
- ROOT subtract.3 = f32[5,5,30,31]{3,2,1,0} subtract(broadcast.105, multiply.100)
-}
-
-fused_computation.4 {
- param_0.172 = f32[5,30,31]{2,1,0} parameter(0)
- constant.315 = f32[] constant(0.125)
- broadcast.242 = f32[5,30,31]{2,1,0} broadcast(constant.315), dimensions={}
- multiply.250 = f32[5,30,31]{2,1,0} multiply(param_0.172, broadcast.242)
- constant.314 = f32[] constant(0)
- reduce-window.100 = f32[5,30,31]{2,1,0} reduce-window(multiply.250, constant.314), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
- constant.79 = f32[] constant(0.055555556)
- broadcast.85 = f32[5,30,31]{2,1,0} broadcast(constant.79), dimensions={}
- multiply.80 = f32[5,30,31]{2,1,0} multiply(reduce-window.100, broadcast.85)
- constant.81 = f32[] constant(0)
- reduce-window.1 = f32[5,30,31]{2,1,0} reduce-window(multiply.80, constant.81), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
- constant.80 = f32[] constant(0.111111112)
- broadcast.86 = f32[5,30,31]{2,1,0} broadcast(constant.80), dimensions={}
- multiply.79 = f32[5,30,31]{2,1,0} multiply(reduce-window.1, broadcast.86)
- bitcast.26 = f32[5,930]{1,0} bitcast(multiply.79)
- ROOT reduce.8 = f32[5]{0} reduce(bitcast.26, constant.81), dimensions={1}, to_apply=region_3.63
-}
-
-ENTRY e {
- Arg_0.1 = f32[5,32,32,1]{3,2,1,0} parameter(0)
- p1 = f32[5,5,30,31]{3,2,1,0} parameter(1)
- p2 = f32[5,5,30,31]{3,2,1,0} parameter(2)
- p3 = f32[5,30,31]{2,1,0} parameter(3)
- fusion.29 = f32[5,30,31]{2,1,0} fusion(Arg_0.1), kind=kLoop, calls=fused_computation.29
- fusion.15 = f32[5,5,30,31]{3,2,1,0} fusion(p2, p1, p3, fusion.29), kind=kLoop, calls=fused_computation.15
- ROOT fusion.4 = f32[5]{0} fusion(fusion.29), kind=kInput, calls=fused_computation.4
-})")
- .value();
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, NoOverlappingRead) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fused_computation_1 {
- p0.1 = f32[100,200]{1,0} parameter(0)
- slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[0:100]}
- mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
- exp = f32[50,100]{1,0} exponential(slice.0)
- ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
- }
-
- fused_computation_2 {
- p0.2 = f32[100,200]{1,0} parameter(0)
- slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[0:50],[100:200]}
- const.2 = f32[] constant(0)
- broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
- ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
- }
-
- ENTRY entry {
- p0 = f32[100,200]{1,0} parameter(0)
- fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
- calls=fused_computation_1
- gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
- gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
- fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
- calls=fused_computation_2
- ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
- tuple(gte0, gte1, fusion.2)
- })")
- .value();
-
- EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, OverlappingRead) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fused_computation_1 {
- p0.1 = f32[100,200]{1,0} parameter(0)
- slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[50:150]}
- mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
- exp = f32[50,100]{1,0} exponential(slice.0)
- ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
- }
-
- fused_computation_2 {
- p0.2 = f32[100,200]{1,0} parameter(0)
- slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[30:80],[20:120]}
- const.2 = f32[] constant(0)
- broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
- ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
- }
-
- ENTRY entry {
- p0 = f32[100,200]{1,0} parameter(0)
- fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
- calls=fused_computation_1
- gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
- gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
- fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
- calls=fused_computation_2
- ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
- tuple(gte0, gte1, fusion.2)
- })")
- .value();
-
- EXPECT_TRUE(mof_.Run(module.get()).value());
-}
-
-class TransposeMultiOutputFusionTest : public MultiOutputFusionTest {};
-
-TEST_F(TransposeMultiOutputFusionTest, MultipleCopies) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f32[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
- p = f32[16,32]{1,0} parameter(0)
- fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
- c1 = f32[16,32]{0,1} copy(p)
- ROOT t = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple(fusion, c1)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
-// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
-// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[16,32]{0,1} copy([[param_0_1_0]])
-// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK-NEXT: }
-
-// CHECK: [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, MultipleTransposes) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f32[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
-}
-
-ENTRY main {
- p = f32[16,32]{1,0} parameter(0)
- fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
- c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
- ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(fusion, c1)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[32,16], f32[32,16]) {
-// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0}
-// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[32,16]{1,0} transpose([[param_0_1_0]]), dimensions={1,0}
-// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK-NEXT: }
-// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, CopyAndTranspose) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f32[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
- p = f32[16,32]{1,0} parameter(0)
- fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
- c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
- ROOT t = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple(fusion, c1)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, R"(
- // CHECK: %fused_computation ({{[^ ]+}} f32[16,32]) -> (f32[16,32], f32[32,16]) {
- // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
- // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0]])
- // CHECK-NEXT: [[copy:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1]])
- // CHECK-NEXT: [[transpose:[^ ]+]] = f32[32,16]{1,0} transpose([[param_0]]), dimensions={1,0}
- // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple([[copy]], [[transpose]])
- // CHECK: %fusion = (f32[16,32]{0,1}, f32[32,16]{1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, MultipleCopiesDifferentTypes) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f16[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} convert(param_0.1)
- ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
- p = f16[16,32]{1,0} parameter(0)
- fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
- c1 = f16[16,32]{0,1} copy(p)
- ROOT t = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple(fusion, c1)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f16[16,32]) -> (f32[16,32], f16[16,32]) {
-// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f16[16,32]{1,0} parameter(0)
-// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} convert([[param_0_1_0]])
-// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
-// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f16[16,32]{0,1} copy([[param_0_1_0]])
-// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK: [[fusion_5:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) fusion([[p_6:%[^ ]+]]), kind=kInput, calls=[[fused_computation_7:%[^ ]+]]
-)");
-}
-
-// Do not group copy and reduction.
-TEST_F(TransposeMultiOutputFusionTest, TiledReduceCopy) {
- const char* hlo = R"(
-HloModule module
-
-add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = add(lhs, rhs)
-}
-
-fused_computation {
- param_0.1 = f32[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
- p = f32[16,32]{1,0} parameter(0)
- fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
- z = f32[] constant(0)
- r1 = f32[32]{0} reduce(p, z), dimensions={0}, to_apply=add
- ROOT t = (f32[16,32]{0,1}, f32[32]{0}) tuple(fusion, r1)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, std::nullopt);
-}
-
-// Do not group incompatible transposes.
-TEST_F(TransposeMultiOutputFusionTest, IncompatibleTransposes) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
- param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
- s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
- t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
- sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
- exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
- ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
-}
-
-fused_computation.2 {
- param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
- s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
- ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
-}
-
-ENTRY main {
- p = f32[18,16,32]{2,1,0} parameter(0)
- p2 = f32[32,16,18]{2,1,0} parameter(1)
- fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
- fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
- ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, std::nullopt);
-}
-
-// A variation of the test above, where no CSE was run.
-TEST_F(TransposeMultiOutputFusionTest, TransposesNoCSE) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
- param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
- s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
- t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
- sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
- exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
- exp.2 = f32[32,16,18]{2,1,0} exponential(sub.1)
- ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.2)
-}
-
-fused_computation.2 {
- param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
- s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
- ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
-}
-
-ENTRY main {
- p = f32[18,16,32]{2,1,0} parameter(0)
- p2 = f32[32,16,18]{2,1,0} parameter(1)
- fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
- fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
- ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, std::nullopt);
-}
-
-TEST_F(TransposeMultiOutputFusionTest, CopyAndInput) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f32[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
- p = f32[16,32]{1,0} parameter(0)
- fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
- c1 = exponential(p)
- ROOT t = tuple(fusion, c1)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
-// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
-// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
-// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK-NEXT: }
-// CHECK: [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, TransposeAndInputEpilogueFusion) {
- const char* hlo = R"(
-HloModule module
-
-fused_computation {
- param_0.1 = f32[16,32]{1,0} parameter(0)
- s.1 = f32[16,32]{1,0} sqrt(param_0.1)
- t.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
- ROOT out = f32[32,16,1]{2,1,0} bitcast(t.1)
-}
-
-ENTRY main {
- p = f32[16,32]{1,0} parameter(0)
- fusion = f32[32,16,1]{2,1,0} fusion(p), kind=kInput, calls=fused_computation
- c1 = exponential(p)
- ROOT t = tuple(fusion, c1)
-}
- )";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation
-// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]])
-// CHECK-NEXT: [[out_3:%[^ ]+]] = f32[32,16,1]{2,1,0} bitcast([[c_1_2]])
-// CHECK-NEXT: [[c1_1_4:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
-// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]])
-// CHECK-NEXT: }
-// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-class ReduceMultiOutputFusionTest : public MultiOutputFusionTest {};
-
-TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoop) {
- const char* hlo = R"(
-HloModule module
-
-add {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a, b)
-}
-
-fused_reduction {
- p = f32[200] parameter(0)
- z = f32[] constant(0)
- e = f32[200] exponential(p)
- ROOT r = f32[] reduce(e, z), dimensions={0}, to_apply=add
-}
-
-fused_elementwise {
- p = f32[200] parameter(0)
- ROOT r = f32[200] sqrt(p)
-}
-
-ENTRY computation {
- p = f32[200] parameter(0)
- o1 = f32[200] fusion(p), kind=kLoop, calls=fused_elementwise
- o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
- ROOT out = (f32[200], f32[]) tuple(o1, o2)
-}
-
-)";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_elementwise
-// CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[200]{0} parameter(0)
-// CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[200]{0} sqrt([[p_1_0]])
-// CHECK-NEXT: [[e_2:%[^ ]+]].clone.1 = f32[200]{0} exponential([[p_1_0]])
-// CHECK-NEXT: [[z_3:%[^ ]+]].clone.1 = f32[] constant(0)
-// CHECK-NEXT: [[r_4:%[^ ]+]].clone.1 = f32[] reduce([[e_2]].clone.1, [[z_3]].clone.1), dimensions={0}, to_apply=[[add_5:%[^ ]+]]
-// CHECK-NEXT: ROOT [[tuple_6:%[^ ]+]] = (f32[200]{0}, f32[]) tuple([[r_1_1]], [[r_4]].clone.1)
-// CHECK-NEXT:}
-// CHECK: [[o1_0:%[^ ]+]] = (f32[200]{0}, f32[]) fusion([[p_2_1:%[^ ]+]]), kind=kInput, calls=[[fused_elementwise_2:%[^ ]+]]
- )");
-}
-
-TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShape) {
- const char* hlo = R"(
-HloModule module
-
-add {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] add(a, b)
-}
-
-fused_reduction {
- p = f32[10,20] parameter(0)
- z = f32[] constant(0)
- e = f32[10,20] exponential(p)
- b = f32[200] bitcast(e)
- ROOT r = f32[] reduce(b, z), dimensions={0}, to_apply=add
-}
-
-fused_elementwise {
- p = f32[10,20] parameter(0)
- ROOT r = f32[10,20] sqrt(p)
-}
-
-ENTRY computation {
- p = f32[10,20] parameter(0)
- o1 = f32[10,20] fusion(p), kind=kLoop, calls=fused_elementwise
- o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
- ROOT out = (f32[10,20], f32[]) tuple(o1, o2)
-}
-)";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_elementwise (p.1: f32[10,20]) -> (f32[10,20], f32[]) {
-// CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[10,20]{1,0} parameter(0)
-// CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[10,20]{1,0} sqrt([[p_1_0]])
-// CHECK-NEXT: [[e_2:%[^ ]+]].clone.1 = f32[10,20]{1,0} exponential([[p_1_0]])
-// CHECK-NEXT: [[b_1_3:%[^ ]+]].clone.1 = f32[200]{0} bitcast([[e_2]].clone.1)
-// CHECK-NEXT: [[z_4:%[^ ]+]].clone.1 = f32[] constant(0)
-// CHECK-NEXT: [[r_5:%[^ ]+]].clone.1 = f32[] reduce([[b_1_3]].clone.1, [[z_4]].clone.1), dimensions={0}, to_apply=[[add_6:%[^ ]+]]
-// CHECK-NEXT: ROOT [[tuple_7:%[^ ]+]] = (f32[10,20]{1,0}, f32[]) tuple([[r_1_1]], [[r_5]].clone.1)
-// CHECK-NEXT: }
- )");
-}
-
-TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShapeDifferentType) {
- const char* hlo = R"(
-HloModule module, entry_computation_layout={(f16[100,200]{1,0},f32[],f32[])->(f16[100,200]{1,0}, f32[])}
-
-max {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] maximum(a, b)
-}
-
-fused_computation {
- one_5 = f32[] constant(1)
- one_b.5 = f32[100,200]{1,0} broadcast(one_5), dimensions={}
- param_1.15 = f16[100,200]{1,0} parameter(1)
- c.6 = f32[100,200]{1,0} convert(param_1.15)
- param_0.11 = f32[] parameter(0)
- b.6 = f32[100,200]{1,0} broadcast(param_0.11), dimensions={}
- d.5 = f32[100,200]{1,0} divide(c.6, b.6)
- a.6 = f32[100,200]{1,0} add(one_b.5, d.5)
- bitcast.1 = f32[20000]{0} bitcast(a.6)
- z_1 = f32[] constant(0)
- ROOT r.1 = f32[] reduce(bitcast.1, z_1), dimensions={0}, to_apply=max
-}
-
-fused_computation.1 {
- one_3 = f32[] constant(1)
- one_b.3 = f32[100,200]{1,0} broadcast(one_3), dimensions={}
- param_2.7 = f16[100,200]{1,0} parameter(2)
- c.4 = f32[100,200]{1,0} convert(param_2.7)
- param_1.10 = f32[] parameter(1)
- b.4 = f32[100,200]{1,0} broadcast(param_1.10), dimensions={}
- d.3 = f32[100,200]{1,0} divide(c.4, b.4)
- a.4 = f32[100,200]{1,0} add(one_b.3, d.3)
- param_0.8 = f32[] parameter(0)
- output_scale_broadcast.1 = f32[100,200]{1,0} broadcast(param_0.8), dimensions={}
- a_scaled.1 = f32[100,200]{1,0} multiply(a.4, output_scale_broadcast.1)
- ROOT a_scaled_converted.1 = f16[100,200]{1,0} convert(a_scaled.1)
-}
-
-ENTRY computation {
- output_scale = f32[] parameter(2)
- input_scale = f32[] parameter(1)
- p = f16[100,200]{1,0} parameter(0)
- fusion.1 = f16[100,200]{1,0} fusion(output_scale, input_scale, p), kind=kLoop, calls=fused_computation.1
- fusion = f32[] fusion(input_scale, p), kind=kInput, calls=fused_computation
- ROOT out = (f16[100,200]{1,0}, f32[]) tuple(fusion.1, fusion)
-}
-)";
-
- CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation.1 (param_0.8: f32[], param_1.10: f32[], param_2.7: f16[100,200]) -> (f16[100,200], f32[]) {
-// CHECK-NEXT: [[one_3_0:%[^ ]+]] = f32[] constant(1)
-// CHECK-NEXT: [[one_b_3_1:%[^ ]+]] = f32[100,200]{1,0} broadcast([[one_3_0]]), dimensions={}
-// CHECK-NEXT: [[param_2_7_2:%[^ ]+]] = f16[100,200]{1,0} parameter(2)
-// CHECK-NEXT: [[c_4_3:%[^ ]+]] = f32[100,200]{1,0} convert([[param_2_7_2]])
-// CHECK-NEXT: [[param_1_10_4:%[^ ]+]] = f32[] parameter(1)
-// CHECK-NEXT: [[b_4_5:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
-// CHECK-NEXT: [[d_3_6:%[^ ]+]] = f32[100,200]{1,0} divide([[c_4_3]], [[b_4_5]])
-// CHECK-NEXT: [[a_4_7:%[^ ]+]] = f32[100,200]{1,0} add([[one_b_3_1]], [[d_3_6]])
-// CHECK-NEXT: [[param_0_8_8:%[^ ]+]] = f32[] parameter(0)
-// CHECK-NEXT: [[output_scale_broadcast_1_9:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_0_8_8]]), dimensions={}
-// CHECK-NEXT: [[a_scaled_1_10:%[^ ]+]] = f32[100,200]{1,0} multiply([[a_4_7]], [[output_scale_broadcast_1_9]])
-// CHECK-NEXT: [[a_scaled_converted_1_11:%[^ ]+]] = f16[100,200]{1,0} convert([[a_scaled_1_10]])
-// CHECK-NEXT: [[one_5_12:%[^ ]+]].clone.1 = f32[] constant(1)
-// CHECK-NEXT: [[one_b_5_13:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[one_5_12]].clone.1), dimensions={}
-// CHECK-NEXT: [[c_6_14:%[^ ]+]].clone.1 = f32[100,200]{1,0} convert([[param_2_7_2]])
-// CHECK-NEXT: [[b_6_15:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
-// CHECK-NEXT: [[d_5_16:%[^ ]+]].clone.1 = f32[100,200]{1,0} divide([[c_6_14]].clone.1, [[b_6_15]].clone.1)
-// CHECK-NEXT: [[a_6_17:%[^ ]+]].clone.1 = f32[100,200]{1,0} add([[one_b_5_13]].clone.1, [[d_5_16]].clone.1)
-// CHECK-NEXT: [[bitcast_1_18:%[^ ]+]].clone.1 = f32[20000]{0} bitcast([[a_6_17]].clone.1)
-// CHECK-NEXT: [[z_1_19:%[^ ]+]].clone.1 = f32[] constant(0)
-// CHECK-NEXT: [[r_1_20:%[^ ]+]].clone.1 = f32[] reduce([[bitcast_1_18]].clone.1, [[z_1_19]].clone.1), dimensions={0}, to_apply=[[max_21:%[^ ]+]]
-// CHECK-NEXT: ROOT [[tuple_22:%[^ ]+]] = (f16[100,200]{1,0}, f32[]) tuple([[a_scaled_converted_1_11]], [[r_1_20]].clone.1)
-// CHECK-NEXT: }
- )");
-}
-
-TEST_F(ReduceMultiOutputFusionTest, GetTupleElementMakeTupleSequence) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- fusion {
- p0 = s32[] parameter(0)
- p1 = s32[32] parameter(1)
- custom-call = (bf16[], s32[], u32[]) custom-call(p1), custom_call_target="my_custom_call"
- get-tuple-element.0 = bf16[] get-tuple-element(custom-call), index=0
- get-tuple-element.1 = s32[] get-tuple-element(custom-call), index=1
- bitcast = s32[1] bitcast(get-tuple-element.1)
- dynamic-update-slice = s32[32] dynamic-update-slice(p1, bitcast, p0)
- get-tuple-element.2 = u32[] get-tuple-element(custom-call), index=2
- ROOT tuple.30 = (bf16[], s32[32], u32[]) tuple(get-tuple-element.0, dynamic-update-slice, get-tuple-element.2)
- }
-
- ENTRY entry{
- p0 = s32[] parameter(0)
- bitcast = s32[32] bitcast(p0)
- ROOT address_computation.7.0 = (bf16[], s32[32], u32[]) fusion(p0, bitcast), kind=kCustom, calls=fusion
- }
- )")
- .value();
-
- ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc
index 40eff87..7925ebc 100644
--- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc
+++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc
@@ -15,6 +15,7 @@
#include "xla/service/gpu/nvptx_compiler.h"
+#include <algorithm>
#include <array>
#include <cstdint>
#include <fstream>
@@ -52,34 +53,35 @@
#include "xla/service/dump.h"
#include "xla/service/float_normalization.h"
#include "xla/service/float_support.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
+#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h"
#include "xla/service/gpu/buffer_sharing.h"
-#include "xla/service/gpu/conv_algorithm_picker.h"
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-#include "xla/service/gpu/cudnn_fused_mha_rewriter.h"
-#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h"
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
-#include "xla/service/gpu/cudnn_norm_rewriter.h"
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-#include "xla/service/gpu/cudnn_simplify_padding.h"
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-#include "xla/service/gpu/cudnn_workspace_rewriter.h"
-#include "xla/service/gpu/cusolver_rewriter.h"
-#include "xla/service/gpu/dot_sparsity_rewriter.h"
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-#include "xla/service/gpu/gemm_fusion_autotuner.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
#include "xla/service/gpu/gpu_asm_opts_util.h"
#include "xla/service/gpu/gpu_compiler.h"
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/gpu_sort_rewriter.h"
+#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "xla/service/gpu/metrics.h"
#include "xla/service/gpu/target_constants.h"
-#include "xla/service/gpu/triangular_solve_rewriter.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h"
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h"
+#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h"
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
+#include "xla/service/gpu/transforms/cudnn_norm_rewriter.h"
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+#include "xla/service/gpu/transforms/cudnn_simplify_padding.h"
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h"
+#include "xla/service/gpu/transforms/gpusolver_rewriter.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+#include "xla/service/gpu/transforms/triangular_solve_rewriter.h"
#include "xla/service/hlo_constant_folding.h"
#include "xla/service/hlo_cse.h"
#include "xla/service/hlo_dataflow_analysis.h"
@@ -94,6 +96,8 @@
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
#include "xla/stream_executor/cuda/cuda_diagnostics.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
+#include "xla/stream_executor/cuda/nvjitlink.h"
+#include "xla/stream_executor/cuda/nvjitlink_support.h"
#include "xla/stream_executor/cuda/ptx_compilation_method.h"
#include "xla/stream_executor/cuda/ptx_compiler.h"
#include "xla/stream_executor/cuda/ptx_compiler_support.h"
@@ -187,7 +191,7 @@
auto cuda_compute_capability =
std::get<se::CudaComputeCapability>(gpu_version);
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
- // (GpuConvPaddingLegalization). Also expand cuSolver calls.
+ // (ConvPaddingLegalization). Also expand cuSolver calls.
HloPassPipeline pipeline("conv_canonicalization");
pipeline.AddInvariantCheckerDebug<HloVerifier>(
/*layout_sensitive=*/false,
@@ -202,10 +206,10 @@
pipeline.AddPass<FloatNormalization>(&matmul_bf16_support);
pipeline.AddPass<GpusolverRewriter>();
- pipeline.AddPass<GpuConvRewriter>(cuda_compute_capability);
+ pipeline.AddPass<ConvRewriter>(cuda_compute_capability);
pipeline.AddPass<CudnnFusedConvRewriter>(cuda_compute_capability, dnn_version,
GetToolkitVersion());
- pipeline.AddPass<GpuConvPaddingLegalization>();
+ pipeline.AddPass<ConvPaddingLegalization>();
pipeline.AddPass<CudnnPadForConvolutions>(cuda_compute_capability);
pipeline.AddPass<CudnnVectorizeConvolutions>(cuda_compute_capability,
dnn_version);
@@ -230,7 +234,7 @@
// e.g. clean up unnecessary nop `convert`s.
pipeline.AddPass<CudnnSimplifyPadding>();
- // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
+ // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
// to a fixed point. Include algsimp because ReshapeMover relies on it.
[&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
@@ -252,7 +256,7 @@
pipeline.AddPass<GpuAlgebraicSimplifier>(algsimp_options, gpu_version);
}();
- // GpuConvRewriter, GpuConvPaddingLegalization and
+ // ConvRewriter, ConvPaddingLegalization and
// CudnnConvPadForTensorCores may add instructions which can be simplified
// by constant folding.
pipeline.AddPass<HloConstantFolding>();
@@ -338,9 +342,6 @@
// Transform TriangularSolve ops into custom-calls, so we can add temp
// memory.
post_pipeline.AddPass<TriangularSolveRewriter>();
- if (stream_exec) {
- post_pipeline.AddPass<CuDnnWorkspaceRewriter>(*stream_exec);
- }
TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status());
return absl::OkStatus();
@@ -386,20 +387,22 @@
absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses(
HloPassPipeline* pipeline, const DebugOptions& debug_options) {
if (debug_options.xla_gpu_enable_cub_radix_sort()) {
- pipeline->AddPass<GpuSortRewriter>();
+ pipeline->AddPass<SortRewriter>();
}
return absl::OkStatus();
}
-absl::Status NVPTXCompiler::RunCudnnFusionCompilerPass(
+absl::Status NVPTXCompiler::RunCudnnCompilerPasses(
HloModule* module, se::StreamExecutor* stream_exec,
BinaryMap* dnn_compiled_graphs) {
tsl::profiler::ScopedAnnotation annotation([&] {
return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#",
module->name(), module->unique_id());
});
- CuDnnFusionCompiler cudnn_compiler(*stream_exec, *dnn_compiled_graphs);
- return cudnn_compiler.Run(module).status();
+ CuDnnFusionCompiler fusion_compiler(*stream_exec, *dnn_compiled_graphs);
+ TF_RETURN_IF_ERROR(fusion_compiler.Run(module).status());
+ CuDnnCustomCallCompiler call_compiler(*stream_exec, *dnn_compiled_graphs);
+ return call_compiler.Run(module).status();
}
namespace {
@@ -531,6 +534,8 @@
return &CanShareBufferHint;
}
+constexpr const uint8_t kPtxPrefix[] = {'P', 'T', 'X', ':', ' '};
+
absl::StatusOr<GpuCompiler::BackendCompileResult>
NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
llvm::Module* llvm_module,
@@ -568,6 +573,20 @@
RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs);
}
+ TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
+ ChooseLinkingMethod(module_config.debug_options()));
+
+ if (linking_method == se::PtxLinkingMethod::kNvJitLink && relocatable) {
+ VLOG(2) << "Deferring the PTX to CUBIN compilation of the relocatable "
+ "module to the linking step.";
+ std::vector<uint8_t> binary;
+ binary.reserve(sizeof(kPtxPrefix) + ptx.size() + 1);
+ binary.insert(binary.end(), kPtxPrefix, kPtxPrefix + sizeof(kPtxPrefix));
+ binary.insert(binary.end(), ptx.begin(), ptx.end());
+ binary.emplace_back('\0');
+ return BackendCompileResult{std::move(ptx), std::move(binary)};
+ }
+
absl::StatusOr<std::vector<uint8_t>> maybe_cubin =
CompileGpuAsmOrGetCachedResult(
ptx, std::get<se::CudaComputeCapability>(gpu_version), module_config,
@@ -588,6 +607,9 @@
if (se::IsLibNvPtxCompilerSupported()) {
methods.emplace_back(PtxCompilationMethod::kNvPtxCompiler);
}
+ if (se::IsLibNvJitLinkSupported()) {
+ methods.emplace_back(PtxCompilationMethod::kNvJitLink);
+ }
methods.emplace_back(PtxCompilationMethod::kPtxas);
return methods;
}
@@ -608,11 +630,26 @@
}
};
+ if (!debug_options.xla_gpu_enable_libnvjitlink()) {
+ VLOG(3) << "Discarding NvJitLink since it is disabled.";
+ remove_compilation_method(PtxCompilationMethod::kNvJitLink);
+ }
if (!debug_options.xla_gpu_enable_libnvptxcompiler()) {
VLOG(3) << "Discarding NvPtxCompiler since it is disabled.";
remove_compilation_method(PtxCompilationMethod::kNvPtxCompiler);
}
+ VLOG(2) << "Supported and enabled compilation methods: "
+ << absl::StrJoin(compilation_methods, ", ");
+
+ if (relocatable && absl::c_linear_search(compilation_methods,
+ PtxCompilationMethod::kNvJitLink)) {
+ // NvJitLink can't produce relocatable CUBINs.
+ VLOG(3) << "Discarding NvJitLink since it can't produce the requested "
+ "relocatable CUBIN.";
+ remove_compilation_method(PtxCompilationMethod::kNvJitLink);
+ }
+
VLOG(2) << "Considered compilation methods: "
<< absl::StrJoin(compilation_methods, ", ");
@@ -655,6 +692,16 @@
absl::StatusOr<std::vector<uint8_t>> maybe_cubin = [&] {
switch (compilation_method) {
+ case PtxCompilationMethod::kNvJitLink:
+ return se::CompileAndLinkUsingLibNvJitLink(
+ cc.major, cc.minor,
+ {se::NvJitLinkInput{
+ se::NvJitLinkInput::Type::kPtx,
+ absl::Span<const uint8_t>{
+ reinterpret_cast<const uint8_t*>(ptx.c_str()),
+ ptx.size() + 1 /* We need the null terminator. */}}},
+ ptxas_config, cancel_if_reg_spill);
+
case PtxCompilationMethod::kNvPtxCompiler:
return se::CompileGpuAsmUsingLibNvPtxCompiler(
cc.major, cc.minor, ptx.c_str(), ptxas_config, cancel_if_reg_spill);
@@ -815,6 +862,12 @@
std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir;
using LinkingMethod = se::PtxLinkingMethod;
+
+ if (stream_executor::IsLibNvJitLinkSupported() &&
+ debug_options.xla_gpu_enable_libnvjitlink()) {
+ return se::PtxLinkingMethod::kNvJitLink;
+ }
+
TF_ASSIGN_OR_RETURN(auto asm_compiler_version,
GetAsmCompilerVersion(debug_options, preferred_cuda_dir));
@@ -859,28 +912,60 @@
}
absl::StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
- se::GpuComputeCapability cc, se::StreamExecutor* stream_exec,
- std::vector<std::vector<uint8_t>> modules,
+ se::GpuComputeCapability compute_capability,
+ se::StreamExecutor* stream_exec, std::vector<std::vector<uint8_t>> modules,
const DebugOptions& debug_options) {
if (modules.empty()) return std::vector<uint8_t>{};
+ auto cc =
+ std::get<stream_executor::CudaComputeCapability>(compute_capability);
+
TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
ChooseLinkingMethod(debug_options));
VLOG(1) << "Linking " << modules.size()
<< " modules with linking method: " << linking_method;
- std::vector<stream_executor::CubinOrPTXImage> images;
- images.reserve(modules.size());
- for (std::vector<uint8_t>& module : modules) {
- images.push_back({"", std::move(module)});
+ if (linking_method == se::PtxLinkingMethod::kNvJitLink) {
+ const auto module_contains_ptx =
+ [](const std::vector<uint8_t>& module) -> bool {
+ return module.size() >= sizeof(kPtxPrefix) &&
+ std::equal(std::begin(kPtxPrefix), std::end(kPtxPrefix),
+ std::begin(module));
+ };
+
+ std::vector<stream_executor::NvJitLinkInput> nvjitlink_inputs;
+ nvjitlink_inputs.reserve(modules.size());
+ for (std::vector<uint8_t>& module : modules) {
+ if (module_contains_ptx(module)) {
+ nvjitlink_inputs.push_back(
+ {se::NvJitLinkInput::Type::kPtx,
+ absl::Span<const uint8_t>(module).subspan(sizeof(kPtxPrefix))});
+ } else {
+ nvjitlink_inputs.push_back({se::NvJitLinkInput::Type::kCubin, module});
+ }
+ }
+
+ se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options);
+ return stream_executor::CompileAndLinkUsingLibNvJitLink(
+ cc.major, cc.minor, nvjitlink_inputs, ptxas_config,
+ /*cancel_if_reg_spill=*/false);
}
+
+ std::vector<stream_executor::CubinOrPTXImage> cubin_images;
+ cubin_images.reserve(modules.size());
+ for (std::vector<uint8_t>& module : modules) {
+ {
+ std::string profile = absl::StrCat("sm_", cc.major, cc.minor);
+ cubin_images.push_back({std::move(profile), std::move(module)});
+ }
+ }
+
auto context = se::gpu::ExtractGpuExecutor(stream_exec)->gpu_context();
if (linking_method == se::PtxLinkingMethod::kNvLink) {
- return LinkUsingNvlink(std::get<se::CudaComputeCapability>(cc),
- debug_options.xla_gpu_cuda_data_dir(), context,
- images);
+ return LinkUsingNvlink(cc, debug_options.xla_gpu_cuda_data_dir(), context,
+ cubin_images);
}
- return LinkGpuAsm(std::get<se::CudaComputeCapability>(cc), context, images);
+ return LinkGpuAsm(cc, context, cubin_images);
}
} // namespace gpu
diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h
index 25fa268..6d84deb 100644
--- a/third_party/xla/xla/service/gpu/nvptx_compiler.h
+++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h
@@ -32,7 +32,7 @@
#include "xla/autotune_results.pb.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/hlo_dataflow_analysis.h"
@@ -84,9 +84,9 @@
absl::Status AddCustomKernelReplacementPasses(
HloPassPipeline* pipeline, const DebugOptions& debug_options) override;
- absl::Status RunCudnnFusionCompilerPass(
- HloModule* module, se::StreamExecutor* stream_exec,
- BinaryMap* dnn_compiled_graphs) override;
+ absl::Status RunCudnnCompilerPasses(HloModule* module,
+ se::StreamExecutor* stream_exec,
+ BinaryMap* dnn_compiled_graphs) override;
HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override;
@@ -100,7 +100,8 @@
private:
absl::StatusOr<std::vector<uint8_t>> LinkModules(
- se::GpuComputeCapability cc, se::StreamExecutor* stream_exec,
+ se::GpuComputeCapability gpu_compute_capability,
+ se::StreamExecutor* stream_exec,
std::vector<std::vector<uint8_t>> modules,
const DebugOptions& debug_options) override;
diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc
deleted file mode 100644
index d0e841c..0000000
--- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc
+++ /dev/null
@@ -1,703 +0,0 @@
-
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/pipelined_p2p_rewriter.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-// Maps a computation to a boolean that indicates whether there is any
-// collective operations directly or indirectly invoked in the computation.
-using CollectiveInComputation =
- absl::flat_hash_map<const HloComputation*, bool>;
-
-using InstructionVector = HloInstruction::InstructionVector;
-
-// Records starting index and the ending index of a pipelined while-op. They
-// are the indices of the while-loop operand.
-struct PipelinedP2PInfo {
- int64_t opnd_start;
- int64_t opnd_end;
-};
-
-// Returns whether the instruction is a collective operation.
-bool IsCollectiveOp(const HloInstruction* op) {
- HloOpcode opcode = op->opcode();
- // TODO(NVIDIA/4364298): The information is recorded in b/309639264.
- // we need to avoid custom-calls to overlap with Send/Recv to workaround the
- // bug. Remove custom-calls here when the bug is fixed.
- if (opcode == HloOpcode::kCustomCall) {
- return true;
- }
-
- return hlo_query::IsCollectiveCommunicationOp(opcode) ||
- opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv;
-}
-
-// Returns whether the instruction may invoke collective operations directly
-// or indirectly.
-bool MayInvokeCollectiveOp(
- const HloInstruction* hlo,
- const CollectiveInComputation& collective_in_computation) {
- if (IsCollectiveOp(hlo)) {
- return true;
- }
- for (HloComputation* callee : hlo->called_computations()) {
- auto collective_in_comp = collective_in_computation.find(callee);
- CHECK(collective_in_comp != collective_in_computation.end());
- if (collective_in_comp->second) {
- return true;
- }
- }
- return false;
-}
-
-// Returns the unique get-tuple-element user with the given idx or nullptr if
-// there isn't such a unique user.
-HloInstruction* FindUniqueGTEUserWithIndex(const HloInstruction* op,
- int64_t idx) {
- CHECK(op->shape().IsTuple());
-
- HloInstruction* gte = nullptr;
- for (auto user : op->users()) {
- if (user->opcode() != HloOpcode::kGetTupleElement) {
- continue;
- }
- if (user->tuple_index() == idx) {
- if (gte == nullptr) {
- gte = user;
- } else {
- return nullptr;
- }
- }
- }
- return gte;
-}
-
-// Returns whether there is any get-tuple-element user with the given idx.
-bool HasGTEUserWithIndex(const HloInstruction* op, int64_t idx) {
- CHECK(op->shape().IsTuple());
-
- for (auto user : op->users()) {
- if (user->opcode() != HloOpcode::kGetTupleElement) {
- continue;
- }
- if (user->tuple_index() == idx) {
- return true;
- }
- }
- return false;
-}
-
-// Returns the instruction hidden behind a trivial tuple or `op`. This allows
-// the discovery of recv-done for the following case, for which the indirection
-// would have been removed by tuple-simplification.
-// gte.0 = f32[1,1024,1024] get-tuple-element(recv-done), index=0
-// gte.1 = token get-tuple-element(recv-done.p), index=1
-// op = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
-//
-// TODO(bixia): investigate the possible of implementing
-// m::TrivialTuple(m::RecvDone(&instr)) as suggested by code review.
-HloInstruction* MaySkipTrivialTuple(HloInstruction* op) {
- if (op->opcode() != HloOpcode::kTuple) {
- return op;
- }
- HloInstruction* hidden_op = nullptr;
- for (auto opnd : op->mutable_operands()) {
- if (opnd->opcode() != HloOpcode::kGetTupleElement) {
- return op;
- }
- if (hidden_op == nullptr) {
- hidden_op = opnd->mutable_operand(0);
- } else if (opnd->mutable_operand(0) != hidden_op) {
- return op;
- }
- }
- return hidden_op;
-}
-
-// This routine is similar to the non-const version above except that the
-// the given instruction is used for pattern checking only and can't be mutated.
-const HloInstruction* MaySkipTrivialTuple(const HloInstruction* op) {
- // Use const_cast to avoid repeating the non-const version above to find
- // operands of the instruction through operands() instead of
- // mutable_operands().
- return MaySkipTrivialTuple(const_cast<HloInstruction*>(op));
-}
-
-// Finds a consecutive block of balanced SendDone/RecvDone in the while_init
-// of a while-loop, assuming its while_init is a tuple.
-std::optional<PipelinedP2PInfo>
-FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(
- const HloInstruction* while_init) {
- PipelinedP2PInfo pipelined_p2p_info{0, 0};
- // Return whether the first SendDone/RecvDone has been seen.
- auto has_started = [&]() {
- return pipelined_p2p_info.opnd_start != pipelined_p2p_info.opnd_end;
- };
- // Record the difference between the number of SendDone and RecvDone in a
- // consecutive block.
- int difference = 0;
- // If SendDone/RecvDone exists in a consecutive block in the while_init
- // tuple, find such block.
- for (int64_t i = 0; i < while_init->operand_count(); ++i) {
- const HloInstruction* op = while_init->operand(i);
- if ((op->opcode() == HloOpcode::kRecvDone ||
- op->opcode() == HloOpcode::kSendDone) &&
- op->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) {
- if (op->opcode() == HloOpcode::kRecvDone) {
- difference++;
- } else {
- difference--;
- }
- if (!has_started()) {
- pipelined_p2p_info.opnd_start = i;
- }
- pipelined_p2p_info.opnd_end = i + 1;
- } else {
- if (has_started()) {
- VLOG(10) << "End a consecutive block";
- break;
- }
- }
- }
-
- if (difference != 0) {
- VLOG(10) << "Mismatch number of SendDone and RecvDone: " << difference;
- return std::nullopt;
- }
-
- if (has_started()) {
- // Check for SendDone/RecvDone outside the consecutive block.
- for (int64_t i = pipelined_p2p_info.opnd_end;
- i < while_init->operand_count(); ++i) {
- const HloInstruction* op = while_init->operand(i);
- if (op->opcode() == HloOpcode::kRecvDone ||
- op->opcode() == HloOpcode::kSendDone) {
- VLOG(10) << "SendDone/RecvDone outside the consecutive block";
- return std::nullopt;
- break;
- }
- }
- }
-
- if (!has_started()) {
- VLOG(10) << "No SendDone/RecvDone in while-init ";
- return std::nullopt;
- }
-
- return pipelined_p2p_info;
-}
-
-// Checks whether the while-op, its while-body and while-condition have a
-// recognized pipelined pattern. If a pipelined pattern is found, returns the
-// first and last indices for the pipelined instruction in the while-init tuple.
-// For pipelined Send/Recv to work, the SendDone/RecvDone doesn't have to be in
-// a consecutive block, but this simplifies the implementation and is the
-// pattern that the current gpu-p2p-pipeliner generated.
-//
-// As a summary, this is what the routine looks for:
-//
-// . The while-init has a tuple with a single user.
-// . The while-init has a consecutive block of SendDone and RecvDone. The
-// numbers of SendDone and RecvDone are the same, and there isn't any other
-// SendDone and RecvDone outside the block.
-// . The while-body has a single tuple parameter.
-// . For the while-op result tuple and the while-body parameter tuple:
-// The index corresponding to the index of SendDone in while-init should not
-// correspond to any get-element-tuple user.
-// The index corresponding to the index of RecvDone in while-init should
-// correspond to a single get-element-tuple user.
-// . In the while-body result tuple, the operand with an index corresponding to
-// the index in the while-init SendDone and RecvDone should also be a SendDone
-// or RecvDone.
-//
-// TODO(bixia): support pipelined SendDone/RecvDone not in a consecutive block
-// if the gpu-p2p-pipeliner will ever generate such code in the future.
-std::optional<PipelinedP2PInfo> FindPipelinedP2P(
- const HloInstruction* while_op) {
- VLOG(10) << "while_op: " << while_op->ToString();
- const HloInstruction* while_init = while_op->while_init();
- if (while_init->opcode() != HloOpcode::kTuple ||
- while_init->user_count() != 1) {
- return std::nullopt;
- }
-
- // The while-body and while-condition should have one parameter of a tuple
- // shape.
- const HloComputation* while_body = while_op->while_body();
- const HloComputation* while_condition = while_op->while_condition();
- if (while_body->num_parameters() != 1 ||
- while_condition->num_parameters() != 1) {
- return std::nullopt;
- }
-
- std::optional<PipelinedP2PInfo> pipelined_p2p_info =
- FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(while_init);
- if (!pipelined_p2p_info.has_value()) {
- return std::nullopt;
- }
-
- VLOG(10) << "opnd_start " << pipelined_p2p_info->opnd_start << " opnd_end "
- << pipelined_p2p_info->opnd_end;
-
- // In the while-result or while-body parameter, the index for RecvDone should
- // correspond to one get-tuple-element user and the index for SendDone should
- // not correspond to any get-tuple-element user.
- for (int64_t i = pipelined_p2p_info->opnd_start;
- i < pipelined_p2p_info->opnd_end; ++i) {
- const HloInstruction* op = while_init->operand(i);
- if (op->opcode() == HloOpcode::kRecvDone) {
- if (!FindUniqueGTEUserWithIndex(while_op, i)) {
- VLOG(10) << "While result get-tuple-element user with index " << i
- << " not unique";
- return std::nullopt;
- }
- if (!FindUniqueGTEUserWithIndex(while_body->parameter_instruction(0),
- i)) {
- VLOG(10) << "While-body parameter get-tuple-element user with index "
- << i << " not unique";
- return std::nullopt;
- }
- } else {
- CHECK(op->opcode() == HloOpcode::kSendDone);
- if (HasGTEUserWithIndex(while_op, i) ||
- HasGTEUserWithIndex(while_body->parameter_instruction(0), i)) {
- VLOG(10) << "SendDone with index " << i << " has unexpected users";
- return std::nullopt;
- }
- }
- }
-
- // The element in the while-body result tuple corresponding to the pipelined
- // SendDone/RecvDone in the while-init have the same opcode.
- const HloInstruction* root = while_body->root_instruction();
- for (int64_t i = pipelined_p2p_info->opnd_start;
- i < pipelined_p2p_info->opnd_end; ++i) {
- const HloInstruction* op_init = while_init->operand(i);
- const HloInstruction* op_root = root->operand(i);
- op_root = MaySkipTrivialTuple(op_root);
- if (op_init->opcode() != op_root->opcode()) {
- VLOG(10) << "Mismatching opcode, op_init: " << op_init->ToString()
- << " op_root: " << op_root->ToString();
- return std::nullopt;
- }
- }
-
- return pipelined_p2p_info.value();
-}
-
-absl::Status RemoveOpFromParent(HloInstruction* op) {
- TF_RETURN_IF_ERROR(op->DropAllControlDeps());
- TF_RETURN_IF_ERROR(op->parent()->RemoveInstruction(op));
- return absl::OkStatus();
-}
-
-absl::Status ReplaceOpInSequence(HloInstruction* old_op, HloInstruction* new_op,
- HloInstructionSequence& instruction_sequence) {
- VLOG(10) << "old_op: " << old_op->ToString();
- VLOG(10) << "new_op: " << new_op->ToString();
- instruction_sequence.replace_instruction(old_op, new_op);
- return RemoveOpFromParent(old_op);
-}
-
-absl::Status ReplaceUsesAndUpdateSequence(
- HloInstruction* old_op, HloInstruction* new_op,
- HloInstructionSequence& instruction_sequence, bool diff_shape = false) {
- VLOG(10) << "old_op: " << old_op->ToString();
- VLOG(10) << "new_op: " << new_op->ToString();
- if (diff_shape) {
- TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWithDifferentShape(new_op));
- } else {
- TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWith(new_op));
- }
- return ReplaceOpInSequence(old_op, new_op, instruction_sequence);
-}
-
-absl::Status ReplaceUsesAndUpdateSequence(
- const InstructionVector& old_ops, const InstructionVector& new_ops,
- HloInstructionSequence& instruction_sequence) {
- CHECK(old_ops.size() == new_ops.size());
- for (int64_t i = 0; i < old_ops.size(); ++i) {
- TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(old_ops[i], new_ops[i],
- instruction_sequence));
- }
- return absl::OkStatus();
-}
-
-absl::Status RemoveDoneOpsAndUpdateSequence(
- const InstructionVector& ops,
- HloInstructionSequence& instruction_sequence) {
- auto remove_op = [&](HloInstruction* op) {
- VLOG(10) << "op: " << op->ToString();
- TF_RETURN_IF_ERROR(RemoveOpFromParent(op));
- instruction_sequence.remove_instruction(op);
- return absl::OkStatus();
- };
- for (auto op : ops) {
- if (op->opcode() == HloOpcode::kTuple) {
- InstructionVector to_remove;
- HloInstruction* tuple_op = op;
- op = MaySkipTrivialTuple(tuple_op);
- to_remove.push_back(tuple_op);
- for (auto opnd : tuple_op->mutable_operands()) {
- to_remove.push_back(opnd);
- }
- for (auto opnd : to_remove) {
- TF_RETURN_IF_ERROR(remove_op(opnd));
- }
- }
- TF_RETURN_IF_ERROR(remove_op(op));
- }
- return absl::OkStatus();
-}
-
-bool InsertBeforeFirstCollectiveOp(
- const InstructionVector& ops,
- const CollectiveInComputation& collective_in_computation,
- HloInstructionSequence& instruction_sequence, int64_t& idx,
- int64_t& idx_tot) {
- bool inserted = false;
- while (idx < idx_tot) {
- HloInstruction* hlo = instruction_sequence.instructions()[idx];
- if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
- for (auto op : ops) {
- instruction_sequence.insert_instruction(op, idx);
- idx++;
- idx_tot++;
- }
- inserted = true;
- break;
- }
- idx++;
- }
- return inserted;
-}
-
-void CopyInstructionInfo(const HloInstruction* old_op, HloInstruction* new_op) {
- new_op->SetAndSanitizeName(absl::StrCat(old_op->name(), ".clone"));
- new_op->set_metadata(old_op->metadata());
- new_op->add_frontend_attributes(old_op->frontend_attributes());
- new_op->CopyBackendConfigFrom(old_op);
-}
-
-HloInstruction* CreateRecvDoneFrom(const HloInstruction* old_recv_done,
- HloInstruction* recv,
- HloComputation* computation) {
- HloInstruction* recv_done =
- computation->AddInstruction(HloInstruction::CreateRecvDone(
- recv, old_recv_done->channel_id().value()));
- CopyInstructionInfo(old_recv_done, recv_done);
- return recv_done;
-}
-
-HloInstruction* CreateSendDoneFrom(const HloInstruction* old_send_done,
- HloInstruction* send,
- HloComputation* computation) {
- HloInstruction* send_done =
- computation->AddInstruction(HloInstruction::CreateSendDone(
- send, old_send_done->channel_id().value()));
- CopyInstructionInfo(old_send_done, send_done);
- return send_done;
-}
-
-absl::Status RewritePipelinedP2PWhileBody(
- const CollectiveInComputation& collective_in_computation,
- const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op,
- int64_t opnd_start, int64_t opnd_end) {
- HloComputation* computation = while_op->while_body();
- HloInstruction* while_init = while_op->while_init();
- HloInstruction* root = computation->root_instruction();
- HloInstructionSequence& instruction_sequence =
- computation->parent()->schedule().GetOrCreateSequence(computation);
-
- HloInstruction* param = computation->parameter_instruction(0);
- *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
-
- InstructionVector recv_dones;
- InstructionVector new_recv_dones;
- InstructionVector new_send_dones;
- for (int64_t i = opnd_start; i < opnd_end; ++i) {
- const HloInstruction* op = root->operand(i);
- op = MaySkipTrivialTuple(op);
- if (op->opcode() == HloOpcode::kRecvDone) {
- HloInstruction* gte = FindUniqueGTEUserWithIndex(param, i);
- CHECK(gte != nullptr);
- recv_dones.push_back(gte);
-
- // Create the new RecvDone using the new while-body parameter.
- HloInstruction* recv = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(param, i));
-
- HloInstruction* recv_done = CreateRecvDoneFrom(op, recv, computation);
- new_recv_dones.push_back(recv_done);
- continue;
- }
- CHECK(op->opcode() == HloOpcode::kSendDone);
- // Create the new SendDone using the new while-op result.
- HloInstruction* send = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(param, i));
- HloInstruction* send_done = CreateSendDoneFrom(op, send, computation);
- new_send_dones.push_back(send_done);
- }
- TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
- instruction_sequence));
-
- // Create a new root tuple.
- InstructionVector done_ops;
- InstructionVector new_opnds;
- for (int64_t i = 0; i < while_init->operand_count(); ++i) {
- HloInstruction* op = root->mutable_operand(i);
- if (i >= opnd_start && i < opnd_end) {
- new_opnds.push_back(MaySkipTrivialTuple(op)->mutable_operand(0));
- done_ops.push_back(op);
- } else {
- new_opnds.push_back(op);
- }
- }
- HloInstruction* new_root =
- computation->AddInstruction(HloInstruction::CreateTuple(new_opnds));
- computation->set_root_instruction(new_root,
- /*accept_different_shape=*/true);
- TF_RETURN_IF_ERROR(computation->RemoveInstruction(root));
- instruction_sequence.replace_instruction(root, new_root);
-
- TF_RETURN_IF_ERROR(
- RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
-
- // Find a place to put the new SendDone. It will be either the first
- // may-invoke-collective ops that is not in the pipelined Send/Recv chain or
- // the first op in the pipelined Send/Recv chain.
- int64_t idx = 0;
- int64_t idx_end = instruction_sequence.size();
- bool inserted =
- InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
- instruction_sequence, idx, idx_end);
- CHECK(inserted); // There are Send/Recv in the while-body, expect inserted.
- CHECK(idx_end == instruction_sequence.size());
-
- // The module schedule will be updated at the end of the pass.
- return absl::OkStatus();
-}
-
-void RewritePipelinedP2PWhileCond(
- const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op) {
- HloComputation* computation = while_op->while_condition();
- HloInstruction* param = computation->parameter_instruction(0);
- *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
- VLOG(10) << computation->ToString();
-}
-
-// Rewrites the while-op with a recognized pipelined SendDone/RecvDone pattern
-// to pipeline Send/Recv instead.
-absl::Status TransformLoop(
- const PipelinedP2PInfo& pipelined_info,
- const CollectiveInComputation& collective_in_computation, int64_t& idx,
- int64_t& idx_end, HloInstructionSequence& instruction_sequence,
- HloInstruction* while_op) {
- HloComputation* computation = while_op->parent();
- int64_t opnd_start = pipelined_info.opnd_start;
- int64_t opnd_end = pipelined_info.opnd_end;
- VLOG(10) << "Transform pipelined while-op " << while_op->ToString();
- HloInstruction* while_init = while_op->while_init();
- InstructionVector new_while_init_opnds;
- std::vector<Shape> new_parameter_shapes;
- for (int64_t i = 0; i < while_init->operand_count(); ++i) {
- HloInstruction* op = while_init->mutable_operand(i);
- if (i >= opnd_start && i < opnd_end) {
- // Get Send/Recv from SendDone/RecvDone.
- new_while_init_opnds.push_back(op->mutable_operand(0));
- } else {
- new_while_init_opnds.push_back(op);
- }
- new_parameter_shapes.push_back(new_while_init_opnds.back()->shape());
- }
-
- RewritePipelinedP2PWhileCond(new_parameter_shapes, while_op);
- TF_RETURN_IF_ERROR(RewritePipelinedP2PWhileBody(
- collective_in_computation, new_parameter_shapes, while_op, opnd_start,
- opnd_end));
- HloInstruction* new_while_init = computation->AddInstruction(
- HloInstruction::CreateTuple(new_while_init_opnds), "while-init");
- VLOG(10) << "new_while_init: " << new_while_init->ToString();
- HloInstruction* new_while_op = computation->AddInstruction(
- HloInstruction::CreateWhile(
- while_op->while_body()->root_instruction()->shape(),
- while_op->while_condition(), while_op->while_body(), new_while_init),
- "while-result");
- CopyInstructionInfo(while_op, new_while_op);
- VLOG(10) << "new_while_op: " << new_while_op->ToString();
-
- InstructionVector recv_dones;
- InstructionVector new_recv_dones;
- InstructionVector new_send_dones;
- InstructionVector done_ops;
- for (int64_t i = opnd_start; i < opnd_end; ++i) {
- HloInstruction* op = while_init->mutable_operand(i);
- done_ops.push_back(op);
- if (op->opcode() == HloOpcode::kRecvDone) {
- HloInstruction* gte = FindUniqueGTEUserWithIndex(while_op, i);
- CHECK(gte != nullptr);
- recv_dones.push_back(gte);
-
- // Create the new RecvDone using the new while-op result.
- HloInstruction* recv = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(new_while_op, i));
- HloInstruction* recv_done = computation->AddInstruction(
- HloInstruction::CreateRecvDone(recv, op->channel_id().value()));
- new_recv_dones.push_back(recv_done);
- CopyInstructionInfo(op, recv_done);
- continue;
- }
- CHECK(op->opcode() == HloOpcode::kSendDone);
- // Create the new SendDone using the new while-op result.
- HloInstruction* send = computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(new_while_op, i));
- HloInstruction* send_done = computation->AddInstruction(
- HloInstruction::CreateSendDone(send, op->channel_id().value()));
- new_send_dones.push_back(send_done);
- CopyInstructionInfo(op, send_done);
- }
-
- TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(
- while_op, new_while_op, instruction_sequence, /*diff_shape*/ true));
- TF_RETURN_IF_ERROR(
- ReplaceOpInSequence(while_init, new_while_init, instruction_sequence));
- TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
- instruction_sequence));
- TF_RETURN_IF_ERROR(
- RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
-
- int64_t opnd_tot = opnd_end - opnd_start;
- // Verify that the numbers of ops we have removed from the sequence is
- // opnd_tot and they are before the position of the new while-op.
- CHECK(idx_end == instruction_sequence.size() + opnd_tot);
- CHECK(instruction_sequence.instructions()[idx - opnd_tot] == new_while_op);
-
- // Update idx_end to reflect the current size of the instruction sequence.
- // Update idx to right after the new while-op.
- idx_end -= opnd_tot;
- idx = idx - opnd_tot + 1;
- bool inserted =
- InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
- instruction_sequence, idx, idx_end);
- CHECK(idx_end == instruction_sequence.size());
- // If there isn't any may-invoke-collective ops after the while-op, add
- // the new SendDone ops before the last instruction in the sequence.
- if (!inserted) {
- CHECK(idx_end == idx);
- idx--;
- for (auto send_done : new_send_dones) {
- instruction_sequence.insert_instruction(send_done, idx++);
- }
- }
- return absl::OkStatus();
-}
-
-// Find while-loop with pipelined Send/Recv and rotates the SendDone/RecvDone
-// for such while-loop.
-absl::StatusOr<bool> ProcessComputation(
- HloModule* module, HloComputation* computation,
- CollectiveInComputation& collective_in_computation) {
- VLOG(10) << "Process compuation " << computation->name();
- bool changed = false;
- HloInstructionSequence& instruction_sequence =
- module->schedule().GetOrCreateSequence(computation);
- int64_t idx = 0;
- int64_t idx_end = instruction_sequence.size();
- while (idx < idx_end) {
- HloInstruction* hlo = instruction_sequence.instructions()[idx];
-
- if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
- collective_in_computation[computation] = true;
- }
-
- if (hlo->opcode() != HloOpcode::kWhile) {
- idx++;
- continue;
- }
-
- std::optional<PipelinedP2PInfo> pipelined_info = FindPipelinedP2P(hlo);
- if (!pipelined_info.has_value()) {
- idx++;
- continue;
- }
- TF_RETURN_IF_ERROR(TransformLoop(pipelined_info.value(),
- collective_in_computation, idx, idx_end,
- instruction_sequence, hlo));
- changed = true;
- }
- return changed;
-}
-} // namespace
-
-absl::StatusOr<bool> PipelinedP2PRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- if (!module->has_schedule()) return changed;
- CollectiveInComputation collective_in_computation;
- // Visit the computations in the order of callees to callers, so that
- // while-body is processed before while-op.
- for (auto* computation :
- module->MakeComputationPostOrder(execution_threads)) {
- if (computation->IsFusionComputation()) {
- collective_in_computation[computation] = false;
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(
- bool cur_changed,
- ProcessComputation(module, computation, collective_in_computation));
- changed |= cur_changed;
- }
-
- if (changed) {
- TF_RETURN_IF_ERROR(module->schedule().Update());
- }
-
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h
deleted file mode 100644
index 88b6bb6..0000000
--- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h
+++ /dev/null
@@ -1,133 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_
-#define XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_
-
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// PipelinedP2PRewriter is a pass that rewrites pipelined Send/Recv related
-// code for point-to-point communication to rotate SendDone and RecvDone at the
-// end of a loop iteration to the beginning of the next iteration. This pass
-// operates on scheduled module and updates the instruction sequence.
-//
-// In particular, a pipelined Send/Recv chain with one channel group with this
-// code pattern:
-//
-// main:
-// recv
-// send
-// recv-done
-// send-done
-// while-init = (recv-done, send-done, ...)
-// while-op = while(whiel-init) ...
-//
-// while-body:
-// ...
-// recv
-// send
-// recv-done
-// send-done
-// ROOT tuple(recv-done, send-done, ...)
-//
-// Will be transformed to:
-//
-// main:
-// recv
-// send
-// while-init = (recv, send, ...)
-// while-op = while(whiel-init) ...
-// recv-done
-// send-done
-//
-// while-body:
-// recv-done
-// ...
-// send-done
-// recv
-// send
-// ROOT tuple(recv, send, ...)
-//
-// A pipelined Send/Recv chain with two channel groups with this code pattern:
-//
-// main:
-// recv.0
-// send.0
-// recv.1
-// send.1
-// recv-done.0
-// send-done.0
-// recv-done.1
-// send-done.1
-// while-init = (recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
-// while-op = while(whiel-init) ...
-//
-// while-body:
-// ...
-// recv.0
-// send.0
-// recv.1
-// send.1
-// recv-done.0
-// send-done.0
-// recv-done.1
-// send-done.1
-// ROOT = tuple(recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
-//
-// Will be transformed to:
-//
-// main:
-//
-// recv.0
-// send.0
-// recv.1
-// send.1
-// while-init = (recv.0, send.0, recv.1, send.1, ...)
-// while-op = while(while-init) ...
-// recv-done.0
-// send-done.0
-// recv-done.1
-// send-done.1
-//
-// while-body:
-// recv-done.0
-// recv-done.1
-// ...
-// send-done.0
-// send-done.1
-// recv.0
-// send.1
-// recv.1
-// send.1
-// ROOT tuple(recv.0, send.0, recv.1, send.1, ...)
-//
-class PipelinedP2PRewriter : public HloModulePass {
- public:
- absl::string_view name() const override { return "pipelined-p2p-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc
deleted file mode 100644
index a0d5830..0000000
--- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc
+++ /dev/null
@@ -1,674 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/pipelined_p2p_rewriter.h"
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class PipelinedP2pRewriterTest : public HloTestBase {
- protected:
- void DoFileCheck(const HloModule* module, absl::string_view expected) {
- HloPrintOptions options;
- options.set_print_operand_shape(false);
- options.set_print_result_shape(false);
- TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
- RunFileCheck(module->ToString(options), expected));
- EXPECT_TRUE(filecheck_matched);
- }
-};
-
-TEST_F(PipelinedP2pRewriterTest, SendRecUnpipelinedNotTransform) {
- const char* kModuleStr = R"(
-HloModule test
-
-cond {
- param = (u32[], u32[2]) parameter(0)
- count = get-tuple-element(%param), index=0
- ub = u32[] constant(11)
- ROOT result = pred[] compare(count, ub), direction=LT
- }
-
-body {
- param = (u32[], u32[2]) parameter(0)
- count = get-tuple-element(param), index=0
- send-data = u32[2] get-tuple-element(param), index=1
-
- after-all.0.n = token[] after-all()
- recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{3,0}}",
- _xla_send_recv_pipeline="0"
- }
- send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.n),
- channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{3,0}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.0 = token[] send-done(send.0), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
-
- recv-data = u32[2] get-tuple-element(recv-done.0), index=0
-
- c1 = u32[] constant(1)
- new_count = u32[] add(count, c1)
-
- r = u32[2] broadcast(c1), dimensions={}
- s = u32[2] add(r, recv-data)
-
- ROOT result = (u32[], u32[2]) tuple(new_count, s)
- }
-
- ENTRY test_computation {
- c0 = u32[] constant(0)
- c1 = u32[] constant(1)
- r = u32[] replica-id()
- a = u32[] add(c1, r)
- init = u32[2] broadcast(a), dimensions={}
- while_init = (u32[], u32[2]) tuple(c0, init)
- while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond,
- backend_config={"known_trip_count":{"n":"11"}}
- ROOT recv-data = u32[2] get-tuple-element(while_result), index=1
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr));
- PipelinedP2PRewriter rewriter;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
- EXPECT_FALSE(changed);
-}
-
-// Tests the rewrite for a pipelined Send/Recv chain with only one channel
-// group.
-TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) {
- const char* kModuleStr = R"(
- HloModule test, is_scheduled=true
-
- while-cond {
- param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
- ub = u32[] constant(25)
- ROOT cond-result = pred[] compare(count, ub), direction=LT
- }
-
- while-body {
- param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
-
- recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
- recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
-
- c1 = u32[] constant(1)
- new-count = u32[] add(count, c1)
- replica = u32[] replica-id()
- c10 = u32[] constant(10)
- sum = u32[] add(replica, c10)
- sum2 = u32[] add(sum, count)
- conv = f32[] convert(sum2)
- p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
- b = f32[1, 1024, 1024] add(p, recv-data)
- c = f32[1, 1024, 1024] multiply(b, b)
- d = f32[1, 1024, 1024] tan(c)
- s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
- lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
- send-data = f32[1, 1024, 1024] add(c, s)
-
- after-all = token[] after-all()
- recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
- channel_id=1, frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.p = token[] send-done(send), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
- gte.1 = token[] get-tuple-element(recv-done.p), index=1
- recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
- ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
- tuple(new-count, recv-done-tuple, send-done.p)
- }
-
- ENTRY main {
- c0 = u32[] constant(0)
- f0 = f32[] constant(0.0)
- init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
- after-all.1 = token[] after-all()
- recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.1.p = token[] send-done(send.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- while-init.p = (u32[], (f32[1,1024,1024], token[]), token[])
- tuple(c0, recv-done.1.p, send-done.1.p)
- while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
- while(while-init.p),
- body=while-body, condition=while-cond,
- backend_config={"known_trip_count":{"n":"25"}}
-
- recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
-
- ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
- }
- )";
-
- const char* kExpected = R"(
- CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
- CHECK: %param.1 = parameter(0)
- CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
- CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
- CHECK: %count.1 = get-tuple-element(%param.1), index=0
- CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK: %recv-data = get-tuple-element(%recv-done.p.clone), index=0
- CHECK: %c1 = constant(1)
- CHECK: %new-count = add(%count.1, %c1)
- CHECK: %replica = replica-id()
- CHECK: %c10 = constant(10)
- CHECK: %sum = add(%replica, %c10)
- CHECK: %sum2 = add(%sum, %count.1)
- CHECK: %conv = convert(%sum2)
- CHECK: %p = broadcast(%conv), dimensions={}
- CHECK: %b = add(%p, %recv-data)
- CHECK: %c = multiply(%b, %b)
- CHECK: %d = tan(%c)
- CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
- CHECK: %send-data = add(%c, %s)
- CHECK: %after-all = after-all()
- CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
- CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
- CHECK: ROOT %tuple = tuple(%new-count, %recv, %send)
- CHECK: }
-
- CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
- CHECK: %param = parameter(0)
- CHECK: %count = get-tuple-element(%param), index=0
- CHECK: %ub = constant(25)
- CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
- CHECK: }
-
- CHECK: ENTRY %main () -> f32[1,1024,1024] {
- CHECK: %c0 = constant(0)
- CHECK: %f0 = constant(0)
- CHECK: %init = broadcast(%f0), dimensions={}
- CHECK: %after-all.1 = after-all()
- CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
- CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
- CHECK: %while-init = tuple(%c0, %recv.1, %send.1)
- CHECK: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body,
- CHECK-SAME{LITERAL}: backend_config={"known_trip_count":{"n":"25"}}
- CHECK: %get-tuple-element.2 = get-tuple-element(%while-result.p.clone), index=1
- CHECK: %get-tuple-element.3 = get-tuple-element(%while-result.p.clone), index=2
- CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK: ROOT %entry-result = get-tuple-element(%recv-done.1.p.clone), index=0
- CHECK: })";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr));
- PipelinedP2PRewriter rewriter;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
- EXPECT_TRUE(changed);
-
- DoFileCheck(module.get(), kExpected);
-}
-
-// Repeats the Send/Recv pattern in the previous test, to test that we can
-// rewrite a routine with multiple pipelined loops without crashing.
-TEST_F(PipelinedP2pRewriterTest, SendRecvTwoPipelinedWhileLoops) {
- const char* kModuleStr = R"(
- HloModule test, is_scheduled=true
-
- while-cond {
- param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
- ub = u32[] constant(25)
- ROOT cond-result = pred[] compare(count, ub), direction=LT
- }
-
- while-body {
- param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
-
- recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
- send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
-
- c1 = u32[] constant(1)
- new-count = u32[] add(count, c1)
-
- after-all = token[] after-all()
- recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
- channel_id=1, frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.p = token[] send-done(send), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
- gte.1 = token[] get-tuple-element(recv-done.p), index=1
- recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
- ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
- tuple(new-count, recv-done-tuple, send-done.p)
- }
-
- while-cond-2 {
- param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
- ub = u32[] constant(25)
- ROOT cond-result = pred[] compare(count, ub), direction=LT
- }
-
- while-body-2 {
- param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
-
- recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
- send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
-
- c1 = u32[] constant(1)
- new-count = u32[] add(count, c1)
-
- after-all = token[] after-all()
- recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
- channel_id=1, frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.p = token[] send-done(send), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
- gte.1 = token[] get-tuple-element(recv-done.p), index=1
- recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
- ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
- tuple(new-count, recv-done-tuple, send-done.p)
- }
-
- ENTRY main {
- c0 = u32[] constant(0)
- f0 = f32[] constant(0.0)
- init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
- after-all.1 = token[] after-all()
- recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.1.p = token[] send-done(send.1), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- while-init.p = (u32[], (f32[1,1024,1024], token[]), token[])
- tuple(c0, recv-done.1.p, send-done.1.p)
- while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
- while(while-init.p),
- body=while-body, condition=while-cond,
- backend_config={"known_trip_count":{"n":"25"}}
-
- recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
-
- after-all-2.1 = token[] after-all()
- recv-2.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-2.1), channel_id=2,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- send-2.1 = (f32[1, 1024, 1024], u32[], token[]) send(recv-done.1.q, after-all-2.1), channel_id=2,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done-2.1.p = (f32[1,1024,1024], token[]) recv-done(recv-2.1), channel_id=2,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done-2.1.p = token[] send-done(send-2.1), channel_id=2,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- while-init-2.p = (u32[], (f32[1,1024,1024], token[]), token[])
- tuple(c0, recv-done-2.1.p, send-done-2.1.p)
- while-result-2.p = (u32[], (f32[1,1024,1024], token[]), token[])
- while(while-init-2.p),
- body=while-body-2, condition=while-cond-2,
- backend_config={"known_trip_count":{"n":"25"}}
-
- recv-done-2.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result-2.p), index=1
-
- ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done-2.1.q), index=0
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr));
- PipelinedP2PRewriter rewriter;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
- // Check that we transform the module without crashing.
- EXPECT_TRUE(changed);
-}
-
-// Tests the rewrite for a pipelined Send/Recv chain with two channel groups.
-TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) {
- const char* kModuleStr = R"(
- HloModule test, is_scheduled=true
-
- while-cond {
- param = (u32[], (f32[1,1024,1024], token[]), token[],
- (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
- ub = u32[] constant(25)
- ROOT cond-result = pred[] compare(count, ub), direction=LT
- }
-
- while-body {
- param = (u32[], (f32[1,1024,1024], token[]), token[],
- (f32[1,1024,1024], token[]), token[]) parameter(0)
- count = get-tuple-element(param), index=0
-
- recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
- recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0
- recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3
- recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
-
- replica = u32[] replica-id()
- constant0 = u32[] constant(0)
- compare0 = pred[] compare(replica, constant0), direction=EQ
- compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
- recv-data = f32[1, 1024, 1024] select(compare, recv-data.0, recv-data.1)
-
- c1 = u32[] constant(1)
- new-count = u32[] add(count, c1)
- c10 = u32[] constant(10)
- sum = u32[] add(replica, c10)
- sum2 = u32[] add(sum, count)
- conv = f32[] convert(sum2)
- p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
- b = f32[1, 1024, 1024] add(p, recv-data)
- c = f32[1, 1024, 1024] multiply(b, b)
- d = f32[1, 1024, 1024] tan(c)
- s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
- lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
- send-data = f32[1, 1024, 1024] add(c, s)
-
- after-all = token[] after-all()
- recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{3,0}}",
- _xla_send_recv_pipeline="0"
- }
- send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
- channel_id=1, frontend_attributes={
- _xla_send_recv_source_target_pairs="{{3,0}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.p = token[] send-done(send), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
-
- after-all.1 = token[] after-all()
- recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
- _xla_send_recv_pipeline="1"
- }
- send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1),
- channel_id=2, frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
- _xla_send_recv_pipeline="1"
- }
- recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2,
- frontend_attributes={
- _xla_send_recv_pipeline="1"
- }
- send-done.1.p = token[] send-done(send.1), channel_id=2,
- frontend_attributes={
- _xla_send_recv_pipeline="1"
- }
-
- ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[],
- (f32[1,1024,1024], token[]), token[])
- tuple(new-count, recv-done.p, send-done.p, recv-done.1.p, send-done.1.p)
- }
-
- ENTRY main {
- c0 = u32[] constant(0)
- f0 = f32[] constant(0.0)
- init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
- after-all.2 = token[] after-all()
- recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{3,0}}",
- _xla_send_recv_pipeline="0"
- }
- send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{3,0}}",
- _xla_send_recv_pipeline="0"
- }
- recv-done.2.p = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
- send-done.2.p = token[] send-done(send.2), channel_id=1,
- frontend_attributes={
- _xla_send_recv_pipeline="0"
- }
-
- after-all.3 = token[] after-all()
- recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
- _xla_send_recv_pipeline="1"
- }
- send.3 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.3), channel_id=2,
- frontend_attributes={
- _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
- _xla_send_recv_pipeline="1"
- }
- recv-done.3.p = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2,
- frontend_attributes={
- _xla_send_recv_pipeline="1"
- }
- send-done.3.p = token[] send-done(send.3), channel_id=2,
- frontend_attributes={
- _xla_send_recv_pipeline="1"
- }
-
- while-init.p = (u32[], (f32[1,1024,1024], token[]), token[],
- (f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2.p, send-done.2.p, recv-done.3.p, send-done.3.p)
- while-result.p = (u32[], (f32[1,1024,1024], token[]), token[],
- (f32[1,1024,1024], token[]), token[]) while(while-init.p),
- body=while-body, condition=while-cond,
- backend_config={"known_trip_count":{"n":"25"}}
-
- recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
- recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0
- recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=3
- recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0
-
- replica = u32[] replica-id()
- constant0 = u32[] constant(0)
- compare0 = pred[] compare(replica, constant0), direction=EQ
- compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
- ROOT entry-result = f32[1, 1024, 1024] select(compare, recv-data.2, recv-data.3)
- }
- )";
-
- const char* kExpected = R"(
- CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
- CHECK: %param.1 = parameter(0)
- CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
- CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
- CHECK: %get-tuple-element.2 = get-tuple-element(%param.1), index=3
- CHECK: %get-tuple-element.3 = get-tuple-element(%param.1), index=4
- CHECK: %count.1 = get-tuple-element(%param.1), index=0
- CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK: %recv-data.0 = get-tuple-element(%recv-done.p.clone), index=0
- CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
- CHECK: %recv-data.1 = get-tuple-element(%recv-done.1.p.clone), index=0
- CHECK: %replica = replica-id()
- CHECK: %constant0 = constant(0)
- CHECK: %compare0 = compare(%replica, %constant0), direction=EQ
- CHECK: %compare = broadcast(%compare0), dimensions={}
- CHECK: %recv-data.2 = select(%compare, %recv-data.0, %recv-data.1)
- CHECK: %c1 = constant(1)
- CHECK: %new-count = add(%count.1, %c1)
- CHECK: %c10 = constant(10)
- CHECK: %sum = add(%replica, %c10)
- CHECK: %sum2 = add(%sum, %count.1)
- CHECK: %conv = convert(%sum2)
- CHECK: %p = broadcast(%conv), dimensions={}
- CHECK: %b = add(%p, %recv-data.2)
- CHECK: %c = multiply(%b, %b)
- CHECK: %d = tan(%c)
- CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
- CHECK: %send-data = add(%c, %s)
- CHECK: %after-all = after-all()
- CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
- CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
- CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
- CHECK: %after-all.1 = after-all()
- CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
- CHECK{LITERAL}: %send.1 = send(%send-data, %after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
- CHECK: ROOT %tuple = tuple(%new-count, %recv, %send, %recv.1, %send.1)
- CHECK: }
-
- CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
- CHECK: %param = parameter(0)
- CHECK: %count = get-tuple-element(%param), index=0
- CHECK: %ub = constant(25)
- CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
- CHECK: }
-
- CHECK: ENTRY %main () -> f32[1,1024,1024] {
- CHECK: %c0 = constant(0)
- CHECK: %f0 = constant(0)
- CHECK: %init = broadcast(%f0), dimensions={}
- CHECK: %after-all.2 = after-all()
- CHECK{LITERAL}: %recv.2 = recv(%after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
- CHECK{LITERAL}: %send.2 = send(%init, %after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
- CHECK: %after-all.3 = after-all()
- CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
- CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
- CHECK: %while-init = tuple(%c0, %recv.2, %send.2, %recv.3, %send.3)
- CHECK{LITERAL}: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}}
- CHECK: %get-tuple-element.4 = get-tuple-element(%while-result.p.clone), index=1
- CHECK: %get-tuple-element.5 = get-tuple-element(%while-result.p.clone), index=2
- CHECK: %get-tuple-element.6 = get-tuple-element(%while-result.p.clone), index=3
- CHECK: %get-tuple-element.7 = get-tuple-element(%while-result.p.clone), index=4
- CHECK: %recv-done.2.p.clone = recv-done(%get-tuple-element.4), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK: %recv-data.3 = get-tuple-element(%recv-done.2.p.clone), index=0
- CHECK: %recv-done.3.p.clone = recv-done(%get-tuple-element.6), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
- CHECK: %recv-data.4 = get-tuple-element(%recv-done.3.p.clone), index=0
- CHECK: %replica.1 = replica-id()
- CHECK: %constant0.1 = constant(0)
- CHECK: %compare0.1 = compare(%replica.1, %constant0.1), direction=EQ
- CHECK: %compare.1 = broadcast(%compare0.1), dimensions={}
- CHECK: %send-done.2.p.clone = send-done(%get-tuple-element.5), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
- CHECK: %send-done.3.p.clone = send-done(%get-tuple-element.7), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
- CHECK: ROOT %entry-result = select(%compare.1, %recv-data.3, %recv-data.4)
- CHECK: })";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr));
- PipelinedP2PRewriter rewriter;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
- EXPECT_TRUE(changed);
-
- DoFileCheck(module.get(), kExpected);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc
index 928f4bf..7ced52c 100644
--- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc
+++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc
@@ -22,10 +22,10 @@
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/copy_insertion.h"
#include "xla/service/cpu_gpu_shape_verifier.h"
-#include "xla/service/gpu/alias_passthrough_params.h"
-#include "xla/service/gpu/copy_fusion.h"
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-#include "xla/service/gpu/horizontal_loop_fusion.h"
+#include "xla/service/gpu/transforms/alias_passthrough_params.h"
+#include "xla/service/gpu/transforms/copy_fusion.h"
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
#include "xla/service/hlo_dataflow_analysis.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/hlo_pass_pipeline.h"
@@ -78,14 +78,14 @@
}
// We are using a sub-pipeline here, so that the verifier only runs after both
- // GpuHorizontalLoopFusion and HloDCE.
+ // HorizontalLoopFusion and HloDCE.
auto& sub_pipeline =
pipeline.AddPass<HloPassPipeline>("horizontal-loop-fusion-for-copy");
// To fuse the copy.
sub_pipeline.AddPass<CopyFusion>();
- sub_pipeline.AddPass<GpuHorizontalLoopFusion>("copy_");
+ sub_pipeline.AddPass<HorizontalLoopFusion>("copy_");
sub_pipeline.AddPass<HloDCE>();
- pipeline.AddPass<GpuSanitizeConstantNames>();
+ pipeline.AddPass<SanitizeConstantNames>();
return pipeline;
}
diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc
deleted file mode 100644
index d57a83c..0000000
--- a/third_party/xla/xla/service/gpu/priority_fusion.cc
+++ /dev/null
@@ -1,871 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/priority_fusion.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <iterator>
-#include <limits>
-#include <map>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/meta/type_traits.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/dump.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusion_process_dump.pb.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/model/fusion_analysis_cache.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/model/gpu_performance_model.h"
-#include "xla/service/gpu/model/gpu_performance_model_base.h"
-#include "xla/service/gpu/model/symbolic_tile_analysis.h"
-#include "xla/service/hlo_graph_dumper.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/blocking_counter.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-bool ElementIsF32OrF16(const Shape& shape) {
- PrimitiveType type = shape.element_type();
- return type == F32 || type == F16;
-}
-
-bool IsFusible(const HloInstruction& instr) {
- // Side-effecting operations are not fusible.
- if (!instr.IsFusible()) {
- return false;
- }
-
- // Element-wise operations are always fusible.
- if (instr.IsElementwise()) {
- return true;
- }
-
- // Other non-elementwise ops also supported by elemental fusion.
- switch (instr.opcode()) {
- case HloOpcode::kFusion:
- return instr.fusion_kind() != HloInstruction::FusionKind::kCustom;
-
- case HloOpcode::kCopy:
- case HloOpcode::kIota:
- case HloOpcode::kConstant:
- case HloOpcode::kReduce:
- case HloOpcode::kBitcast:
- case HloOpcode::kBroadcast:
- case HloOpcode::kConcatenate:
- case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- case HloOpcode::kGather:
- case HloOpcode::kPad:
- case HloOpcode::kReduceWindow:
- case HloOpcode::kReshape:
- case HloOpcode::kReverse:
- case HloOpcode::kScatter:
- case HloOpcode::kSlice:
- case HloOpcode::kTranspose:
- return true;
- default:
- return false;
- }
-}
-
-// An implementation of FusionQueue that determines whether to fuse instructions
-// according to a cost model, and chooses the next fusion candidate according to
-// dynamically updated priorities. The elements in the queue are producer nodes
-// that could be fused, and the priority of a producer is the benefit in
-// performance when fusing it to all of its fusible users. We greedily pick the
-// max-benefit producer to fuse, and update the estimated benefits of the fused
-// nodes and their operands.
-class GpuPriorityFusionQueue {
- using Priority = int64_t;
- using CanFuseCallback = std::function<FusionDecision(
- HloInstruction* /*producer*/, int64_t /*consumer operand_index*/)>;
-
- public:
- GpuPriorityFusionQueue(
- HloComputation* computation,
- const GpuHloCostAnalysis::Options& cost_analysis_options,
- const se::DeviceDescription* device_info,
- FusionProcessDumpProto* fusion_process_dump,
- tsl::thread::ThreadPool* thread_pool, mlir::MLIRContext* mlir_context,
- HloFusionAnalysisCache& fusion_analysis_cache,
- bool triton_softmax_priority_fusion_enabled)
- : computation_(computation),
- device_info_(device_info),
- cost_analysis_(cost_analysis_options, *device_info),
- fusion_process_dump_(fusion_process_dump),
- thread_pool_(thread_pool),
- mlir_context_(mlir_context),
- fusion_analysis_cache_(fusion_analysis_cache),
- triton_softmax_priority_fusion_enabled_(
- triton_softmax_priority_fusion_enabled) {
- VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
- TF_CHECK_OK(computation_->Accept(&cost_analysis_));
-
- dump_fusion_visualization_ = computation->parent()
- ->config()
- .debug_options()
- .xla_dump_fusion_visualization();
-
- // Initializes the priority queue.
- std::vector<HloInstruction*> instructions;
- for (auto* instruction : computation->MakeInstructionPostOrder()) {
- if (instruction->opcode() == HloOpcode::kParameter ||
- instruction->user_count() == 0 || !instruction->IsFusible() ||
- instruction->opcode() == HloOpcode::kTuple ||
- instruction->opcode() == HloOpcode::kGetTupleElement) {
- continue;
- }
- instructions.push_back(instruction);
- }
-
- ComputeAndSetPriorities(instructions);
- }
-
- void ComputeAndSetPriorities(
- const std::vector<HloInstruction*>& instructions) {
- std::vector<Priority> priorities = ComputePriorities(instructions);
-
- for (auto [instruction, priority] : llvm::zip(instructions, priorities)) {
- auto key = std::make_pair(priority, instruction->unique_id());
-
- // Remove instruction with the old priority from the queue.
- auto reverse_it = reverse_map_.find(instruction);
- if (reverse_it != reverse_map_.end()) {
- const PriorityQueue::iterator& queue_it = reverse_it->second;
- // Priority didn't change. Nothing to do.
- if (key == queue_it->first) {
- continue;
- }
- producer_priority_queue_.erase(queue_it);
- reverse_map_.erase(reverse_it);
- }
-
- // If the priority is negative, it's not helpful to perform fusion on this
- // instruction.
- if (priority < 0) {
- continue;
- }
-
- auto emplace_result = producer_priority_queue_.emplace(key, instruction);
- reverse_map_.emplace(instruction, emplace_result.first);
- }
- }
-
- std::vector<Priority> ComputePriorities(
- const std::vector<HloInstruction*>& instructions) {
- auto schedule_or_run = [this](std::function<void()> fn) {
- if (thread_pool_) {
- thread_pool_->Schedule(std::move(fn));
- } else {
- fn();
- }
- };
- tsl::BlockingCounter counter(instructions.size());
- std::vector<Priority> priorities(instructions.size());
-
- for (size_t i = 0; i < instructions.size(); ++i) {
- schedule_or_run([&, i] {
- priorities[i] = CalculateProducerPriority(instructions[i]);
- counter.DecrementCount();
- });
- }
- counter.Wait();
- return priorities;
- }
-
- // Gets the next pair of (producer, consumers) from the queue for fusion.
- // Returns true if there is the next producer to fuse, otherwise false. Stores
- // the producer and consumers in `current_producer_` and `current_consumers_`.
- bool DequeueNextProducer() {
- current_producer_ = nullptr;
- current_consumers_.clear();
-
- while (!producer_priority_queue_.empty() && current_consumers_.empty()) {
- auto next_it = std::prev(producer_priority_queue_.end());
-
- current_producer_ = next_it->second;
- producer_priority_queue_.erase(next_it);
- reverse_map_.erase(current_producer_);
-
- current_consumers_ = current_producer_->users();
-
- if (current_producer_->opcode() == HloOpcode::kBitcast) {
- // We don't check if bitcasts can be fused with all consumers, so we
- // have to do it here.
- llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) {
- return !CanFuseCached(current_producer_, consumer);
- });
- }
- }
-
- return !current_consumers_.empty();
- }
-
- // Update priorities of all affected ops.
- void UpdatePriorities() {
- // Revisit costs of all updated ops. It's important to update cost analysis
- // before recalculating priorities.
- for (auto instruction : to_update_priority_) {
- TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction));
- }
-
- ComputeAndSetPriorities(std::vector<HloInstruction*>{
- to_update_priority_.begin(), to_update_priority_.end()});
-
- to_update_priority_.clear();
- }
-
- // Prepares producer and consumer instruction to be fused. Invalidates caches
- // and writes logs.
- void PreFusion(HloInstruction* producer, HloInstruction* consumer) {
- if (dump_fusion_visualization_) {
- RegisterFusionState(
- *computation_,
- absl::StrCat("About to fuse |", producer->name(), "| into |",
- consumer->name(), "| inside PriorityFusion"),
- *consumer, producer);
- }
-
- InvalidateCaches(producer);
- InvalidateCaches(consumer);
- }
-
- // Invalidates all cached value related to this instruction. Called before the
- // instruction is fused. The instruction can be either producer or consumer.
- void InvalidateCaches(HloInstruction* instruction) {
- can_fuse_cache_.erase(instruction);
- for (const HloInstruction* operand : instruction->operands()) {
- auto it = can_fuse_cache_.find(operand);
- if (it != can_fuse_cache_.end()) {
- it->second.erase(instruction);
- }
- }
-
- gpu_performance_model_cache_.Invalidate(*instruction);
- fusion_analysis_cache_.Invalidate(*instruction);
- fusion_info_cache_.Invalidate(instruction);
- }
-
- // Updates data for the new fusion instruction and its users and operands.
- void OnFusingInstruction(HloInstruction* fusion,
- HloInstruction* original_producer,
- HloInstruction* original_consumer) {
- if (fusion_process_dump_) {
- auto* fusion_step =
- fusion_process_dump_->add_fusion_steps()->mutable_fusion();
-
- // Explicit std::string is needed for OSS proto implementation.
- fusion_step->set_fusion_name(std::string(fusion->name()));
- fusion_step->set_producer_name(std::string(original_producer->name()));
- fusion_step->set_consumer_name(std::string(original_consumer->name()));
- }
-
- if (dump_fusion_visualization_) {
- RegisterFusionState(
- *computation_,
- absl::StrCat("Fused |", original_producer->name(), "| into |",
- fusion->name(), "| inside PriorityFusion"),
- *fusion);
- }
-
- // The original consumer was replaced with the fusion, but it's pointer can
- // still be referenced somewhere, for example, in to_update_priority_.
- // Priority recomputation is called before DCE. Remove all references to
- // the original consumer here.
- if (fusion != original_consumer) {
- RemoveInstruction(original_consumer);
- }
-
- // Detach 'original_producer' from its operands if it has no users.
- // This avoids having it appear as a "phantom" user in subsequent priority
- // calculations on 'fusion.operands' below, before it is finally removed
- // in 'RemoveInstruction'.
- if (original_producer->user_count() == 0) {
- original_producer->DetachFromOperandsAndUsers();
- }
-
- // Collect the instructions whose priorities need to be updated.
- for (HloInstruction* operand : fusion->operands()) {
- if (operand == original_producer ||
- operand->opcode() == HloOpcode::kConstant ||
- operand->opcode() == HloOpcode::kGetTupleElement) {
- continue;
- }
- // Need to consider only instructions that are fusible, e.g., rng with
- // greater than one user is not fusible.
- if (!operand->IsFusible()) {
- continue;
- }
-
- to_update_priority_.insert(operand);
- }
- to_update_priority_.insert(fusion);
- }
-
- // Removes data for the instruction.
- void RemoveInstruction(HloInstruction* instruction) {
- to_update_priority_.erase(instruction);
- fusion_analysis_cache_.Invalidate(*instruction);
-
- auto reverse_it = reverse_map_.find(instruction);
- if (reverse_it == reverse_map_.end()) {
- return;
- }
- producer_priority_queue_.erase(reverse_it->second);
- reverse_map_.erase(reverse_it);
- }
-
- HloInstruction* current_producer() { return current_producer_; }
-
- const std::vector<HloInstruction*>& current_consumers() {
- return current_consumers_;
- }
-
- private:
- // Returns the priority of the producer based on its current operands and
- // users.
- Priority CalculateProducerPriority(HloInstruction* producer) {
- // Bitcasts should always be fused first, since they are no-ops.
- if (producer->opcode() == HloOpcode::kBitcast) {
- return std::numeric_limits<Priority>::max();
- }
- // We always fuse constants, but the cost model doesn't handle them very
- // well: fusing constants changes costs significantly. Also, there's no
- // point recomputing priorities. Therefore, we fuse all of them at the end.
- if (producer->opcode() == HloOpcode::kConstant) {
- return std::numeric_limits<Priority>::min();
- }
-
- // Don't fuse if we can't fuse in all users.
- if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer);
- !fusion_decision) {
- if (fusion_process_dump_) {
- absl::MutexLock lock(&fusion_process_dump_mutex_);
- auto* step = fusion_process_dump_->add_fusion_steps()
- ->mutable_producer_ineligible();
- step->set_producer_name(std::string(producer->name()));
- step->set_reason(fusion_decision.Explain());
- }
- return std::numeric_limits<Priority>::min();
- }
-
- GpuPerformanceModel::RunTimes run_times =
- GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
- producer, *device_info_, &cost_analysis_,
- GpuPerformanceModelOptions::PriorityFusion(
- &fusion_analysis_cache_, &gpu_performance_model_cache_),
- producer->users());
-
- if (fusion_process_dump_) {
- absl::MutexLock lock(&fusion_process_dump_mutex_);
- auto* step =
- fusion_process_dump_->add_fusion_steps()->mutable_update_priority();
- step->set_producer_name(std::string(producer->name()));
- for (auto* consumer : producer->users()) {
- step->add_consumer_names(std::string(consumer->name()));
- }
- step->set_us_fused(absl::ToDoubleMicroseconds(run_times.time_fused));
- step->set_us_unfused(absl::ToDoubleMicroseconds(run_times.time_unfused));
- }
- return absl::ToInt64Nanoseconds(run_times.time_unfused -
- run_times.time_fused);
- }
-
- FusionDecision CanFuseTriton(HloInstruction* producer,
- HloInstruction* consumer) {
- if (!triton_softmax_priority_fusion_enabled_) {
- return "triton softmax fusion is not enabled";
- }
-
- if (IsGenericTritonFusion(*producer)) {
- if (!IsFusible(*consumer)) {
- return "the consumer is not fusible";
- }
- } else {
- if (!IsFusible(*producer)) {
- return "the producer is not fusible";
- }
- }
-
- auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer);
-
- SymbolicTileAnalysisOrError symbolic_tile_analysis_or =
- SymbolicTileAnalysis::AnalyzeFusion(*fusion, mlir_context_);
-
- if (const auto* fusion_decision =
- std::get_if<FusionDecision>(&symbolic_tile_analysis_or)) {
- return {
- absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ",
- fusion_decision->Explain())};
- }
-
- return {};
- }
-
- FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) {
- if (IsGenericTritonFusion(*producer) || IsGenericTritonFusion(*consumer)) {
- return CanFuseTriton(producer, consumer);
- }
-
- if (!IsFusible(*producer)) {
- return "the producer is not fusible";
- }
-
- if (!IsFusible(*consumer)) {
- return "the consumer is not fusible";
- }
-
- if (consumer->opcode() == HloOpcode::kBitcast) {
- return "not fusing into a single bitcast as consumer";
- }
-
- // Scatter is special as it has no elemental version but is still input
- // fusible. Block attempts to create scatter fusions we can't codegen.
- if (auto can_fuse = CanEmitInputFusedScatter(*producer, *consumer);
- !can_fuse) {
- return can_fuse;
- }
-
- // Avoid fusing reduce into reduce. Our cost model doesn't currently
- // understand this case due to a lack of tiling analysis.
- // TODO(b/312200883): Remove this.
- auto contains_significant_reduce = [&](const HloInstruction* instr) {
- auto fusion = HloFusionAdaptor::ForInstruction(instr);
- return HloAnyOf(*fusion, [](auto node) {
- if (!(node.opcode() == HloOpcode::kReduce && node.shape().IsArray())) {
- return false;
- }
-
- int64_t reduction_size =
- ShapeUtil::ElementsIn(node.instruction().operand(0)->shape()) /
- ShapeUtil::ElementsIn(node.shape());
-
- // Small reductions are emitted using the elemental emitter anyway.
- return reduction_size >= 16;
- });
- };
- if (contains_significant_reduce(producer) &&
- contains_significant_reduce(consumer)) {
- return "both the producer and the consumer contain a reduce";
- }
-
- // Avoid doing fusions into the output of an "input" fusion when it would
- // switch it to the loop emitter. This often occurs during epilog fusion for
- // reductions, which suffer from limited emitter support.
- // TODO(b/312686229): Cost model should handle this.
- const auto& analysis = fusion_analysis_cache_.Get(*producer);
- if (analysis.GetEmitterFusionKind() ==
- HloFusionAnalysis::EmitterFusionKind::kReduction) {
- const auto& analysis_fused =
- fusion_analysis_cache_.Get(*producer, *consumer);
- if (analysis_fused.GetEmitterFusionKind() ==
- HloFusionAnalysis::EmitterFusionKind::kLoop) {
- return "fusion into output of a reduce fusion would create a loop "
- "fusion";
- }
- }
-
- // Avoid cases where we'd create a fusion that hit limitations in ptxas.
- // Would be nice to model this with cost instead.
- if (auto fits_budget = FusionFitsInBudget(
- *consumer, *producer, *device_info_,
- /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
- !fits_budget) {
- return fits_budget;
- }
-
- // Also check that our emitter can handle the fusion node. We currently can
- // have exponential time/memory requirements for emitting certain fusion
- // kernels, in which case we don't want to fuse.
- // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
- if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) {
- return "the fusion would result in an overly large code duplication";
- }
-
- // Don't fuse across a root instruction. There are situation when a root
- // instruction is not the last in the computation. Instructions after the
- // root are not necessary dead. They can be inputs to instructions with side
- // effects, like outfeed.
- if (producer == producer->parent()->root_instruction()) {
- return "not fusing into the output of the root instruction";
- }
-
- return InstructionFusion::ShouldFuseInPlaceOp(producer, consumer);
- }
-
- FusionDecision CanFuseCached(HloInstruction* producer,
- HloInstruction* consumer) {
- {
- absl::MutexLock lock(&can_fuse_cache_mutex_);
- auto& producer_cache = can_fuse_cache_[producer];
-
- auto it = producer_cache.find(consumer);
- if (it != producer_cache.end()) {
- return it->second;
- }
- }
-
- auto fusion_decision = CanFuse(producer, consumer);
-
- // The lock is required, because writing to a flat_hash_map is not
- // thread-safe even for different keys. We never call this computation
- // concurrently for the same producer, so it's guaranteed that we don't
- // override any value.
- {
- absl::MutexLock lock(&can_fuse_cache_mutex_);
- can_fuse_cache_[producer][consumer] = fusion_decision;
- }
-
- return fusion_decision;
- }
-
- FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) {
- if (producer->users().empty()) {
- return "No users to fuse";
- }
-
- FusionDecision result;
- bool has_non_bitcast_user = false;
- for (const auto& user : producer->users()) {
- if (user->opcode() == HloOpcode::kBitcast) {
- continue;
- }
- has_non_bitcast_user = true;
- if (auto fusion_decision = CanFuseCached(producer, user);
- !fusion_decision) {
- VLOG(10) << "Cannot fuse " << producer->name() << " with "
- << user->name() << ", because: " << fusion_decision.Explain();
- return fusion_decision;
- }
- }
- if (!has_non_bitcast_user) {
- return "not fusing because there are only bitcast users";
- }
- return {};
- }
-
- // Store computation for cost analysis.
- HloComputation* computation_;
-
- const se::DeviceDescription* device_info_;
-
- // Reference to cost model that defines priorities in the queue.
- GpuHloCostAnalysis cost_analysis_;
-
- // The priority queue of producers, implemented as an ordered map, where a
- // key is a pair: the first element is the priority and the second element is
- // the unique ID of the instruction to break ties.
- using PriorityQueue = std::map<std::pair<Priority, int>, HloInstruction*>;
- PriorityQueue producer_priority_queue_;
-
- // A reverse map that helps find an instruction in the priority queue.
- absl::flat_hash_map<HloInstruction*, PriorityQueue::iterator> reverse_map_;
-
- // The current producer being visited.
- HloInstruction* current_producer_;
-
- // The current consumers being visited.
- std::vector<HloInstruction*> current_consumers_;
-
- // The set of producers whose priorities need to be updated. Their
- // priorities are changed because their neighbors got fused, but we delay
- // the priority updates until current_consumers_ becomes empty. This is to
- // avoid recomputing priorities multiple times before we dequeue a new
- // producer.
- absl::flat_hash_set<HloInstruction*> to_update_priority_;
-
- // Proto with structured logs of fusion decisions. Used only for debugging. If
- // null, logging is disabled.
- FusionProcessDumpProto* fusion_process_dump_;
- absl::Mutex fusion_process_dump_mutex_;
-
- tsl::thread::ThreadPool* thread_pool_;
-
- mlir::MLIRContext* mlir_context_;
-
- HloFusionAnalysisCache& fusion_analysis_cache_;
-
- // Caches result of can_fuse for a (producer, consumer) pair. A cache entry is
- // invalidated if producer or consumer is modified.
- absl::flat_hash_map<
- const HloInstruction*,
- absl::flat_hash_map<const HloInstruction*, FusionDecision>>
- can_fuse_cache_;
- absl::Mutex can_fuse_cache_mutex_;
-
- GpuPerformanceModelCache gpu_performance_model_cache_;
-
- // Cache for `FusionFitsInBudget` to avoid recomputing expensive properties
- // like shared memory usage or number of unnested reductions of fusion nodes.
- FusionInfoCache fusion_info_cache_;
-
- bool triton_softmax_priority_fusion_enabled_;
-
- bool dump_fusion_visualization_;
-};
-
-} // namespace
-
-/*static*/ bool GpuPriorityFusion::IsExpensive(
- const HloInstruction& instruction) {
- // Some floating-point math ops are cheap on the GPU.
- switch (instruction.opcode()) {
- case HloOpcode::kDivide:
- case HloOpcode::kSqrt:
- case HloOpcode::kRsqrt:
- case HloOpcode::kExp:
- if (ElementIsF32OrF16(instruction.shape())) {
- return false;
- }
- break;
- // Loop fusions are cheap.
- case HloOpcode::kFusion:
- return false;
- default:
- break;
- }
- return InstructionFusion::IsExpensive(instruction);
-}
-
-// Return true, if instr is a small constant.
-//
-// There is not single definition for what is a small constant in XLA.
-// IrEmitterContext::emit_constant treats as small only constants of 1 element.
-// HloPrintOptions::print_large_constants is effective for constants larger
-// than 10 elements.
-//
-// This function matches the emitter logic.
-bool IsSmallConstant(const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kConstant && instr->shape().IsArray() &&
- ShapeUtil::ElementsIn(instr->shape()) <= 1;
-}
-
-bool GpuPriorityFusion::ConsumeFuel(HloInstruction* producer,
- HloInstruction* consumer) {
- return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] {
- return absl::StrFormat("Not fusing producer %s with consumer %s",
- producer->name(), consumer->name());
- });
-};
-
-absl::StatusOr<bool> GpuPriorityFusion::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool dump_enabled =
- DumpingEnabledForHloPass(name(), module->config().debug_options());
- if (dump_enabled) {
- fusion_process_dump_ = std::make_unique<FusionProcessDumpProto>();
- *fusion_process_dump_->mutable_gpu_device_info() =
- device_info_.ToGpuProto();
- }
-
- // Compute the computations within which more fusion is possible.
- auto fusible_computations =
- GetFusibleComputations(*module, execution_threads);
-
- // Appends ".0" suffix to all instructions.
- //
- // Every time an instruction is duplicated, the last integer suffix is
- // incremented.
- // Before: broadcast.123 -> broadcast.124
- // After: broadcast.123.0 -> broadcast.123.1
- //
- // With this modification it will be easier to match instructions before and
- // after fusion passes, because they will have the same unique prefix. Names
- // are not used in the pipeline, but it makes debugging much easier.
- for (auto* computation : fusible_computations) {
- for (auto* instruction : computation->instructions()) {
- module->SetAndUniquifyInstrName(instruction,
- absl::StrCat(instruction->name(), ".0"));
- }
- }
-
- if (dump_enabled) {
- fusion_process_dump_->set_hlo_module_before_fusion(
- module->ToString(HloPrintOptions::ShortParsable()));
- }
-
- bool triton_softmax_priority_fusion_enabled =
- module->config()
- .debug_options()
- .xla_gpu_enable_triton_softmax_priority_fusion();
-
- int changed = false;
- for (auto* computation : fusible_computations) {
- CHECK(!computation->IsFusionComputation());
-
- auto fusion_queue = std::make_unique<GpuPriorityFusionQueue>(
- computation, cost_analysis_options_, &device_info_,
- fusion_process_dump_.get(), thread_pool_, &mlir_context_,
- fusion_analysis_cache_, triton_softmax_priority_fusion_enabled);
-
- while (fusion_queue->DequeueNextProducer()) {
- auto producer = fusion_queue->current_producer();
-
- for (auto* consumer : fusion_queue->current_consumers()) {
- // Don't fuse into single bitcasts. We ignore them in the check
- // CanFuseWithAllNonBitcastUsers(), so we need to check it here.
- if (consumer->opcode() == HloOpcode::kBitcast) {
- continue;
- }
- if (!ConsumeFuel(producer, consumer)) continue;
-
- VLOG(5) << "next: " << consumer->name() << "(" << consumer << ") + "
- << producer->name() << "(" << producer << ")";
-
- fusion_queue->PreFusion(producer, consumer);
- auto fusion_instruction = Fuse(producer, consumer, computation);
- fusion_queue->OnFusingInstruction(fusion_instruction, producer,
- consumer);
-
- changed = true;
- }
-
- if (producer->user_count() == 0) {
- fusion_queue->RemoveInstruction(producer);
- // Remove from computation.
- TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer));
- }
-
- fusion_queue->UpdatePriorities();
- }
-
- // Fuse all constants.
- std::vector<HloInstruction*> constants;
- for (auto* instruction : computation->instructions()) {
- // Small constants should be fused, because they can be folded and
- // codegened efficiently.
- // Fusing large constants doesn't give much benefits, because they're
- // treated like parameters and read from global memory anyway. Fusion
- // and duplication of large constants can, however, cause problems if we
- // want to dump hlo and parse back, because in that case duplicated
- // constants will be filled with different data.
- if (IsSmallConstant(instruction)) {
- constants.push_back(instruction);
- }
- }
- for (auto* constant : constants) {
- auto users = constant->users();
- for (auto* user : users) {
- if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) {
- Fuse(constant, user, computation);
- changed = true;
- }
- }
- }
- }
-
- // FusionAnalysis cache uses unique_id as key. IDs are only unique inside one
- // module. It's important to fully clear the cache if the same instance of the
- // pass will be called on a different module.
- fusion_analysis_cache_.Clear();
-
- if (dump_enabled) {
- DumpPerModuleProtobufToFile(*module, *fusion_process_dump_,
- module->config().debug_options(),
- "priority_fusion_dump");
- }
-
- return changed;
-}
-
-FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer,
- int64_t operand_index) {
- // This method is called in `InstructionFusion::Run` right before fusion, but
- // it will always return true. Fusion decision are fully controlled by the
- // PriorityQueue. If the queue returns a producer that shouldn't be fused,
- // it's a bug and should be fixed in the queue logic.
- return {};
-}
-
-HloInstruction::FusionKind GpuPriorityFusion::ChooseKind(
- const HloInstruction* producer, const HloInstruction* consumer) {
- // Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't
- // matter but some passes downstream still query these instead of fusion
- // analysis.
- const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer);
- switch (analysis.GetEmitterFusionKind()) {
- case HloFusionAnalysis::EmitterFusionKind::kLoop:
- return HloInstruction::FusionKind::kLoop;
- case HloFusionAnalysis::EmitterFusionKind::kTriton:
- case HloFusionAnalysis::EmitterFusionKind::kCustomFusion:
- case HloFusionAnalysis::EmitterFusionKind::kCuDnn:
- return HloInstruction::FusionKind::kCustom;
- case HloFusionAnalysis::EmitterFusionKind::kConcatenate:
- case HloFusionAnalysis::EmitterFusionKind::kReduction:
- case HloFusionAnalysis::EmitterFusionKind::kTranspose:
- case HloFusionAnalysis::EmitterFusionKind::kInputSlices:
- case HloFusionAnalysis::EmitterFusionKind::kScatter:
- return HloInstruction::FusionKind::kInput;
- }
-}
-
-HloInstruction* GpuPriorityFusion::FuseInstruction(
- HloInstruction* fusion_instruction, HloInstruction* producer) {
- HloInstruction* result = fusion_instruction;
- if (producer->opcode() == HloOpcode::kFusion) {
- if (IsGenericTritonFusion(*producer)) {
- TF_CHECK_OK(fusion_instruction->set_backend_config(
- *producer->backend_config<GpuBackendConfig>()));
- }
-
- fusion_instruction->MergeFusionInstruction(producer);
- } else {
- result = InstructionFusion::FuseInstruction(fusion_instruction, producer);
- }
- return result;
-}
-
-std::unique_ptr<FusionQueue> GpuPriorityFusion::GetFusionQueue(
- HloComputation* computation) {
- return nullptr;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h
deleted file mode 100644
index 999eb78..0000000
--- a/third_party/xla/xla/service/gpu/priority_fusion.h
+++ /dev/null
@@ -1,100 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_PRIORITY_FUSION_H_
-#define XLA_SERVICE_GPU_PRIORITY_FUSION_H_
-
-#include <stdint.h>
-
-#include <memory>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/gpu/fusion_process_dump.pb.h"
-#include "xla/service/gpu/model/fusion_analysis_cache.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-
-class GpuPriorityFusion : public InstructionFusion {
- public:
- GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool,
- const se::DeviceDescription& device,
- GpuHloCostAnalysis::Options cost_analysis_options)
- : InstructionFusion(GpuPriorityFusion::IsExpensive),
- thread_pool_(thread_pool),
- device_info_(device),
- cost_analysis_options_(std::move(cost_analysis_options)),
- fusion_analysis_cache_(device_info_) {}
-
- absl::string_view name() const override { return "priority-fusion"; }
-
- static bool IsExpensive(const HloInstruction& instruction);
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- protected:
- std::unique_ptr<FusionQueue> GetFusionQueue(
- HloComputation* computation) override;
-
- FusionDecision ShouldFuse(HloInstruction* consumer,
- int64_t operand_index) override;
-
- HloInstruction::FusionKind ChooseKind(
- const HloInstruction* producer, const HloInstruction* consumer) override;
-
- private:
- HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
- HloInstruction* producer) override;
-
- // Consumes a unit of compiler fuel and returns true if we should
- // continue with the transformation.
- bool ConsumeFuel(HloInstruction* producer, HloInstruction* consumer);
-
- tsl::thread::ThreadPool* thread_pool_;
- se::DeviceDescription device_info_;
-
- // Cost model options that defines priorities in the queue.
- GpuHloCostAnalysis::Options cost_analysis_options_;
-
- // Proto with structured logs of fusion decisions. Used only for debugging. If
- // null, logging is disabled.
- std::unique_ptr<FusionProcessDumpProto> fusion_process_dump_;
-
- HloFusionAnalysisCache fusion_analysis_cache_;
-
- mlir::MLIRContext mlir_context_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_PRIORITY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc
deleted file mode 100644
index 4f71a51..0000000
--- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc
+++ /dev/null
@@ -1,941 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/priority_fusion.h"
-
-#include <stdint.h>
-
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace m = ::xla::match;
-
-using ::testing::UnorderedElementsAre;
-using ::tsl::testing::IsOk;
-using ::tsl::testing::IsOkAndHolds;
-
-namespace xla {
-namespace gpu {
-
-class PriorityFusionTest : public HloTestBase {
- HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
- return [&](const Shape& shape) {
- constexpr int64_t kPointerSize = 8;
- return ShapeUtil::ByteSizeOf(shape, kPointerSize);
- };
- }
-
- public:
- std::vector<HloFusionAnalysis::EmitterFusionKind> RunAndGetFusionKinds(
- absl::string_view hlo) {
- auto module = ParseAndReturnVerifiedModule(hlo).value();
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
- EXPECT_THAT(module->RemoveUnusedComputations(), IsOk());
- std::vector<HloFusionAnalysis::EmitterFusionKind> kinds;
- for (auto computation : module->computations()) {
- if (!computation->FusionInstruction()) continue;
-
- auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
- auto analysis = HloFusionAnalysis::Create(
- Cast<HloFusionInstruction>(computation->FusionInstruction()),
- &device_info);
- kinds.push_back(analysis.GetEmitterFusionKind());
- }
- return kinds;
- }
-
- GpuPriorityFusion priority_fusion_{
- /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(),
- GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(),
- /*per_second_rates=*/{},
- /*count_multiple_input_accesses=*/true}};
-};
-
-TEST_F(PriorityFusionTest, FuseWithSharedArgument) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- ENTRY main {
- %p0 = f32[] parameter(0)
- %p1 = f32[] parameter(1)
- %subtract = f32[] subtract(%p0, %p1)
- %compare = pred[] compare(%subtract, %subtract), direction=NE
- %add = f32[] add(%p0, %p1)
- %abs = f32[] abs(%subtract)
- ROOT %select = f32[] select(%compare, %add, %abs)
- })")
- .value();
-
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::Fusion()));
- EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
-}
-
-TEST_F(PriorityFusionTest, FusionFusionWithDuplication) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- square {
- p = f32[16384]{0} parameter(0)
- ROOT m = f32[16384]{0} multiply(p, p)
- }
-
- exp {
- p = f32[16384]{0} parameter(0)
- ROOT e = f32[16384]{0} exponential(p)
- }
-
- log {
- p = f32[16384]{0} parameter(0)
- ROOT l = f32[16384]{0} log(p)
- }
-
- ENTRY main {
- p = f32[16384]{0} parameter(0)
- s = f32[16384]{0} fusion(p), kind=kLoop, calls=square
- e = f32[16384]{0} fusion(s), kind=kLoop, calls=exp
- l = f32[16384]{0} fusion(s), kind=kInput, calls=log
- ROOT t = (f32[16384], f32[16384]) tuple(l, e)
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ENTRY
-CHECK-NEXT: %[[PARAM:.*]] = f32[16384]{0} parameter(0)
-CHECK-NEXT: %[[FUSION_0:.*]] = f32[16384]{0} fusion(%[[PARAM]])
-CHECK-NEXT: %[[FUSION_1:.*]] = f32[16384]{0} fusion(%[[PARAM]])
-CHECK-NEXT: ROOT {{.*}} tuple(%[[FUSION_0]], %[[FUSION_1]])
- )");
-}
-
-TEST_F(PriorityFusionTest, FuseBroadcastIntoBitcastConsumers) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- ENTRY main {
- param_0 = f32[96]{0} parameter(0)
- broadcast = f32[8,96,128,7]{3,2,1,0} broadcast(param_0), dimensions={1}
- bitcast.6079.2 = f32[8,24,4,128,7]{4,3,2,1,0} bitcast(broadcast)
- ROOT transpose.1990.2 = f32[8,24,128,7,4]{4,3,2,1,0} transpose(bitcast.6079.2), dimensions={0,1,3,4,2}
- }
- )";
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ENTRY
-CHECK-NEXT: %[[PARAM:.*]] = f32[96]{0} parameter(0)
-CHECK-NEXT: ROOT %{{.*}} fusion(%[[PARAM]])
- )");
-}
-
-TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- ENTRY main {
- p = f16[512]{0} parameter(0)
- a = f16[512]{0} add(p, p)
- c = f32[512]{0} convert(a)
- s = f32[512]{0} multiply(c, c)
- bc = s32[512]{0} bitcast(c)
- ROOT t = (f32[512], s32[512]) tuple(s, bc)
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ENTRY
-CHECK-NEXT: %[[PARAM:.*]] = f16[512]{0} parameter(0)
-CHECK-NEXT: %[[FUSION_F32:.*]] = f32[512]{0} fusion(%[[PARAM]])
-CHECK-NEXT: %[[CONVERT_FUSION:.*]] = f32[512]{0} fusion(%[[PARAM]])
-CHECK-NEXT: %[[BITCAST:.*]] = s32[512]{0} bitcast(%[[CONVERT_FUSION]])
-CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[BITCAST]])
- )");
-}
-
-TEST_F(PriorityFusionTest, FuseConvertIntoReduce) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add.13235 = f32[] add(p0, p1)
- }
-
- ENTRY main {
- param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
- param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
- param_2.483 = f32[8192]{0} parameter(2)
- param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
- convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
- convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
- constant_7773 = f32[] constant(0)
- broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
- multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
- reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
- convert.13970 = bf16[1024]{0} convert(reduce.4813)
- convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
- multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
- reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
- convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
- multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
- reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
- convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
- ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK-COUNT-3: ROOT {{.*}} convert(
-CHECK: ENTRY %main
-CHECK-COUNT-3: fusion
- )");
-}
-
-TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) {
- // Regression test for epilogue fusion of convert into a reduction, even if
- // the convert has a bitcast as consumer.
- absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- rhs.407 = f32[] parameter(1)
- lhs.407 = f32[] parameter(0)
- ROOT add.24451 = f32[] add(lhs.407, rhs.407)
- }
-
- ENTRY main {
- param_1.15162 = f32[2752]{0} parameter(1)
- convert.44829 = bf16[2752]{0} convert(param_1.15162)
- bitcast.24686 = bf16[1,1,2752]{2,1,0} bitcast(convert.44829)
- convert.44468 = f32[1,1,2752]{2,1,0} convert(bitcast.24686)
- constant_13722 = bf16[] constant(1)
- convert.17451 = f32[] convert(constant_13722)
- broadcast.17565 = f32[1,1,2752]{2,1,0} broadcast(convert.17451), dimensions={}
- negate.167 = f32[1,1,2752]{2,1,0} negate(convert.44468)
- exponential.569 = f32[1,1,2752]{2,1,0} exponential(negate.167)
- add.1850 = f32[1,1,2752]{2,1,0} add(broadcast.17565, exponential.569)
- divide.1376 = f32[1,1,2752]{2,1,0} divide(broadcast.17565, add.1850)
- multiply.9709 = f32[1,1,2752]{2,1,0} multiply(convert.44468, divide.1376)
- param_0.15005 = f32[2752]{0} parameter(0)
- convert.44826 = bf16[2752]{0} convert(param_0.15005)
- bitcast.24683 = bf16[1,1,2752]{2,1,0} bitcast(convert.44826)
- convert.44467 = f32[1,1,2752]{2,1,0} convert(bitcast.24683)
- multiply.9708 = f32[1,1,2752]{2,1,0} multiply(multiply.9709, convert.44467)
- convert.16959 = bf16[1,1,2752]{2,1,0} convert(multiply.9708)
- fusion.3203 = bf16[2752]{0} bitcast(convert.16959)
- convert.15093 = f32[2752]{0} convert(fusion.3203)
- broadcast.13841 = f32[8192,2752]{1,0} broadcast(convert.15093), dimensions={1}
- param_0.15525 = bf16[8192,2752]{1,0} parameter(2)
- convert.13738 = f32[8192,2752]{1,0} convert(param_0.15525)
- multiply.6422 = f32[8192,2752]{1,0} multiply(broadcast.13841, convert.13738)
- constant_14382 = f32[] constant(0)
- fusion.339 = f32[8192]{0} reduce(multiply.6422, constant_14382), dimensions={1}, to_apply=add
- convert.44633 = bf16[8192]{0} convert(fusion.339)
- ROOT bitcast.24487 = bf16[1,1,8192]{2,1,0} bitcast(convert.44633)
- }
- )";
-
- EXPECT_THAT(
- RunAndGetFusionKinds(kHlo),
- UnorderedElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop,
- HloFusionAnalysis::EmitterFusionKind::kReduction));
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ENTRY
-CHECK: ROOT {{.*}} bitcast({{.*}}fusion{{.*}})
- )");
-}
-
-TEST_F(PriorityFusionTest, DoNotChangeReductionFusionToLoopFusion) {
- // Regression test for epilogue fusion of slice into a reduction. The fusion
- // kind for the reduction fusion is intentionally chosen to be set to kLoop,
- // as we cannot rely on reductions always having fusion kind kInput.
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- rhs.407 = f32[] parameter(1)
- lhs.407 = f32[] parameter(0)
- ROOT add.24451 = f32[] add(lhs.407, rhs.407)
- }
-
- fused_computation {
- p0 = f32[16,64]{1,0} parameter(0)
- zero = f32[] constant(0.0)
- ROOT reduce = f32[16]{0} reduce(p0, zero), dimensions={1}, to_apply=add
- }
-
- ENTRY main {
- param0 = f32[16,64]{1,0} parameter(0)
- fusion = f32[16]{0} fusion(param0), kind=kLoop, calls=fused_computation
- ROOT slice = f32[8]{0} slice(fusion), slice={[0:8]}
- })");
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- Arg_1.1046 = f32[] parameter(1)
- Arg_0.1045 = f32[] parameter(0)
- ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
- }
-
- ENTRY main {
- param_0.17323 = pred[2048,2048]{1,0} parameter(0)
- broadcast.22829 = pred[1,12,2048,2048]{3,2,1,0} broadcast(param_0.17323), dimensions={2,3}
- param_1.19761 = bf16[2048,24576]{1,0} parameter(1)
- convert.29880.clone.1 = f32[2048,24576]{1,0} convert(param_1.19761)
- constant_10033_clone_1 = bf16[] constant(0.02002)
- convert.30056.clone.1 = f32[] convert(constant_10033_clone_1)
- broadcast.18898.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30056.clone.1), dimensions={}
- multiply.13451.clone.1 = f32[2048,24576]{1,0} multiply(convert.29880.clone.1, broadcast.18898.clone.1)
- tanh.798.clone.1 = f32[2048,24576]{1,0} tanh(multiply.13451.clone.1)
- constant_10244_clone_1 = bf16[] constant(50)
- convert.30039.clone.1 = f32[] convert(constant_10244_clone_1)
- broadcast.18310.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30039.clone.1), dimensions={}
- multiply.12550.clone.1 = f32[2048,24576]{1,0} multiply(tanh.798.clone.1, broadcast.18310.clone.1)
- convert.29370.clone.1 = bf16[2048,24576]{1,0} convert(multiply.12550.clone.1)
- bitcast.22330 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
- transpose.6582 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22330), dimensions={0,3,2,1}
- convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6582)
- constant_10212 = f32[] constant(-2.38197633e+38)
- broadcast.22828 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10212), dimensions={}
- select.589 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22829, convert.33705, broadcast.22828)
- bitcast.22075 = f32[12,2048,2048]{2,1,0} bitcast(select.589)
- constant_10192 = f32[] constant(-inf)
- reduce.1614 = f32[12,2048]{1,0} reduce(bitcast.22075, constant_10192), dimensions={2}, to_apply=add
-
- predarg = pred[1,1,2048,2048]{3,2,1,0} parameter(2)
- bitcast.11069 = pred[2048,2048]{1,0} bitcast(predarg)
-
- broadcast.22825 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
- bitcast.22331 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
- transpose.6580 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22331), dimensions={0,3,2,1}
- convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6580)
- constant_10213 = f32[] constant(-2.38197633e+38)
- broadcast.22824 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10213), dimensions={}
- select.587 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22825, convert.33703, broadcast.22824)
- broadcast.22819 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
- subtract.1129 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.587, broadcast.22819)
- exponential.418 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1129)
- bitcast.22074 = f32[12,2048,2048]{2,1,0} bitcast(exponential.418)
- constant_10490 = f32[] constant(0)
- reduce.1613 = f32[12,2048]{1,0} reduce(bitcast.22074, constant_10490), dimensions={2}, to_apply=add
-
- constant_468 = f32[] constant(-2.38197633e+38)
- broadcast.22833 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
- bitcast.22332 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
- transpose.6584 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22332), dimensions={0,3,2,1}
- convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6584)
- broadcast.22832 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_468), dimensions={}
- select.591 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22833, convert.33707, broadcast.22832)
- broadcast.22821 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
- subtract.1131 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.591, broadcast.22821)
- exponential.420 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1131)
- broadcast.18351 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1613), dimensions={1,2}
- divide.340 = f32[1,12,2048,2048]{3,2,1,0} divide(exponential.420, broadcast.18351)
- ROOT convert.29418 = bf16[1,12,2048,2048]{3,2,1,0} convert(divide.340)
- })";
-
- using Kind = HloFusionAnalysis::EmitterFusionKind;
- EXPECT_THAT(
- RunAndGetFusionKinds(kHlo),
- UnorderedElementsAre(Kind::kLoop, Kind::kLoop, Kind::kLoop,
- Kind::kReduction, Kind::kReduction, Kind::kTranspose,
- Kind::kTranspose, Kind::kTranspose));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduce) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add.13235 = f32[] add(p0, p1)
- }
-
- ENTRY main {
- p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
- c0 = f32[] constant(0)
- r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
- ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} reduce(
-CHECK: ROOT {{.*}} reduce(
- )");
-}
-
-TEST_F(PriorityFusionTest, ConvertFusedIntoReduce) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add.13235 = f32[] add(p0, p1)
- }
-
- ENTRY main {
- param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
- param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
- param_2.483 = f32[8192]{0} parameter(2)
- param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
- convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
- convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
- constant_7773 = f32[] constant(0)
- broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
- multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
- reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
- convert.13970 = bf16[1024]{0} convert(reduce.4813)
- convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
- multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
- reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
- convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
- multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
- reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
- convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
- ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK-COUNT-3: ROOT {{.*}} convert(
-CHECK: ENTRY %main
-CHECK-COUNT-3: fusion(
-CHECK-NOT: fusion(
- )");
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseDynamicUpdateSliceIntoReduce) {
- GTEST_SKIP() << "b/294198633";
- absl::string_view kHlo = R"(
- HloModule test_module
-
-add {
- Arg_1.1046 = f32[] parameter(1)
- Arg_0.1045 = f32[] parameter(0)
- ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
-}
-
-ENTRY main {
- param_0.10549 = f32[4,2112]{1,0} parameter(0)
- param_5.2561 = pred[] parameter(5)
- broadcast.19725 = pred[4,1]{1,0} broadcast(param_5.2561), dimensions={}
- param_1.11587 = pred[4]{0} parameter(1)
- constant_5837 = f32[] constant(1)
- broadcast.19723 = f32[4]{0} broadcast(constant_5837), dimensions={}
- param_2.5952 = f32[4,8000]{1,0} parameter(2)
- param_3.4004 = f32[4]{0} parameter(3)
- broadcast.19718 = f32[4,8000]{1,0} broadcast(param_3.4004), dimensions={0}
- subtract.1112 = f32[4,8000]{1,0} subtract(param_2.5952, broadcast.19718)
- exponential.418 = f32[4,8000]{1,0} exponential(subtract.1112)
- constant_6254 = f32[] constant(0)
- reduce.1154 = f32[4]{0} reduce(exponential.418, constant_6254), dimensions={1}, to_apply=add
- log.38 = f32[4]{0} log(reduce.1154)
- broadcast.19717 = f32[4,8000]{1,0} broadcast(log.38), dimensions={0}
- subtract.1111 = f32[4,8000]{1,0} subtract(subtract.1112, broadcast.19717)
- iota.170 = s32[4,1]{1,0} iota(), iota_dimension=0
- constant_6281 = s32[] constant(0)
- broadcast.19735 = s32[4]{0} broadcast(constant_6281), dimensions={}
- param_4.3400 = s32[4,8000]{1,0} parameter(4)
- slice.3186 = s32[4,40]{1,0} slice(param_4.3400), slice={[0:4], [0:40]}
- iota.168 = s32[4,1]{1,0} iota(), iota_dimension=0
- param_7.1596 = s32[4]{0} parameter(7)
- compare.341 = pred[4]{0} compare(param_7.1596, broadcast.19735), direction=LT
- constant_5833 = s32[] constant(40)
- broadcast.19731 = s32[4]{0} broadcast(constant_5833), dimensions={}
- add.8348 = s32[4]{0} add(param_7.1596, broadcast.19731)
- select.418 = s32[4]{0} select(compare.341, add.8348, param_7.1596)
- bitcast.20942 = s32[4,1]{1,0} bitcast(select.418)
- concatenate.1337 = s32[4,2]{1,0} concatenate(iota.168, bitcast.20942), dimensions={1}
- gather.43 = s32[4,1,1]{2,1,0} gather(slice.3186, concatenate.1337), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
- bitcast.20941 = s32[4]{0} bitcast(gather.43)
- select.398 = s32[4]{0} select(param_1.11587, broadcast.19735, bitcast.20941)
- compare.334 = pred[4]{0} compare(select.398, broadcast.19735), direction=LT
- constant_6260 = s32[] constant(8000)
- broadcast.19720 = s32[4]{0} broadcast(constant_6260), dimensions={}
- add.8336 = s32[4]{0} add(select.398, broadcast.19720)
- select.396 = s32[4]{0} select(compare.334, add.8336, select.398)
- bitcast.20830 = s32[4,1]{1,0} bitcast(select.396)
- concatenate.1308 = s32[4,2]{1,0} concatenate(iota.170, bitcast.20830), dimensions={1}
- gather.41 = f32[4,1,1]{2,1,0} gather(subtract.1111, concatenate.1308), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
- bitcast.20824 = f32[4]{0} bitcast(gather.41)
- select.389 = f32[4]{0} select(param_1.11587, broadcast.19723, bitcast.20824)
- bitcast.20823 = f32[4,1]{1,0} bitcast(select.389)
- param_6.1719 = s32[] parameter(6)
- constant_6323 = s32[] constant(2048)
- add.8549 = s32[] add(param_6.1719, constant_6323)
- compare.388 = pred[] compare(add.8549, constant_6281), direction=LT
- constant_5436 = s32[] constant(4160)
- add.8339 = s32[] add(param_6.1719, constant_5436)
- select.409 = s32[] select(compare.388, add.8339, add.8549)
- dynamic-slice.36 = f32[4,1]{1,0} dynamic-slice(param_0.10549, constant_6281, select.409), dynamic_slice_sizes={4,1}
- select.388 = f32[4,1]{1,0} select(broadcast.19725, bitcast.20823, dynamic-slice.36)
- ROOT dynamic-update-slice.307 = f32[4,2112]{1,0} dynamic-update-slice(param_0.10549, select.388, constant_6281, select.409)
-})";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} dynamic-update-slice(
-CHECK: %[[REDUCE:.*]] = {{.*}} reduce(
-CHECK: ROOT {{.*}} log(%[[REDUCE]])
-CHECK: ENTRY
-CHECK-COUNT-2: fusion(
- )");
-}
-
-TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) {
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = s32[] parameter(0)
- rhs = s32[] parameter(1)
- ROOT add = s32[] add(lhs, rhs)
- }
-
- ENTRY FuseIntoScatter {
- p0 = s32[3,3] parameter(0)
- operand = s32[3,3] add(p0, p0)
- p1 = s32[2] parameter(1)
- indices = s32[2] add(p1, p1)
- p2 = s32[2,3] parameter(2)
- updates = s32[2,3] add(p2, p2)
- scatter = s32[3,3] scatter(operand, indices, updates),
- to_apply=add,
- update_window_dims={1},
- inserted_window_dims={0},
- scatter_dims_to_operand_dims={0},
- index_vector_dim=1
- ROOT add = s32[3,3] add(scatter, scatter)
- })");
-
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- const HloInstruction* fusion = nullptr;
- ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
- EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
- EXPECT_THAT(fusion->fused_expression_root(),
- GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
-}
-
-// This test is similar to DontFuseIntoFirstOperandOfScatter, but PriorityFusion
-// has a separate run to fuse constants. Fusing anything into a scatter fusion
-// will fail in the emitter.
-TEST_F(PriorityFusionTest, DontFuseConstantIntoFirstOperandOfScatter) {
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- lhs = s32[] parameter(0)
- rhs = s32[] parameter(1)
- ROOT add = s32[] add(lhs, rhs)
- }
-
- ENTRY FuseIntoScatter {
- operand = s32[1] constant({0})
- indices = s32[24,1] parameter(0)
- constant = s32[] constant(1)
- updates = s32[24,1] broadcast(constant)
- ROOT scatter = s32[1] scatter(operand, indices, updates),
- to_apply=add,
- update_window_dims={1},
- inserted_window_dims={},
- scatter_dims_to_operand_dims={0},
- index_vector_dim=1
- })");
-
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
- EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
- EXPECT_THAT(root->fused_expression_root(),
- GmockMatch(m::Scatter(m::Parameter(), m::Parameter(),
- m::Broadcast(m::Constant()))));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduceEvenIfOccupancyIsHigh) {
- constexpr absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
- }
-
- ENTRY main {
- p0 = f32[4,3584,128,168]{3,2,1,0} parameter(0)
- c = f32[] constant(0)
- r1 = f32[4,3584,128]{2,1,0} reduce(p0, c), dimensions={3}, to_apply=add
- ROOT r2 = f32[4,3584]{1,0} reduce(r1, c), dimensions={2}, to_apply=add
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} reduce(
-CHECK: ROOT {{.*}} reduce(
- )");
-}
-
-TEST_F(PriorityFusionTest, FuseReductionEpilogueWithMultipleUsers) {
- // Regression test that verifies we correctly fuse the `log` into the reduce.
- constexpr absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- x = f32[] parameter(0)
- y = f32[] parameter(1)
- ROOT add = f32[] add(x, y)
- }
-
- fused_computation {
- p0 = f32[64,16384]{1,0} parameter(0)
- c0 = f32[] constant(0)
- ROOT reduce.858 = f32[64]{0} reduce(p0, c0), dimensions={1}, to_apply=add
- }
-
- ENTRY main {
- p0 = f32[64,16384]{1,0} parameter(0)
- fusion = f32[64]{0} fusion(p0), kind=kInput, calls=fused_computation
- log = f32[64]{0} log(fusion)
- negate = f32[64]{0} custom-call(log), custom_call_target="negate"
- ROOT add = f32[64]{0} add(negate, log)
- }
- )";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
- CHECK: ENTRY
- CHECK: %[[PARAM:.*]] = {{.*}} parameter(0)
- CHECK: %[[FUSION:.*]] = {{.*}} fusion(%[[PARAM]])
- CHECK: custom-call(%[[FUSION]])
- )");
-}
-
-TEST_F(PriorityFusionTest, EpilogueFusion) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add.13235 = f32[] add(p0, p1)
- }
-
- fused_computation.1 {
- p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
- c0 = f32[] constant(0)
- ROOT r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
- }
-
- fused_computation.2 {
- p0 = f32[8,4,128]{2,1,0} parameter(0)
- r1 = f32[8,4,128]{2,1,0} log(p0)
- ROOT r2 = f32[8,4,128]{2,1,0} log(r1)
- }
-
- ENTRY main {
- p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
- f1 = f32[8,4,128]{2,1,0} fusion(p0), kind=kInput, calls=%fused_computation.1
- ROOT fusion = f32[8,4,128]{2,1,0} fusion(f1), kind=kLoop, calls=%fused_computation.2
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} = f32[8,4,128]{2,1,0} fusion(%p{{.*}}), kind=kInput, calls=%fused_computation)");
-}
-
-TEST_F(PriorityFusionTest, EpilogueFusionFails) {
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add.13235 = f32[] add(p0, p1)
- }
-
- fused_computation.1 {
- p0 = f32[28672,4096]{1,0} parameter(0)
- c0 = f32[] constant(0)
- ROOT r = f32[28672]{0} reduce(p0, c0), dimensions={1}, to_apply=add
- }
-
- fused_computation.2 {
- p0 = f32[28672]{0} parameter(0)
- p1 = f32[28672]{0} parameter(1)
- ROOT a = f32[28672]{0} add(p0, p1)
- }
-
- ENTRY main {
- p0 = f32[28672,4096]{1,0} parameter(0)
- p1 = f32[28672]{0} parameter(1)
- f = f32[28672]{0} fusion(p0), kind=kInput, calls=%fused_computation.1
- ROOT fusion = f32[28672]{0} fusion(f,p1), kind=kLoop, calls=%fused_computation.2
- })");
-
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseIntoRoot) {
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule test_module
-
- ENTRY %main (p.0: u32[2], p.1: u32[]) -> u32[2] {
- %p.0 = u32[2]{0} parameter(0)
- %p.1 = u32[] parameter(1)
- ROOT %broadcast = u32[2]{0} broadcast(u32[] %p.1), dimensions={}, sharding={replicated}
- %add = u32[2]{0} add(u32[2]{0} %p.0, u32[2]{0} %broadcast)
- %tuple.1 = (u32[2]{0}) tuple(u32[2]{0} %add)
- %token.0 = token[] after-all()
- %outfeed.6 = token[] outfeed((u32[2]{0}) %tuple.1, token[] %token.0), outfeed_shape=(u32[2]{0}), sharding={maximal device=0}
- })");
-
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, DontFuseConcat) {
- // Regression test that verifies we don't fuse concat into a column reduction.
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- %maximum (param_0: f32[], param_1: f32[]) -> f32[] {
- %param_0 = f32[] parameter(0)
- %param_1 = f32[] parameter(1)
- ROOT %maximum = f32[] maximum(f32[] %param_0, f32[] %param_1)
- }
-
- %fused_concat (param_0: f32[1,4,401,8,8], param_1: f32[1,1,4,1023,8], param_2: bf16[1,4,1023,8,8]) -> f32[1,4,1424,8,8] {
- %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
- %convert = f32[1,4,1023,8,8]{4,3,2,1,0} convert(bf16[1,4,1023,8,8]{4,3,2,1,0} %param_2)
- %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
- %bitcast = f32[4,1023,8]{2,1,0} bitcast(f32[1,1,4,1023,8]{4,3,2,1,0} %param_1)
- %broadcast = f32[1,4,1023,8,8]{4,3,2,1,0} broadcast(f32[4,1023,8]{2,1,0} %bitcast), dimensions={1,2,4}
- %add = f32[1,4,1023,8,8]{4,3,2,1,0} add(f32[1,4,1023,8,8]{4,3,2,1,0} %convert, f32[1,4,1023,8,8]{4,3,2,1,0} %broadcast)
- %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
- ROOT %concatenate = f32[1,4,1424,8,8]{4,3,2,1,0} concatenate(f32[1,4,1023,8,8]{4,3,2,1,0} %add, f32[1,4,401,8,8]{4,3,2,1,0} %param_0), dimensions={2}
- }
-
- %fused_reduce (param_0: f32[], param_1: f32[1,4,1424,8,8]) -> f32[4,8,8] {
- %param_1 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(1)
- %bitcast = f32[4,1424,8,8]{3,2,1,0} bitcast(f32[1,4,1424,8,8]{4,3,2,1,0} %param_1)
- %param_0 = f32[] parameter(0)
- ROOT %reduce = f32[4,8,8]{2,1,0} reduce(f32[4,1424,8,8]{3,2,1,0} %bitcast, f32[] %param_0), dimensions={1}, to_apply=%maximum
- }
-
- %fused_broadcast (param_0: f32[1,4,1424,8,8], param_1: f32[4,8,8]) -> f32[1,4,1424,8,8] {
- %param_0 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(0)
- %param_1 = f32[4,8,8]{2,1,0} parameter(1)
- %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} broadcast(f32[4,8,8]{2,1,0} %param_1), dimensions={1,3,4}
- ROOT %subtract = f32[1,4,1424,8,8]{4,3,2,1,0} subtract(f32[1,4,1424,8,8]{4,3,2,1,0} %param_0, f32[1,4,1424,8,8]{4,3,2,1,0} %broadcast)
- }
-
- ENTRY fusion {
- %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
- %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
- %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
- %concat = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%param_0, %param_1, %param_2), kind=kLoop, calls=fused_concat
- %param_3 = f32[] parameter(3)
- %reduce = f32[4,8,8]{2,1,0} fusion(%param_3, %concat), kind=kLoop, calls=fused_reduce
- %param_4 = f32[4,8,8]{2,1,0} parameter(4)
- %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%concat, %param_4), kind=kLoop, calls=fused_broadcast
- ROOT tuple = (f32[4,8,8]{2,1,0}, f32[1,4,1424,8,8]{4,3,2,1,0}) tuple(%reduce, %broadcast)
- }
- )");
-
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, FuseOnlySmallConstant) {
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- ENTRY main {
- param_0 = f32[32,32]{1,0} parameter(0)
- c_1 = f32[] constant(1)
- c_2 = f32[32,32] constant({...})
- broadcast = f32[32,32]{1,0} broadcast(c_1), dimensions={}
- add = f32[32,32]{1,0} add(param_0, broadcast)
- ROOT mul = f32[32,32]{1,0} multiply(c_2, add)
- }
- )");
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
- EXPECT_THAT(root->fused_expression_root(),
- GmockMatch(m::Multiply(
- m::Parameter(),
- m::Add(m::Parameter(), m::Broadcast(m::Constant())))));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) {
- auto module = *ParseAndReturnVerifiedModule(R"(
- HloModule module
-
- fused_computation.1 {
- iota.9.7 = s32[3,1,1]{2,1,0} iota(), iota_dimension=0
- param_3.29 = s32[] parameter(2)
- pad.2.7 = s32[3,1,2]{2,1,0} pad(iota.9.7, param_3.29), padding=0_0x0_0x0_1
- param_2.39 = s32[] parameter(1)
- broadcast.76.1 = s32[3,1,2]{2,1,0} broadcast(param_2.39), dimensions={}
- compare.9.1 = pred[3,1,2]{2,1,0} compare(pad.2.7, broadcast.76.1), direction=GE
- param_1.73 = s32[2]{0} parameter(0)
- broadcast.78.1 = s32[3,2]{1,0} broadcast(param_1.73), dimensions={1}
- bitcast.1 = s32[3,2]{1,0} bitcast(pad.2.7)
- compare.10.1 = pred[3,2]{1,0} compare(bitcast.1, broadcast.78.1), direction=LE
- bitcast.2 = pred[3,1,2]{2,1,0} bitcast(compare.10.1)
- ROOT and.3.1 = pred[3,1,2]{2,1,0} and(compare.9.1, bitcast.2)
- }
-
- and {
- x = pred[] parameter(0)
- y = pred[] parameter(1)
- ROOT and = pred[] and(x, y)
- }
-
- fused_computation.2 {
- param0 = pred[3,1,2]{2,1,0} parameter(0)
- slice = pred[1,1,2]{2,1,0} slice(param0), slice={[0:1], [0:1], [0:2]}
- bitcast = pred[2]{0} bitcast(slice)
- init = pred[] constant(true)
- reduce = pred[2]{0} reduce(param0, init), dimensions={0,1}, to_apply=and
- and = pred[2]{0} and(bitcast, reduce)
- pad = pred[3]{0} pad(and, init), padding=0_1
- broadcast = pred[3,2]{1,0} broadcast(pad), dimensions={0}
- bitcast2 = pred[6]{0} bitcast(broadcast)
- broadcast2 = pred[2,3]{1,0} broadcast(pad), dimensions={1}
- bitcast3 = pred[6]{0} bitcast(broadcast2)
- ROOT and2 = pred[6]{0} and(bitcast2, bitcast3)
- }
-
- ENTRY main {
- p0 = s32[2]{0} parameter(0)
- p1 = s32[] parameter(1)
- p2 = s32[] parameter(2)
- fusion1 = pred[3,1,2]{2,1,0} fusion(p0, p1, p2), kind=kLoop, calls=fused_computation.1
- ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2
- }
- )");
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, CanMergeTritonFusionWithBothProducerAndConsumer) {
-#ifndef GOOGLE_CUDA
- GTEST_SKIP() << "Triton fusion only enable for CUDA devices.";
-#endif
-
- const std::string kHloText = R"(
-HloModule t
-add {
- Arg_0 = f32[] parameter(0)
- Arg_1 = f32[] parameter(1)
- ROOT add = f32[] add(Arg_0, Arg_1)
-}
-
-producer_computation {
- parameter_0 = f32[125]{0} parameter(0)
- ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0}
-}
-
-consumer_computation {
- parameter_0 = f32[125,127]{1,0} parameter(0)
- parameter_1 = f32[125,127]{1,0} parameter(1)
- ROOT multiply = f32[125,127]{1,0} multiply(parameter_1, parameter_0)
-}
-
-triton_softmax_computation {
- parameter_0 = f32[125,127]{1,0} parameter(0)
- multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0)
- constant_0 = f32[] constant(0)
- reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add
- broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0}
- ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4)
-}
-
-ENTRY main {
- param_0 = f32[125]{0} parameter(0)
- param_1 = f32[125,127]{1,0} parameter(1)
- producer_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=producer_computation
- triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
- ROOT consumer_fusion = f32[125,127]{1,0} fusion(param_1, triton_softmax), kind=kLoop, calls=consumer_computation
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
- auto debug_options = module->config().debug_options();
- debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true);
- module->mutable_config().set_debug_options(debug_options);
-
- EXPECT_TRUE(priority_fusion_.Run(module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-
- HloInstruction* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
- EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom);
- EXPECT_TRUE(IsGenericTritonFusion(*root));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) {
- auto module = *ParseAndReturnVerifiedModule(R"(
- %reducer {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- add = f32[] add(p0, p1)
- ROOT max = f32[] maximum(add, p0)
- }
-
- %fused_reduce {
- p0 = f32[256] parameter(0)
- p1 = f32[] parameter(1)
- ROOT reduce = f32[] reduce(p0, p1), dimensions={0}, to_apply=%reducer
- }
-
- ENTRY fusion {
- p0 = f32[256] parameter(0)
- p1 = f32[] parameter(1)
- ROOT %reduce = f32[] fusion(p0, p1), kind=kInput, calls=fused_reduce
- }
- )");
- EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc
index b8496b3..cd0be0b 100644
--- a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc
+++ b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc
@@ -41,6 +41,7 @@
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/nvptx_compiler.h"
#include "xla/service/hlo_module_config.h"
+#include "xla/stream_executor/cuda/nvjitlink_support.h"
#include "xla/stream_executor/cuda/ptx_compilation_method.h"
#include "xla/stream_executor/cuda/ptx_compiler_support.h"
#include "xla/stream_executor/cuda/ptx_linking_method.h"
@@ -155,6 +156,29 @@
// Compiled without libnvptxcompiler support
GTEST_SKIP() << "libnvptxcompiler is not supported in this build.";
}
+
+ if (!stream_executor::IsLibNvJitLinkSupported() &&
+ (compilation_method == PtxCompilationMethod::kNvJitLink ||
+ linking_method == PtxLinkingMethod::kNvJitLink)) {
+ // Compiled without libnvjitlink support
+ GTEST_SKIP() << "libnvjitlink is not supported in this build.";
+ }
+
+ if (compilation_method == PtxCompilationMethod::kNvJitLink &&
+ linking_method != PtxLinkingMethod::kNvJitLink) {
+ // When compilation method is NvJitLink, linking method must be NvJitLink
+ // as well.
+ GTEST_SKIP() << "Compilation method NvJitLink is only supported if the "
+ "linking method is NvJitLink as well.";
+ }
+
+ if (compilation_method == PtxCompilationMethod::kPtxas &&
+ linking_method == PtxLinkingMethod::kNvJitLink) {
+ // We could support this combination, but it would require some
+ // refactoring of the flags.
+ GTEST_SKIP() << "Compilation method Ptxas is not supported with linking "
+ "method NvJitLink.";
+ }
}
void SetDebugOptionsFromPtxSettings(DebugOptions* debug_options,
@@ -163,6 +187,10 @@
debug_options->set_xla_gpu_enable_libnvptxcompiler(
compilation_method == PtxCompilationMethod::kNvPtxCompiler);
+ debug_options->set_xla_gpu_enable_libnvjitlink(
+ compilation_method == PtxCompilationMethod::kNvJitLink ||
+ linking_method == PtxLinkingMethod::kNvJitLink);
+
debug_options->set_xla_gpu_enable_llvm_module_compilation_parallelism(
linking_method != PtxLinkingMethod::kNone);
debug_options->set_xla_gpu_force_compilation_parallelism(12);
@@ -316,9 +344,11 @@
::testing::Combine(
::testing::Values("simple", "parallel_compilation", "requires_sm90a"),
::testing::Values(PtxCompilationMethod::kNvPtxCompiler,
- PtxCompilationMethod::kPtxas),
+ PtxCompilationMethod::kPtxas,
+ PtxCompilationMethod::kNvJitLink),
::testing::Values(PtxLinkingMethod::kNone, PtxLinkingMethod::kNvLink,
- PtxLinkingMethod::kDriver)),
+ PtxLinkingMethod::kDriver,
+ PtxLinkingMethod::kNvJitLink)),
[](const ::testing::TestParamInfo<std::tuple<
std::string_view, PtxCompilationMethod, PtxLinkingMethod>>& info) {
return GenerateParametrizedTestname(std::get<0>(info.param),
diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc
deleted file mode 100644
index ac5419c..0000000
--- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc
+++ /dev/null
@@ -1,131 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleReduce(HloInstruction *hlo) override {
- auto instr = Cast<HloReduceInstruction>(hlo);
- absl::InlinedVector<HloInstruction *, 2> input_reshapes;
- absl::InlinedVector<Shape, 2> canonical_reduce_shapes;
-
- int idx = -1;
- std::vector<int64_t> updated_reduced_dimensions;
- for (HloInstruction *reduced_op : instr->inputs()) {
- idx++;
- const Shape &input_shape = reduced_op->shape();
- const Shape &reduce_shape = instr->shape().IsTuple()
- ? instr->shape().tuple_shapes(idx)
- : instr->shape();
-
- if (!ShapeUtil::HasDegenerateDimensions(reduced_op->shape())) {
- return absl::OkStatus();
- }
- Shape canonical_input_shape =
- ShapeUtil::DropDegenerateDimensions(input_shape);
-
- Shape canonical_reduce_shape =
- ShapeUtil::DropDegenerateDimensions(reduce_shape);
-
- auto reduced_dimensions = instr->dimensions();
- int64_t shift = 0;
-
- for (int dim = 0; dim < input_shape.rank(); dim++) {
- if (input_shape.dimensions(dim) == 1) {
- shift++;
- } else {
- if (absl::c_linear_search(reduced_dimensions, dim) && idx == 0) {
- // Only populate on first iteration.
- updated_reduced_dimensions.push_back(dim - shift);
- }
- }
- }
-
- if (updated_reduced_dimensions.empty()) {
- std::unique_ptr<HloInstruction> reshape =
- HloInstruction::CreateBitcast(reduce_shape, reduced_op);
- return ReplaceWithNewInstruction(instr, std::move(reshape));
- }
-
- input_reshapes.push_back(instr->parent()->AddInstruction(
- HloInstruction::CreateBitcast(canonical_input_shape, reduced_op)));
- canonical_reduce_shapes.push_back(canonical_reduce_shape);
- }
-
- Shape canonical_reduce_shape =
- ShapeUtil::MakeMaybeTupleShape(canonical_reduce_shapes);
- const Shape &orig_reduce_shape = instr->shape();
- std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
- canonical_reduce_shape, input_reshapes, instr->init_values(),
- updated_reduced_dimensions, instr->to_apply());
- instr->SetupDerivedInstruction(new_reduce.get());
-
- if (canonical_reduce_shape != instr->shape()) {
- HloInstruction *wrapped_reduce =
- instr->parent()->AddInstruction(std::move(new_reduce));
- absl::InlinedVector<HloInstruction *, 2> out;
- if (!canonical_reduce_shape.IsTuple()) {
- new_reduce =
- HloInstruction::CreateBitcast(orig_reduce_shape, wrapped_reduce);
- } else {
- for (int oidx = 0; oidx < instr->input_count(); oidx++) {
- HloInstruction *gte = instr->parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
- out.push_back(
- instr->parent()->AddInstruction(HloInstruction::CreateBitcast(
- orig_reduce_shape.tuple_shapes(oidx), gte)));
- }
- new_reduce = HloInstruction::CreateTuple(out);
- }
- }
-
- return ReplaceWithNewInstruction(instr, std::move(new_reduce));
- }
-};
-
-absl::StatusOr<bool> ReductionDegenerateDimRemover::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- TF_ASSIGN_OR_RETURN(bool changed,
- ReductionDegenerateDimRemoverVisitor().RunOnModule(
- module, execution_threads));
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h b/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h
deleted file mode 100644
index 03d6819..0000000
--- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_
-#define XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Enforces the invariant that reduction input and output have no degenerate
-// (size 1) dimension. Since these dimensions are physically meaningless, they
-// are removed using bitcasts.
-//
-// For example,
-//
-// f[1] out = reduce(f[100, 1, 1] input, dimensions={0, 1})
-//
-// becomes:
-//
-//
-// f[100] tmp1 = f[100] bitcast(f[100, 1, 1], input)
-// f[] tmp2 = reduce(f[100] tmp1, dimensions={0})
-// f[1] out = f[] bitcast(tmp2)
-//
-class ReductionDegenerateDimRemover : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "reduction-degenerate-dim-remover";
- }
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc b/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc
deleted file mode 100644
index 8ab4fcf..0000000
--- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_dimension_grouper.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/layout_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleReduce(HloInstruction *hlo) override {
- auto reduce = Cast<HloReduceInstruction>(hlo);
-
- VLOG(4) << "Input: " << reduce->ToString();
-
- absl::InlinedVector<HloInstruction *, 2> reduce_inputs_grouped;
- std::vector<int64_t> reduced_dims_grouped;
-
- int idx = -1;
- for (HloInstruction *operand : reduce->inputs()) {
- idx++;
- std::vector<int64_t> new_grouped_dims;
- const Shape &shape = operand->shape();
- CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
- << "Default layout should be enforced on reduction operand";
- auto is_reduced = [&](int dim) {
- return absl::c_linear_search(reduce->dimensions(), dim);
- };
-
- bool changed = false;
- int64_t next_dim_size = 1;
-
- // Since we have enforced the standard layout, iteration over logical
- // dimensions is equivalent to iteration over the major-to-minor order.
- for (int logical_dim = 0; logical_dim < shape.rank(); logical_dim++) {
- VLOG(5) << "Processing dimension " << logical_dim << " of size "
- << shape.dimensions(logical_dim);
- if (is_reduced(logical_dim) && logical_dim < shape.rank() - 1 &&
- is_reduced(logical_dim + 1)) {
- VLOG(5) << "This and consecutive dimension are reduced, merging";
- changed = true;
- next_dim_size *= shape.dimensions(logical_dim);
- continue;
- }
-
- if (is_reduced(logical_dim)) {
- new_grouped_dims.push_back(next_dim_size *
- shape.dimensions(logical_dim));
- if (idx == 0) {
- // Only populate for first argument.
- reduced_dims_grouped.push_back(new_grouped_dims.size() - 1);
- }
- next_dim_size = 1;
- } else {
- new_grouped_dims.push_back(shape.dimensions(logical_dim));
- }
- }
-
- if (!changed) { // Since all inputs have same shape dimensions.
- return absl::OkStatus();
- }
-
- Shape grouped_shape =
- ShapeUtil::MakeShape(shape.element_type(), new_grouped_dims);
- reduce_inputs_grouped.push_back(reduce->parent()->AddInstruction(
- HloInstruction::CreateBitcast(grouped_shape, operand),
- &operand->metadata()));
- VLOG(5) << "Adding bitcast: " << reduce_inputs_grouped.back()->ToString();
- }
-
- std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
- reduce->shape(), reduce_inputs_grouped, reduce->init_values(),
- reduced_dims_grouped, reduce->to_apply());
- VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
- return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
- }
-};
-
-absl::StatusOr<bool> ReductionDimensionGrouper::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- TF_ASSIGN_OR_RETURN(bool changed, ReduceDimensionGroupVisitor().RunOnModule(
- module, execution_threads));
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h b/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h
deleted file mode 100644
index 8ee4efd..0000000
--- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_
-#define XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Groups adjacent (logically and physically) reduced dimensions in reduction
-// input.
-//
-// Precondition: ReductionLayoutNormalizer has been run (physical proximity and
-// logical proximity become the same).
-//
-// For example,
-//
-// f[] out = reduce(f[10,20,30] input, dimensions={0,1,2})
-//
-// becomes:
-//
-// f[600] tmp = f[600] bitcast(f[10,20,30] input)
-// f[] out = reduce(f[600] tmp, dimensions={0})
-//
-class ReductionDimensionGrouper : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "reduction-dimension-grouper";
- }
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc
deleted file mode 100644
index a91fdf7..0000000
--- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc
+++ /dev/null
@@ -1,203 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_layout_normalizer.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
- absl::Status HandleReduce(HloInstruction *hlo) override {
- auto reduce = Cast<HloReduceInstruction>(hlo);
- VLOG(5) << "Input: " << reduce->ToString();
-
- int operand_idx = -1;
-
- absl::InlinedVector<HloInstruction *, 2> canonical_reduce_inputs;
- absl::InlinedVector<Shape, 2> new_reduce_shapes;
-
- DimensionVector out_reduce_dimensions;
- const Shape &first_instruction_shape = reduce->inputs()[0]->shape();
-
- for (HloInstruction *operand : reduce->inputs()) {
- operand_idx++;
-
- if (operand_idx != 0 &&
- operand->shape().layout() != first_instruction_shape.layout()) {
- HloInstruction *copy =
- reduce->parent()->AddInstruction(HloInstruction::CreateUnary(
- operand->shape(), HloOpcode::kCopy, operand));
-
- LayoutUtil::ClearLayout(copy->mutable_shape());
- TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
- first_instruction_shape, copy->mutable_shape()));
-
- copy->set_metadata(operand->metadata());
- operand = copy;
- VLOG(3) << "Copying to establish consistent inputs layout: "
- << copy->ToString();
- }
-
- const Shape &operand_shape = operand->shape();
- const Layout &operand_layout = operand_shape.layout();
-
- const Shape &reduce_shape =
- reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(operand_idx)
- : reduce->shape();
-
- DimensionVector new_reduce_dimensions;
- DimensionVector new_operand_shape_data;
- DimensionVector new_reduce_shape_data;
-
- // The layout order of the reduction output can be different to the
- // ordering of kept dimensions in the input operand, thus we need to
- // calculate the new layout.
- DimensionVector new_reduce_shape_layout(reduce_shape.rank());
- std::vector<int64_t> reduce_shape_logical_to_physical =
- LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout());
-
- auto to_reduce_logical_dim = [&](int64_t op_logical_dim) {
- return op_logical_dim -
- absl::c_count_if(reduce->dimensions(), [&](int64_t dim) {
- CHECK(dim != op_logical_dim);
- return dim < op_logical_dim;
- });
- };
-
- for (int i = 0; i < operand_shape.rank(); i++) {
- // Process the dimensions in the major-to-minor order in order to
- // enforce the default layout.
- int64_t major_to_minor_dim_idx = operand_shape.rank() - i - 1;
- int64_t logical_dim =
- operand_layout.minor_to_major(major_to_minor_dim_idx);
- int64_t dim_size = operand_shape.dimensions(logical_dim);
- VLOG(5) << "Processing logical dimension " << logical_dim << " of size "
- << dim_size;
- new_operand_shape_data.push_back(dim_size);
-
- if (absl::c_linear_search(reduce->dimensions(), logical_dim)) {
- new_reduce_dimensions.push_back(i);
- } else {
- new_reduce_shape_data.push_back(dim_size);
- int64_t logical_reduce_dim = to_reduce_logical_dim(logical_dim);
- int64_t physical_reduce_dim =
- reduce_shape_logical_to_physical[logical_reduce_dim];
- VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", "
- << "physical_reduce_dim = " << physical_reduce_dim;
- new_reduce_shape_layout[reduce_shape.rank() - physical_reduce_dim -
- 1] = new_reduce_shape_data.size() - 1;
- }
- }
-
- Shape new_operand_shape = ShapeUtil::MakeShape(
- operand_shape.element_type(), new_operand_shape_data);
- Shape new_reduce_shape = ShapeUtil::MakeShapeWithDenseLayout(
- reduce_shape.element_type(), new_reduce_shape_data,
- new_reduce_shape_layout);
-
- if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) {
- return absl::OkStatus();
- }
-
- HloInstruction *canonical_reduce_input =
- new_operand_shape != operand_shape
- ? reduce->parent()->AddInstruction(
- HloInstruction::CreateBitcast(new_operand_shape, operand))
- : operand;
- canonical_reduce_input->set_metadata(operand->metadata());
- VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
-
- new_reduce_shapes.push_back(new_reduce_shape);
- canonical_reduce_inputs.push_back(canonical_reduce_input);
-
- if (out_reduce_dimensions.empty()) {
- out_reduce_dimensions = new_reduce_dimensions;
- } else {
- TF_RET_CHECK(out_reduce_dimensions == new_reduce_dimensions);
- }
- }
-
- Shape new_reduce_shape = ShapeUtil::MakeMaybeTupleShape(new_reduce_shapes);
-
- std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
- new_reduce_shape, canonical_reduce_inputs, reduce->init_values(),
- out_reduce_dimensions, reduce->to_apply());
- VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
- const Shape &orig_reduce_shape = reduce->shape();
-
- if (new_reduce_shape != orig_reduce_shape) {
- HloInstruction *wrapped_reduce =
- reduce->parent()->AddInstruction(std::move(new_reduce));
-
- if (!new_reduce_shape.IsTuple()) {
- new_reduce =
- HloInstruction::CreateBitcast(reduce->shape(), wrapped_reduce);
- } else {
- // Bitcast each element of the tuple.
- absl::InlinedVector<HloInstruction *, 2> out;
- for (int oidx = 0; oidx < reduce->input_count(); oidx++) {
- HloInstruction *gte = reduce->parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
- out.push_back(
- reduce->parent()->AddInstruction(HloInstruction::CreateBitcast(
- orig_reduce_shape.tuple_shapes(oidx), gte)));
- }
- new_reduce = HloInstruction::CreateTuple(out);
- }
- }
-
- VLOG(5) << "Generated output: " << new_reduce->ToString();
- return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
- }
-};
-
-absl::StatusOr<bool> ReductionLayoutNormalizer::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- TF_ASSIGN_OR_RETURN(bool changed,
- EnforceMinorToMajorReduceOpVisitor().RunOnModule(
- module, execution_threads));
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h b/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h
deleted file mode 100644
index 7d2d207..0000000
--- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_
-#define XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Enforces default (minor-to-major) layout on all reduction inputs.
-// Note that since reduction output can request a custom layout,
-// this pass only guarantees standard layout for the input.
-//
-// For example,
-//
-// f[20,30]{0,1} out = reduce(f[10,20,30]{2,0,1} input, dimensions={0})
-//
-// becomes:
-//
-// f[20,10,30] tmp = f[20,10,30] bitcast(f[10,20,30]{2,0,1} input)
-// f[20,30]{0,1} out = reduce(f[20,10,30]{2,1,0} tmp, dimensions={1})
-class ReductionLayoutNormalizer : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "reduction-layout-normalizer";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.cc b/third_party/xla/xla/service/gpu/reduction_splitter.cc
deleted file mode 100644
index cd37319..0000000
--- a/third_party/xla/xla/service/gpu/reduction_splitter.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_splitter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <cstdlib>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class ReductionSplitterVisitor : public DfsHloRewriteVisitor {
- public:
- explicit ReductionSplitterVisitor(bool ignore_small_dims)
- : ignore_small_dims_(ignore_small_dims) {}
- absl::Status HandleReduce(HloInstruction *reduce) override {
- VLOG(4) << "Input: " << reduce->ToString();
-
- // Reductions with contiguous dimensions are lowered to efficient code. No
- // need to split such ops.
- if (IsReductionFromOrToContiguousDimensions(*reduce)) {
- VLOG(4) << "Reduction with contiguous dimensions. Return.";
- return absl::OkStatus();
- }
- if (reduce->dimensions().size() < 2) {
- return absl::OkStatus();
- }
- if (!reduce->shape().IsArray()) {
- // TODO(cheshire): Handle variadic reduction.
- return absl::OkStatus();
- }
-
- HloInstruction *operand = reduce->mutable_operand(0);
- const Shape &shape = operand->shape();
- CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
- << "Default layout should be enforced on reduction operand";
- // Verify that contiguous dimensions have been grouped by the
- // ReductionDimensionGrouper pass.
- for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
- for (int64_t j = i + 1; j < reduce->dimensions().size(); ++j) {
- CHECK(abs(reduce->dimensions(i) - reduce->dimensions(j)) > 1)
- << "Reduction dimensions must not be consecutive";
- }
- }
-
- // The reduce op has non-contiguous dimensions. Look for the dimension with
- // the largest shape dimension. Reducing along this dimension first will
- // reduce the output size most effectively.
- int64_t max_shape_dim = 0;
- int64_t max_reduce_dim = 0;
- const auto &input_shape = reduce->operand(0)->shape();
- for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
- if (input_shape.dimensions(reduce->dimensions(i)) > max_shape_dim) {
- max_reduce_dim = reduce->dimensions(i);
- max_shape_dim = input_shape.dimensions(max_reduce_dim);
- }
- }
- if (ignore_small_dims_ && max_shape_dim <= 8) {
- return absl::OkStatus();
- }
-
- // Split the reduction into a pre-reduction and a final reduction.
- VLOG(3) << "Splitting reduction " << reduce->name() << " at dimension "
- << max_reduce_dim;
- std::vector<int64_t> pre_reduce_dims;
- pre_reduce_dims.push_back(max_reduce_dim);
- std::vector<int64_t> pre_reduce_shape_dims(input_shape.dimensions().begin(),
- input_shape.dimensions().end());
- pre_reduce_shape_dims.erase(pre_reduce_shape_dims.begin() + max_reduce_dim);
- Shape pre_reduce_shape = ShapeUtil::MakeShape(
- reduce->shape().element_type(), pre_reduce_shape_dims);
- std::unique_ptr<HloInstruction> pre_reduce = HloInstruction::CreateReduce(
- pre_reduce_shape, reduce->mutable_operand(0),
- reduce->mutable_operand(1), pre_reduce_dims, reduce->to_apply());
- pre_reduce->set_metadata(reduce->metadata());
-
- std::vector<int64_t> final_reduce_dims(reduce->dimensions().begin(),
- reduce->dimensions().end());
- final_reduce_dims.erase(
- std::remove(final_reduce_dims.begin(), final_reduce_dims.end(),
- max_reduce_dim),
- final_reduce_dims.end());
- for (int64_t i = 0; i < final_reduce_dims.size(); ++i) {
- if (final_reduce_dims[i] > max_reduce_dim) {
- final_reduce_dims[i]--;
- }
- }
- std::unique_ptr<HloInstruction> final_reduce = HloInstruction::CreateReduce(
- reduce->shape(),
- reduce->parent()->AddInstruction(std::move(pre_reduce)),
- reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply());
- return ReplaceWithNewInstruction(reduce, std::move(final_reduce));
- }
-
- private:
- bool ignore_small_dims_;
-};
-
-absl::StatusOr<bool> ReductionSplitter::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- TF_ASSIGN_OR_RETURN(bool changed,
- ReductionSplitterVisitor(ignore_small_dims_)
- .RunOnModule(module, execution_threads));
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.h b/third_party/xla/xla/service/gpu/reduction_splitter.h
deleted file mode 100644
index 7e76525..0000000
--- a/third_party/xla/xla/service/gpu/reduction_splitter.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
-#define XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Splits a reduce op into two consecutive reduce ops if the reduce dimensions
-// are not contiguous. Ignores small reduce dimensions if `ignore_small_dims` is
-// set.
-//
-// Reductions with non-contiguous dimensions are emitted as simple element-wise
-// loops. This is inefficient when reducing large input shape dimensions.
-// Splitting such reductions allows using more efficient reduction emitters.
-//
-// This pass splits reduce ops into two consecutive reduce ops. Run it to a
-// fixpoint to split reduce ops along multiple dimensions.
-//
-// Precondition: ReductionDimensionGrouper has been run and adjacent reduce
-// dimentsions have been grouped. Reduction layouts have been normalized.
-
-class ReductionSplitter : public HloModulePass {
- public:
- explicit ReductionSplitter(bool ignore_small_dims)
- : ignore_small_dims_(ignore_small_dims) {}
- absl::string_view name() const override { return "reduction-splitter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- bool ignore_small_dims_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/reduction_splitter_test.cc
deleted file mode 100644
index 13a5210..0000000
--- a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_splitter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape_util.h"
-#include "xla/test.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class ReductionSplitterTest : public HloTestBase {};
-
-TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test
-
- add_computation {
- x = f32[] parameter(0)
- y = f32[] parameter(1)
- ROOT add = f32[] add(x, y)
- }
-
- ENTRY entry_computation {
- param_0 = f16[6,16,512,64]{3,2,1,0} parameter(0)
- transpose.1781 = f16[6,512,16,64]{3,1,2,0} transpose(param_0), dimensions={0,2,1,3}
- convert.6986 = f32[6,512,16,64]{3,1,2,0} convert(transpose.1781)
- bitcast.2136 = f32[6,16,512,64]{3,2,1,0} bitcast(convert.6986)
- constant_11111 = f32[] constant(0)
- ROOT reduce.982 = f32[16,64]{1,0} reduce(bitcast.2136, constant_11111), dimensions={0,2}, to_apply=add_computation
- }
- )")
- .value();
- ASSERT_TRUE(
- ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root_reduction =
- module->entry_computation()->root_instruction();
- ASSERT_THAT(root_reduction,
- GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
-
- auto* pre_reduction = root_reduction->operand(0);
- EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({2}));
- EXPECT_THAT(pre_reduction->shape(), ShapeUtil::MakeShape(F32, {6, 16, 64}));
- EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({0}));
- EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
-}
-
-TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test
-
- add_computation {
- x = f32[] parameter(0)
- y = f32[] parameter(1)
- ROOT add = f32[] add(x, y)
- }
-
- ENTRY entry_computation {
- param_0 = f32[1024,16,512,64,128]{4,3,2,1,0} parameter(0)
- constant_11111 = f32[] constant(0)
- ROOT reduce.982 = f32[16,64]{1,0} reduce(param_0, constant_11111), dimensions={2,0,4}, to_apply=add_computation
- }
- )")
- .value();
- ASSERT_TRUE(
- ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
- SCOPED_TRACE(module->ToString());
- const HloInstruction* root_reduction =
- module->entry_computation()->root_instruction();
- ASSERT_THAT(root_reduction,
- GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
-
- auto* pre_reduction = root_reduction->operand(0);
- EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({0}));
- EXPECT_THAT(pre_reduction->shape(),
- ShapeUtil::MakeShape(F32, {16, 512, 64, 128}));
- EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({1, 3}));
- EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
-}
-
-TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test
-
- add_computation {
- x = f32[] parameter(0)
- y = f32[] parameter(1)
- ROOT add = f32[] add(x, y)
- }
-
- ENTRY entry_computation {
- param_0 = f32[16,8,1024,8]{3,2,1,0} parameter(0)
- constant_11111 = f32[] constant(0)
- ROOT reduce.982 = f32[16,1024]{1,0} reduce(param_0, constant_11111), dimensions={3,1}, to_apply=add_computation
- }
- )")
- .value();
- EXPECT_FALSE(
- ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
- EXPECT_TRUE(
- ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
-}
-
-TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule test
-
- add_computation {
- x = f32[] parameter(0)
- y = f32[] parameter(1)
- ROOT add = f32[] add(x, y)
- }
-
- ENTRY entry_computation {
- param_0 = f32[128,128,64,128]{3,2,1,0} parameter(0)
- constant_11111 = f32[] constant(0)
- // The dimenstions to keep (1 and 2) are contiguous.
- ROOT reduce.982 = f32[128,64]{1,0} reduce(param_0, constant_11111), dimensions={3,0}, to_apply=add_computation
- }
- )")
- .value();
- EXPECT_FALSE(
- ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/rename_fusions.cc b/third_party/xla/xla/service/gpu/rename_fusions.cc
deleted file mode 100644
index a2a3048..0000000
--- a/third_party/xla/xla/service/gpu/rename_fusions.cc
+++ /dev/null
@@ -1,92 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/rename_fusions.h"
-
-#include <memory>
-#include <string>
-
-#include "absl/container/btree_set.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr absl::string_view FusionKindToString(
- HloInstruction::FusionKind kind) {
- switch (kind) {
- case HloInstruction::FusionKind::kCustom:
- return "custom";
- case HloInstruction::FusionKind::kLoop:
- return "loop";
- case HloInstruction::FusionKind::kInput:
- return "input";
- case HloInstruction::FusionKind::kOutput:
- return "output";
- }
-}
-
-std::string MakeFusionHeroNames(const HloInstruction* instruction) {
- std::unique_ptr<HloFusionAdaptor> fusion_adaptor =
- HloFusionAdaptor::ForInstruction(instruction);
- absl::btree_set<absl::string_view> heroes;
-
- for (auto root : fusion_adaptor->GetRoots()) {
- heroes.insert(HloOpcodeString(FindNonTrivialHero(root).opcode()));
- }
- return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}});
-}
-
-void RenameFusion(HloModule* module, HloInstruction* instruction) {
- std::string hero_names = MakeFusionHeroNames(instruction);
- module->SetAndUniquifyInstrName(
- instruction, absl::StrCat(FusionKindToString(instruction->fusion_kind()),
- "_", hero_names, "_fusion"));
- module->SetAndUniquifyComputationName(
- instruction->fused_instructions_computation(),
- absl::StrCat("fused_", hero_names));
-}
-
-} // namespace
-
-absl::StatusOr<bool> RenameFusions::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- for (HloComputation* computation : module->MakeNonfusionComputations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kFusion ||
- instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) {
- continue;
- }
- RenameFusion(module, instruction);
- }
- }
- return true;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/rename_fusions.h b/third_party/xla/xla/service/gpu/rename_fusions.h
deleted file mode 100644
index c3065a4..0000000
--- a/third_party/xla/xla/service/gpu/rename_fusions.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_RENAME_FUSIONS_H_
-#define XLA_SERVICE_GPU_RENAME_FUSIONS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// An HLO pass that gives fusions and fused computations descriptive names.
-//
-// The name is based on hero instructions and the fusion kind, i.e.
-// Fusions get name "<fusion kind>_<hero instrucitons>_fusion",
-// and fused computations get name "fused_<hero instructions>".
-// In the case of multiple roots, the hero instructions in the name are
-// underscore-separated and alphabetically sorted.
-
-class RenameFusions : public HloModulePass {
- absl::string_view name() const override { return "rename_fusions"; }
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_RENAME_FUSIONS_H_
diff --git a/third_party/xla/xla/service/gpu/rename_fusions_test.cc b/third_party/xla/xla/service/gpu/rename_fusions_test.cc
deleted file mode 100644
index 60c97cf..0000000
--- a/third_party/xla/xla/service/gpu/rename_fusions_test.cc
+++ /dev/null
@@ -1,83 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/rename_fusions.h"
-
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-
-class RenameFusionsTest : public HloTestBase {
- protected:
- RenameFusions rename_fusions_;
-};
-
-TEST_F(RenameFusionsTest, FusionInstructionNames) {
- absl::string_view kHlo = R"(
- HloModule test_module
-
- square {
- p = f32[16384] parameter(0)
- ROOT m = f32[16384] multiply(p, p)
- }
-
- exp {
- p = f32[16384] parameter(0)
- ROOT e = f32[16384] exponential(p)
- }
-
- log {
- p = f32[16384] parameter(0)
- ROOT l = f32[16384] log(p)
- }
-
- add {
- p0 = f32[] parameter(0)
- p1 = f32[] parameter(1)
- ROOT add = f32[] add(p0, p1)
- }
-
- ENTRY main {
- p0 = bf16[1024,8192] parameter(0)
- p1 = f32[8192] parameter(1)
- p2 = f32[16384] parameter(2)
- convert = f32[1024,8192] convert(p0)
- broadcast = f32[1024,8192] broadcast(p1), dimensions={1}
- c0 = f32[] constant(0)
- multiply = f32[1024,8192] multiply(broadcast, convert)
- reduce = f32[1024] reduce(multiply, c0), dimensions={1}, to_apply=add
- convert.1 = bf16[1024] convert(reduce)
- s = f32[16384] fusion(p2), kind=kLoop, calls=square
- e = f32[16384] fusion(s), kind=kLoop, calls=exp
- l = f32[16384] fusion(s), kind=kInput, calls=log
- ROOT result = (bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(convert.1, l, e)
- })";
-
- RunAndFilecheckHloRewrite(kHlo, std::move(rename_fusions_), R"(
-CHECK: ENTRY %main
-CHECK: %loop_multiply_fusion{{.*}} calls=%fused_multiply
-CHECK: %input_log_fusion{{.*}} calls=%fused_log
-CHECK: %loop_exponential_fusion{{.*}} calls=%fused_exponential
-CHECK: ROOT %result
- )");
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD
index 861fcca..d6b6f1f 100644
--- a/third_party/xla/xla/service/gpu/runtime/BUILD
+++ b/third_party/xla/xla/service/gpu/runtime/BUILD
@@ -79,7 +79,6 @@
"//xla/service:executable",
"//xla/service:global_device_id",
"//xla/service/gpu:buffer_allocations",
- "//xla/service/gpu:gpu_fused_mha_runner",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:stream_executor_util",
@@ -121,7 +120,6 @@
":copy_thunk",
":cudnn_thunk",
":custom_call_thunk",
- ":fused_mha_thunk",
":gemm_thunk",
":gpublas_lt_matmul_thunk",
":kernel_thunk",
@@ -661,28 +659,6 @@
)
cc_library(
- name = "fused_mha_thunk",
- srcs = ["fused_mha_thunk.cc"],
- hdrs = ["fused_mha_thunk.h"],
- deps = [
- ":thunk",
- "//xla:util",
- "//xla:xla_data_proto_cc",
- "//xla/service:buffer_assignment",
- "//xla/service/gpu:buffer_allocations",
- "//xla/service/gpu:gpu_fused_mha_runner",
- "//xla/stream_executor",
- "//xla/stream_executor:lazy_op_runner",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/synchronization",
- "@local_tsl//tsl/platform:errors",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-cc_library(
name = "gemm_thunk",
srcs = ["gemm_thunk.cc"],
hdrs = ["gemm_thunk.h"],
diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
index 5a05bb4..ce99623 100644
--- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
+++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
@@ -1168,314 +1168,6 @@
}
//===----------------------------------------------------------------------===//
-// FusedMHACmd
-//===----------------------------------------------------------------------===//
-
-FusedMHACmd::FusedMHACmd(
- ExecutionStreamId execution_stream_id, GpufMHAConfig config,
- BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1,
- BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output,
- BufferAllocation::Slice scratch, BufferAllocation::Slice mask,
- BufferAllocation::Slice bias, BufferAllocation::Slice activation,
- BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k)
- : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHACmd,
- execution_stream_id),
- config_(std::move(config)),
- lhs_bmm1_buffer_(lhs_bmm1),
- rhs_bmm1_buffer_(rhs_bmm1),
- rhs_bmm2_buffer_(rhs_bmm2),
- output_buffer_(output),
- scratch_buffer_(scratch),
- bias_buffer_(bias),
- activation_buffer_(activation),
- seqlen_q_buffer_(seqlen_q),
- seqlen_k_buffer_(seqlen_k) {}
-
-FusedMultiHeadedAttentionRunner& FusedMHACmd::GetOrCreateRunner(
- const stream_executor::Stream* stream) {
- absl::MutexLock lock(&mutex_);
- auto it = runner_cache_.find(stream);
- if (it == runner_cache_.end()) {
- it = runner_cache_
- .insert({stream, std::make_unique<FusedMultiHeadedAttentionRunner>(
- config_)})
- .first;
- }
- return *it->second;
-}
-
-absl::Status FusedMHACmd::Initialize(const Thunk::InitializeParams& params,
- StateManager& state) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* lazy_runner =
- GetOrCreateRunner(params.command_buffer_trace_stream).AsFusedMHARunner();
- TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig());
- return lazy_runner
- ->GetOrCreateRunner(config, params.command_buffer_trace_stream)
- .status();
-}
-
-absl::Status FusedMHACmd::Record(const Thunk::ExecuteParams& execute_params,
- const RecordParams& record_params,
- se::CommandBuffer* command_buffer) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* lazy_runner =
- GetOrCreateRunner(execute_params.command_buffer_trace_stream)
- .AsFusedMHARunner();
- CHECK(lazy_runner) << "FusedMHA lazy runner cache should have been populated";
-
- const auto& buffer_allocations = *execute_params.buffer_allocations;
- se::DeviceMemoryBase lhs_bmm1_buffer =
- buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_);
- se::DeviceMemoryBase rhs_bmm1_buffer =
- buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_);
- se::DeviceMemoryBase rhs_bmm2_buffer =
- buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_);
- se::DeviceMemoryBase output_buffer =
- buffer_allocations.GetDeviceAddress(output_buffer_);
- se::DeviceMemoryBase scratch_buffer =
- buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
- std::optional<se::DeviceMemoryBase> bias_buffer =
- AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
- std::optional<se::DeviceMemoryBase> activation_buffer =
- AssignBufferIfNotNull(buffer_allocations, activation_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
-
- ExecutionScopeId execution_scope_id = GetExecutionScope(record_params);
- VLOG(5) << "FusedMHACmd with execution_scope_id: "
- << execution_scope_id.value();
- VLOG(5) << " lhs_bmm1_buffer: " << lhs_bmm1_buffer_.ToString();
- VLOG(5) << " rhs_bmm1_buffer: " << rhs_bmm1_buffer_.ToString();
- VLOG(5) << " rhs_bmm2_buffer: " << rhs_bmm2_buffer_.ToString();
- VLOG(5) << " output_buffer: " << output_buffer_.ToString();
- VLOG(5) << " scratch_buffer: " << scratch_buffer_.ToString();
- VLOG(5) << " bias_buffer: " << bias_buffer_.ToString();
- VLOG(5) << " activation_buffer: " << activation_buffer_.ToString();
- VLOG(5) << " seqlen_q_buffer: " << seqlen_q_buffer_.ToString();
- VLOG(5) << " seqlen_k_buffer: " << seqlen_k_buffer_.ToString();
-
- RunFusedMHAOptions opts;
- opts.runner_cache =
- &GetOrCreateRunner(execute_params.command_buffer_trace_stream);
- return AddTracedCommandBuffer(
- execute_params, record_params, command_buffer, [&](se::Stream* stream) {
- return RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer,
- rhs_bmm2_buffer, output_buffer, scratch_buffer,
- bias_buffer, activation_buffer, seqlen_q_buffer,
- seqlen_k_buffer, stream, opts);
- });
-}
-
-FusedMHACmd::BufferUsageVector FusedMHACmd::buffers() {
- BufferUsageVector buffer_usage;
- buffer_usage.reserve(9);
- buffer_usage.push_back({lhs_bmm1_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({rhs_bmm1_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({rhs_bmm2_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({output_buffer_, MemoryAccess::kWrite});
- buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite});
- if (bias_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead});
- }
- if (activation_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({activation_buffer_, MemoryAccess::kRead});
- }
- if (seqlen_q_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead});
- }
- if (seqlen_k_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead});
- }
- return buffer_usage;
-}
-
-//===----------------------------------------------------------------------===//
-// FusedMHABackwardCmd
-//===----------------------------------------------------------------------===//
-
-FusedMHABackwardCmd::FusedMHABackwardCmd(
- ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config,
- BufferAllocation::Slice bmm1_grad_gemm1_rhs,
- BufferAllocation::Slice bmm1_grad_gemm2_rhs,
- BufferAllocation::Slice bmm2_grad_gemm1_lhs,
- BufferAllocation::Slice bmm2_grad_gemm2_rhs,
- BufferAllocation::Slice d_output, BufferAllocation::Slice scratch,
- BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs,
- BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s,
- BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output,
- BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q,
- BufferAllocation::Slice seqlen_k)
- : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHABackwardCmd,
- execution_stream_id),
- config_(std::move(config)),
- bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs),
- bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs),
- bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs),
- bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs),
- d_output_buffer_(d_output),
- scratch_buffer_(scratch),
- d_bmm1_lhs_buffer_(d_bmm1_lhs),
- d_bmm1_rhs_buffer_(d_bmm1_rhs),
- d_bmm2_rhs_buffer_(d_bmm2_rhs),
- d_s_buffer_(d_s),
- d_bias_buffer_(d_bias),
- fwd_output_buffer_(fwd_output),
- bias_buffer_(bias),
- seqlen_q_buffer_(seqlen_q),
- seqlen_k_buffer_(seqlen_k) {}
-
-FusedMultiHeadedAttentionBackwardRunner& FusedMHABackwardCmd::GetOrCreateRunner(
- const stream_executor::Stream* stream) {
- absl::MutexLock lock(&mutex_);
- auto it = runner_cache_.find(stream);
- if (it == runner_cache_.end()) {
- it = runner_cache_
- .insert({stream,
- std::make_unique<FusedMultiHeadedAttentionBackwardRunner>(
- config_)})
- .first;
- }
- return *it->second;
-}
-
-absl::Status FusedMHABackwardCmd::Initialize(
- const Thunk::InitializeParams& params, StateManager& state) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>* lazy_runner =
- GetOrCreateRunner(params.command_buffer_trace_stream)
- .AsFusedMHABackwardRunner();
- TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig());
- return lazy_runner
- ->GetOrCreateRunner(config, params.command_buffer_trace_stream)
- .status();
-}
-
-absl::Status FusedMHABackwardCmd::Record(
- const Thunk::ExecuteParams& execute_params,
- const RecordParams& record_params, se::CommandBuffer* command_buffer) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>* lazy_runner =
- GetOrCreateRunner(execute_params.command_buffer_trace_stream)
- .AsFusedMHABackwardRunner();
- CHECK(lazy_runner)
- << "FusedMHABackward lazy runner cache should have been populated";
-
- const auto& buffer_allocations = *execute_params.buffer_allocations;
- se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_);
-
- se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_);
-
- se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_);
-
- se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_);
-
- se::DeviceMemoryBase d_output_buffer =
- buffer_allocations.GetDeviceAddress(d_output_buffer_);
-
- se::DeviceMemoryBase scratch_buffer =
- buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
- se::DeviceMemoryBase d_bmm1_lhs_buffer =
- buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_);
-
- se::DeviceMemoryBase d_bmm1_rhs_buffer =
- buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_);
-
- se::DeviceMemoryBase d_bmm2_rhs_buffer =
- buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_);
-
- std::optional<se::DeviceMemoryBase> d_s_buffer =
- AssignBufferIfNotNull(buffer_allocations, d_s_buffer_);
- std::optional<se::DeviceMemoryBase> d_bias_buffer =
- AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_);
- std::optional<se::DeviceMemoryBase> fwd_output_buffer =
- AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_);
- std::optional<se::DeviceMemoryBase> bias_buffer =
- AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
-
- ExecutionScopeId execution_scope_id = GetExecutionScope(record_params);
- VLOG(5) << "FusedMHABackwardCmd with execution_scope_id: "
- << execution_scope_id.value();
- VLOG(5) << "bmm1_grad_gemm1_rhs_buffer"
- << bmm1_grad_gemm1_rhs_buffer_.ToString();
- VLOG(5) << "bmm1_grad_gemm2_rhs_buffer"
- << bmm1_grad_gemm2_rhs_buffer_.ToString();
- VLOG(5) << "bmm2_grad_gemm1_lhs_buffer"
- << bmm2_grad_gemm1_lhs_buffer_.ToString();
- VLOG(5) << "bmm2_grad_gemm2_rhs_buffer"
- << bmm2_grad_gemm2_rhs_buffer_.ToString();
- VLOG(5) << "d_output_buffer" << d_output_buffer_.ToString();
- VLOG(5) << "scratch_buffer" << scratch_buffer_.ToString();
- VLOG(5) << "d_bmm1_lhs_buffer" << d_bmm1_lhs_buffer_.ToString();
- VLOG(5) << "d_bmm1_rhs_buffer" << d_bmm1_rhs_buffer_.ToString();
- VLOG(5) << "d_bmm2_rhs_buffer" << d_bmm2_rhs_buffer_.ToString();
- VLOG(5) << "d_s_buffer" << d_s_buffer_.ToString();
- VLOG(5) << "d_bias_buffer" << d_bias_buffer_.ToString();
- VLOG(5) << "fwd_output_buffer" << fwd_output_buffer_.ToString();
- VLOG(5) << "bias_buffer" << bias_buffer_.ToString();
- VLOG(5) << "seqlen_q_buffer" << seqlen_q_buffer_.ToString();
- VLOG(5) << "seqlen_k_buffer" << seqlen_k_buffer_.ToString();
-
- RunFusedMHABackwardOptions opts;
- opts.runner_cache =
- &GetOrCreateRunner(execute_params.command_buffer_trace_stream);
- return AddTracedCommandBuffer(
- execute_params, record_params, command_buffer, [&](se::Stream* stream) {
- return RunGpuFMHABackward(
- config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
- bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer,
- d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer,
- d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer,
- fwd_output_buffer, bias_buffer, seqlen_q_buffer, seqlen_k_buffer,
- stream, opts);
- });
-}
-
-FusedMHABackwardCmd::BufferUsageVector FusedMHABackwardCmd::buffers() {
- BufferUsageVector buffer_usage;
- buffer_usage.reserve(15);
-
- buffer_usage.push_back({bmm1_grad_gemm1_rhs_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({bmm1_grad_gemm2_rhs_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({bmm2_grad_gemm1_lhs_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({bmm2_grad_gemm2_rhs_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({d_output_buffer_, MemoryAccess::kWrite});
- buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite});
- buffer_usage.push_back({d_bmm1_lhs_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({d_bmm1_rhs_buffer_, MemoryAccess::kRead});
- buffer_usage.push_back({d_bmm2_rhs_buffer_, MemoryAccess::kRead});
-
- if (d_s_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({d_s_buffer_, MemoryAccess::kRead});
- };
- if (d_bias_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({d_bias_buffer_, MemoryAccess::kRead});
- };
- if (fwd_output_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({fwd_output_buffer_, MemoryAccess::kRead});
- };
- if (bias_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead});
- };
- if (seqlen_q_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead});
- };
- if (seqlen_k_buffer_.allocation() != nullptr) {
- buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead});
- };
-
- return buffer_usage;
-}
-
-//===----------------------------------------------------------------------===//
// CublasLtCmd
//===----------------------------------------------------------------------===//
@@ -1835,8 +1527,9 @@
execute_params.command_buffer_trace_stream, [&](se::Stream* stream) {
ffi::CallOptions options = {
execute_params.buffer_allocations->device_ordinal(),
- execute_params.stream,
- execute_params.buffer_allocations->memory_allocator(),
+ ffi::CallOptions::GpuOptions{
+ execute_params.stream,
+ execute_params.buffer_allocations->memory_allocator()},
/*called_computation=*/nullptr, // TODO(b/342285364)
execute_params.ffi_execution_context};
return ffi::Call(handler_, call_frame, options);
diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h
index b7a077e..27e8fea 100644
--- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h
+++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h
@@ -40,7 +40,6 @@
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/buffer_allocations.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
#include "xla/service/gpu/kernels/custom_kernel.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/matmul_utils.h"
@@ -81,8 +80,6 @@
V(kReduceScatter, "ReduceScatterCmd") \
V(kAllGatherCmd, "AllGatherCmd") \
V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \
- V(kFusedMHACmd, "FusedMHACmd") \
- V(kFusedMHABackwardCmd, "FusedMHABackwardCmd") \
V(kUnknownCmd, "UnknownCmd") \
// clang-format on
@@ -783,112 +780,6 @@
};
//===----------------------------------------------------------------------===//
-// FusedMHACmd
-//===----------------------------------------------------------------------===//
-
-class FusedMHACmd : public TracedCommandBufferCmd {
- public:
- FusedMHACmd(ExecutionStreamId execution_stream_id, GpufMHAConfig config,
- BufferAllocation::Slice lhs_bmm1,
- BufferAllocation::Slice rhs_bmm1,
- BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output,
- BufferAllocation::Slice scratch, BufferAllocation::Slice mask,
- BufferAllocation::Slice bias, BufferAllocation::Slice activation,
- BufferAllocation::Slice seqlen_q,
- BufferAllocation::Slice seqlen_k);
-
- absl::Status Initialize(const Thunk::InitializeParams& params,
- StateManager& state) override;
-
- absl::Status Record(const Thunk::ExecuteParams& execute_params,
- const RecordParams& record_params,
- se::CommandBuffer* command_buffer) override;
-
- BufferUsageVector buffers() override;
-
- bool IsNestedCommandBuffer() const final { return true; }
-
- private:
- FusedMultiHeadedAttentionRunner& GetOrCreateRunner(
- const stream_executor::Stream* stream);
-
- const GpufMHAConfig config_;
- BufferAllocation::Slice lhs_bmm1_buffer_;
- BufferAllocation::Slice rhs_bmm1_buffer_;
- BufferAllocation::Slice rhs_bmm2_buffer_;
- BufferAllocation::Slice output_buffer_;
- BufferAllocation::Slice scratch_buffer_;
- BufferAllocation::Slice bias_buffer_;
- BufferAllocation::Slice activation_buffer_;
- BufferAllocation::Slice seqlen_q_buffer_;
- BufferAllocation::Slice seqlen_k_buffer_;
-
- // FusedMHA config
- absl::Mutex mutex_;
- absl::flat_hash_map<const stream_executor::Stream*,
- std::unique_ptr<FusedMultiHeadedAttentionRunner>>
- runner_cache_ ABSL_GUARDED_BY(mutex_);
-};
-
-//===----------------------------------------------------------------------===//
-// FusedMHABackwardCmd
-//===----------------------------------------------------------------------===//
-
-class FusedMHABackwardCmd : public TracedCommandBufferCmd {
- public:
- FusedMHABackwardCmd(
- ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config,
- BufferAllocation::Slice bmm1_grad_gemm1_rhs,
- BufferAllocation::Slice bmm1_grad_gemm2_rhs,
- BufferAllocation::Slice bmm2_grad_gemm1_lhs,
- BufferAllocation::Slice bmm2_grad_gemm2_rhs,
- BufferAllocation::Slice d_output, BufferAllocation::Slice scratch,
- BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs,
- BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s,
- BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output,
- BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q,
- BufferAllocation::Slice seqlen_k);
-
- absl::Status Initialize(const Thunk::InitializeParams& params,
- StateManager& state) override;
-
- absl::Status Record(const Thunk::ExecuteParams& execute_params,
- const RecordParams& record_params,
- se::CommandBuffer* command_buffer) override;
-
- BufferUsageVector buffers() override;
-
- bool IsNestedCommandBuffer() const final { return true; }
-
- private:
- FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner(
- const stream_executor::Stream* stream);
-
- const GpufMHABackwardConfig config_;
- BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_;
- BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_;
- BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_;
- BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_;
- BufferAllocation::Slice d_output_buffer_;
- BufferAllocation::Slice scratch_buffer_;
- BufferAllocation::Slice d_bmm1_lhs_buffer_;
- BufferAllocation::Slice d_bmm1_rhs_buffer_;
- BufferAllocation::Slice d_bmm2_rhs_buffer_;
- BufferAllocation::Slice d_s_buffer_;
- BufferAllocation::Slice d_bias_buffer_;
- BufferAllocation::Slice fwd_output_buffer_;
- BufferAllocation::Slice bias_buffer_;
- BufferAllocation::Slice seqlen_q_buffer_;
- BufferAllocation::Slice seqlen_k_buffer_;
-
- // FusedMHA config
- absl::Mutex mutex_;
- absl::flat_hash_map<const stream_executor::Stream*,
- std::unique_ptr<FusedMultiHeadedAttentionBackwardRunner>>
- runner_cache_ ABSL_GUARDED_BY(mutex_);
-};
-
-//===----------------------------------------------------------------------===//
// CublasLtCmd
//===----------------------------------------------------------------------===//
diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc
index 54e01fa..230d050 100644
--- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc
+++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc
@@ -29,7 +29,6 @@
#include "xla/service/gpu/runtime/copy_thunk.h"
#include "xla/service/gpu/runtime/cudnn_thunk.h"
#include "xla/service/gpu/runtime/custom_call_thunk.h"
-#include "xla/service/gpu/runtime/fused_mha_thunk.h"
#include "xla/service/gpu/runtime/gemm_thunk.h"
#include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h"
#include "xla/service/gpu/runtime/kernel_thunk.h"
@@ -143,27 +142,6 @@
thunk.workspace().value());
}
-static absl::StatusOr<Command> Convert(const FusedMHAThunk& thunk) {
- return std::make_unique<FusedMHACmd>(
- thunk.execution_stream_id(), thunk.config(), thunk.lhs_bmm1_buffer(),
- thunk.rhs_bmm1_buffer(), thunk.rhs_bmm2_buffer(), thunk.output_buffer(),
- thunk.scratch_buffer(), BufferAllocation::Slice(), thunk.bias_buffer(),
- thunk.activation_buffer(), thunk.seqlen_q_buffer(),
- thunk.seqlen_k_buffer());
-}
-
-static absl::StatusOr<Command> Convert(const FusedMHABackwardThunk& thunk) {
- return std::make_unique<FusedMHABackwardCmd>(
- thunk.execution_stream_id(), thunk.config(),
- thunk.bmm1_grad_gemm1_rhs_buffer(), thunk.bmm1_grad_gemm2_rhs_buffer(),
- thunk.bmm2_grad_gemm1_lhs_buffer(), thunk.bmm2_grad_gemm2_rhs_buffer(),
- thunk.d_output_buffer(), thunk.scratch_buffer(),
- thunk.d_bmm1_lhs_buffer(), thunk.d_bmm1_rhs_buffer(),
- thunk.d_bmm2_rhs_buffer(), thunk.d_s_buffer(), thunk.d_bias_buffer(),
- thunk.fwd_output_buffer(), thunk.bias_buffer(), thunk.seqlen_q_buffer(),
- thunk.seqlen_k_buffer());
-}
-
static absl::StatusOr<Command> Convert(
const ConditionalThunk& thunk,
CommandBufferCmdSequence::SynchronizationMode synchronization_mode) {
@@ -276,10 +254,6 @@
return append(Convert<CustomCallThunk>(thunk));
case Thunk::Kind::kCustomKernel:
return append(Convert<CustomKernelThunk>(thunk));
- case Thunk::Kind::kFusedMHA:
- return append(Convert<FusedMHAThunk>(thunk));
- case Thunk::Kind::kFusedMHABackward:
- return append(Convert<FusedMHABackwardThunk>(thunk));
case Thunk::Kind::kKernel:
return append(Convert<KernelThunk>(thunk));
case Thunk::Kind::kGemm:
diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
index e994fac..f77653e 100644
--- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
+++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
@@ -198,9 +198,9 @@
builder.AddAttributes(attrs.Build());
CallFrame call_frame = builder.Build();
- CallOptions options = {device_ordinal, stream,
- allocator, called_computation_,
- execution_context, execution_state_.get()};
+ CallOptions options = {
+ device_ordinal, CallOptions::GpuOptions{stream, allocator},
+ called_computation_, execution_context, execution_state_.get()};
return Call(handler, call_frame, options, stage);
}
diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc
index 604bce1..eda2bc6 100644
--- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc
+++ b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc
@@ -74,8 +74,6 @@
case Thunk::kCustomKernel:
case Thunk::kCuDnn:
case Thunk::kFft:
- case Thunk::kFusedMHA:
- case Thunk::kFusedMHABackward:
case Thunk::kGemm:
case Thunk::kInfeed:
case Thunk::kKernel:
diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc
deleted file mode 100644
index ee13689..0000000
--- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc
+++ /dev/null
@@ -1,230 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/runtime/fused_mha_thunk.h"
-
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "absl/status/status.h"
-#include "absl/synchronization/mutex.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/gpu/buffer_allocations.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-FusedMHAThunk::FusedMHAThunk(
- ThunkInfo thunk_info, GpufMHAConfig config,
- BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1,
- BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output,
- BufferAllocation::Slice scratch, BufferAllocation::Slice mask,
- BufferAllocation::Slice bias, BufferAllocation::Slice activation,
- BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k)
- : Thunk(Kind::kFusedMHA, thunk_info),
- lhs_bmm1_buffer_(lhs_bmm1),
- rhs_bmm1_buffer_(rhs_bmm1),
- rhs_bmm2_buffer_(rhs_bmm2),
- output_buffer_(output),
- scratch_buffer_(scratch),
- bias_buffer_(bias),
- activation_buffer_(activation),
- seqlen_q_buffer_(seqlen_q),
- seqlen_k_buffer_(seqlen_k),
- config_(std::move(config)) {}
-
-FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner(
- const stream_executor::Stream* stream) {
- absl::MutexLock lock(&mu_);
- auto it = runner_cache_.find(stream);
- if (it == runner_cache_.end()) {
- it = runner_cache_
- .insert({stream, std::make_unique<FusedMultiHeadedAttentionRunner>(
- config_)})
- .first;
- }
- return *it->second;
-}
-
-std::optional<se::DeviceMemoryBase> AssignBufferIfNotNull(
- const BufferAllocations& buffer_allocations,
- BufferAllocation::Slice& slice) {
- return slice.allocation() != nullptr
- ? std::optional<se::DeviceMemoryBase>{buffer_allocations
- .GetDeviceAddress(slice)}
- : std::nullopt;
-}
-
-absl::Status FusedMHAThunk::Initialize(const InitializeParams& params) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* lazy_runner =
- GetOrCreateRunner(params.stream).AsFusedMHARunner();
- TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig());
- return lazy_runner->GetOrCreateRunner(config, params.stream).status();
-}
-
-absl::Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) {
- const auto& buffer_allocations = *params.buffer_allocations;
- se::DeviceMemoryBase lhs_bmm1_buffer =
- buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_);
- se::DeviceMemoryBase rhs_bmm1_buffer =
- buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_);
- se::DeviceMemoryBase rhs_bmm2_buffer =
- buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_);
- se::DeviceMemoryBase output_buffer =
- buffer_allocations.GetDeviceAddress(output_buffer_);
- se::DeviceMemoryBase scratch_buffer =
- buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
- std::optional<se::DeviceMemoryBase> bias_buffer =
- AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
- std::optional<se::DeviceMemoryBase> activation_buffer =
- AssignBufferIfNotNull(buffer_allocations, activation_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
- RunFusedMHAOptions opts;
- opts.runner_cache = &GetOrCreateRunner(params.stream);
- TF_RETURN_IF_ERROR(RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer,
- rhs_bmm2_buffer, output_buffer, scratch_buffer,
- bias_buffer, activation_buffer, seqlen_q_buffer,
- seqlen_k_buffer, params.stream, opts));
-
- if (!params.stream->ok()) {
- return Internal("FusedMHAThunk::ExecuteOnStream failed.");
- }
- return absl::OkStatus();
-}
-FusedMHABackwardThunk::FusedMHABackwardThunk(
- ThunkInfo thunk_info, GpufMHABackwardConfig config,
- BufferAllocation::Slice bmm1_grad_gemm1_rhs,
- BufferAllocation::Slice bmm1_grad_gemm2_rhs,
- BufferAllocation::Slice bmm2_grad_gemm1_lhs,
- BufferAllocation::Slice bmm2_grad_gemm2_rhs,
- BufferAllocation::Slice d_output, BufferAllocation::Slice scratch,
- BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs,
- BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s,
- BufferAllocation::Slice mask, BufferAllocation::Slice d_bias,
- BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias,
- BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k)
- : Thunk(Kind::kFusedMHABackward, thunk_info),
- bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs),
- bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs),
- bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs),
- bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs),
- d_output_buffer_(d_output),
- scratch_buffer_(scratch),
- d_bmm1_lhs_buffer_(d_bmm1_lhs),
- d_bmm1_rhs_buffer_(d_bmm1_rhs),
- d_bmm2_rhs_buffer_(d_bmm2_rhs),
- d_s_buffer_(d_s),
- d_bias_buffer_(d_bias),
- fwd_output_buffer_(fwd_output),
- bias_buffer_(bias),
- seqlen_q_buffer_(seqlen_q),
- seqlen_k_buffer_(seqlen_k),
- config_(std::move(config)) {}
-
-FusedMultiHeadedAttentionBackwardRunner&
-FusedMHABackwardThunk::GetOrCreateRunner(
- const stream_executor::Stream* stream) {
- absl::MutexLock lock(&mu_);
- auto it = runner_cache_.find(stream);
- if (it == runner_cache_.end()) {
- it = runner_cache_
- .insert({stream,
- std::make_unique<FusedMultiHeadedAttentionBackwardRunner>(
- config_)})
- .first;
- }
- return *it->second;
-}
-
-absl::Status FusedMHABackwardThunk::Initialize(const InitializeParams& params) {
- se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>* lazy_runner =
- GetOrCreateRunner(params.stream).AsFusedMHABackwardRunner();
- TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig());
- return lazy_runner->GetOrCreateRunner(config, params.stream).status();
-}
-
-absl::Status FusedMHABackwardThunk::ExecuteOnStream(
- const ExecuteParams& params) {
- const auto& buffer_allocations = *params.buffer_allocations;
- se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_);
-
- se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_);
-
- se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_);
-
- se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer =
- buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_);
-
- se::DeviceMemoryBase d_output_buffer =
- buffer_allocations.GetDeviceAddress(d_output_buffer_);
-
- se::DeviceMemoryBase scratch_buffer =
- buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
- se::DeviceMemoryBase d_bmm1_lhs_buffer =
- buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_);
-
- se::DeviceMemoryBase d_bmm1_rhs_buffer =
- buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_);
-
- se::DeviceMemoryBase d_bmm2_rhs_buffer =
- buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_);
-
- std::optional<se::DeviceMemoryBase> d_s_buffer =
- AssignBufferIfNotNull(buffer_allocations, d_s_buffer_);
- std::optional<se::DeviceMemoryBase> d_bias_buffer =
- AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_);
- std::optional<se::DeviceMemoryBase> fwd_output_buffer =
- AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_);
- std::optional<se::DeviceMemoryBase> bias_buffer =
- AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
- std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
- AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
- RunFusedMHABackwardOptions opts;
-
- opts.runner_cache = &GetOrCreateRunner(params.stream);
-
- TF_RETURN_IF_ERROR(RunGpuFMHABackward(
- config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
- bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer,
- scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer,
- d_s_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer,
- seqlen_q_buffer, seqlen_k_buffer, params.stream, opts));
- if (!params.stream->ok()) {
- return Internal("FusedMHABackwardThunk::ExecuteOnStream failed.");
- }
- return absl::OkStatus();
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h
deleted file mode 100644
index 99a8327..0000000
--- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h
+++ /dev/null
@@ -1,184 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_
-#define XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_
-
-#include <memory>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/status.h"
-#include "absl/synchronization/mutex.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// This class stores everything that StreamExecutor needs to launch a DNN
-// fMHA. It is generated by IrEmitter.
-//
-// This is thread-compatible.
-class FusedMHAThunk : public Thunk {
- public:
- // Constructs a thunk for launching a DNN FMHA.
- FusedMHAThunk(ThunkInfo thunk_info, GpufMHAConfig config,
- BufferAllocation::Slice lhs_bmm1_slice,
- BufferAllocation::Slice rhs_bmm1_slice,
- BufferAllocation::Slice rhs_bmm2_slice,
- BufferAllocation::Slice output_slice,
- BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice mask_slice, /* may be null */
- BufferAllocation::Slice bias_slice /* may be null */,
- BufferAllocation::Slice activation_slice /* may be null */,
- BufferAllocation::Slice seqlen_q_slice /* may be null */,
- BufferAllocation::Slice seqlen_k_slice /* may be null */);
-
- FusedMHAThunk(const FusedMHAThunk&) = delete;
- FusedMHAThunk& operator=(const FusedMHAThunk&) = delete;
-
- BufferAllocation::Slice lhs_bmm1_buffer() const { return lhs_bmm1_buffer_; }
- BufferAllocation::Slice rhs_bmm1_buffer() const { return rhs_bmm1_buffer_; }
- BufferAllocation::Slice rhs_bmm2_buffer() const { return rhs_bmm2_buffer_; }
- BufferAllocation::Slice output_buffer() const { return output_buffer_; }
- BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; }
- BufferAllocation::Slice bias_buffer() const { return bias_buffer_; }
- BufferAllocation::Slice activation_buffer() const {
- return activation_buffer_;
- }
- BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; }
- BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; }
-
- GpufMHAConfig config() const { return config_; }
- absl::Status Initialize(const InitializeParams& params) override;
- absl::Status ExecuteOnStream(const ExecuteParams& params) override;
-
- private:
- BufferAllocation::Slice lhs_bmm1_buffer_;
- BufferAllocation::Slice rhs_bmm1_buffer_;
- BufferAllocation::Slice rhs_bmm2_buffer_;
- BufferAllocation::Slice output_buffer_;
- BufferAllocation::Slice scratch_buffer_;
- BufferAllocation::Slice bias_buffer_;
- BufferAllocation::Slice activation_buffer_;
- BufferAllocation::Slice seqlen_q_buffer_;
- BufferAllocation::Slice seqlen_k_buffer_;
-
- FusedMultiHeadedAttentionRunner& GetOrCreateRunner(
- const stream_executor::Stream* stream);
-
- // FusedMHA config
- const GpufMHAConfig config_;
- absl::Mutex mu_;
- absl::flat_hash_map<const stream_executor::Stream*,
- std::unique_ptr<FusedMultiHeadedAttentionRunner>>
- runner_cache_ ABSL_GUARDED_BY(mu_);
-};
-
-class FusedMHABackwardThunk : public Thunk {
- public:
- // Constructs a thunk for launching a DNN FMHA backward.
- FusedMHABackwardThunk(ThunkInfo thunk_info, GpufMHABackwardConfig config,
- BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice,
- BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice,
- BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice,
- BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice,
- BufferAllocation::Slice d_output_slice,
- BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice d_bmm1_lhs_slice,
- BufferAllocation::Slice d_bmm1_rhs_slice,
- BufferAllocation::Slice d_bmm2_rhs_slice,
- BufferAllocation::Slice d_s_slice,
- BufferAllocation::Slice mask_slice,
- BufferAllocation::Slice d_bias_slice,
- BufferAllocation::Slice fwd_output_slice,
- BufferAllocation::Slice bias_slice,
- BufferAllocation::Slice seqlen_q_slice,
- BufferAllocation::Slice seqlen_k_slice);
-
- FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete;
- FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete;
-
- BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer() const {
- return bmm1_grad_gemm1_rhs_buffer_;
- }
- BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer() const {
- return bmm1_grad_gemm2_rhs_buffer_;
- }
- BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer() const {
- return bmm2_grad_gemm1_lhs_buffer_;
- }
- BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer() const {
- return bmm2_grad_gemm2_rhs_buffer_;
- }
- BufferAllocation::Slice d_output_buffer() const { return d_output_buffer_; }
- BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; }
- BufferAllocation::Slice d_bmm1_lhs_buffer() const {
- return d_bmm1_lhs_buffer_;
- }
- BufferAllocation::Slice d_bmm1_rhs_buffer() const {
- return d_bmm1_rhs_buffer_;
- }
- BufferAllocation::Slice d_bmm2_rhs_buffer() const {
- return d_bmm2_rhs_buffer_;
- }
- BufferAllocation::Slice d_s_buffer() const { return d_s_buffer_; }
- BufferAllocation::Slice d_bias_buffer() const { return d_bias_buffer_; }
- BufferAllocation::Slice fwd_output_buffer() const {
- return fwd_output_buffer_;
- }
- BufferAllocation::Slice bias_buffer() const { return bias_buffer_; }
- BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; }
- BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; }
-
- GpufMHABackwardConfig config() const { return config_; }
-
- absl::Status Initialize(const InitializeParams& params) override;
- absl::Status ExecuteOnStream(const ExecuteParams& params) override;
-
- private:
- BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_;
- BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_;
- BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_;
- BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_;
- BufferAllocation::Slice d_output_buffer_;
- BufferAllocation::Slice scratch_buffer_;
- BufferAllocation::Slice d_bmm1_lhs_buffer_;
- BufferAllocation::Slice d_bmm1_rhs_buffer_;
- BufferAllocation::Slice d_bmm2_rhs_buffer_;
- BufferAllocation::Slice d_s_buffer_;
- BufferAllocation::Slice d_bias_buffer_;
- BufferAllocation::Slice fwd_output_buffer_;
- BufferAllocation::Slice bias_buffer_;
- BufferAllocation::Slice seqlen_q_buffer_;
- BufferAllocation::Slice seqlen_k_buffer_;
-
- FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner(
- const stream_executor::Stream* stream);
-
- // FusedMHA backward config
- const GpufMHABackwardConfig config_;
- absl::Mutex mu_;
- absl::flat_hash_map<const stream_executor::Stream*,
- std::unique_ptr<FusedMultiHeadedAttentionBackwardRunner>>
- runner_cache_ ABSL_GUARDED_BY(mu_);
-};
-} // namespace gpu
-} // namespace xla
-#endif // XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_
diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc
index fc3c0cf..6f3081a 100644
--- a/third_party/xla/xla/service/gpu/runtime/thunk.cc
+++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc
@@ -286,8 +286,6 @@
CASE(kSequential);
CASE(kTriangularSolve);
CASE(kWhile);
- CASE(kFusedMHA);
- CASE(kFusedMHABackward);
CASE(kWaitForStreams);
CASE(kCuDnn);
}
diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h
index 3466649..cd26323 100644
--- a/third_party/xla/xla/service/gpu/runtime/thunk.h
+++ b/third_party/xla/xla/service/gpu/runtime/thunk.h
@@ -165,8 +165,6 @@
kSendDone,
kTriangularSolve,
kWhile,
- kFusedMHA,
- kFusedMHABackward,
kWaitForStreams,
kCuDnn
};
diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc b/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc
deleted file mode 100644
index 9672bf2..0000000
--- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc
+++ /dev/null
@@ -1,264 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scatter_slice_simplifier.h"
-
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-// Returns whether the instruction could be an operand for a slice instruction.
-bool IsValidIntermediaryUser(const HloInstruction* instruction) {
- // Allow elementwise instructions, as they don't depend on the truncated
- // elements. In case of multi-output scatters, the resulting shape is a tuple.
- return instruction->IsElementwise() ||
- instruction->opcode() == HloOpcode::kGetTupleElement;
-}
-
-// Matches the "Scatter -> Elementwise (zero or more) -> Slice" pattern.
-// Calculates the resulting scatter dimensions from the slice users.
-class ScatterSliceMatcher {
- public:
- explicit ScatterSliceMatcher(const HloScatterInstruction* scatter)
- : scatter_(scatter),
- operand_dimensions_(
- scatter->scatter_operands()[0]->shape().dimensions()),
- result_dimensions_(operand_dimensions_.begin(),
- operand_dimensions_.end()) {}
-
- // Determine the scatter shape from the user slice instructions.
- // If any of the users are not truncation slices, return `nullopt`.
- std::optional<Shape> InferShape() {
- VLOG(10) << "Evaluating scatter " << scatter_->name();
- if (!AreAllUsersValid(scatter_)) {
- return std::nullopt;
- }
- std::vector<Shape> result_shapes;
- absl::c_transform(scatter_->scatter_operands(),
- std::back_inserter(result_shapes),
- [&](const HloInstruction* op) {
- return ShapeUtil::MakeShape(op->shape().element_type(),
- result_dimensions_);
- });
- return ShapeUtil::MakeMaybeTupleShape(result_shapes);
- }
-
- private:
- // Update the resulting scatter dimensions from the slice configuration and
- // the original scatter dimensions. Return `false` if the update is not
- // possible.
- bool UpdateDimensions(const HloSliceInstruction* slice) {
- int64_t rank = slice->shape().rank();
- for (int64_t i = 0; i < rank; ++i) {
- if (slice->slice_starts(i) != 0 || slice->slice_strides(i) != 1) {
- return false; // The slice is not a truncation.
- }
- if (slice->slice_limits(i) != result_dimensions_[i]) {
- if (result_dimensions_[i] != operand_dimensions_[i]) {
- return false; // Another slice has incompatible dimensions.
- }
- auto& update_window_dims =
- scatter_->scatter_dimension_numbers().update_window_dims();
- if (absl::c_binary_search(update_window_dims, i)) {
- return false; // Update dimensions cannot be truncated.
- }
- result_dimensions_[i] = slice->slice_limits(i);
- VLOG(10) << "Dimension " << i << " truncated to size "
- << result_dimensions_[i];
- }
- }
- return true;
- }
-
- // Verify that the instruction is a valid scatter user, i.e. is either a slice
- // operation or is an elementwise operation that has slice users (recursive).
- bool IsUserValid(const HloInstruction* op) {
- VLOG(10) << "Visiting user " << op->name();
-
- // If the user is a slice operation, verify the configuration and update
- // the resulting dimensions.
- if (auto* slice = DynCast<HloSliceInstruction>(op)) {
- return UpdateDimensions(slice);
- }
- // If the user is an elementwise operation, verify the users recursively
- // (unless already visited).
- bool is_valid = visited_set_.contains(op) ||
- (IsValidIntermediaryUser(op) && AreAllUsersValid(op));
- if (is_valid) {
- visited_set_.emplace(op);
- }
- return is_valid;
- }
-
- // Verify that all users are valid (see the definition of IsValidUser).
- // If we reach the root instruction, fail the matching (slice is not found).
- bool AreAllUsersValid(const HloInstruction* op) {
- if (op->user_count() == 0) {
- return !op->IsRoot();
- }
- return absl::c_all_of(op->users(), [this](const HloInstruction* user) {
- return IsUserValid(user);
- });
- }
-
- const HloScatterInstruction* scatter_;
- absl::flat_hash_set<const HloInstruction*> visited_set_;
- absl::Span<const int64_t> operand_dimensions_;
- DimensionVector result_dimensions_;
-};
-
-// Create a replacement operand for the scatter instruction.
-HloInstruction* CreateSliceFrom(HloInstruction* operand, const Shape& shape) {
- std::vector<int64_t> start_indices(shape.rank(), 0);
- std::vector<int64_t> limit_indices(shape.rank());
- std::vector<int64_t> strides(shape.rank(), 1);
- for (int64_t i = 0; i < shape.rank(); ++i) {
- limit_indices[i] = shape.dimensions(i);
- }
- return operand->AddInstruction(HloInstruction::CreateSlice(
- shape, operand, start_indices, limit_indices, strides));
-}
-
-// Create a replacement for the scatter instruction.
-HloInstruction* CreateScatterFrom(HloScatterInstruction* scatter,
- const Shape& shape) {
- std::vector<HloInstruction*> operands(scatter->scatter_operand_count());
- for (int64_t i = 0; i < operands.size(); ++i) {
- operands[i] =
- CreateSliceFrom(scatter->scatter_operands()[i],
- shape.IsTuple() ? shape.tuple_shapes(i) : shape);
- }
- return scatter->AddInstruction(HloInstruction::CreateScatter(
- shape, absl::MakeSpan(operands), scatter->scatter_indices(),
- scatter->scatter_updates(), scatter->called_computations()[0],
- scatter->scatter_dimension_numbers(), scatter->indices_are_sorted(),
- scatter->unique_indices()));
-}
-
-class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleScatter(HloInstruction* instruction) override {
- auto* scatter = Cast<HloScatterInstruction>(instruction);
-
- // Infer scatter shape from the slice users.
- std::optional<Shape> result_shape =
- ScatterSliceMatcher(scatter).InferShape();
- if (!result_shape.has_value()) {
- return absl::OkStatus();
- }
- VLOG(2) << "Matched scatter " << scatter->name() << " with shape "
- << scatter->shape().ToString() << ", inferred result shape "
- << result_shape->ToString() << " (from the slice users)";
-
- // Replace slice user instructions.
- HloInstruction* new_scatter = CreateScatterFrom(scatter, *result_shape);
- return ReplaceAllUsersRecursive(scatter, new_scatter);
- }
-
- private:
- // Create a replacement for every user. If the user is a slice operation,
- // replace it in the computation graph, the old branch will be removed.
- absl::Status ReplaceAllUsersRecursive(HloInstruction* old_instruction,
- HloInstruction* new_instruction) {
- // Maintain the replacement map, needed for non-unary elementwise users.
- replacements_[old_instruction] = new_instruction;
-
- // It's importand to make a copy of the users list, as it may be modified
- // during the iteration.
- std::vector<HloInstruction*> users = old_instruction->users();
- for (HloInstruction* user : users) {
- if (user->parent() == nullptr) {
- VLOG(3) << "Skipping user " << user->name() << " (already replaced)";
- continue;
- }
- TF_RETURN_IF_ERROR(ReplaceUserRecursive(user, new_instruction));
- }
- return absl::OkStatus();
- }
-
- // Replace the slice user with a new scatter (or a new chain of operations
- // starting with a scatter). For elementwise operations, create a new user
- // with updated operands (build the chain).
- absl::Status ReplaceUserRecursive(HloInstruction* user,
- HloInstruction* operand) {
- VLOG(3) << "Replacing scatter user " << user->name();
- if (user->opcode() == HloOpcode::kSlice) {
- return ReplaceInstruction(user, operand);
- }
-
- // Create the replacement instruction with new shape.
- HloInstruction* new_user = nullptr;
- if (user->IsElementwise()) {
- auto new_shape = [operand](HloInstruction* from) {
- return ShapeUtil::MakeShape(from->shape().element_type(),
- operand->shape().dimensions());
- };
- std::vector<HloInstruction*> new_operands;
- absl::c_transform(user->operands(), std::back_inserter(new_operands),
- [&](HloInstruction* op) {
- auto it = replacements_.find(op);
- return it != replacements_.end()
- ? it->second
- : CreateSliceFrom(op, new_shape(op));
- });
- new_user = user->AddInstruction(
- user->CloneWithNewOperands(new_shape(user), new_operands));
- } else {
- auto* gte = Cast<HloGetTupleElementInstruction>(user);
- TF_ASSIGN_OR_RETURN(new_user,
- MakeGetTupleElementHlo(operand, gte->tuple_index(),
- &user->metadata()));
- }
-
- // Replace slice user instructions recursively.
- return ReplaceAllUsersRecursive(user, new_user);
- }
-
- absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements_;
-};
-
-} // namespace
-
-absl::StatusOr<bool> ScatterSliceSimplifier::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- return ScatterSliceSimplifierVisitor{}.RunOnModule(module, execution_threads);
-}
-
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h b/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h
deleted file mode 100644
index 3498377..0000000
--- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_
-#define XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Replaces scatters followed by truncation slices with a new scatter using
-// a different output shape, and the slices are eliminated.
-//
-// (a) Single output (b) Multiple outputs (c) Elementwise users
-//
-// T[N+1] scatter (T1, T2) scatter T scatter T constant
-// v v v v v
-// T[N] slice T1 gte T2 gte T maximum
-// v v v
-// T1 slice T2 slice T slice
-//
-// This pattern is used when the last element of the scatter output is intended
-// to accumulate updates from the input elements that should be ignored.
-// This is slow if there are many indices mapped to the last output index and
-// the scatter is implemented using atomics, so everything collides on that one
-// memory location.
-// As OOB scatter indices are dropped by the GPU implementation, we can remove
-// the slice step entirely and avoid the memory congestion in the scatter step.
-
-class ScatterSliceSimplifier : public HloModulePass {
- public:
- absl::string_view name() const override { return "scatter-slice-simplifier"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc
deleted file mode 100644
index 281a4f0..0000000
--- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc
+++ /dev/null
@@ -1,336 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scatter_slice_simplifier.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace {
-
-namespace m = ::xla::match;
-
-using ScatterSliceSimplifierTest = HloTestBase;
-
-TEST_F(ScatterSliceSimplifierTest, Scatter1D) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[9] constant(0)
- %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- ROOT %slice = f32[8] slice(%scatter), slice={[0:8]}
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
- m::Parameter(1))
- .WithShape(F32, {8})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, Scatter3D) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
- %indices = s32[2] parameter(0)
- %updates = f32[2,4,4] parameter(1)
- %operands = f32[5,4,4] constant(0)
- %scatter = f32[5,4,4] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- ROOT %slice = f32[4,4,4] slice(%scatter), slice={[0:4], [0:4], [0:4]}
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
- m::Parameter(1))
- .WithShape(F32, {4, 4, 4})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, ScatterMultiOutput) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32_add_F16 {
- %lhs.0 = f32[] parameter(0)
- %rhs.0 = f32[] parameter(2)
- %add.0 = f32[] add(%lhs.0, %rhs.0)
- %lhs.1 = f16[] parameter(1)
- %rhs.1 = f16[] parameter(3)
- %add.1 = f16[] add(%lhs.1, %rhs.1)
- ROOT %tuple = (f32[], f16[]) tuple(%add.0, %add.1)
-}
-
-ENTRY main {
- %indices = s32[4] parameter(0)
- %updates.0 = f32[4] parameter(1)
- %updates.1 = f16[4] parameter(2)
- %operands.0 = f32[9] constant(0)
- %operands.1 = f16[9] constant(0)
- %scatter = (f32[9], f16[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_add_F16
- %gte.0 = f32[9] get-tuple-element(%scatter), index=0
- %slice.0 = f32[8] slice(%gte.0), slice={[0:8]}
- %gte.1 = f16[9] get-tuple-element(%scatter), index=1
- %slice.1 = f16[8] slice(%gte.1), slice={[0:8]}
- ROOT %tuple = (f32[8], f16[8]) tuple(%slice.0, %slice.1)
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
- auto expected_scatter =
- m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
- m::Parameter(0), m::Parameter(1), m::Parameter(2));
-
- Shape expected_shape = ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F16, {8})});
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::GetTupleElement(expected_scatter),
- m::GetTupleElement(expected_scatter))
- .WithShapeEqualTo(&expected_shape)));
-}
-
-TEST_F(ScatterSliceSimplifierTest, NotMatching) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-slice_not_truncation {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[9] constant(0)
- %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- ROOT %slice = f32[8] slice(%scatter), slice={[1:9]}
-}
-
-slice_with_stride {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[9] constant(0)
- %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- ROOT %slice = f32[4] slice(%scatter), slice={[0:8:2]}
-}
-
-scatter_multiple_users {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[9] constant(0)
- %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- %slice = f32[8] slice(%scatter), slice={[0:8]}
- ROOT %tuple = (f32[9], f32[8]) tuple(%scatter, %slice)
-}
-
-scatter_incompatible_slices {
- %indices = s32[2] parameter(0)
- %updates = f32[2,4] parameter(1)
- %operands = f32[4,4] constant(0)
- %scatter = f32[4,4] scatter(%operands, %indices, %updates), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- %slice.0 = f32[3,4] slice(%scatter), slice={[0:3], [0:4]}
- %slice.1 = f32[4,3] slice(%scatter), slice={[0:4], [0:3]}
- ROOT %tuple = (f32[3,4], f32[4,3]) tuple(%slice.0, %slice.1)
-}
-
-slice_not_found {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[8] constant(0)
- %scatter = f32[8] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- ROOT %exp = f32[8] exponential(%scatter)
-}
-
-slice_update_dimensions {
- %indices = s32[10] parameter(0)
- %updates = f32[10,1,128] parameter(1)
- %operands = f32[100,128] constant(0)
- %scatter = f32[100,128] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- ROOT %slice = f32[100,64] slice(%scatter), slice={[0:100], [0:64]}
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_FALSE(RunHloPass(&test_pass, module.get()).value());
-}
-
-TEST_F(ScatterSliceSimplifierTest, IntermediaryUsers) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[9] constant(0)
- %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- %unary = f32[9] abs(%scatter)
- %slice.0 = f32[8] slice(%unary), slice={[0:8]}
- %binary = f32[9] maximum(%scatter, %operands)
- %slice.1 = f32[8] slice(%binary), slice={[0:8]}
- ROOT %tuple = (f32[8], f32[8]) tuple(%slice.0, %slice.1)
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
- auto expected_scatter =
- m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
-
- Shape expected_shape = ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})});
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(m::Abs(expected_scatter),
- m::Maximum(expected_scatter, m::Slice(m::Constant())))
- .WithShapeEqualTo(&expected_shape)));
-}
-
-TEST_F(ScatterSliceSimplifierTest, IntermediaryChain) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[9] constant(0)
- %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- %elementwise.0 = f32[9] abs(%scatter)
- %elementwise.1 = f32[9] exponential(%elementwise.0)
- %elementwise.2 = f32[9] add(%elementwise.0, %elementwise.1)
- ROOT %result = f32[8] slice(%elementwise.2), slice={[0:8]}
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
- auto expected_scatter =
- m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Add(m::Abs(expected_scatter),
- m::Exp(m::Abs(expected_scatter)))
- .WithShape(F32, {8})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, DiamondShape) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32_mul_F32 {
- %lhs.0 = f32[] parameter(0)
- %rhs.0 = f32[] parameter(2)
- %add.0 = f32[] add(%lhs.0, %rhs.0)
- %lhs.1 = f32[] parameter(1)
- %rhs.1 = f32[] parameter(3)
- %mul.1 = f32[] multiply(%lhs.1, %rhs.1)
- ROOT %tuple = (f32[], f32[]) tuple(%add.0, %mul.1)
-}
-
-ENTRY main {
- %indices = s32[4] parameter(0)
- %updates.0 = f32[4] parameter(1)
- %updates.1 = f32[4] parameter(2)
- %operands.0 = f32[9] constant(0)
- %operands.1 = f32[9] constant(0)
- %scatter = (f32[9], f32[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_mul_F32
- %gte.0 = f32[9] get-tuple-element(%scatter), index=0
- %gte.1 = f32[9] get-tuple-element(%scatter), index=1
- %consumer = f32[9] add(%gte.0, %gte.1)
- ROOT %slice = f32[8] slice(%consumer), slice={[0:8]}
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
- auto expected_scatter =
- m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
- m::Parameter(0), m::Parameter(1), m::Parameter(2));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Add(m::GetTupleElement(expected_scatter),
- m::GetTupleElement(expected_scatter))
- .WithShape(F32, {8})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, ElementwiseSelect) {
- auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
- %indices = s32[4] parameter(0)
- %updates = f32[4] parameter(1)
- %operands = f32[9] constant(0)
- %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
- %pred_ = pred[9] parameter(2)
- %select = f32[9] select(%pred_, %scatter, %operands)
- ROOT %slice = f32[8] slice(%select), slice={[0:8]}
-}
- )")
- .value();
- ScatterSliceSimplifier test_pass;
- ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
- auto expected_scatter =
- m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Select(m::Slice(m::Parameter(2)), expected_scatter,
- m::Slice(m::Constant()))
- .WithShape(F32, {8})));
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc b/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc
deleted file mode 100644
index fbf1b2c..0000000
--- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scheduling_instruction_annotator.h"
-
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-// Populates `OpMetadata`'s `scheduling_name` field for all of the instructions
-// belonging to `computation`.
-absl::StatusOr<bool> AnnotateSchedulingInstructionNames(
- HloComputation& computation) {
- bool changed = false;
- for (HloInstruction* inst : computation.instructions()) {
- if (!inst->metadata().scheduling_name().empty()) {
- continue;
- }
- inst->set_metadata_scheduling_name(std::string(inst->name()));
- changed = true;
- }
- return changed;
-}
-
-} // namespace
-
-absl::StatusOr<bool> SchedulingInstructionAnnotator::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- CHECK(module->has_schedule())
- << "The pass is supposed to run in the beginning of post-scheduling!";
- bool changed = false;
-
- // We visit computations in the order of callees to callers, as information is
- // propagated from calles to callers.
- for (HloComputation* computation :
- module->MakeComputationPostOrder(execution_threads)) {
- TF_ASSIGN_OR_RETURN(bool result,
- AnnotateSchedulingInstructionNames(*computation));
- changed |= result;
- }
-
- return changed;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h b/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h
deleted file mode 100644
index 3f9b769..0000000
--- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// The pass amends the `OpMetadata` with instruction name present at the
-// scheduling time. This is later being used to make sure instructions are not
-// renamed post scheduling. Enforcing this is necessary because otherwise
-class SchedulingInstructionAnnotator : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "scheduling-instruction-annotator";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc b/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc
deleted file mode 100644
index 146607f..0000000
--- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scheduling_instruction_annotator.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using SchedulingInstructionAnnotatorTest = HloTestBase;
-
-TEST_F(SchedulingInstructionAnnotatorTest,
- AnnotatesAllInstructionsWithTheirRespectiveNames) {
- constexpr absl::string_view kHloString = R"(
- HloModule module, is_scheduled=true
-
- ENTRY entry {
- p0 = f32[1] parameter(0)
- p1 = f32[1] parameter(1)
- add0 = f32[1] add(p0,p1)
- ROOT exp0 = f32[1] exponential(add0)
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- SchedulingInstructionAnnotator pass;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-
- ASSERT_TRUE(changed);
- for (const auto* comp : module->computations()) {
- for (const auto* instruction : comp->instructions()) {
- EXPECT_EQ(instruction->name(), instruction->metadata().scheduling_name());
- }
- }
- constexpr absl::string_view kExpected = R"(
-// CHECK: %[[P0:.+]] = {{.*}} parameter(0)
-// CHECK-SAME: scheduling_name="[[P0]]"
-// CHECK: %[[P1:.+]] = {{.*}} parameter(1)
-// CHECK-SAME: scheduling_name="[[P1]]"
-// CHECK: %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[P1]])
-// CHECK-SAME: scheduling_name="[[ADD0]]"
-// CHECK: ROOT %[[EXP0:.+]] = {{.*}} exponential(%[[ADD0]])
-// CHECK-SAME: scheduling_name="[[EXP0]]"
- )";
- TF_ASSERT_OK_AND_ASSIGN(
- bool filecheck_matches,
- RunFileCheck(
- module->ToString(HloPrintOptions().set_print_operand_shape(false)),
- kExpected));
- EXPECT_TRUE(filecheck_matches);
-}
-
-TEST_F(SchedulingInstructionAnnotatorTest,
- DoesNotAnnotateAllInstructionsWithTheirRespectiveNames) {
- constexpr absl::string_view kHloString = R"(
- HloModule module, is_scheduled=true
-
- ENTRY entry {
- p0 = f32[1] parameter(0), metadata={scheduling_name="p0"}
- p1 = f32[1] parameter(1), metadata={scheduling_name="p1"}
- add0 = f32[1] add(p0,p1), metadata={scheduling_name="add0"}
- ROOT exp0 = f32[1] exponential(add0), metadata={scheduling_name="exp0"}
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- SchedulingInstructionAnnotator pass;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-
- EXPECT_FALSE(changed);
-}
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc
deleted file mode 100644
index c6bd796..0000000
--- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc
+++ /dev/null
@@ -1,797 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/softmax_rewriter_triton.h"
-
-#include <cstdint>
-#include <functional>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/model/fusion_analysis_cache.h"
-#include "xla/service/gpu/model/gpu_indexing_performance_model.h"
-#include "xla/service/gpu/model/symbolic_tile_analysis.h"
-#include "xla/service/gpu/model/tiled_hlo_computation.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using hlo_query::IsBroadcastOfParameter;
-using hlo_query::IsBroadcastOfScalarConstant;
-
-bool HasDefaultLayout(const Shape& shape) {
- return shape.has_layout() &&
- LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
-}
-
-// Returns true if a trivially connected producer of 'consumer' with opcode
-// 'opcode' exists. If such an instruction is found, the value of 'producer' is
-// set to it. The definition of "trivial" operations is as given in
-// 'IsTriviallyFusible'.
-bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
- HloOpcode opcode, const se::GpuComputeCapability& gpu_version);
-
-bool BitcastIsTilingNoop(HloInstruction* bitcast,
- const se::GpuComputeCapability& gpu_version) {
- CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
-
- if (ShapeUtil::IsEffectiveScalar(bitcast->shape())) {
- return true;
- }
-
- // In the Softmax rewriter for now, tiling is derived from a hero reduction
- // operation, which should be reducing its input on the last axis. Therefore,
- // a bitcast is always a no-op with regards to a tile if
- // (1) it does not change the size of the reduction dimension of its input
- // (the last one); if its input is already reduced, then (1) is true
- // by default
- // (2) the layout of its output is ordered in the same way as the layout of
- // its input. This is a fuzzy definition, but since we assume fusible
- // ops to always have a default layout, we can just check if both the
- // bitcast and its input have a default layout
- auto last_dimension = [](const HloInstruction* instr) {
- return instr->shape().dimensions().back();
- };
-
- HloInstruction* reduce = nullptr;
- TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce,
- gpu_version);
-
- return (HasDefaultLayout(bitcast->shape()) &&
- HasDefaultLayout(bitcast->operand(0)->shape()) &&
- (reduce != nullptr ||
- last_dimension(bitcast->operand(0)) == last_dimension(bitcast)));
-}
-
-inline bool HasOneUse(const HloInstruction* instr) {
- return instr->user_count() == 1;
-}
-
-// Supports two types of broadcast of parameters. Either to one batch
-// dim, or one reduction dim. For example the following cases are supported:
-//
-// Case #1:
-// p = f32[a] parameter(0)
-// b = f32[a,x] broadcast(p), dimensions={0}
-//
-// Case #2:
-// p = f32[a] parameter(0)
-// b = f32[x,a] broadcast(p), dimensions={1}
-//
-// Case #3:
-// p = f32[a,b] parameter(0)
-// b = f32[x,a,b] broadcast(p), dimensions={1,2}
-//
-// Other broadcast tiling patterns are currently unsupported.
-// See b/328049138 for details.
-//
-// Unsupported case #1:
-// p = f32[a] parameter(0)
-// b = f32[x,a,y] broadcast(p), dimensions={1}
-//
-// Unsupported case #2:
-// p = f32[a,b] parameter(0)
-// b = f32[x,a,y,b] broadcast(p), dimensions={1,3}
-//
-// Unsupported case #3:
-// p = f32[a] parameter(0)
-// b = f32[x,y,a] broadcast(p), dimensions={2}
-//
-// Unsupported case #4:
-// p = f32[a,b] parameter(0)
-// b = f32[a,x,b] broadcast(p), dimensions={0,2}
-bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) {
- CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
- << "Expected broadcast " << hlo.ToShortString();
- CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
- << "Expected parameter " << hlo.operand(0)->ToShortString();
-
- const HloBroadcastInstruction* broadcast =
- Cast<HloBroadcastInstruction>(&hlo);
-
- const HloParameterInstruction* parameter =
- Cast<HloParameterInstruction>(hlo.operand(0));
-
- // Support only one dim broadcast.
- if (parameter->shape().dimensions_size() + 1 !=
- broadcast->shape().dimensions_size()) {
- return false;
- }
-
- // It is enough to ensure that the broadcast does not preserve both last, and
- // first dimensions of the parameter at the same time. Otherwise the broadcast
- // is the unsupported case #4.
- //
- // Preserve the first dim:
- // p = f32[a,b] parameter(0)
- // b1 = f32[a,b,c] broadcast(p), dimensions={0,1}
- bool preserve_first_dim = broadcast->dimensions().front() == 0;
- // Preserve the last dim:
- // p = f32[a,b] parameter(0)
- // b1 = f32[c,a,b] broadcast(p), dimensions={1,2}
- bool preserve_last_dim = broadcast->dimensions().back() ==
- broadcast->shape().dimensions_size() - 1;
- // We do not want to preserve both first and last dim, as it means the
- // broadcast is not expanding on outermost dims.
- return !(preserve_first_dim && preserve_last_dim);
-}
-
-bool IsBroadcastOfAScalar(const HloInstruction& hlo) {
- CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
- << "Expected broadcast " << hlo.ToShortString();
- return ShapeUtil::IsScalar(hlo.operand(0)->shape());
-}
-
-bool IsSingleRowParameterBroadcast(const HloInstruction& hlo) {
- CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
- << "Expected broadcast " << hlo.ToShortString();
- CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
- << "Expected parameter " << hlo.operand(0)->ToShortString();
-
- const HloBroadcastInstruction* broadcast =
- Cast<HloBroadcastInstruction>(&hlo);
- const HloParameterInstruction* parameter =
- Cast<HloParameterInstruction>(hlo.operand(0));
-
- if (parameter->shape().dimensions_size() != 1) {
- return false;
- }
- return broadcast->dimensions()[0] == broadcast->shape().dimensions_size() - 1;
-}
-
-bool IsSupportedBroadcastOfParameter(const HloInstruction& hlo) {
- return IsBroadcastOfParameter(hlo) &&
- (IsBatchOrReductionDimBroadcast(hlo) || IsBroadcastOfAScalar(hlo) ||
- IsSingleRowParameterBroadcast(hlo));
-}
-
-// Chooses which operand to use for fusion processing. Taking in a unary or
-// binary instruction, returns the first non-splat operand. If none is
-// present, returns any operand.
-HloInstruction* ChooseOperandForFusionProcessing(HloInstruction* instr) {
- CHECK_GT(instr->operand_count(), 0);
- CHECK_LE(instr->operand_count(), 2);
-
- // TODO(b/326217416): Extend the broadcast of splat constants/parameters to a
- // broadcast of any op.
- if (instr->operand_count() > 1 &&
- (IsBroadcastOfScalarConstant(*instr->operand(0)) ||
- IsSupportedBroadcastOfParameter(*instr->operand(0)))) {
- return instr->mutable_operand(1);
- }
- return instr->mutable_operand(0);
-}
-
-bool IsTriviallyFusible(HloInstruction* instr,
- const se::GpuComputeCapability& gpu_version,
- int num_allowed_users = 1) {
- // Checks whether an op is trivially fusible. An op is said to be trivially
- // fusible if it does not increase the amount of memory read/written by the
- // resulting fusion, is compatible with any chosen tiling, and can be
- // codegen'd using Triton. The op is allowed to have up to num_allowed_users
- // users.
- if (instr->user_count() > num_allowed_users ||
- !HasDefaultLayout(instr->shape())) {
- return false;
- }
-
- if (instr->opcode() == HloOpcode::kBitcast &&
- BitcastIsTilingNoop(instr, gpu_version)) {
- return true;
- }
-
- if (instr->IsElementwise() && instr->operand_count() == 1) {
- return static_cast<bool>(IsTritonSupportedInstruction(*instr, gpu_version));
- }
-
- // Elementwise binary ops are trivially fusible if the operands are the same,
- // or if exactly one of the operands is a splat constant.
- if (instr->IsElementwiseBinary()) {
- const HloInstruction* operand_0 = instr->operand(0);
- const HloInstruction* operand_1 = instr->operand(1);
-
- // Elementwise binary ops should be fused if both operands are the same and
- // if the operand is triton supported.
- if (operand_0 == operand_1) {
- return static_cast<bool>(
- IsTritonSupportedInstruction(*instr, gpu_version));
- }
-
- // For simplicity we only fuse elementwise binary ops with splat operands
- // if they contain one non-splat operand.
- // TODO(b/326217416): Extend the broadcast of splat constants/parameters to
- // a broadcast of any op.
- if ((IsBroadcastOfScalarConstant(*operand_0) ||
- IsSupportedBroadcastOfParameter(*operand_0)) ^
- (IsBroadcastOfScalarConstant(*operand_1) ||
- IsSupportedBroadcastOfParameter(*operand_1))) {
- return static_cast<bool>(
- IsTritonSupportedInstruction(*instr, gpu_version));
- }
- }
-
- return false;
-}
-
-bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
- HloOpcode opcode,
- const se::GpuComputeCapability& gpu_version) {
- while (consumer->opcode() != opcode) {
- if (IsTriviallyFusible(consumer, gpu_version)) {
- consumer = ChooseOperandForFusionProcessing(consumer);
- } else {
- return false;
- }
- }
-
- *producer = consumer;
- return true;
-}
-
-bool IsTriviallyConnectedProducerOf(
- HloInstruction* producer, HloInstruction* consumer,
- const se::GpuComputeCapability& gpu_version) {
- if (producer == consumer) {
- return true;
- }
-
- HloInstruction* found_producer = consumer;
- while (
- TrivialEdge(&found_producer, consumer, producer->opcode(), gpu_version)) {
- if (found_producer == producer) {
- return true;
- }
-
- if (!IsTriviallyFusible(found_producer, gpu_version)) {
- return false;
- }
-
- consumer = found_producer->mutable_operand(0);
- }
-
- return false;
-}
-
-// Finds the first non-fusible producer of a diamond. This instruction is either
-// 1. the direct producer of the diamond, if that producer is used more than
-// twice and/or is not otherwise trivially fusible
-// 2. the first parent instruction of the producer of the diamond such that
-// that instruction is used more than once, and/or is not trivially
-// fusible.
-HloInstruction* FindFirstNonFusibleDiamondProducer(
- HloInstruction* diamond_producer,
- const se::GpuComputeCapability& gpu_version) {
- if (IsTriviallyFusible(diamond_producer, gpu_version,
- /*num_allowed_users=*/2)) {
- diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
- while (IsTriviallyFusible(diamond_producer, gpu_version)) {
- diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
- }
- }
-
- return diamond_producer;
-}
-
-// Creates a fusion corresponding to the input diamond chain. The resulting
-// fusion instruction is added to the module, but is not yet inserted into the
-// graph as a replacement of the original instructions.
-//
-// TODO(b/347956491): this awkward abstraction is needed to work around
-// limitations of HloFusionAdaptor, which underpins the implementation of
-// SymbolicTileAnalysis. We need to come up with a better solution.
-absl::StatusOr<HloFusionInstruction*> MakeFusionForDiamondChain(
- const DiamondChainDescriptor& diamond_chain) {
- auto [root, producer] = diamond_chain;
-
- std::string suggested_name = "triton_softmax";
- HloComputation::Builder builder(absl::StrCat(suggested_name, "_computation"));
- // Original instruction -> fused one.
- absl::flat_hash_map<const HloInstruction*, HloInstruction*>
- old_to_new_mapping;
-
- int param = 0;
- old_to_new_mapping[producer] =
- builder.AddInstruction(HloInstruction::CreateParameter(
- param, producer->shape(), absl::StrCat("parameter_", param)));
- param++;
-
- std::vector<HloInstruction*> parameters = {producer};
-
- std::function<void(HloInstruction*)> create_computation =
- [&](HloInstruction* instr) -> void {
- if (old_to_new_mapping.contains(instr)) {
- return;
- }
- std::vector<HloInstruction*> new_operands;
- for (HloInstruction* operand : instr->mutable_operands()) {
- create_computation(operand);
- new_operands.push_back(old_to_new_mapping[operand]);
- }
- if (instr->opcode() == HloOpcode::kParameter) {
- old_to_new_mapping[instr] =
- builder.AddInstruction(HloInstruction::CreateParameter(
- param, instr->shape(), absl::StrCat("parameter_", param)));
- parameters.push_back(instr);
- param++;
- } else {
- old_to_new_mapping[instr] = builder.AddInstruction(
- instr->CloneWithNewOperands(instr->shape(), new_operands));
- }
- };
- create_computation(root);
-
- HloComputation* computation =
- root->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
- /*is_entry=*/false);
-
- HloInstruction* softmax_fusion =
- root->parent()->AddInstruction(HloInstruction::CreateFusion(
- root->shape(), HloInstruction::FusionKind::kCustom, parameters,
- computation));
-
- softmax_fusion->GetModule()->SetAndUniquifyInstrName(softmax_fusion,
- "triton_softmax");
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- softmax_fusion->backend_config<GpuBackendConfig>());
- FusionBackendConfig& backend_config =
- *gpu_config.mutable_fusion_backend_config();
- backend_config.set_kind(std::string(kTritonFusionKind));
- TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(gpu_config));
- return xla::Cast<HloFusionInstruction>(softmax_fusion);
-}
-
-absl::Status FuseDiamondChainImpl(
- const DiamondChainDescriptor& diamond_chain,
- GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model) {
- TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
- MakeFusionForDiamondChain(diamond_chain));
- HloInstruction* root = diamond_chain.root;
-
- auto fusion_adaptor = HloFusionAdaptor::ForInstruction(softmax_fusion);
-
- TF_ASSIGN_OR_RETURN(
- TiledRunTimeDataOrError tiled_runtime_data_or,
- indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor));
-
- if (const auto* fusion_decision =
- std::get_if<FusionDecision>(&tiled_runtime_data_or)) {
- return absl::FailedPreconditionError(absl::StrCat(
- "SymbolicTileAnalysis failed. ", fusion_decision->Explain()));
- }
-
- TiledRunTimeData tiled_runtime_data =
- std::get<TiledRunTimeData>(std::move(tiled_runtime_data_or));
-
- TF_ASSIGN_OR_RETURN(auto backend_config,
- softmax_fusion->backend_config<GpuBackendConfig>());
- *backend_config.mutable_fusion_backend_config()
- ->mutable_block_level_fusion_config() =
- tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig();
- TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config));
-
- if (root->IsRoot()) {
- root->parent()->set_root_instruction(softmax_fusion);
- TF_RETURN_IF_ERROR(
- root->parent()->RemoveInstructionAndUnusedOperands(root));
- } else {
- TF_RETURN_IF_ERROR(
- root->parent()->ReplaceInstruction(root, softmax_fusion));
- }
-
- VLOG(5) << softmax_fusion->ToString();
- return absl::OkStatus();
-}
-
-// Returns `true` if the diamond chain passed as a parameter can be tiled
-// correctly using `SymbolicTileAnalysis`.
-absl::StatusOr<bool> CanSymbolicTileAnalysisTileDiamondChain(
- const DiamondChainDescriptor& diamond_chain) {
- TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
- MakeFusionForDiamondChain(diamond_chain));
- mlir::MLIRContext context;
- SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error =
- SymbolicTileAnalysis::AnalyzeComputation(
- *softmax_fusion->called_computation(), &context);
-
- bool can_tile = std::holds_alternative<SymbolicTileAnalysis>(
- symbolic_tile_analysis_or_error);
-
- TF_RETURN_IF_ERROR(diamond_chain.root->GetModule()->RemoveEmbeddedComputation(
- softmax_fusion->called_computation()));
- TF_RETURN_IF_ERROR(
- diamond_chain.root->parent()->RemoveInstruction(softmax_fusion));
-
- return can_tile;
-}
-
-FusionDecision ShouldFuseReduction(const HloInstruction& reduce,
- const se::GpuComputeCapability& cc) {
- if (CodegenDecision is_supported = IsTritonSupportedInstruction(reduce, cc);
- !is_supported) {
- return FusionDecision(is_supported.Explain());
- }
-
- // Ensure that the reduction's identity is either a constant or a supported
- // convert of a constant.
- const HloInstruction* identity = reduce.operand(1);
- bool should_fuse_identity =
- identity->opcode() == HloOpcode::kConstant ||
- (identity->opcode() == HloOpcode::kConvert &&
- identity->operand(0)->opcode() == HloOpcode::kConstant &&
- IsTritonSupportedInstruction(*identity, cc));
- if (!should_fuse_identity) {
- return "Reduction identity is not a constant or a supported convert of a "
- "constant.";
- }
-
- return {};
-}
-
-DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl(
- HloInstruction* instr, const se::GpuComputeCapability& cc) {
- if (!instr->IsElementwiseBinary()) {
- return "Root is not elementwise binary.";
- }
-
- if (!IsTritonSupportedInstruction(*instr, cc)) {
- return "Root is not supported for Triton instruction.";
- }
-
- HloInstruction* producer;
- HloInstruction* broadcast;
- HloInstruction* reduce;
-
- if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast,
- cc)) {
- return "Could not find a trivial connection from root to a broadcast.";
- }
-
- if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce,
- cc)) {
- return "Could not find a trivial connection from matched broadcast to a "
- "reduction.";
- }
-
- if (!(HasDefaultLayout(broadcast->shape()) &&
- HasDefaultLayout(reduce->shape()))) {
- return "Broadcast or reduce have non-default layouts.";
- }
-
- if (FusionDecision should_fuse_reduction = ShouldFuseReduction(*reduce, cc);
- !should_fuse_reduction) {
- VLOG(2) << should_fuse_reduction.Explain();
- return should_fuse_reduction;
- }
-
- // Ensure that the reduction's identity is either a constant or a supported
- // convert of a constant.
- const HloInstruction* identity = reduce->operand(1);
- bool should_fuse_identity =
- identity->opcode() == HloOpcode::kConstant ||
- (identity->opcode() == HloOpcode::kConvert &&
- identity->operand(0)->opcode() == HloOpcode::kConstant &&
- IsTritonSupportedInstruction(*identity, cc));
- if (!should_fuse_identity) {
- return "Reduction identity is not a constant or a supported convert of a "
- "constant.";
- }
-
- if (!HasOneUse(broadcast) || !HasOneUse(reduce)) {
- return "More than one use of broadcast or reduce.";
- }
-
- producer = reduce->mutable_operand(0);
-
- if (absl::c_linear_search(broadcast->dimensions(),
- broadcast->shape().rank() - 1)) {
- return "Broadcast is not along the reduction dimension.";
- }
-
- while (IsTriviallyFusible(producer, cc)) {
- producer = ChooseOperandForFusionProcessing(producer);
- }
-
- if (!HasDefaultLayout(producer->shape())) {
- return "Producer has non-default layout.";
- }
-
- if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0),
- cc)) {
- return "Producer is not trivially connected.";
- }
-
- if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) {
- return "Unsupported root-producer connection.";
- }
-
- VLOG(5) << "Matched Softmax diamond with: ";
- VLOG(5) << "root: " << instr->ToString();
- VLOG(5) << "producer: " << producer->ToString();
- VLOG(5) << "broadcast: " << broadcast->ToString();
- VLOG(5) << "reduce: " << reduce->ToString();
-
- return producer;
-}
-
-// Returns a vector containing all the single diamonds in the parameter module.
-// The diamonds are returned in def-before-use order, and grouped by
-// computation.
-absl::StatusOr<std::vector<DiamondChainDescriptor>> FindAllFusibleDiamonds(
- HloModule& module,
- const absl::flat_hash_set<absl::string_view>& execution_threads,
- const se::GpuComputeCapability& cc) {
- std::vector<DiamondChainDescriptor> matched_diamonds;
-
- for (HloComputation* comp :
- module.MakeNonfusionComputations(execution_threads)) {
- if (comp->IsCustomCallComputation()) {
- continue;
- }
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- auto producer =
- MatchesTritonCompatibleClosedReductionDiamondImpl(instr, cc);
- if (std::holds_alternative<HloInstruction*>(producer)) {
- DiamondChainDescriptor diamond_chain{
- /*root=*/instr, /*producer=*/std::get<HloInstruction*>(producer)};
- // We filter out the diamond chains that cannot be tiled correctly using
- // `SymbolicTileAnalysis`.
- TF_ASSIGN_OR_RETURN(
- bool can_tile_diamond_chain,
- CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
- if (can_tile_diamond_chain) {
- matched_diamonds.push_back(diamond_chain);
- } else {
- VLOG(5) << "Cannot tile the diamond pattern described by "
- << "instructions " << instr->ToString() << " and "
- << std::get<HloInstruction*>(producer)->ToString() << ".";
- continue;
- }
-
- } else {
- VLOG(5) << "Cannot match the diamond pattern for instruction "
- << instr->ToString()
- << ". Reason: " << std::get<FusionDecision>(producer).Explain();
- }
- }
- }
-
- return std::move(matched_diamonds);
-}
-
-// Returns the size of the reduction dimension of the input diamond.
-int64_t GetReductionDimensionSizeForDiamond(
- const DiamondChainDescriptor& diamond_chain) {
- HloInstruction* diamond_root = diamond_chain.root;
- HloInstruction* instr = diamond_root->mutable_operand(1);
- while (instr->opcode() != HloOpcode::kReduce) {
- instr = ChooseOperandForFusionProcessing(instr);
- }
-
- int operand_rank = instr->operand(0)->shape().rank();
- CHECK_EQ(instr->dimensions().size(), 1);
- CHECK_EQ(instr->dimensions(0), operand_rank - 1);
- return instr->operand(0)->shape().dimensions(operand_rank - 1);
-}
-
-// Returns a pointer to the last user of `instr` that is trivially fusible.
-HloInstruction* GetLastTriviallyFusibleUser(
- HloInstruction* instr, const se::GpuComputeCapability& cc) {
- while (HasOneUse(instr) && !instr->IsRoot() &&
- IsTriviallyFusible(instr->users().front(), cc)) {
- instr = instr->users().front();
- }
-
- // We do not care about the number of users for the last instruction of the
- // fusion, so attempt to fuse one more instruction with this relaxed
- // restriction.
- if (HasOneUse(instr) && !instr->IsRoot() &&
- IsTriviallyFusible(
- instr->users().front(), cc,
- /*num_allowed_users=*/instr->users().front()->user_count())) {
- instr = instr->users().front();
- }
- return instr;
-}
-
-} // anonymous namespace
-
-DiamondMatchingDecision
-SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
- HloInstruction* instr) const {
- return MatchesTritonCompatibleClosedReductionDiamondImpl(
- instr, device_info_.gpu_compute_capability());
-}
-
-absl::StatusOr<std::vector<DiamondChainDescriptor>>
-SoftmaxRewriterTriton::FindAllFusibleDiamondChains(
- HloModule& module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) const {
- const se::GpuComputeCapability& cc = device_info_.gpu_compute_capability();
- TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> matched_diamonds,
- FindAllFusibleDiamonds(module, execution_threads, cc));
-
- if (matched_diamonds.empty()) {
- return std::vector<DiamondChainDescriptor>();
- }
-
- // If we matched several diamonds, it may be possible for some of them to be
- // fused together. This is the case if the following conditions hold:
- // 1. The path between the root of diamond n towards the producer of
- // diamond n+1 is composed only of trivially fusible operations. In that
- // case, the first non-trivially fusible producer of diamond n+1 must be
- // exactly the root of diamond n.
- // 2. The root of diamond n/first non-fusible producer of diamond n+1 must
- // have
- // a. exactly one user if it is not exactly the producer of diamond
- // n+1;
- // b/ exactly two users otherwise.
- // 3. The axis being reduced must have the same length in all the diamonds
- // being fused together.
- //
- // Crucially, this approach relies on a diamond root never being considered a
- // trivially fusible operation.
- std::vector<DiamondChainDescriptor> diamond_chains;
- diamond_chains.reserve(matched_diamonds.size());
-
- HloInstruction* current_fusion_producer =
- FindFirstNonFusibleDiamondProducer(matched_diamonds.front().producer, cc);
- int current_reduce_dimension_size =
- GetReductionDimensionSizeForDiamond(matched_diamonds.front());
-
- for (int diamond_idx = 1; diamond_idx < matched_diamonds.size();
- ++diamond_idx) {
- HloInstruction* diamond_producer = matched_diamonds[diamond_idx].producer;
- HloInstruction* previous_diamond_root =
- matched_diamonds[diamond_idx - 1].root;
-
- HloInstruction* first_non_fusible_diamond_producer =
- FindFirstNonFusibleDiamondProducer(diamond_producer, cc);
-
- int diamond_reduce_dimension_size =
- GetReductionDimensionSizeForDiamond(matched_diamonds[diamond_idx]);
-
- if (first_non_fusible_diamond_producer == previous_diamond_root && // 1
- ((first_non_fusible_diamond_producer != diamond_producer &&
- HasOneUse(first_non_fusible_diamond_producer)) || // 2.a
- (first_non_fusible_diamond_producer == diamond_producer &&
- first_non_fusible_diamond_producer->user_count() == 2)) && // 2.b
- diamond_reduce_dimension_size == current_reduce_dimension_size) { // 3
- continue;
- }
-
- // The "last trivially fusible user" chain of diamond chain n should never
- // intersect with the "first non fusible diamond producer" chain of diamond
- // chain n+1: if these chains intersected, then all the intermediate ops
- // between the diamond chains could be trivially fused, and both diamond
- // chains could be fused into a single diamond chain. Note that this only
- // holds insofar as we do not allow fusing in bitcasts that modify the last
- // dimension of the input array. It is however possible for the last
- // trivially fusible user of diamond chain n to be the first non fusible
- // diamond producer of diamond chain n+1.
- diamond_chains.push_back(DiamondChainDescriptor{
- GetLastTriviallyFusibleUser(previous_diamond_root, cc),
- current_fusion_producer,
- });
-
- current_fusion_producer = first_non_fusible_diamond_producer;
- current_reduce_dimension_size = diamond_reduce_dimension_size;
- }
-
- // The last diamond chain is still open; close it.
- diamond_chains.push_back(DiamondChainDescriptor{
- GetLastTriviallyFusibleUser(matched_diamonds.back().root, cc),
- current_fusion_producer});
-
- // We filter out the diamond chains that cannot be tiled correctly using
- // `SymbolicTileAnalysis`.
- std::vector<DiamondChainDescriptor> filtered_diamond_chains;
- for (const DiamondChainDescriptor& diamond_chain : diamond_chains) {
- TF_ASSIGN_OR_RETURN(bool can_tile_diamond_chain,
- CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
- if (can_tile_diamond_chain) {
- filtered_diamond_chains.push_back(diamond_chain);
- }
- }
- return filtered_diamond_chains;
-}
-
-absl::Status SoftmaxRewriterTriton::FuseDiamondChain(
- const DiamondChainDescriptor& diamond_chain) {
- HloFusionAnalysisCache fusion_analysis_cache(device_info_);
- GpuPerformanceModelWithIndexingAnalysis indexing_performance_model(
- &device_info_, &fusion_analysis_cache, shape_size_, &mlir_context_);
-
- return FuseDiamondChainImpl(diamond_chain, indexing_performance_model);
-}
-
-absl::StatusOr<bool> SoftmaxRewriterTriton::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability(
- device_info_.gpu_compute_capability()));
-
- TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
- FindAllFusibleDiamondChains(*module, execution_threads));
-
- if (diamond_chains.empty()) {
- return false;
- }
-
- // The diamond chains must be emitted in reverse order, to make sure that
- // producer instructions are emitted correctly when the root of
- // diamond chain n is exactly the producer of diamond chain n+1.
- for (auto diamond_chain = diamond_chains.rbegin();
- diamond_chain != diamond_chains.rend(); ++diamond_chain) {
- TF_RET_CHECK(FuseDiamondChain(*diamond_chain).ok());
- }
- return true;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h
deleted file mode 100644
index 9da8cc5..0000000
--- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h
+++ /dev/null
@@ -1,101 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_
-#define XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_
-
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-struct DiamondChainDescriptor {
- HloInstruction* root = nullptr;
- HloInstruction* producer = nullptr;
-};
-
-using DiamondMatchingDecision = std::variant<FusionDecision, HloInstruction*>;
-
-// Rewrite compatible Softmax into a custom fusion region to be code-generated
-// with the Triton-based Softmax emitter.
-class SoftmaxRewriterTriton : public HloModulePass {
- public:
- explicit SoftmaxRewriterTriton(const se::DeviceDescription& device_info,
- HloCostAnalysis::ShapeSizeFunction shape_size)
- : device_info_(device_info), shape_size_(shape_size) {}
-
- absl::string_view name() const override { return "triton-softmax-rewriter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- // Finds and returns all the fusible diamond chains in the module. The
- // resulting vector is sorted according to a post-order matching (i.e. within
- // the same computation, producer diamonds appear before consumer diamonds).
- absl::StatusOr<std::vector<DiamondChainDescriptor>>
- FindAllFusibleDiamondChains(
- HloModule& module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) const;
-
- // Constructs a Softmax fusion containing all the instructions between the
- // root and the producer of a diamond chain. The producer is excluded from the
- // fusion.
- absl::Status FuseDiamondChain(const DiamondChainDescriptor& diamond_chain);
-
- // Return the producer of the following pattern:
- //
- // producer
- // | \
- // | reduce_{max,sum,...}
- // | |
- // | broadcast
- // | /
- // binop (elementwise)
- //
- // where each edge is allowed to contain also trivial operations that can be
- // generated by Triton. We mean by "trivial" here those operations that do not
- // increase the amount of memory read/written by the fusion, and that are
- // compatible with any chosen tiling.
- //
- // We also assume that the reduction is done on the last axis of the producer
- // array.
- DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamond(
- HloInstruction* instr) const;
-
- private:
- const se::DeviceDescription& device_info_;
- const HloCostAnalysis::ShapeSizeFunction shape_size_;
- mlir::MLIRContext mlir_context_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_
diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc
deleted file mode 100644
index 8488031..0000000
--- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc
+++ /dev/null
@@ -1,1590 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-
-You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/softmax_rewriter_triton.h"
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <variant>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using ::testing::HasSubstr;
-
-GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() {
- return [&](const Shape& shape) {
- constexpr int64_t kPointerSize = 8;
- return ShapeUtil::ByteSizeOf(shape, kPointerSize);
- };
-}
-
-bool HasBlockLevelFusionConfig(const HloInstruction* fusion) {
- return fusion->opcode() == HloOpcode::kFusion &&
- fusion->has_backend_config() &&
- fusion->backend_config<GpuBackendConfig>().ok() &&
- fusion->backend_config<GpuBackendConfig>()
- ->fusion_backend_config()
- .has_block_level_fusion_config();
-}
-
-// Wrapper around SoftmaxRewriterTriton(gpu_version).Run(module) that finds
-// and fuses as many diamond chains as possible without invoking any kind of
-// cost analysis.
-absl::StatusOr<bool> SoftmaxRewriterTritonMatchAndRewrite(
- const se::DeviceDescription& device_info, HloModule* module) {
- CHECK_NE(module, nullptr);
- SoftmaxRewriterTriton softmax_rewriter_triton(device_info,
- ShapeSizeBytesFunction());
- TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
- softmax_rewriter_triton.FindAllFusibleDiamondChains(
- *module, /*execution_threads=*/{}));
-
- for (auto diamond_chain = diamond_chains.rbegin();
- diamond_chain != diamond_chains.rend(); ++diamond_chain) {
- TF_RETURN_IF_ERROR(
- softmax_rewriter_triton.FuseDiamondChain(*diamond_chain));
- }
-
- return !diamond_chains.empty();
-}
-
-class SoftmaxRewriterTritonTest
- : public HloTestBase,
- public ::testing::WithParamInterface<PrimitiveType> {
- protected:
- se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()};
-};
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseExactSoftmaxF32) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- exponential = f32[127,125]{1,0} exponential(subtract)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- VLOG(2) << module->ToString();
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseSoftmaxLikeComputationWithNonF32DataType) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f16[] parameter(0)
- arg_1 = f16[] parameter(1)
- ROOT maximum = f16[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f16[] parameter(0)
- arg_1.1 = f16[] parameter(1)
- ROOT add = f16[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f16[127,125]{1,0} parameter(0)
- constant_neg_inf = f16[] constant(-inf)
- reduce = f16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f16[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f16[127,125]{1,0} subtract(param_0, broadcast)
- // Replace Softmax exponential with abs, because Triton doesn't support
- // non-f32 exponentials.
- abs = f16[127,125]{1,0} abs(subtract)
- constant_zero = f16[] constant(0)
- second_reduce = f16[127]{0} reduce(abs, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f16[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- // Replace divide with multiply, because Triton doesn't support f16
- // divisions.
- ROOT multiply = f16[127,125]{1,0} multiply(abs, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseSingleNormalizationDiamond) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- DoesNotFuseDiamondInvolvingUnsupportedTritonInstruction) {
- const std::string hlo_string = R"(
-HloModule softmax
-add_computation {
- arg_0.1 = bf16[] parameter(0)
- arg_1.1 = bf16[] parameter(1)
- ROOT add = bf16[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = bf16[127,125]{1,0} parameter(0)
- constant_zero = bf16[] constant(0)
- reduce = bf16[127]{0} reduce(param_0, constant_zero), dimensions={1}, to_apply=add_computation
- broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT divide = bf16[127,125]{1,0} divide(param_0, broadcast)
-})";
-
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- const HloInstruction* bf16_divide =
- module->entry_computation()->root_instruction();
- EXPECT_FALSE(IsTritonSupportedInstruction(
- *bf16_divide, device_info_.gpu_compute_capability()));
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- DoesNotFuseInstructionsUnsupportedByTritonIntoDiamonds) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = bf16[] parameter(0)
- arg_1 = bf16[] parameter(1)
- ROOT maximum = bf16[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = bf16[127,125]{1,0} parameter(0)
- constant_neg_inf = bf16[] constant(-inf)
- reduce = bf16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = bf16[127,125]{1,0} subtract(param_0, broadcast)
- ROOT exponential = bf16[127,125]{1,0} exponential(subtract)
-})";
-
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- const HloInstruction* bf16_exponential =
- hlo_query::GetFirstInstructionWithOpcode(*module->entry_computation(),
- HloOpcode::kExp);
- EXPECT_FALSE(IsTritonSupportedInstruction(
- *bf16_exponential, device_info_.gpu_compute_capability()));
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Exp(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithWrongLayout) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{0,1} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseSoftmaxDiamondWithWrongReduceDimension) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={0}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={1}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseSoftmaxDiamondWithWrongBroadcastDimension) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[125,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[125,125]{1,0} broadcast(reduce), dimensions={1}
- ROOT subtract = f32[125,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseSoftmaxDiamondWithExtraBroadcastUsage) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- ROOT multiply = f32[127,125]{1,0} multiply(broadcast, subtract)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseSoftmaxWithIntermediateUnaryElementwise) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- abs = f32[127,125]{1,0} abs(subtract)
- exponential = f32[127,125]{1,0} exponential(abs)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(subtract, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- ROOT divide = f32[127,125]{1,0} divide(subtract, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseDiamondWithTrailingUnaryElementwiseAtTheRoot) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- ROOT abs = f32[127,125]{1,0} abs(subtract)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseDiamondWithUnaryElementwisePrefix) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- abs = f32[127,125]{1,0} abs(param_0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(abs, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseDiamondWithMultipleBroadcastDimensions) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[1,3,125,125]{3,2,1,0} parameter(0)
- bitcast = f32[3,125,125]{2,1,0} bitcast(f32[1,3,125,125]{3,2,1,0} param_0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[3,125]{1,0} reduce(f32[3,125,125]{2,1,0} bitcast, f32[] constant_neg_inf), dimensions={2}, to_apply=max_computation
- broadcast = f32[1,3,125,125]{3,2,1,0} broadcast(f32[3,125]{1,0} reduce), dimensions={1,2}
- ROOT subtract = f32[1,3,125,125]{3,2,1,0} subtract(f32[1,3,125,125]{3,2,1,0} param_0, f32[1,3,125,125]{3,2,1,0} broadcast)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseSoftmaxDiamondWithParameterReducerIdentity) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- identity = f32[] parameter(1)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseSoftmaxDiamondWithTritonIncompatibleReducer) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- if_0 = pred[] is-finite(arg_0)
- c = f32[] convert(if_0)
- ROOT maximum = f32[] maximum(c, arg_1)
-}
-
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseSoftmaxDiamondWithLastDimensionBitcastAfterReduce) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
- param_0 = f32[3,127,125]{2,1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[3,127]{1,0} reduce(param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
- bitcasted_reduce = f32[381]{0} bitcast(reduce)
- broadcast = f32[381,125]{1,0} broadcast(bitcasted_reduce), dimensions={0}
- bitcasted_broadcast = f32[3,127,125]{2,1,0} bitcast(broadcast)
- ROOT subtract = f32[3,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseSoftmaxDiamondWithTransposeBitcast) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
- param_0 = f32[1,127,125]{2,1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- bitcasted_param_0 = f32[127,1,125]{2,0,1} bitcast(param_0)
- reduce = f32[127,1]{0,1} reduce(bitcasted_param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
- broadcast = f32[127,1,125]{2,0,1} broadcast(reduce), dimensions={0,1}
- bitcasted_broadcast = f32[1,127,125]{2,1,0} bitcast(broadcast)
- ROOT subtract = f32[1,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseTwoDiamondsWithDifferentReductionAxisSizeTogether) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,625]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,625]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,625]{1,0} subtract(param_0, broadcast)
- bitcasted_subtract = f32[127,5,125] bitcast(subtract)
- exponential = f32[127,5,125] exponential(bitcasted_subtract)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127,5] reduce(exponential, constant_zero), dimensions={2}, to_apply=add_computation
- second_broadcast = f32[127,5,125] broadcast(second_reduce), dimensions={0,1}
- ROOT divide = f32[127,5,125] divide(exponential, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Bitcast(m::Fusion(m::Parameter())
- .WithPredicate(HasBlockLevelFusionConfig)))
- .WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseTwoDiamondsWithExtraUsageForFirstDiamondRoot) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- exponential = f32[127,125]{1,0} exponential(subtract)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
- ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, subtract)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
- m::Fusion(m::Parameter())
- .WithPredicate(HasBlockLevelFusionConfig))));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseTwoDiamondsWithExtraUsageForSecondDiamondProducer) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- exponential = f32[127,125]{1,0} exponential(subtract)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
- ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, exponential)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseSoftmaxDiamondWithTritonIncompatibleProducer) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
- param_0 = f16[127,125]{1,0} parameter(0)
- exponential = f16[127,125] exponential(param_0)
- convert = f32[127,125] convert(exponential)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(convert, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(convert, broadcast)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Exp(m::Parameter()))
- .WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanNotFuseSoftmaxDiamondWithNonFusibleBitcastBetweenReduceAndProducer) {
- const std::string hlo_string = R"(
-HloModule softmax
-
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
- param_0 = f32[1,127,5,25]{3,2,1,0} parameter(0)
- bitcast_0 = f32[127,125] bitcast(param_0)
- bitcast_1 = f32[127,125] bitcast(param_0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseSoftmaxDiamondWithBitcastProducerFollowedByBitcastsOnEachUse) {
- const std::string hlo_string = R"(
-HloModule softmax
-
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
- param_0 = f32[1,1,127,125]{3,2,1,0} parameter(0)
- bitcast_parent = f32[127,125]{1,0} bitcast(param_0)
- bitcast_0 = f32[127,125]{1,0} bitcast(bitcast_parent)
- bitcast_1 = f32[127,125]{1,0} bitcast(bitcast_parent)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnPreAmpereCudaGpu) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = bf16[127,125]{1,0} parameter(0)
- param_0_f32 = f32[127,125]{1,0} convert(param_0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
-})";
-
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_THAT(
- SoftmaxRewriterTriton(
- TestGpuDeviceInfo::RTXA6000DeviceInfo(
- se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}),
- ShapeSizeBytesFunction())
- .Run(module.get()),
- tsl::testing::StatusIs(
- tsl::error::FAILED_PRECONDITION,
- ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
- "(compute capability 8.0) and up, but got")));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, RewriterSucceedsOnNonCudaGpu) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = bf16[127,125]{1,0} parameter(0)
- param_0_f32 = f32[127,125]{1,0} convert(param_0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
-})";
-
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
- EXPECT_TRUE(SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(),
- ShapeSizeBytesFunction())
- .Run(module.get())
- .ok());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) {
- const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- multiply = f32[127,125]{1,0} multiply(param_0, param_0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(multiply, broadcast)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- CanFuseIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- multiply = f32[127]{0} multiply(reduce, reduce)
- broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) {
- const std::string hlo_string = R"(
-HloModule fusible_diamonds
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- multiply = f32[127,125]{1,0} multiply(subtract, subtract)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- ROOT subtract_second = f32[127,125]{1,0} subtract(multiply, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) {
- const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- ROOT multiply = f32[127,125]{1,0} multiply(subtract, subtract)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- DoesNotFuseIntermediateBinaryElementwiseWithBothSplatOperandsIntoDiamond) {
- const std::string hlo_string = R"(
-HloModule nonfusible_splat
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- constant_0 = f32[] constant(0.333333343)
- splat_0 = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
- constant_1 = f32[] constant(0.66666)
- splat_1 = f32[127,125]{1,0} broadcast(constant_1), dimensions={}
- param_0 = f32[127,125]{1,0} parameter(0)
- multiply_splats = f32[127,125]{1,0} multiply(splat_0, splat_1)
- multiply_splat_param = f32[127,125]{1,0} multiply(multiply_splats, param_0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(multiply_splat_param, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- DoesNotFuseIntermediateBinaryElementwiseWithSameSplatOperandsIntoDiamond) {
- const std::string hlo_string = R"(
-HloModule nonfusible_splat_diamond
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- constant_0 = f32[] constant(0.333333343)
- splat = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
- param_0 = f32[127,125]{1,0} parameter(0)
- multiply = f32[127,125]{1,0} multiply(splat, splat)
- add = f32[127,125]{1,0} add(param_0, multiply)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(add, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- SoftmaxRewriterTriton fusion_rewriter(device_info_, ShapeSizeBytesFunction());
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseRMSNormDiamond) {
- const std::string hlo_string = R"(
-HloModule rms_norm
-add_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT add.1 = f32[] add(arg_0, arg_1)
-}
-ENTRY main.30 {
- param_0 = f32[10,10,10,128]{3,2,1,0} parameter(0)
- multiply_param = f32[10,10,10,128]{3,2,1,0} multiply(param_0, param_0)
- constant_0 = f32[] constant(0)
- reduce = f32[10,10,10]{2,1,0} reduce(multiply_param, constant_0), dimensions={3}, to_apply=add_computation
- constant_1 = f32[] constant(0.333333343)
- splat = f32[10,10,10]{2,1,0} broadcast(constant_1), dimensions={}
- multiply_splat = f32[10,10,10]{2,1,0} multiply(reduce, splat)
- epsilon = f32[] constant(1e-06)
- splat_epsilon = f32[10,10,10]{2,1,0} broadcast(epsilon), dimensions={}
- add = f32[10,10,10]{2,1,0} add(multiply_splat, splat_epsilon)
- rsqrt = f32[10,10,10]{2,1,0} rsqrt(add)
- broadcast = f32[10,10,10,128]{3,2,1,0} broadcast(rsqrt), dimensions={0,1,2}
- ROOT multiply = f32[10,10,10,128]{3,2,1,0} multiply(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get())
- .value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Fusion(m::Parameter())
- .WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule fusible_diamonds
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- constant = f32[] constant(0.333333343)
- broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
- multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule fusible_diamonds
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- constant = f32[] constant(0.333333343)
- broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
- multiply = f32[127,125]{1,0} multiply(subtract, broadcast_splat)
- constant_zero = f32[] constant(0)
- second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- CanFuseBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) {
- const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- constant = f32[] constant(0.333333343)
- broadcast_splat = f32[127]{0} broadcast(constant), dimensions={}
- multiply = f32[127]{0} multiply(broadcast_splat, reduce)
- broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstant) {
- const std::string hlo_string = R"(
-HloModule fusible_diamond
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
- constant = f32[] constant(0.333333343)
- broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
- ROOT multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
- CanFuseBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducer) {
- const std::string hlo_string = R"(
-HloModule nonfusible_diamond
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT max = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_2 = f32[] constant(0.333333343)
- broadcast_splat = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
- param_1 = f32[127,125]{1,0} parameter(1)
- multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, param_1)
- multiply = f32[127,125]{1,0} multiply(param_0, broadcast_splat)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
- EXPECT_TRUE(verifier().Run(module.get()).status().ok());
- VLOG(2) << module->ToString();
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- DoesNotFuseBinaryElementwiseOperationWhereFirstOperandIsASplatAndSecondOperandIsASharedSplatProducer) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule nonfusible_diamond
-add_computation {
- arg_0.1 = f32[] parameter(0)
- arg_1.1 = f32[] parameter(1)
- ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- constant_2 = f32[] constant(0.333333343)
- broadcast_splat_shared = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
- param_1 = f32[127,125]{1,0} parameter(1)
- multiply_splat_shared = f32[127,125]{1,0} multiply(broadcast_splat_shared, param_1)
- constant_3 = f32[] constant(0.5)
- broadcast_splat = f32[127,125]{1,0} broadcast(constant_3), dimensions={}
- multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, broadcast_splat_shared)
- multiply = f32[127,125]{1,0} multiply(param_0, multiply_splat)
- constant_neg_inf = f32[] constant(-inf)
- reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=add_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest, FusionDecisionIsCapturedExplicitly) {
- const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
- arg_0 = f32[] parameter(0)
- arg_1 = f32[] parameter(1)
- ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
- param_0 = f32[127,125]{1,0} parameter(0)
- identity_f8 = f8e5m2[] parameter(1)
- identity = f32[] convert(identity_f8)
- reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
- broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
- ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- SoftmaxRewriterTriton softmax_rewriter_triton(device_info_,
- ShapeSizeBytesFunction());
- int unmatched = 0, matched = 0;
- for (HloInstruction* instruction :
- module->entry_computation()->MakeInstructionPostOrder()) {
- DiamondMatchingDecision decision =
- softmax_rewriter_triton.MatchesTritonCompatibleClosedReductionDiamond(
- instruction);
- if (std::holds_alternative<FusionDecision>(decision)) {
- std::string actual_decision =
- std::get<FusionDecision>(decision).Explain();
- EXPECT_THAT(
- actual_decision,
- AnyOf(
- HasSubstr("Root is not elementwise binary"),
- HasSubstr("identity is not a constant or a supported convert")));
- unmatched++;
- } else {
- matched++;
- }
- }
- EXPECT_EQ(unmatched, 6);
- EXPECT_EQ(matched, 0);
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongReductionDimAsParameter) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- p0 = f32[32]{0} parameter(0)
- p1 = f32[32,16]{1,0} parameter(1)
- c = f32[] constant(0)
-
- r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
- b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
- b1 = f32[32,16]{1,0} broadcast(p0), dimensions={0}
- add0 = f32[32,16]{1,0} add(b1, p1)
- ROOT add1 = f32[32,16]{1,0} add(add0, b0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- p0 = f32[16]{0} parameter(0)
- p1 = f32[32,16]{1,0} parameter(1)
- c = f32[] constant(0)
-
- r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
- b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
- b1 = f32[32,16]{1,0} broadcast(p0), dimensions={1}
- add0 = f32[32,16]{1,0} add(b1, p1)
- ROOT add1 = f32[32,16]{1,0} add(add0, b0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- FusesBinaryElementwiseIfIntermediateDiamondOpWithMultiDimTensorBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- p0 = f32[32,16]{1,0} parameter(0)
- p1 = f32[64,32,16]{2,1,0} parameter(1)
- c = f32[] constant(0)
-
- r0 = f32[64,32]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
- b0 = f32[64,32,16]{2,1,0} broadcast(r0), dimensions={0,1}
- b1 = f32[64,32,16]{2,1,0} broadcast(p0), dimensions={1,2}
- add0 = f32[64,32,16]{2,1,0} add(b1, p1)
- ROOT add1 = f32[64,32,16]{2,1,0} add(add0, b0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- FusesBinaryElementwiseIfIntermediateDiamondOpWithZeroDimTensorBroadcastAsParameter) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- parameter_0 = f32[] parameter(0)
- parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
- c = f32[] constant(0)
-
- reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
- broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
- broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={}
- add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
- ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- FusesBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongNonReductionDimensions) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- parameter_0 = f32[16] parameter(0)
- parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
- c = f32[] constant(0)
-
- reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
- broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
- broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={2}
- add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
- ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_TRUE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongBothBatchAndReductionDimensions) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- parameter_0 = f32[64] parameter(0)
- parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
- c = f32[] constant(0)
-
- reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
- broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
- broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={0}
- add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
- ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchAndReductionDimAsParameter) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- p0 = f32[8]{0} parameter(0)
- p1 = f32[32,8,16]{2,1,0} parameter(1)
- c = f32[] constant(0)
-
- r0 = f32[32,8]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
- b0 = f32[32,8,16]{2,1,0} broadcast(r0), dimensions={0,1}
- b1 = f32[32,8,16]{2,1,0} broadcast(p0), dimensions={1}
- add0 = f32[32,8,16]{2,1,0} add(b1, p1)
- ROOT add1 = f32[32,8,16]{2,1,0} add(add0, b0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithPartialBroadcastToBatchDim) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- p0 = f32[16,64]{1,0} parameter(0)
- p1 = f32[8,16,32,64]{3,2,1,0} parameter(1)
- c = f32[] constant(0)
-
- r0 = f32[8,16,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
- b0 = f32[8,16,32,64]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
- b1 = f32[8,16,32,64]{3,2,1,0} broadcast(p0), dimensions={1,3}
- add0 = f32[8,16,32,64]{3,2,1,0} add(b1, p1)
- ROOT add1 = f32[8,16,32,64]{3,2,1,0} add(add0, b0)
-}
-)";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
- SoftmaxRewriterTritonTest,
- DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithMultiDimBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length)
- const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
- y = f32[] parameter(1)
- x = f32[] parameter(0)
- ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
- p0 = f32[32,16]{1,0} parameter(0)
- p1 = f32[128,64,32,16]{3,2,1,0} parameter(1)
- c = f32[] constant(0)
-
- r0 = f32[128,64,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
- b0 = f32[128,64,32,16]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
- b1 = f32[128,64,32,16]{3,2,1,0} broadcast(p0), dimensions={2,3}
- add0 = f32[128,64,32,16]{3,2,1,0} add(b1, p1)
- ROOT add1 = f32[128,64,32,16]{3,2,1,0} add(add0, b0)
-})";
- auto module = ParseAndReturnVerifiedModule(hlo_string).value();
- EXPECT_FALSE(
- SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-} // anonymous namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc
deleted file mode 100644
index 7e54ea5..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc
+++ /dev/null
@@ -1,219 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_annotator.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-bool IsOnlyRootNonDefaultStream(HloComputation* computation) {
- HloInstruction* root = computation->root_instruction();
- auto root_gpu_config = root->backend_config<GpuBackendConfig>();
- if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple) {
- return false;
- }
- int64_t root_stream_id = root_gpu_config->operation_queue_id();
- VLOG(2) << "Found fusion computation's root stream id to be "
- << root_stream_id;
- if (root_stream_id == Thunk::kDefaultExecutionStreamId.value()) {
- return false;
- }
- for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
- if (instr == root) {
- continue;
- }
- int64_t instr_stream_id =
- instr->backend_config<GpuBackendConfig>()->operation_queue_id();
- if (instr_stream_id != Thunk::kDefaultExecutionStreamId.value() &&
- instr_stream_id != root_stream_id) {
- return false;
- }
- }
- return true;
-}
-
-absl::StatusOr<bool> AnnotateStreamAttributesForInstruction(
- HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
- if (instr->called_computations().size() != 1) {
- return false;
- }
- HloComputation* called_comp = instr->called_computations()[0];
- int64_t stream_id = instr_gpu_config.operation_queue_id();
-
- if (!IsOnlyRootNonDefaultStream(called_comp) ||
- stream_id != Thunk::kDefaultExecutionStreamId.value()) {
- return false;
- }
-
- auto comp_root_gpu_config =
- called_comp->root_instruction()->backend_config<GpuBackendConfig>();
-
- instr_gpu_config.set_operation_queue_id(
- comp_root_gpu_config->operation_queue_id());
- *instr_gpu_config.mutable_wait_on_operation_queues() =
- comp_root_gpu_config->wait_on_operation_queues();
- TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
- return true;
-}
-
-absl::StatusOr<bool> AnnotateStreamAttributesForCopyStart(
- HloInstruction* instr, int64_t channel_id,
- GpuBackendConfig& instr_gpu_config) {
- // Do nothing if copy-start has already been annotated
- if (instr_gpu_config.operation_queue_id() !=
- Thunk::kDefaultExecutionStreamId.value()) {
- return false;
- }
- instr_gpu_config.set_operation_queue_id(channel_id);
- TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
- VLOG(3) << "Add copy-start's backend config: " << channel_id;
- return true;
-}
-
-absl::StatusOr<bool> WrapIntoFusionAndAnnotateStreamAttributes(
- HloInstruction* instruction, int64_t channel_id,
- GpuBackendConfig& instr_gpu_config) {
- auto* computation = instruction->parent();
- auto* module = computation->parent();
- auto* fusion_instruction =
- computation->AddInstruction(HloInstruction::CreateFusion(
- instruction->shape(), ChooseFusionKind(*instruction, *instruction),
- instruction));
- const absl::string_view wrapped_opcode =
- HloOpcodeString(instruction->opcode());
- module->SetAndUniquifyInstrName(fusion_instruction,
- absl::StrCat("wrapped_", wrapped_opcode));
- module->SetAndUniquifyComputationName(
- fusion_instruction->fused_instructions_computation(),
- absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
- if (module->has_schedule()) {
- module->schedule().replace_instruction(computation, instruction,
- fusion_instruction);
- }
- TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction));
- TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
- TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
- TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
-
- instr_gpu_config.set_operation_queue_id(channel_id);
- TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config));
- VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction "
- << instruction->ToString();
- VLOG(3) << " Fusion wrapper: " << fusion_instruction->ToString();
- return true;
-}
-
-absl::StatusOr<bool> AnnotateStreamAttributesForUsers(
- HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
- bool changed = false;
- int64_t stream_id = instr_gpu_config.operation_queue_id();
- if (stream_id == Thunk::kDefaultExecutionStreamId.value()) {
- return changed;
- }
- std::vector<HloInstruction*> all_consumers;
- for (auto user : instr->users()) {
- if (user->opcode() == HloOpcode::kGetTupleElement) {
- user = user->users()[0];
- }
- all_consumers.push_back(user);
- }
-
- for (auto user : all_consumers) {
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- user->backend_config<GpuBackendConfig>());
- auto it = absl::c_find(gpu_config.wait_on_operation_queues(), stream_id);
- if (it == gpu_config.wait_on_operation_queues().end() &&
- gpu_config.operation_queue_id() != stream_id) {
- gpu_config.mutable_wait_on_operation_queues()->Add(stream_id);
- TF_RETURN_IF_ERROR(user->set_backend_config(gpu_config));
- changed = true;
- }
- }
-
- return changed;
-}
-} // namespace
-
-absl::StatusOr<bool> StreamAttributeAnnotator::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_VLOG_LINES(
- 5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString());
- bool changed = false;
- int64_t channel_id = hlo_query::NextChannelId(*module);
- for (const HloComputation* comp :
- module->MakeComputationPostOrder(execution_threads)) {
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
- if (!instr_gpu_config.ok()) {
- continue;
- }
- // For fusion instruction, only annotate
- // when the root of fusion is a single instruction
- // running on non-default stream.
- if (instr->opcode() == HloOpcode::kFusion) {
- TF_ASSIGN_OR_RETURN(bool comp_result,
- AnnotateStreamAttributesForInstruction(
- instr, instr_gpu_config.value()));
- changed |= comp_result;
- } else if (instr->opcode() == HloOpcode::kCopyStart) {
- TF_ASSIGN_OR_RETURN(bool comp_result,
- AnnotateStreamAttributesForCopyStart(
- instr, channel_id, instr_gpu_config.value()));
- changed |= comp_result;
- continue;
- } else if (comp->IsAsyncComputation() &&
- (instr->opcode() == HloOpcode::kDynamicSlice ||
- instr->opcode() == HloOpcode::kDynamicUpdateSlice)) {
- TF_ASSIGN_OR_RETURN(bool comp_result,
- WrapIntoFusionAndAnnotateStreamAttributes(
- instr, channel_id, instr_gpu_config.value()));
- changed |= comp_result;
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(
- bool user_result,
- AnnotateStreamAttributesForUsers(instr, instr_gpu_config.value()));
- changed |= user_result;
- }
- }
- XLA_VLOG_LINES(
- 5, "StreamAttributeAnnotator::Run(), after:\n" + module->ToString());
- return changed;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/stream_attribute_annotator.h
deleted file mode 100644
index 8a0284a..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass checks to see if:
-// 1. there's any instruction, that
-// consumes data from other computes streams,
-// is missing "wait_on_operation_queues" attribute.
-// 2. there's any fusion instruction with non-default
-// stream fusion root.
-// It will annotate the corresponding instruction with
-// the correct attribute in GpuBackendConfig.
-// Instructions annotated with operation_queue_id > 0
-// will be wrapped with AsyncInstruction and split into
-// AsyncStart and AsyncDone in the
-// StreamAttributeAsyncWrapper pass.
-// We also check if there's any non-default-stream
-// instruction's user doesn't have the correct "wait_on_operation_queues"
-// attribute and set it with producer's operation_queue_id.
-// "wait_on_operation_queues" will need to used by the emitter to emit the
-// correct WaitForStreams thunk.
-
-class StreamAttributeAnnotator : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "stream-attribute-annotator";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc
deleted file mode 100644
index 17d9b2f..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc
+++ /dev/null
@@ -1,287 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_annotator.h"
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using StreamAttributeAnnotatorTest = HloTestBase;
-
-TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) {
- constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithAsync
-
- ENTRY entry {
- p1_32 = f32[1] parameter(0)
- p2_32 = f32[1] parameter(1)
- add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
- exp_32 = f32[1] exponential(add_32)
-
- neg32 = f32[1] negate(add_32)
- ROOT add_out_32 = f32[1] add(neg32, exp_32)
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- StreamAttributeAnnotator attr_annotator;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
- EXPECT_TRUE(changed);
-
- const HloInstruction* add = FindInstruction(module.get(), "add_32");
- for (auto user : add->users()) {
- // Every user should have an annotation.
- EXPECT_TRUE(user->has_backend_config());
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- user->backend_config<GpuBackendConfig>());
- EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
- }
-}
-
-TEST_F(StreamAttributeAnnotatorTest, MultipleStreamsAreCombined) {
- constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithAsync
-
- ENTRY entry {
- p1_32 = f32[1] parameter(0)
- p2_32 = f32[1] parameter(1)
- add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
- exp_32 = f32[1] exponential(p2_32), backend_config={"operation_queue_id":"2", "wait_on_operation_queues":[]}
-
- ROOT add_out_32 = f32[1] add(add_32, exp_32)
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- StreamAttributeAnnotator attr_annotator;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
- EXPECT_TRUE(changed);
-
- const HloInstruction* root = module->entry_computation()->root_instruction();
- // Root should wait on 2 streams.
- EXPECT_TRUE(root->has_backend_config());
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- root->backend_config<GpuBackendConfig>());
- std::vector<int64_t> expected_stream_ids = {1, 2};
- for (auto id : expected_stream_ids) {
- auto it = absl::c_find(gpu_config.wait_on_operation_queues(), id);
- EXPECT_NE(it, gpu_config.wait_on_operation_queues().end());
- }
-}
-
-TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) {
- constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithAsync
-
- ENTRY entry {
- p1_32 = f32[16,32] parameter(0)
- p2_32 = f32[32,16] parameter(1)
-
- custom-call.3 = (f32[16,16], s8[1028]{0}) custom-call(p1_32, p2_32), custom_call_target="__cublas$gemm", backend_config={"operation_queue_id":"1","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","grad_x":false,"grad_y":false}}
- get-tuple-element.24 = f32[16,16] get-tuple-element(custom-call.3), index=0
-
- exp_32 = f32[16,16] exponential(get-tuple-element.24)
-
- ROOT neg32 = f32[16,16] negate(exp_32)
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- StreamAttributeAnnotator attr_annotator;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
- EXPECT_TRUE(changed);
-
- const HloInstruction* exp = FindInstruction(module.get(), "exp_32");
- EXPECT_TRUE(exp->has_backend_config());
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- exp->backend_config<GpuBackendConfig>());
- EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
-}
-
-TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) {
- constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithFusion
-
- fused_computation.1 {
- fusion_p0_32 = f32[16,16] parameter(0)
- fusion_p2_32 = f32[16,16] parameter(1)
- ROOT add = f32[16,16] add(fusion_p0_32, fusion_p2_32), backend_config={"operation_queue_id":"1","wait_on_operation_queues":[]}
- }
-
- ENTRY entry {
- p1_32 = f32[16,16] parameter(0)
- p2_32 = f32[16,16] parameter(1)
- ROOT fusion.1 = f32[16,16] fusion(p1_32, p2_32), kind=kLoop, calls=fused_computation.1
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- StreamAttributeAnnotator attr_annotator;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
- EXPECT_TRUE(changed);
-
- const HloInstruction* fusion = FindInstruction(module.get(), "fusion.1");
- EXPECT_TRUE(fusion->has_backend_config());
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- fusion->backend_config<GpuBackendConfig>());
- EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-}
-
-TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) {
- constexpr absl::string_view kHloString = R"(
- HloModule offloading
- ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] {
- %param_1 = f32[1024]{0} parameter(1)
- %param_0 = f32[1024]{0} parameter(0)
- %res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1)
- %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3)
- %res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3)
- %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4)
- %res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4)
- %copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start)
- %res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5)
- %copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2)
- %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2)
- %res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6)
- %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done)
- %res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5)
- %copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3)
- %res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3)
- %copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1)
- %res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1)
- ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10)
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- StreamAttributeAnnotator attr_annotator;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
- EXPECT_TRUE(changed);
-
- for (std::string i : {"", ".1", ".2", ".3"}) {
- const HloInstruction* cp_start =
- FindInstruction(module.get(), "copy-start" + i);
- EXPECT_TRUE(cp_start->has_backend_config());
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- cp_start->backend_config<GpuBackendConfig>());
- EXPECT_EQ(gpu_config.operation_queue_id(), 1);
- }
-}
-
-TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) {
- constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithAsyncDynamicUpdateSlice
-
- ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] {
- param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0)
- param_1 = f32[1,128,128]{2,1,0} parameter(1)
- izero = s32[] constant(0)
- dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[])
- dynamic-update-slice-start(param_0, param_1, izero, izero, izero)
- ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)}
- dynamic-update-slice-done(dynamic-update-slice-start.2)
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- StreamAttributeAnnotator().Run(module.get()));
- EXPECT_TRUE(changed);
-
- // Check that the dynamic-update-slice instruction is wrapped in a fusion
- // and the fusion is annotated with the correct operation_queue_id.
- const HloInstruction* dus =
- FindInstruction(module.get(), HloOpcode::kDynamicUpdateSlice);
- const HloComputation* computation = dus->parent();
- EXPECT_TRUE(computation->IsFusionComputation());
- const HloInstruction* fusion = computation->FusionInstruction();
- EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
- EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
-
- EXPECT_TRUE(fusion->has_backend_config());
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- fusion->backend_config<GpuBackendConfig>());
- EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-}
-
-TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) {
- constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithAsyncDynamicSlice
-
- ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] {
- param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0)
- izero = s32[] constant(0)
- dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[])
- dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128}
- ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0}
- dynamic-slice-done(dynamic-slice-start.2)
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- StreamAttributeAnnotator().Run(module.get()));
- EXPECT_TRUE(changed);
-
- // Check that the dynamic-slice instruction is wrapped in a fusion
- // and the fusion is annotated with the correct operation_queue_id.
- const HloInstruction* ds =
- FindInstruction(module.get(), HloOpcode::kDynamicSlice);
- const HloComputation* computation = ds->parent();
- EXPECT_TRUE(computation->IsFusionComputation());
- const HloInstruction* fusion = computation->FusionInstruction();
- EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
- EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
-
- EXPECT_TRUE(fusion->has_backend_config());
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- fusion->backend_config<GpuBackendConfig>());
- EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-}
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc b/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc
deleted file mode 100644
index 822c647..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_async_wrapper.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-namespace {
-static absl::StatusOr<bool> AsynchronizeInstruction(HloInstruction* instr) {
- auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
- if (!instr_gpu_config.ok() || instr_gpu_config->operation_queue_id() ==
- Thunk::kDefaultExecutionStreamId.value()) {
- return false;
- }
- HloComputation* computation = instr->parent();
- TF_ASSIGN_OR_RETURN(
- HloInstruction * done,
- computation->CreateAsyncInstructions(
- instr, {}, StreamAttributeAsyncWrapper::kParallelExecutionThread,
- /*replace=*/true));
- TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
- done->backend_config<GpuBackendConfig>());
- // Set the false delay of done op to be false so it can be scheduled
- // far apart from start.
- gpu_config.set_force_earliest_schedule(false);
- TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config));
- VLOG(5) << "Created async instruction: " << done->ToString();
- return true;
-}
-} // namespace
-
-absl::StatusOr<bool> StreamAttributeAsyncWrapper::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- XLA_VLOG_LINES(
- 2, "StreamAttributeAsyncWrapper::Run(), before:\n" + module->ToString());
- bool changed = false;
- for (const HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* instr : comp->instructions()) {
- TF_ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr));
- changed |= result;
- }
- }
- XLA_VLOG_LINES(
- 2, "StreamAttributeAsyncWrapper::Run(), after:\n" + module->ToString());
- return changed;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h b/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h
deleted file mode 100644
index 95fe7bb..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
-#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass will find the instructions that
-// are annotated with non-default stream id in backend configs
-// by the StreamAttributeAnnotator pass
-// and wrap them using AsyncStartDone pairs to achieve
-// asynchronous executions.
-class StreamAttributeAsyncWrapper : public HloModulePass {
- public:
- inline static constexpr char kParallelExecutionThread[] = "parallel";
-
- absl::string_view name() const override {
- return "async-stream-attribute-wrapper";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc b/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc
deleted file mode 100644
index 8b3dcb2..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc
+++ /dev/null
@@ -1,77 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_async_wrapper.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using StreamAttributeAsyncWrapperTest = HloTestBase;
-
-TEST_F(StreamAttributeAsyncWrapperTest, NonDefaultOpIsWrapped) {
- constexpr absl::string_view kHloString = R"(
- HloModule ModuleWithAsync
-
- ENTRY entry {
- p1_32 = f32[1] parameter(0)
- p2_32 = f32[1] parameter(1)
- add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[], "force_earliest_schedule":true}
- ROOT exp_32 = f32[1] exponential(add_32), backend_config={"operation_queue_id":"0", "wait_on_operation_queues":[1]}
- }
- )";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(kHloString));
-
- StreamAttributeAsyncWrapper async_wrapper;
- bool changed;
- TF_ASSERT_OK_AND_ASSIGN(changed, async_wrapper.Run(module.get()));
- EXPECT_TRUE(changed);
- const HloInstruction* producer =
- module->entry_computation()->root_instruction()->operand(0);
- EXPECT_EQ(producer->opcode(), HloOpcode::kAsyncDone);
- // Verify that the force_earliest_schedule is set to false for the done op.
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig done_gpu_config,
- producer->backend_config<GpuBackendConfig>());
- EXPECT_EQ(done_gpu_config.force_earliest_schedule(), false);
-
- const HloInstruction* producer_start = producer->operand(0);
- EXPECT_EQ(producer_start->opcode(), HloOpcode::kAsyncStart);
-
- const xla::HloAsyncInstruction* async =
- Cast<HloAsyncInstruction>(producer_start);
- EXPECT_EQ(async->async_wrapped_opcode(), HloOpcode::kAdd);
- // Verify that the backend config is kept intact
- TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
- async->backend_config<GpuBackendConfig>());
- EXPECT_EQ(gpu_config.operation_queue_id(), 1);
- EXPECT_EQ(gpu_config.force_earliest_schedule(), true);
- EXPECT_EQ(async->async_execution_thread(), "parallel");
-}
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD
index 487b6d9..e18dd57 100644
--- a/third_party/xla/xla/service/gpu/tests/BUILD
+++ b/third_party/xla/xla/service/gpu/tests/BUILD
@@ -162,44 +162,6 @@
],
)
-xla_cc_test(
- name = "gpu_reduce_scatter_creator_test",
- srcs = ["gpu_reduce_scatter_creator_test.cc"],
- deps = [
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_module_config",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service/gpu:gpu_reduce_scatter_creator",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/log",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings:string_view",
- "@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:statusor",
- ],
-)
-
-xla_cc_test(
- name = "gpu_all_gather_optimizer_test",
- srcs = ["gpu_all_gather_optimizer_test.cc"],
- deps = [
- "//xla:util",
- "//xla/hlo/ir:hlo",
- "//xla/service:hlo_module_config",
- "//xla/service/gpu:gpu_all_gather_optimizer",
- "//xla/tests:hlo_test_base",
- "//xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- ],
-)
-
xla_test(
name = "gpu_spmd_e2e_compile_test",
size = "small",
@@ -221,65 +183,6 @@
)
xla_test(
- name = "gemm_rewrite_test",
- srcs = ["gemm_rewrite_test.cc"],
- backends = ["gpu"],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
- "TENSORFLOW_USE_ROCM=1",
- ]),
- deps = [
- ":gpu_codegen_test",
- "//xla:error_spec",
- "//xla:test",
- "//xla:xla_proto_cc",
- "//xla/hlo/ir:hlo",
- "//xla/service:buffer_assignment",
- "//xla/service:executable",
- "//xla/service:hlo_module_config",
- "//xla/service:hlo_pass",
- "//xla/service:pattern_matcher",
- "//xla/service:pattern_matcher_gmock",
- "//xla/service/gpu:gemm_rewriter",
- "//xla/service/gpu:gpu_executable",
- "//xla/stream_executor:device_description",
- "//xla/stream_executor:device_memory_allocator",
- "//xla/stream_executor:stream_executor_memory_allocator",
- "//xla/tests:filecheck",
- "//xla/tests:verified_hlo_module",
- "//xla/tsl/lib/core:status_test_util",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/status:statusor",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test_main",
- ] + if_cuda_is_configured([
- "@local_config_cuda//cuda:cuda_headers",
- ]) + if_rocm_is_configured([
- "@local_config_rocm//rocm:rocm_headers",
- ]),
-)
-
-xla_test(
- name = "gemm_broadcast_folding_rewrite_test",
- srcs = ["gemm_broadcast_folding_rewrite_test.cc"],
- backends = ["gpu"],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
- "TENSORFLOW_USE_ROCM=1",
- ]),
- deps = [
- ":gpu_codegen_test",
- "//xla:error_spec",
- "//xla/hlo/ir:hlo",
- "//xla/service/gpu:gemm_broadcast_folding_rewriter",
- "//xla/service/gpu:gemm_rewriter",
- "@local_tsl//tsl/platform:statusor",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-xla_test(
name = "gpu_too_many_blocks_test",
srcs = [
"gpu_too_many_blocks_test.cc",
@@ -297,51 +200,6 @@
],
)
-xla_cc_test(
- name = "reduction_degenerate_dim_remover_test",
- srcs = [
- "reduction_degenerate_dim_remover_test.cc",
- ],
- deps = [
- "//xla/service/gpu:reduction_degenerate_dim_remover",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-xla_test(
- name = "reduction_layout_normalizer_test",
- srcs = [
- "reduction_layout_normalizer_test.cc",
- ],
- backends = ["gpu"],
- deps = [
- "//xla:error_spec",
- "//xla/service/gpu:reduction_layout_normalizer",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
-xla_cc_test(
- name = "tree_reduction_rewriter_test",
- srcs = [
- "tree_reduction_rewriter_test.cc",
- ],
- deps = [
- "//xla/service/gpu:tree_reduction_rewriter",
- "//xla/stream_executor:device_description",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
xla_test(
name = "swap_conv_operands_test",
srcs = [
@@ -375,20 +233,6 @@
],
)
-xla_cc_test(
- name = "reduction_dimension_grouper_test",
- srcs = [
- "reduction_dimension_grouper_test.cc",
- ],
- deps = [
- "//xla/service/gpu:reduction_dimension_grouper",
- "//xla/tests:hlo_test_base",
- "@com_google_absl//absl/strings:string_view",
- "@local_tsl//tsl/platform:test",
- "@local_tsl//tsl/platform:test_main",
- ],
-)
-
xla_test(
name = "parallel_reduction_test",
srcs = [
@@ -640,7 +484,7 @@
"//xla/hlo/ir:hlo",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu:gpu_fusible",
- "//xla/service/gpu:instruction_fusion",
+ "//xla/service/gpu/transforms:instruction_fusion",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:test_main",
],
@@ -655,10 +499,10 @@
"//xla:shape_util",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_pass_pipeline",
- "//xla/service/gpu:fusion_merger",
"//xla/service/gpu:gpu_device_info_for_tests",
- "//xla/service/gpu:instruction_fusion",
- "//xla/service/gpu:multi_output_fusion",
+ "//xla/service/gpu/transforms:fusion_merger",
+ "//xla/service/gpu/transforms:instruction_fusion",
+ "//xla/service/gpu/transforms:multi_output_fusion",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:test_main",
@@ -857,8 +701,7 @@
# name = "xla-opt",
# srcs = ["xla-opt.cc"],
# deps = [
-# "//xla/service/gpu/fusions/triton:prevent_mmav3_loop_unrolling",
-# "//xla/service/gpu/fusions/triton:sparse_extensions",
+# "//xla/service/gpu/fusions/triton:passes",
# "@llvm-project//mlir:AllExtensions",
# "@llvm-project//mlir:MlirOptLib",
# "@triton//:AllPassesAndDialects",
@@ -972,7 +815,7 @@
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
- "//xla/service/gpu:gpu_sort_rewriter",
+ "//xla/service/gpu/transforms:sort_rewriter",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
diff --git a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
deleted file mode 100644
index 7e98e97..0000000
--- a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
+++ /dev/null
@@ -1,229 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest {
- protected:
- const auto& GpuComputeComp() {
- return backend()
- .default_stream_executor()
- ->GetDeviceDescription()
- .gpu_compute_capability();
- }
-
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- // These tests test the cuBLAS rewriter so we have to make sure that we use
- // cuBLAS for them.
- debug_options.set_xla_gpu_enable_triton_gemm(false);
- debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
- return debug_options;
- }
-};
-
-TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteRhs) {
- const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
- x = f32[3,2,2]{2,1,0} parameter(0)
- y = f32[2,2]{1,0} parameter(1)
- y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
- ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,2], {{.*}}: f32[2,2]) -> f32[3,2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["2"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":["0"]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteLhs) {
- const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
- x = f32[2,2]{1,0} parameter(0)
- y = f32[3,2,2]{2,1,0} parameter(1)
- x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
- ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[3,2,2]) -> f32[3,2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK : custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
-; CHECK : backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest,
- BroadcastedStridedRewriteRhsPassChanged) {
- const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
- x = f32[3,2,2]{2,1,0} parameter(0)
- y = f32[2,2]{1,0} parameter(1)
- y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
- ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- // Use GemmRewriter to generate cublasGemm call.
- GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- this->RunHloPass(&gemm_rewriter, module.get()));
- EXPECT_TRUE(changed);
- GemmBroadcastFoldingRewriter pass;
- TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest,
- BroadcastedStridedRewriteLhsPassChanged) {
- const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
- x = f32[2,2]{1,0} parameter(0)
- y = f32[3,2,2]{2,1,0} parameter(1)
- x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
- ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- // Use GemmRewriter to generate cublasGemm call.
- GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- this->RunHloPass(&gemm_rewriter, module.get()));
- EXPECT_TRUE(changed);
- GemmBroadcastFoldingRewriter pass;
- TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest, LHSBatchDimNonZero) {
- const char* hlo_text = R"(
-HloModule LHSBatchDimNonZero
-
-ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
- %Arg_1 = f32[4,3]{1,0} parameter(0)
- %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
- %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
- ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[7,4,3]{2,1,0} %broadcast.22, f32[4,7,3]{2,1,0} %Arg_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
-}
-)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- // Use GemmRewriter to generate cublasGemm call.
- GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- this->RunHloPass(&gemm_rewriter, module.get()));
- EXPECT_TRUE(changed);
- GemmBroadcastFoldingRewriter pass;
- TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest, RHSBatchDimNonZero) {
- const char* hlo_text = R"(
-HloModule RHSBatchDimNonZero
-
-ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
- %Arg_1 = f32[4,3]{1,0} parameter(0)
- %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
- %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
- ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[4,7,3]{2,1,0} %Arg_2, f32[7,4,3]{2,1,0} %broadcast.22), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2}
-}
-)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- this->RunHloPass(&gemm_rewriter, module.get()));
- EXPECT_TRUE(changed);
- GemmBroadcastFoldingRewriter pass;
- TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
- EXPECT_FALSE(changed);
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
deleted file mode 100644
index f412f1f..0000000
--- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ /dev/null
@@ -1,8313 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <array>
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/gpu_executable.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/test.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#elif TENSORFLOW_USE_ROCM
-#include "rocm/rocm_config.h"
-#endif
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-namespace m = ::xla::match;
-
-class GemmRewriteTest : public GpuCodegenTest {
- const auto& device_desc() {
- return backend().default_stream_executor()->GetDeviceDescription();
- }
-
- protected:
- const se::GpuComputeCapability& Capability() {
- return device_desc().gpu_compute_capability();
- }
-
- int32_t GetToolkitVersion() const {
-#if GOOGLE_CUDA
- return CUDA_VERSION;
-#elif TENSORFLOW_USE_ROCM
- return TF_ROCM_VERSION;
-#endif
- return 0;
- }
-
- bool IsCuda() {
- return std::holds_alternative<se::CudaComputeCapability>(Capability());
- }
-
- se::GpuComputeCapability CudaHopperOrRocmMI300() {
- if (IsCuda()) {
- return se::CudaComputeCapability::Hopper();
- } else {
- return se::RocmComputeCapability{"gfx942"};
- }
- }
-
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- // These tests test the cuBLAS rewriter so we have to make sure that we use
- // cuBLAS for them.
- debug_options.set_xla_gpu_enable_triton_gemm(false);
- debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
- return debug_options;
- }
-
- bool SkipGpuBlasLtTest() {
- return !IsCuda() &&
- !std::get<se::RocmComputeCapability>(Capability()).has_hipblaslt() &&
- GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
- }
-
- bool HasFp8Support() {
- if (IsCuda()) {
- return std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(8, 9);
- }
- return std::get<se::RocmComputeCapability>(Capability()).has_fp8_support();
- }
-
- bool HasCudaComputeCapability(const se::CudaComputeCapability& cc) {
- return IsCuda() &&
- std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(cc);
- }
-};
-
-TEST_F(GemmRewriteTest, CheckCustomCallTarget) {
- if (SkipGpuBlasLtTest()) {
- GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
- }
-
- const char* hlo_text = R"(
-HloModule SimpleGemm
-
-ENTRY AddDotsFunc {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
- DebugOptions debug_options = GetDebugOptionsForTest();
- if (debug_options.xla_gpu_enable_cublaslt()) {
- MatchOptimizedHlo(hlo_text,
- R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
- } else {
- MatchOptimizedHlo(hlo_text,
- R"(; CHECK: custom_call_target="__cublas$gemm")");
- }
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-TEST_F(GemmRewriteTest, TestBatchedAutotuning) {
- if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP()
- << "There is no autotuning starting with the Nvidia Ampere generation";
- }
-
- const char* hlo_text = R"(
-HloModule ComplexDotMultipleNonContracting
-
-ENTRY %test {
- %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
- %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
- ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
-}
-
-)";
-
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: selected_algorithm
- )");
-}
-#endif
-
-TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) {
- if (SkipGpuBlasLtTest()) {
- GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
- }
-
- const char* hlo_text = R"(
-HloModule SimpleGemm
-
-ENTRY AddDotsFunc {
- x = f32[128,128] parameter(0)
- y = f32[128,128] parameter(1)
- ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- ErrorSpec error_spec = [&] {
- DebugOptions debug_options = GetDebugOptionsForTest();
- if (debug_options.xla_gpu_enable_cublaslt()) {
- return ErrorSpec{1e-3, 1e-3};
- } else {
- return ErrorSpec{1e-3, 1e-3};
- }
- }();
-
- auto get_module = [&]() {
- HloModuleConfig config;
- DebugOptions debug_options = GetDebugOptionsForTest();
- debug_options.set_xla_gpu_exclude_nondeterministic_ops(true);
- config.set_debug_options(debug_options);
- return ParseAndReturnVerifiedModule(hlo_text, config);
- };
-
- se::StreamExecutorMemoryAllocator allocator(
- backend().default_stream_executor());
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> optimized_module,
- backend().compiler()->RunHloPasses(
- *get_module(), backend().default_stream_executor(), &allocator));
-
- absl::StatusOr<bool> filecheck_result =
- RunFileCheck(optimized_module->ToString(),
- R"(
-; CHECK: custom_call_target="__cublas${{(lt\$matmul|gemm)}}"
- )");
- TF_ASSERT_OK(filecheck_result.status());
- EXPECT_TRUE(filecheck_result.value());
- EXPECT_TRUE(RunAndCompare(*get_module(), error_spec));
-}
-
-TEST_F(GemmRewriteTest, BF16GemmCodeGen) {
- const char* hlo_text = R"(
-HloModule bf16codegendgemm
-
-ENTRY bf16gemm {
- %parameter.1 = bf16[3]{0} parameter(0)
- %parameter.2 = bf16[3]{0} parameter(1)
- ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-}
- )";
-
- if (HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) {
- // The Hopper optimized HLO has a BF16 multiply instruction since Hopper has
- // native BF16 multiply support.
- MatchOptimizedHlo(hlo_text, R"(
- ; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
- ; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
- ; CHECK: [[INSTR_2:%[^ ]+]] = bf16[3]{0} multiply([[P0]], [[P1]])
- ; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[INSTR_2]])
- ; CHECK: [[INSTR_4:%[^ ]+]] = f32[] constant(0)
- ; CHECK: [[INSTR_5:%[^ ]+]] = f32[] reduce([[INSTR_3]], [[INSTR_4]]), dimensions={0}, to_apply=[[INSTR_6:%[^ ]+]]
- ; CHECK: ROOT [[INSTR_7:%[^ ]+]] = bf16[] convert([[INSTR_5]])
- )");
- } else {
- MatchOptimizedHlo(hlo_text, R"(
- ; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
- ; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
- ; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
- ; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
- ; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
- ; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
- ; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
- ; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
- )");
- }
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-4}));
-}
-
-TEST_F(GemmRewriteTest, BF16Transpose) {
- const char* hlo_text = R"(
-HloModule broadcast
-
-ENTRY broadcast {
- p = bf16[9] parameter(0)
- ROOT out = bf16[1,9] broadcast(p), dimensions={1}
-}
-)";
-
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK: bf16[1,9]{1,0} bitcast
-; CHECK: bf16[1,9]{1,0} copy
-)");
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-// A test fixture class for tests which should have similar results with legacy
-// cublas and cublasLt
-class ParameterizedGemmRewriteTest
- : public GemmRewriteTest,
- public ::testing::WithParamInterface<bool> {
- public:
- ParameterizedGemmRewriteTest() {
- const bool kUsingCublasLt = GetParam();
- replacements_[kCustomCallTargetPlaceholder] =
- kUsingCublasLt ? "__cublas$lt$matmul" : "__cublas$gemm";
- }
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_cublaslt(GetParam());
- debug_options.set_xla_gpu_enable_triton_gemm(false);
- return debug_options;
- }
- void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
- bool print_operand_shape = false) {
- GemmRewriteTest::MatchOptimizedHlo(
- hlo, absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
- }
- absl::string_view CustomCallTarget() {
- return replacements_[kCustomCallTargetPlaceholder];
- }
-
- protected:
- void SetUp() override {
- if (SkipGpuBlasLtTest()) {
- GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
- }
- }
-
- protected:
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
-
- private:
- static constexpr const char* kCustomCallTargetPlaceholder{
- "<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"};
-};
-
-TEST_P(ParameterizedGemmRewriteTest, Simple) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, SimpleRewrite) {
- const char* hlo_text = R"(
-HloModule SimpleGemm
-
-ENTRY AddDotsFunc {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, MultipleContractingDims) {
- const char* hlo_text = R"(
-HloModule MultipleContractingCheckGemm
-
-ENTRY AddDotsFunc {
- x = f32[3,4,2] parameter(0)
- y = f32[3,4,5] parameter(1)
- ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, operand_precision={highest,highest}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-NOT: copy
-;
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,4,2], {{.*}}: f32[3,4,5]) -> f32[2,5] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1)
-; CHECK-DAG: [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]])
-; CHECK-DAG: [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]])
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[BITCAST0]], [[BITCAST1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, ArgTransposeFoldCheck) {
- const char* hlo_text = R"(
-HloModule ArgTransposeFoldGemm
-
-ENTRY AddDotsFunc {
- x = f32[3,2] parameter(0)
- y = f32[3,4] parameter(1)
- x_transposed = f32[2,3] transpose(x), dimensions={1, 0}
- ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["0"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchedArgRowColTransposeFoldCheck) {
- const char* hlo_text = R"(
-HloModule BatchedArgRowColTransposeFoldGemm
-
-ENTRY AddDotsFunc {
- x = f32[5,3,2] parameter(0)
- y = f32[5,3,4] parameter(1)
- x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1}
- ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,3,2], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":["0"]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchRowTransposeFoldCheck) {
- const char* hlo_text = R"(
-HloModule BatchRowTransposeFoldCheck
-
-ENTRY AddDotsFunc {
- x = f32[2,5,3] parameter(0)
- y = f32[5,3,4] parameter(1)
- x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2}
- ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,5,3], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["2"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":["1"]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) {
- const char* hlo_text = R"(
-HloModule BatchFromMinorDimTransposeDoesntFold
-
-ENTRY AddDotsFunc {
- x = f32[3,2,5] parameter(0)
- y = f32[5,3,4] parameter(1)
- x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0}
- ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,5], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]])
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[FUSION]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["2"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":["0"]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, LargeBatch) {
- const char* hlo_text = R"(
-HloModule BatchedArgRowColTransposeFoldGemm
-
-ENTRY AddDotsFunc {
- x = f32[20000,4,3,2] parameter(0)
- y = f32[20000,4,3,4] parameter(1)
- ROOT dot_a = f32[20000,4,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-}
-
-)";
-
- // Batch sizes larger than 2^16-1 are not supported by cublasLt. Ensure that
- // the custom_call_target is __cublas$gemm.
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[20000,4,3,2], {{.*}}: f32[20000,4,3,4]) -> f32[20000,4,2,4] {
-; CHECK: [[P0:%[^ ]+]] = f32[20000,4,3,2]{3,2,1,0} parameter(0)
-; CHECK: [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]])
-; CHECK: [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1)
-; CHECK: [[BC1:%[^ ]+]] = f32[80000,3,4]{2,1,0} bitcast([[P1]])
-; CHECK: [[GEMM:%[^ ]+]] = (f32[80000,2,4]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[BC0]], [[BC1]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":["0"]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK: }
-; CHECK: [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} get-tuple-element([[GEMM]]), index=0
-; CHECK: ROOT {{[^ ]+}} = f32[20000,4,2,4]{3,2,1,0} bitcast([[OUT]])
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, InstrTransposeFoldCheck) {
- const char* hlo_text = R"(
-HloModule InstrTransposeFoldGemm
-
-ENTRY AddDotsFunc {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[4,2] {
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P1]], [[P0]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["0"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutTransposed) {
- const char* hlo_text = R"(
-HloModule BatchedInstrLayoutCheck
-
-ENTRY AddDotsFunc {
- x = f32[5,2,3] parameter(0)
- y = f32[5,3,4] parameter(1)
- dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
- ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,5,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["2"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":["0"]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) {
- const char* hlo_text = R"(
-HloModule BatchedInstrLayoutBatchNotInMinorDim
-
-ENTRY AddDotsFunc {
- x = f32[5,2,3] parameter(0)
- y = f32[5,3,4] parameter(1)
- dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
- ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,4,5] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["2"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":["0"]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, AlphaSimpleRewrite) {
- const char* hlo_text = R"(
-HloModule AlphaSimpleRewrite
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- k = f32[] constant(3.0)
- k_broadcast = f32[2, 2] broadcast(k), dimensions={}
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, F64C64_CublasLtSupportTest) {
- // This test should fail if gemm rewriter does not correctly rewrite
- // F64/C64 dots to cublas-lt or legacy cublas calls
- {
- const char* hlo_text = R"(
-HloModule F64_rewrite
-
-ENTRY AddDotsFunc {
- x = f64[2,2] parameter(0)
- y = f64[2,2] parameter(1)
- k = f64[] constant(3.0)
- k_broadcast = f64[2, 2] broadcast(k), dimensions={}
- dot_a = f64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT dot_a_multiplied = f64[2, 2] multiply(dot_a, k_broadcast)
-}
-)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
- }
- {
- const char* hlo_text = R"(
-HloModule C64_rewrite
-
-ENTRY AddDotsFunc {
- x = c64[2,2] parameter(0)
- y = c64[2,2] parameter(1)
- k = c64[] constant((3.0, 3.0))
- k_broadcast = c64[2, 2] broadcast(k), dimensions={}
- dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
-}
-)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) {
- if (!IsCuda() && GetDebugOptionsForTest().xla_gpu_enable_cublaslt()) {
- GTEST_SKIP() << "TODO: Unsupported C64 gpublas-lt datatype on ROCM";
- }
- const char* hlo_text = R"(
-HloModule ComplexAlphaSimpleRewrite
-
-ENTRY AddDotsFunc {
- x = c64[2,2] parameter(0)
- y = c64[2,2] parameter(1)
- k = c64[] constant((3.0, 3.0))
- k_broadcast = c64[2, 2] broadcast(k), dimensions={}
- dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: c64[2,2], {{.*}}: c64[2,2]) -> c64[2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":3
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, AlphaMultipleUsersNoRewrite) {
- const char* hlo_text = R"(
-HloModule AlphaMultipleUsersNoRewrite
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- k = f32[] constant(3.0)
- k_broadcast = f32[2, 2] broadcast(k), dimensions={}
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
- ROOT out = f32[2,2] add(dot_a_multiplied, dot_a)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{[^ ]+}} = {{.*}} custom-call({{[^,]+}}, {{[^)]+}}),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, AlphaVectorNoRewrite) {
- const char* hlo_text = R"(
-HloModule AlphaVectorNoRewrite
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- alpha = f32[2] constant({1, 2})
- alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1}
- dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast)
-}
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BF16Gemm) {
- const char* hlo_text = R"(
-HloModule bf16gemm
-
-ENTRY bf16gemm {
- %parameter.1 = bf16[12,4]{1,0} parameter(0)
- %parameter.2 = bf16[4,8]{1,0} parameter(1)
- ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
- )";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
- )",
- /*print_operand_shape=*/true);
- } else {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BF16GemmStrided) {
- const char* hlo_text = R"(
-HloModule bf16gemm
-
-ENTRY bf16gemm {
- %parameter.1 = bf16[3,3,4] parameter(0)
- %parameter.2 = bf16[3,3,2] parameter(1)
- ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest}
-}
-
- )";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- MatchOptimizedHlo(hlo_text,
- R"(
- ; CHECK: {{.*}} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
- )",
- /*print_operand_shape=*/true);
- } else {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8Gemm) {
- const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
- %parameter.1 = s8[12,4]{1,0} parameter(0)
- %parameter.2 = s8[4,8]{1,0} parameter(1)
- ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
- )";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
- )",
- /*print_operand_shape=*/true);
- } else {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
- )",
- /*print_operand_shape=*/true);
- }
-}
-
-TEST_F(GemmRewriteTest, Int8GemmRankGreaterThanTwo) {
- if (!IsCuda()) {
- GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
- }
-
- const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY main.4 {
- Arg_0.1 = s8[1,8,2]{2,1,0} parameter(0)
- Arg_1.2 = s8[2,4]{1,0} parameter(1)
- ROOT dot.3 = s32[1,8,4]{2,1,0} dot(Arg_0.1, Arg_1.2),
- lhs_contracting_dims={2}, rhs_contracting_dims={0}
-}
- )";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %{{.*}}, s8[4,4]{0,1} %{{.*}}), custom_call_target="__cublas$gemm",
- )",
- /*print_operand_shape=*/true);
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoAlphaRewrite) {
- const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
- %parameter.1 = s8[12,4]{1,0} parameter(0)
- %parameter.2 = s8[4,8]{1,0} parameter(1)
- k = s32[] constant(2)
- k_broadcast = s32[12,8] broadcast(k), dimensions={}
- %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast)
-}
- )";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
- )",
- /*print_operand_shape=*/true);
- } else {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
- )",
- /*print_operand_shape=*/true);
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoBetaRewrite) {
- const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
- %parameter.1 = s8[12,4]{1,0} parameter(0)
- %parameter.2 = s8[4,8]{1,0} parameter(1)
- bias = s32[12,8] parameter(2)
- %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = s32[12,8] add(%dot.8, bias)
-}
- )";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
- )",
- /*print_operand_shape=*/true);
- } else {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
- )",
- /*print_operand_shape=*/true);
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8GemmNotMultipleOfFour) {
- if (!IsCuda()) {
- GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
- }
-
- const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
- %parameter.1 = s8[13,4]{1,0} parameter(0)
- %parameter.2 = s8[4,9]{1,0} parameter(1)
- ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
- )";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
- )",
- /*print_operand_shape=*/true);
- } else {
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: {{.*}} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
- )",
- /*print_operand_shape=*/true);
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) {
- if (!IsCuda()) {
- GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
- }
-
- std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
- type_combinations = {{"s8", "s8", true},
- {"s32", "s32", true},
- {"bf16", "bf16", true},
- {"f16", "f16", true},
- {"f32", "f32", true},
- {"f64", "f64", true},
- {"c64", "c64", true},
- {"c128", "c128", true},
- // add mix type gemm
- {"s8", "s32", true},
- {"s8", "f32", true},
- {"f16", "f32", true},
- {"bf16", "f32", true}};
-
- if (!IsCuda() ||
- HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- // For compute capabilities before Ampere, we may do upcasting, so it
- // would be impossible for this test to fail. That is why we only add these
- // cases when the compute capability is at least Volta.
- std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
- more_type_combinations = {
- {"s8", "bf16", false}, {"s8", "f16", false},
- {"s8", "f64", false}, {"s8", "c64", false},
- {"s8", "c128", false},
-
- {"s32", "f32", false}, {"s32", "f64", false},
- {"s32", "c64", false}, {"s32", "c128", false},
-
- {"f16", "bf16", false}, {"f16", "f64", false},
- {"f16", "c64", false}, {"f16", "c128", false},
-
- {"bf16", "f16", false}, {"bf16", "f64", false},
- {"bf16", "c64", false}, {"bf16", "c128", false},
-
- {"f32", "f64", false}, {"f32", "c64", false},
- {"f32", "c128", false},
-
- {"f64", "c64", false}, {"f64", "c128", false},
- };
- type_combinations.insert(type_combinations.end(),
- more_type_combinations.begin(),
- more_type_combinations.end());
- }
-
- for (const auto& type_combination : type_combinations) {
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<ABType>>"] = std::get<0>(type_combination);
- replacements["<<DType>>"] = std::get<1>(type_combination);
- const char* hlo_template = R"(
- HloModule type_combo
-
- ENTRY type_combo {
- %parameter.1 = <<ABType>>[4,4]{1,0} parameter(0)
- %parameter.2 = <<ABType>>[4,4]{1,0} parameter(1)
- ROOT %dot = <<DType>>[4,4] dot(%parameter.1, %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
- )";
- const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
- if (std::get<2>(type_combination)) {
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- } else {
- EXPECT_FALSE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- }
- }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingBf16ToF64) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- Arg_0.1 = bf16[4,3]{1,0} parameter(0)
- Arg_1.2 = bf16[3,6]{1,0} parameter(1)
- ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- // This is a type combination which is not supported by cublasLt, expect
- // GemmRewriter to choose legacy cublas.
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- Arg_0.1 = c64[4,3]{1,0} parameter(0)
- Arg_1.2 = c64[3,6]{1,0} parameter(1)
- ROOT dot.3 = c128[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- // This is a type combination which is not supported by cublasLt, expect
- // GemmRewriter to choose legacy cublas.
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- Arg_0.1 = f16[4,3]{1,0} parameter(0)
- Arg_1.2 = f16[3,6]{1,0} parameter(1)
- ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest, highest}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- Arg_0.1 = f16[4,3]{1,0} parameter(0)
- Arg_1.2 = f16[3,6]{1,0} parameter(1)
- ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- // This is a type combination which is not supported by cublasLt, expect
- // GemmRewriter to choose legacy cublas.
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- Arg_0.1 = f32[4,3]{1,0} parameter(0)
- Arg_1.2 = f32[3,6]{1,0} parameter(1)
- ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- // This is a type combination which is not supported by cublasLt, expect
- // GemmRewriter to choose legacy cublas.
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, DoNotUpconvertOutput) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
- param_0 = f16[240,88]{1,0} parameter(0)
- param_1 = f16[88,4]{1,0} parameter(1)
- dot = f16[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- constant_255 = f16[] constant(255)
- broadcast = f16[240,4]{1,0} broadcast(constant_255), dimensions={}
- multiply = f16[240,4]{1,0} multiply(dot, broadcast)
- ROOT result = f32[240,4]{1,0} convert(multiply)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- // input fp16 and output fp32 combination is supported by legacy cublas and
- // cublasLt, expect GemmRewriter to fuse the convert into gemm.
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Convert(
- m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UnsupportedMixTypeGemm) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
- param_0 = f32[240,88]{1,0} parameter(0)
- param_1 = f32[88,4]{1,0} parameter(1)
- dot = f32[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- constant_255 = f32[] constant(255)
- broadcast = f32[240,4]{1,0} broadcast(constant_255), dimensions={}
- multiply = f32[240,4]{1,0} multiply(dot, broadcast)
- ROOT result = u8[240,4]{1,0} convert(multiply)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- // u8 is not supported by legacy cublas and cublasLt, expect
- // GemmRewriter to not fuse the convert into gemm.
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Convert(
- m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, CheckIsGemmAliasedBeforeFusion) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
- Arg_0.1 = f16[8,16]{1,0} parameter(0)
- Arg_1.2 = f16[16,32]{1,0} parameter(1)
- dot.8 = f16[8,32]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- Arg_2.3 = f16[8,32]{1,0} parameter(2)
- constant.5 = f16[] constant(1)
- broadcast.6 = f16[8,32]{1,0} broadcast(constant.5), dimensions={}
- add.7 = f16[8,32]{1,0} add(Arg_2.3, broadcast.6)
- add.9 = f16[8,32]{1,0} add(dot.8, add.7)
- convert.10 = f32[8,32]{1,0} convert(add.9)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- // input fp16 and output fp32 combination is supported by legacy cublas and
- // cublasLt, but gemm output is already aliased with one of the input expect
- // GemmRewriter to not fuse the convert into gemm.
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Convert(
- m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
-}
-
-INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt,
- ParameterizedGemmRewriteTest, ::testing::Bool());
-#endif
-
-// A test fixture class for tests which are specific to legacy cublas
-class LegacyCublasGemmRewriteTest : public GemmRewriteTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_triton_gemm(false);
- debug_options.set_xla_gpu_enable_cublaslt(false);
- return debug_options;
- }
-};
-
-TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplication) {
- const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
- p0 = f32[2048] parameter(0)
- p1 = f32[2048, 16384] parameter(1)
- ROOT d = f32[16384] dot(p0, p1),
- lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})";
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(
- se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
- /*toolkit_version=*/12040),
- R"(
-; CHECK: %[[P0:.+]] = f32[2048]{0} parameter(0)
-; CHECK: %[[P1:.+]] = f32[2048,16384]{1,0} parameter(1)
-; CHECK: %[[CUSTOM_CALL:.+]] = (f32[16384]{0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplicationWithBatch) {
- const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
- p0 = f32[10, 10, 2048] parameter(0)
- p1 = f32[10, 10, 2048, 16384] parameter(1)
- ROOT d = f32[10, 10, 16384] dot(p0, p1),
- lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1},
- lhs_contracting_dims={2}, rhs_contracting_dims={2}
-})";
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(
- se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
- /*toolkit_version=*/12040),
- R"(
-; CHECK: %[[P0:.+]] = f32[10,10,2048]{2,1,0} parameter(0)
-; CHECK: %[[P1:.+]] = f32[10,10,2048,16384]{3,2,1,0} parameter(1)
-; CHECK: %[[CUSTOM_CALL:.+]] = (f32[10,10,16384]{2,1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, SparseDotNotSupported) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
- lhs = f16[5,16] parameter(0)
- rhs = f16[32,10] parameter(1)
- meta = u16[5,2] parameter(2)
- ROOT dot = f32[5,10] dot(lhs, rhs, meta),
- lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
-})";
- auto hlo_pass = GemmRewriter(
- se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
- /*toolkit_version=*/12040);
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get()));
- EXPECT_FALSE(changed);
-}
-
-// Test that the alpha and beta fields of the GemmBackendConfig are updated.
-// A bias must be present for the beta value to be set.
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, AlphaBetaRewrite) {
- const char* hlo_text = R"(
-HloModule NonZeroAlphaBeta
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- param_2 = f32[2,2] parameter(2)
- bias = f32[2,2] negate(param_2)
- k = f32[] constant(3.0)
- k_broadcast = f32[2, 2] broadcast(k), dimensions={}
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
- ROOT out = f32[2,2] add(dot_a_multiplied, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK: [[O:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: output_to_operand_aliasing={
-; CHECK-SAME: {0}: (2, {})
-; CHECK-SAME: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[O]]), index=0
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
- const char* hlo_text = R"(
-HloModule BiasMultipleUsersNoOverwrite
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] parameter(2)
- k = f32[] constant(3.0)
- k_broadcast = f32[2, 2] broadcast(k), dimensions={}
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
- biased_out = f32[2,2] add(dot_a_multiplied, bias)
- ROOT out = f32[2,2] add(biased_out, bias)
-}
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: [[CUSTOM_CALL:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, BiasParameterNoOverwrite) {
- const char* hlo_text = R"(
-HloModule BiasParameterNoOverwrite
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] parameter(2)
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,2] add(dot_a, bias)
-}
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, BiasTupleParameterOverwrite) {
- const char* hlo_text = R"(
-HloModule BiasTupleParameterOverwrite
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- param_2 = (f32[2,2], f32[3,3]) parameter(2)
- bias = f32[2,2] get-tuple-element(param_2), index=0
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,2] add(dot_a, bias)
-}
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: (f32[2,2], f32[3,3])) -> f32[2,2] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG: [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2)
-; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[P2]]), index=0
-; CHECK-DAG: [[BIAS_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[BIAS]])
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS_COPY]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: output_to_operand_aliasing={
-; CHECK-SAME: {0}: (2, {})
-; CHECK-SAME: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, AliasedBiasOverwrite) {
- const char* hlo_text = R"(
-HloModule AliasedBiasOverwrite, input_output_alias={ {}: (2, {}, must-alias) }
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] parameter(2)
- k = f32[] constant(3.0)
- k_broadcast = f32[2, 2] broadcast(k), dimensions={}
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
- ROOT out = f32[2,2] add(dot_a_multiplied, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
-; CHECK: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: output_to_operand_aliasing={
-; CHECK-SAME: {0}: (2, {})
-; CHECK-SAME: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
- const char* hlo_text = R"(
-HloModule LargerBiasMultipleUsersNoRewrite
-
-ENTRY AddDotsFunc {
- x = f32[1024,1024] parameter(0)
- y = f32[1024,1024] parameter(1)
- bias = f32[1024,1024] parameter(2)
- dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- biased_out = f32[1024,1024] add(dot_a, bias)
- ROOT out = f32[1024,1024] add(biased_out, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, BF16GemmWithBias) {
- const char* hlo_text = R"(
-HloModule BF16GemmWithBias
-
-ENTRY BF16GemmWithBias {
- x = bf16[8,8]{1,0} parameter(0)
- y = bf16[8,8]{1,0} parameter(1)
- dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- param_2 = bf16[8,8]{1,0} parameter(2)
- bias = bf16[8,8]{1,0} negate(param_2)
- ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
-}
- )";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 2e-3}));
-
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
-; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
-; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
-; CHECK: [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: output_to_operand_aliasing={
-; CHECK-SAME: {0}: (2, {})
-; CHECK-SAME: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- param_2 = f32[2,4] parameter(2)
- bias = f32[2,4] negate(param_2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,4] add(dot_a, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK: [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], {{[^,)]+}}),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: output_to_operand_aliasing={
-; CHECK-SAME: {0}: (2, {})
-; CHECK-SAME: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- w = f32[2,3] parameter(0)
- x = f32[3,4] parameter(1)
- first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- y = f32[2,3] parameter(2)
- z = f32[3,4] parameter(3)
- second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,4] add(second_dot, first_dot)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
-; CHECK-NEXT: [[FIRST_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[FIRST_GEMM_OUT:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM]]), index=0
-; CHECK-NEXT: [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM_OUT]]),
-; CHECK: custom_call_target="__cublas$gemm",
-; CHECK: output_to_operand_aliasing={
-; CHECK-SAME: {0}: (2, {})
-; CHECK-SAME: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-// Test gemm matrix bias add fusion with mix type
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixType) {
- std::vector<std::tuple<absl::string_view, absl::string_view>>
- type_combinations = {
- {"f16", "f32"},
- {"bf16", "f32"},
- };
-
- const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
- x = <<ABType>>[16,32] parameter(0)
- y = <<ABType>>[32,16] parameter(1)
- z = <<DType>>[16,16] parameter(2)
- dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias = <<DType>>[16,16] negate(z)
- convert = <<DType>>[16,16] convert(dot_a)
- ROOT out = <<DType>>[16,16] add(convert, bias)
-}
-
-)";
- for (const auto& type_combination : type_combinations) {
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<ABType>>"] = std::get<0>(type_combination);
- replacements["<<DType>>"] = std::get<1>(type_combination);
- const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- continue;
- }
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1),
- m::Negate(m::Parameter(2))),
- 0)));
- }
-}
-
-// Test batch gemm matrix bias add fusion with mix type
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeBatched) {
- std::vector<std::tuple<absl::string_view, absl::string_view>>
- type_combinations = {
- {"f16", "f32"},
- {"bf16", "f32"},
- };
-
- const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
- x = <<ABType>>[4,16,32] parameter(0)
- y = <<ABType>>[4,32,16] parameter(1)
- z = <<DType>>[4,16,16] parameter(2)
- dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
- bias = <<DType>>[4,16,16] negate(z)
- convert = <<DType>>[4,16,16] convert(dot_a)
- ROOT out = <<DType>>[4,16,16] add(convert, bias)
-})";
- for (const auto& type_combination : type_combinations) {
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<ABType>>"] = std::get<0>(type_combination);
- replacements["<<DType>>"] = std::get<1>(type_combination);
- const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- continue;
- }
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1),
- m::Negate(m::Parameter(2))),
- 0)));
- }
-}
-#endif
-
-// Test batch gemm matrix bias add fusion with mix type that is not supported.
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[16,32] parameter(0)
- y = bf16[32,16] parameter(1)
- z = f64[16,16] parameter(2)
- dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias = f64[16,16] negate(z)
- convert = f64[16,16] convert(dot_a)
- ROOT out = f64[16,16] add(convert, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
-; CHECK: %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
-; CHECK: ROOT {{.*}} fusion({{.*}}%[[gte]]
-)");
-}
-
-// Test batch gemm matrix bias add fusion with mix type that is not supported
-// because there are consumers of bias add.
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeAddWithMoreConsumers) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[16,32] parameter(0)
- y = bf16[32,16] parameter(1)
- z = f32[16,16] parameter(2)
- dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias = f32[16,16] negate(z)
- convert = f32[16,16] convert(dot_a)
- add_bias = f32[16,16] add(convert, bias)
- ROOT out = f32[16,16] negate(add_bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
-; CHECK: %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
-; CHECK: ROOT {{.*}} fusion({{.*}}%[[gte]]
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) {
- const char* hlo_text = R"(
-HloModule test
-ENTRY test {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[4] parameter(2)
- dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[4] add(f32[4] bitcast(dot), bias)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Bitcast(
- m::GetTupleElement(
- m::CustomCall(
- {"__cublas$gemm"}, m::Parameter(0), m::Parameter(1),
- m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})),
- 0))
- .WithShape(F32, {4})));
-}
-
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, FoldConstantBias) {
- const char* hlo_text = R"(
-HloModule test
-ENTRY test {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
-
- dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- param_2 = f32[2,2] parameter(2)
- bias1 = f32[2,2] negate(param_2)
- sum1 = add(dot1, bias1)
-
- dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- sum2 = add(dot2, f32[2,2] reshape(bias))
-
- dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias3 = f32[2,2] transpose(bias), dimensions={1,0}
- sum3 = add(dot3, bias3)
-
- dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- sum4 = add(dot4, f32[2,2] bitcast(bias))
-
- ROOT root = tuple(sum1, sum2, sum3, sum4)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- SCOPED_TRACE(module->ToString());
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(m::Parameter(0), m::Parameter(1),
- m::Negate(m::Parameter(2))),
- 0),
- m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
- 0),
- m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
- 0),
- m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
- 0))));
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-// A test fixture class for tests which are specific to cublasLt
-class CublasLtGemmRewriteTest : public GemmRewriteTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_enable_cublaslt(true);
- debug_options.set_xla_gpu_enable_triton_gemm(false);
- return debug_options;
- }
-
- protected:
- void SetUp() override {
- if (SkipGpuBlasLtTest()) {
- GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
- }
- }
-};
-
-TEST_F(CublasLtGemmRewriteTest, AlphaBetaRewrite) {
- const char* hlo_text = R"(
-HloModule NonZeroAlphaBeta
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] parameter(2)
- k = f32[] constant(3.0)
- k_broadcast = f32[2, 2] broadcast(k), dimensions={}
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
- ROOT out = f32[2,2] add(dot_a_multiplied, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK-NEXT ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element(%cublas-lt-matmul.2.0), index=0
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
- const char* hlo_text = R"(
-HloModule BiasMultipleUsersNoOverwrite
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] parameter(2)
- k = f32[] constant(3.0)
- k_broadcast = f32[2, 2] broadcast(k), dimensions={}
- dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
- biased_out = f32[2,2] add(dot_a_multiplied, bias)
- ROOT out = f32[2,2] add(biased_out, bias)
-}
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK-NOT: output_to_operand_aliasing
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
- const char* hlo_text = R"(
-HloModule LargerBiasMultipleUsersNoRewrite
-
-ENTRY AddDotsFunc {
- x = f32[1024,1024] parameter(0)
- y = f32[1024,1024] parameter(1)
- bias = f32[1024,1024] parameter(2)
- dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- biased_out = f32[1024,1024] add(dot_a, bias)
- ROOT out = f32[1024,1024] add(biased_out, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
-; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[1024,1024]{1,0} parameter(2)
-; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[1024,1024]{1,0} add([[GEMM]], [[BIAS]])
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BF16GemmWithBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY BF16GemmWithBias {
- x = bf16[8,8]{1,0} parameter(0)
- y = bf16[8,8]{1,0} parameter(1)
- dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias = bf16[8,8]{1,0} parameter(2)
- ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
-}
- )";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
-; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
-; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
-; CHECK-DAG: [[BIAS:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[2,4] parameter(2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,4] add(dot_a, z)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- w = f32[2,3] parameter(0)
- x = f32[3,4] parameter(1)
- first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- y = f32[2,3] parameter(2)
- z = f32[3,4] parameter(3)
- second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,4] add(second_dot, first_dot)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
-; CHECK-NEXT: [[FIRST_GEMM_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM_TUPLE]]), index=0
-; CHECK-NEXT: [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: output_to_operand_aliasing={
-; CHECK: {0}: (2, {})
-; CHECK: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[4] parameter(2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z), dimensions={1}
- ROOT out = f32[2,4] add(dot_a, z_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
-)");
-}
-
-// Epilogue Fusion disabled when GEMM has multiple users.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasMultipleUsers) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[4,4] parameter(0)
- y = f32[4,4] parameter(1)
- z = f32[4] parameter(2)
- c = f32[] constant(5)
- dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- z_bcast = f32[4,4] broadcast(z), dimensions={1}
- add_a = f32[4,4] add(dot_a, z_bcast)
- c_bcast = f32[4,4] broadcast(c), dimensions={}
- dot_b = f32[4,4] dot(dot_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- ROOT out = f32[4,4] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[4,4], [[DUMMY1:[^ ]+]]: f32[4]) -> f32[4,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} broadcast([[P1]]), dimensions={1}
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} add([[P0]], [[P2]])
-}
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4]) -> f32[4,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
-; CHECK-NEXT: [[MATMUL0_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK-NEXT: [[MATMUL0:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL0_TUPLE]]), index=0
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[4,4]{1,0} fusion([[MATMUL0]], [[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]]
-; CHECK: [[MATMUL1_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[MATMUL0]]
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[FUSION]], [[MATMUL1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedVectorBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3,4] parameter(0)
- y = f32[4,5,6] parameter(1)
- z = f32[3,5,6] parameter(2)
- dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
- ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
-; CHECK: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: output_to_operand_aliasing={
-; CHECK: {0}: (2, {})
-; CHECK: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedSharedVectorBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3,4] parameter(0)
- y = f32[4,5,6] parameter(1)
- z = f32[6] parameter(2)
- dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- z_bcast = f32[2,3,5,6] broadcast(z), dimensions={3}
- ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[6]) -> f32[2,3,5,6] {
-; CHECK: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: output_to_operand_aliasing={
-; CHECK: {0}: (2, {})
-; CHECK: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[2] parameter(2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z), dimensions={0}
- add = f32[2,4] add(dot_a, z_bcast)
- ROOT out = f32[4,2] transpose(add), dimensions={1,0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
-; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
-; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasSliced) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[4,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[3] parameter(2)
- dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- slice_a = f32[2,3] slice(dot_a), slice={[0:2], [0:3]}
- z_bcast = f32[2,3] broadcast(z), dimensions={1}
- ROOT out = f32[2,3] add(slice_a, z_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,3], {{.*}}: f32[3,4], {{.*}}: f32[3]) -> f32[2,3] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3]{0} parameter(2)
-; CHECK-NEXT: [[MATMUL:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
-; CHECK-NEXT: [[GETTUPLE:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3]{1,0} slice([[GETTUPLE]]), slice={[0:2], [0:3]}
- )");
-}
-
-// Epilogue Fusion disabled when slice has multiple users.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasSlicedMultipleUsers) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[2] parameter(2)
- c = f32[] constant(5)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
- z_bcast = f32[2,2] broadcast(z), dimensions={1}
- add_a = f32[2,2] add(slice_a, z_bcast)
- c_bcast = f32[2,2] broadcast(c), dimensions={}
- dot_b = f32[2,2] dot(slice_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,2] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[2,2] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
-; CHECK-NEXT: [[MATMUL0_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[MATMUL1_TUPLE:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call{{.*}}[[MATMUL1]]
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposed) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[2] parameter(2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] parameter(3)
- ROOT out = f32[2,4] add(dot_a, z_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2_BCAST]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[4] parameter(2)
- z2 = f32[2,4] parameter(3)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z), dimensions={1}
- add0 = f32[2,4] add(dot_a, z_bcast)
- ROOT add1 = f32[2,4] add(add0, z2)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG: [[VECTOR_BIAS:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-DAG: [[MATRIX_BIAS:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[MATRIX_BIAS]], [[VECTOR_BIAS]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BF16VectorBias) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[16,24] parameter(0)
- y = bf16[24,32] parameter(1)
- z = bf16[32] parameter(2)
- dot_a = bf16[16,32] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = bf16[16,32] broadcast(z), dimensions={1}
- ROOT out = bf16[16,32] add(dot_a, z_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{3e-3, 1e-3}));
-
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[16,24], {{.*}}: bf16[24,32], {{.*}}: bf16[32]) -> bf16[16,32] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[16,24]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[24,32]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32]{0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BF16VectorBiasPadded) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
- "architectures with bf16 Tensor Cores.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[2,3] parameter(0)
- y = bf16[3,4] parameter(1)
- z = bf16[4] parameter(2)
- dot_a = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = bf16[2,4] broadcast(z), dimensions={1}
- ROOT out = bf16[2,4] add(dot_a, z_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] {
-; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
-; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- c = f32[] constant(0)
- c_bcast = f32[2,4] broadcast(c), dimensions={}
- ROOT out = f32[2,4] maximum(dot_a, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3,4] parameter(0)
- y = f32[4,5,6] parameter(1)
- dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- c = f32[] constant(0)
- c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
- ROOT out = f32[2,3,5,6] maximum(dot_a, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6]) -> f32[2,3,5,6] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1)
-; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0}
-; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
-; CHECK: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivationSliced) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- c = f32[] constant(0)
- c_bcast = f32[2,2] broadcast(c), dimensions={}
- slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
- ROOT out = f32[2,2] maximum(slice_a, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
-; CHECK: [[MATMUL:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL]]), slice={[0:2], [0:2]}
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[2,4] parameter(2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- add = f32[2,4] add(dot_a, z)
- c = f32[] constant(0)
- c_bcast = f32[2,4] broadcast(c), dimensions={}
- ROOT out = f32[2,4] maximum(add, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, SquareMatrixBiasReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[4,4] parameter(0)
- y = f32[4,4] parameter(1)
- z = f32[4,4] parameter(2)
- dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- add = f32[4,4] add(dot_a, z)
- c = f32[] constant(0)
- c_bcast = f32[4,4] broadcast(c), dimensions={}
- ROOT out = f32[4,4] maximum(add, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4,4]) -> f32[4,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[4] parameter(2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z), dimensions={1}
- add = f32[2,4] add(dot_a, z_bcast)
- c = f32[] constant(0)
- c_bcast = f32[2,4] broadcast(c), dimensions={}
- ROOT out = f32[2,4] maximum(add, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_RELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedVectorBiasReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3,4] parameter(0)
- y = f32[4,5,6] parameter(1)
- z = f32[3,5,6] parameter(2)
- dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
- add = f32[2,3,5,6] add(dot_a, z_bcast)
- c = f32[] constant(0)
- c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
- ROOT out = f32[2,3,5,6] maximum(add, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
-; CHECK: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
-; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposedReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[2] parameter(2)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z), dimensions={0}
- add = f32[2,4] add(dot_a, z_bcast)
- c = f32[] constant(0)
- c_bcast = f32[2,4] broadcast(c), dimensions={}
- maximum = f32[2,4] maximum(add, c_bcast)
- ROOT out = f32[4,2] transpose(maximum), dimensions={1,0}
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
-; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_RELU"
-; CHECK: }
-; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBiasReluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z_vec = f32[4] parameter(2)
- z_matrix = f32[2,4] parameter(3)
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z_vec), dimensions={1}
- add0 = f32[2,4] add(dot_a, z_bcast)
- add1 = f32[2,4] add(add0, z_matrix)
- c = f32[] constant(0)
- c_bcast = f32[2,4] broadcast(c), dimensions={}
- ROOT out = f32[2,4] maximum(add1, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P3]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_RELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- mul.0 = f32[2,4] multiply(dot, dot)
- mul.1 = f32[2,4] multiply(dot, mul.0)
- const.0 = f32[] constant(0.044715)
- bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
- mul.2 = f32[2,4] multiply(mul.1, bcast.0)
- add.0 = f32[2,4] add(dot, mul.2)
- const.1 = f32[] constant(0.797884583)
- bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
- mul.3 = f32[2,4] multiply(add.0, bcast.1)
- tanh = f32[2,4] tanh(mul.3)
- const.2 = f32[] constant(1)
- bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
- add.2 = f32[2,4] add(tanh, bcast.2)
- const.3 = f32[] constant(0.5)
- bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
- mul.4 = f32[2,4] multiply(add.2, bcast.3)
- ROOT out = f32[2,4] multiply(dot, mul.4)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"GELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWrongConstant) {
- // Modify one constant slightly, so it should no longer pattern match.
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- mul.0 = f32[2,4] multiply(dot, dot)
- mul.1 = f32[2,4] multiply(dot, mul.0)
- const.0 = f32[] constant(0.05)
- bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
- mul.2 = f32[2,4] multiply(mul.1, bcast.0)
- add.0 = f32[2,4] add(dot, mul.2)
- const.1 = f32[] constant(0.797884583)
- bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
- mul.3 = f32[2,4] multiply(add.0, bcast.1)
- tanh = f32[2,4] tanh(mul.3)
- const.2 = f32[] constant(1)
- bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
- add.2 = f32[2,4] add(tanh, bcast.2)
- const.3 = f32[] constant(0.5)
- bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
- mul.4 = f32[2,4] multiply(add.2, bcast.3)
- ROOT out = f32[2,4] multiply(dot, mul.4)
-}
-
-)";
-
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-NOT: GELU
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) {
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60000
- auto rocm_switch = false; // GELU is only available from ROCM 6.0
-#else
- auto rocm_switch = true;
-#endif
- if (!IsCuda() && rocm_switch) {
- GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[4] parameter(2)
- dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z), dimensions={1}
- add = f32[2,4] add(dot, z_bcast)
- mul.0 = f32[2,4] multiply(add, add)
- mul.1 = f32[2,4] multiply(add, mul.0)
- const.0 = f32[] constant(0.044715)
- bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
- mul.2 = f32[2,4] multiply(mul.1, bcast.0)
- add.0 = f32[2,4] add(add, mul.2)
- const.1 = f32[] constant(0.797884583)
- bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
- mul.3 = f32[2,4] multiply(add.0, bcast.1)
- tanh = f32[2,4] tanh(mul.3)
- const.2 = f32[] constant(1)
- bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
- add.2 = f32[2,4] add(tanh, bcast.2)
- const.3 = f32[] constant(0.5)
- bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
- mul.4 = f32[2,4] multiply(add.2, bcast.3)
- ROOT out = f32[2,4] multiply(add, mul.4)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_GELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) {
- if (!IsCuda()) {
- GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- mul.0 = f32[2,4] multiply(dot, dot)
- mul.1 = f32[2,4] multiply(dot, mul.0)
- const.0 = f32[] constant(0.044715)
- bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
- mul.2 = f32[2,4] multiply(mul.1, bcast.0)
- add.0 = f32[2,4] add(dot, mul.2)
- const.1 = f32[] constant(0.797884583)
- bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
- mul.3 = f32[2,4] multiply(add.0, bcast.1)
- tanh = f32[2,4] tanh(mul.3)
- const.2 = f32[] constant(1)
- bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
- add.2 = f32[2,4] add(tanh, bcast.2)
- const.3 = f32[] constant(0.5)
- bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
- mul.4 = f32[2,4] multiply(add.2, bcast.3)
- mul.5 = f32[2,4] multiply(dot, mul.4)
- ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, dot)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> (f32[2,4], f32[2,4]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"GELU_AUX"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) {
- if (!IsCuda()) {
- GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[4] parameter(2)
- dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f32[2,4] broadcast(z), dimensions={1}
- add = f32[2,4] add(dot, z_bcast)
- mul.0 = f32[2,4] multiply(add, add)
- mul.1 = f32[2,4] multiply(add, mul.0)
- const.0 = f32[] constant(0.044715)
- bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
- mul.2 = f32[2,4] multiply(mul.1, bcast.0)
- add.0 = f32[2,4] add(add, mul.2)
- const.1 = f32[] constant(0.797884583)
- bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
- mul.3 = f32[2,4] multiply(add.0, bcast.1)
- tanh = f32[2,4] tanh(mul.3)
- const.2 = f32[] constant(1)
- bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
- add.2 = f32[2,4] add(tanh, bcast.2)
- const.3 = f32[] constant(0.5)
- bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
- mul.4 = f32[2,4] multiply(add.2, bcast.3)
- mul.5 = f32[2,4] multiply(add, mul.4)
- ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, add)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> (f32[2,4], f32[2,4]) {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_GELU_AUX"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBF16) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
- "architectures with bf16 Tensor Cores.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[2,3] parameter(0)
- y = bf16[3,4] parameter(1)
- dot = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- mul.0 = bf16[2,4] multiply(dot, dot)
- mul.1 = bf16[2,4] multiply(dot, mul.0)
- const.0 = bf16[] constant(0.044715)
- bcast.0 = bf16[2,4] broadcast(const.0), dimensions={}
- mul.2 = bf16[2,4] multiply(mul.1, bcast.0)
- add.0 = bf16[2,4] add(dot, mul.2)
- const.1 = bf16[] constant(0.797884583)
- bcast.1 = bf16[2,4] broadcast(const.1), dimensions={}
- mul.3 = bf16[2,4] multiply(add.0, bcast.1)
- tanh = bf16[2,4] tanh(mul.3)
- const.2 = bf16[] constant(1)
- bcast.2 = bf16[2,4] broadcast(const.2), dimensions={}
- add.2 = bf16[2,4] add(tanh, bcast.2)
- const.3 = bf16[] constant(0.5)
- bcast.3 = bf16[2,4] broadcast(const.3), dimensions={}
- mul.4 = bf16[2,4] multiply(add.2, bcast.3)
- ROOT out = bf16[2,4] multiply(dot, mul.4)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{5e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] {
-; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
-; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBitcast) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_bitcast = f32[2,2,2] bitcast(dot)
- mul.0 = f32[2,2,2] multiply(dot_bitcast, dot_bitcast)
- mul.1 = f32[2,2,2] multiply(dot_bitcast, mul.0)
- const.0 = f32[] constant(0.044715)
- bcast.0 = f32[2,2,2] broadcast(const.0), dimensions={}
- mul.2 = f32[2,2,2] multiply(mul.1, bcast.0)
- add.0 = f32[2,2,2] add(dot_bitcast, mul.2)
- const.1 = f32[] constant(0.797884583)
- bcast.1 = f32[2,2,2] broadcast(const.1), dimensions={}
- mul.3 = f32[2,2,2] multiply(add.0, bcast.1)
- tanh = f32[2,2,2] tanh(mul.3)
- const.2 = f32[] constant(1)
- bcast.2 = f32[2,2,2] broadcast(const.2), dimensions={}
- add.2 = f32[2,2,2] add(tanh, bcast.2)
- const.3 = f32[] constant(0.5)
- bcast.3 = f32[2,2,2] broadcast(const.3), dimensions={}
- mul.4 = f32[2,2,2] multiply(add.2, bcast.3)
- ROOT out = f32[2,2,2] multiply(dot_bitcast, mul.4)
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Bitcast(m::GetTupleElement(
- m::CustomCall({"__cublas$lt$matmul"},
- m::Parameter(0).WithShape(F32, {2, 3}),
- m::Parameter(1).WithShape(F32, {3, 4})),
- 0))
- .WithShape(F32, {2, 2, 2})));
-}
-
-// For F16, the sizes of all dimensions of the operands are required to be
-// multiples of 8 to allow matrix bias fusion.
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasF16) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[8,16] parameter(0)
- y = f16[16,8] parameter(1)
- z = f16[8,8] parameter(2)
- dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f16[8,8] add(dot_a, z)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasF32UnpaddedWithBitcast) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3]{1,0} parameter(0)
- y = f32[3,4]{1,0} parameter(1)
- z = f32[2]{0} parameter(2)
- dot_a = f32[2,4]{0,1} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitc = f32[4,2]{1,0} bitcast(f32[2,4]{0,1} dot_a)
- z_bcast = f32[4,2] broadcast(z), dimensions={1}
- ROOT add = f32[4,2]{1,0} add(bitc, z_bcast)
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Bitcast(m::GetTupleElement(
- m::CustomCall({"__cublas$lt$matmul"}, m::Parameter(0),
- m::Parameter(1),
- m::Parameter(2).WithShape(F32, {2})),
- 0)
- .WithShape(F32, {2, 4}))
- .WithShape(F32, {4, 2})));
-}
-
-// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
-// newer architectures) so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Unpadded) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[8,16] parameter(0)
- y = f16[16,8] parameter(1)
- z = f16[8] parameter(2)
- dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f16[8,8] broadcast(z), dimensions={1}
- ROOT add = f16[8,8] add(dot_a, z_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT: pad("
-; CHECK: custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Padded) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- GTEST_SKIP() << "Padding of GEMM operands only implemented on "
- "architectures with Tensor Cores.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[6,12] parameter(0)
- y = f16[12,6] parameter(1)
- z = f16[6] parameter(2)
- dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f16[6,6] broadcast(z), dimensions={1}
- ROOT add = f16[6,6] add(dot_a, z_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
-; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
- )");
-}
-
-// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
-// newer architectures) so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Unpadded) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[8,16] parameter(0)
- y = f16[16,8] parameter(1)
- dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- c = f16[] constant(0)
- c_bcast = f16[8,8] broadcast(c), dimensions={}
- ROOT out = f16[8,8] maximum(dot_a, c_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT: pad("
-; CHECK: custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Padded) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- GTEST_SKIP() << "Padding of GEMM operands only implemented on "
- "architectures with Tensor Cores.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[6,12] parameter(0)
- y = f16[12,6] parameter(1)
- dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- c = f16[] constant(0)
- c_bcast = f16[6,6] broadcast(c), dimensions={}
- ROOT out = f16[6,6] maximum(dot_a, c_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] {
-; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivationF16) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[8,16] parameter(0)
- y = f16[16,8] parameter(1)
- z = f16[8,8] parameter(2)
- dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- add = f16[8,8] add(dot_a, z)
- c = f16[] constant(0)
- c_bcast = f16[8,8] broadcast(c), dimensions={}
- ROOT out = f16[8,8] maximum(add, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
- )");
-}
-
-// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
-// newer architectures) so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Unpadded) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[8,16] parameter(0)
- y = f16[16,8] parameter(1)
- z = f16[8] parameter(2)
- dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f16[8,8] broadcast(z), dimensions={1}
- add = f16[8,8] add(dot_a, z_bcast)
- c = f16[] constant(0)
- c_bcast = f16[8,8] broadcast(c), dimensions={}
- ROOT out = f16[8,8] maximum(add, c_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT: pad("
-; CHECK: custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Padded) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
- GTEST_SKIP() << "Padding of GEMM operands only implemented on "
- "architectures with Tensor Cores.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f16[6,12] parameter(0)
- y = f16[12,6] parameter(1)
- z = f16[6] parameter(2)
- dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f16[6,6] broadcast(z), dimensions={1}
- add = f16[6,6] add(dot_a, z_bcast)
- c = f16[] constant(0)
- c_bcast = f16[6,6] broadcast(c), dimensions={}
- ROOT out = f16[6,6] maximum(add, c_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
-; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
- )");
-}
-
-// For bfloat16, the sizes of all dimensions of the operands are required to be
-// multiples of 8 to allow matrix bias fusion.
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasBF16) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[8,16] parameter(0)
- y = bf16[16,8] parameter(1)
- z = bf16[8,8] parameter(2)
- dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = bf16[8,8] add(dot_a, z)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[8,16], {{.*}}: bf16[16,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
-; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0)
-; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1)
-; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasBitcastBF16) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[8,16] parameter(0)
- y = bf16[16,8] parameter(1)
- bias = bf16[2,4,8] parameter(2)
- dot = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bitcast = bf16[2,4,8] bitcast(dot)
- ROOT out = bf16[2,4,8] add(bitcast, bias)
-}
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Bitcast(
- m::GetTupleElement(
- m::CustomCall(
- {"__cublas$lt$matmul"},
- m::Parameter(0).WithShape(BF16, {8, 16}),
- m::Parameter(1).WithShape(BF16, {16, 8}),
- m::Bitcast(m::Parameter(2)).WithShape(BF16, {8, 8})),
- 0))
- .WithShape(BF16, {2, 4, 8})));
-}
-
-// For bfloat16, the operands are padded if necessary on Ampere and newer
-// architectures so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Unpadded) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[8,16] parameter(0)
- y = bf16[16,8] parameter(1)
- z = bf16[8] parameter(2)
- dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = bf16[8,8] broadcast(z), dimensions={1}
- ROOT add = bf16[8,8] add(dot_a, z_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT: pad("
-; CHECK: custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Padded) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
- "Ampere and newer architectures.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[6,12] parameter(0)
- y = bf16[12,6] parameter(1)
- z = bf16[6] parameter(2)
- dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = bf16[6,6] broadcast(z), dimensions={1}
- ROOT add = bf16[6,6] add(dot_a, z_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
-; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
- )");
-}
-
-// For bfloat16, the operands are padded if necessary on Ampere and newer
-// architectures so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Unpadded) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[8,16] parameter(0)
- y = bf16[16,8] parameter(1)
- dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- c = bf16[] constant(0)
- c_bcast = bf16[8,8] broadcast(c), dimensions={}
- ROOT out = bf16[8,8] maximum(dot_a, c_bcast)
-}
-
-)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT: pad("
-; CHECK: custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Padded) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
- "Ampere and newer architectures.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[6,12] parameter(0)
- y = bf16[12,6] parameter(1)
- dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- c = bf16[] constant(0)
- c_bcast = bf16[6,6] broadcast(c), dimensions={}
- ROOT out = bf16[6,6] maximum(dot_a, c_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] {
-; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
- )");
-}
-
-// For bfloat16, the operands are padded if necessary on Ampere and newer
-// architectures so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Unpadded) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[8,16] parameter(0)
- y = bf16[16,8] parameter(1)
- z = bf16[8] parameter(2)
- dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = bf16[8,8] broadcast(z), dimensions={1}
- add = bf16[8,8] add(dot_a, z_bcast)
- c = bf16[] constant(0)
- c_bcast = bf16[8,8] broadcast(c), dimensions={}
- ROOT out = bf16[8,8] maximum(add, c_bcast)
-})";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-NOT: pad("
-; CHECK: custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Padded) {
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
- "Ampere and newer architectures.";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[6,12] parameter(0)
- y = bf16[12,6] parameter(1)
- z = bf16[6] parameter(2)
- dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = bf16[6,6] broadcast(z), dimensions={1}
- add = bf16[6,6] add(dot_a, z_bcast)
- c = bf16[] constant(0)
- c_bcast = bf16[6,6] broadcast(c), dimensions={}
- ROOT out = bf16[6,6] maximum(add, c_bcast)
-}
-
-)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
-; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) {
- if (!IsCuda()) {
- GTEST_SKIP() << "TODO: Unsupported blas-lt F64 datatype on ROCM";
- }
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f64[2,3] parameter(0)
- y = f64[3,4] parameter(1)
- z = f64[4] parameter(2)
- dot_a = f64[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- z_bcast = f64[2,4] broadcast(z), dimensions={1}
- add = f64[2,4] add(dot_a, z_bcast)
- c = f64[] constant(0)
- c_bcast = f64[2,4] broadcast(c), dimensions={}
- ROOT out = f64[2,4] maximum(add, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-10, 1e-10}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f64[2,3], {{.*}}: f64[3,4], {{.*}}: f64[4]) -> f64[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f64[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f64[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f64[4]{0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f64[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_RELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, AlphaSimpleRewriteBiasAddActivation) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = f32[2,3] parameter(0)
- y = f32[3,4] parameter(1)
- z = f32[4] parameter(2)
- k = f32[] constant(3.0)
- k_bcast = f32[2,4] broadcast(k), dimensions={}
- dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
- dot_a_multiplied = f32[2, 4] multiply(dot_a, k_bcast)
- z_bcast = f32[2,4] broadcast(z), dimensions={1}
- add = f32[2,4] add(dot_a_multiplied, z_bcast)
- c = f32[] constant(0)
- c_bcast = f32[2,4] broadcast(c), dimensions={}
- ROOT out = f32[2,4] maximum(add, c_bcast)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHlo(hlo_text,
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_RELU"
-; CHECK: }
- )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, FoldConstantBias) {
- const char* hlo_text = R"(
-HloModule test
-ENTRY test {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
-
- dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias1 = f32[2,2] parameter(2)
- sum1 = add(dot1, bias1)
-
- dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- sum2 = add(dot2, f32[2,2] reshape(bias))
-
- dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias3 = f32[2,2] transpose(bias), dimensions={1,0}
- sum3 = add(dot3, bias3)
-
- dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- sum4 = add(dot4, f32[2,2] bitcast(bias))
-
- ROOT root = tuple(sum1, sum2, sum3, sum4)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(Capability(), GetToolkitVersion());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- SCOPED_TRACE(module->ToString());
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter()),
- 0),
- m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
- 0),
- m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
- 0),
- m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
- 0))));
-}
-
-TEST_F(CublasLtGemmRewriteTest, MultipleMaximumUsers) {
- const char* hlo_text = R"(
-HloModule multiple_maximum_users
-
-relu {
- Arg_0 = f32[3,896,54]{2,1,0} parameter(0)
- constant = f32[] constant(0)
- broadcast = f32[3,896,54]{2,1,0} broadcast(constant), dimensions={}
- ROOT maximum = f32[3,896,54]{2,1,0} maximum(Arg_0, broadcast)
-}
-
-ENTRY main {
- constant = f32[] constant(1)
- broadcast_1 = f32[3,896,1024]{2,1,0} broadcast(constant), dimensions={}
- Arg_2 = f32[1024,54]{1,0} parameter(2)
- dot = f32[3,896,54]{2,1,0} dot(broadcast_1, Arg_2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
- Arg_1 = f32[54]{0} parameter(1)
- broadcast_2 = f32[3,896,54]{2,1,0} broadcast(Arg_1), dimensions={2}
- add = f32[3,896,54]{2,1,0} add(dot, broadcast_2)
- call = f32[3,896,54]{2,1,0} call(add), to_apply=relu
- Arg_0 = f32[1]{0} parameter(0)
- reshape_1 = f32[1,1,1]{2,1,0} reshape(Arg_0)
- broadcast_3 = f32[1,1,1]{2,1,0} broadcast(reshape_1), dimensions={0,1,2}
- reshape_2 = f32[] reshape(broadcast_3)
- broadcast_4 = f32[3,896,54]{2,1,0} broadcast(reshape_2), dimensions={}
- multiply = f32[3,896,54]{2,1,0} multiply(call, broadcast_4)
- ROOT tuple = (f32[3,896,54]{2,1,0}, f32[3,896,54]{2,1,0}) tuple(multiply, call)
-}
-)";
-
- // TODO(cjfj): Why do we need to relax the error constraint here?!
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-4}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK: custom_call_target="__cublas$lt$matmul",
- )");
-}
-
-// Test gemm matrix bias add fusion with mix type and out of place update(C !=
-// D)
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlace) {
- if (!IsCuda()) {
- GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
- }
- std::vector<std::tuple<absl::string_view, absl::string_view>>
- type_combinations = {
- {"f16", "f32"},
- {"bf16", "f32"},
- };
-
- const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
- x = <<ABType>>[16,32] parameter(0)
- y = <<ABType>>[32,16] parameter(1)
- z = <<DType>>[16,16] parameter(2)
- dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- convert = <<DType>>[16,16] convert(dot_a)
- ROOT out = <<DType>>[16,16] add(convert, z)
-})";
- for (const auto& type_combination : type_combinations) {
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<ABType>>"] = std::get<0>(type_combination);
- replacements["<<DType>>"] = std::get<1>(type_combination);
- const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- continue;
- }
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- EXPECT_THAT(
- optimized_module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)));
- }
-}
-
-// Test batch gemm matrix bias add fusion with mix type and out of place
-// update(C != D)
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlaceBatched) {
- if (!IsCuda()) {
- GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
- }
- std::vector<std::tuple<absl::string_view, absl::string_view>>
- type_combinations = {
- {"f16", "f32"},
- {"bf16", "f32"},
- };
-
- const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
- x = <<ABType>>[4,16,32] parameter(0)
- y = <<ABType>>[4,32,16] parameter(1)
- z = <<DType>>[4,16,16] parameter(2)
- dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
- convert = <<DType>>[4,16,16] convert(dot_a)
- ROOT out = <<DType>>[4,16,16] add(convert, z)
-})";
- for (const auto& type_combination : type_combinations) {
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<ABType>>"] = std::get<0>(type_combination);
- replacements["<<DType>>"] = std::get<1>(type_combination);
- const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- continue;
- }
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- EXPECT_THAT(
- optimized_module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)));
- }
-}
-
-// Test gemm matrix bias add fusion with mix type and in place update(C = D)
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeInPlace) {
- if (!IsCuda()) {
- GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
- }
- std::vector<std::tuple<absl::string_view, absl::string_view>>
- type_combinations = {
- {"f16", "f32"},
- {"bf16", "f32"},
- };
- const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
- x = <<ABType>>[16,32] parameter(0)
- y = <<ABType>>[32,16] parameter(1)
- z = <<DType>>[16,16] parameter(2)
- dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias = <<DType>>[16,16] negate(z)
- convert = <<DType>>[16,16] convert(dot_a)
- ROOT out = <<DType>>[16,16] add(convert, bias)
-})";
-
- for (const auto& type_combination : type_combinations) {
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<ABType>>"] = std::get<0>(type_combination);
- replacements["<<DType>>"] = std::get<1>(type_combination);
- const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- continue;
- }
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(m::Parameter(0), m::Parameter(1),
- m::Negate(m::Parameter(2))),
- 0)));
- }
-}
-
-// Test gemm matrix bias add fusion with mix type that is not supported
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
- const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
- x = bf16[16,32] parameter(0)
- y = bf16[32,16] parameter(1)
- z = f64[16,16] parameter(2)
- dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- bias = f64[16,16] negate(z)
- convert = f64[16,16] convert(dot_a)
- ROOT out = f64[16,16] add(convert, bias)
-}
-
-)";
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
- if (IsCuda() &&
- !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
- GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
- }
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo_text));
- MatchOptimizedHlo(hlo_text, R"(
-; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$lt$matmul
-; CHECK: %[[tuple:.*]] = bf16[16,16]{1,0} get-tuple-element(%[[custom_call]]), index=0
-; CHECK: ROOT {{.*}} fusion({{.*}}%[[tuple]]
-)");
-}
-
-class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest {
- public:
- ParameterizedFp8GemmRewriteTest() {
- replacements_[kF8E4M3DatatypePlaceholder] =
-#if GOOGLE_CUDA
- "f8e4m3fn";
-#else
- "f8e4m3fnuz";
-#endif
- replacements_[kF8E5M2DatatypePlaceholder] =
-#if GOOGLE_CUDA
- "f8e5m2";
-#else
- "f8e5m2fnuz";
-#endif
- replacements_[kF8E4M3AmaxPlaceholder] =
-#if GOOGLE_CUDA
- "448.";
-#else
- "240.";
-#endif
- }
-
- protected:
- // Check the HLO runs and has an FP8 cuBLAS LT custom call on supported
- // architectures (Ada, Hopper, and later).
- void CheckFp8IfSupported(absl::string_view hlo_text,
- ErrorSpec error_spec = ErrorSpec{1e-2, 1e-2}) {
- if (!HasFp8Support()) {
- return;
- }
- std::string replaced_hlo_text =
- absl::StrReplaceAll(hlo_text, replacements_);
- EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
- error_spec));
-
- // Most FP8 tests directly create a GemmRewriter and check the output.
- // Here, also run the entire HLO pass pipeline to ensure no other passes
- // interfere with GemmRewriter's pattern matching.
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(replaced_hlo_text));
- const HloInstruction* call =
- FindInstruction(optimized_module.get(), HloOpcode::kCustomCall);
- ASSERT_NE(call, nullptr);
- EXPECT_EQ(call->custom_call_target(), "__cublas$lt$matmul$f8");
- }
-
- void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
- bool print_operand_shape = false) {
- GemmRewriteTest::MatchOptimizedHlo(
- absl::StrReplaceAll(hlo, replacements_),
- absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
- }
-
- void RunAndFilecheckHloRewrite(
- absl::string_view hlo, HloPassInterface&& hlo_pass,
- std::optional<absl::string_view> expected,
- std::function<void(HloModule*)> after_pass_checks = nullptr,
- const HloModuleConfig* config = nullptr) {
- if (expected.has_value()) {
- std::string replaced_pattern =
- absl::StrReplaceAll(expected.value(), replacements_);
- GemmRewriteTest::RunAndFilecheckHloRewrite(
- absl::StrReplaceAll(hlo, replacements_), std::move(hlo_pass),
- replaced_pattern, after_pass_checks, config);
- }
- }
-
- absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
- ParseAndReturnVerifiedModule(absl::string_view hlo_text,
- int64_t replica_count = 1,
- int64_t num_partitions = 1) {
- return GemmRewriteTest::ParseAndReturnVerifiedModule(
- absl::StrReplaceAll(hlo_text, replacements_));
- }
-
- private:
- static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
- static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
- static constexpr const char* kF8E4M3AmaxPlaceholder{"<<F8E4M3_AMAX>>"};
-};
-
-TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) {
- if (HasFp8Support()) {
- GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY PreAdaTest {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
- ErrorSpec{1e-2, 1e-2}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
-; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
-; CHECK-DAG: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteOnPreAdaWithF32Output) {
- if (HasFp8Support()) {
- GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
- }
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY PreAdaTest {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- ROOT out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
- ErrorSpec{1e-2, 1e-2}));
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
-; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
-; CHECK-DAG: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- // Test with types unsupported by cuBLAS LT when FP8 is used. cuBLAS LT with
- // FP8 requires one of the operands to be F8E4M3FN.
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY unsupported_types {
- x = <<F8E5M2>>[16,16] parameter(0)
- y = <<F8E5M2>>[16,16] parameter(1)
- ROOT out = <<F8E5M2>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-)";
- EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
- ErrorSpec{1e-2, 1e-2}));
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(Capability(), GetToolkitVersion(), /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <<F8E5M2>>[16,16], {{.*}}: <<F8E5M2>>[16,16]) -> <<F8E5M2>>[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(0)
-; CHECK-NEXT: [[P0_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P1]])
-; CHECK-NEXT: [[DOT:%[^ ]+]] = f16[16,16]{1,0} dot([[P0_CONVERT]], [[P1_CONVERT]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} convert([[DOT]])
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1)
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
- R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#else
- R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#endif
- R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-// Do not fuse FP8 matrix bias.
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
- GTEST_SKIP() << "F8 gemm rewrite for D to be fp8 with Matrix Bias is only "
- "supported in ROCm 6.2 and above.";
-#endif // TF_ROCM_VERSION < 60200
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- dot_a = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- b = <<F8E4M3>>[16,16] parameter(2)
- ROOT out = <<F8E4M3>>[16,16] add(dot_a, b)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: <<F8E4M3>>[16,16]) -> <<F8E4M3>>[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[DOT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
-; CHECK-NEXT: [[P2:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} parameter(2)
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} add([[DOT]], [[P2]])
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[13,17] parameter(0)
- y = <<F8E4M3>>[17,31] parameter(1)
- x_f32 = f32[13,17] convert(x)
- y_f32 = f32[17,31] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[13,17] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[17,31] broadcast(y_scale), dimensions={}
- x_unscaled = f32[13,17] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[17,31] multiply(y_f32, y_scale_bcast)
- ROOT out = f32[13,31] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[13,17], {{.*}}: <<F8E4M3>>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[13,17]{1,0} parameter(0)
-; CHECK-NEXT: [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[17,31]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,17]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C4:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]], [[C4]], /*index=5*/[[C4]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK-NEXT: [[DOT:%[^ ]+]] = f32[16,32]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[13,31]{1,0} slice([[DOT]]), slice={[0:13], [0:31]}
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[2,8,16] parameter(0)
- y = <<F8E4M3>>[16,16] parameter(1)
- x_f32 = f32[2,8,16] convert(x)
- y_f32 = f32[16,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[2,8,16] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[2,8,16] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,16] multiply(y_f32, y_scale_bcast)
- x_bitcast = f32[16,16] bitcast(x_unscaled)
- ROOT out = f32[16,16] dot(x_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
- .WithShape(F32, {16, 16})));
-}
-
-// Test case where F8 inputs are converted to F32 before the dot, but without
-// any scaling.
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- ROOT out = f32[16,16] dot(x_f32, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[3] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[3] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[3] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[3] multiply(x_f32, x_scale_bcast)
- zero = f32[] constant(0)
- x_unscaled_padded = f32[30] pad(x_unscaled, zero), padding=0_27
- x_unscaled_padded_bcast = f32[30,8,5] broadcast(x_unscaled_padded), dimensions={0}
- x_unscaled_padded_bcast_sliced = f32[16,8,4] slice(x_unscaled_padded_bcast), slice={[2:18], [0:8], [0:4]}
- x_unscaled_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_unscaled_padded_bcast_sliced)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- ROOT out = f32[16,16] dot(x_unscaled_padded_bcast_sliced_reshaped, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
-; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
-; CHECK-NEXT: [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
-; CHECK-NEXT: [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
-; CHECK-NEXT: [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
-; CHECK-NEXT: [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C2]], /*index=5*/[[C2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- UnscaledABUnscaledDUnaryOpsWithConvertF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[3] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[3] convert(x)
- y_f32 = f32[32,16] convert(y)
- zero = f32[] constant(0)
- x_padded = f32[30] pad(x_f32, zero), padding=0_27
- x_padded_bcast = f32[30,8,5] broadcast(x_padded), dimensions={0}
- x_padded_bcast_sliced = f32[16,8,4] slice(x_padded_bcast), slice={[2:18], [0:8], [0:4]}
- x_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_padded_bcast_sliced)
- ROOT out = f32[16,16] dot(x_padded_bcast_sliced_reshaped, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
-; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
-; CHECK-NEXT: [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
-; CHECK-NEXT: [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
-; CHECK-NEXT: [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
-; CHECK-NEXT: [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]], [[C2]], /*index=5*/[[C2]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[32,32] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- zero = s32[] constant(0)
- x_f32 = f32[32,32] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[32,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[32,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- dyn_slice = f32[16,32]{1,0} dynamic-slice(x_unscaled, zero, zero), dynamic_slice_sizes={16,32}
- ROOT dot_a = f32[16,16] dot(dyn_slice, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[32,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} parameter(0)
-; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0)
-; CHECK-NEXT: [[DYN_SLICE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32}
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- k = pred[16,32] parameter(4)
- c = f32[] constant(0)
- c_bcast = f32[16,32] broadcast(c), dimensions={}
- select_a = f32[16,32] select(k, y_unscaled, c_bcast)
- ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P4:%[^ ]+]] = pred[16,32]{1,0} parameter(4)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[16,32]{1,0} broadcast([[C0]]), dimensions={}
-; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} convert([[C0_BCAST]])
-; CHECK-NEXT: [[SELECT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDSelectNonzeroConstantF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- k = pred[16,32] parameter(4)
- c = f32[] constant(1)
- c_bcast = f32[16,32] broadcast(c), dimensions={}
- select_a = f32[16,32] select(k, y_unscaled, c_bcast)
- ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_FALSE(changed);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[10,16,32] parameter(0)
- y = <<F8E4M3>>[10,32,16] parameter(1)
- x_f32 = f32[10,16,32] convert(x)
- y_f32 = f32[10,32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[10,16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[10,32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[10,16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[10,32,16] multiply(y_f32, y_scale_bcast)
- ROOT out = f32[10,16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[10,16,32], {{.*}}: <<F8E4M3>>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[10,32,16]{2,1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["2"]
-; CHECK-DAG: "rhs_contracting_dimensions":["2"]
-; CHECK-DAG: "lhs_batch_dimensions":["0"]
-; CHECK-DAG: "rhs_batch_dimensions":["0"]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- k = f32[] constant(3.0)
- k_bcast = f32[16,16] broadcast(k), dimensions={}
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[16,16] multiply(dot_a, k_bcast)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":3
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- c = f32[] constant(0)
- c_bcast = f32[16,16] broadcast(c), dimensions={}
- ROOT out = f32[16,16] maximum(dot_a, c_bcast)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_bf16 = bf16[16,32] convert(x)
- y_bf16 = bf16[32,16] convert(y)
- x_scale = bf16[] parameter(2)
- y_scale = bf16[] parameter(3)
- bias = bf16[16] parameter(4)
- x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
- y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
- dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- b_bcast = bf16[16,16] broadcast(bias), dimensions={1}
- dot = bf16[16,16] add(dot1, b_bcast)
- mul.0 = bf16[16,16] multiply(dot, dot)
- mul.1 = bf16[16,16] multiply(dot, mul.0)
- const.0 = bf16[] constant(0.044715)
- bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
- mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
- add.0 = bf16[16,16] add(dot, mul.2)
- const.1 = bf16[] constant(0.797884583)
- bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
- mul.3 = bf16[16,16] multiply(add.0, bcast.1)
- tanh = bf16[16,16] tanh(mul.3)
- const.2 = bf16[] constant(1)
- bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
- add.2 = bf16[16,16] add(tanh, bcast.2)
- const.3 = bf16[] constant(0.5)
- bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
- mul.4 = bf16[16,16] multiply(add.2, bcast.3)
- ROOT out = bf16[16,16] multiply(dot, mul.4)
- }
-)";
-
- CheckFp8IfSupported(hlo_text);
-
-// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
-// than 12.4.
-#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2)
-; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3)
-; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
- R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#else
- R"(; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]),
-)"
-#endif
- R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
- R"(; CHECK-GCN-DAG: "epilogue":"DEFAULT"
-)"
-#else
- R"(; CHECK-DAG: "epilogue":"BIAS_GELU"
-)"
-#endif
- R"(; CHECK: }
- )");
-#endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDApproxGeluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_bf16 = bf16[16,32] convert(x)
- y_bf16 = bf16[32,16] convert(y)
- x_scale = bf16[] parameter(2)
- y_scale = bf16[] parameter(3)
- x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
- y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
- dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- mul.0 = bf16[16,16] multiply(dot, dot)
- mul.1 = bf16[16,16] multiply(dot, mul.0)
- const.0 = bf16[] constant(0.044715)
- bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
- mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
- add.0 = bf16[16,16] add(dot, mul.2)
- const.1 = bf16[] constant(0.797884583)
- bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
- mul.3 = bf16[16,16] multiply(add.0, bcast.1)
- tanh = bf16[16,16] tanh(mul.3)
- const.2 = bf16[] constant(1)
- bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
- add.2 = bf16[16,16] add(tanh, bcast.2)
- const.3 = bf16[] constant(0.5)
- bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
- mul.4 = bf16[16,16] multiply(add.2, bcast.3)
- ROOT out = bf16[16,16] multiply(dot, mul.4)
- }
-)";
-
- CheckFp8IfSupported(hlo_text);
-
-// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
-// than 12.4.
-#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
- // Currently, hipBlasLt does not support output datatype bf16 for fp8 matmul.
- // And no fusion was done for such cases.
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2)
-; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3)
-; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
- R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#else
- R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#endif
- R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
- R"(; CHECK-GCN-DAG: "epilogue":"DEFAULT"
-)"
-#else
- R"(; CHECK-DAG: "epilogue":"GELU"
-)"
-#endif
- R"(; CHECK: }
- )");
-#endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] divide(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] divide(y_f32, y_scale_bcast)
- ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- b = f32[16,16] parameter(2)
- one = f32[] constant(1)
- ones = f32[16,16] broadcast(one), dimensions={}
- b_ones = f32[16,16] add(b, ones)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(3)
- y_scale = f32[] parameter(4)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = add(dot_a, b_ones)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK: [[C0:%[^ ]+]] = f32[16,16]{1,0} add({{.*}})
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: output_to_operand_aliasing={
-; CHECK-SAME: {0}: (2, {})
-; CHECK-SAME: }
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[14,31] parameter(0)
- y = <<F8E4M3>>[31,14] parameter(1)
- b = f32[14,14] parameter(2)
- x_f32 = f32[14,31] convert(x)
- y_f32 = f32[31,14] convert(y)
- x_scale = f32[] parameter(3)
- y_scale = f32[] parameter(4)
- x_scale_bcast = f32[14,31] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[31,14] broadcast(y_scale), dimensions={}
- x_unscaled = f32[14,31] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[31,14] multiply(y_f32, y_scale_bcast)
- dot_a = f32[14,14] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = add(dot_a, b)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[14,31], {{.*}}: <<F8E4M3>>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} parameter(0)
-; CHECK-NEXT: [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[31,14]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[14,14]{1,0} parameter(2)
-; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]], /*index=5*/[[C3]], [[C3]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[DOT:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[14,14]{1,0} slice([[DOT]]), slice={[0:14], [0:14]}
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- z_scale = f32[] parameter(2)
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
- c1 = f32[] constant(-448.)
- c1_bcast = f32[16,16] broadcast(c1), dimensions={}
- c2 = f32[] constant(448.)
- c2_bcast = f32[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- z_scale = f32[] parameter(2)
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- z_scale = f32[] parameter(2)
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-// Do not fuse output scaling without type conversion when a matrix bias was
-// fused.
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- b = f32[16,16] parameter(2)
- z_scale = f32[] parameter(3)
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bias = f32[16,16] add(dot_a, b)
- ROOT dot_a_scaled = f32[16,16] divide(dot_a_bias, z_scale_bcast)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2)
-; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK-PTX-NEXT: [[GEMM:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-PTX-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-PTX-NEXT: [[P3_BCAST:%[^ ]+]] = f32[16,16]{1,0} broadcast([[P3]]), dimensions={}
-; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} divide([[GEMM]], [[P3_BCAST]])
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- z_scale = f32[] parameter(4)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
- c1 = f32[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f32[16,16] broadcast(c1), dimensions={}
- c2 = f32[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f32[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- z_scale = f32[] parameter(4)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
- c1 = f32[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f32[16,16] broadcast(c1), dimensions={}
- c2 = f32[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f32[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-NOT: divide
-
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- z_scale = f32[] parameter(4)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- c = f32[] constant(0)
- c_bcast = f32[16,16] broadcast(c), dimensions={}
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- relu_a = f32[16,16] maximum(dot_a, c_bcast)
- relu_a_scaled = f32[16,16] divide(relu_a, z_scale_bcast)
- c1 = f32[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f32[16,16] broadcast(c1), dimensions={}
- c2 = f32[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f32[16,16] broadcast(c2), dimensions={}
- relu_a_clamped = f32[16,16] clamp(c1_bcast, relu_a_scaled, c2_bcast)
- ROOT out = <<F8E4M3>>[16,16] convert(relu_a_clamped)
- }
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f16[] parameter(0)
- b = f16[] parameter(1)
- ROOT c = f16[] maximum(a, b)
- }
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f16 = f16[16,32] convert(x)
- y_f16 = f16[32,16] convert(y)
- b = f16[16,16] parameter(2)
- one = f16[] constant(1)
- ones = f16[16,16] broadcast(one), dimensions={}
- b_ones = f16[16,16] add(b, ones)
- x_scale = f16[] parameter(3)
- y_scale = f16[] parameter(4)
- z_scale = f16[] parameter(5)
- x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
- y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
- dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bias = f16[16,16] add(dot_a, b_ones)
- abs_dot_a = f16[16,16] abs(dot_a_bias)
- c0 = f16[] constant(-inf)
- amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
- dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
- c1 = f16[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f16[16,16] broadcast(c1), dimensions={}
- c2 = f16[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f16[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- ROOT result = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
- }
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK: [[C0:%[^ ]+]] = f16[16,16]{1,0} add({{.*}})
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK: [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5)
-; CHECK-PTX: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
-; CHECK-NOT: output_to_operand_aliasing
-; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f16 = f16[16,32] convert(x)
- y_f16 = f16[32,16] convert(y)
- b = f16[16] parameter(2)
- b_bcast = f16[16,16] broadcast(b), dimensions={1}
- x_scale = f16[] parameter(3)
- y_scale = f16[] parameter(4)
- z_scale = f16[] parameter(5)
- x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
- y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
- dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bias = f16[16,16] add(dot_a, b_bcast)
- dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
- c1 = f16[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f16[16,16] broadcast(c1), dimensions={}
- c2 = f16[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f16[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <<F8E4M3>>[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1)
-; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5)
-; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]])
-; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
-; CHECK-PTX: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]),
-; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- b = f32[16] parameter(2)
- b_bf16 = bf16[16] convert(b)
- b_f32 = f32[16] convert(b_bf16)
- b_bcast = f32[16,16] broadcast(b_f32), dimensions={1}
- x_scale = f32[] parameter(3)
- y_scale = f32[] parameter(4)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[16,16] add(dot_a, b_bcast)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[VB:%[^ ]+]] = f32[16]{0} parameter(2)
-; CHECK-NEXT: [[VBC:%[^ ]+]] = bf16[16]{0} convert([[VB]])
-; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]], [[VBC]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDVectorBiasThenReluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- b = f16[16] parameter(2)
- b_bcast = f16[16,16] broadcast(b), dimensions={1}
- x_f32 = f16[16,32] convert(x)
- y_f32 = f16[32,16] convert(y)
- x_scale = f16[] parameter(3)
- y_scale = f16[] parameter(4)
- x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f16[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f16[32,16] multiply(y_f32, y_scale_bcast)
- c = f16[] constant(0)
- c_bcast = f16[16,16] broadcast(c), dimensions={}
- dot_a0 = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a = f16[16,16] add(dot_a0, b_bcast)
- ROOT out = f16[16,16] maximum(dot_a, c_bcast)
- }
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
-; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS_RELU"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[4,16,16] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- b = f32[32] parameter(2)
- b_f16 = f16[32] convert(b)
- b_bcast = f16[4,16,32] broadcast(b_f16), dimensions={2}
- x_f16 = f16[4,16,16] convert(x)
- y_f16 = f16[16,32] convert(y)
- x_scale = f16[] parameter(3)
- y_scale = f16[] parameter(4)
- x_scale_bcast = f16[4,16,16] broadcast(x_scale), dimensions={}
- y_scale_bcast = f16[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f16[4,16,16] multiply(x_f16, x_scale_bcast)
- x_unscaled_bitcast = f16[64,16] bitcast(x_unscaled)
- y_unscaled = f16[16,32] multiply(y_f16, y_scale_bcast)
- dot_a = f16[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bitcast = f16[4,16,32]{2,1,0} bitcast(dot_a)
- ROOT out = f16[4,16,32] add(dot_a_bitcast, b_bcast)
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Bitcast(m::GetTupleElement(
- m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
- .WithShape(F16, {64, 32}))
- .WithShape(F16, {4, 16, 32})));
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[B:%[^ ]+]] = f32[32]{0} parameter(2)
-; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]])
-; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
-; CHECK: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK: ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]])
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- Rank3ScaledABUnscaledDVectorBiasPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[4,15,15] parameter(0)
- y = <<F8E4M3>>[15,31] parameter(1)
- b = f32[31] parameter(2)
- b_f16 = f16[31] convert(b)
- b_bcast = f16[4,15,31] broadcast(b_f16), dimensions={2}
- x_f16 = f16[4,15,15] convert(x)
- y_f16 = f16[15,31] convert(y)
- x_scale = f16[] parameter(3)
- y_scale = f16[] parameter(4)
- x_scale_bcast = f16[4,15,15] broadcast(x_scale), dimensions={}
- y_scale_bcast = f16[15,31] broadcast(y_scale), dimensions={}
- x_unscaled = f16[4,15,15] multiply(x_f16, x_scale_bcast)
- x_unscaled_bitcast = f16[60,15] bitcast(x_unscaled)
- y_unscaled = f16[15,31] multiply(y_f16, y_scale_bcast)
- dot_a = f16[60,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bitcast = f16[4,15,31]{2,1,0} bitcast(dot_a)
- ROOT out = f16[4,15,31] add(dot_a_bitcast, b_bcast)
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Bitcast(m::Slice(m::GetTupleElement(
- m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
- .WithShape(F16, {64, 32}))
- .WithShape(F16, {60, 31}))
- .WithShape(F16, {4, 15, 31})));
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[4,15,15]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[60,15]{1,0} bitcast([[P0]])
-; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[B:%[^ ]+]] = f32[31]{0} parameter(2)
-; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]])
-; CHECK-NEXT: [[C3:%[^ ]+]] = f16[] constant(0)
-; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1
-; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"BIAS"
-; CHECK: }
-; CHECK: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT: [[SLICE:%[^ ]+]] = f16[60,31]{1,0} slice([[GEMM]]), slice={[0:60], [0:31]}
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[4,15,31]{2,1,0} bitcast([[SLICE]])
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[4,16,16] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- b = f32[4,16,32] parameter(2)
- x_f32 = f32[4,16,16] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(3)
- y_scale = f32[] parameter(4)
- x_scale_bcast = f32[4,16,16] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[4,16,16] multiply(x_f32, x_scale_bcast)
- x_unscaled_bitcast = f32[64,16] bitcast(x_unscaled)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- dot_a = f32[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bitcast = f32[4,16,32]{2,1,0} bitcast(dot_a)
- ROOT out = f32[4,16,32] add(dot_a_bitcast, b)
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(m::Bitcast(m::GetTupleElement(
- m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
- .WithShape(F32, {64, 32}))
- .WithShape(F32, {4, 16, 32})));
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2)
-; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[GEMM:%[^ ]+]] = f32[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK: ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]])
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- Rank3ScaledABUnscaledDMatrixBiasPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[3,15,15] parameter(0)
- y = <<F8E4M3>>[15,31] parameter(1)
- b = f32[3,15,31] parameter(2)
- x_f32 = f32[3,15,15] convert(x)
- y_f32 = f32[15,31] convert(y)
- x_scale = f32[] parameter(3)
- y_scale = f32[] parameter(4)
- x_scale_bcast = f32[3,15,15] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[15,31] broadcast(y_scale), dimensions={}
- x_unscaled = f32[3,15,15] multiply(x_f32, x_scale_bcast)
- x_unscaled_bitcast = f32[45,15] bitcast(x_unscaled)
- y_unscaled = f32[15,31] multiply(y_f32, y_scale_bcast)
- dot_a = f32[45,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bitcast = f32[3,15,31]{2,1,0} bitcast(dot_a)
- ROOT out = f32[3,15,31] add(dot_a_bitcast, b)
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- GmockMatch(
- m::Bitcast(m::Slice(m::GetTupleElement(
- m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
- .WithShape(F32, {48, 32}))
- .WithShape(F32, {45, 31}))
- .WithShape(F32, {3, 15, 31})));
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[3,15,15]{2,1,0} parameter(0)
-; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[45,15]{1,0} bitcast([[P0]])
-; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
-; CHECK-NEXT: [[B:%[^ ]+]] = f32[3,15,31]{2,1,0} parameter(2)
-; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[45,31]{1,0} bitcast([[B]])
-; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[45,31]{1,0} slice([[GEMM]]), slice={[0:45], [0:31]}
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[3,15,31]{2,1,0} bitcast([[SLICE]])
- )");
-}
-
-// Do not fuse matrix bias When there is a slice that does not chop off the ends
-// of dimensions.
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDMatrixBiasWithSliceF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>>[48,16] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- b = f32[32,16] parameter(2)
- x_f32 = f32[48,16] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(3)
- y_scale = f32[] parameter(4)
- x_scale_bcast = f32[48,16] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[48,16] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- dot_a = f32[48,32] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_sliced = f32[32,16] slice(dot_a), slice={[16:48], [16:32]}
- ROOT out = f32[32,16] add(dot_a_sliced, b)
- }
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_TRUE(changed);
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[48,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[GEMM:%[^_]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[32,16]{1,0} slice([[GEMM]]), slice={[16:48], [16:32]}
-; CHECK-NEXT: [[B:%[^ ]+]] = f32[32,16]{1,0} parameter(2)
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[32,16]{1,0} add([[SLICE]], [[B]])
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- absl::string_view hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- all_gather = f32[16,64]{1,0} all-gather(x_unscaled), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
- all_gather1 = f32[64,32]{1,0} all-gather(y_unscaled), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
- ROOT dot_a = f32[16,32] dot(all_gather, all_gather1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-)";
-
- HloModuleConfig config = GetModuleConfigForTest();
- config.set_use_spmd_partitioning(true);
- config.set_num_partitions(8);
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK: [[AG:%[^ ]+]] = <<F8E4M3>>[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}}
-; CHECK: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK: [[AG1:%[^ ]+]] = <<F8E4M3>>[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}}
-; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0}
-; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: ROOT [[GEMM:%[^_]+]] = f32[16,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
- )",
- nullptr, &config);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- absl::string_view hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- all_to_all = f32[16,32]{1,0} all-to-all(x_unscaled), channel_id=1, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={0}
- ROOT dot_a = f32[16,16] dot(all_to_all, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- }
-)";
-
- HloModuleConfig config = GetModuleConfigForTest();
- config.set_use_spmd_partitioning(true);
- config.set_num_partitions(8);
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK: [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}}
-; CHECK: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )",
- nullptr, &config);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDWithCollectivePermuteF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- absl::string_view hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[16,32] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[16,32] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
- collective_permute = f32[16,32]{1,0} collective-permute(x_unscaled), source_target_pairs={{0,0}, {1,1}, {2,4}, {3,5}, {4,2}, {5,3}, {6,6}, {7,7}}
- ROOT dot_a = f32[16,16] dot(collective_permute, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
- }
-)";
-
- HloModuleConfig config = GetModuleConfigForTest();
- config.set_use_spmd_partitioning(true);
- config.set_num_partitions(8);
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK: [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}}
-; CHECK: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK: [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )",
- nullptr, &config);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDMatrixBiasThenVectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f16 = f16[16,32] convert(x)
- y_f16 = f16[32,16] convert(y)
- b = f16[16] parameter(2)
- b_bcast = f16[16,16] broadcast(b), dimensions={1}
- b2 = f16[16,16] parameter(3)
- x_scale = f16[] parameter(4)
- y_scale = f16[] parameter(5)
- x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
- y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
- dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- dot_a_bias1 = f16[16,16] add(dot_a, b2)
- ROOT dot_a_bias = f16[16,16] add(dot_a_bias1, b_bcast)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
-; CHECK-DAG: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT: [[MB:%[^ ]+]] = f16[16,16]{1,0} parameter(3)
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT: [[CV0:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(5)
-; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]], /*index=5*/[[C1]], [[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":1
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
-; CHECK: [[GEMMOUT:%[^ ]+]] = f16[16,16]{1,0} get-tuple-element([[GEMMOUT_TUPLE]]), index=0
-; CHECK: [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
-; CHECK: [[VBC:%[^ ]+]] = f16[16,16]{1,0} broadcast([[VB]]), dimensions={1}
-; CHECK: ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} add([[GEMMOUT]], [[VBC]])
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] maximum(a, b)
- }
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- z_scale = f32[] parameter(4)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- abs_dot_a = f32[16,16] abs(dot_a)
- c0 = f32[] constant(-inf)
- amax = f32[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
- dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
- c1 = f32[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f32[16,16] broadcast(c1), dimensions={}
- c2 = f32[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f32[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABScaledDWithDAmaxF8WithF16Intermediates) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate
- // values instead of F32 intermediate values.
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f16[] parameter(0)
- b = f16[] parameter(1)
- ROOT c = f16[] maximum(a, b)
- }
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f16 = f16[16,32] convert(x)
- y_f16 = f16[32,16] convert(y)
- x_scale = f16[] parameter(2)
- y_scale = f16[] parameter(3)
- z_scale = f16[] parameter(4)
- x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
- y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
- dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- abs_dot_a = f16[16,16] abs(dot_a)
- c0 = f16[] constant(-inf)
- amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
- dot_a_scaled = f16[16,16] divide(dot_a, z_scale_bcast)
- c1 = f16[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f16[16,16] broadcast(c1), dimensions={}
- c2 = f16[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f16[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- ROOT out = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(2)
-; CHECK-NEXT: [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT: [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1)
-; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4)
-; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]])
-; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]),
-; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABScaledDReluActivationWithDAmaxF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- apply {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- ROOT c = f32[] maximum(a, b)
- }
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E4M3>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- z_scale = f32[] parameter(4)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- czero = f32[] constant(0)
- czero_bcast = f32[16,16] broadcast(czero), dimensions={}
- dot_a_relu = f32[16,16] maximum(dot_a, czero_bcast)
- c0 = f32[] constant(-inf)
- amax = f32[] reduce(dot_a_relu, c0), dimensions={0,1}, to_apply=apply
- dot_a_scaled = f32[16,16] divide(dot_a_relu, z_scale_bcast)
- c1 = f32[] constant(-<<F8E4M3_AMAX>>)
- c1_bcast = f32[16,16] broadcast(c1), dimensions={}
- c2 = f32[] constant(<<F8E4M3_AMAX>>)
- c2_bcast = f32[16,16] broadcast(c2), dimensions={}
- dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
- dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
- ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
-; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"RELU"
-; CHECK: }
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* raw_hlo_template = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[1600,3200] parameter(0)
- y = <<F8E4M3>>[3200,1600] parameter(1)
- x_f32 = f32[1600,3200] convert(x)
- y_f32 = f32[3200,1600] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
- x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
- ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
- }
-)";
-
- std::string hlo_template =
- absl::StrReplaceAll(raw_hlo_template, replacements_);
-
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<precision>>"] = "default";
- const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
- EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
-
- replacements["<<precision>>"] = "highest";
- const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
- EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- std::array<std::array<absl::string_view, 7>, 32> combinations;
- int i = 0;
-
- for (bool d_is_col : {false, true}) {
- for (bool a_is_col : {false, true}) {
- for (bool b_is_col : {false, true}) {
- for (int lhs_contracting_dim : {0, 1}) {
- for (int rhs_contracting_dim : {0, 1}) {
- const absl::string_view lcd =
- lhs_contracting_dim == 1 ? "{1}" : "{0}";
- const absl::string_view rcd =
- rhs_contracting_dim == 1 ? "{1}" : "{0}";
- const absl::string_view a_shape =
- lhs_contracting_dim == 1 ? "[64,32]" : "[32,64]";
- const absl::string_view b_shape =
- rhs_contracting_dim == 0 ? "[32,16]" : "[16,32]";
- const absl::string_view a_layout = a_is_col ? "{0,1}" : "{1,0}";
- const absl::string_view b_layout = b_is_col ? "{0,1}" : "{1,0}";
- const absl::string_view output_layout =
- d_is_col ? "{0,1}" : "{1,0}";
- combinations[i++] = std::array{
- lcd, rcd, a_shape, b_shape, a_layout, b_layout, output_layout};
- }
- }
- }
- }
- }
-
- const char* hlo_template = R"(
- HloModule test
- ENTRY test {
- x = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
- x_f32 = f32<<Ashape>><<Alayout>> convert(x)
- x_scale = f32[] parameter(2)
- x_scale_bcast = f32<<Ashape>> broadcast(x_scale), dimensions={}
- x_unscaled = f32<<Ashape>> multiply(x_f32, x_scale_bcast)
- y = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
- y_f32 = f32<<Bshape>><<Blayout>> convert(y)
- y_scale = f32[] parameter(3)
- y_scale_bcast = f32<<Bshape>> broadcast(y_scale), dimensions={}
- y_unscaled = f32<<Bshape>> multiply(y_f32, y_scale_bcast)
- ROOT out = f32[64,16]<<Olayout>> dot(x_unscaled, y_unscaled), lhs_contracting_dims=<<Lcd>>, rhs_contracting_dims=<<Rcd>>
- }
- )";
- for (const auto& combination : combinations) {
- absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
- replacements["<<Lcd>>"] = std::get<0>(combination);
- replacements["<<Rcd>>"] = std::get<1>(combination);
- replacements["<<Ashape>>"] = std::get<2>(combination);
- replacements["<<Bshape>>"] = std::get<3>(combination);
- replacements["<<Alayout>>"] = std::get<4>(combination);
- replacements["<<Blayout>>"] = std::get<5>(combination);
- replacements["<<Olayout>>"] = std::get<6>(combination);
- const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
- CheckFp8IfSupported(hlo_text);
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
- ; CHECK: custom_call_target="__cublas$lt$matmul$f8",
- )");
- }
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
- ScaledABUnscaledDF8ParameterizedBatched) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- // TODO(wenscarl): For batched matmul, not all combinations of A, B and
- // output layouts get pattern matched successfully to FP8 custom call. Only
- // a handful of cases are tested here.
- std::array<std::array<std::string, 7>, 32> combinations;
- std::string lcd, rcd, a_shape, b_shape, a_layout, b_layout, o_layout;
- int i = 0;
- for (bool o_is_col : {false, true}) {
- for (int lhs_contracting_dim : {2, 1}) {
- for (int rhs_contracting_dim : {2, 1}) {
- lcd = lhs_contracting_dim == 2 ? "{2}" : "{1}";
- rcd = rhs_contracting_dim == 2 ? "{2}" : "{1}";
- a_shape = lhs_contracting_dim == 2 ? "[2,64,32]" : "[2,32,64]";
- b_shape = rhs_contracting_dim == 1 ? "[2,32,16]" : "[2,16,32]";
- o_layout = o_is_col ? "{2, 0, 1}" : "{2, 1, 0}";
- for (std::string a_layout : {"{2,1,0}", "{1,2,0}"}) {
- for (std::string b_layout : {"{2,1,0}", "{1,2,0}"}) {
- combinations[i++] = std::array{lcd, rcd, a_shape, b_shape,
- a_layout, b_layout, o_layout};
- }
- }
- }
- }
- }
-
- const char* hlo_template = R"(
- HloModule m
-ENTRY f {
- x_q = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
- x_scale = f32[] parameter(2)
- x_scale_broadcast = f32<<Ashape>><<Alayout>> broadcast(x_scale), dimensions={}
- x_q_convert = f32<<Ashape>><<Alayout>> convert(x_q)
- x_qdq = f32<<Ashape>><<Alayout>> multiply(x_q_convert, x_scale_broadcast)
-
- y_q = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
- y_scale = f32[] parameter(3)
- y_scale_broadcast = f32<<Bshape>><<Blayout>> broadcast(y_scale), dimensions={}
- y_q_convert = f32<<Bshape>><<Blayout>> convert(y_q)
- y_qdq = f32<<Bshape>><<Blayout>> multiply(y_q_convert, y_scale_broadcast)
-
- ROOT out = f32[2,64,16]<<Olayout>> dot(x_qdq, y_qdq), lhs_batch_dims={0}, lhs_contracting_dims=<<Lcd>>, rhs_batch_dims={0}, rhs_contracting_dims=<<Rcd>>
-}
- )";
-
- for (const auto& combination : combinations) {
- absl::flat_hash_map<std::string, std::string> replacements;
- replacements["<<Lcd>>"] = std::get<0>(combination);
- replacements["<<Rcd>>"] = std::get<1>(combination);
- replacements["<<Ashape>>"] = std::get<2>(combination);
- replacements["<<Bshape>>"] = std::get<3>(combination);
- replacements["<<Alayout>>"] = std::get<4>(combination);
- replacements["<<Blayout>>"] = std::get<5>(combination);
- replacements["<<Olayout>>"] = std::get<6>(combination);
-
- const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
- CheckFp8IfSupported(hlo_text);
-
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
- ; CHECK: custom_call_target="__cublas$lt$matmul$f8",
- )");
- }
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = <<F8E4M3>>[16,32] parameter(0)
- y = <<F8E5M2>>[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-
-)";
-
- CheckFp8IfSupported(hlo_text);
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
- ; CHECK: custom_call_target="__cublas$lt$matmul$f8",
- )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
- GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif // TF_ROCM_VERSION < 60000
-
- // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them
- const char* hlo_text = R"(
- HloModule test
-
- ENTRY test {
- x = f8e4m3fnuz[16,32] parameter(0)
- y = f8e4m3fnuz[32,16] parameter(1)
- x_f32 = f32[16,32] convert(x)
- y_f32 = f32[32,16] convert(y)
- x_scale = f32[] parameter(2)
- y_scale = f32[] parameter(3)
- x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
- y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
- x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
- y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
- ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- }
-)";
-#if GOOGLE_CUDA
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_text));
- GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
- EXPECT_FALSE(changed);
-#endif
-#if TENSORFLOW_USE_ROCM
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2}));
- RunAndFilecheckHloRewrite(
- hlo_text,
- GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
- /*f8_rewrite=*/true),
- R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0)
-; CHECK-PTX-NEXT: [[P0_CV:%[^ ]+]] = f32[16,32]{1,0} convert([[P0]])
-; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-PTX-NEXT: [[P2_B:%[^ ]+]] = f32[16,32]{1,0} broadcast([[P2]]), dimensions={}
-; CHECK-PTX-NEXT: [[P0_UNSCALED:%[^ ]+]] = f32[16,32]{1,0} multiply([[P0_CV]], [[P2_B]])
-; CHECK-PTX-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
-; CHECK-PTX-NEXT: [[P1_CV:%[^ ]+]] = f32[32,16]{1,0} convert([[P1]])
-; CHECK-PTX-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-PTX-NEXT: [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={}
-; CHECK-PTX-NEXT: [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]])
-; CHECK-PTX-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]),
-; CHECK-GCN-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
-; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8",
-; CHECK: backend_config={
-; CHECK-DAG: "alpha_real":1
-; CHECK-DAG: "alpha_imag":0
-; CHECK-DAG: "beta":0
-; CHECK-DAG: "dot_dimension_numbers":{
-; CHECK-DAG: "lhs_contracting_dimensions":["1"]
-; CHECK-PTX-DAG: "rhs_contracting_dimensions":["0"]
-; CHECK-GCN-DAG: "rhs_contracting_dimensions":["1"]
-; CHECK-DAG: "lhs_batch_dimensions":[]
-; CHECK-DAG: "rhs_batch_dimensions":[]
-; CHECK-DAG: }
-; CHECK-DAG: "precision_config":{
-; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG: }
-; CHECK-DAG: "epilogue":"DEFAULT"
-; CHECK: }
- )");
-#endif
-}
-
-INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt,
- ParameterizedFp8GemmRewriteTest, ::testing::Bool());
-#endif
-
-TEST_F(GemmRewriteTest, NoFuseBiasBroadcast) {
- const char* hlo = R"(
-
-HloModule module
-
-ENTRY main.10 {
- Arg_0.1 = f16[384,128]{1,0} parameter(0)
- Arg_1.2 = f16[128,256]{1,0} parameter(1)
- dot.4 = f16[384,256]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- Arg_2.3 = f16[256]{0} parameter(2)
- reshape.5 = f16[1,256]{1,0} reshape(Arg_2.3)
- broadcast.6 = f16[1,256]{1,0} broadcast(reshape.5), dimensions={0,1}
- reshape.7 = f16[256]{0} reshape(broadcast.6)
- broadcast.8 = f16[384,256]{1,0} broadcast(reshape.7), dimensions={1}
- ROOT add.9 = f16[384,256]{1,0} add(dot.4, broadcast.8)
-})";
-
- MatchOptimizedHlo(hlo, R"(
-// CHECK: "beta":0
- )");
-}
-
-TEST_F(GemmRewriteTest, ReduceOfBatchDot) {
- absl::string_view hlo_string =
- R"(
-HloModule test
-
-region_5.50 {
- Arg_0.51 = f32[] parameter(0)
- Arg_1.52 = f32[] parameter(1)
- ROOT add.53 = f32[] add(Arg_0.51, Arg_1.52)
-}
-
-ENTRY main {
- p0 = bf16[3,32,3,13]{3,2,1,0} parameter(0)
- p1 = bf16[3,32,3,64]{3,2,1,0} parameter(1)
- dot.95 = bf16[3,3,13,64]{3,2,1,0} dot(p0, p1), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}, operand_precision={highest,highest}
- transpose.96 = bf16[3,64,3,13]{1,3,2,0} transpose(dot.95), dimensions={0,3,1,2}
- convert.101 = f32[3,64,3,13]{1,3,2,0} convert(transpose.96)
- constant.66 = f32[] constant(0.0)
- ROOT reduce.102 = f32[3,64,13]{2,1,0} reduce(convert.101, constant.66), dimensions={2}, to_apply=region_5.50
-}
-)";
- // Make sure the dot is lowered to a custom call. There is an algebraic
- // simplifier simplification which could turn the dot into a non-canonical dot
- // late in the pipeline, which will make it unsupported by the GemmRewriter.
- MatchOptimizedHlo(hlo_string, R"(
- // CHECK: custom_call_target="__cublas$gemm"
- )");
-}
-
-class GemmRewriteAllocationTest : public GpuCodegenTest {
- public:
- void CheckNumberOfAllocations(const std::string& hlo,
- int expected_number_of_allocations) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
- GetOptimizedModule(hlo));
- if (allocator_ == nullptr) {
- allocator_ = std::make_unique<se::StreamExecutorMemoryAllocator>(
- backend().default_stream_executor());
- }
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Executable> executable,
- backend().compiler()->RunBackend(std::move(optimized_module),
- backend().default_stream_executor(),
- allocator_.get()));
- GpuExecutable* gpu_executable =
- static_cast<GpuExecutable*>(executable.get());
- absl::Span<const BufferAllocation> allocations =
- gpu_executable->GetAllocations();
- ASSERT_EQ(allocations.size(), expected_number_of_allocations);
- }
-
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- // Make sure the rewriter does not skip the rewrite for being too small.
- debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
- return debug_options;
- }
-
- private:
- std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
-};
-
-TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) {
- const char* hlo_text = R"(
-HloModule SharedBufferAssignment
-
-ENTRY AddDotsFunc {
- x = f32[2,2] parameter(0)
- y = f32[2,2] parameter(1)
- bias = f32[2,2] add(x, y)
- dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- ROOT out = f32[2,2] add(dot, bias)
-}
-
-)";
-
- // Bias should be fused into the multiplication.
- CheckNumberOfAllocations(hlo_text, 4);
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-}
-
-class SmallDotGemmRewriteTest : public GemmRewriteTest {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
- debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
- return debug_options;
- }
-};
-
-TEST_F(SmallDotGemmRewriteTest, SkipSmallMatrixMultiplicationRewrite) {
- const char* hlo_text = R"(
-HloModule SkipSmallMatrixRewrite
-
-ENTRY DotFunc {
- x = f32[3,3] parameter(0)
- y = f32[3,3] parameter(1)
- ROOT out = f32[3,3] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[3,3], {{.*}}: f32[3,3]) -> f32[3,3] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,3]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,3]{1,0} parameter(1)
-; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} dot([[P0]], [[P1]]),
-; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0}
-)");
-}
-
-TEST_F(SmallDotGemmRewriteTest, LargeMatrixMultiplicationIsRewritten) {
- const char* hlo_text = R"(
-HloModule SkipSmallMatrixRewrite
-
-ENTRY DotFunc {
- x = f32[8,8] parameter(0)
- y = f32[8,8] parameter(1)
- ROOT out = f32[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
- MatchOptimizedHlo(hlo_text,
- R"(
-; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[8,8], {{.*}}: f32[8,8]) -> f32[8,8] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8,8]{1,0} parameter(1)
-; CHECK: {{[^ ]+}} = {{.*}} custom-call([[P0]], [[P1]])
-)");
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc
deleted file mode 100644
index 5db5ffd..0000000
--- a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc
+++ /dev/null
@@ -1,232 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_all_gather_optimizer.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class GpuAllGatherOptimizerTest : public HloTestBase {
- public:
- absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
- absl::string_view hlo_module, int64_t num_replicas,
- int64_t num_partitions, bool expect_change) {
- HloModuleConfig config = GetModuleConfigForTest(
- /*replica_count=*/num_replicas,
- /*num_partitions=*/num_partitions);
- config.set_use_spmd_partitioning(num_partitions > 1);
- TF_ASSIGN_OR_RETURN(auto module,
- ParseAndReturnVerifiedModule(hlo_module, config));
-
- auto changed = AllGatherOptimizer().Run(module.get());
- if (!changed.ok()) {
- return changed.status();
- }
- EXPECT_EQ(changed.value(), expect_change);
- return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
- }
-
- template <HloOpcode oc>
- size_t CollectiveCount(std::unique_ptr<HloModule> &module) {
- return absl::c_count_if(module->entry_computation()->instructions(),
- HloPredicateIsOp<oc>);
- }
-};
-
-TEST_F(GpuAllGatherOptimizerTest, BranchesOptimized) {
- absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
- x = bf16[] parameter(0)
- y = bf16[] parameter(1)
- ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/true));
- // graph should contain 1 all-gather but since the node removal piece
- // is diferred, they still exist at this stage
- EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
- EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 2);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, DisbledSPMDPartitioningJAXBug) {
- absl::string_view hlo_string = R"(
-HloModule pjit_f, entry_computation_layout={(f32[4,8]{1,0}, f32[4,8]{1,0})->f32[8,8]{1,0}}
-
-ENTRY %main.6_spmd (param: f32[4,8], param.1: f32[4,8]) -> f32[8,8] {
- %param = f32[4,8]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}
- %all-gather = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
- %param.1 = f32[4,8]{1,0} parameter(1), sharding={devices=[2,1]<=[2]}
- %all-gather.1 = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param.1), channel_id=2, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
- ROOT %add.0 = f32[8,8]{1,0} add(f32[8,8]{1,0} %all-gather, f32[8,8]{1,0} %all-gather.1), metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/1,
- /*num_partitions=*/2,
- /*expect_change=*/true));
- EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, MoreThanSingleUserForAllGather) {
- absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
- x = bf16[] parameter(0)
- y = bf16[] parameter(1)
- ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
-reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
-add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/false));
- // see the comment for BranchesOptimized test
- EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
- EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, AllGatherWithOpInBetweenOnRightBranch) {
- absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
- x = bf16[] parameter(0)
- y = bf16[] parameter(1)
- ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
-reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-add.1 = bf16[8,64,1024]{2,1,0} add(reduce-scatter.1, reduce-scatter.2)
-all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(add.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/true));
- EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
- EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, AllGatherOneSided) {
- absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
- x = bf16[] parameter(0)
- y = bf16[] parameter(1)
- ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
-
-add.1 = bf16[8,128,1024]{2,1,0} add(param.1, param.2)
-reduce-scatter = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.2 = bf16[8,128,1024]{2,1,0} add(all-gather, add.1)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/false));
- EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
- EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 1);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, DifferentOperandShapes) {
- absl::string_view hlo_string = R"(
-HloModule TestModule
-
-ENTRY main {
-param.1 = bf16[8,64,128]{2,1,0} parameter(0)
-param.2 = bf16[8,128,64]{2,1,0} parameter(1)
-all-gather.1 = bf16[8,128,128]{2,1,0} all-gather(param.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-all-gather.2 = bf16[8,128,128]{2,1,0} all-gather(param.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-add.1 = bf16[8,128,128]{2,1,0} add(all-gather.1, all-gather.2)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/false));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
index 2e5db53..b4124f2 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
@@ -24,7 +24,7 @@
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/primitive_util.h"
-#include "xla/service/gpu/gpu_sort_rewriter.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"
@@ -35,7 +35,7 @@
bool HloWasRewrittenToUseCubSort(const HloModule& module) {
for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) {
- if (pass_metadata.pass_name() == "gpu-sort-rewriter") {
+ if (pass_metadata.pass_name() == "sort-rewriter") {
return pass_metadata.module_changed();
}
}
@@ -50,13 +50,13 @@
public:
void SetUp() override {
HloTestBase::SetUp();
- GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000);
+ SortRewriter::SetSortSizeThresholdForTestingOnly(33000);
}
};
TEST_P(CubSortKeysTest, CompareToReference) {
int batch_size = std::get<2>(GetParam());
- int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+ int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
const char* kHloTpl = R"(
HloModule TestSortKeys
@@ -103,7 +103,7 @@
})";
int batch_size = std::get<2>(GetParam());
- int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+ int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
std::string hlo_str = absl::Substitute(
kHloTpl,
primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
@@ -138,13 +138,13 @@
public:
void SetUp() override {
HloTestBase::SetUp();
- GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000);
+ SortRewriter::SetSortSizeThresholdForTestingOnly(33000);
}
};
TEST_P(CubSortPairsTest, CompareToReference) {
int batch_size = std::get<3>(GetParam());
- int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+ int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
const char* kHloTpl = R"(
HloModule TestSortPairs
@@ -216,7 +216,7 @@
})";
int batch_size = std::get<3>(GetParam());
- int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+ int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
std::string hlo_str = absl::Substitute(
kHloTpl,
primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc
index 639cf51..aed017c 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc
@@ -134,7 +134,7 @@
EXPECT_TRUE(
LiteralTestUtil::Near(expected_result, actual_result, mha_error_spec_));
- // Run FusedMHA/FuseMHABackward thunk through command buffer
+ // Run through command buffer
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
debug_options.set_xla_gpu_graph_min_graph_size(1);
@@ -393,8 +393,8 @@
void TestImpl_Flash_Attention_BMM1_CausalMask_Softmax_BMM2() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 4)) {
- GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+ se::dnn::VersionInfo(9, 0, 0)) {
+ GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -414,8 +414,8 @@
void TestImpl_Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 4)) {
- GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+ se::dnn::VersionInfo(9, 0, 0)) {
+ GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -709,8 +709,8 @@
void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 4)) {
- GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+ se::dnn::VersionInfo(9, 0, 0)) {
+ GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -731,8 +731,8 @@
void TestImpl_Flash_Attention_Training_BMM1_Bias_Softmax_BMM2() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 4)) {
- GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+ se::dnn::VersionInfo(9, 0, 0)) {
+ GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -756,9 +756,9 @@
void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2_Cross_Attention() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 4)) {
+ se::dnn::VersionInfo(9, 0, 0)) {
GTEST_SKIP() << "Flash Attention cross attention requires "
- "cuDNN >= 8.9.4.";
+ "cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -780,10 +780,10 @@
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
auto cc = GetCudaComputeCapability();
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 6) ||
+ se::dnn::VersionInfo(9, 0, 0) ||
!cc.IsAtLeastHopper() || cc.minor != 0) {
GTEST_SKIP()
- << "Flash Attention dbias requires cuDNN >= 8.9.6 and Hopper arch.";
+ << "Flash Attention dbias requires cuDNN >= 9.0.0 and Hopper arch.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -900,8 +900,8 @@
void TestImpl_Flash_Attention_Training_BMM1_Softmax_BMM2() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 4)) {
- GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+ se::dnn::VersionInfo(9, 0, 0)) {
+ GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -925,10 +925,10 @@
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
auto cc = GetCudaComputeCapability();
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 4) ||
+ se::dnn::VersionInfo(9, 0, 0) ||
!cc.IsAtLeastHopper() || cc.minor != 0) {
GTEST_SKIP() << "Flash Attention deterministic kernels requires cuDNN >= "
- "8.9.4 and Hopper arch.";
+ "9.0.0 and Hopper arch.";
}
XlaBuilder builder(TestName());
auto lhs_bmm1_literal =
@@ -1085,8 +1085,8 @@
void TestImpl_Flash_Attention_Training_BMM1_PaddingMask_Softmax_BMM2() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
- se::dnn::VersionInfo(8, 9, 3)) {
- GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3.";
+ se::dnn::VersionInfo(9, 0, 0)) {
+ GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
// pass padding mask as bias
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc
index 3e573eb..f1577f0 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc
@@ -18,11 +18,11 @@
#include <utility>
#include "absl/strings/string_view.h"
-#include "xla/service/gpu/fusion_merger.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/instruction_fusion.h"
-#include "xla/service/gpu/multi_output_fusion.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/fusion_merger.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/shape.h"
@@ -51,8 +51,7 @@
device_info);
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true, device_info);
pipeline.AddPass<FusionMerger>(device_info, ShapeSizeBytesFunction());
- pipeline.AddPass<GpuMultiOutputFusion>(device_info,
- ShapeSizeBytesFunction());
+ pipeline.AddPass<MultiOutputFusion>(device_info, ShapeSizeBytesFunction());
RunAndFilecheckHloRewrite(hlo, std::move(pipeline), expected);
}
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc
index 849cf1d..8328156 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc
@@ -23,8 +23,8 @@
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/instruction_fusion.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index e86f2c0..45e7af6 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -34,7 +34,11 @@
// Most tests in this file want to skip layout assignment, but a few need it
// enabled.
HloModuleConfig ConfigWithLayoutAssignment() {
- return GetModuleConfigForTest();
+ HloModuleConfig config;
+ auto debug_options = HloTestBase::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+ config.set_debug_options(debug_options);
+ return config;
}
HloModuleConfig ConfigWithoutLayoutAssignment() {
@@ -42,6 +46,7 @@
auto debug_options = HloTestBase::GetDebugOptionsForTest();
// Disable layout_assignment to use the preassigned layouts.
debug_options.add_xla_disable_hlo_passes("layout-assignment");
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
config.set_debug_options(debug_options);
return config;
}
@@ -635,6 +640,8 @@
}
)";
auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value();
+ auto &debug_options = hlo_module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
auto expected_ir = is_built_with_rocm_ ? R"(
; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] }
; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc
deleted file mode 100644
index b1d2734..0000000
--- a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc
+++ /dev/null
@@ -1,572 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuReduceScatterCreatorTest : public HloTestBase {
- public:
- absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
- absl::string_view hlo_module, int64_t num_replicas,
- int64_t num_partitions, bool expect_change) {
- HloModuleConfig config = GetModuleConfigForTest(
- /*replica_count=*/num_replicas,
- /*num_partitions=*/num_partitions);
- config.set_use_spmd_partitioning(num_partitions > 1);
- TF_ASSIGN_OR_RETURN(auto module,
- ParseAndReturnVerifiedModule(hlo_module, config));
- auto changed = ReduceScatterCreator().Run(module.get());
- if (!changed.ok()) {
- return changed.status();
- }
- EXPECT_EQ(changed.value(), expect_change);
- return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
- }
-
- size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
- return CollectiveCount(module, HloOpcode::kAllReduce);
- }
-
- size_t ReduceScatterCount(std::unique_ptr<HloModule> &module) {
- return CollectiveCount(module, HloOpcode::kAllReduce);
- }
-
- private:
- size_t CollectiveCount(std::unique_ptr<HloModule> &module, HloOpcode opcode) {
- return absl::c_count_if(
- module->entry_computation()->instructions(),
- [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
- }
-};
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicas) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={}, to_apply=%sum
- %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
- %rid = u32[] replica-id()
- %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%id)
- %slice_size = s32[] constant(4)
- %offset = s32[] multiply(%reshape, %slice_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
- dynamic_slice_sizes={4,8,128}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/true));
- ASSERT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- const auto *rs = Cast<HloReduceScatterInstruction>(
- module->entry_computation()->root_instruction());
- EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithOffsetReshape) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={}, to_apply=%sum
- %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
- %rid = u32[] replica-id()
- %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
- %slice_size = s32[1] constant({4})
- %offset = s32[1] multiply(%id, %slice_size)
- %reshape = s32[] reshape(%offset)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %reshape, %zero, %zero),
- dynamic_slice_sizes={4,8,128}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/true));
- ASSERT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- const auto *rs = Cast<HloReduceScatterInstruction>(
- module->entry_computation()->root_instruction());
- EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={}, to_apply=%sum
- %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
- %rid = u32[] replica-id()
- %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%id)
- %slice_size = s32[] constant(4)
- %offset = s32[] multiply(%reshape, %slice_size)
- %zero = s32[] constant(0)
- %reshape.1 = f32[32,16,64] reshape(%all-reduce)
- ROOT %dynamic-slice = f32[4,16,64] dynamic-slice(%reshape.1, %offset, %zero, %zero),
- dynamic_slice_sizes={4,16,64}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshapeSplitDimModified) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[336,1024] parameter(0)
- %all-reduce = f32[336,1024] all-reduce(%param), replica_groups={}, to_apply=%sum
- %rid = u32[] replica-id()
- %id = s32[] convert(%rid)
- %slice_size = s32[] constant(128)
- %offset = s32[] multiply(%id, %slice_size)
- %zero = s32[] constant(0)
- %reshape.1 = f32[4,84,1024] reshape(%all-reduce)
- ROOT %dynamic-slice = f32[4,84,128] dynamic-slice(%reshape.1, %zero, %zero, %offset),
- dynamic_slice_sizes={4,84,128}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasDim2) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={}, to_apply=%sum
- %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
- %rid = u32[] replica-id()
- %rid_s32 = s32[] convert(%rid)
- %slice_size = s32[] constant(16)
- %offset = s32[] multiply(%rid_s32, %slice_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[32,8,16] dynamic-slice(%all-reduce, %zero, %zero, %offset),
- dynamic_slice_sizes={32,8,16}
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/true));
- ASSERT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- const auto *rs = Cast<HloReduceScatterInstruction>(
- module->entry_computation()->root_instruction());
- EXPECT_EQ(rs->scatter_dimension(), 2) << rs->ToString();
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWrongOffsets) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={}, to_apply=%sum
- %table = s32[8]{0} constant({0,1,2,3,4,5,6,8})
- %rid = u32[] replica-id()
- %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%id)
- %slice_size = s32[] constant(4)
- %offset = s32[] multiply(%reshape, %slice_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
- dynamic_slice_sizes={4,8,128}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/1,
- /*expect_change=*/false));
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasIotaTable) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={}, to_apply=%sum
- %table = s32[8]{0} iota(), iota_dimension=0
- %rid = u32[] replica-id()
- %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%id)
- %slice_size = s32[] constant(4)
- %offset = s32[] multiply(%reshape, %slice_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
- dynamic_slice_sizes={4,8,128}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/2,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupedReplicas) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum
- %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
- %rid = u32[] replica-id()
- %id = s32[1] dynamic-slice(%gtable, %rid), dynamic_slice_sizes={1}
- %reshape.0 = s32[] reshape(%id)
- %table = s32[4]{0} constant({0,8,16,24})
- %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
- %reshape.1 = s32[] reshape(%offset)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
- dynamic_slice_sizes={8,8,128}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/8,
- /*num_partitions=*/2,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllPartitions) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={{0},{1}}, to_apply=%sum, channel_id=1
- %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
- %pid = u32[] partition-id()
- %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%id)
- %slice_size = s32[] constant(4)
- %offset = s32[] multiply(%reshape, %slice_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
- dynamic_slice_sizes={4,8,128}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/2,
- /*num_partitions=*/8,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReduceFollowedByAllReduce) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce.scattered = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=1
- %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
- %pid = u32[] partition-id()
- %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%id)
- %slice_size = s32[] constant(4)
- %offset = s32[] multiply(%reshape, %slice_size)
- %zero = s32[] constant(0)
- %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce.scattered, %offset, %zero, %zero),
- dynamic_slice_sizes={4,8,128}
- ROOT %all-reduce.sync = f32[4,8,128]{2,1,0} all-reduce(%dynamic-slice),
- replica_groups={{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=2
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/2,
- /*num_partitions=*/8,
- /*expect_change=*/true));
- EXPECT_EQ(AllReduceCount(module), 1);
- EXPECT_EQ(ReduceScatterCount(module), 1);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobals) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
- %pid = u32[] partition-id()
- %rid = u32[] replica-id()
- %pcount = u32[] constant(4)
- %ridxp = u32[] multiply(%rid, %pcount)
- %gid = u32[] add(%ridxp, %pid)
- %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
- %id = s32[1] dynamic-slice(%gtable, %gid), dynamic_slice_sizes={1}
- %reshape.0 = s32[] reshape(%id)
- %table = s32[4]{0} constant({0,8,16,24})
- %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
- %reshape.1 = s32[] reshape(%offset)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
- dynamic_slice_sizes={8,8,128}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/2,
- /*num_partitions=*/4,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsOrthogonalReplicas) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={{1,3,2,0},{5,7,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
- %pid = u32[] partition-id()
- %pid_table = s32[4]{0} constant({3,0,2,1})
- %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%offset)
- %shard_size = s32[] constant(8)
- %mul = s32[] multiply(%reshape, %shard_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
- dynamic_slice_sizes={8,8,128}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/2,
- /*num_partitions=*/4,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Parameter(0))));
- EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsNonOrthogonalReplicas) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[32,8,128]{2,1,0} parameter(0)
- %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
- replica_groups={{1,3,2,0},{7,5,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
- %pid = u32[] partition-id()
- %pid_table = s32[4]{0} constant({3,0,2,1})
- %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%offset)
- %shard_size = s32[] constant(8)
- %mul = s32[] multiply(%reshape, %shard_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
- dynamic_slice_sizes={8,8,128}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/2,
- /*num_partitions=*/4,
- /*expect_change=*/false));
-}
-
-TEST_F(GpuReduceScatterCreatorTest, NonUniformSplit) {
- absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
- %a = f32[] parameter(0)
- %b = f32[] parameter(1)
- ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
- %param = f32[1,7]{1,0} parameter(0)
- %all-reduce = f32[1,7]{1,0} all-reduce(%param),
- replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
- %pid = u32[] partition-id()
- %pid_table = s32[8]{0} constant({0, 1, 0, 1, 0, 1, 0, 1})
- %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
- %reshape = s32[] reshape(%offset)
- %shard_size = s32[] constant(3)
- %mul = s32[] multiply(%reshape, %shard_size)
- %zero = s32[] constant(0)
- ROOT %dynamic-slice = f32[1,3] dynamic-slice(%all-reduce, %zero, %mul),
- dynamic_slice_sizes={1,3}
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
- /*num_replicas=*/1,
- /*num_partitions=*/8,
- /*expect_change=*/true));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::ReduceScatter(m::Slice(m::Parameter(0)))));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
deleted file mode 100644
index bb6eb63..0000000
--- a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
+++ /dev/null
@@ -1,133 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class ReductionDegenerateDimRemoverTest : public HloTestBase {
- public:
- void CheckDegenerateDimRemover(absl::string_view hlo,
- std::optional<absl::string_view> expected) {
- RunAndFilecheckHloRewrite(hlo, gpu::ReductionDegenerateDimRemover{},
- expected);
- }
-};
-
-TEST_F(ReductionDegenerateDimRemoverTest, ReductionWithDegenerateDimensions) {
- const char* hlo = R"(
-HloModule ReduceWithDegenerateDimensions
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1,3,1,4,1,5,1] parameter(0)
- zero = f32[] constant(0)
-
- ROOT out = f32[1,1,1,1] reduce(input, zero), dimensions={1,3,5}, to_apply=add
-}
-
-)";
-
- CheckDegenerateDimRemover(hlo, R"(
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[reduce_2:%[^ ]+]] = f32[] reduce([[bitcast_0]], [[zero_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[reduce_2]])
- )");
-}
-
-TEST_F(ReductionDegenerateDimRemoverTest,
- ReductionWithDegenerateDimensionsVariadic) {
- const char* hlo = R"(
-HloModule ReduceWithDegenerateDimensions
-
-argmax {
- running_max = f32[] parameter(0)
- running_max_idx = u32[] parameter(1)
- current_value = f32[] parameter(2)
- current_value_idx = u32[] parameter(3)
-
- current = (f32[], u32[]) tuple(running_max, running_max_idx)
- potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
- cmp_code = pred[] compare(current_value, running_max), direction=GT
-
- new_max = f32[] select(cmp_code, current_value, running_max)
- new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
- ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
- input = f32[1,3,1,4,1,5,1] parameter(0)
- idxs = u32[1,3,1,4,1,5,1] parameter(1)
- zero = f32[] constant(0)
- zero_idx = u32[] constant(0)
-
- ROOT out = (f32[1,1,1,1], u32[1,1,1,1]) reduce(input, idxs, zero, zero_idx), dimensions={1,3,5}, to_apply=argmax
-}
-
-)";
-
- CheckDegenerateDimRemover(hlo, R"(
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[bitcast_1_2:%[^ ]+]] = u32[3,4,5]{2,1,0} bitcast([[idxs_3:%[^ ]+]])
-// CHECK: [[reduce_4:%[^ ]+]] = (f32[], u32[]) reduce([[bitcast_0]], [[bitcast_1_2]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
-// CHECK-NEXT: [[get_tuple_element_8:%[^ ]+]] = f32[] get-tuple-element([[reduce_4]]), index=0
-// CHECK-NEXT: [[bitcast_2_9:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_8]])
-// CHECK-NEXT: [[get_tuple_element_1_10:%[^ ]+]] = u32[] get-tuple-element([[reduce_4]]), index=1
-// CHECK-NEXT: [[bitcast_3_11:%[^ ]+]] = u32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_1_10]])
-// CHECK-NEXT: ROOT [[tuple_12:%[^ ]+]] = (f32[1,1,1,1]{3,2,1,0}, u32[1,1,1,1]{3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
-)");
-}
-
-TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
- const char* hlo = R"(
-HloModule ReduceWithDegenerateDimensions
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1,3,1,4,1,5,1] parameter(0)
- zero = f32[] constant(0)
-
- ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
-}
-)";
-
- CheckDegenerateDimRemover(hlo,
- R"(
-// CHECK: ROOT [[bitcast_0:%[^ ]+]] = f32[3,4,5,1]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
- )");
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
deleted file mode 100644
index fa149a1..0000000
--- a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
+++ /dev/null
@@ -1,103 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_dimension_grouper.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class ReductionDimensionGrouperTest : public HloTestBase {
- public:
- void CheckDimensionGrouper(absl::string_view hlo,
- std::optional<absl::string_view> expected) {
- RunAndFilecheckHloRewrite(hlo, gpu::ReductionDimensionGrouper{}, expected);
- }
-};
-
-TEST_F(ReductionDimensionGrouperTest, ReductionWithGrouping) {
- const char* hlo = R"(
-HloModule ReductionWithGrouping
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[100,10,32,3]{3,2,1,0} parameter(0)
- zero = f32[] constant(0)
-
- ROOT out = f32[100,10]{0,1} reduce(input, zero), dimensions={2,3}, to_apply=add
-}
-)";
-
- CheckDimensionGrouper(hlo,
- R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
-// CHECK: ROOT [[out_1_2:%[^ ]+]] = f32[100,10]{0,1} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
- )");
-}
-
-TEST_F(ReductionDimensionGrouperTest, ReductionWithGroupingVariadic) {
- const char* hlo = R"(
-HloModule ReductionWithGrouping
-
-argmax {
- running_max = f32[] parameter(0)
- running_max_idx = u32[] parameter(1)
- current_value = f32[] parameter(2)
- current_value_idx = u32[] parameter(3)
-
- current = (f32[], u32[]) tuple(running_max, running_max_idx)
- potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
- cmp_code = pred[] compare(current_value, running_max), direction=GT
-
- new_max = f32[] select(cmp_code, current_value, running_max)
- new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
- ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
- input = f32[100,10,32,3]{3,2,1,0} parameter(0)
- idxs = u32[100,10,32,3]{3,2,1,0} parameter(1)
- zero = f32[] constant(0)
- zero_idx = u32[] constant(0)
-
- ROOT out = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(input, idxs, zero, zero_idx), dimensions={2,3}, to_apply=argmax
-}
-)";
-
- CheckDimensionGrouper(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
-// CHECK: [[idxs_2:%[^ ]+]] = u32[100,10,32,3]{3,2,1,0} parameter(1)
-// CHECK: [[bitcast_1_3:%[^ ]+]] = u32[100,10,96]{2,1,0} bitcast([[idxs_2]])
-// CHECK: ROOT [[out_1_4:%[^ ]+]] = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_7:%[^ ]+]]
-)");
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
deleted file mode 100644
index 817d9c7..0000000
--- a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
+++ /dev/null
@@ -1,164 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_layout_normalizer.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/error_spec.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class ReductionLayoutNormalizerTest : public HloTestBase {
- public:
- void CheckReductionLayoutNormalizer(
- absl::string_view hlo, std::optional<absl::string_view> expected) {
- RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
- }
-};
-
-TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
- const char* hlo = R"(
-HloModule ReduceWithLayoutChange
-
-add {
- x0 = f32[] parameter(0)
- y0 = f32[] parameter(1)
- ROOT add0 = f32[] add(x0, y0)
-}
-
-ENTRY main {
- arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
- constant0 = f32[] constant(0)
- ROOT reduce0 = f32[4,5,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
- dimensions={1,6,7}, to_apply=add
-}
-
-)";
-
- CheckReductionLayoutNormalizer(hlo,
- R"(
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
-// CHECK: [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
- )");
-}
-
-TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
- const char* hlo = R"(
-HloModule ReduceWithLayoutChangeVariadic
-
-
-argmax {
- running_max = f32[] parameter(0)
- running_max_idx = u32[] parameter(1)
- current_value = f32[] parameter(2)
- current_value_idx = u32[] parameter(3)
-
- current = (f32[], u32[]) tuple(running_max, running_max_idx)
- potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
- cmp_code = pred[] compare(current_value, running_max), direction=GT
-
- new_max = f32[] select(cmp_code, current_value, running_max)
- new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
- ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
- arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
- idxs = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
- constant0 = f32[] constant(0)
- constant1 = u32[] constant(0)
- ROOT reduce0 = (
- f32[4,5,16,12,12]{4,3,2,1,0},
- u32[4,5,16,12,12]{4,3,2,1,0}
- ) reduce(arg0, idxs, constant0,constant1), dimensions={1,6,7}, to_apply=argmax
-}
-
-
-)";
-
- CheckReductionLayoutNormalizer(hlo,
- R"(
-// CHECK: [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
-// CHECK: [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
-// CHECK: [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
-// CHECK: [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
-// CHECK: [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
-// CHECK: [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
-// CHECK: [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
-// CHECK: [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
-// CHECK: ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
- )");
-}
-
-TEST_F(ReductionLayoutNormalizerTest,
- LayoutCanonicalizerTestVariadicDifferentLayouts) {
- const char* hlo = R"(
-HloModule ReduceWithLayoutChangeVariadicDifferent
-
-argmax {
- running_max = f32[] parameter(0)
- running_max_idx = u32[] parameter(1)
- current_value = f32[] parameter(2)
- current_value_idx = u32[] parameter(3)
-
- current = (f32[], u32[]) tuple(running_max, running_max_idx)
- potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
- cmp_code = pred[] compare(current_value, running_max), direction=GT
-
- new_max = f32[] select(cmp_code, current_value, running_max)
- new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
- ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
- arg0 = f32[2,3,4,7]{2,1,0,3} parameter(0)
- idxs = u32[2,3,4,7]{3,2,1,0} parameter(1)
- constant0 = f32[] constant(0)
- constant1 = u32[] constant(0)
- ROOT reduce0 = (
- f32[2,3,4]{2,1,0},
- u32[2,3,4]{2,1,0}
- ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
-}
-
-
-)";
-
- CheckReductionLayoutNormalizer(hlo,
- R"(
-// CHECK: [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
-// CHECK: [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
-// CHECK: [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
-// CHECK: [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
-// CHECK: ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
- )");
- EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir
new file mode 100644
index 0000000..d8abd2d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir
@@ -0,0 +1,46 @@
+// RUN: xla-opt %s \
+// RUN: -convert-triton-to-tritongpu='target=cuda:80' \
+// RUN: -sparse-add-encoding -canonicalize \
+// RUN: | FileCheck %s
+
+// Note: 'canonicalize' folds redundant (back-and-forth) convert_layout ops.
+
+// CHECK-DAG: #[[BLOCKED4x4:.*]] = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
+// CHECK-DAG: #[[BLOCKED1x1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
+
+module {
+ // CHECK: @sparse_dot
+ tt.func @sparse_dot() {
+ // CHECK-NEXT: %[[A:.*]] = arith.constant dense<1.000000e+00>
+ // CHECK-SAME: : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>>
+ // CHECK-NEXT: %[[B:.*]] = arith.constant dense<2.000000e+00>
+ // CHECK-SAME: : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>>
+ // CHECK-NEXT: %[[C:.*]] = arith.constant dense<0.000000e+00>
+ // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]>
+ // CHECK-NEXT: %[[META:.*]] = arith.constant dense<13107>
+ // CHECK-SAME: : tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>>
+ %a = arith.constant dense<1.00e+00> : tensor<64x32xf16>
+ %b = arith.constant dense<2.00e+00> : tensor<64x64xf16>
+ %c = arith.constant dense<0.00e+00> : tensor<64x64xf32>
+ %meta = arith.constant dense<0x3333> : tensor<64x4xi16>
+
+ // CHECK-NEXT: %[[D:.*]] = triton_gpu.sparse_dot %[[A]], %[[B]], %[[C]], %[[META]]
+ // CHECK-SAME: : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>>
+ // CHECK-SAME: meta tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>>
+ // CHECK-SAME: * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>>
+ // CHECK-SAME: -> tensor<64x64xf32, #[[BLOCKED4x4]]>
+ %d = triton_gpu.sparse_dot %a, %b, %c, %meta
+ : tensor<64x32xf16> meta tensor<64x4xi16> * tensor<64x64xf16> -> tensor<64x64xf32>
+
+ // CHECK-NEXT: %[[CVT:.*]] = triton_gpu.convert_layout %[[D]]
+ // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]>
+ // CHECK-SAME: -> tensor<64x64xf32, #[[BLOCKED1x1]]>
+ // CHECK-NEXT: tt.print "" {hex = false, isSigned = array<i32: 0>} : %[[CVT]]
+ // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED1x1]]>
+ // A use with side effects so we don't DCE the whole function.
+ tt.print "" { hex = false, isSigned = array<i32: 0>} : %d : tensor<64x64xf32>
+
+ // CHECK-NEXT: tt.return
+ tt.return
+ }
+}
diff --git a/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir b/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir
deleted file mode 100644
index 10b3e45..0000000
--- a/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir
+++ /dev/null
@@ -1,46 +0,0 @@
-// RUN: xla-opt %s \
-// RUN: -convert-triton-to-tritongpu='target=cuda:80' \
-// RUN: -add-sparse-encoding -canonicalize \
-// RUN: | FileCheck %s
-
-// Note: 'canonicalize' folds redundant (back-and-forth) convert_layout ops.
-
-// CHECK-DAG: #[[BLOCKED4x4:.*]] = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
-// CHECK-DAG: #[[BLOCKED1x1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
-
-module {
- // CHECK: @sparse_dot
- tt.func @sparse_dot() {
- // CHECK-NEXT: %[[A:.*]] = arith.constant dense<1.000000e+00>
- // CHECK-SAME: : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>>
- // CHECK-NEXT: %[[B:.*]] = arith.constant dense<2.000000e+00>
- // CHECK-SAME: : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>>
- // CHECK-NEXT: %[[C:.*]] = arith.constant dense<0.000000e+00>
- // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]>
- // CHECK-NEXT: %[[META:.*]] = arith.constant dense<13107>
- // CHECK-SAME: : tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>>
- %a = arith.constant dense<1.00e+00> : tensor<64x32xf16>
- %b = arith.constant dense<2.00e+00> : tensor<64x64xf16>
- %c = arith.constant dense<0.00e+00> : tensor<64x64xf32>
- %meta = arith.constant dense<0x3333> : tensor<64x4xi16>
-
- // CHECK-NEXT: %[[D:.*]] = triton_gpu.sparse_dot %[[A]], %[[B]], %[[C]], %[[META]]
- // CHECK-SAME: : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>>
- // CHECK-SAME: meta tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>>
- // CHECK-SAME: * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>>
- // CHECK-SAME: -> tensor<64x64xf32, #[[BLOCKED4x4]]>
- %d = triton_gpu.sparse_dot %a, %b, %c, %meta
- : tensor<64x32xf16> meta tensor<64x4xi16> * tensor<64x64xf16> -> tensor<64x64xf32>
-
- // CHECK-NEXT: %[[CVT:.*]] = triton_gpu.convert_layout %[[D]]
- // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]>
- // CHECK-SAME: -> tensor<64x64xf32, #[[BLOCKED1x1]]>
- // CHECK-NEXT: tt.print "" {hex = false, isSigned = array<i32: 0>} : %[[CVT]]
- // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED1x1]]>
- // A use with side effects so we don't DCE the whole function.
- tt.print "" { hex = false, isSigned = array<i32: 0>} : %d : tensor<64x64xf32>
-
- // CHECK-NEXT: tt.return
- tt.return
- }
-}
diff --git a/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir b/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir
new file mode 100644
index 0000000..ad61162
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir
@@ -0,0 +1,25 @@
+// RUN: xla-opt %s -convert-triton-to-tritongpu='target=cuda:80' | FileCheck %s
+
+module attributes {} {
+ tt.func @gemm_fusion_dot_1_impl() {
+ %c0_i32 = arith.constant 0 : i32
+ %c32_i32 = arith.constant 32 : i32
+ %acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
+ %a = arith.constant dense<0.000000e+00> : tensor<32x16xbf16>
+ // CHECK: %[[A:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x16xbf16, {{.+}}> -> tensor<32x16xbf16>
+ %b = arith.constant dense<0.000000e+00> : tensor<32x32xbf16>
+ // CHECK: %[[B:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xbf16, {{.+}}> -> tensor<32x32xbf16>
+ %meta = arith.constant dense<0> : tensor<32x2xi16>
+ // CHECK: %[[META:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x2xi16, {{.+}}> -> tensor<32x2xi16>
+ %35:1 = scf.for %arg4 = %c0_i32 to %c32_i32 step %c32_i32 iter_args(%arg8 = %acc) -> (tensor<32x32xf32>) : i32 {
+ // CHECK: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xf32, {{.+}}> -> tensor<32x32xf32>
+ // CHECK-NEXT: %[[D:.*]] = triton_gpu.sparse_dot %[[A]], %[[B]], %[[ACC]], %[[META]]
+ // CHECK-SAME: : tensor<32x16xbf16> meta tensor<32x2xi16>
+ // CHECK-SAME: * tensor<32x32xbf16> -> tensor<32x32xf32>
+ %74 = triton_gpu.sparse_dot %a, %b, %arg8, %meta : tensor<32x16xbf16> meta tensor<32x2xi16> * tensor<32x32xbf16> -> tensor<32x32xf32>
+ // CHECK: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xf32> -> tensor<32x32xf32, {{.+}}>
+ scf.yield %74 : tensor<32x32xf32>
+ }
+ tt.return
+ }
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
deleted file mode 100644
index ef6e189..0000000
--- a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
+++ /dev/null
@@ -1,557 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/tree_reduction_rewriter.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class TreeReductionRewriterTest : public HloTestBase {
- public:
- void CheckTreeRewriter(absl::string_view hlo,
- std::optional<absl::string_view> expected) {
- RunAndFilecheckHloRewrite(
- hlo,
-#if TENSORFLOW_USE_ROCM
- gpu::GpuTreeReductionRewriter{se::RocmComputeCapability {
- "908"
- }},
-#else
- gpu::GpuTreeReductionRewriter{se::CudaComputeCapability{8, 1}},
-#endif
- expected);
- }
-};
-
-TEST_F(TreeReductionRewriterTest, RowReductionSingleDimensionNoBatched) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[50021] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[pad_0:%[^ ]+]] = f32[50022]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1
-// CHECK: [[bitcast_3:%[^ ]+]] = f32[126,397]{1,0} bitcast([[pad_0]])
-// CHECK: [[reduce_4:%[^ ]+]] = f32[126]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]]
-// CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionWeirdOutputLayout) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[2,4,17000]{2,1,0} parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[2,4]{0,1} reduce(input, zero), dimensions={2}, to_apply=add
-}
-)";
-
- // Check that we preserve the layout.
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: f32[2,4]{0,1} reduce(
- )");
-}
-
-TEST_F(TreeReductionRewriterTest,
- RowReductionSingleDimensionNoBatchedDivisible) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[50048] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[50048]{0} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,391]{1,0} bitcast([[input_0]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[128]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionNoBatched) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[100,10,65536] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[100,10] reduce(input, zero), dimensions={2}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[100,10,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[100,10,256]{2,1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100,10]{1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={2}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest,
- RowReductionSingleDimensionNoBatchedLargeInput) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1048576] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1048576]{0} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[1024,1024]{1,0} bitcast([[input_0]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[1024]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionFits) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[8,100,65536] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[8,100,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[100,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={0,3}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[32,100,90000] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[reduce_0:%[^ ]+]] = f32[32,100]{1,0} reduce([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), dimensions={2}, to_apply=[[add_3:%[^ ]+]]
-// CHECK: ROOT [[out_1_4:%[^ ]+]] = f32[100]{0} reduce([[reduce_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_3]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionSimple) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[16384,100] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-
-// CHECK: [[input_0:%[^ ]+]] = f32[16384,100]{1,0} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,128,100]{2,1,0} bitcast([[input_0]])
-// CHECK: [[reduce_2:%[^ ]+]] = f32[128,100]{1,0} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_2]], [[zero_3]]), dimensions={0}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionSimpleNoDivisible) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[10303,100] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[10303,100]{1,0} parameter(0)
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1x0_0
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[64,161,100]{2,1,0} bitcast([[pad_0]])
-// CHECK: [[reduce_3:%[^ ]+]] = f32[64,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionOtherIndex) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[16384,2,2,2] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[16384,2,2,2]{3,2,1,0} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,128,2,2,2]{4,3,2,1,0} bitcast([[input_0]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[128,2,2,2]{3,2,1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionVeryLargeInput) {
- const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1048576,5] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[1024,1024,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,5]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[5]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, VariadicReductionLargeRow) {
- const char* hlo = R"(
-HloModule Reduce_R1x2_to_R0x2_argmax
-
-argmax {
- running_max = f32[] parameter(0)
- running_max_idx = u32[] parameter(1)
- current_value = f32[] parameter(2)
- current_value_idx = u32[] parameter(3)
-
- current = (f32[], u32[]) tuple(running_max, running_max_idx)
- potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
- cmp_code = pred[] compare(current_value, running_max), direction=GT
-
- new_max = f32[] select(cmp_code, current_value, running_max)
- new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
- ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
- input = f32[2,100003] parameter(0)
- idxs = u32[2,100003] iota(), iota_dimension=0
- zero = f32[] constant(0)
- zero_idx = u32[] constant(0)
-
- ROOT out = (f32[2], u32[2]) reduce(
- input, idxs, zero, zero_idx),
- dimensions={1},
- to_apply=%argmax
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[pad_0:%[^ ]+]] = f32[2,100005]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_2
-// CHECK: [[bitcast_3:%[^ ]+]] = f32[2,295,339]{2,1,0} bitcast([[pad_0]])
-// CHECK: [[zero_idx_4:%[^ ]+]] = u32[] constant(0)
-// CHECK: [[pad_1_5:%[^ ]+]] = u32[2,100005]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_2
-// CHECK: [[bitcast_1_7:%[^ ]+]] = u32[2,295,339]{2,1,0} bitcast([[pad_1_5]])
-// CHECK: [[reduce_8:%[^ ]+]] = (f32[2,295]{1,0}, u32[2,295]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]]
-// CHECK: [[get_tuple_element_10:%[^ ]+]] = f32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=0
-// CHECK: [[get_tuple_element_1_11:%[^ ]+]] = u32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=1
-// CHECK: ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, VariadicReductionLargeBatchSize) {
- const char* hlo = R"(
-HloModule Reduce_R1x2_to_R0x2_argmax
-
-argmax {
- running_max = f32[] parameter(0)
- running_max_idx = u32[] parameter(1)
- current_value = f32[] parameter(2)
- current_value_idx = u32[] parameter(3)
-
- current = (f32[], u32[]) tuple(running_max, running_max_idx)
- potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
- cmp_code = pred[] compare(current_value, running_max), direction=GT
-
- new_max = f32[] select(cmp_code, current_value, running_max)
- new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
- ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
- input = f32[20,2,100] parameter(0)
- idxs = u32[20,2,100] iota(), iota_dimension=0
- zero = f32[] constant(0)
- zero_idx = u32[] constant(0)
-
- ROOT out = (f32[2], u32[2]) reduce(
- input, idxs, zero, zero_idx),
- dimensions={0,2},
- to_apply=%argmax
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-// CHECK: [[reduce_0:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce([[input_1:%[^ ]+]], [[idxs_2:%[^ ]+]], [[zero_3:%[^ ]+]], [[zero_idx_4:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_5:%[^ ]+]]
-// CHECK: [[get_tuple_element_6:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=0
-// CHECK: [[get_tuple_element_1_7:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=1
-// CHECK: ROOT [[out_1_8:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_6]], [[get_tuple_element_1_7]], [[zero_3]], [[zero_idx_4]]), dimensions={0}, to_apply=[[argmax_5]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, KeepInnerReductionVectorized) {
- const char* hlo = R"(
-HloModule KeepInnerRowReductionVectorized
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1024,73984] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[1024,289,256]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,289]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, PreferLargeVectorizedDimension) {
- const char* hlo = R"(
-HloModule PreferLargeVectorizedDimension
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1024,98304] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[1024,256,384]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, SwapIfNonAlignedBeforePadding) {
- const char* hlo = R"(
-HloModule SwapIfNonAlignedBeforePadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1024,19739] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-
-// CHECK-DAG: [[bitcast_0:%[^ ]+]] = f32[1024,140,141]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK-DAG: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, DontSwapIfNonAlignedBeforePadding) {
- const char* hlo = R"(
-HloModule DontSwapIfNonAlignedBeforePadding
-
-add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
- input = f32[1024,19459] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
- CheckTreeRewriter(hlo,
- R"(
-
-// CHECK-DAG: [[bitcast_0:%[^ ]+]] = f32[1024,140,139]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK-DAG: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
- )");
-}
-
-TEST_F(TreeReductionRewriterTest, NonCosequtiveReductionDims) {
- const char* hlo = R"(
- HloModule NonCosequtiveReductionDims
-
- add {
- accum = f32[] parameter(0)
- op = f32[] parameter(1)
- ROOT out = f32[] add(accum, op)
- }
-
- ENTRY main {
- input = f32[5,3,4,5] parameter(0)
- zero = f32[] constant(0)
- ROOT out = f32[5,4] reduce(input, zero), dimensions={1,3}, to_apply=add
- }
- )";
-
- CheckTreeRewriter(hlo, std::nullopt);
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/xla-opt.cc b/third_party/xla/xla/service/gpu/tests/xla-opt.cc
index 30bd45f..cd5eeb8 100644
--- a/third_party/xla/xla/service/gpu/tests/xla-opt.cc
+++ b/third_party/xla/xla/service/gpu/tests/xla-opt.cc
@@ -15,16 +15,14 @@
#include "mlir/InitAllExtensions.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
-#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h"
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
#include "third_party/triton/bin/RegisterTritonDialects.h"
int main(int argc, char **argv) {
mlir::DialectRegistry registry;
mlir::registerAllExtensions(registry);
registerTritonDialects(registry); // This registers all passes as well.
- xla::gpu::RegisterSparsePasses();
- xla::gpu::RegisterPreventMmaV3LoopUnrollingPass();
+ xla::gpu::registerTritonFusionTransformsPasses();
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "xla-opt modular optimizer driver\n", registry));
diff --git a/third_party/xla/xla/service/gpu/topk_specializer.cc b/third_party/xla/xla/service/gpu/topk_specializer.cc
deleted file mode 100644
index bd01a07..0000000
--- a/third_party/xla/xla/service/gpu/topk_specializer.cc
+++ /dev/null
@@ -1,113 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/topk_specializer.h"
-
-#include <stddef.h>
-
-#include <initializer_list>
-#include <string>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/primitive_util.h"
-#include "xla/service/hlo.pb.h"
-#include "xla/service/tuple_util.h"
-#include "xla/shape.h"
-#include "xla/status_macros.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-absl::StatusOr<HloInstruction*> SmallBufferOptimization(
- HloCustomCallInstruction* topk) {
- Shape data_shape = topk->operand(0)->shape();
- auto supported_dtypes = {F32, BF16};
- if (!absl::c_linear_search(supported_dtypes, data_shape.element_type())) {
- return InvalidArgument(
- "Invalid Dtype: %s",
- primitive_util::LowercasePrimitiveTypeName(data_shape.element_type()));
- }
- // We only support topk of the shape [x] or [batch, x].
- if (data_shape.dimensions_size() > 2) {
- return InvalidArgument("Invalid input dimensions: %s",
- data_shape.ToString());
- }
- bool has_batch = data_shape.dimensions_size() == 2;
- constexpr size_t max_k = 16;
- constexpr size_t min_n = 1024;
- size_t n = data_shape.dimensions(has_batch ? 1 : 0);
- size_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
- if (k > max_k) {
- return InvalidArgument("k too large (%d), must be <= %d", k, max_k);
- }
- if (n < min_n) {
- return InvalidArgument("Input too small (n=%d, min_n=%d)", n, min_n);
- }
- HloComputation* comp = topk->parent();
- HloInstruction* new_topk =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- topk->shape(), topk->operands(),
- // We don't need the original to_apply, but keeping it around allows
- // us to round-trip this CustomCall on tests.
- topk->to_apply(), "__gpu$TopK",
- /*opaque=*/"", CustomCallApiVersion::API_VERSION_TYPED_FFI));
- return TupleUtil::ExtractPrefix(new_topk, 2);
-}
-
-class SpecializeTopkVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleCustomCall(HloInstruction* inst) override {
- HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
- if (topk == nullptr || topk->custom_call_target() != "TopK") {
- return absl::OkStatus();
- }
- TF_RET_CHECK(topk->operand_count() == 1);
-
- if (auto small_topk = SmallBufferOptimization(topk); small_topk.ok()) {
- return ReplaceInstruction(topk, *small_topk);
- } else {
- VLOG(2) << "Small TopK optimization doesn't match: "
- << small_topk.status();
- }
-
- return absl::OkStatus();
- }
-};
-
-} // namespace
-
-absl::StatusOr<bool> TopkSpecializer::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- return SpecializeTopkVisitor().RunOnModule(module, execution_threads);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/topk_specializer.h b/third_party/xla/xla/service/gpu/topk_specializer.h
deleted file mode 100644
index 5b57f57..0000000
--- a/third_party/xla/xla/service/gpu/topk_specializer.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_
-#define XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass transforms eligible TopK CustomCall into a call to be executed by
-// runtime/topk.cc.
-class TopkSpecializer : public HloModulePass {
- public:
- absl::string_view name() const override { return "topk-specializer"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/topk_splitter.cc b/third_party/xla/xla/service/gpu/topk_splitter.cc
deleted file mode 100644
index d20116d..0000000
--- a/third_party/xla/xla/service/gpu/topk_splitter.cc
+++ /dev/null
@@ -1,154 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/topk_splitter.h"
-
-#include <algorithm>
-#include <cmath>
-#include <cstddef>
-#include <cstdint>
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/numeric/bits.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr size_t kRequiredAlignment = 1024;
-constexpr size_t kMaximumBatchSize = 1024;
-
-class TopkSplitterVisitor : public DfsHloRewriteVisitor {
- public:
- explicit TopkSplitterVisitor(size_t split_threshold)
- : split_threshold_(split_threshold) {}
-
- absl::Status HandleCustomCall(HloInstruction* inst) override {
- HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
- if (topk == nullptr || topk->custom_call_target() != "TopK") {
- return absl::OkStatus();
- }
- HloComputation* comp = inst->parent();
- Shape data_shape = topk->operand(0)->shape();
- bool has_batch = data_shape.dimensions_size() == 2;
- // TODO(doak): Support multiple batches.
- if (has_batch && data_shape.dimensions(0) != 1) {
- return absl::OkStatus();
- }
- size_t n = data_shape.dimensions(has_batch ? 1 : 0);
- int64_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
- // If K approaches N, splitting the input will not be beneficial anymore.
- if (k > sqrt(n)) {
- return absl::OkStatus();
- }
- // TODO(doak): Relax this alignment requirement.
- if (n % kRequiredAlignment != 0) {
- return absl::OkStatus();
- }
- if (n < split_threshold_) return absl::OkStatus();
- int new_batch =
- std::min(absl::bit_floor(n / split_threshold_), kMaximumBatchSize);
- int new_n = n / new_batch;
- // Split the input into B batches and compute TopK over the batched arrays.
- Shape split_input_shape =
- ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, new_n});
- TF_ASSIGN_OR_RETURN(
- HloInstruction * reshaped,
- MakeReshapeHlo(split_input_shape, topk->mutable_operand(0)));
- Shape batch_topk_shape = ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, k}),
- ShapeUtil::MakeShape(S32, {new_batch, k})});
- HloInstruction* batch_topk =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- batch_topk_shape, {reshaped}, topk->to_apply(), "TopK",
- /*opaque=*/""));
- // Fix indices, adding j*split_N to the j-th batch of indices.
- TF_ASSIGN_OR_RETURN(HloInstruction * indices,
- MakeGetTupleElementHlo(batch_topk, 1));
- TF_ASSIGN_OR_RETURN(HloInstruction * values,
- MakeGetTupleElementHlo(batch_topk, 0));
- Shape iota_shape = ShapeUtil::MakeShape(S32, {new_batch});
- TF_ASSIGN_OR_RETURN(
- HloInstruction * fix,
- MakeBinaryHlo(
- HloOpcode::kMultiply, MakeIotaHlo(comp, iota_shape, 0),
- MakeBroadcastHlo(MakeR0ConstantHlo<int32_t>(comp, new_n),
- /*broadcast_dimensions=*/{}, iota_shape)));
- TF_ASSIGN_OR_RETURN(
- indices, MakeBinaryHlo(HloOpcode::kAdd, indices,
- MakeBroadcastHlo(fix, {0}, indices->shape())));
- // With the indices restored, compute a final top-k. Since this topk uses
- // arbitrary indices, we need to use sort+slice.
- Shape linear_index_shape = ShapeUtil::MakeShape(S32, {k * new_batch});
- Shape linear_shape = ShapeUtil::ChangeElementType(
- linear_index_shape, data_shape.element_type());
- Shape linear_sort_shape =
- ShapeUtil::MakeTupleShape({linear_shape, linear_index_shape});
- // Assuming the outputs of the TopK above are stably sorted, using a stable
- // sort here is enough to guarantee global stable sorting:
- // - Within a blocks elements are stably sorted by TopK.
- // - Since blocks are organized linearly from smallest to largest, the
- // index used on the stable sort below will also respect block ordering.
- HloInstruction* aggregated_sort =
- comp->AddInstruction(HloInstruction::CreateSort(
- linear_sort_shape, 0,
- {*MakeReshapeHlo(linear_shape, values),
- *MakeReshapeHlo(linear_index_shape, indices)},
- topk->to_apply(), /*is_stable=*/true));
- auto slice_tuple = [&](HloInstruction* sort, const size_t index) {
- return *MakeReshapeHlo(
- topk->shape().tuple_shapes(index),
- *MakeSliceHlo(*MakeGetTupleElementHlo(sort, index), {0}, {k}, {1}));
- };
- return ReplaceInstruction(topk,
- comp->AddInstruction(HloInstruction::CreateTuple({
- slice_tuple(aggregated_sort, 0),
- slice_tuple(aggregated_sort, 1),
- })));
- }
-
- private:
- size_t split_threshold_;
-};
-
-} // namespace
-
-absl::StatusOr<bool> TopKSplitter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- return TopkSplitterVisitor(split_threshold_)
- .RunOnModule(module, execution_threads);
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/topk_splitter.h b/third_party/xla/xla/service/gpu/topk_splitter.h
deleted file mode 100644
index 8fee2dc..0000000
--- a/third_party/xla/xla/service/gpu/topk_splitter.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TOPK_SPLITTER_H_
-#define XLA_SERVICE_GPU_TOPK_SPLITTER_H_
-
-#include <cstddef>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Splits large TopK into batches of smaller TopKs, followed by sorting and
-// slicing the results of those smaller topks. We consider TopKs to be 'large'
-// the last dimension of the TopK is larger than `split_threshold`.
-class TopKSplitter : public HloModulePass {
- public:
- explicit TopKSplitter(size_t split_threshold = 1024 * 1024)
- : split_threshold_(split_threshold) {}
- absl::string_view name() const override { return "topk-splitter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- const size_t split_threshold_;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_TOPK_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/topk_splitter_test.cc
deleted file mode 100644
index 834185f..0000000
--- a/third_party/xla/xla/service/gpu/topk_splitter_test.cc
+++ /dev/null
@@ -1,210 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/topk_splitter.h"
-
-#include <stdint.h>
-
-#include <cstddef>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/topk_rewriter.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace m = ::xla::match;
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::tsl::testing::IsOkAndHolds;
-using TopkSplitterTest = HloTestBase;
-
-constexpr absl::string_view kComparator = R"(
- %compare {
- %p.1.lhs.40628 = s32[] parameter(2)
- %p.1.rhs.40629 = s32[] parameter(3)
- %constant.40630 = pred[] constant(true)
- %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
- %p.0.lhs.40626 = f32[] parameter(0)
- %p.0.rhs.40627 = f32[] parameter(1)
- %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
- ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
- })";
-
-TEST_F(TopkSplitterTest, SplitsTopK) {
- const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
- %arg.1 = f32[1,1073741824] parameter(0)
- ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
- kComparator);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
- auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
- auto slice_result = [&](auto input, size_t i) {
- return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
- };
- auto index_correction =
- m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
- auto sorted = m::Sort(
- m::Reshape(m::GetTupleElement(first_topk, 0)),
- m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
- EXPECT_TRUE(
- Match(module->entry_computation()->root_instruction(),
- m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
-}
-
-TEST_F(TopkSplitterTest, SplitsTopKNoBatchDimension) {
- const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
- %arg.1 = f32[1073741824] parameter(0)
- ROOT %cc.2 = (f32[5], s32[5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
- kComparator);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
- auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
- auto slice_result = [&](auto input, size_t i) {
- return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
- };
- auto index_correction =
- m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
- auto sorted = m::Sort(
- m::Reshape(m::GetTupleElement(first_topk, 0)),
- m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
- EXPECT_TRUE(
- Match(module->entry_computation()->root_instruction(),
- m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
-}
-
-TEST_F(TopkSplitterTest, SplitFailsUnderThreshold) {
- const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
- %arg.1 = f32[1,524288] parameter(0)
- ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
- kComparator);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_THAT(
- RunHloPass(TopKSplitter(/*split_threshold=*/1048576), module.get()),
- IsOkAndHolds(false));
-}
-
-TEST_F(TopkSplitterTest, SplitFailsUnaligned) {
- const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
- %arg.1 = f32[1,524289] parameter(0)
- ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
- kComparator);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
- IsOkAndHolds(false));
-}
-
-TEST_F(TopkSplitterTest, SplitFailsLargeK) {
- const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
- %arg.1 = f32[1,524288] parameter(0)
- ROOT %cc.2 = (f32[1,1024], s32[1,1024]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
- kComparator);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
- IsOkAndHolds(false));
-}
-
-TEST_F(TopkSplitterTest, Equivalent) {
- const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
- %arg.1 = f32[1,16384] parameter(0)
- ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
- kComparator);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
- auto round_trip = [](HloModule* module) {
- EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
- return true;
- }).Run(module),
- IsOkAndHolds(true));
- EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
- EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
- EXPECT_TRUE(HloDCE().Run(module).status().ok());
- };
- EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
-}
-
-TEST_F(TopkSplitterTest, StableSorts) {
- const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
- %constant.1 = f32[] constant(42)
- %broadcast.2= f32[1,16384] broadcast(f32[] %constant.1), dimensions={}
- ROOT %cc.3 = (f32[1,5], s32[1,5]) custom-call(%broadcast.2), custom_call_target= "TopK", to_apply=%compare
-})",
- kComparator);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
- auto round_trip = [](HloModule* module) {
- EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
- return true;
- }).Run(module),
- IsOkAndHolds(true));
- EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
- EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
- EXPECT_TRUE(HloDCE().Run(module).status().ok());
- };
- EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/topk_test.cc b/third_party/xla/xla/service/gpu/topk_test.cc
deleted file mode 100644
index 43e25b8..0000000
--- a/third_party/xla/xla/service/gpu/topk_test.cc
+++ /dev/null
@@ -1,159 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <stddef.h>
-
-#include <memory>
-#include <optional>
-#include <string>
-#include <string_view>
-#include <tuple>
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/topk_specializer.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/platform_util.h"
-#include "xla/service/topk_rewriter.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-using ::testing::Combine;
-using ::testing::Values;
-
-// Params:
-// - n_kb: number of elements in kilobytes.
-// - k: number of elements to return.
-// - batch_size
-// - dtype
-using ParameterizedInterface =
- ::testing::WithParamInterface<std::tuple<int, int, int, std::string_view>>;
-
-class TopkTest : public HloTestBase, public ParameterizedInterface {
- public:
- TopkTest()
- : HloTestBase(*PlatformUtil::GetPlatform("gpu"),
- *PlatformUtil::GetPlatform("gpu"), true, true, {}) {}
-
- protected:
- absl::StatusOr<std::unique_ptr<HloModule>> TopkHlo(int n, int k,
- int batch_size,
- std::string_view dtype) {
- return ParseAndReturnVerifiedModule(absl::Substitute(
- R"(
- %compare {
- %p.1.lhs.40628 = s32[] parameter(2)
- %p.1.rhs.40629 = s32[] parameter(3)
- %constant.40630 = pred[] constant(true)
- %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
- %p.0.lhs.40626 = f32[] parameter(0)
- %p.0.rhs.40627 = f32[] parameter(1)
- %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
- ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
- }
-
- ENTRY top_k {
- %arg = $3[$2,$0] parameter(0)
- ROOT %result = ($3[$2,$1], s32[$2,$1]) custom-call(%arg), custom_call_target="TopK", to_apply=%compare
- }
- )",
- n, k, batch_size, dtype));
- }
-};
-
-class GeneralizeTopkVisitor : public DfsHloRewriteVisitor {
- public:
- absl::Status HandleCustomCall(HloInstruction* inst) override {
- HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
- if (topk == nullptr || topk->custom_call_target() != "__gpu$TopK") {
- return absl::OkStatus();
- }
- HloComputation* comp = topk->parent();
- auto original_shape = ShapeUtil::SliceTuple(topk->shape(), 0, 2);
- HloInstruction* original_topk =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- original_shape, topk->operands(), topk->to_apply(), "TopK"));
- // TupleUtil::ExtractPrefix creates the following structure:
- // TopK
- // -------------
- // | | |
- // Get Get Get
- // \ | /
- // CreateTuple
- // Here we walk to Create Tuple and replace it with the original topk.
- HloInstruction* new_tuple = topk->users()[0]->users()[0];
- return ReplaceInstruction(new_tuple, original_topk);
- }
-};
-
-class GeneralizeTopk : public HloModulePass {
- public:
- absl::string_view name() const override { return "generalized-topk"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(HloModule* module,
- const absl::flat_hash_set<absl::string_view>&
- execution_threads) override {
- return GeneralizeTopkVisitor().RunOnModule(module, execution_threads);
- }
-};
-
-void ToSortAndSlice(HloModule* module) {
- TF_ASSERT_OK_AND_ASSIGN(bool changed, GeneralizeTopk().Run(module));
- ASSERT_TRUE(changed);
- TF_ASSERT_OK_AND_ASSIGN(changed, TopkDecomposer().Run(module));
- ASSERT_TRUE(changed);
-}
-
-TEST_P(TopkTest, ProducesCorrectResult) {
- const auto [n_kb, k, batch_size, dtype] = GetParam();
- const size_t n = n_kb * 1024;
- TF_ASSERT_OK_AND_ASSIGN(auto topk_module, TopkHlo(n, k, batch_size, dtype));
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- gpu::TopkSpecializer().Run(topk_module.get()));
- ASSERT_TRUE(changed);
- EXPECT_TRUE(
- RunAndCompare(std::move(topk_module), std::nullopt, ToSortAndSlice));
-}
-
-INSTANTIATE_TEST_SUITE_P(
- TopkTests, TopkTest,
- Combine(
- /*n_kb=*/Values(1, 8, 12, 32),
- /*k=*/Values(1, 2, 4, 8, 16, 7, 12),
- /*batch_size=*/Values(1, 16, 32, 64, 128),
- /*dtype=*/Values(absl::string_view("f32"), "bf16")),
- [](const auto& info) {
- return absl::Substitute("n$0KiB_k$1_batch_size$2_$3",
- std::get<0>(info.param), std::get<1>(info.param),
- std::get<2>(info.param), std::get<3>(info.param));
- });
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD
new file mode 100644
index 0000000..21a841a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/BUILD
@@ -0,0 +1,2963 @@
+load(
+ "@local_config_rocm//rocm:build_defs.bzl",
+ "if_rocm_is_configured",
+)
+load(
+ "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
+ "if_cuda_is_configured",
+)
+load("//xla:xla.bzl", "xla_cc_test")
+load(
+ "//xla/stream_executor:build_defs.bzl",
+ "if_gpu_is_configured",
+)
+load("//xla/tests:build_defs.bzl", "xla_test")
+load("//xla/tsl:tsl.bzl", "if_google", "if_oss")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//xla/service/gpu:__subpackages__"],
+ licenses = ["notice"],
+)
+
+cc_library(
+ name = "algebraic_simplifier",
+ srcs = [
+ "algebraic_simplifier.cc",
+ ],
+ hdrs = [
+ "algebraic_simplifier.h",
+ ],
+ deps = [
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:algebraic_simplifier",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu/fusions/triton:triton_support",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ ],
+)
+
+xla_cc_test(
+ name = "algebraic_simplifier_test",
+ srcs = ["algebraic_simplifier_test.cc"],
+ deps = [
+ ":algebraic_simplifier",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:algebraic_simplifier",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+# End-to-end tested via //third_party/tensorflow/compiler/xla/service/gpu:dot_algorithm_support_test
+cc_library(
+ name = "algorithm_checker",
+ srcs = ["algorithm_checker.cc"],
+ hdrs = ["algorithm_checker.h"],
+ deps = [
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:algorithm_util",
+ "//xla/service:hlo_pass",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ ],
+)
+
+cc_library(
+ name = "alias_passthrough_params",
+ srcs = ["alias_passthrough_params.cc"],
+ hdrs = ["alias_passthrough_params.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ ],
+)
+
+xla_cc_test(
+ name = "alias_passthrough_params_test",
+ srcs = ["alias_passthrough_params_test.cc"],
+ tags = [
+ "nomsan",
+ ],
+ deps = [
+ ":alias_passthrough_params",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "//xla/tsl/lib/core:status_test_util",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+cc_library(
+ name = "all_gather_optimizer",
+ srcs = ["all_gather_optimizer.cc"],
+ hdrs = ["all_gather_optimizer.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ ],
+)
+
+xla_cc_test(
+ name = "all_gather_optimizer_test",
+ srcs = ["all_gather_optimizer_test.cc"],
+ deps = [
+ ":all_gather_optimizer",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_module_config",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+cc_library(
+ name = "all_reduce_blueconnect",
+ srcs = ["all_reduce_blueconnect.cc"],
+ hdrs = ["all_reduce_blueconnect.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:computation_placer_hdr",
+ "//xla/service:global_device_id",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:btree",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "all_reduce_blueconnect_test",
+ srcs = ["all_reduce_blueconnect_test.cc"],
+ deps = [
+ ":all_reduce_blueconnect",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:computation_placer_hdr",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "all_reduce_splitter",
+ srcs = ["all_reduce_splitter.cc"],
+ hdrs = ["all_reduce_splitter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:collective_opt_utils",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/cleanup",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "all_reduce_splitter_test",
+ srcs = ["all_reduce_splitter_test.cc"],
+ deps = [
+ ":all_reduce_splitter",
+ ":reduce_scatter_creator",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass_pipeline",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "async_collective_annotator",
+ srcs = ["async_collective_annotator.cc"],
+ hdrs = ["async_collective_annotator.h"],
+ deps = [
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "async_collective_annotator_test",
+ srcs = ["async_collective_annotator_test.cc"],
+ deps = [
+ ":async_collective_annotator",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:test_macros_header",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "async_wrapper",
+ srcs = ["async_wrapper.cc"],
+ hdrs = ["async_wrapper.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:hlo_proto_cc",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+xla_cc_test(
+ name = "async_wrapper_test",
+ srcs = ["async_wrapper_test.cc"],
+ deps = [
+ ":async_wrapper",
+ "//xla:literal",
+ "//xla:literal_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:hlo_proto_cc",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:literal_test_util",
+ "//xla/tests:verified_hlo_module",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "collective_permute_cycle_decomposer",
+ srcs = ["collective_permute_cycle_decomposer.cc"],
+ hdrs = ["collective_permute_cycle_decomposer.h"],
+ deps = [
+ "//xla:comparison_util",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:hlo_parser",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+xla_cc_test(
+ name = "collective_permute_cycle_decomposer_test",
+ srcs = ["collective_permute_cycle_decomposer_test.cc"],
+ deps = [
+ ":collective_permute_cycle_decomposer",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_parser",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "collective_permute_valid_iteration_annotator",
+ srcs = ["collective_permute_valid_iteration_annotator.cc"],
+ hdrs = ["collective_permute_valid_iteration_annotator.h"],
+ deps = [
+ "//xla:literal_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service:while_loop_analysis",
+ ],
+)
+
+xla_cc_test(
+ name = "collective_permute_valid_iteration_annotator_test",
+ srcs = ["collective_permute_valid_iteration_annotator_test.cc"],
+ deps = [
+ ":collective_permute_valid_iteration_annotator",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:hlo_pass_pipeline",
+ "//xla/service:while_loop_trip_count_annotator",
+ "//xla/tests:hlo_test_base",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "command_buffer_scheduling",
+ srcs = ["command_buffer_scheduling.cc"],
+ hdrs = ["command_buffer_scheduling.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/ffi:ffi_api",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:variant_visitor",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "command_buffer_scheduling_test",
+ srcs = ["command_buffer_scheduling_test.cc"],
+ deps = [
+ ":command_buffer_scheduling",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_parser",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "conv_padding_legalization",
+ srcs = ["conv_padding_legalization.cc"],
+ hdrs = ["conv_padding_legalization.h"],
+ deps = [
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:window_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:shape_inference",
+ "//xla/service/gpu:cublas_cudnn",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "conv_padding_legalization_test",
+ srcs = ["conv_padding_legalization_test.cc"],
+ deps = [
+ ":conv_padding_legalization",
+ "//xla:shape_util",
+ "//xla:test",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+cc_library(
+ name = "conv_rewriter",
+ srcs = ["conv_rewriter.cc"],
+ hdrs = ["conv_rewriter.h"],
+ deps = [
+ "//xla:permutation_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:window_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "conv_rewriter_test",
+ srcs = ["conv_rewriter_test.cc"],
+ deps = [
+ ":conv_rewriter",
+ "//xla:array4d",
+ "//xla:literal_util",
+ "//xla:protobuf_util",
+ "//xla:shape_util",
+ "//xla:test",
+ "//xla:test_helpers",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service:shape_inference",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings:str_format",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+cc_library(
+ name = "convert_async_collectives_to_sync",
+ srcs = ["convert_async_collectives_to_sync.cc"],
+ hdrs = ["convert_async_collectives_to_sync.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:convert_async_collectives_to_sync",
+ "//xla/service/gpu:backend_configs_cc",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "convert_async_collectives_to_sync_test",
+ srcs = ["convert_async_collectives_to_sync_test.cc"],
+ deps = [
+ ":convert_async_collectives_to_sync",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "copy_fusion",
+ srcs = ["copy_fusion.cc"],
+ hdrs = ["copy_fusion.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:reduction_utils",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ ],
+)
+
+xla_cc_test(
+ name = "copy_fusion_test",
+ srcs = ["copy_fusion_test.cc"],
+ deps = [
+ ":copy_fusion",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "cublas_pad_for_gemms",
+ srcs = ["cublas_pad_for_gemms.cc"],
+ hdrs = ["cublas_pad_for_gemms.h"],
+ deps = [
+ ":gemm_fusion",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu/fusions/triton:triton_support",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "cublas_pad_for_gemms_test",
+ srcs = ["cublas_pad_for_gemms_test.cc"],
+ tags = [
+ "nomsan",
+ ],
+ deps = [
+ ":cublas_pad_for_gemms",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "cudnn_custom_call_converter",
+ srcs = ["cudnn_custom_call_converter.cc"],
+ hdrs = ["cudnn_custom_call_converter.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:ir_emission_utils",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+xla_cc_test(
+ name = "cudnn_custom_call_converter_test",
+ srcs = ["cudnn_custom_call_converter_test.cc"],
+ deps = [
+ ":cudnn_custom_call_converter",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "cudnn_fused_conv_rewriter",
+ srcs = ["cudnn_fused_conv_rewriter.cc"],
+ hdrs = ["cudnn_fused_conv_rewriter.h"],
+ deps = [
+ "//xla:comparison_util",
+ "//xla:debug_options_flags",
+ "//xla:literal",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/stream_executor",
+ "//xla/stream_executor:dnn",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:ml_dtypes",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "cudnn_fused_conv_rewriter_test",
+ srcs = ["cudnn_fused_conv_rewriter_test.cc"],
+ backend_tags = {
+ "gpu_a100": [
+ "noasan",
+ "nomsan",
+ "no_rocm",
+ ],
+ },
+ backends = [
+ "gpu_a100",
+ "gpu_amd_any",
+ ] + if_oss(["gpu_any"]),
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+ "TENSORFLOW_USE_ROCM=1",
+ ]),
+ shard_count = 10,
+ deps = [
+ ":conv_rewriter",
+ ":cudnn_fused_conv_rewriter",
+ "//xla:comparison_util",
+ "//xla:error_spec",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:algebraic_simplifier",
+ "//xla/service:convert_mover",
+ "//xla/service:hlo_constant_folding",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service:hlo_pass_pipeline",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service:reshape_mover",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/service/gpu/tests:gpu_codegen_test",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:dnn",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ] + if_cuda_is_configured([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cudnn_header",
+ ]) + if_rocm_is_configured([
+ "@local_config_rocm//rocm:rocm_headers",
+ ]),
+)
+
+cc_library(
+ name = "cudnn_fused_mha_rewriter",
+ srcs = ["cudnn_fused_mha_rewriter.cc"],
+ hdrs = ["cudnn_fused_mha_rewriter.h"],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+ deps = [
+ "//xla:permutation_util",
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:types",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor",
+ "//xla/stream_executor:dnn",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ] + if_cuda_is_configured([
+ "@local_config_cuda//cuda:cuda_headers",
+ ]),
+)
+
+xla_test(
+ name = "cudnn_fused_mha_rewriter_test",
+ srcs = ["cudnn_fused_mha_rewriter_test.cc"],
+ backend_tags = {"gpu": [
+ "requires-gpu-nvidia",
+ "no_rocm",
+ ]},
+ backends = [
+ "gpu",
+ ],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+ deps = [
+ ":cudnn_fused_mha_rewriter",
+ ":cudnn_fused_mha_transpose_fusion",
+ "//xla:error_spec",
+ "//xla:test_helpers",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:algebraic_simplifier",
+ "//xla/service:computation_layout",
+ "//xla/service:hlo_cse",
+ "//xla/service:hlo_dce",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_parser",
+ "//xla/service:hlo_verifier",
+ "//xla/service:layout_normalization",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service:reshape_decomposer",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:dnn",
+ "//xla/tests:hlo_test_base",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ] + if_cuda_is_configured([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cudnn_header",
+ ]),
+)
+
+# Tested via cudnn_fused_mha_rewriter_test.
+cc_library(
+ name = "cudnn_fused_mha_transpose_fusion",
+ srcs = ["cudnn_fused_mha_transpose_fusion.cc"],
+ hdrs = ["cudnn_fused_mha_transpose_fusion.h"],
+ deps = [
+ "//xla:permutation_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:matmul_utils",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+# Tested via //third_party/tensorflow/compiler/xla/service/gpu/fusions:cudnn_test
+cc_library(
+ name = "cudnn_fusion_compiler",
+ srcs = if_cuda_is_configured(["cudnn_fusion_compiler.cc"]),
+ hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]),
+ deps = if_cuda_is_configured([
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cudnn_support_utils",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:kernel_reuse_cache",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/service/gpu:triton_fusion_analysis",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_config_cuda//cuda:cudnn_header",
+ "//xla:shape_util",
+ "//xla:comparison_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:hlo_pass",
+ "//xla/stream_executor:dnn",
+ "//xla/stream_executor:stream_executor_h",
+ "//xla/service:dump",
+ "//xla/stream_executor/cuda:cudnn_frontend_helpers",
+ "//xla/stream_executor/cuda:cudnn_plugin",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ]),
+)
+
+cc_library(
+ name = "cudnn_norm_rewriter",
+ srcs = ["cudnn_norm_rewriter.cc"],
+ hdrs = ["cudnn_norm_rewriter.h"],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+ deps = [
+ "//xla:shape_util",
+ "//xla:types",
+ "//xla:util",
+ "//xla:window_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/stream_executor",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/protobuf:dnn_proto_cc",
+ ] + if_cuda_is_configured([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cudnn_header",
+ ]) + if_google([
+ "@com_google_protobuf//:wrappers_cc_proto",
+ ]),
+)
+
+xla_test(
+ name = "cudnn_norm_rewriter_test",
+ srcs = ["cudnn_norm_rewriter_test.cc"],
+ backends = ["gpu"],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+ deps = [
+ ":cudnn_norm_rewriter",
+ "//xla:error_spec",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu/tests:gpu_codegen_test",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:filecheck",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_googletest//:gtest_main",
+ ] + if_cuda_is_configured([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cudnn_header",
+ ]),
+)
+
+cc_library(
+ name = "cudnn_pad_for_convolutions",
+ srcs = ["cudnn_pad_for_convolutions.cc"],
+ hdrs = ["cudnn_pad_for_convolutions.h"],
+ deps = [
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:cudnn_support_utils",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/functional:bind_front",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "cudnn_pad_for_convolutions_test",
+ srcs = ["cudnn_pad_for_convolutions_test.cc"],
+ deps = [
+ ":cudnn_pad_for_convolutions",
+ "//xla/service:hlo_parser",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "cudnn_simplify_padding",
+ srcs = ["cudnn_simplify_padding.cc"],
+ hdrs = ["cudnn_simplify_padding.h"],
+ deps = [
+ "//xla:literal",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "cudnn_simplify_padding_test",
+ srcs = ["cudnn_simplify_padding_test.cc"],
+ deps = [
+ ":cudnn_pad_for_convolutions",
+ ":cudnn_simplify_padding",
+ ":cudnn_vectorize_convolutions",
+ "//xla:literal",
+ "//xla:util",
+ "//xla/service:algebraic_simplifier",
+ "//xla/service:call_inliner",
+ "//xla/service:hlo_pass",
+ "//xla/service:hlo_pass_pipeline",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service:reshape_mover",
+ "//xla/service:tuple_simplifier",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:dnn",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/functional:function_ref",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "cudnn_vectorize_convolutions",
+ srcs = ["cudnn_vectorize_convolutions.cc"],
+ hdrs = ["cudnn_vectorize_convolutions.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/client:xla_builder",
+ "//xla/client:xla_computation",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:cudnn_support_utils",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor",
+ "//xla/stream_executor:dnn",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "cudnn_vectorize_convolutions_test",
+ srcs = ["cudnn_vectorize_convolutions_test.cc"],
+ deps = [
+ ":cudnn_vectorize_convolutions",
+ "//xla:util",
+ "//xla/service:call_inliner",
+ "//xla/service:hlo_parser",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:dnn",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+ name = "cudnn_custom_call_compiler",
+ srcs = if_cuda_is_configured(["cudnn_custom_call_compiler.cc"]),
+ hdrs = if_cuda_is_configured(["cudnn_custom_call_compiler.h"]),
+ deps = if_cuda_is_configured([
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_config_cuda//cuda:cudnn_header",
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu/runtime:cudnn_thunk",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor:dnn",
+ "//xla/stream_executor:stream_executor_h",
+ "//xla/stream_executor/cuda:cudnn_frontend_helpers",
+ "//xla/stream_executor/cuda:cudnn_plugin",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ]),
+)
+
+cc_library(
+ name = "custom_kernel_fusion_rewriter",
+ srcs = ["custom_kernel_fusion_rewriter.cc"],
+ hdrs = ["custom_kernel_fusion_rewriter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu/kernels:custom_fusion_library",
+ "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "custom_kernel_fusion_rewriter_test",
+ srcs = ["custom_kernel_fusion_rewriter_test.cc"],
+ deps = [
+ ":custom_kernel_fusion_rewriter",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "dot_dimension_sorter",
+ srcs = ["dot_dimension_sorter.cc"],
+ hdrs = ["dot_dimension_sorter.h"],
+ deps = [
+ "//xla:permutation_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ ],
+)
+
+xla_test(
+ name = "dot_dimension_sorter_test",
+ srcs = ["dot_dimension_sorter_test.cc"],
+ backends = ["gpu"],
+ deps = [
+ ":dot_dimension_sorter",
+ "//xla:error_spec",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu/tests:gpu_codegen_test",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "dot_operand_converter",
+ srcs = ["dot_operand_converter.cc"],
+ hdrs = ["dot_operand_converter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:op_expander_pass",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+xla_test(
+ name = "dot_operand_converter_test",
+ srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]),
+ backends = [
+ "gpu_a100",
+ "gpu_p100",
+ "gpu_v100",
+ "gpu_amd_any",
+ ],
+ deps = if_gpu_is_configured(
+ [
+ ":dot_operand_converter",
+ "@com_google_googletest//:gtest",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_matchers",
+ "//xla/service:pattern_matcher",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+ ) + [
+ # b/317293391
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "dot_sparsity_rewriter",
+ srcs = ["dot_sparsity_rewriter.cc"],
+ hdrs = ["dot_sparsity_rewriter.h"],
+ deps = [
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "dot_sparsity_rewriter_test",
+ srcs = ["dot_sparsity_rewriter_test.cc"],
+ deps = [
+ ":dot_sparsity_rewriter",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "double_buffer_loop_unrolling",
+ srcs = ["double_buffer_loop_unrolling.cc"],
+ hdrs = ["double_buffer_loop_unrolling.h"],
+ deps = [
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/ir:hlo_instruction_utils",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:flatten_call_graph",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "double_buffer_loop_unrolling_test",
+ srcs = ["double_buffer_loop_unrolling_test.cc"],
+ deps = [
+ ":double_buffer_loop_unrolling",
+ "//xla:test",
+ "//xla:xla_data_proto_cc",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:tuple_simplifier",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "dynamic_slice_fusion_rewriter",
+ srcs = ["dynamic_slice_fusion_rewriter.cc"],
+ hdrs = ["dynamic_slice_fusion_rewriter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/ffi:ffi_api",
+ "//xla/ffi/api:c_api",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:custom_call_target_registry",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:gpu_constants",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu/kernels:custom_fusion_library",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "dynamic_slice_fusion_rewriter_test",
+ srcs = if_cuda_is_configured(["dynamic_slice_fusion_rewriter_test.cc"]),
+ deps = [
+ ":dynamic_slice_fusion_rewriter",
+ "//xla:shape_util",
+ "//xla/client:xla_builder",
+ "//xla/client/lib:constants",
+ "//xla/ffi",
+ "//xla/ffi:ffi_api",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:buffer_value",
+ "//xla/service:custom_call_target_registry",
+ "//xla/service:executable",
+ "//xla/service:hlo_memory_scheduler",
+ "//xla/service:hlo_module_config",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/stream_executor",
+ "//xla/stream_executor/gpu:gpu_types_header",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/status",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "fusion_merger",
+ srcs = ["fusion_merger.cc"],
+ hdrs = ["fusion_merger.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_cost_analysis",
+ "//xla/service:hlo_graph_dumper",
+ "//xla/service:hlo_pass",
+ "//xla/service:instruction_fusion",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+ "//xla/service/gpu/model:gpu_performance_model",
+ "//xla/service/gpu/model:gpu_performance_model_base",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:status",
+ ],
+)
+
+xla_cc_test(
+ name = "fusion_merger_test",
+ srcs = ["fusion_merger_test.cc"],
+ tags = [
+ "nomsan",
+ ],
+ deps = [
+ ":fusion_merger",
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_cost_analysis",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/types:span",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "fusion_wrapper",
+ srcs = ["fusion_wrapper.cc"],
+ hdrs = ["fusion_wrapper.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:gpu_fusible",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+xla_cc_test(
+ name = "fusion_wrapper_test",
+ srcs = ["fusion_wrapper_test.cc"],
+ deps = [
+ ":fusion_wrapper",
+ "//xla/tests:hlo_test_base",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "gemm_broadcast_folding_rewriter",
+ srcs = ["gemm_broadcast_folding_rewriter.cc"],
+ hdrs = ["gemm_broadcast_folding_rewriter.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "gemm_broadcast_folding_rewriter_test",
+ srcs = ["gemm_broadcast_folding_rewriter_test.cc"],
+ backends = ["gpu"],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+ "TENSORFLOW_USE_ROCM=1",
+ ]),
+ deps = [
+ ":gemm_broadcast_folding_rewriter",
+ ":gemm_rewriter",
+ "//xla:error_spec",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu/tests:gpu_codegen_test",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "gemm_fusion",
+ srcs = ["gemm_fusion.cc"],
+ hdrs = ["gemm_fusion.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:instruction_fusion",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_padding_requirements",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu:triton_fusion_analysis",
+ "//xla/service/gpu:triton_tiling_propagation",
+ "//xla/service/gpu/fusions/triton:triton_support",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:tensor_float_32_utils",
+ ],
+)
+
+xla_cc_test(
+ name = "gemm_fusion_test",
+ srcs = ["gemm_fusion_test.cc"],
+ deps = [
+ ":gemm_fusion",
+ "//xla:autotuning_proto_cc",
+ "//xla:xla_data_proto_cc",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:cublas_padding_requirements",
+ "//xla/service/gpu:triton_fusion_analysis",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "gemm_rewriter",
+ srcs = ["gemm_rewriter.cc"],
+ hdrs = ["gemm_rewriter.h"],
+ deps = [
+ "//xla:literal",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:types",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/evaluator:hlo_evaluator",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:algorithm_util",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/stream_executor:blas",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor/gpu:gpu_blas_lt",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:ml_dtypes",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/protobuf:dnn_proto_cc",
+ ],
+)
+
+xla_test(
+ name = "gemm_rewriter_test",
+ srcs = ["gemm_rewriter_test.cc"],
+ backends = ["gpu"],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+ "TENSORFLOW_USE_ROCM=1",
+ ]),
+ deps = [
+ ":gemm_rewriter",
+ "//xla:error_spec",
+ "//xla:test",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:buffer_assignment",
+ "//xla/service:executable",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:gpu_executable",
+ "//xla/service/gpu/tests:gpu_codegen_test",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:device_memory_allocator",
+ "//xla/stream_executor:stream_executor_memory_allocator",
+ "//xla/tests:filecheck",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ] + if_cuda_is_configured([
+ "@local_config_cuda//cuda:cuda_headers",
+ ]) + if_rocm_is_configured([
+ "@local_config_rocm//rocm:rocm_headers",
+ ]),
+)
+
+cc_library(
+ name = "gemv_rewriter",
+ srcs = ["gemv_rewriter.cc"],
+ hdrs = ["gemv_rewriter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "gemv_rewriter_test",
+ srcs = ["gemv_rewriter_test.cc"],
+ deps = [
+ ":gemv_rewriter",
+ "//xla/hlo/ir:hlo",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+ name = "gpusolver_rewriter",
+ srcs = if_gpu_is_configured(["gpusolver_rewriter.cc"]),
+ hdrs = if_gpu_is_configured(["gpusolver_rewriter.h"]),
+ deps = if_gpu_is_configured([
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "//xla:comparison_util",
+ "//xla:literal",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:cusolver_context",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/stream_executor",
+ "//xla/stream_executor:blas",
+ "//xla/stream_executor:device_memory_allocator",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ]),
+)
+
+cc_library(
+ name = "horizontal_input_fusion",
+ srcs = ["horizontal_input_fusion.cc"],
+ hdrs = ["horizontal_input_fusion.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "horizontal_input_fusion_test",
+ srcs = ["horizontal_input_fusion_test.cc"],
+ backends = ["gpu"],
+ deps = [
+ ":horizontal_input_fusion",
+ "//xla:error_spec",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:test",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu/tests:gpu_codegen_test",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
+ name = "horizontal_loop_fusion",
+ srcs = ["horizontal_loop_fusion.cc"],
+ hdrs = ["horizontal_loop_fusion.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:sub_byte_normalization",
+ "//xla/service/gpu:gpu_fusible",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "horizontal_loop_fusion_test",
+ srcs = ["horizontal_loop_fusion_test.cc"],
+ backends = ["gpu"],
+ deps = [
+ ":horizontal_loop_fusion",
+ ":instruction_fusion",
+ "//xla:error_spec",
+ "//xla:shape_util",
+ "//xla:test",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_dce",
+ "//xla/service:hlo_parser",
+ "//xla/service:hlo_pass",
+ "//xla/service:hlo_pass_pipeline",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/log",
+ ],
+)
+
+cc_library(
+ name = "instruction_fusion",
+ srcs = ["instruction_fusion.cc"],
+ hdrs = ["instruction_fusion.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:fusion_node_indexing_evaluation",
+ "//xla/service:fusion_queue",
+ "//xla/service:hlo_pass",
+ "//xla/service:instruction_fusion",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/meta:type_traits",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_cc_test(
+ name = "instruction_fusion_test",
+ srcs = ["instruction_fusion_test.cc"],
+ tags = [
+ "nomsan",
+ "not_run:arm",
+ ],
+ deps = [
+ ":instruction_fusion",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:test_utils",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "layout_assignment",
+ srcs = ["layout_assignment.cc"],
+ hdrs = ["layout_assignment.h"],
+ deps = [
+ "//xla:shape_layout",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:window_util",
+ "//xla:xla_data_proto_cc",
+ "//xla:xla_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:computation_layout",
+ "//xla/service:host_memory_offload_annotations_hdr",
+ "//xla/service:layout_assignment",
+ "//xla/service:logical_buffer",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu:matmul_utils",
+ "//xla/service/gpu:reduction_utils",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor",
+ "//xla/stream_executor:dnn",
+ "//xla/tsl/util:env_var",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "layout_assignment_test",
+ srcs = ["layout_assignment_test.cc"],
+ deps = [
+ ":layout_assignment",
+ "//xla:shape_layout",
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:computation_layout",
+ "//xla/service:hlo_parser",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:stream_executor_util",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:dnn",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_absl//absl/types:span",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "move_copy_to_users",
+ srcs = ["move_copy_to_users.cc"],
+ hdrs = ["move_copy_to_users.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "move_copy_to_users_test",
+ srcs = ["move_copy_to_users_test.cc"],
+ deps = [
+ ":move_copy_to_users",
+ "//xla/service:layout_assignment",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "multi_output_fusion",
+ srcs = ["multi_output_fusion.cc"],
+ hdrs = ["multi_output_fusion.h"],
+ deps = [
+ "//xla:debug_options_flags",
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/ir:hlo_dfs_reachability",
+ "//xla/service:hlo_cost_analysis",
+ "//xla/service:hlo_graph_dumper",
+ "//xla/service:hlo_pass",
+ "//xla/service:instruction_fusion",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+ "//xla/service/gpu/model:gpu_performance_model",
+ "//xla/service/gpu/model:gpu_performance_model_base",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:status",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "multi_output_fusion_test",
+ srcs = ["multi_output_fusion_test.cc"],
+ tags = [
+ "nomsan",
+ ],
+ deps = [
+ ":multi_output_fusion",
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_cost_analysis",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "pipelined_p2p_rewriter",
+ srcs = ["pipelined_p2p_rewriter.cc"],
+ hdrs = ["pipelined_p2p_rewriter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "pipelined_p2p_rewriter_test",
+ srcs = ["pipelined_p2p_rewriter_test.cc"],
+ deps = [
+ ":pipelined_p2p_rewriter",
+ "//xla/hlo/ir:hlo",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "priority_fusion",
+ srcs = ["priority_fusion.cc"],
+ hdrs = ["priority_fusion.h"],
+ deps = [
+ "//xla:debug_options_flags",
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:dump",
+ "//xla/service:fusion_queue",
+ "//xla/service:hlo_cost_analysis",
+ "//xla/service:hlo_graph_dumper",
+ "//xla/service:hlo_pass",
+ "//xla/service:instruction_fusion",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:fusion_process_dump_proto_cc",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu/model:fusion_analysis_cache",
+ "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+ "//xla/service/gpu/model:gpu_performance_model",
+ "//xla/service/gpu/model:gpu_performance_model_base",
+ "//xla/service/gpu/model:symbolic_tile_analysis",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/meta:type_traits",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:blocking_counter",
+ "@local_tsl//tsl/platform:env",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:status",
+ ],
+)
+
+xla_cc_test(
+ name = "priority_fusion_test",
+ srcs = ["priority_fusion_test.cc"],
+ local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+ tags = ["no_pip"],
+ deps = [
+ ":priority_fusion",
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_cost_analysis",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu:hlo_fusion_analysis",
+ "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "reduce_scatter_creator",
+ srcs = ["reduce_scatter_creator.cc"],
+ hdrs = ["reduce_scatter_creator.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:collective_opt_utils",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+xla_cc_test(
+ name = "reduce_scatter_creator_test",
+ srcs = ["reduce_scatter_creator_test.cc"],
+ deps = [
+ ":reduce_scatter_creator",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_module_config",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "reduction_degenerate_dim_remover",
+ srcs = ["reduction_degenerate_dim_remover.cc"],
+ hdrs = ["reduction_degenerate_dim_remover.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "reduction_degenerate_dim_remover_test",
+ srcs = [
+ "reduction_degenerate_dim_remover_test.cc",
+ ],
+ deps = [
+ ":reduction_degenerate_dim_remover",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "reduction_dimension_grouper",
+ srcs = ["reduction_dimension_grouper.cc"],
+ hdrs = ["reduction_dimension_grouper.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "reduction_dimension_grouper_test",
+ srcs = [
+ "reduction_dimension_grouper_test.cc",
+ ],
+ deps = [
+ ":reduction_dimension_grouper",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "reduction_layout_normalizer",
+ srcs = ["reduction_layout_normalizer.cc"],
+ hdrs = ["reduction_layout_normalizer.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "reduction_layout_normalizer_test",
+ srcs = [
+ "reduction_layout_normalizer_test.cc",
+ ],
+ backends = ["gpu"],
+ deps = [
+ ":reduction_layout_normalizer",
+ "//xla:error_spec",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "reduction_splitter",
+ srcs = ["reduction_splitter.cc"],
+ hdrs = ["reduction_splitter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:reduction_utils",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "reduction_splitter_test",
+ srcs = ["reduction_splitter_test.cc"],
+ deps = [
+ ":reduction_splitter",
+ "//xla:shape_util",
+ "//xla:test",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_parser",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
+ name = "rename_fusions",
+ srcs = ["rename_fusions.cc"],
+ hdrs = ["rename_fusions.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "@com_google_absl//absl/container:btree",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_cc_test(
+ name = "rename_fusions_test",
+ srcs = ["rename_fusions_test.cc"],
+ deps = [
+ ":rename_fusions",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "sanitize_constant_names",
+ srcs = ["sanitize_constant_names.cc"],
+ hdrs = ["sanitize_constant_names.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:name_uniquer",
+ "//xla/service/llvm_ir:buffer_assignment_util",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:logging",
+ ],
+)
+
+xla_cc_test(
+ name = "sanitize_constant_names_test",
+ srcs = ["sanitize_constant_names_test.cc"],
+ deps = [
+ ":sanitize_constant_names",
+ "//xla:literal_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+cc_library(
+ name = "scatter_slice_simplifier",
+ srcs = ["scatter_slice_simplifier.cc"],
+ hdrs = ["scatter_slice_simplifier.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+ name = "scatter_expander",
+ srcs = ["scatter_expander.cc"],
+ hdrs = ["scatter_expander.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:scatter_expander",
+ "@com_google_absl//absl/strings:string_view",
+ ],
+)
+
+xla_cc_test(
+ name = "scatter_slice_simplifier_test",
+ srcs = ["scatter_slice_simplifier_test.cc"],
+ deps = [
+ ":scatter_slice_simplifier",
+ "//xla:shape_util",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "schedule_postprocessing",
+ srcs = ["schedule_postprocessing.cc"],
+ hdrs = ["schedule_postprocessing.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "schedule_postprocessing_test",
+ srcs = ["schedule_postprocessing_test.cc"],
+ deps = [
+ ":schedule_postprocessing",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_parser",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "scheduling_instruction_annotator",
+ srcs = ["scheduling_instruction_annotator.cc"],
+ hdrs = ["scheduling_instruction_annotator.h"],
+ deps = [
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "scheduling_instruction_annotator_test",
+ srcs = ["scheduling_instruction_annotator_test.cc"],
+ deps = [
+ ":scheduling_instruction_annotator",
+ "//xla/hlo/ir:hlo",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "softmax_rewriter_triton",
+ srcs = ["softmax_rewriter_triton.cc"],
+ hdrs = ["softmax_rewriter_triton.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:hlo_cost_analysis",
+ "//xla/service:hlo_pass",
+ "//xla/service:instruction_fusion",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:hlo_traversal",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu/fusions/triton:triton_support",
+ "//xla/service/gpu/model:fusion_analysis_cache",
+ "//xla/service/gpu/model:gpu_indexing_performance_model",
+ "//xla/service/gpu/model:symbolic_tile_analysis",
+ "//xla/service/gpu/model:tiled_hlo_computation",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "softmax_rewriter_triton_test",
+ srcs = ["softmax_rewriter_triton_test.cc"],
+ deps = [
+ ":softmax_rewriter_triton",
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:instruction_fusion",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:gpu_device_info_for_tests",
+ "//xla/service/gpu/fusions/triton:triton_support",
+ "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "sort_rewriter",
+ srcs = if_gpu_is_configured(
+ ["sort_rewriter.cc"],
+ ["sort_rewriter_stub.cc"],
+ ),
+ hdrs = ["sort_rewriter.h"],
+ deps = [
+ "//xla:comparison_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:stable_sort_expander",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/service/gpu/runtime:cub_sort_thunk",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "sort_rewriter_test",
+ srcs = if_cuda_is_configured(["sort_rewriter_test.cc"]),
+ backends = ["gpu"],
+ tags = ["no_oss"],
+ deps = [
+ ":sort_rewriter",
+ "//xla:error_spec",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:cublas_cudnn",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+cc_library(
+ name = "stream_attribute_annotator",
+ srcs = ["stream_attribute_annotator.cc"],
+ hdrs = ["stream_attribute_annotator.h"],
+ deps = [
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:gpu_fusible",
+ "//xla/service/gpu/runtime:thunk",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "stream_attribute_annotator_test",
+ srcs = ["stream_attribute_annotator_test.cc"],
+ deps = [
+ ":stream_attribute_annotator",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "stream_attribute_async_wrapper",
+ srcs = ["stream_attribute_async_wrapper.cc"],
+ hdrs = ["stream_attribute_async_wrapper.h"],
+ deps = [
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu/runtime:thunk",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "stream_attribute_async_wrapper_test",
+ srcs = ["stream_attribute_async_wrapper_test.cc"],
+ deps = [
+ ":stream_attribute_async_wrapper",
+ "//xla/hlo/ir:hlo",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "topk_specializer",
+ srcs = ["topk_specializer.cc"],
+ hdrs = ["topk_specializer.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:hlo_proto_cc",
+ "//xla/service:tuple_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_test(
+ name = "topk_specializer_test",
+ srcs = ["topk_specializer_test.cc"],
+ backends = ["gpu"],
+ deps = [
+ ":topk_specializer",
+ "//xla:shape_util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "//xla/service:platform_util",
+ "//xla/service:topk_rewriter",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
+ name = "topk_splitter",
+ srcs = ["topk_splitter.cc"],
+ hdrs = ["topk_splitter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/numeric:bits",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "topk_splitter_test",
+ srcs = ["topk_splitter_test.cc"],
+ deps = [
+ ":topk_splitter",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_dce",
+ "//xla/service:pattern_matcher",
+ "//xla/service:topk_rewriter",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:verified_hlo_module",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ ],
+)
+
+cc_library(
+ name = "tree_reduction_rewriter",
+ srcs = ["tree_reduction_rewriter.cc"],
+ hdrs = ["tree_reduction_rewriter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:collective_ops_utils",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:reduction_utils",
+ "//xla/stream_executor:device_description",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/numeric:bits",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "tree_reduction_rewriter_test",
+ srcs = [
+ "tree_reduction_rewriter_test.cc",
+ ],
+ deps = [
+ ":tree_reduction_rewriter",
+ "//xla/stream_executor:device_description",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+ name = "triangular_solve_rewriter",
+ srcs = ["triangular_solve_rewriter.cc"],
+ hdrs = ["triangular_solve_rewriter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service/gpu:cublas_cudnn",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+cc_library(
+ name = "triton_fusion_numerics_verifier",
+ srcs = ["triton_fusion_numerics_verifier.cc"],
+ hdrs = ["triton_fusion_numerics_verifier.h"],
+ tags = ["gpu"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:status_macros",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:executable",
+ "//xla/service:hlo_module_config",
+ "//xla/service:hlo_pass",
+ "//xla/service:shaped_buffer",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/service/gpu:buffer_comparator",
+ "//xla/service/gpu:ir_emission_utils",
+ "//xla/service/gpu/autotuning:autotuner_compile_util",
+ "//xla/service/gpu/autotuning:autotuner_util",
+ "//xla/stream_executor:stream",
+ "//xla/tools:hlo_decomposer_lib",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/functional:any_invocable",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_test(
+ name = "triton_fusion_numerics_verifier_test",
+ srcs = ["triton_fusion_numerics_verifier_test.cc"],
+ backend_tags = {"gpu": [
+ "requires-gpu-sm80",
+ ]},
+ backends = ["gpu"],
+ deps = [
+ ":triton_fusion_numerics_verifier",
+ "//xla:shape_util",
+ "//xla:test_helpers",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:platform_util",
+ "//xla/service/gpu/autotuning:autotuner_compile_util",
+ "//xla/service/gpu/autotuning:autotuner_util",
+ "//xla/stream_executor:platform",
+ "//xla/tests:hlo_test_base",
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "variadic_op_splitter",
+ srcs = ["variadic_op_splitter.cc"],
+ hdrs = ["variadic_op_splitter.h"],
+ deps = [
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "variadic_op_splitter_test",
+ srcs = ["variadic_op_splitter_test.cc"],
+ tags = [
+ "nomsan",
+ ],
+ deps = [
+ ":variadic_op_splitter",
+ "//xla:literal_util",
+ "//xla:shape_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_parser",
+ "//xla/service:pattern_matcher",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "windowed_einsum_handler",
+ srcs = ["windowed_einsum_handler.cc"],
+ hdrs = ["windowed_einsum_handler.h"],
+ deps = [
+ "//xla:literal_util",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/hlo/utils:hlo_query",
+ "//xla/service:hlo_creation_utils",
+ "//xla/service:hlo_pass",
+ "//xla/service:pattern_matcher",
+ "//xla/service:shape_inference",
+ "//xla/service/gpu:backend_configs_cc",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "windowed_einsum_handler_test",
+ srcs = ["windowed_einsum_handler_test.cc"],
+ deps = [
+ ":windowed_einsum_handler",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:pattern_matcher",
+ "//xla/service:pattern_matcher_gmock",
+ "//xla/service/gpu:backend_configs_cc",
+ "//xla/tests:filecheck",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings:string_view",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc
new file mode 100644
index 0000000..bba27a0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc
@@ -0,0 +1,66 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+
+#include "absl/log/check.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla::gpu {
+
+bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
+ const HloInstruction* hlo) {
+ if (!options_.enable_dot_strength_reduction()) {
+ return false;
+ }
+
+ const HloDotInstruction* dot = DynCast<HloDotInstruction>(hlo);
+ if (dot == nullptr) {
+ return false;
+ }
+
+ const HloInstruction* lhs = dot->operand(0);
+ const HloInstruction* rhs = dot->operand(1);
+ DotDimensionNumbers dnums = dot->dot_dimension_numbers();
+ bool lhs_is_vector = (dnums.lhs_batch_dimensions_size() +
+ dnums.lhs_contracting_dimensions_size() ==
+ lhs->shape().rank());
+ bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() +
+ dnums.rhs_contracting_dimensions_size() ==
+ rhs->shape().rank());
+ // Strength-reduce vector-vector dots since they are not supported by
+ // GemmFusion.
+ if (lhs_is_vector && rhs_is_vector) {
+ return true;
+ }
+
+ absl::StatusOr<bool> is_too_small =
+ IsMatrixMultiplicationTooSmallForRewriting(*hlo, /*threshold=*/10000000);
+ CHECK_OK(is_too_small.status());
+ if (is_too_small.value()) {
+ return true;
+ }
+
+ // If GemmFusion cannot handle this dot, we should strength-reduce it so that
+ // it can be handled by the fusion pipeline.
+ return !legacy_triton::CanTritonHandleGEMM(*dot, compute_capability_);
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h
new file mode 100644
index 0000000..f29b31e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h
@@ -0,0 +1,78 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+
+namespace xla::gpu {
+
+class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor {
+ public:
+ explicit GpuAlgebraicSimplifierVisitor(
+ const AlgebraicSimplifierOptions& options,
+ se::GpuComputeCapability compute_capability,
+ AlgebraicSimplifier* simplifier)
+ : AlgebraicSimplifierVisitor(options, simplifier),
+ compute_capability_(std::move(compute_capability)) {}
+
+ bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override;
+
+ private:
+ se::GpuComputeCapability compute_capability_;
+};
+
+class GpuAlgebraicSimplifier : public AlgebraicSimplifier {
+ public:
+ explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options,
+ se::GpuComputeCapability compute_capability)
+ : AlgebraicSimplifier(options),
+ compute_capability_(std::move(compute_capability)) {}
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(HloModule* module,
+ const absl::flat_hash_set<absl::string_view>&
+ execution_threads) override {
+ XLA_VLOG_LINES(
+ 2, "GpuAlgebraicSimplifier::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ GpuAlgebraicSimplifierVisitor visitor(options_, compute_capability_, this);
+ for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
+ if (visitor.Run(comp, options_, this)) {
+ changed = true;
+ }
+ }
+ XLA_VLOG_LINES(
+ 2, "GpuAlgebraicSimplifier::Run(), after:\n" + module->ToString());
+ return changed;
+ }
+
+ private:
+ se::GpuComputeCapability compute_capability_;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc
new file mode 100644
index 0000000..c1e52e9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc
@@ -0,0 +1,141 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+
+#include <string>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+class GpuAlgebraicSimplifierTest : public HloTestBase {};
+
+TEST_F(GpuAlgebraicSimplifierTest, VectorVectorDotShouldBeStrengthReduced) {
+ const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+ p0 = f32[32, 500] parameter(0)
+ p1 = f32[32, 500] parameter(1)
+ ROOT dot = f32[32] dot(p0, p1), lhs_batch_dims={0},
+ lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ const HloInstruction* dot = module->entry_computation()->root_instruction();
+ AlgebraicSimplifierOptions options;
+ options.set_enable_dot_strength_reduction(true);
+ se::CudaComputeCapability ampere(8, 0);
+ GpuAlgebraicSimplifier simplifier(options, ampere);
+ GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+ EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest, MatrixVectorDotShouldNotBeStrengthReduced) {
+ const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+ p0 = f32[32, 5000, 7000] parameter(0)
+ p1 = f32[32, 5000] parameter(1)
+ ROOT dot = f32[32,7000] dot(p0, p1), lhs_batch_dims={0},
+ lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
+ algorithm=dot_bf16_bf16_f32_x6
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ const HloInstruction* dot = module->entry_computation()->root_instruction();
+ AlgebraicSimplifierOptions options;
+ options.set_enable_dot_strength_reduction(true);
+ se::CudaComputeCapability ampere(8, 0);
+ GpuAlgebraicSimplifier simplifier(options, ampere);
+ GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+ EXPECT_FALSE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest,
+ DotWithTypeUnsupportedByGemmFusionShouldBeStrengthReduced) {
+ const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+ p0 = c64[32, 5000, 7000] parameter(0)
+ p1 = c64[32, 5000] parameter(1)
+ ROOT dot = c64[32,7000] dot(p0, p1), lhs_batch_dims={0},
+ lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ const HloInstruction* dot = module->entry_computation()->root_instruction();
+ AlgebraicSimplifierOptions options;
+ options.set_enable_dot_strength_reduction(true);
+ se::CudaComputeCapability ampere(8, 0);
+ GpuAlgebraicSimplifier simplifier(options, ampere);
+ GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+ EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced) {
+ const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+ p0 = f32[32, 50, 70] parameter(0)
+ p1 = f32[32, 50] parameter(1)
+ ROOT dot = f32[32,70] dot(p0, p1), lhs_batch_dims={0},
+ lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
+ algorithm=dot_bf16_bf16_f32_x6
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ const HloInstruction* dot = module->entry_computation()->root_instruction();
+ AlgebraicSimplifierOptions options;
+ options.set_enable_dot_strength_reduction(true);
+ se::CudaComputeCapability ampere(8, 0);
+ GpuAlgebraicSimplifier simplifier(options, ampere);
+ GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+ EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced2) {
+ const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+ p0 = f32[2000, 3000] parameter(0)
+ p1 = f32[2000] parameter(1)
+ ROOT dot = f32[3000] dot(p0, p1), lhs_contracting_dims={0},
+ rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32_x6
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ const HloInstruction* dot = module->entry_computation()->root_instruction();
+ AlgebraicSimplifierOptions options;
+ options.set_enable_dot_strength_reduction(true);
+ se::CudaComputeCapability ampere(8, 0);
+ GpuAlgebraicSimplifier simplifier(options, ampere);
+ GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+ EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc
new file mode 100644
index 0000000..664d7b2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc
@@ -0,0 +1,117 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/algorithm_checker.h"
+
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/algorithm_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+bool HasNonDefaultOperandPrecision(const PrecisionConfig& config) {
+ return absl::c_any_of(config.operand_precision(), [](int precision) {
+ return static_cast<PrecisionConfig::Precision>(precision) !=
+ PrecisionConfig::DEFAULT;
+ });
+}
+
+class AlgorithmCheckerVisitor : public ConstDfsHloVisitorWithDefault {
+ public:
+ explicit AlgorithmCheckerVisitor(
+ se::GpuComputeCapability gpu_compute_capability)
+ : gpu_compute_capability_(std::move(gpu_compute_capability)) {}
+
+ absl::Status RunOnModule(
+ const HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_RETURN_IF_ERROR(computation->Accept(this));
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleDot(const HloInstruction* hlo) override {
+ VLOG(1) << "Handling dot: " << hlo->ToString();
+ const PrecisionConfig& config = hlo->precision_config();
+
+ if (config.algorithm() != PrecisionConfig::ALG_UNSET &&
+ HasNonDefaultOperandPrecision(config)) {
+ LOG(WARNING)
+ << "There is no need to set precisions when we set the algorithm: "
+ << hlo->ToString();
+ }
+
+ if (config.algorithm() == PrecisionConfig::ALG_UNSET) {
+ return absl::OkStatus();
+ }
+
+ PrimitiveType lhs_storage_type = hlo->operand(0)->shape().element_type();
+ PrimitiveType rhs_storage_type = hlo->operand(1)->shape().element_type();
+ PrimitiveType output_storage_type = hlo->shape().element_type();
+
+ if (lhs_storage_type != rhs_storage_type) {
+ return absl::UnimplementedError(absl::StrFormat(
+ "Dot operands must have the same type when using an algorithm: %s",
+ hlo->ToString()));
+ }
+
+ return algorithm_util::IsSupportedDotAlgorithmOnGpu(
+ config.algorithm(), gpu_compute_capability_, lhs_storage_type,
+ output_storage_type)
+ ? absl::OkStatus()
+ : absl::UnimplementedError(absl::StrFormat(
+ "Unsupported algorithm on the current device(s): %s",
+ PrecisionConfig::Algorithm_Name(config.algorithm())));
+ }
+
+ absl::Status DefaultAction(const HloInstruction* hlo) override {
+ return absl::OkStatus();
+ }
+
+ private:
+ se::GpuComputeCapability gpu_compute_capability_;
+};
+
+} // namespace
+
+absl::StatusOr<bool> AlgorithmChecker::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ TF_RETURN_IF_ERROR(AlgorithmCheckerVisitor(gpu_compute_capability_)
+ .RunOnModule(module, execution_threads));
+ // No change was made.
+ return false;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h
new file mode 100644
index 0000000..c2cf0d2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h
@@ -0,0 +1,54 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// This checks if the requested algorithms are supported. This can give an early
+// and specific error if an unsupported algorithm is requested.
+//
+// Note: Maybe we can make this more generic and move it outside of GPU.
+class AlgorithmChecker : public HloModulePass {
+ public:
+ explicit AlgorithmChecker(se::GpuComputeCapability gpu_compute_capability)
+ : gpu_compute_capability_(std::move(gpu_compute_capability)){};
+
+ absl::string_view name() const override { return "algorithm-checker"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::GpuComputeCapability gpu_compute_capability_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc
new file mode 100644
index 0000000..0d6bff3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc
@@ -0,0 +1,69 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/alias_passthrough_params.h"
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> AliasPassthroughParams::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ if (module->entry_computation()->num_parameters() == 0 ||
+ root->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+ bool changed = false;
+ absl::flat_hash_set<int64_t> used_params;
+ for (int64_t i = 0; i < root->operand_count(); ++i) {
+ if (root->operand(i)->opcode() == HloOpcode::kParameter &&
+ used_params.count(root->operand(i)->parameter_number()) == 0) {
+ VLOG(2) << "Parameter " << root->operand(i)->parameter_number()
+ << " with shape " << root->operand(i)->shape().ToString()
+ << " in module " << module->name()
+ << " is passed-through to root tuple element " << i << ": "
+ << root->shape().ToString();
+
+ if (module->input_output_alias_config().OutputHasAlias({i}) ||
+ module->input_output_alias_config().ParameterHasAlias(
+ root->operand(i)->parameter_number(), /*param_index=*/{})) {
+ VLOG(2) << "Skip setting the above pass-through alias as an alias may"
+ << " have been set up for alising resource update.";
+ continue;
+ }
+
+ TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{i},
+ /*param_number=*/root->operand(i)->parameter_number(),
+ /*param_index=*/{}));
+ used_params.insert(root->operand(i)->parameter_number());
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h
new file mode 100644
index 0000000..4cd4ab4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h
@@ -0,0 +1,50 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// This pass aliases input and output buffers that are associated with a
+// parameter that is passed through to the module root unmodified.
+//
+// This pass assumes that parameters and the root use unnested shapes, which is
+// the case for XLA:GPU.
+//
+// This pass must run prior to copy insertion.
+class AliasPassthroughParams : public HloModulePass {
+ public:
+ AliasPassthroughParams() = default;
+ ~AliasPassthroughParams() override = default;
+ absl::string_view name() const override { return "alias_passthrough_params"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc
new file mode 100644
index 0000000..32c1e5b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/alias_passthrough_params.h"
+
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+
+class AliasPassthroughParamsTest : public HloTestBase {};
+
+TEST_F(AliasPassthroughParamsTest, AliasPassThroughParams) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ p0 = f16[2048,1024] parameter(0)
+ p1 = f16[2048,1024] parameter(1)
+ sum = f16[2048,1024] add(p0, p1)
+ ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
+ })")
+ .value();
+ EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
+ const auto& alias_config = module->input_output_alias_config();
+ EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
+ EXPECT_FALSE(alias_config.OutputHasAlias({1}));
+ EXPECT_EQ(1, alias_config.GetAliasedParameter({2})->parameter_number);
+}
+
+TEST_F(AliasPassthroughParamsTest, DoNotAliasPassThroughParamsMoreThanOnce) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ p0 = f16[2048,1024] parameter(0)
+ ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p0)
+ })")
+ .value();
+ EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
+ const auto& alias_config = module->input_output_alias_config();
+ EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
+ EXPECT_FALSE(alias_config.OutputHasAlias({1}));
+}
+
+TEST_F(AliasPassthroughParamsTest, PresetAliases) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ p0 = f16[2048,1024] parameter(0)
+ p1 = f16[2048,1024] parameter(1)
+ sum = f16[2048,1024] add(p0, p1)
+ ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
+ })")
+ .value();
+
+ // Presetting an alias for p0 -> Sum. This could happen in a case of
+ // `alias_resource_update`.
+ auto& preset_alias = module->input_output_alias_config();
+ TF_EXPECT_OK(preset_alias.SetUpAlias(/*output_index=*/{1},
+ /*param_number=*/0,
+ /*param_index=*/{}));
+
+ EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
+ const auto& alias_result = module->input_output_alias_config();
+ // Assert that an alias p1 -> p1 is established by `AliasPassthroughParams`.
+ EXPECT_EQ(1, alias_result.GetAliasedParameter({2})->parameter_number);
+ EXPECT_FALSE(alias_result.OutputHasAlias({0}));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc
new file mode 100644
index 0000000..2f7c130
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc
@@ -0,0 +1,109 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_gather_optimizer.h"
+
+#include <cstdint>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> AllGatherOptimizer::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ if (!HloOpcodeIsBinaryCommutative(instruction->opcode())) {
+ continue;
+ }
+
+ HloInstruction* left_op = instruction->mutable_operand(0);
+ HloInstruction* right_op = instruction->mutable_operand(1);
+
+ if (right_op->opcode() != HloOpcode::kAllGather ||
+ left_op->opcode() != HloOpcode::kAllGather) {
+ VLOG(2) << "Binary op's operands are not all-gather deduced types.";
+ continue;
+ }
+
+ auto* left_all_gather = Cast<HloAllGatherInstruction>(left_op);
+ auto* right_all_gather = Cast<HloAllGatherInstruction>(right_op);
+
+ if (right_all_gather->constrain_layout() !=
+ left_all_gather->constrain_layout() ||
+ right_all_gather->use_global_device_ids() !=
+ left_all_gather->use_global_device_ids() ||
+ !ReplicaGroupsEqual(right_all_gather->replica_groups(),
+ left_all_gather->replica_groups())) {
+ VLOG(2) << "The right and left all-gather ops are not compatible "
+ "to merge. ";
+ continue;
+ }
+
+ if (!ShapeUtil::Equal(left_all_gather->operand(0)->shape(),
+ right_all_gather->operand(0)->shape())) {
+ VLOG(2) << "all-gather operands have different shapes";
+ continue;
+ }
+
+ if (right_all_gather->user_count() != 1 ||
+ left_all_gather->user_count() != 1) {
+ VLOG(2) << "all-gather user_count > 1 ";
+ continue;
+ }
+ auto index_in_full_shape =
+ computation->AddInstruction(HloInstruction::CreateBinary(
+ right_all_gather->operand(0)->shape(), instruction->opcode(),
+ left_all_gather->mutable_operand(0),
+ right_all_gather->mutable_operand(0)));
+
+ int64_t all_gather_dimension =
+ Cast<HloAllGatherInstruction>(right_all_gather)
+ ->all_gather_dimension();
+
+ auto combined = HloInstruction::CreateAllGather(
+ left_all_gather->shape(), {index_in_full_shape}, all_gather_dimension,
+ left_all_gather->device_list(),
+ /*constrain_layout=*/false, left_all_gather->channel_id(),
+ Cast<HloAllGatherInstruction>(left_all_gather)
+ ->use_global_device_ids());
+
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ instruction, std::move(combined)));
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h
new file mode 100644
index 0000000..988c1f6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h
@@ -0,0 +1,46 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Transforms binary_op(all-gather(reduce_scatter(a)),
+// all-gather(reduce_scatter(b))) to allgather(binary_op(reduce_scatter(a),
+// reduce_scatter(b)))
+
+class AllGatherOptimizer : public HloModulePass {
+ public:
+ AllGatherOptimizer() = default;
+ absl::string_view name() const override { return "all-gather-optimizer"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc
new file mode 100644
index 0000000..27f6d65
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc
@@ -0,0 +1,232 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_gather_optimizer.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class GpuAllGatherOptimizerTest : public HloTestBase {
+ public:
+ absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
+ absl::string_view hlo_module, int64_t num_replicas,
+ int64_t num_partitions, bool expect_change) {
+ HloModuleConfig config = GetModuleConfigForTest(
+ /*replica_count=*/num_replicas,
+ /*num_partitions=*/num_partitions);
+ config.set_use_spmd_partitioning(num_partitions > 1);
+ TF_ASSIGN_OR_RETURN(auto module,
+ ParseAndReturnVerifiedModule(hlo_module, config));
+
+ auto changed = AllGatherOptimizer().Run(module.get());
+ if (!changed.ok()) {
+ return changed.status();
+ }
+ EXPECT_EQ(changed.value(), expect_change);
+ return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
+ }
+
+ template <HloOpcode oc>
+ size_t CollectiveCount(std::unique_ptr<HloModule> &module) {
+ return absl::c_count_if(module->entry_computation()->instructions(),
+ HloPredicateIsOp<oc>);
+ }
+};
+
+TEST_F(GpuAllGatherOptimizerTest, BranchesOptimized) {
+ absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+ x = bf16[] parameter(0)
+ y = bf16[] parameter(1)
+ ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/true));
+ // graph should contain 1 all-gather but since the node removal piece
+ // is diferred, they still exist at this stage
+ EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
+ EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 2);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, DisbledSPMDPartitioningJAXBug) {
+ absl::string_view hlo_string = R"(
+HloModule pjit_f, entry_computation_layout={(f32[4,8]{1,0}, f32[4,8]{1,0})->f32[8,8]{1,0}}
+
+ENTRY %main.6_spmd (param: f32[4,8], param.1: f32[4,8]) -> f32[8,8] {
+ %param = f32[4,8]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}
+ %all-gather = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
+ %param.1 = f32[4,8]{1,0} parameter(1), sharding={devices=[2,1]<=[2]}
+ %all-gather.1 = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param.1), channel_id=2, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
+ ROOT %add.0 = f32[8,8]{1,0} add(f32[8,8]{1,0} %all-gather, f32[8,8]{1,0} %all-gather.1), metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/1,
+ /*num_partitions=*/2,
+ /*expect_change=*/true));
+ EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, MoreThanSingleUserForAllGather) {
+ absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+ x = bf16[] parameter(0)
+ y = bf16[] parameter(1)
+ ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
+reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
+add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/false));
+ // see the comment for BranchesOptimized test
+ EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
+ EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, AllGatherWithOpInBetweenOnRightBranch) {
+ absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+ x = bf16[] parameter(0)
+ y = bf16[] parameter(1)
+ ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
+reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+add.1 = bf16[8,64,1024]{2,1,0} add(reduce-scatter.1, reduce-scatter.2)
+all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(add.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/true));
+ EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
+ EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, AllGatherOneSided) {
+ absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+ x = bf16[] parameter(0)
+ y = bf16[] parameter(1)
+ ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
+
+add.1 = bf16[8,128,1024]{2,1,0} add(param.1, param.2)
+reduce-scatter = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.2 = bf16[8,128,1024]{2,1,0} add(all-gather, add.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/false));
+ EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
+ EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 1);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, DifferentOperandShapes) {
+ absl::string_view hlo_string = R"(
+HloModule TestModule
+
+ENTRY main {
+param.1 = bf16[8,64,128]{2,1,0} parameter(0)
+param.2 = bf16[8,128,64]{2,1,0} parameter(1)
+all-gather.1 = bf16[8,128,128]{2,1,0} all-gather(param.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+all-gather.2 = bf16[8,128,128]{2,1,0} all-gather(param.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+add.1 = bf16[8,128,128]{2,1,0} add(all-gather.1, all-gather.2)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/false));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc
new file mode 100644
index 0000000..0e0e67a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc
@@ -0,0 +1,373 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_blueconnect.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/btree_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/computation_placer.h"
+#include "xla/service/global_device_id.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+std::vector<HloInstruction*> GetOutputs(HloInstruction& instruction) {
+ if (!instruction.shape().IsTuple()) {
+ return {&instruction};
+ }
+
+ std::vector<HloInstruction*> outputs;
+ outputs.reserve(instruction.shape().tuple_shapes_size());
+
+ HloComputation& computation = *instruction.parent(); // never null
+ for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
+ outputs.push_back(computation.AddInstruction(
+ HloInstruction::CreateGetTupleElement(&instruction, i)));
+ }
+ return outputs;
+}
+
+struct DecomposedReplicaGroups {
+ std::vector<ReplicaGroup> scatter_gather_groups;
+ std::vector<ReplicaGroup> new_all_reduce_groups;
+};
+
+// Returns the global device id for the given replica id. Returns nullopt if
+// if the replica id can refer to multiple devices, or if the pass does not
+// support the CollectiveOpGroupMode.
+std::optional<GlobalDeviceId> TryConvertingReplicaIdToDeviceId(
+ int64_t replica_id, const DeviceAssignment& device_assignment,
+ CollectiveOpGroupMode collective_group_mode) {
+ if (collective_group_mode == CollectiveOpGroupMode::kCrossReplica) {
+ if (device_assignment.computation_count() != 1) {
+ // If there are multiple partitions, the replica_id may refer to multiple
+ // devices on different partitions.
+ return std::nullopt;
+ }
+ return GlobalDeviceId{device_assignment(replica_id, /*computation_id=*/0)};
+ } else if (collective_group_mode == CollectiveOpGroupMode::kFlattenedID) {
+ int partition_count = device_assignment.computation_count();
+ int64_t actual_replica_id = replica_id / partition_count;
+ int64_t partition_id = replica_id % partition_count;
+ return GlobalDeviceId{device_assignment(actual_replica_id, partition_id)};
+ }
+
+ // kCrossPartition and kCrossReplicaAndPartition are unsupported.
+ VLOG(1) << "Skip AllReduceBlueConnect because of unsupported "
+ "CollectiveOpGroupMode "
+ << CollectiveOpGroupModeToString(collective_group_mode);
+ return std::nullopt;
+}
+
+absl::StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroup(
+ const ReplicaGroup& replica_group,
+ const DeviceAssignment& device_assignment, size_t num_devices_per_host,
+ CollectiveOpGroupMode collective_group_mode) {
+ int group_size = replica_group.replica_ids_size();
+ TF_RET_CHECK(group_size > 0);
+
+ absl::btree_map<int, std::vector<int64_t>> replica_ids_by_host;
+ for (int64_t replica_id : replica_group.replica_ids()) {
+ std::optional<GlobalDeviceId> device_id = TryConvertingReplicaIdToDeviceId(
+ replica_id, device_assignment, collective_group_mode);
+ if (!device_id.has_value()) {
+ return {std::nullopt};
+ }
+ TF_RET_CHECK(*device_id >= 0);
+ // We assume that devices are ordered by host.
+ int host_id = device_id->value() / num_devices_per_host;
+ replica_ids_by_host[host_id].push_back(replica_id);
+ }
+
+ size_t num_local_devices = replica_ids_by_host.begin()->second.size();
+ bool same_num_devices_on_each_host =
+ absl::c_all_of(replica_ids_by_host, [&](const auto& entry) {
+ return entry.second.size() == num_local_devices;
+ });
+
+ if (!same_num_devices_on_each_host) {
+ return {std::nullopt};
+ }
+
+ std::vector<int64_t> sorted_replica_group;
+ sorted_replica_group.reserve(group_size);
+ for (const auto& entry : replica_ids_by_host) {
+ absl::c_copy(entry.second, std::back_inserter(sorted_replica_group));
+ }
+
+ size_t scatter_group_size = std::max(num_local_devices, size_t(2));
+ size_t num_scatter_groups = group_size / scatter_group_size;
+
+ if ((group_size % scatter_group_size != 0) || (num_scatter_groups < 2)) {
+ return {std::nullopt};
+ }
+
+ std::vector<ReplicaGroup> scatter_gather_groups(num_scatter_groups);
+ std::vector<ReplicaGroup> new_all_reduce_groups(scatter_group_size);
+
+ for (size_t i = 0; i < group_size; ++i) {
+ int64_t replica_id = sorted_replica_group[i];
+ scatter_gather_groups[i / scatter_group_size].add_replica_ids(replica_id);
+ new_all_reduce_groups[i % scatter_group_size].add_replica_ids(replica_id);
+ }
+
+ return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
+ std::move(new_all_reduce_groups)}};
+}
+
+absl::StatusOr<std::optional<DecomposedReplicaGroups>>
+TryDecomposeReplicaGroups(const HloAllReduceInstruction& all_reduce,
+ size_t num_devices_per_host) {
+ const DeviceAssignment& device_assignment =
+ all_reduce.GetModule()->config().static_device_assignment();
+
+ absl::Span<const ReplicaGroup> replica_groups = all_reduce.replica_groups();
+
+ ReplicaGroup all_replicas; // only populated if replica groups not present.
+ if (replica_groups.empty()) {
+ for (int i = 0; i < device_assignment.replica_count(); ++i) {
+ all_replicas.add_replica_ids(i);
+ }
+ replica_groups = absl::MakeSpan(&all_replicas, 1);
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ CollectiveOpGroupMode collective_op_group_mode,
+ GetCollectiveOpGroupMode(all_reduce.channel_id().has_value(),
+ all_reduce.use_global_device_ids()));
+
+ std::vector<ReplicaGroup> scatter_gather_groups;
+ std::vector<ReplicaGroup> new_all_reduce_groups;
+
+ // Try to find a valid decomposition for each replica group.
+ for (const ReplicaGroup& replica_group : replica_groups) {
+ TF_ASSIGN_OR_RETURN(
+ std::optional<DecomposedReplicaGroups> decomposed_groups,
+ TryDecomposeReplicaGroup(replica_group, device_assignment,
+ num_devices_per_host,
+ collective_op_group_mode));
+
+ if (!decomposed_groups) return {std::nullopt};
+
+ int scatter_group_size =
+ decomposed_groups->scatter_gather_groups[0].replica_ids_size();
+
+ if (scatter_gather_groups.empty()) {
+ // Check that every operand is exactly divisible by scatter group sizes.
+ for (const HloInstruction* operand : all_reduce.operands()) {
+ TF_RET_CHECK(operand->shape().IsArray());
+ int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
+ if (num_elements % scatter_group_size != 0) {
+ return {std::nullopt};
+ }
+ }
+
+ scatter_gather_groups.reserve(
+ replica_groups.size() *
+ decomposed_groups->scatter_gather_groups.size());
+ new_all_reduce_groups.reserve(
+ replica_groups.size() *
+ decomposed_groups->new_all_reduce_groups.size());
+ } else if (scatter_group_size !=
+ scatter_gather_groups[0].replica_ids_size()) {
+ // Reduce-scatter would have different output shapes on different devices.
+ return {std::nullopt};
+ }
+
+ absl::c_move(decomposed_groups->scatter_gather_groups,
+ std::back_inserter(scatter_gather_groups));
+ absl::c_move(decomposed_groups->new_all_reduce_groups,
+ std::back_inserter(new_all_reduce_groups));
+ }
+
+ return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
+ std::move(new_all_reduce_groups)}};
+}
+
+// Attempts to decompose all-reduces as described by the BlueConnect paper.
+//
+// If possible, the all-reduce will be transformed into:
+// 1. reduce-scatter
+// 2. all-reduce
+// 3. all-gather
+//
+// If the all-reduce replica groups have more than one device within the same
+// host, the reduce-scatter will be performed over all devices with each host.
+// Otherwise, the reduce-scatter will be performed between pairs of devices on
+// different hosts.
+//
+// When applied repeatedly, this transformation will reproduce the same pattern
+// as described in the BlueConnect paper.
+absl::StatusOr<bool> TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce,
+ size_t num_devices_per_host) {
+ TF_RET_CHECK(all_reduce);
+ TF_RET_CHECK(!all_reduce->has_sharding());
+
+ HloComputation& computation = *all_reduce->parent(); // never null
+ PrimitiveType element_type = all_reduce->operand(0)->shape().element_type();
+
+ TF_ASSIGN_OR_RETURN(
+ std::optional<DecomposedReplicaGroups> decomposed_groups,
+ TryDecomposeReplicaGroups(*all_reduce, num_devices_per_host));
+
+ if (!decomposed_groups) return false;
+
+ // Bitcast operands to 1D to guarantee that first dimension is divisible by
+ // scatter group size (we checked num elements was divisible above).
+ std::vector<HloInstruction*> flat_operands;
+ flat_operands.reserve(all_reduce->operand_count());
+ std::vector<Shape> flat_shapes;
+ flat_shapes.reserve(all_reduce->operand_count());
+ std::vector<Shape> scattered_shapes;
+ scattered_shapes.reserve(all_reduce->operand_count());
+
+ int scatter_group_size =
+ decomposed_groups->scatter_gather_groups[0].replica_ids_size();
+
+ for (HloInstruction* operand : all_reduce->operands()) {
+ TF_RET_CHECK(operand->shape().IsArray());
+ int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
+ Shape flat_shape = ShapeUtil::MakeShape(element_type, {num_elements});
+ flat_operands.push_back(computation.AddInstruction(
+ HloInstruction::CreateBitcast(flat_shape, operand)));
+ flat_shapes.push_back(std::move(flat_shape));
+ scattered_shapes.push_back(ShapeUtil::MakeShape(
+ element_type, {num_elements / scatter_group_size}));
+ }
+
+ Shape reduce_scatter_shape = ShapeUtil::MakeMaybeTupleShape(scattered_shapes);
+
+ int64_t next_channel_id = hlo_query::NextChannelId(*computation.parent());
+ auto get_channel_id = [&]() -> std::optional<int64_t> {
+ if (all_reduce->channel_id().has_value()) {
+ return next_channel_id++;
+ }
+ return std::nullopt;
+ };
+
+ HloInstruction* reduce_scatter =
+ computation.AddInstruction(HloInstruction::CreateReduceScatter(
+ reduce_scatter_shape, flat_operands, all_reduce->to_apply(),
+ CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
+ /*constrain_layout=*/false, get_channel_id(),
+ all_reduce->use_global_device_ids(),
+ /*scatter_dimension=*/0));
+
+ HloInstruction* new_all_reduce =
+ computation.AddInstruction(HloInstruction::CreateAllReduce(
+ reduce_scatter_shape, GetOutputs(*reduce_scatter),
+ all_reduce->to_apply(),
+ CollectiveDeviceList(decomposed_groups->new_all_reduce_groups),
+ /*constrain_layout=*/false, all_reduce->channel_id(),
+ all_reduce->use_global_device_ids()));
+
+ HloInstruction* all_gather =
+ computation.AddInstruction(HloInstruction::CreateAllGather(
+ ShapeUtil::MakeMaybeTupleShape(flat_shapes),
+ GetOutputs(*new_all_reduce),
+ /*all_gather_dimension=*/0,
+ CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
+ /*constrain_layout=*/false, get_channel_id(),
+ all_reduce->use_global_device_ids()));
+
+ // Bitcast back to the original shapes and replace all-reduce with decomposed
+ // implementation.
+ std::vector<HloInstruction*> outputs = GetOutputs(*all_gather);
+ for (int64_t i = 0; i < outputs.size(); ++i) {
+ outputs[i] = computation.AddInstruction(HloInstruction::CreateBitcast(
+ all_reduce->operand(i)->shape(), outputs[i]));
+ }
+ HloInstruction* replacement = MaybeMakeTuple(outputs);
+
+ TF_RETURN_IF_ERROR(
+ all_reduce->CopyAllControlDepsTo(reduce_scatter, replacement));
+
+ TF_RETURN_IF_ERROR(all_reduce->DropAllControlDeps());
+ TF_RETURN_IF_ERROR(computation.ReplaceInstruction(all_reduce, replacement));
+
+ // Try to apply decomposition recursively.
+ TF_RETURN_IF_ERROR(
+ TryDecomposeAllReduce(Cast<HloAllReduceInstruction>(new_all_reduce),
+ num_devices_per_host)
+ .status());
+ return true;
+}
+
+} // namespace
+
+absl::StatusOr<bool> AllReduceBlueConnect::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ VLOG(1) << "Running AllReduceBlueConnect";
+
+ if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
+ VLOG(1)
+ << "Skip AllReduceBlueConnect because the module contains all-reduce "
+ "with constrained layouts";
+ return false;
+ }
+ if (!module->config().has_static_device_assignment()) {
+ VLOG(1)
+ << "Skip AllReduceBlueConnect because the module doesn't have static "
+ "device assignment";
+ return false;
+ }
+
+ std::vector<HloAllReduceInstruction*> all_reduces;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kAllReduce) {
+ all_reduces.push_back(Cast<HloAllReduceInstruction>(instruction));
+ }
+ }
+ }
+
+ bool changed = false;
+ for (HloAllReduceInstruction* all_reduce : all_reduces) {
+ TF_ASSIGN_OR_RETURN(
+ bool all_reduce_changed,
+ TryDecomposeAllReduce(all_reduce, num_devices_per_host_));
+ changed |= all_reduce_changed;
+ }
+
+ return changed;
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h
new file mode 100644
index 0000000..6da0bbe
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h
@@ -0,0 +1,56 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_
+
+#include <cstddef>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Decomposes all-reduce operations using the BlueConnect algorithm.
+//
+// Paper: "BLUECONNECT: DECOMPOSING ALL-REDUCE FOR DEEP LEARNING ON
+// HETEROGENEOUS NETWORK HIERARCHY"
+// https://mlsys.org/Conferences/2019/doc/2019/130.pdf
+//
+// This algorithm attempts to minimize the number of levels of network hierarchy
+// traversed for as much data transfer as possible. This implementation assumes
+// that host IDs are ordered corresponding to network hierarchy.
+class AllReduceBlueConnect : public HloModulePass {
+ public:
+ explicit AllReduceBlueConnect(size_t num_devices_per_host)
+ : num_devices_per_host_(num_devices_per_host) {}
+
+ absl::string_view name() const override { return "all-reduce-blueconnect"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ size_t num_devices_per_host_;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc
new file mode 100644
index 0000000..cafbf24
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc
@@ -0,0 +1,414 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_blueconnect.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/computation_placer.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+namespace m = ::xla::match;
+
+using AllReduceBlueConnectTest = HloTestBase;
+
+HloPredicate MatchChannelId(std::optional<int64_t> channel_id) {
+ return [channel_id](const HloInstruction* instruction) {
+ return instruction->channel_id() == channel_id;
+ };
+}
+
+void SetModuleConfig(HloModuleConfig* module_config, size_t replica_count,
+ size_t partition_count = 1) {
+ DeviceAssignment device_assignment(replica_count,
+ /*computation_count=*/partition_count);
+ device_assignment.FillIota(0);
+ module_config->set_replica_count(replica_count);
+ module_config->set_num_partitions(partition_count);
+ module_config->set_static_device_assignment(device_assignment);
+}
+
+void SetModuleConfig(HloModule& module, size_t replica_count,
+ size_t partition_count = 1) {
+ SetModuleConfig(&module.mutable_config(), replica_count, partition_count);
+}
+
+TEST_F(AllReduceBlueConnectTest, OneStage) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[4,4] parameter(0)
+ ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/8);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+ // clang-format off
+ std::vector<std::vector<int64_t>> scatter_gather_groups = {
+ {0, 1, 2, 3}, {4, 5, 6, 7}};
+ std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+ {0, 4}, {1, 5}, {2, 6}, {3, 7}};
+ // clang-format on
+
+ auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+ auto reduce_scatter = m::ReduceScatter(bitcast)
+ .WithShape(F32, {4})
+ .WithReplicaGroups(scatter_gather_groups)
+ .WithPredicate(MatchChannelId(std::nullopt));
+ auto all_reduce = m::AllReduce(reduce_scatter)
+ .WithShape(F32, {4})
+ .WithReplicaGroups(new_all_reduce_groups)
+ .WithPredicate(MatchChannelId(std::nullopt));
+ auto all_gather = m::AllGather(all_reduce)
+ .WithShape(F32, {16})
+ .WithReplicaGroups(scatter_gather_groups)
+ .WithPredicate(MatchChannelId(std::nullopt));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Bitcast(all_gather).WithShape(F32, {4, 4})));
+}
+
+TEST_F(AllReduceBlueConnectTest, TwoStage) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[4,4] parameter(0)
+ ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/16);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+ std::vector<std::vector<int64_t>> outer_scatter_gather_groups = {
+ {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}};
+ std::vector<std::vector<int64_t>> inner_scatter_gather_groups = {
+ {0, 4}, {8, 12}, {1, 5}, {9, 13}, {2, 6}, {10, 14}, {3, 7}, {11, 15}};
+ std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+ {0, 8}, {4, 12}, {1, 9}, {5, 13}, {2, 10}, {6, 14}, {3, 11}, {7, 15}};
+
+ auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+ auto reduce_scatter0 =
+ m::ReduceScatter(bitcast0).WithShape(F32, {4}).WithReplicaGroups(
+ outer_scatter_gather_groups);
+ auto bitcast1 = m::Bitcast(reduce_scatter0).WithShape(F32, {4});
+ auto reduce_scatter1 =
+ m::ReduceScatter(bitcast1).WithShape(F32, {2}).WithReplicaGroups(
+ inner_scatter_gather_groups);
+ auto all_reduce = m::AllReduce(reduce_scatter1)
+ .WithShape(F32, {2})
+ .WithReplicaGroups(new_all_reduce_groups);
+ auto all_gather0 = m::AllGather(all_reduce)
+ .WithShape(F32, {4})
+ .WithReplicaGroups(inner_scatter_gather_groups);
+ auto bitcast2 = m::Bitcast(all_gather0).WithShape(F32, {4});
+ auto all_gather1 =
+ m::AllGather(bitcast2).WithShape(F32, {16}).WithReplicaGroups(
+ outer_scatter_gather_groups);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Bitcast(all_gather1).WithShape(F32, {4, 4})));
+}
+
+TEST_F(AllReduceBlueConnectTest, TwoOperands) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[4,4] parameter(0)
+ p1 = f32[4,4,2] parameter(1)
+ ROOT crs = (f32[4,4], f32[4,4,2]) all-reduce(p0, p1), to_apply=add
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/8);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+ // clang-format off
+ std::vector<std::vector<int64_t>> scatter_gather_groups = {
+ {0, 1, 2, 3}, {4, 5, 6, 7}};
+ std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+ {0, 4}, {1, 5}, {2, 6}, {3, 7}};
+ // clang-format on
+
+ auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+ auto bitcast1 = m::Bitcast(m::Parameter(1)).WithShape(F32, {32});
+
+ Shape expected0 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {4}), ShapeUtil::MakeShape(F32, {8})});
+ Shape expected1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {16}), ShapeUtil::MakeShape(F32, {32})});
+ auto reduce_scatter = m::ReduceScatter(bitcast0, bitcast1)
+ .WithShapeEqualTo(&expected0)
+ .WithReplicaGroups(scatter_gather_groups);
+ auto all_reduce = m::AllReduce(m::GetTupleElement(reduce_scatter, 0),
+ m::GetTupleElement(reduce_scatter, 1))
+ .WithShapeEqualTo(&expected0)
+ .WithReplicaGroups(new_all_reduce_groups);
+ auto all_gather = m::AllGather(m::GetTupleElement(all_reduce, 0),
+ m::GetTupleElement(all_reduce, 1))
+ .WithShapeEqualTo(&expected1)
+ .WithReplicaGroups(scatter_gather_groups);
+ auto bitcast2 =
+ m::Bitcast(m::GetTupleElement(all_gather, 0)).WithShape(F32, {4, 4});
+ auto bitcast3 =
+ m::Bitcast(m::GetTupleElement(all_gather, 1)).WithShape(F32, {4, 4, 2});
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(bitcast2, bitcast3)));
+}
+
+TEST_F(AllReduceBlueConnectTest, MultiplePartitionsFilecheck) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[8,8] parameter(0)
+ ROOT crs = f32[8,8] all-reduce(p0), channel_id=1,
+ replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=add
+})";
+ HloModuleConfig module_config;
+ SetModuleConfig(&module_config, /*replica_count=*/1, /*partition_count=*/8);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ // Note: When matching strings like "replica_groups={{0,1,2,3}}", FileCheck
+ // interprets the string inside the double braces as regex. So to match such
+ // strings, we use "replica_groups={{..0,1,2,3..}}", where the dots match the
+ // opening and closing braces.
+ RunAndFilecheckHloRewrite(hlo_string, std::move(pass), R"(
+ CHECK: %p0 = f32[8,8]{1,0} parameter(0)
+ CHECK-NEXT: [[bitcast:%[^ ]+]] = f32[64]{0} bitcast(%p0)
+ CHECK-NEXT: [[reduce_scatter:%[^ ]+]] = f32[16]{0} reduce-scatter([[bitcast]]), channel_id=2, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+ CHECK-NEXT: [[all_reduce:%[^ ]+]] = f32[16]{0} all-reduce([[reduce_scatter]]), channel_id=1, replica_groups={{..0,4.,.1,5.,.2,6.,.3,7..}}, use_global_device_ids=true, to_apply=%add
+ CHECK-NEXT: [[all_gather:%[^ ]+]] = f32[64]{0} all-gather([[all_reduce]]), channel_id=3, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, dimensions={0}, use_global_device_ids=true
+ CHECK-NEXT: ROOT [[output:%[^ ]+]] = f32[8,8]{1,0} bitcast([[all_gather]])
+}
+)",
+ /*after_pass_checks=*/nullptr, &module_config);
+}
+
+TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesWithinReplicaGroup) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[4,4] parameter(0)
+ ROOT crs = f32[4,4] all-reduce(p0),
+ replica_groups={{0,1,2,7},{3,4,5,6}}, to_apply=add
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/8);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesAcrossReplicaGroups) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[4,4] parameter(0)
+ ROOT crs = f32[4,4] all-reduce(p0),
+ replica_groups={{0,1,4,5},{2,3,6,7},{8,9,10,11},{12,13,14,15}}, to_apply=add
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/16);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(AllReduceBlueConnectTest, OperandIndivisible) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[4,4] parameter(0)
+ p1 = f32[9] parameter(1)
+ ROOT crs = (f32[4,4], f32[9]) all-reduce(p0, p1), to_apply=add
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/8);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(AllReduceBlueConnectTest, ControlDeps) {
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[4,4] parameter(0)
+ p1 = f32[4,4] parameter(1)
+ add = f32[4,4] add(p0, p1)
+ crs = f32[4,4] all-reduce(p0), to_apply=add, control-predecessors={add}
+ ROOT add1 = f32[4,4] add(crs, add), control-predecessors={crs}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/8);
+
+ // Remember all-reduce's control succ and preds.
+ const HloInstruction* ar =
+ module->entry_computation()->root_instruction()->operand(0);
+ auto expected_preds = ar->control_predecessors();
+ auto expected_succs = ar->control_successors();
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+ // clang-format off
+ std::vector<std::vector<int64_t>> scatter_gather_groups = {
+ {0, 1, 2, 3}, {4, 5, 6, 7}};
+ std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+ {0, 4}, {1, 5}, {2, 6}, {3, 7}};
+ // clang-format on
+
+ const HloInstruction *matched_rs, *matched_bitcast;
+ auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+ auto reduce_scatter = m::ReduceScatter(&matched_rs, bitcast)
+ .WithShape(F32, {4})
+ .WithReplicaGroups(scatter_gather_groups);
+ auto all_reduce = m::AllReduce(reduce_scatter)
+ .WithShape(F32, {4})
+ .WithReplicaGroups(new_all_reduce_groups);
+ auto all_gather = m::AllGather(all_reduce)
+ .WithShape(F32, {16})
+ .WithReplicaGroups(scatter_gather_groups);
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Add()));
+
+ EXPECT_THAT(
+ root->operand(0),
+ GmockMatch(
+ m::Bitcast(&matched_bitcast, all_gather).WithShape(F32, {4, 4})));
+
+ // Verify that control dependencies are transferred correctly.
+ EXPECT_THAT(matched_rs, GmockMatch(m::Op().WithControlDeps(
+ absl::MakeSpan(expected_preds), {})));
+ EXPECT_THAT(matched_bitcast, GmockMatch(m::Op().WithControlDeps(
+ {}, absl::MakeSpan(expected_succs))));
+}
+
+TEST_F(AllReduceBlueConnectTest, ReduceScatterUnchanged) {
+ // Tests that this pass does not affect reduce-scatter. In principle, the
+ // BlueConnect algorithm could be applied to reduce-scatter, but for now it
+ // doesn't.
+ constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+ p0 = f32[8,4] parameter(0)
+ ROOT crs = f32[1,4] reduce-scatter(p0), dimensions={0}, to_apply=add
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ SetModuleConfig(*module, /*replica_count=*/8);
+
+ AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+ EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc
new file mode 100644
index 0000000..51f71c0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc
@@ -0,0 +1,436 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_splitter.h"
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include "absl/cleanup/cleanup.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/collective_device_list.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_opt_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+// Structure containing the newly calculated replica groups.
+struct ARReplicaGroups {
+ // First AR's replica group.
+ std::vector<ReplicaGroup> first_ar_replica_groups;
+ // Second AR's replica group.
+ std::vector<ReplicaGroup> second_ar_replica_groups;
+};
+
+// Contains relevant data to rewrite the AR + DS into AR + DS + AR.
+struct AllReduceRewriteSpec {
+ // Determines a dimension on which DS occurs.
+ int split_dim;
+ // Determines the size of the process group.
+ int group_size;
+ // AllReduce instruction to be rewritten.
+ HloAllReduceInstruction* all_reduce;
+ // DynamicSlice following the `all_reduce` indicating logical RS.
+ HloDynamicSliceInstruction* dynamic_slice;
+ // New replica groups for an `all_reduce`.
+ ARReplicaGroups replica_groups;
+
+ std::string ToString() {
+ return absl::Substitute(
+ "{\n split_dim=$0\n group_size=$1\n all_reduce=$2\n "
+ "dynamic_slice=$3\n}\n",
+ split_dim, group_size, all_reduce->ToString(),
+ dynamic_slice->ToString());
+ }
+};
+
+// Contains the relevant metadata for debugging why rewrite is infeasible.
+struct RewriteInfeasibleReason {
+ // Instruction for which it is infeasible to do a rewrite.
+ const HloInstruction* ar;
+ // Describes a reason of infeasibility.
+ std::string message;
+};
+
+// Hashable container to hold replica groups.
+struct ReplicaGroups {
+ std::vector<ReplicaGroup> replica_groups;
+
+ template <typename H>
+ friend H AbslHashValue(H h, const ReplicaGroups& rg) {
+ return H::combine(std::move(h), rg.replica_groups.size());
+ }
+
+ friend bool operator==(const ReplicaGroups& item,
+ const ReplicaGroups& other) {
+ if (item.replica_groups.size() != other.replica_groups.size()) {
+ return false;
+ }
+ for (int i = 0; i < item.replica_groups.size(); i++) {
+ const ReplicaGroup& item_replica_group = item.replica_groups[i];
+ const ReplicaGroup& other_replica_group = other.replica_groups[i];
+ for (int i = 0; i < item_replica_group.replica_ids_size(); i++) {
+ if (item_replica_group.replica_ids(i) !=
+ other_replica_group.replica_ids(i)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+};
+
+using ARReplicaGroupMap =
+ absl::flat_hash_map<ReplicaGroups,
+ std::vector<const HloAllReduceInstruction*>>;
+
+using RewriteDecision =
+ std::variant<AllReduceRewriteSpec, RewriteInfeasibleReason>;
+
+// Returns a single dimension which is being split by `ds`. Returns
+// std::nullopt if there are more, or no dimension to be split.
+std::optional<int> GetSplitDim(const HloAllReduceInstruction& ar,
+ const HloDynamicSliceInstruction& ds) {
+ int split_dim = -1;
+ int num_dims = 0;
+ for (int64_t dim = 0; dim < ar.shape().rank(); ++dim) {
+ if (ar.shape().dimensions(dim) != ds.shape().dimensions(dim)) {
+ num_dims++;
+ split_dim = dim;
+ }
+ }
+ if (num_dims != 1) {
+ VLOG(2) << "No support for multiple nor 0 split dims.";
+ return std::nullopt;
+ }
+ return split_dim;
+}
+
+// For input collective instruction `ar` get the process group size (# shards).
+std::optional<int> GetProcessGroupSize(const HloAllReduceInstruction& ar,
+ const HloDynamicSliceInstruction& ds) {
+ CHECK(ds.operand(0) == &ar) << "Irrelevant AR + DS pair.";
+ std::optional<int> split_dim = GetSplitDim(ar, ds);
+ if (!split_dim.has_value()) {
+ return std::nullopt;
+ }
+
+ return ar.shape().dimensions(*split_dim) /
+ ds.dynamic_slice_sizes()[*split_dim];
+}
+
+ARReplicaGroupMap GetReplicaGroupsMap(HloComputation& computation) {
+ ARReplicaGroupMap map;
+ hlo_query::ForEachInstructionWithOpcode(
+ computation, HloOpcode::kAllReduce,
+ [&map](const HloInstruction* instruction) {
+ const HloAllReduceInstruction* ar =
+ Cast<HloAllReduceInstruction>(instruction);
+ auto rgs = ReplicaGroups{ar->replica_groups()};
+ map[rgs].push_back(ar);
+ });
+ return map;
+}
+
+ARReplicaGroups GetNewReplicaGroups(int group_size, int num_partitions) {
+ CHECK_EQ(num_partitions % group_size, 0);
+
+ std::vector<ReplicaGroup> first_ar_rgs, second_ar_rgs;
+ int num_units = num_partitions / group_size;
+ first_ar_rgs.reserve(num_units);
+ second_ar_rgs.reserve(group_size);
+
+ // Construct first AR replica groups.
+ for (int u = 0; u < group_size * num_units; u += group_size) {
+ ReplicaGroup& group = first_ar_rgs.emplace_back();
+ for (int r = u; r < u + group_size; r++) {
+ group.add_replica_ids(r);
+ }
+ }
+
+ // Construct second AR replica groups.
+ for (int g = 0; g < group_size; g++) {
+ ReplicaGroup& group = second_ar_rgs.emplace_back();
+ for (int r = g; r < group_size * num_units; r += group_size) {
+ group.add_replica_ids(r);
+ }
+ }
+ return {
+ /*first_ar_replica_groups=*/first_ar_rgs,
+ /*second_ar_replica_groups=*/second_ar_rgs,
+ };
+}
+
+// Returns true if `spec` can be transformed into a logical reduce scatter.
+// False otherwise.
+bool IsLogicalReduceScatter(const HloModule& module,
+ const AllReduceRewriteSpec& spec,
+ HloComputation& computation) {
+ HloAllReduceInstruction& ar = *spec.all_reduce;
+ CHECK_EQ(ar.user_count(), 1);
+ CHECK_EQ(module.config().replica_count(), 1);
+
+ HloInstruction* first_ar =
+ computation.AddInstruction(HloInstruction::CreateAllReduce(
+ ar.shape(), ar.operands(), ar.to_apply(),
+ CollectiveDeviceList(spec.replica_groups.first_ar_replica_groups),
+ ar.constrain_layout(), hlo_query::NextChannelId(module),
+ ar.use_global_device_ids()));
+
+ HloInstruction* ds = ar.users()[0];
+ auto* old_operand = ds->mutable_operand(0);
+ if (!ds->ReplaceOperandWith(0, first_ar).ok()) {
+ return false;
+ }
+ absl::Cleanup _ = [&] {
+ CHECK_OK(ds->ReplaceOperandWith(0, old_operand));
+ CHECK_OK(computation.RemoveInstruction(first_ar));
+ };
+ return MatchReduceScatter(Cast<HloAllReduceInstruction>(first_ar),
+ module.config().num_partitions(),
+ module.config().replica_count(),
+ /*allow_multiple_split_dims=*/false,
+ /*allow_intervening_reshape=*/true)
+ .has_value();
+}
+
+// Determine whether the given `spec`'s AllReduce instruction is profitable to
+// split. Currently it employs a simple heuristic, and it checks whether there
+// exists at least one all reduce with same replica groups as any of the all
+// reduce's replica groups after the potential split.
+bool IsProfitableToSplit(const ARReplicaGroupMap& replica_map,
+ const AllReduceRewriteSpec& spec) {
+ auto new_rgs = spec.replica_groups;
+ bool first_replica_exists =
+ replica_map.contains(ReplicaGroups{new_rgs.first_ar_replica_groups});
+ bool second_replica_exists =
+ replica_map.contains(ReplicaGroups{new_rgs.second_ar_replica_groups});
+ return first_replica_exists || second_replica_exists;
+}
+
+RewriteDecision CanRewrite(const HloModule& module,
+ const ARReplicaGroupMap& replica_map,
+ HloComputation& computation,
+ HloInstruction& instruction) {
+ // We rely on SPMD partitioning enabled, thus asserting `replica_count` = 1.
+ const HloModuleConfig& config = module.config();
+ if (config.use_auto_spmd_partitioning() || !config.use_spmd_partitioning() ||
+ config.replica_count() != 1) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Supporting only SPMD partitioning scheme.",
+ };
+ }
+
+ if (instruction.opcode() != HloOpcode::kAllReduce) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Cannot rewrite an AllReduce, since it's not AllReduce.",
+ };
+ }
+
+ auto* ar = Cast<HloAllReduceInstruction>(&instruction);
+
+ if (!ar->use_global_device_ids()) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Only global ids are supported currently.",
+ };
+ }
+
+ if (ar->user_count() != 1 ||
+ ar->users().front()->opcode() != HloOpcode::kDynamicSlice) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Cannot rewrite AllReduce if it is not a logical reduce scatter.",
+ };
+ }
+
+ auto* ds = Cast<HloDynamicSliceInstruction>(ar->users().front());
+
+ if (ds->user_count() > 1) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Exactly one user of dynamic slice is required for a rewrite.",
+ };
+ }
+
+ int num_partitions = config.num_partitions();
+
+ std::vector<ReplicaGroup> rgs = ar->replica_groups();
+ if (rgs.size() != 1 || rgs.front().replica_ids_size() != num_partitions) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ absl::StrCat("Cannot determine a valid split with num_partitions: ",
+ num_partitions),
+ };
+ }
+
+ std::optional<int> split_dim = GetSplitDim(*ar, *ds);
+ if (!split_dim.has_value()) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Cannot get a split dim.",
+ };
+ }
+
+ std::optional<int> group_size = GetProcessGroupSize(*ar, *ds);
+ if (!group_size.has_value()) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Cannot determine a group size.",
+ };
+ }
+
+ if (num_partitions == group_size) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Nothing to rewrite",
+ };
+ }
+
+ if (num_partitions % *group_size != 0) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Group size does not evenly divide the number of partitions",
+ };
+ }
+
+ auto spec = AllReduceRewriteSpec{
+ /*split_dim=*/*split_dim,
+ /*group_size=*/*group_size,
+ /*all_reduce=*/ar,
+ /*dynamic_slice=*/ds,
+ /*replica_groups=*/GetNewReplicaGroups(*group_size, num_partitions),
+ };
+
+ if (!IsLogicalReduceScatter(module, spec, computation)) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Not a logical reduce scatter.",
+ };
+ }
+
+ if (!IsProfitableToSplit(replica_map, spec)) {
+ return RewriteInfeasibleReason{
+ &instruction,
+ "Splitting is not profitable.",
+ };
+ }
+
+ return spec;
+}
+
+absl::StatusOr<bool> SplitAllReduce(const HloModuleConfig& config,
+ AllReduceRewriteSpec spec,
+ HloComputation& computation) {
+ int64_t next_channel_id =
+ hlo_query::NextChannelId(*spec.all_reduce->GetModule());
+ VLOG(1) << "AR splitting spec: " << spec.ToString();
+ // Create first AR.
+ int num_partitions = config.num_partitions();
+ // # of shards within a replica
+ int group_size = spec.group_size;
+
+ CHECK_EQ(num_partitions % group_size, 0);
+
+ HloAllReduceInstruction& ar = *spec.all_reduce;
+ HloDynamicSliceInstruction& ds = *spec.dynamic_slice;
+
+ const auto& [first_ar_replica_groups, second_ar_replica_groups] =
+ spec.replica_groups;
+ int channel_id = next_channel_id++;
+ HloInstruction* first_ar =
+ computation.AddInstruction(HloInstruction::CreateAllReduce(
+ ar.shape(), ar.operands(), ar.to_apply(),
+ CollectiveDeviceList(first_ar_replica_groups), ar.constrain_layout(),
+ channel_id, ar.use_global_device_ids()));
+
+ // Create second AR.
+ channel_id = next_channel_id++;
+ HloInstruction* second_ar =
+ computation.AddInstruction(HloInstruction::CreateAllReduce(
+ ds.shape(), {&ds}, ar.to_apply(),
+ CollectiveDeviceList(second_ar_replica_groups), ar.constrain_layout(),
+ channel_id, ar.use_global_device_ids()));
+
+ // Rewire.
+ TF_RETURN_IF_ERROR(computation.ReplaceInstruction(&ar, first_ar));
+ if (ds.IsRoot()) {
+ computation.set_root_instruction(second_ar);
+ }
+ TF_RETURN_IF_ERROR(ds.ReplaceAllUsesWith(second_ar));
+ return true; // changed
+}
+
+// Splits `instruction` if it finds it is feasible and profitable to do so.
+// Return true if `instruction` has been rewritten, or false otherwise.
+absl::StatusOr<bool> SplitAllReduce(const HloModule& module,
+ const ARReplicaGroupMap& replica_map,
+ HloComputation& computation,
+ HloInstruction& instruction) {
+ RewriteDecision spec =
+ CanRewrite(module, replica_map, computation, instruction);
+ if (std::holds_alternative<RewriteInfeasibleReason>(spec)) {
+ auto reason = std::get<RewriteInfeasibleReason>(spec);
+ VLOG(1) << "Cannot process {" << reason.ar->ToString()
+ << "} due to : " << reason.message;
+ return false; // changed
+ }
+ return SplitAllReduce(module.config(), std::get<AllReduceRewriteSpec>(spec),
+ computation); // changed
+}
+
+} // namespace
+
+absl::StatusOr<bool> AllReduceSplitter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+
+ for (auto* computation : module->computations(execution_threads)) {
+ ARReplicaGroupMap replica_map = GetReplicaGroupsMap(*computation);
+ for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
+ TF_ASSIGN_OR_RETURN(bool rewritten, SplitAllReduce(*module, replica_map,
+ *computation, *instr));
+ changed |= rewritten;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h
new file mode 100644
index 0000000..91e0811
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h
@@ -0,0 +1,77 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Rewrites global AR if it is in the form of AR + DS and matches existing
+// replica groups into a logical RS followed by AR.
+//
+// If the pass detects AR followed by DS, then it checks whether
+// it is profitable to break it down into a logical RS (but AR + DS still),
+// followed by an AR to keep the rewrite numerically equivalent.
+//
+// Consider following example:
+//
+// Input program:
+// HloModule m, num_partitions=8
+// p = partition_id()
+// ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3,4,5,6,7}}
+// ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
+//
+// There is a global AR performing a reduction over 8 partitions.
+// However DS is performing 8-sized slice of a 32-sized tensor which implies
+// only 4 distinct slices of a tensor, which further implies 2 replicas of each
+// calculated slice. This can be expressed as RS within the replicas followed by
+// AR across the replicas. The transformation limits collectives to the data
+// that is actually needed for the requested slice.
+//
+// Output program:
+// HloModule m, num_partitions=8
+// p = partition_id()
+// ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3},{4,5,6,7}}
+// ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
+// ar.2 = bf16[32] all-reduce(ds), replica_groups={{0,4},{1,5},{2,6},{3,7}}
+//
+// In addition the pass does the rewrite only if it finds it profitable to do
+// so. The profitability function is simple, and just checks whether there are
+// any collectives with same replica groups. If there are then the combiner pass
+// can pick it up, and fuse it into the same NCCL call.
+//
+// While the solution is orthogonal to existing known distribution patterns, in
+// practice it is profitable for HSDP style communication pattern.
+// https://arxiv.org/pdf/2203.11014
+//
+class AllReduceSplitter : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "all-reduce-splitter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc
new file mode 100644
index 0000000..ec2e66d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_splitter.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+
+class AllReduceSplitterTest : public HloTestBase {
+ public:
+ absl::StatusOr<std::unique_ptr<HloModule>> PrepareModule(
+ absl::string_view hlo_module, int64_t num_replicas,
+ int64_t num_partitions) {
+ HloModuleConfig config = GetModuleConfigForTest(
+ /*replica_count=*/num_replicas,
+ /*num_partitions=*/num_partitions);
+ config.set_use_spmd_partitioning(num_partitions > 1);
+ return ParseAndReturnVerifiedModule(hlo_module, config);
+ }
+
+ size_t AllReduceCount(const HloModule &module) {
+ return CollectiveCount(module, HloOpcode::kAllReduce);
+ }
+
+ private:
+ size_t CollectiveCount(const HloModule &module, HloOpcode opcode) {
+ return absl::c_count_if(
+ module.entry_computation()->instructions(),
+ [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
+ }
+};
+
+class AllReduceSplitterFilecheckTest : public AllReduceSplitterTest {
+ public:
+ absl::Status FileCheck(const std::string &hlo_text,
+ absl::string_view pattern) {
+ TF_ASSIGN_OR_RETURN(bool matched, RunFileCheck(hlo_text, pattern));
+ if (!matched) {
+ return absl::InternalError("Filecheck failed.");
+ }
+ return absl::OkStatus();
+ }
+};
+
+TEST_F(
+ AllReduceSplitterFilecheckTest,
+ MatchBasicPatternIfDynamicSliceIsRootAndThereExistsAllReduceWithSameReplicaGroups) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
+ TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+ CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+ CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+ CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]}
+ CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
+ CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
+ CHECK: %[[AR1:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
+ CHECK-SAME: replica_groups={[[DESIRED_RGS]]}
+ CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR1]], s32[] %[[_:.*]])
+ CHECK-SAME: dynamic_slice_sizes={1024}
+ CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
+ CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+ )"));
+}
+
+TEST_F(
+ AllReduceSplitterTest,
+ DoesNotMatchMatchBasicPatternIfDynamicSliceIsRootAndThereIsNoAllReduceWithSameReplicaGroups) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+
+ EXPECT_EQ(AllReduceCount(*module), 1);
+}
+
+TEST_F(
+ AllReduceSplitterFilecheckTest,
+ MatchBasicPatternIfDynamicSliceIsNotRootAndThereExistsAllReduceWithSameReplicaGroups) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ zero = bf16[] constant(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+ broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
+ ROOT _ = tuple(broadcast, first.ar)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
+ TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+ CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+ CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
+ CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
+ CHECK: %[[AR0:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
+ CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]}
+ CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR0]], s32[] %[[_:.*]])
+ CHECK-SAME: dynamic_slice_sizes={1024}
+ CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
+ CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+ CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+ CHECK-SAME: replica_groups={[[DESIRED_RGS]]}
+ CHECK: ROOT
+ CHECK-NOT: %[[AR1]]
+ CHECK-SAME: %[[EXISTING_AR]]
+ )"));
+}
+
+TEST_F(
+ AllReduceSplitterTest,
+ DoesNotMatchBasicPatternIfDynamicSliceIsNotRootAndThereIsNoAllReduceWithSameReplicaGroups) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ p.1 = bf16[2,4096,4096] parameter(1)
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+ broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
+ add = bf16[2,4096,4096] add(p,p.1)
+ ROOT _ = tuple(broadcast, add)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+ EXPECT_EQ(AllReduceCount(*module), 1);
+}
+
+TEST_F(AllReduceSplitterTest,
+ DoesNotMatchBasicPatternIfDynamicSliceIsFullySharded) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(512)
+ offset = s32[] multiply(reshape, slice_size)
+ ROOT _ = bf16[512] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={512}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+ EXPECT_EQ(AllReduceCount(*module), 2);
+}
+
+TEST_F(AllReduceSplitterTest,
+ DoesNotMatchBasicPatternIfItIsNotCompiledWithSPMDPartitioning) {
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+ HloModuleConfig config =
+ GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/8);
+ config.set_use_spmd_partitioning(false);
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string, config));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+ EXPECT_THAT(AllReduceCount(*module), 2);
+}
+
+TEST_F(AllReduceSplitterTest,
+ DoesNotMatchBasicPatternIfUseGlobalDeviceIdsIsFalse) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, channel_id=1
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, channel_id=2
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+
+ EXPECT_EQ(AllReduceCount(*module), 2);
+}
+
+TEST_F(AllReduceSplitterTest,
+ DoesNotMatchBasicPatternIfIsNotCrossAllPartitionsAllReduce) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+
+ EXPECT_EQ(AllReduceCount(*module), 2);
+}
+
+TEST_F(
+ AllReduceSplitterFilecheckTest,
+ PipelineMatchesBasicPatternWithDynamicSliceAsRootAndRewritesToReduceScatter) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ zero = bf16[] constant(0)
+ reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ HloPassPipeline pipeline("all-reduce-splitter-rewrite");
+ pipeline.AddPass<AllReduceSplitter>();
+ pipeline.AddPass<ReduceScatterCreator>();
+ EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
+ TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+ CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+ CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+ CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]}
+ CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
+ CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
+ CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
+ CHECK-SAME: replica_groups={[[DESIRED_RGS]]}
+ CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
+ CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+ )"));
+}
+
+TEST_F(
+ AllReduceSplitterFilecheckTest,
+ PipelineMatchesBasicPatternWithDynamicSliceNotAsRootAndRewritesToReduceScatter) { // NOLINT
+ absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+ a = bf16[] parameter(0)
+ b = bf16[] parameter(1)
+ ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+ p = bf16[2,4096,4096] parameter(0)
+ zero = bf16[] constant(0)
+ first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+ all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+ table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+ pid = u32[] partition-id()
+ id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(id)
+ slice_size = s32[] constant(1024)
+ offset = s32[] multiply(reshape, slice_size)
+ dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+ broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
+ ROOT _ = tuple(broadcast, first.ar)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+ HloPassPipeline pipeline("all-reduce-splitter-rewrite");
+ pipeline.AddPass<AllReduceSplitter>();
+ pipeline.AddPass<ReduceScatterCreator>();
+ EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
+ TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+ CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+ CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
+ CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
+ CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
+ CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
+ CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+ CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+ CHECK: ROOT
+ CHECK-NOT: %[[AR1]]
+ CHECK-SAME: %[[EXISTING_AR]]
+ )"));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc
new file mode 100644
index 0000000..aa76aff
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc
@@ -0,0 +1,55 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_collective_annotator.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> AsyncCollectiveAnnotator::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (!hlo_query::IsAsyncCollectiveStartOp(instruction)) {
+ continue;
+ }
+ CollectiveBackendConfig config;
+ config.set_is_sync(!is_collective_async_(instruction));
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ instruction->backend_config<GpuBackendConfig>());
+ *gpu_config.mutable_collective_backend_config() = config;
+ TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h
new file mode 100644
index 0000000..1b41d50
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h
@@ -0,0 +1,52 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+
+// Annotate async collectives with CollectiveBackendConfig.
+class AsyncCollectiveAnnotator : public HloModulePass {
+ public:
+ explicit AsyncCollectiveAnnotator(HloPredicate is_collective_async)
+ : is_collective_async_(std::move(is_collective_async)) {}
+ absl::string_view name() const override {
+ return "async-collective-annotator";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ HloPredicate is_collective_async_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc
new file mode 100644
index 0000000..6622a7b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_collective_annotator.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/test_macros.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithAsync
+
+ addf32 {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+ }
+
+ addf16 {
+ p0 = f16[] parameter(0)
+ p1 = f16[] parameter(1)
+ ROOT add = f16[] add(p0, p1)
+ }
+
+ reduce_scatterf32 {
+ p0 = f32[2] parameter(0)
+ ROOT result = f32[1] reduce-scatter(p0), replica_groups={},
+ dimensions={0}, to_apply=addf32
+ }
+
+ ENTRY entry {
+ pf32 = f32[1] parameter(0)
+ pf16 = f16[1] parameter(1)
+
+ arf32-start = f32[1] all-reduce-start(pf32), to_apply=addf32
+ arf32-done = f32[1] all-reduce-done(arf32-start)
+
+ arf16-start = f16[1] all-reduce-start(pf16), to_apply=addf16
+ arf16-done = f16[1] all-reduce-done(arf16-start)
+
+ agf32-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}
+ agf32-done = f32[2] all-gather-done(agf32-start)
+
+ agf16-start = (f16[1], f16[2]) all-gather-start(pf16), dimensions={0}
+ agf16-done = f16[2] all-gather-done(agf16-start)
+
+ cpf32-start = (f32[1], f32[1], u32[], u32[]) collective-permute-start(pf32),
+ source_target_pairs={{0,1}, {1,0}}
+ cpf32-done = f32[1] collective-permute-done(cpf32-start)
+
+ cpf16-start = (f16[1], f16[1], u32[], u32[]) collective-permute-start(pf16),
+ source_target_pairs={{0,1}, {1,0}}
+ cpf16-done = f16[1] collective-permute-done(cpf16-start)
+
+ rsf32-start = ((f32[2]), f32[1]) async-start(agf32-done), calls=reduce_scatterf32
+ rsf32-done = f32[1] async-done(rsf32-start), calls=reduce_scatterf32
+
+ ROOT tuple = (f32[1], f16[1], f32[2], f16[2], f32[1], f16[1], f32[1])
+ tuple(arf32-done, arf16-done, agf32-done, agf16-done, cpf32-done,
+ cpf16-done, rsf32-done)
+ }
+)";
+
+struct TestCase {
+ std::string test_name;
+ HloPredicate is_async_predicate;
+ absl::flat_hash_set<absl::string_view> expected_async;
+ absl::flat_hash_set<absl::string_view> expected_sync;
+};
+
+class AsyncCollectiveAnnotatorTest
+ : public HloTestBase,
+ public ::testing::WithParamInterface<TestCase> {};
+
+XLA_TEST_P(AsyncCollectiveAnnotatorTest, Test) {
+ const TestCase& test_case = GetParam();
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString, /*replica_count=*/2));
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool changed,
+ AsyncCollectiveAnnotator(test_case.is_async_predicate).Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ // Assert that all async collectives are annotated with the backend config.
+ for (const HloInstruction* hlo :
+ module->entry_computation()->instructions()) {
+ if (!hlo_query::IsAsyncCollectiveStartOp(hlo)) {
+ continue;
+ }
+ auto gpu_config = hlo->backend_config<GpuBackendConfig>();
+ ASSERT_TRUE(gpu_config.ok());
+
+ const CollectiveBackendConfig& backend_config =
+ gpu_config.value().collective_backend_config();
+ if (test_case.expected_async.contains(hlo->name())) {
+ EXPECT_FALSE(backend_config.is_sync());
+ }
+
+ if (test_case.expected_sync.contains(hlo->name())) {
+ EXPECT_TRUE(backend_config.is_sync());
+ }
+ }
+}
+
+std::vector<TestCase> TestCases() {
+ HloPredicate is_f16 = [](const HloInstruction* hlo) {
+ return hlo->operand(0)->shape().element_type() == PrimitiveType::F16;
+ };
+
+ return {
+ {"all_async",
+ HloPredicateTrue, /*expected_async=*/
+ {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
+ "cpf32-start", "cpf16-start", "rsf32-start"},
+ /*expected_sync=*/{}},
+ {"all_sync",
+ HloPredicateFalse,
+ /*expected_async=*/{},
+ /*expected_sync=*/
+ {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
+ "cpf32-start", "cpf16-start", "rsf32-start"}},
+ {"ar_async",
+ HloPredicateIsOp<HloOpcode::kAllReduceStart>,
+ /*expected_async=*/
+ {"arf32-start", "arf16-start"},
+ /*expected_sync=*/
+ {"agf32-start", "agf16-start", "cpf32-start", "cpf16-start",
+ "rsf32-start"}},
+ {"cp_async",
+ HloPredicateIsOp<HloOpcode::kCollectivePermuteStart>,
+ /*expected_async=*/
+ {"cpf32-start", "cpf16-start"},
+ /*expected_sync=*/
+ {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
+ "rsf32-start"}},
+ {"f16_async",
+ is_f16,
+ /*expected_async=*/{"arf16-start", "agf16-start", "cpf16-start"},
+ /*expected_sync=*/
+ {"arf32-start", "agf32-start", "cpf32-start", "rsf32-start"}},
+ };
+}
+
+std::string TestCaseName(const ::testing::TestParamInfo<TestCase>& test_case) {
+ return test_case.param.test_name;
+}
+
+INSTANTIATE_TEST_SUITE_P(AsyncCollectiveAnnotatorTest,
+ AsyncCollectiveAnnotatorTest,
+ ::testing::ValuesIn(TestCases()), TestCaseName);
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc
new file mode 100644
index 0000000..baeecba
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc
@@ -0,0 +1,70 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_wrapper.h"
+
+#include <algorithm>
+#include <deque>
+#include <iterator>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+
+namespace xla::gpu {
+
+absl::StatusOr<bool> AsyncWrapper::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+
+ std::deque<HloComputation*> computations;
+ computations.push_back(module->entry_computation());
+ while (!computations.empty()) {
+ HloComputation* computation = computations.front();
+ computations.pop_front();
+
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ if (predicate_(instruction)) {
+ // If the predicate matches, then wrap the instructions in async blocks.
+ TF_RETURN_IF_ERROR(
+ computation
+ ->CreateAsyncInstructions(instruction,
+ {ShapeUtil::MakeScalarShape(U32)})
+ .status());
+ changed = true;
+ continue;
+ }
+
+ // Otherwise, follow any `calls` to discover other instructions that can
+ // potentially be made async.
+ if (instruction->opcode() == HloOpcode::kCall) {
+ std::copy(instruction->called_computations().begin(),
+ instruction->called_computations().end(),
+ std::back_inserter(computations));
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.h b/third_party/xla/xla/service/gpu/transforms/async_wrapper.h
new file mode 100644
index 0000000..d6cefe8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.h
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_
+
+#include <functional>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// AsyncWrappers wrap instructions that match a given `predicate` into async
+// blocks (i.e. `async-start` and `async-stop` instructions) so that they run
+// concurrently.
+class AsyncWrapper : public HloModulePass {
+ public:
+ using Predicate = std::function<bool(HloInstruction*)>;
+ explicit AsyncWrapper(Predicate predicate)
+ : predicate_(std::move(predicate)) {}
+
+ absl::string_view name() const override { return "async-wrapper"; }
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const Predicate predicate_;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc
new file mode 100644
index 0000000..9d69899
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_wrapper.h"
+
+#include <memory>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/literal_test_util.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "tsl/platform/status_matchers.h"
+
+namespace xla::gpu {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+
+class AsyncWrapperTest : public HloTestBase {};
+
+int CountAsyncInstructions(HloComputation* computation) {
+ int count = 0;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ if (instruction->IsAsynchronous()) ++count;
+ }
+ return count;
+}
+
+TEST_F(AsyncWrapperTest, BasicFusion) {
+ const char* hlo_text = R"(
+ HloModule m
+
+ double1 {
+ p0 = f32[1] parameter(0)
+ ROOT add = f32[1] add(p0, p0)
+ }
+
+ double2 {
+ p0 = f32[1] parameter(0)
+ ROOT add = f32[1] add(p0, p0)
+ }
+
+ ENTRY main {
+ p0 = f32[1] parameter(0)
+ agg1 = f32[1] fusion(p0), kind=kLoop, calls=double1
+ agg2 = f32[1] fusion(p0), kind=kLoop, calls=double2
+ ROOT done = f32[1] add(agg1, agg2)
+ })";
+
+ std::unique_ptr<VerifiedHloModule> module =
+ ParseAndReturnVerifiedModule(hlo_text).value();
+
+ AsyncWrapper wrapper([](const HloInstruction* instruction) {
+ return instruction->opcode() == HloOpcode::kFusion;
+ });
+ EXPECT_THAT(wrapper.HloModulePass::Run(module.get()), IsOkAndHolds(true));
+ EXPECT_EQ(CountAsyncInstructions(module->entry_computation()), 4);
+
+ Literal argument = LiteralUtil::CreateR1<float>({1.0});
+ Literal expected = LiteralUtil::CreateR1<float>({4.0});
+
+ Literal result = ExecuteNoHloPasses(std::move(module), {&argument});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc
new file mode 100644
index 0000000..07f52fd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc
@@ -0,0 +1,231 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h"
+
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/literal_util.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+
+namespace {
+using SourceTargetPair = std::pair<int64_t, int64_t>;
+using SourceTargetPairs = std::vector<SourceTargetPair>;
+enum class CycleType { kUnknown, kForward, kBackward };
+
+// Returns true if the CollectivePermute instruction has a cycle in its
+// source-target pairs and should be decomposed.
+CycleType ShouldDecomposeWithCycleType(
+ const HloCollectivePermuteInstruction& collective_permute,
+ int64_t threshold_in_bytes) {
+ if (!collective_permute.channel_id().has_value()) {
+ return CycleType::kUnknown;
+ }
+
+ if (collective_permute.operand_count() != 1) {
+ return CycleType::kUnknown;
+ }
+
+ const Shape& result_shape = collective_permute.shape();
+ // Skip the transformation if there is any context data.
+ if (result_shape.IsTuple()) {
+ return CycleType::kUnknown;
+ }
+
+ CHECK(result_shape.IsArray());
+ if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) {
+ return CycleType::kUnknown;
+ }
+
+ const SourceTargetPairs& pairs = collective_permute.source_target_pairs();
+ if (pairs.size() == 1) {
+ return CycleType::kUnknown;
+ }
+
+ return IsForwardCycle(pairs) ? CycleType::kForward
+ : IsBackwardCycle(pairs) ? CycleType::kBackward
+ : CycleType::kUnknown;
+}
+
+// Constructs the frontend attributes for the two decomposed CollectivePermute
+// instructions.
+absl::Status GetFrontendAttributes(HloCollectivePermuteInstruction* cp,
+ CycleType cycle_type,
+ xla::FrontendAttributes& cp1_attr,
+ xla::FrontendAttributes& cp2_attr) {
+ cp1_attr = cp->frontend_attributes();
+ cp2_attr = cp->frontend_attributes();
+ auto validation_it =
+ cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+ if (validation_it == cp->frontend_attributes().map().end() ||
+ validation_it->second == "invalid") {
+ return absl::OkStatus();
+ }
+
+ auto statusor_bounds = ParseReplicaGroupsOnly(validation_it->second);
+ if (!statusor_bounds.ok()) {
+ return statusor_bounds.status();
+ }
+ const std::vector<ReplicaGroup>& bounds = statusor_bounds.value();
+ if (bounds.size() < 2) {
+ return Internal("Invalid number of replica groups");
+ }
+
+ int64_t num_pairs = bounds.size();
+ // A forward cycle has its backedge at the end while a backward cycle has its
+ // backedge at the beginning.
+ auto backedge_start = cycle_type == CycleType::kBackward
+ ? bounds.begin()
+ : bounds.begin() + num_pairs - 1;
+ auto other_edges_start =
+ cycle_type == CycleType::kBackward ? bounds.begin() + 1 : bounds.begin();
+ std::vector<ReplicaGroup> cp1_bounds(backedge_start, backedge_start + 1);
+ std::vector<ReplicaGroup> cp2_bounds(other_edges_start,
+ other_edges_start + num_pairs - 1);
+ auto bounds_to_string = [](const std::vector<ReplicaGroup> groups) {
+ return "{" +
+ absl::StrJoin(groups, ",",
+ [](std::string* out, const ReplicaGroup& value) {
+ absl::StrAppend(out, "{", value.replica_ids(0), ",",
+ value.replica_ids(1), "}");
+ }) +
+ "}";
+ };
+ std::string cp1_validation_str = bounds_to_string(cp1_bounds);
+ std::string cp2_validation_str = bounds_to_string(cp2_bounds);
+ (*cp1_attr.mutable_map())[kSendRecvValidationAttr] = cp1_validation_str;
+ (*cp2_attr.mutable_map())[kSendRecvValidationAttr] = cp2_validation_str;
+ return absl::OkStatus();
+}
+
+// Decomposes a CollectivePermute instruction with a cycle in its source-target
+// pairs into two CollectivePermute instructions.
+absl::Status DecomposeCollectivePermuteCycle(
+ HloCollectivePermuteInstruction* cp, HloComputation* computation,
+ HloModule* module, int64_t next_channel_id, CycleType cycle_type) {
+ const SourceTargetPairs& pairs = cp->source_target_pairs();
+ int64_t num_pairs = pairs.size();
+ // A forward cycle has its backedge at the end as in
+ // {{0,1},{1,2},{2,3},{3,0}} while a backward cycle has its backedge at the
+ // beginning as in {{0,3},{1,0},{2,1},{3,2}}.
+ auto backedge_start = cycle_type == CycleType::kBackward
+ ? pairs.begin()
+ : pairs.begin() + num_pairs - 1;
+ auto other_edges_start =
+ cycle_type == CycleType::kBackward ? pairs.begin() + 1 : pairs.begin();
+ SourceTargetPairs backedge(backedge_start, backedge_start + 1);
+ SourceTargetPairs other_edges(other_edges_start,
+ other_edges_start + num_pairs - 1);
+ const OpMetadata& metadata = cp->metadata();
+ xla::FrontendAttributes cp1_attr, cp2_attr;
+ TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr));
+
+ // Create the CollectivePermute instruction for the communication represented
+ // by the backedge.
+ HloInstruction* cp1 =
+ computation->AddInstruction(HloInstruction::CreateCollectivePermute(
+ cp->shape(), cp->mutable_operand(0), backedge,
+ cp->channel_id().value()));
+ cp1->set_metadata(metadata);
+ cp1->set_frontend_attributes(cp1_attr);
+ int64_t cp1_receiver = backedge.back().second;
+
+ // Create the CollectivePermute instruction for the communication represented
+ // byt other edges.
+ HloInstruction* cp2 =
+ computation->AddInstruction(HloInstruction::CreateCollectivePermute(
+ cp->shape(), cp->mutable_operand(0), other_edges, next_channel_id));
+ cp2->set_metadata(metadata);
+ cp2->set_frontend_attributes(cp2_attr);
+
+ // Calculate the received data as follows:
+ // partition = u32[] partition-id()
+ // constant = u32[] constant(cp1_receiver)
+ // compare0 = pred[] compare(partition, cp1_received), direction=EQ
+ // compare = pred[?] broadcast(compare0), dimensions={}
+ // recv-data = type[?] select(compare, cp1_done, cp2_done)
+ HloInstruction* partition =
+ computation->AddInstruction(HloInstruction::CreatePartitionId());
+ HloInstruction* constant = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(U32, cp1_receiver)));
+ HloInstruction* compare0 = computation->AddInstruction(
+ HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), partition,
+ constant, Comparison::Direction::kEq));
+ HloInstruction* compare =
+ computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(PRED, cp1->shape().dimensions()), compare0, {}));
+ HloInstruction* recv_data =
+ computation->AddInstruction(HloInstruction::CreateTernary(
+ cp1->shape(), HloOpcode::kSelect, compare, cp1, cp2));
+
+ TF_RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data));
+ TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp));
+
+ return absl::OkStatus();
+}
+} // namespace
+
+absl::StatusOr<bool> CollectivePermuteCycleDecomposer::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ int64_t next_channel_id;
+ for (auto comp : module->computations(execution_threads)) {
+ for (auto hlo : comp->MakeInstructionPostOrder()) {
+ if (hlo->opcode() != HloOpcode::kCollectivePermute) {
+ continue;
+ }
+ auto collective_permute = Cast<HloCollectivePermuteInstruction>(hlo);
+ CycleType cycle_type = ShouldDecomposeWithCycleType(*collective_permute,
+ threshold_in_bytes_);
+ if (cycle_type != CycleType::kUnknown) {
+ if (changed == false) {
+ next_channel_id = hlo_query::NextChannelId(*module);
+ changed = true;
+ }
+ TF_RETURN_IF_ERROR(DecomposeCollectivePermuteCycle(
+ collective_permute, comp, module, next_channel_id++, cycle_type));
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h
new file mode 100644
index 0000000..cfacd66
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h
@@ -0,0 +1,73 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// CollectivePermuteCycleDecomposer is a pass that converts CollectivePermute
+// instructions with all participants forming either a forward cycle (such as
+// {{0,1},{1,2},{2,3},{3,0}) or a backward cycle (such as {{3,2},{2,1},{1,0},
+// {0,3}}) into two CollectivePermute instructions. We currently restrict
+// this transformation to CollectivePermute using partition mode, with one
+// input, without any context data. Here is an example.
+//
+// before transformation:
+// start = (<rt>, <rt>) collective-permute(data),
+// source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+//
+// after transformation:
+// partition-id = u32[] partition-id()
+// constant = u32[] constant(0)
+// compare = pred[] compare(u32[] partition-id, u32[] constant),
+// direction=EQ
+// pred = pred[] broadcast(pred[] compare), dimensions={}
+// cp1 = (<rt>, <rt>) collective-permute(data), source_target_pairs={{3,0}}
+// cp2 = (<rt>, <rt>) collective-permute(data),
+// source_target_pairs={{0,1},{1,2},{2,3}}
+// data = <rt> select(pred, cp1, cp2)
+//
+class CollectivePermuteCycleDecomposer : public HloModulePass {
+ public:
+ explicit CollectivePermuteCycleDecomposer(int64_t threshold_in_bytes)
+ : threshold_in_bytes_(threshold_in_bytes) {}
+ absl::string_view name() const override {
+ return "collective-permute-cycle-decomposer";
+ }
+
+ using HloPassInterface::Run;
+ // Runs CollectivePermuteCycleDecomposer pass on computations in 'module'.
+ // Returns whether the 'module' was changed.
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ // Transform only if the size of the CollectivePermute data >= threshold.
+ int64_t threshold_in_bytes_;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc
new file mode 100644
index 0000000..ae537d9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc
@@ -0,0 +1,234 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+using ::testing::HasSubstr;
+using CollectivePermuteCycleDecomposerTest = HloTestBase;
+
+using ::testing::HasSubstr;
+using CollectivePermuteDecomposerTest = HloTestBase;
+
+TEST_F(CollectivePermuteDecomposerTest, DefaultChannelNotTransformed) {
+ const absl::string_view kModuleStr = R"(
+ HloModule test
+ ENTRY test_computation {
+ p = u32[] replica-id()
+ ROOT start = u32[] collective-permute(p),
+ source_target_pairs={{0,1},{1,0}}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kModuleStr)));
+ CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
+ const absl::string_view kModuleStr = R"(
+ HloModule test
+ ENTRY test_computation {
+ p = u32[] partition-id()
+ ROOT start = u32[] collective-permute(p), channel_id=1,
+ source_target_pairs={{0,0}}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kModuleStr)));
+ CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
+ const absl::string_view kModuleStr = R"(
+ HloModule test
+ ENTRY test_computation {
+ p = u32[] partition-id()
+ ROOT start = u32[] collective-permute(p), channel_id=1,
+ source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kModuleStr)));
+ CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
+ const absl::string_view kModuleStr = R"(
+ HloModule test
+ ENTRY test_computation {
+ p = u32[] partition-id()
+ ROOT start = u32[3,2] collective-permute(p), channel_id=1,
+ source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
+ frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
+ metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kModuleStr)));
+ CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ auto check_metadata = [](const HloInstruction* inst) {
+ EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
+ EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
+ EXPECT_EQ(inst->metadata().source_line(), 35);
+ };
+
+ HloCollectivePermuteInstruction* cp1 =
+ DynCast<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), "collective-permute"));
+ HloCollectivePermuteInstruction* cp2 =
+ DynCast<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), "collective-permute.1"));
+ EXPECT_NE(cp1, nullptr);
+ EXPECT_NE(cp2, nullptr);
+ EXPECT_EQ(cp1->operand(0), cp2->operand(0));
+ EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
+ EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
+ EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{3,10}}"));
+ EXPECT_THAT(cp2->ToString(),
+ HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
+ EXPECT_THAT(cp2->ToString(),
+ HasSubstr("_xla_send_recv_validation={{0,7},{1,8},{2,9}}"));
+ check_metadata(cp1);
+ check_metadata(cp2);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
+ const absl::string_view kModuleStr = R"(
+ HloModule test
+
+ while_cond {
+ param = (u32[], f32[2,2], f32[2,2]) parameter(0)
+ iter = u32[] get-tuple-element(param), index=0
+ max_iter = u32[] constant(3)
+ ROOT cmp = pred[] compare(iter, max_iter), direction=LT
+ }
+
+ while_body {
+ param = (u32[], f32[2,2], f32[2,2]) parameter(0)
+ iter = u32[] get-tuple-element(param), index=0
+ data = f32[2,2] get-tuple-element(param), index=1
+ weights = f32[2,2] get-tuple-element(param), index=2
+ cp = f32[2,2] collective-permute(data),
+ channel_id=1,
+ source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}},
+ frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}
+ matmul = f32[2,2] dot(weights, cp), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ iter_increment = u32[] constant(1)
+ next_iter = u32[] add(iter, iter_increment)
+ ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights)
+ }
+
+ ENTRY test_computation {
+ iter = u32[] constant(0)
+ data = f32[2,2] parameter(0)
+ weights = f32[2,2] parameter(1)
+ input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights)
+ while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
+ ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kModuleStr)));
+ CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+ EXPECT_TRUE(changed);
+ HloCollectivePermuteInstruction* cp1 =
+ DynCast<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), "collective-permute"));
+ HloCollectivePermuteInstruction* cp2 =
+ DynCast<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), "collective-permute.1"));
+ EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
+ EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{3,10}}"));
+ EXPECT_THAT(cp2->ToString(),
+ HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
+ EXPECT_THAT(cp2->ToString(),
+ HasSubstr("_xla_send_recv_validation={{0,7},{1,8},{2,9}}"));
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
+ const absl::string_view kModuleStr = R"(
+ HloModule test
+ ENTRY test_computation {
+ p = u32[] partition-id()
+ ROOT start = u32[] collective-permute(p), channel_id=1,
+ source_target_pairs={{0,3},{1,0},{2,1},{3,2}},
+ frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
+ metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kModuleStr)));
+ CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+ EXPECT_TRUE(changed);
+ auto check_metadata = [](const HloInstruction* inst) {
+ EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
+ EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
+ EXPECT_EQ(inst->metadata().source_line(), 35);
+ };
+
+ HloCollectivePermuteInstruction* cp1 =
+ DynCast<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), "collective-permute"));
+ HloCollectivePermuteInstruction* cp2 =
+ DynCast<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), "collective-permute.1"));
+ EXPECT_NE(cp1, nullptr);
+ EXPECT_NE(cp2, nullptr);
+ EXPECT_EQ(cp1->operand(0), cp2->operand(0));
+ EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
+ EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{0,3}}"));
+ EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{0,7}}"));
+ EXPECT_THAT(cp2->ToString(),
+ HasSubstr("source_target_pairs={{1,0},{2,1},{3,2}}"));
+ EXPECT_THAT(cp2->ToString(),
+ HasSubstr("_xla_send_recv_validation={{1,8},{2,9},{3,10}}"));
+ check_metadata(cp1);
+ check_metadata(cp2);
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc
new file mode 100644
index 0000000..e9df22a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc
@@ -0,0 +1,163 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h"
+
+#include "xla/literal_util.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/while_loop_analysis.h"
+
+namespace xla {
+
+// Finds and returns the non-constant operand in instr.
+//
+// CHECK-fails if instr doesn't have exactly one unique non-constant operand.
+static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
+ const HloInstruction* result = nullptr;
+ for (const HloInstruction* operand : instr->operands()) {
+ if (!operand->IsConstant()) {
+ if (result != nullptr) {
+ CHECK_EQ(result, operand);
+ }
+ result = operand;
+ }
+ }
+ CHECK_NE(result, nullptr);
+ return result;
+}
+
+// Finds the step (k) for while instruction, if the loop is of the form:
+//
+// while(cond) {
+// ind_var = ind_var + k
+// }
+//
+// If this pattern is not found, it returns std::nullopt.
+std::optional<int64_t> GetStep(HloInstruction* while_inst) {
+ // Get the update operation
+ std::optional<int64_t> indvar_tuple_idx =
+ GetLoopInductionVarTupleIdx(while_inst);
+ if (!indvar_tuple_idx) {
+ return std::nullopt;
+ };
+ auto* while_body_indvar_update =
+ while_inst->while_body()->root_instruction()->mutable_operand(
+ *indvar_tuple_idx);
+ auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
+
+ HloInstruction* trip_count_increase_step_instr = nullptr;
+ if (!Match(while_body_indvar_update,
+ match::AddAnyOrder(match::Op().Is(while_body_indvar),
+ match::Op(&trip_count_increase_step_instr)))) {
+ return std::nullopt;
+ }
+ return LiteralUtil::LiteralAsScalarInt64(
+ trip_count_increase_step_instr->literal());
+}
+
+absl::StatusOr<bool> CollectivePermuteValidIterationAnnotator::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* comp : module->computations(execution_threads)) {
+ for (HloInstruction* inst : comp->instructions()) {
+ if (inst->opcode() != HloOpcode::kCollectivePermute) {
+ continue;
+ }
+
+ if (inst->frontend_attributes().map().find(kSendRecvValidationAttr) !=
+ inst->frontend_attributes().map().end()) {
+ continue;
+ }
+ auto sourceTargetPairs = inst->source_target_pairs();
+ if (!IsForwardCycle(sourceTargetPairs) &&
+ !IsBackwardCycle(sourceTargetPairs)) {
+ continue;
+ }
+
+ VLOG(2) << "Collective permute with cycle: " << inst->ToString();
+
+ int64_t max_device_num = -1;
+ for (auto [source, target] : sourceTargetPairs) {
+ max_device_num = std::max(std::max(source, target), max_device_num);
+ }
+ int64_t num_devices = max_device_num + 1;
+
+ HloInstruction* whileOp = inst->parent()->WhileCallInstruction();
+ if (whileOp == nullptr) {
+ VLOG(2) << "No surrounding while op found. Ignoring " << inst->name();
+ continue;
+ }
+ if (!whileOp->frontend_attributes().map().contains(
+ "is_pipelined_while_loop"))
+ continue;
+ TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
+ whileOp->backend_config<WhileLoopBackendConfig>());
+ if (!config.has_known_trip_count()) {
+ VLOG(2) << "Trip count for while loop (" << whileOp->name()
+ << "): unknown";
+ continue;
+ }
+
+ int64_t trip_count = config.known_trip_count().n();
+ std::optional<int64_t> step = GetStep(whileOp);
+ VLOG(2) << "Trip count for while loop (" << whileOp->name()
+ << "): " << trip_count;
+ if (!step) {
+ VLOG(2) << "Could not find step for while operation";
+ continue;
+ }
+ VLOG(2) << "Step for while loop (" << whileOp->name() << "): " << *step;
+ if (*step != 1) {
+ VLOG(2) << "Step is not 1. Skipping...";
+ continue;
+ }
+
+ // For each source i, the send/recv iteration instances are {i, i+offset}
+ // where offset is `number of microbatches * CR - 1`. We know that
+ // `trip_count = number_of_microbatches * CR + num_devices - 1` So, offset
+ // = number_of_microbatches * CR - 1 = trip_count - num_devices.
+ int64_t offset = trip_count - num_devices;
+
+ std::vector<std::pair<int64_t, int64_t>> sendRecvValidation(
+ sourceTargetPairs.size());
+
+ for (size_t currIdx = 0; currIdx < sourceTargetPairs.size(); currIdx++) {
+ sendRecvValidation[currIdx] = {currIdx, currIdx + offset};
+ }
+
+ if (IsBackwardCycle(sourceTargetPairs)) {
+ std::reverse(sendRecvValidation.begin(), sendRecvValidation.end());
+ }
+
+ xla::FrontendAttributes attributes;
+ std::string iteration_instances =
+ "{" +
+ absl::StrJoin(sendRecvValidation, ",",
+ [](std::string* out, std::pair<int64_t, int64_t> item) {
+ absl::StrAppend(out, "{", item.first, ",",
+ item.second, "}");
+ }) +
+ "}";
+ (*attributes.mutable_map())[kSendRecvValidationAttr] =
+ iteration_instances;
+
+ inst->add_frontend_attributes(attributes);
+ VLOG(1) << "Adding " << kSendRecvValidationAttr << " to " << inst->name()
+ << ": " << iteration_instances;
+ changed = true;
+ }
+ }
+ return changed;
+}
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h
new file mode 100644
index 0000000..f8999a9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h
@@ -0,0 +1,58 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
+
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// This is an unsafe transformation that is triggered only if the attribute
+// `is_pipelined_while_loop` is present on a while loop.
+//
+// If a while loop is known to be a pipelined while loop, has a known trip count
+// and increments with step=1, then this pass annotates the `collective-permute`
+// operations within the while loop with valid iterations for each GPU. This is
+// only done when the source-target pairs of the `collective-permute` operation
+// form a forward or backward cycle.
+//
+// For example, if the trip count is 10 (iteration 0 to 9), with step=1, and the
+// source-target pairs of a `collective-permute` operation are
+// `{{0,1},{1,2},{2,3},{3,0}}`, then this pass would annotate such operation
+// with `_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"`. This annotation
+// means that
+// - for GPU index 0, the valid iterations are 0,1,2,3,4,5,6.
+// - for GPU index 1, the valid iterations are 1,2,3,4,5,6,7.
+// - for GPU index 2, the valid iterations are 2,3,4,5,6,7,8.
+// - for GPU index 3, the valid iterations are 3,4,5,6,7,8,9.
+//
+// The index in the list denotes the device index and the bounds {start,end} are
+// inclusive. For more examples, look at
+// `xla/service/spmd/collective_permute_valid_iteration_annotator_tests.cc`.
+class CollectivePermuteValidIterationAnnotator : public HloModulePass {
+ public:
+ CollectivePermuteValidIterationAnnotator() = default;
+ absl::string_view name() const override {
+ return "collective-permute-valid-iteration-annotator";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc
new file mode 100644
index 0000000..3585acc
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc
@@ -0,0 +1,174 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h"
+
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/while_loop_trip_count_annotator.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+
+using CollectivePermuteValidIterationAnnotatorTest = HloTestBase;
+
+TEST_F(CollectivePermuteValidIterationAnnotatorTest, NoChange) {
+ // We expect no changes here because the while loop is not labelled as
+ // `is_pipelined_while_loop`.
+ absl::string_view hlo_string = R"(
+ HloModule test, entry_computation_layout={()->(s32[], s32[])}
+ %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
+ %param = (s32[], s32[]) parameter(0)
+ %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
+ %one = s32[] constant(1)
+ %i_plus_one = s32[] add(s32[] %i, s32[] %one)
+ %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+ ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %permute)
+ }
+ %Cond (param.1: (s32[], s32[])) -> pred[] {
+ %param.1 = (s32[], s32[]) parameter(0)
+ %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
+ %trip_count = s32[] constant(10)
+ ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
+ }
+ ENTRY %test () -> (s32[], s32[]) {
+ %i_start = s32[] constant(0)
+ %p_start = s32[] constant(0)
+ %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
+ ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string, 1, 4));
+
+ HloPassPipeline pipeline("my-pass-pipeline");
+
+ pipeline.AddPass<WhileLoopTripCountAnnotator>();
+ pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+
+ HloCollectivePermuteInstruction* cp =
+ DynCastOrNull<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), HloOpcode::kCollectivePermute));
+
+ ASSERT_NE(cp, nullptr);
+
+ auto sendRecvValidationIt =
+ cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+ ASSERT_EQ(sendRecvValidationIt, cp->frontend_attributes().map().end());
+}
+
+TEST_F(CollectivePermuteValidIterationAnnotatorTest, ForwardCycle) {
+ absl::string_view hlo_string = R"(
+ HloModule test, entry_computation_layout={()->(s32[], s32[])}
+ %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
+ %param = (s32[], s32[]) parameter(0)
+ %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
+ %one = s32[] constant(1)
+ %i_plus_one = s32[] add(s32[] %i, s32[] %one)
+ %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+ ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
+ }
+ %Cond (param.1: (s32[], s32[])) -> pred[] {
+ %param.1 = (s32[], s32[]) parameter(0)
+ %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
+ %trip_count = s32[] constant(10)
+ ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
+ }
+ ENTRY %test () -> (s32[], s32[]) {
+ %i_start = s32[] constant(0)
+ %p_start = s32[] constant(0)
+ %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
+ ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string, 1, 4));
+
+ HloPassPipeline pipeline("my-pass-pipeline");
+
+ pipeline.AddPass<WhileLoopTripCountAnnotator>();
+ pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloCollectivePermuteInstruction* cp =
+ DynCastOrNull<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), HloOpcode::kCollectivePermute));
+
+ ASSERT_NE(cp, nullptr);
+
+ auto sendRecvValidationIt =
+ cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+ ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
+ std::string sendRecvValidationAttr = sendRecvValidationIt->second;
+ EXPECT_EQ(sendRecvValidationAttr, "{{0,6},{1,7},{2,8},{3,9}}");
+}
+
+TEST_F(CollectivePermuteValidIterationAnnotatorTest, BackwardCycle) {
+ absl::string_view hlo_string = R"(
+ HloModule test, entry_computation_layout={()->(s32[], s32[])}
+ %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
+ %param = (s32[], s32[]) parameter(0)
+ %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
+ %one = s32[] constant(1)
+ %i_plus_one = s32[] add(s32[] %i, s32[] %one)
+ %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+ ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
+ }
+ %Cond (param.1: (s32[], s32[])) -> pred[] {
+ %param.1 = (s32[], s32[]) parameter(0)
+ %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
+ %trip_count = s32[] constant(10)
+ ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
+ }
+ ENTRY %test () -> (s32[], s32[]) {
+ %i_start = s32[] constant(0)
+ %p_start = s32[] constant(0)
+ %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
+ ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string, 1, 4));
+
+ HloPassPipeline pipeline("my-pass-pipeline");
+
+ pipeline.AddPass<WhileLoopTripCountAnnotator>();
+ pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloCollectivePermuteInstruction* cp =
+ DynCastOrNull<HloCollectivePermuteInstruction>(
+ FindInstruction(module.get(), HloOpcode::kCollectivePermute));
+
+ ASSERT_NE(cp, nullptr);
+
+ auto sendRecvValidationIt =
+ cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+ ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
+ std::string sendRecvValidationAttr = sendRecvValidationIt->second;
+ EXPECT_EQ(sendRecvValidationAttr, "{{3,9},{2,8},{1,7},{0,6}}");
+}
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc
new file mode 100644
index 0000000..2d4aa52
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc
@@ -0,0 +1,811 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/command_buffer_scheduling.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/variant_visitor.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+using CommandBuffer = CommandBufferScheduling::CommandBuffer;
+using CommandBufferConfig = CommandBufferScheduling::CommandBufferConfig;
+
+// Returns true if HLO computation can be executed as a command buffer.
+static bool IsCommand(const HloComputation* computation,
+ const CommandBufferConfig& config);
+
+//===----------------------------------------------------------------------===//
+// No-op HLO operations.
+//===----------------------------------------------------------------------===//
+
+// Some of the HLO operations do not have corresponding operations at run time
+// and they can be safely wrapped into command buffers together with load
+// bearing commands.
+
+static bool IsConstant(const HloInstruction* hlo) {
+ return hlo->opcode() == HloOpcode::kConstant;
+}
+
+static bool IsParameter(const HloInstruction* hlo) {
+ return hlo->opcode() == HloOpcode::kParameter;
+}
+
+// Returns true if instruction is no-op at run time and doesn't have a
+// corresponding Thunk or Command (metadata only operation).
+static bool IsNoOp(const HloInstruction* hlo) {
+ return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
+ HloOpcode::kGetTupleElement>(hlo);
+};
+
+//===----------------------------------------------------------------------===//
+// Synchronous HLO operations mapped to commands.
+//===----------------------------------------------------------------------===//
+
+// Synchronous HLO operations can be wrapped into command buffers when they have
+// a corresponding commands.
+
+// This is a template to define pattern matching functions for HLO instructions
+// that do not have a corresponding class for them.
+template <HloOpcode op>
+static bool IsCommand(const HloInstruction*, const CommandBufferConfig&);
+
+// While loops can be executed inside command buffers only if condition and body
+// regions can be executed as command buffers.
+template <>
+bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
+ const CommandBufferConfig& config) {
+ return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
+ IsCommand(hlo->while_body(), config) &&
+ IsCommand(hlo->while_condition(), config);
+}
+
+// Conditional can be executed inside command buffers only if all regions of its
+// branches can be executed as command buffers.
+template <>
+bool IsCommand<HloOpcode::kConditional>(const HloInstruction* hlo,
+ const CommandBufferConfig& config) {
+ return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
+ absl::c_all_of(hlo->branch_computations(),
+ [&](const HloComputation* comp) {
+ return IsCommand(comp, config);
+ });
+}
+
+static bool IsCommand(const HloCustomCallInstruction* hlo,
+ const CommandBufferConfig& config) {
+ // cuBLAS gemms represented in the HLO as custom call instructions.
+ if (config.enabled_commands.contains(DebugOptions::CUBLAS) &&
+ IsLegacyCublasMatmul(*hlo)) {
+ return true;
+ }
+
+ if (config.enabled_commands.contains(DebugOptions::CUBLASLT) &&
+ (IsCublasLtMatmul(*hlo) || IsCublasLtMatmulF8(*hlo))) {
+ return true;
+ }
+
+ if (config.enabled_commands.contains(DebugOptions::CUDNN) &&
+ IsCustomCallTofMHA(*hlo)) {
+ VLOG(3) << "Recording FusedMHA, target " << hlo->custom_call_target()
+ << " into command buffer.";
+ return true;
+ }
+
+ if (!config.enabled_commands.contains(DebugOptions::CUSTOM_CALL)) {
+ return false;
+ }
+
+ if (config.enabled_legacy_custom_call_targets.contains(
+ hlo->custom_call_target())) {
+ VLOG(3) << "Recording legacy custom call target "
+ << hlo->custom_call_target() << " into command buffer.";
+ return true;
+ }
+
+ // A special case for jax-triton kernel while it is not ported to FFI.
+ if (hlo->custom_call_target() == "triton_kernel_call" &&
+ // TODO(b/327718087): This is an ugly hack to prevent capturing triton
+ // custom calls that might do autotuning at run time.
+ !absl::StrContains(hlo->metadata().op_name(), "Autotuner")) {
+ return true;
+ }
+
+ // Check if FFI handler is compatible with command buffers.
+ auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
+ return registration.ok()
+ ? ffi::IsCommandBufferCompatible(registration->traits)
+ : false;
+}
+
+static bool IsCommand(const HloInstruction* hlo,
+ const CommandBufferConfig& config) {
+ if (auto* fusion = DynCast<HloFusionInstruction>(hlo)) {
+ auto gpu_config = fusion->backend_config<GpuBackendConfig>();
+ const FusionBackendConfig& backend_config =
+ gpu_config->fusion_backend_config();
+ if (backend_config.kind() == kCuDnnFusionKind) {
+ return config.enabled_commands.contains(DebugOptions::CUDNN);
+ }
+ const auto& custom_config = backend_config.custom_fusion_config();
+ if (custom_config.name() == "address_computation") {
+ auto fusion_analysis =
+ HloFusionAnalysis::Create(*hlo, config.device_description);
+ const HloFusionAdaptor& adaptor = fusion_analysis.fusion();
+ auto custom_call_adaptor = HloBfsFindIf(
+ adaptor.GetRoots(), adaptor,
+ [](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
+ const auto* custom_call = static_cast<const HloCustomCallInstruction*>(
+ &custom_call_adaptor->instruction());
+ return IsCommand(custom_call, config);
+ }
+ if (custom_config.name() == "dynamic_address_computation") {
+ return false;
+ }
+ return config.enabled_commands.contains(DebugOptions::FUSION);
+ }
+
+ if (auto* sort = DynCast<HloSortInstruction>(hlo))
+ return config.enabled_commands.contains(DebugOptions::FUSION);
+
+ if (hlo->opcode() == HloOpcode::kPartitionId ||
+ hlo->opcode() == HloOpcode::kReplicaId) {
+ return config.enabled_commands.contains(DebugOptions::FUSION);
+ }
+
+ if (auto* custom_call = DynCast<HloCustomCallInstruction>(hlo))
+ return IsCommand(custom_call, config);
+
+ if (hlo->opcode() == HloOpcode::kWhile)
+ return IsCommand<HloOpcode::kWhile>(hlo, config);
+
+ if (hlo->opcode() == HloOpcode::kConditional)
+ return IsCommand<HloOpcode::kConditional>(hlo, config);
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Asynchronous HLO operations mapped to commands.
+//===----------------------------------------------------------------------===//
+
+// Asynchronous HLO operations can be wrapped into command buffers only when
+// both start and done operations can be put into the same command buffer.
+// Command buffer semantics implies that when command buffer execution
+// completes, all recorded commands are also completed, which means that if
+// done operation is not part of the same command buffer, we would change the
+// execution semantics and create additional synchronization point.
+
+static bool IsAsyncStartCommand(const HloInstruction* hlo,
+ const CommandBufferConfig& config) {
+ if (hlo->opcode() == HloOpcode::kAllReduceStart ||
+ hlo->opcode() == HloOpcode::kAllGatherStart) {
+ return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+ }
+
+ if (hlo->opcode() == HloOpcode::kAsyncStart) {
+ if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
+ return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+ }
+ }
+
+ return false;
+}
+
+static bool IsAsyncDoneCommand(const HloInstruction* hlo,
+ const CommandBufferConfig& config) {
+ if (hlo->opcode() == HloOpcode::kAllReduceDone ||
+ hlo->opcode() == HloOpcode::kAllGatherDone) {
+ return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+ }
+
+ if (hlo->opcode() == HloOpcode::kAsyncDone) {
+ if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
+ return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+ }
+ }
+
+ return false;
+}
+
+// Finds an async-done HLO operation corresponding on an async-start one.
+static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) {
+ if (start->opcode() == HloOpcode::kAllReduceStart ||
+ start->opcode() == HloOpcode::kAllGatherStart) {
+ CHECK(start->users().size() == 1); // NOLINT, checked by HLO verifier
+ return start->users().front();
+ } else if (start->opcode() == HloOpcode::kAsyncStart) {
+ return start->async_chain_done();
+ }
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// HLO computations mapped to command buffers.
+//===----------------------------------------------------------------------===//
+
+// Returns true if HLO computation can be executed as a command buffer.
+static bool IsCommand(const HloComputation* computation,
+ const CommandBufferConfig& config) {
+ return absl::c_all_of(
+ computation->instructions(), [&](const HloInstruction* inst) {
+ return IsNoOp(inst) || IsConstant(inst) || IsParameter(inst) ||
+ IsCommand(inst, config) || IsAsyncStartCommand(inst, config) ||
+ IsAsyncDoneCommand(inst, config);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+
+static void RemoveTrailingNoOps(HloInstructionSequence& seq) {
+ std::vector<HloInstruction*> instructions = seq.instructions();
+ for (int i = instructions.size() - 1; i >= 0; i--) {
+ if (HloInstruction* inst = instructions[i]; IsNoOp(inst)) {
+ seq.remove_instruction(inst);
+ } else {
+ break;
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Discovering sequences of compatible Hlo instructions
+//===----------------------------------------------------------------------===//
+
+// The input is a scheduled sequence of instructions. This function collects
+// subsequences that will be extracted as command buffers.
+std::vector<HloInstructionSequence>
+CommandBufferScheduling::CollectCommandBufferSequences(
+ const HloInstructionSequence schedule, const CommandBufferConfig& config,
+ int32_t min_num_commands) {
+ std::vector<HloInstructionSequence> sequences;
+
+ HloInstructionSequence current_seq;
+ int64_t num_commands_in_current_seq = 0;
+
+ // Adds `current_seq` to `sequences` if it has enough commands in it.
+ auto collect_current_seq = [&]() {
+ if (num_commands_in_current_seq >= std::max(1, min_num_commands)) {
+ RemoveTrailingNoOps(current_seq);
+ sequences.push_back(std::move(current_seq));
+ }
+ current_seq = HloInstructionSequence();
+ num_commands_in_current_seq = 0;
+ };
+
+ auto& instructions = schedule.instructions();
+
+ // Collect the sequence of instructions that contains the async start and its
+ // corresponding done instruction. If there is another start instruction
+ // between the original start and done, we may potentially extend the sequence
+ // to include its corresponding done instruction. For example, if we call this
+ // function on async-start_a in the following sequence:
+ //
+ // async_start_a
+ // async_start_b
+ // async_done_a
+ // async_done_b
+ //
+ // The returned sequence will contain async_done_b. So that all async pairs
+ // are captured by the same command buffer.
+ auto collect_async_region = [&](const HloInstruction* start) {
+ auto get_index = [&](const HloInstruction* inst) -> size_t {
+ auto it = std::find(instructions.begin(), instructions.end(), inst);
+ return std::distance(instructions.begin(), it);
+ };
+
+ HloInstructionSequence seq;
+ size_t done_index = get_index(FindAsyncDoneCommand(start));
+ for (size_t i = get_index(start); i <= done_index; i++) {
+ HloInstruction* inst = instructions.at(i);
+ if (IsAsyncStartCommand(inst, config)) {
+ const HloInstruction* done = FindAsyncDoneCommand(inst);
+ done_index = std::max(done_index, get_index(done));
+ }
+ seq.push_back(inst);
+ }
+ return seq;
+ };
+
+ // Check that instructions are safe to be captured by command buffer, and that
+ // we do not capture unmatched async done instruction.
+ auto check_async_region = [&](const HloInstructionSequence& seq) {
+ if (!absl::c_all_of(seq.instructions(), [&](HloInstruction* inst) {
+ return IsNoOp(inst) || IsCommand(inst, config) ||
+ IsAsyncStartCommand(inst, config) ||
+ IsAsyncDoneCommand(inst, config);
+ })) {
+ return false;
+ }
+
+ absl::flat_hash_set<HloInstruction*> done_instructions;
+ for (const HloInstruction* inst : seq.instructions()) {
+ if (IsAsyncStartCommand(inst, config)) {
+ done_instructions.insert(FindAsyncDoneCommand(inst));
+ }
+ if (IsAsyncDoneCommand(inst, config)) {
+ if (!done_instructions.contains(inst)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+
+ for (size_t i = 0; i < instructions.size(); i++) {
+ HloInstruction* inst = instructions.at(i);
+
+ // We add no-op instructions to current sequence only if they act as a glue
+ // between commands. We do not create command sequences consisting only from
+ // no-op instruction. First and last instruction in the command buffer is
+ // always a load-bearing command.
+ if (IsNoOp(inst) && num_commands_in_current_seq) {
+ current_seq.push_back(inst);
+ continue;
+ }
+
+ // Synchronous commands always can be added to instruction sequence.
+ if (IsCommand(inst, config)) {
+ num_commands_in_current_seq++;
+ current_seq.push_back(inst);
+ continue;
+ }
+
+ // We capture async commands if all instruction between start and done can
+ // be outlined into a command buffer.
+ if (IsAsyncStartCommand(inst, config)) {
+ HloInstructionSequence seq = collect_async_region(inst);
+ if (check_async_region(seq)) {
+ num_commands_in_current_seq += seq.instructions().size();
+ for (HloInstruction* inst : seq.instructions()) {
+ current_seq.push_back(inst);
+ }
+ i += seq.instructions().size() - 1;
+ continue;
+ }
+ }
+
+ // If we didn't find the next command, collect the current sequence and
+ // start a new one.
+ collect_current_seq();
+ }
+
+ // Don't forget to collect the final command sequence.
+ collect_current_seq();
+ return sequences;
+}
+
+// This function moves kParameter and kConstant instructions in a computation to
+// the beginning of the computation. This simplifies the construction of command
+// buffer computations because we don't need to deal with parameters and
+// constants that have users outside of a command buffer.
+absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
+ HloComputation* computation) {
+ HloInstructionSequence new_sequence;
+ HloSchedule& schedule = computation->parent()->schedule();
+ HloInstructionSequence& sequence = schedule.GetOrCreateSequence(computation);
+
+ for (HloInstruction* inst : sequence.instructions()) {
+ if (IsParameter(inst) || IsConstant(inst)) {
+ new_sequence.push_back(inst);
+
+ // Because we move instruction to the front of the computation we can't
+ // have any control predecessors, however silently dropping them is unsafe
+ // as we can have transitive dependencies that define schedule order, so
+ // we forward control predecessors to all users.
+ for (HloInstruction* control_predecessor : inst->control_predecessors()) {
+ for (HloInstruction* user : inst->users()) {
+ TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user));
+ }
+ }
+ TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
+ }
+ }
+
+ for (HloInstruction* inst : sequence.instructions()) {
+ if (!IsParameter(inst) && !IsConstant(inst)) {
+ new_sequence.push_back(inst);
+ }
+ }
+
+ schedule.set_sequence(computation, new_sequence);
+ return absl::OkStatus();
+}
+
+//===----------------------------------------------------------------------===//
+// Prepares command buffer from sequence of instructions
+//===----------------------------------------------------------------------===//
+
+absl::StatusOr<CommandBuffer> CommandBufferScheduling::PrepareCommandBuffer(
+ const HloInstructionSequence& seq, HloModule* module) {
+ auto builder = HloComputation::Builder("command_buffer");
+
+ absl::Span<HloInstruction* const> instructions =
+ absl::MakeSpan(seq.instructions());
+
+ // A set of instructions that will be moved into command buffer computation.
+ absl::flat_hash_set<HloInstruction*> in_command_buffer(instructions.begin(),
+ instructions.end());
+
+ // The sequence might use results of instructions that are not captured by the
+ // sequence. We pass those results as parameters and map the producers of the
+ // results to their corresponding parameter instructions.
+ absl::flat_hash_map<HloInstruction*, HloParameterInstruction*> parameters;
+
+ // Mapping from command buffer instructions to their clones in the command
+ // buffer computation body.
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
+
+ // Maps HLO instructions in the original computation to instructions in the
+ // command buffer: (a) a parameter corresponding to captured value (b) cloned
+ // instruction corresponding to a command.
+ auto mapped_operands = [&](HloInstruction* instr) {
+ absl::InlinedVector<HloInstruction*, 4> operands;
+ for (HloInstruction* operand : instr->operands()) {
+ if (auto it = inst_mapping.find(operand); it != inst_mapping.end())
+ operands.push_back(it->second);
+ }
+ return operands;
+ };
+
+ // Create parameters in the command buffer computation for captured values.
+ for (HloInstruction* inst : instructions) {
+ for (HloInstruction* operand : inst->operands()) {
+ // We already mapped instruction to a parameter.
+ if (parameters.contains(operand)) continue;
+
+ // Operand instruction is a part of the command buffer.
+ if (in_command_buffer.contains(operand)) continue;
+
+ // Create a new parameter for value defined outside of a command buffer.
+ int64_t parameter_id = parameters.size();
+ auto* parameter = Cast<HloParameterInstruction>(
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ parameter_id, operand->shape(), "p")));
+
+ parameter->UniquifyName(module);
+ parameter->UniquifyId(module);
+ inst_mapping[operand] = parameters[operand] = parameter;
+ }
+ }
+
+ // Clone commands into the command buffer body with mapped operands.
+ for (HloInstruction* inst : seq.instructions()) {
+ HloCloneContext ctx(inst->GetModule());
+
+ // Cloned instructions should call the same computations as original
+ // instructions will be dead code eliminated.
+ for (HloComputation* called_computation : inst->called_computations()) {
+ // Async computations can only be referenced by a single async chain at
+ // a time. Detach the current chain to let its copy bind to the
+ // computation.
+ if (called_computation->IsAsyncComputation()) {
+ called_computation->RemoveAsyncStart();
+ }
+ ctx.MapComputation(called_computation, called_computation);
+ }
+
+ inst_mapping[inst] = builder.AddInstruction(
+ inst->CloneWithNewOperands(inst->shape(), mapped_operands(inst), &ctx));
+ inst_mapping[inst]->UniquifyId(module);
+ }
+
+ // Convert parameters to command buffer arguments.
+ std::vector<HloInstruction*> arguments(parameters.size());
+ for (auto& [argument, parameter] : parameters) {
+ arguments[parameter->parameter_number()] = argument;
+ }
+
+ // Collect command buffer `results` (instructions replaced in the original
+ // computation) and `results` (instructions in the command buffer).
+ std::vector<HloInstruction*> results;
+ std::vector<HloInstruction*> returned;
+
+ auto has_external_users = [&](HloInstruction* inst) {
+ return inst->IsRoot() || absl::c_any_of(inst->users(), [&](auto* user) {
+ return !in_command_buffer.contains(user);
+ });
+ };
+
+ for (HloInstruction* inst : instructions) {
+ if (has_external_users(inst)) {
+ results.push_back(inst);
+ returned.push_back(inst_mapping[inst]);
+ }
+ }
+
+ // If we return multiple results wrap them into tuple.
+ if (returned.size() > 1) {
+ HloInstruction* inst =
+ builder.AddInstruction(HloInstruction::CreateTuple(returned));
+ inst->UniquifyName(module);
+ inst->UniquifyId(module);
+ }
+
+ std::unique_ptr<HloComputation> comp = builder.Build();
+ comp->UniquifyName(module);
+ comp->SetUniqueId(comp->root_instruction()->unique_id());
+
+ return CommandBuffer{std::move(arguments), std::move(results),
+ std::move(comp), std::move(inst_mapping)};
+}
+
+//===----------------------------------------------------------------------===//
+// Rewrites original computation into command buffer call
+//===----------------------------------------------------------------------===//
+
+absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
+ HloComputation* parent, const HloInstructionSequence& seq,
+ CommandBuffer command_buffer) {
+ if (command_buffer.results.empty())
+ return absl::InternalError("command buffer results must not be empty");
+
+ // If we have more than one result we return them as tuple, and get individual
+ // values using `get-tuple-element` instructions. Otherwise we simply return
+ // a result from a command buffer computation.
+ Shape cmd_buffer_result_shape;
+ bool has_single_result = command_buffer.results.size() == 1;
+
+ if (has_single_result) {
+ cmd_buffer_result_shape = command_buffer.results[0]->shape();
+ } else {
+ absl::InlinedVector<Shape, 4> shapes;
+ shapes.reserve(command_buffer.results.size());
+ for (auto* res : command_buffer.results) shapes.push_back(res->shape());
+ cmd_buffer_result_shape = ShapeUtil::MakeTupleShape(shapes);
+ }
+
+ HloComputation* computation =
+ parent->parent()->AddComputation(std::move(command_buffer.computation),
+ /*is_entry=*/false);
+
+ HloInstruction* call = parent->AddInstruction(HloInstruction::CreateCall(
+ cmd_buffer_result_shape, command_buffer.arguments, computation));
+
+ // Replace all users or original results with a command buffer results.
+ if (has_single_result) {
+ TF_RETURN_IF_ERROR(command_buffer.results[0]->ReplaceAllUsesWith(call));
+ } else {
+ for (int i = 0; i < command_buffer.results.size(); i++) {
+ TF_RETURN_IF_ERROR(
+ command_buffer.results[i]->ReplaceAllUsesWith(parent->AddInstruction(
+ HloInstruction::CreateGetTupleElement(call, i))));
+ }
+ }
+
+ // As we are running after scheduling we have to keep it valid.
+ HloSchedule& schedule = parent->parent()->schedule();
+
+ // Update schedule to replace the last instruction with a command buffer call.
+ // Removal of the rest of the instructions in the sequence is handled by
+ // schedule update below.
+ HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent);
+ sequence.replace_instruction(seq.instructions().back(), call);
+
+ // Rebuild original instruction sequence schedule in a newly created
+ // command buffer computation to guarantee that we'll get exactly the same
+ // buffer assignment result as if we were running without command buffers.
+ HloInstructionSequence cmd_buffer_schedule;
+ for (auto* argument : command_buffer.arguments) {
+ cmd_buffer_schedule.push_back(command_buffer.inst_mapping[argument]);
+ }
+ for (auto* inst : seq.instructions()) {
+ cmd_buffer_schedule.push_back(command_buffer.inst_mapping[inst]);
+ }
+ if (!has_single_result) {
+ cmd_buffer_schedule.push_back(computation->root_instruction());
+ }
+ schedule.set_sequence(computation, cmd_buffer_schedule);
+
+ // Forward control dependencies between original instructions to instruction
+ // in the command buffer computation.
+ auto& inst_mapping = command_buffer.inst_mapping;
+ for (HloInstruction* inst : seq.instructions()) {
+ HloInstruction* cmd_inst = inst_mapping[inst];
+
+ // Forward control dependencies to the new instruction inside command
+ // buffer. If the dependent instruction is not captured by the command
+ // buffer, forward the dependency to the command buffer call instead.
+ for (HloInstruction* predecessor : inst->control_predecessors()) {
+ if (auto it = inst_mapping.find(predecessor); it != inst_mapping.end()) {
+ // If predecessor mapped to a parameter instruction it means that we
+ // need to forward control dependency to a call operation, otherwise
+ // we add control dependency between commands in the command buffer.
+ HloInstruction* cmd_predecessor = it->second;
+ if (IsParameter(cmd_predecessor)) {
+ TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
+ } else {
+ TF_RETURN_IF_ERROR(cmd_predecessor->AddControlDependencyTo(cmd_inst));
+ }
+ } else {
+ TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
+ }
+ }
+
+ for (HloInstruction* successor : inst->control_successors()) {
+ if (auto it = inst_mapping.find(successor); it != inst_mapping.end()) {
+ HloInstruction* cmd_successor = it->second;
+ TF_RETURN_IF_ERROR(cmd_inst->AddControlDependencyTo(cmd_successor));
+ } else {
+ TF_RETURN_IF_ERROR(call->AddControlDependencyTo(successor));
+ }
+ }
+
+ TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
+ }
+
+ // Traverse in reverse order as original sequence was topologically sorted and
+ // we can't remove instructions with users.
+ for (int32_t i = seq.instructions().size() - 1; i >= 0; i--) {
+ TF_RETURN_IF_ERROR(parent->RemoveInstruction(seq.instructions()[i]));
+ }
+
+ return computation;
+}
+
+//===----------------------------------------------------------------------===//
+
+CommandBufferScheduling::CommandBufferScheduling(
+ const se::DeviceDescription& device_description,
+ int32_t gpu_toolkit_version, int32_t gpu_driver_version)
+ : device_description_(device_description),
+ gpu_toolkit_version_(gpu_toolkit_version),
+ gpu_driver_version_(gpu_driver_version) {}
+
+absl::StatusOr<bool> CommandBufferScheduling::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ // We run command buffer scheduling after a regular scheduling to guarantee
+ // that command buffers will not change execution order and buffer assignment
+ // compared to a regular execution. Some operations (i.e. async collectives)
+ // can't be captured into command buffers, and forming too large command
+ // buffers too early can impact async operations scheduling.
+ if (!module->has_schedule()) return Internal("module is not scheduled");
+
+ const DebugOptions& debug_options = module->config().debug_options();
+
+ absl::flat_hash_set<DebugOptions::CommandBufferCmdType> commands;
+ for (auto cmd_type : debug_options.xla_gpu_enable_command_buffer()) {
+ commands.insert(static_cast<DebugOptions::CommandBufferCmdType>(cmd_type));
+ }
+
+ absl::flat_hash_set<std::string> legacy_custom_call_targets;
+ for (const auto& target :
+ debug_options.legacy_command_buffer_custom_call_targets()) {
+ legacy_custom_call_targets.insert(target);
+ }
+
+ CommandBufferConfig config{std::move(commands),
+ std::move(legacy_custom_call_targets),
+ device_description_};
+
+ // Erase command buffer cmd types that are not supported by the gpu runtime.
+ static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS};
+ static constexpr auto kRequireTracing = {
+ DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN,
+ DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES};
+
+ auto erase = [&](absl::Span<const DebugOptions::CommandBufferCmdType> cmds) {
+ for (auto cmd : cmds) {
+ if (config.enabled_commands.erase(cmd)) {
+ VLOG(1) << "Removed command buffer support for "
+ << DebugOptions::CommandBufferCmdType_Name(cmd)
+ << " as it's not supported with gpu toolkit version "
+ << gpu_toolkit_version_ << " and driver version "
+ << gpu_driver_version_
+ << ". This might negatively impact peformance. To enable "
+ << DebugOptions::CommandBufferCmdType_Name(cmd)
+ << " support in command buffers use cuda-compat package: "
+#if defined(PLATFORM_GOOGLE)
+ << "set CUDA_COMPAT_LOAD=1 env variable.";
+#else
+ << "https://docs.nvidia.com/deploy/cuda-compatibility/.";
+#endif
+ }
+ }
+ };
+
+ // Check if CUDA/ROCM driver supports required features.
+ auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
+ if (std::min(gpu_toolkit_version_, gpu_driver_version_) < 12030) {
+ erase(kRequireTracing); // cuStreamBeginCaptureToGraph
+ erase(kRequireConditionals); // on-device control flow
+ }
+ };
+ auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
+ erase(kRequireConditionals); // on-device control flow
+ };
+
+ std::visit(VariantVisitor{erase_cuda, erase_rocm},
+ device_description_.gpu_compute_capability());
+
+ auto order = module->MakeComputationPostOrder();
+ std::reverse(order.begin(), order.end());
+ absl::flat_hash_set<HloComputation*> processed_command_buffers;
+
+ for (HloComputation* comp : order) {
+ // Skip special computations that do not have lowering to thunks.
+ if (comp->IsFusionComputation() || comp->IsAsyncComputation() ||
+ comp->IsCustomCallComputation())
+ continue;
+
+ // Skip computations that already part of command buffers.
+ if (processed_command_buffers.contains(comp)) continue;
+
+ TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp));
+
+ std::vector<HloInstructionSequence> sequences =
+ CollectCommandBufferSequences(
+ module->schedule().sequence(comp), config,
+ debug_options.xla_gpu_graph_min_graph_size());
+
+ for (const HloInstructionSequence& seq : sequences) {
+ TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer,
+ PrepareCommandBuffer(seq, comp->parent()));
+ TF_ASSIGN_OR_RETURN(
+ HloComputation * command_buffer_computation,
+ RewriteCommandBuffer(comp, seq, std::move(command_buffer)));
+
+ // All computations reachable from a command buffer computation are nested
+ // command buffers (i.e. body computations attached to a while operation).
+ for (HloComputation* called :
+ command_buffer_computation->MakeEmbeddedComputationsList()) {
+ processed_command_buffers.insert(called);
+ }
+ }
+ }
+ TF_RETURN_IF_ERROR(module->schedule().Update());
+
+ return true;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h
new file mode 100644
index 0000000..30e0249
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h
@@ -0,0 +1,143 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla::gpu {
+
+// Lift fusion instructions to command buffers.
+//
+// Before the pass:
+// %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+// ...
+// }
+//
+// ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+// %a = s32[] parameter(0)
+// %b = s32[] parameter(1)
+// ROOT %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop,
+// calls=%fused_computation
+// }
+//
+// After the pass:
+// %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+// ...
+// }
+//
+// %command_buffer (param_0: s32[], param_1: s32[]) -> s32[] {
+// %param_0 = s32[] parameter(0)
+// %param_1 = s32[] parameter(1)
+// ROOT %fusion = s32[] fusion(s32[] %param_0, s32[] %param_1), kind=kLoop,
+// calls=%fused_computation
+// }
+//
+// ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+// %a = s32[] parameter(0)
+// %b = s32[] parameter(1)
+// ROOT %call = s32[] call(s32[] %a, s32[] %b), to_apply=%command_buffer
+// }
+//
+// We currently do not have a command_buffer HLO operation, so we'll start with
+// a kCall op code with an attached HLO computation. We'll consider graduating
+// custom call to a first class operation later.
+class CommandBufferScheduling : public HloModulePass {
+ public:
+ struct CommandBufferConfig {
+ // DebugOptions control which commands are enabled. Long term we want to
+ // remove that flag and enable all supported commands by default.
+ absl::flat_hash_set<DebugOptions::CommandBufferCmdType> enabled_commands;
+ absl::flat_hash_set<std::string> enabled_legacy_custom_call_targets;
+ const se::DeviceDescription& device_description;
+ };
+
+ CommandBufferScheduling(const se::DeviceDescription& device_description,
+ int32_t gpu_toolkit_version,
+ int32_t gpu_driver_version);
+
+ absl::string_view name() const override {
+ return "command-buffer-scheduling";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ static std::vector<HloInstructionSequence> CollectCommandBufferSequences(
+ HloInstructionSequence schedule, const CommandBufferConfig& config,
+ int32_t min_num_commands = 1);
+
+ // Moves kParameter and kConstant instructions in a computation to
+ // the beginning of the computation. This simplifies the construction of
+ // command buffer computations because we don't need to deal with parameters
+ // and constants that have users outside of a command buffer.
+ static absl::Status MoveParametersAndConstantsToFront(
+ HloComputation* computation);
+
+ struct CommandBuffer {
+ // Command buffer arguments (call instruction arguments).
+ std::vector<HloInstruction*> arguments;
+
+ // Command buffer result (call instruction result tuple).
+ std::vector<HloInstruction*> results;
+
+ // Hlo computation corresponding to a command buffer body.
+ std::unique_ptr<HloComputation> computation;
+
+ // Mapping from original instruction to their clones in the command buffer.
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
+ };
+
+ // Prepares a command buffer from the instruction sequence. Used values
+ // constructed by instructions outside of the sequence are passed in as
+ // parameters. Results of instructions in the sequence are returned in a tuple
+ // (if command buffer has a single result we don't wrap it into tuple).
+ static absl::StatusOr<CommandBuffer> PrepareCommandBuffer(
+ const HloInstructionSequence& seq, HloModule* module);
+
+ // Rewrites prepared command buffer computation into Hlo operations in the
+ // parent computation (calls command buffer and replaced all users).
+ static absl::StatusOr<HloComputation*> RewriteCommandBuffer(
+ HloComputation* parent, const HloInstructionSequence& seq,
+ CommandBuffer command_buffer);
+
+ private:
+ se::DeviceDescription device_description_;
+ // For NVIDIA gpus XLA can be compiled with a CUDA version that is larger than
+ // the version supported by the driver, e.g. we can compile for CUDA 12.3 but
+ // have 12.1 driver installed. When deciding what command buffer features we
+ // can use we have to consider both versions.
+ int32_t gpu_toolkit_version_;
+ int32_t gpu_driver_version_;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc
new file mode 100644
index 0000000..43d0dae
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc
@@ -0,0 +1,1018 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/command_buffer_scheduling.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+class CommandBufferSchedulingTest : public HloTestBase {
+ public:
+ // Use CUDA 12.3 version for testing as it has all the features we rely on.
+ static constexpr int32_t kCudaVersion = 12030;
+
+ const se::DeviceDescription& device_desc() {
+ return backend().default_stream_executor()->GetDeviceDescription();
+ }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ auto debug_options = HloTestBase::GetDebugOptionsForTest();
+ debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
+ debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS);
+ debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
+ debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
+ debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
+ debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
+ debug_options.set_xla_gpu_graph_min_graph_size(2);
+ return debug_options;
+ }
+};
+
+using CommandBuffer = CommandBufferScheduling::CommandBuffer;
+
+TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+ %a = s32[] parameter(0)
+ %b = s32[] parameter(1)
+ %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+ %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1
+ ROOT %custom-call = s32[] custom-call(s32[] %fusion, s32[] %fusion.1), custom_call_target="some target"
+ })";
+
+ const char* expected = R"(
+// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
+// CHECK: %[[P0]] = s32[] parameter(0)
+// CHECK: %[[P1]] = s32[] parameter(1)
+// CHECK: %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+// CHECK: %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
+// CHECK: ROOT %tuple = (s32[], s32[]) tuple(%fusion, %fusion.1)
+// CHECK: }
+//
+// CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+// CHECK: %a = s32[] parameter(0)
+// CHECK: %b = s32[] parameter(1)
+// CHECK: %call = (s32[], s32[]) call(%a, %b), to_apply=%command_buffer
+// CHECK: %get-tuple-element = s32[] get-tuple-element(%call), index=0
+// CHECK: %get-tuple-element.1 = s32[] get-tuple-element(%call), index=1
+// CHECK: ROOT %custom-call = s32[] custom-call(%get-tuple-element, %get-tuple-element.1), custom_call_target="some target"
+// CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
+ %a = s32[] parameter(0)
+ %b = s32[] parameter(1)
+ %c = (s32[], s32[]) parameter(2)
+ %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+ %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
+ %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
+ %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
+ %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
+ %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
+ %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
+ ROOT %custom-call.1 = s32[] custom-call(s32[] %fusion.3), custom_call_target="some target"
+ })";
+
+ const char* expected = R"(
+// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[], [[P2:.+]]: (s32[], s32[])) -> s32[] {
+// CHECK: %[[P0]] = s32[] parameter(0)
+// CHECK: %[[P1]] = s32[] parameter(1)
+// CHECK: %[[P2]] = (s32[], s32[]) parameter(2)
+// CHECK: %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%[[P2]]), index=0
+// CHECK: ROOT {{.*}} = s32[] fusion(%[[F0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
+// CHECK: }
+
+// CHECK: %command_buffer.2 ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
+// CHECK: %[[P0]] = s32[] parameter(0)
+// CHECK: %[[P1]] = s32[] parameter(1)
+// CHECK: %[[F2:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.2
+// CHECK: ROOT {{.*}} = s32[] fusion(%[[P0]], %[[F2]]), kind=kLoop, calls=%fused_computation.3
+// CHECK: }
+
+// CHECK: ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
+// CHECK: %a = s32[] parameter(0)
+// CHECK: %b = s32[] parameter(1)
+// CHECK: %c = (s32[], s32[]) parameter(2)
+// CHECK: %[[CMD0:.+]] = s32[] call(%a, %b, %c), to_apply=%command_buffer
+// CHECK: %e = s32[] get-tuple-element(%c), index=1
+// CHECK: %[[CALL:.+]] = s32[] custom-call(%[[CMD0]], %e), custom_call_target="some target"
+// CHECK: %[[CMD1:.+]] = s32[] call(%[[CALL]], %a), to_apply=%command_buffer.2
+// CHECK: ROOT {{.*}} = s32[] custom-call(%[[CMD1]]), custom_call_target="some target"
+// CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+ %p0 = s32[4] parameter(0)
+ %p1 = s32[4] parameter(1)
+ ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+ }
+
+ ENTRY %main (a: s32[4]) -> s32[4] {
+ %a = s32[4] parameter(0)
+ %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+ replica_groups={{0,1}}, to_apply=%add,
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+ ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
+ CHECK: %[[P0]] = s32[4]{0} parameter(0)
+ CHECK: %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+ CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
+ CHECK: }
+
+ CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
+ CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
+ CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+ CHECK: to_apply=%command_buffer
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ ENTRY %main (a: s32[2]) -> s32[4] {
+ %a = s32[2] parameter(0)
+
+ %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a),
+ channel_id=555, replica_groups={{0,1}}, dimensions={0},
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+
+ ROOT %done = s32[4]{0} all-gather-done(%start)
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s32[2]) -> s32[4] {
+ CHECK: %[[P0]] = s32[2]{0} parameter(0)
+ CHECK: %[[START:.+]] = {{.*}} all-gather-start(%[[P0]])
+ CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-gather-done(%[[START]])
+ CHECK: }
+
+ CHECK: ENTRY %main (a: s32[2]) -> s32[4] {
+ CHECK: %[[A:.+]] = s32[2]{0} parameter(0)
+ CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+ CHECK: to_apply=%command_buffer
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %add (p0: s32[], p1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[4]) -> s32[2] {
+ %a = s32[4] parameter(0)
+
+ %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a),
+ channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add,
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+
+ ROOT %done = s32[2]{0} reduce-scatter-done(%start)
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[2] {
+ CHECK: %[[P0]] = s32[4]{0} parameter(0)
+ CHECK: %[[START:.+]] = {{.*}} reduce-scatter-start(%[[P0]])
+ CHECK: ROOT %[[DONE:.+]] = s32[2]{0} reduce-scatter-done(%[[START]])
+ CHECK: }
+
+ CHECK: ENTRY %main (a: s32[4]) -> s32[2] {
+ CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
+ CHECK: ROOT %[[CALL:.+]] = s32[2]{0} call(%[[A]]),
+ CHECK: to_apply=%command_buffer
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+ %p0 = s32[4] parameter(0)
+ %p1 = s32[4] parameter(1)
+ ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+ }
+
+ ENTRY %main (a: s32[4]) -> s32[4] {
+ %a = s32[4] parameter(0)
+ %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+ replica_groups={{0,1}}, to_apply=%add,
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+ %bitcast = s32[4] bitcast(s32[4]{0} %a)
+ ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
+ CHECK: %[[P0]] = s32[4]{0} parameter(0)
+ CHECK: %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+ CHECK: %[[BITCAST:.+]] = s32[4]{0} bitcast(%[[P0]])
+ CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
+ CHECK: }
+
+ CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
+ CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
+ CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+ CHECK: to_apply=%command_buffer
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+ %p0 = s32[4] parameter(0)
+ %p1 = s32[4] parameter(1)
+ ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+ }
+
+ ENTRY %main (a: s32[4]) -> s32[4] {
+ %a = s32[4] parameter(0)
+ %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+ replica_groups={{0,1}}, to_apply=%add,
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+ %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+ replica_groups={{0,1}}, to_apply=%add,
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+ %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
+ ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
+ CHECK: %[[P0]] = s32[4]{0} parameter(0)
+ CHECK: %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+ CHECK: %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+ CHECK: %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
+ CHECK: ROOT %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
+ CHECK: }
+
+ CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
+ CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
+ CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+ CHECK: to_apply=%command_buffer
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+ %p0 = s32[4] parameter(0)
+ %p1 = s32[4] parameter(1)
+ ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+ }
+
+ ENTRY %main (a: s32[4], b:s32[]) -> s32[] {
+ %a = s32[4] parameter(0)
+ %b = s32[] parameter(1)
+ %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+ replica_groups={{0,1}}, to_apply=%add,
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+ %c = s32[] custom-call(), custom_call_target="target"
+ %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+ replica_groups={{0,1}}, to_apply=%add,
+ backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+ %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
+ %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
+ %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation
+ ROOT %fusion.1 = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation.1
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
+ CHECK: %[[P0]] = s32[] parameter(0)
+ CHECK: %[[P1]] = s32[] parameter(1)
+ CHECK: %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+ CHECK: ROOT %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
+ CHECK: }
+
+ CHECK: ENTRY %main (a: s32[4], b: s32[]) -> s32[] {
+ CHECK: %[[A:.+]] = s32[4]{0} parameter(0)
+ CHECK: %[[B:.+]] = s32[] parameter(1)
+ CHECK: %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[A]])
+ CHECK: %[[C:.+]] = s32[] custom-call()
+ CHECK: %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[A]])
+ CHECK: %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
+ CHECK: %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
+ CHECK: %call = s32[] call(%b, %c), to_apply=%command_buffer
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
+ %a = s32[] parameter(0)
+ %b = s32[] parameter(1)
+ %c = (s32[], s32[]) parameter(2)
+ %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+ %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
+ %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
+ %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
+ %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
+ %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
+ ROOT %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+
+ HloInstructionSequence seq;
+ for (HloInstruction* x : module->entry_computation()->instructions()) {
+ seq.push_back(x);
+ }
+ EXPECT_EQ(seq.size(), 10);
+
+ CommandBufferScheduling::CommandBufferConfig config{
+ {DebugOptions::FUSION}, {}, device_desc()};
+
+ std::vector<HloInstructionSequence> command_buffer_sequences =
+ CommandBufferScheduling::CollectCommandBufferSequences(seq, config);
+ EXPECT_EQ(command_buffer_sequences.size(), 2);
+
+ std::vector<HloInstruction*> seq_0 =
+ command_buffer_sequences[0].instructions();
+ EXPECT_EQ(seq_0.size(), 3);
+ EXPECT_EQ(seq_0[0]->opcode(), HloOpcode::kFusion);
+ EXPECT_EQ(seq_0[1]->opcode(), HloOpcode::kGetTupleElement);
+ EXPECT_EQ(seq_0[2]->opcode(), HloOpcode::kFusion);
+
+ std::vector<HloInstruction*> seq_1 =
+ command_buffer_sequences[1].instructions();
+ EXPECT_EQ(seq_1.size(), 2);
+ EXPECT_EQ(seq_1[0]->opcode(), HloOpcode::kFusion);
+ EXPECT_EQ(seq_1[1]->opcode(), HloOpcode::kFusion);
+}
+
+TEST_F(CommandBufferSchedulingTest, MoveParametersToFront) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
+ %a = s32[] parameter(0)
+ %b = s32[] parameter(1)
+ %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+ %c = s32[] parameter(2)
+ ROOT %fusion.1 = s32[] fusion(s32[] %a, s32[] %c), kind=kLoop, calls=%fused_computation.1
+ })";
+
+ const char* expected = R"(
+// CHECK: ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
+// CHECK: %a = s32[] parameter(0)
+// CHECK: %b = s32[] parameter(1)
+// CHECK: %c = s32[] parameter(2)
+// CHECK: %fusion = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation
+// CHECK: ROOT %fusion.1 = s32[] fusion(%a, %c), kind=kLoop, calls=%fused_computation.1
+// CHECK: })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ TF_ASSERT_OK(CommandBufferScheduling::MoveParametersAndConstantsToFront(
+ module->entry_computation()));
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(
+ module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+ expected));
+ EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(CommandBufferSchedulingTest, PrepareCommandBuffer) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation(param_0: s32[], param_1: s32[]) -> (s32[], s32[]) {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %tuple.1 = (s32[], s32[]) tuple(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+ %a = s32[] parameter(0)
+ %b = s32[] custom-call(), custom_call_target="target"
+ %fusion = (s32[], s32[]) fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+ %d = s32[] get-tuple-element((s32[], s32[]) %fusion), index=0
+ %fusion.1 = s32[] fusion(s32[] %a, s32[] %d), kind=kLoop, calls=%fused_computation.1
+ ROOT %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %d), custom_call_target="some target"
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule(hlo));
+
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 6);
+ std::vector<HloInstruction*> instructions;
+ HloInstructionSequence seq;
+ for (HloInstruction* inst : module->entry_computation()->instructions()) {
+ if (inst->opcode() == HloOpcode::kFusion ||
+ inst->opcode() == HloOpcode::kGetTupleElement) {
+ seq.push_back(inst);
+ }
+ instructions.push_back(inst);
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ CommandBuffer command_buffer,
+ CommandBufferScheduling::PrepareCommandBuffer(seq, module.get()));
+ HloComputation* computation = module->AddComputation(
+ std::move(command_buffer.computation), /*is_entry=*/false);
+
+ const char* expected = R"(
+// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
+// CHECK: %[[P0]] = s32[] parameter(0)
+// CHECK: %[[P1]] = s32[] parameter(1)
+// CHECK: %fusion = (s32[], s32[]) fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%fusion), index=0
+// CHECK: %fusion.1 = s32[] fusion(%[[P0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
+// CHECK: ROOT {{.*}} = (s32[], s32[]) tuple(%[[V0]], %fusion.1)
+// CHECK:})";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(computation->ToString(
+ HloPrintOptions{}.set_print_operand_shape(false)),
+ expected));
+ EXPECT_TRUE(filecheck_matches);
+
+ auto& arguments = command_buffer.arguments;
+ ASSERT_EQ(arguments.size(), 2);
+ EXPECT_EQ(arguments[0], instructions[0]);
+ EXPECT_EQ(arguments[1], instructions[1]);
+
+ auto& results = command_buffer.results;
+ ASSERT_EQ(results.size(), 2);
+ EXPECT_EQ(results[0], instructions[3]);
+ EXPECT_EQ(results[1], instructions[4]);
+}
+
+TEST_F(CommandBufferSchedulingTest, ForwardControlDependencies) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.2 (param_0: s32[], param_1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+ %a = s32[] parameter(0)
+ %b = s32[] parameter(1)
+ %custom-call = s32[] custom-call(), custom_call_target="some target"
+ %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation, control-predecessors={%custom-call}
+ %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion}
+ %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
+ %fusion.2 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%fusion.1}
+ ROOT %custom-call.2 = s32[] custom-call(s32[] %fusion.1, s32[] %fusion.2), custom_call_target="some target"
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
+ CHECK: %[[P0]] = s32[] parameter(0)
+ CHECK: %[[P1]] = s32[] parameter(1)
+ CHECK: %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]])
+ CHECK: ROOT {{.*}} = s32[] fusion(%[[P0]], %[[P1]]), {{.*}} control-predecessors={%[[F0]]}
+ CHECK: }
+
+ CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+ CHECK: %a = s32[] parameter(0)
+ CHECK: %b = s32[] parameter(1)
+ CHECK: %custom-call = s32[] custom-call(), custom_call_target="some target"
+ CHECK: %call = s32[] call(%a, %b), to_apply=%command_buffer, control-predecessors={%custom-call}
+ CHECK: %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
+ CHECK: %[[F3:.+]] = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%call}
+ CHECK: ROOT %custom-call.2 = s32[] custom-call(%call, %[[F3]]), custom_call_target="some target"
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, ForwardControlDependenciesToParams) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation.0 (p0: s32[], p1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ %fused_computation.1 (p0: s32[], p1: s32[]) -> s32[] {
+ %p0 = s32[] parameter(0)
+ %p1 = s32[] parameter(1)
+ ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+ }
+
+ ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+ %a = s32[] parameter(0)
+ %b = s32[] parameter(1)
+ %custom-call = s32[] custom-call(), custom_call_target="some target"
+ %fusion = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.0, control-predecessors={%custom-call}
+ ROOT %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %b), kind=kLoop, calls=%fused_computation.1
+ })";
+
+ const char* expected = R"(
+ CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+ CHECK: %a = s32[] parameter(0)
+ CHECK: %b = s32[] parameter(1)
+ CHECK: %[[CUSTOM_CALL:.+]] = s32[] custom-call(), custom_call_target="some target"
+ CHECK: ROOT {{.*}} call(%[[CUSTOM_CALL]], %a, %b), to_apply=%command_buffer, control-predecessors={%[[CUSTOM_CALL]]}
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, WhileNotCommand) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation (param_0: f32[1]) -> f32[1] {
+ %param_0 = f32[1]{0} parameter(0)
+ ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
+ }
+
+ %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
+ %param_0.1 = f32[1]{0} parameter(0)
+ %param_1 = f32[1]{0} parameter(1)
+ ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
+ }
+
+ %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
+ %param_0.2 = f32[1]{0} parameter(0)
+ %param_1.1 = f32[1]{0} parameter(1)
+ ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
+ }
+
+ %fused_computation.3 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
+ %param_0.1 = f32[1]{0} parameter(0)
+ %param_1 = f32[1]{0} parameter(1)
+ ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
+ }
+
+ %body (Arg_.3: f32[1]) -> f32[1] {
+ %constant_4 = f32[1]{0} constant({1})
+ %Arg_.3 = f32[1]{0} parameter(0)
+ %custom-call = s32[] custom-call(), custom_call_target="some target"
+ %add = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1, control-predecessors={%custom-call}
+ ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %add, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.3, control-predecessors={%custom-call}
+ }
+
+ %cond (Arg_.11: f32[1]) -> pred[] {
+ %constant = f32[1]{0} constant({100})
+ %Arg_.11 = f32[1]{0} parameter(0)
+ %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
+ ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
+ }
+
+ ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
+ %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
+ %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
+ %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
+ ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: f32[1], [[P1:.+]]: f32[1]) -> f32[1] {
+ CHECK: %[[P0]] = f32[1]{0} parameter(0)
+ CHECK: %[[P1]] = f32[1]{0} parameter(1)
+ CHECK: %[[ADD:.*]] = f32[1]{0} fusion(%[[P0]], %[[P1]]), kind=kLoop
+ CHECK: ROOT {{.*}} = f32[1]{0} fusion(%[[ADD]], %[[P1]]), kind=kLoop
+ CHECK: }
+
+ CHECK: %[[BODY:[a-z_0-9.]+]] ([[P0:.+]]: f32[1]) -> f32[1] {
+ CHECK: %[[C1:.*]] = f32[1]{0} constant({1})
+ CHECK: %[[P0]] = f32[1]{0} parameter(0)
+ CHECK: %[[CC:.*]] = s32[] custom-call(), custom_call_target="some target"
+ CHECK: ROOT %call = f32[1]{0} call(%[[P0]], %[[C1]]), to_apply=%command_buffer, control-predecessors={%[[CC]]}
+ CHECK: }
+
+ CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
+ CHECK: %[[ARG0]] = f32[1]{0} parameter(0)
+ CHECK: %[[COPY:.*]] = f32[1]{0} fusion(%[[ARG0]]), kind=kLoop
+ CHECK: %[[WHILE:.*]] = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY]]
+ CHECK: ROOT %[[BC:.+]] = f32[] bitcast(%[[WHILE]])
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, While) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation (param_0: f32[1]) -> f32[1] {
+ %param_0 = f32[1]{0} parameter(0)
+ ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
+ }
+
+ %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
+ %param_0.1 = f32[1]{0} parameter(0)
+ %param_1 = f32[1]{0} parameter(1)
+ ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
+ }
+
+ %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
+ %param_0.2 = f32[1]{0} parameter(0)
+ %param_1.1 = f32[1]{0} parameter(1)
+ ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
+ }
+
+ %body (Arg_.3: f32[1]) -> f32[1] {
+ %constant_4 = f32[1]{0} constant({1})
+ %Arg_.3 = f32[1]{0} parameter(0)
+ ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1
+ }
+
+ %cond (Arg_.11: f32[1]) -> pred[] {
+ %constant = f32[1]{0} constant({100})
+ %Arg_.11 = f32[1]{0} parameter(0)
+ %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
+ ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
+ }
+
+ ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
+ %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
+ %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
+ %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
+ ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: f32[1]) -> f32[1] {
+ CHECK: %[[P0]] = f32[1]{0} parameter(0)
+ CHECK: %[[COPY:.*]] = f32[1]{0} fusion(%[[P0]]), kind=kLoop
+ CHECK: ROOT {{.*}} = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY:[a-z_0-9.]+]]
+ CHECK: }
+
+ CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
+ CHECK: %[[ARG0]] = f32[1]{0} parameter(0)
+ CHECK: %call = f32[1]{0} call(%[[ARG0]]), to_apply=%command_buffer
+ CHECK: ROOT %[[BC:.+]] = f32[] bitcast(%call)
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, Conditional) {
+ const char* hlo = R"(
+ HloModule TestModule, is_scheduled=true
+
+ %fused_computation.1 (param_0.2: s32[5]) -> s32[5] {
+ %param_0.2 = s32[5]{0} parameter(0)
+ ROOT %negate.2 = s32[5]{0} negate(s32[5]{0} %param_0.2)
+ }
+
+ %region_0.7 (Arg_.8: s32[5]) -> (s32[5]) {
+ %Arg_.8 = s32[5]{0} parameter(0)
+ %wrapped_negate.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.8), kind=kLoop, calls=%fused_computation.1
+ ROOT %tuple.3 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_negate.1)
+ }
+
+ %fused_computation.2 (param_0.3: s32[5]) -> s32[5] {
+ %param_0.3 = s32[5]{0} parameter(0)
+ ROOT %not.2 = s32[5]{0} not(s32[5]{0} %param_0.3)
+ }
+
+ %region_1.10 (Arg_.11: s32[5]) -> (s32[5]) {
+ %Arg_.11 = s32[5]{0} parameter(0)
+ %wrapped_not.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.11), kind=kLoop, calls=%fused_computation.2
+ ROOT %tuple.4 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_not.1)
+ }
+
+ %fused_computation.3 (param_0.4: s32[5]) -> s32[5] {
+ %param_0.4 = s32[5]{0} parameter(0)
+ ROOT %multiply.2 = s32[5]{0} multiply(s32[5]{0} %param_0.4, s32[5]{0} %param_0.4)
+ }
+
+ %region_2.13 (Arg_.14: s32[5]) -> (s32[5]) {
+ %Arg_.14 = s32[5]{0} parameter(0)
+ %wrapped_multiply.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.14), kind=kLoop, calls=%fused_computation.3
+ ROOT %tuple.5 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_multiply.1)
+ }
+
+ %fused_computation (param_0.1: s64[]) -> s32[] {
+ %constant_1 = s32[] constant(0)
+ %param_0.1 = s64[] parameter(0)
+ %convert.2 = s32[] convert(s64[] %param_0.1)
+ %constant_0 = s32[] constant(2)
+ ROOT %clamp.2 = s32[] clamp(s32[] %constant_1, s32[] %convert.2, s32[] %constant_0)
+ }
+
+ ENTRY %main.17 (Arg_0.1: s64[], Arg_1.2: s32[5]) -> s32[5] {
+ %Arg_0.1 = s64[] parameter(0), sharding={replicated}
+ %fusion = s32[] fusion(s64[] %Arg_0.1), kind=kLoop, calls=%fused_computation
+ %Arg_1.2 = s32[5]{0} parameter(1), sharding={replicated}
+ %conditional.16.clone = (s32[5]{0}) conditional(s32[] %fusion, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2), branch_computations={%region_0.7, %region_1.10, %region_2.13}
+ ROOT %get-tuple-element = s32[5]{0} get-tuple-element((s32[5]{0}) %conditional.16.clone), index=0
+ })";
+
+ const char* expected = R"(
+ CHECK: %command_buffer ([[P0:.+]]: s64[], [[P1:.+]]: s32[5]) -> (s32[5]) {
+ CHECK: %[[P0]] = s64[] parameter(0)
+ CHECK: %[[P1]] = s32[5]{0} parameter(1)
+ CHECK: %[[FUSION:.*]] = s32[] fusion(%[[P0]]), kind=kLoop
+ CHECK: ROOT {{.*}} = (s32[5]{0}) conditional(%[[FUSION]], %[[P1]], %[[P1]], %[[P1]]), branch_computations={%[[B1:[a-z_0-9.]+]], %[[B2:[a-z_0-9.]+]], %[[B3:[a-z_0-9.]+]]}
+ CHECK: }
+
+ CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: s64[], [[ARG1:.+]]: s32[5]) -> s32[5] {
+ CHECK: %[[ARG0]] = s64[] parameter(0)
+ CHECK: %[[ARG1]] = s32[5]{0} parameter(1)
+ CHECK: %call = (s32[5]{0}) call(%[[ARG0]], %[[ARG1]]), to_apply=%command_buffer
+ CHECK: ROOT %[[GEP:.+]] = s32[5]{0} get-tuple-element(%call)
+ CHECK: })";
+
+ RunAndFilecheckHloRewrite(
+ hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ expected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+TEST_F(CommandBufferSchedulingTest, CuDnnFusionGraphCaptureWorks) {
+ const std::string kHloText = R"(
+HloModule m, is_scheduled=true
+
+fusion0 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ ROOT d = f32[64,64] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+fusion1 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ ROOT d = f32[64,64] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+}
+
+fusion_a {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ ROOT a = f32[64,64] add(p0, p1)
+}
+
+ENTRY e {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ d0 = f32[64,64] fusion(p0, p1), kind=kCustom,
+ calls=fusion0,
+ backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
+ a = f32[64,64] fusion(d0, d0), kind=kLoop, calls=fusion_a
+ ROOT d1 = f32[64,64] fusion(a, p1), kind=kCustom,
+ calls=fusion1,
+ backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
+})";
+
+ const std::string kExpected = R"(
+; CHECK: ENTRY
+; CHECK-NEXT: parameter
+; CHECK-NEXT: parameter
+; CHECK-NEXT: ROOT
+; CHECK-SAME: call(
+; CHECK-SAME: to_apply=%command_buffer
+})";
+
+ RunAndFilecheckHloRewrite(
+ kHloText,
+ CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+ kExpected, [](HloModule* module) {
+ EXPECT_TRUE(module->has_schedule());
+ TF_CHECK_OK(module->schedule().Verify());
+ });
+}
+
+} // namespace
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc
new file mode 100644
index 0000000..f072a91
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc
@@ -0,0 +1,461 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/shape_inference.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/window_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
+ CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
+ conv.custom_call_target() ==
+ kCudnnConvBiasActivationForwardCallTarget ||
+ conv.custom_call_target() == kCudnnConvForwardGraphCallTarget);
+ return window_util::HasSymmetricPadding(conv.window()) &&
+ !window_util::HasNegativePadding(conv.window()) &&
+ !window_util::HasDilation(conv.window());
+}
+
+// If the (positive and negative) padding on the input operand of a convolution
+// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
+// dilation), returns kPad and/or kSlice instructions that explicitly apply the
+// padding; otherwise returns the original input operand. When there is both
+// positive padding (including dilation) and negative padding, we insert both
+// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
+// into a kPad or kSlice op.
+HloInstruction* MaybePaddedAndSlicedInput(
+ Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
+ HloInstruction* input) {
+ HloComputation* computation = input->parent();
+ if (!window_util::HasSymmetricPadding(*conv_window) ||
+ window_util::HasBaseDilation(*conv_window)) {
+ // If padding is uneven or has dilation, we insert a kPad instruction that
+ // applies positive padding and dilation.
+ //
+ // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
+ // moving all the padding into an explicit pad op, we should keep as much
+ // padding inside of cudnn as possible, on the assumption that padding
+ // within cudnn is basically free, whereas a kPad's cost increases as the
+ // amount of padding increases.
+ PaddingConfig padding_config =
+ MakeNoPaddingConfig(input->shape().dimensions_size());
+ for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
+ int64_t dim = conv_dnums.input_spatial_dimensions(i);
+ if (conv_window->dimensions(i).padding_low() > 0) {
+ padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+ conv_window->dimensions(i).padding_low());
+ conv_window->mutable_dimensions(i)->set_padding_low(0);
+ }
+ if (conv_window->dimensions(i).padding_high() > 0) {
+ padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+ conv_window->dimensions(i).padding_high());
+ conv_window->mutable_dimensions(i)->set_padding_high(0);
+ }
+ if (conv_window->dimensions(i).base_dilation() != 1) {
+ padding_config.mutable_dimensions(dim)->set_interior_padding(
+ conv_window->dimensions(i).base_dilation() - 1);
+ conv_window->mutable_dimensions(i)->set_base_dilation(1);
+ }
+ }
+ PrimitiveType element_type = input->shape().element_type();
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+ input =
+ MakePadHlo(input, padding, padding_config, &input->metadata()).value();
+ }
+
+ if (window_util::HasNegativePadding(*conv_window)) {
+ // If the window has negative padding, insert a kSlice that explicitly
+ // applies negative padding.
+ //
+ // For each dimension, initialize the start index to 0 and the limit index
+ // to the size of that dimension.
+ std::vector<int64_t> start_indices(input->shape().dimensions_size(), 0);
+ std::vector<int64_t> limit_indices(input->shape().dimensions().begin(),
+ input->shape().dimensions().end());
+ std::vector<int64_t> strides(input->shape().dimensions_size(), 1);
+ for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
+ int64_t dim = conv_dnums.input_spatial_dimensions(i);
+ // If dimension "dim" has negative padding, increase the start index or
+ // decrement the limit index by the amount of negative padding.
+ if (conv_window->dimensions(i).padding_low() < 0) {
+ start_indices[dim] += -conv_window->dimensions(i).padding_low();
+ conv_window->mutable_dimensions(i)->set_padding_low(0);
+ }
+ if (conv_window->dimensions(i).padding_high() < 0) {
+ limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
+ conv_window->mutable_dimensions(i)->set_padding_high(0);
+ }
+ }
+
+ input = MakeSliceHlo(input, start_indices, limit_indices, strides).value();
+ }
+
+ return input;
+}
+
+// If the padding on the kernel operand of a convolution can't be folded into a
+// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
+// explicitly applies the padding; otherwise returns the original kernel
+// operand.
+HloInstruction* MaybePaddedKernel(const Window& conv_window,
+ const ConvolutionDimensionNumbers& conv_dnums,
+ HloInstruction* kernel) {
+ if (!window_util::HasWindowDilation(conv_window)) {
+ return kernel;
+ }
+
+ // Compute the shape and padding config of the pad to be inserted.
+ PaddingConfig padding_config;
+ padding_config.mutable_dimensions()->Reserve(
+ kernel->shape().dimensions_size());
+ for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
+ padding_config.add_dimensions();
+ }
+ for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
+ int64_t dim = conv_dnums.kernel_spatial_dimensions(i);
+ padding_config.mutable_dimensions(dim)->set_interior_padding(
+ conv_window.dimensions(i).window_dilation() - 1);
+ }
+
+ HloComputation* computation = kernel->parent();
+ PrimitiveType element_type = kernel->shape().element_type();
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+ return MakePadHlo(kernel, padding, padding_config, &kernel->metadata())
+ .value();
+}
+} // namespace
+
+bool ConvPaddingLegalization::CanonicalizeForwardConvolution(
+ HloInstruction* conv) {
+ if (IsForwardConvolutionCanonical(*conv)) {
+ return false;
+ }
+
+ // Insert slices and/or pads between the convolution and its input and/or
+ // kernel operand.
+ Window new_conv_window = conv->window();
+ HloInstruction* new_input = MaybePaddedAndSlicedInput(
+ &new_conv_window, conv->convolution_dimension_numbers(),
+ conv->mutable_operand(0));
+ HloInstruction* new_kernel =
+ MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
+ conv->mutable_operand(1));
+
+ // Remove the window dilation from convolution's window field. These paddings
+ // are made explicit with the pads inserted by MaybePaddedKernel().
+ for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
+ WindowDimension* dim = new_conv_window.mutable_dimensions(i);
+
+ // The size of the kernel may have changed so update the Window to match.
+ dim->set_size(new_kernel->shape().dimensions(
+ conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
+ dim->set_window_dilation(1);
+ }
+
+ // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
+ // out the shape of conv_result.
+ VLOG(1) << "Canonicalizing forward conv";
+ std::vector<HloInstruction*> operands(conv->operands().begin(),
+ conv->operands().end());
+ operands[0] = new_input;
+ operands[1] = new_kernel;
+ auto new_conv = conv->parent()->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), operands));
+ new_conv->set_window(new_conv_window);
+ VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
+ << new_conv->ToString();
+ TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
+ return true;
+}
+
+namespace {
+void IncreasePaddingLowBy(int64_t delta, WindowDimension* window_dim) {
+ window_dim->set_padding_low(window_dim->padding_low() + delta);
+}
+
+void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) {
+ window_dim->set_padding_high(window_dim->padding_high() + delta);
+}
+} // namespace
+
+bool ConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
+ HloInstruction* backward_conv) {
+ CHECK_EQ(backward_conv->custom_call_target(),
+ kCudnnConvBackwardFilterCallTarget);
+ if (window_util::HasSymmetricPadding(backward_conv->window())) {
+ return false;
+ }
+
+ // A backward filter convolution with uneven padding can be canonicalized to
+ // one with even padding by padding the activations (input) beforehand. For
+ // example,
+ // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
+ // is equivalent to
+ // ABCD0 = Pad(ABCD, padding_high=1)
+ // BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
+ // We choose the lesser of padding_low and padding_high as the new padding.
+ HloInstruction* input = backward_conv->mutable_operand(0);
+ Window new_backward_conv_window = backward_conv->window();
+ // input_padding_config is the config of the kPad to be inserted.
+ PaddingConfig input_padding_config =
+ MakeNoPaddingConfig(input->shape().rank());
+ ConvolutionDimensionNumbers backward_conv_dnums =
+ backward_conv->convolution_dimension_numbers();
+ for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+ int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
+ int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
+ if (padding_low < 0 || padding_high < 0) {
+ // TODO(b/32744257): The following canonicalization wouldn't remove
+ // negative padding in a backward convolution, and would therefore cause
+ // cuDNN convolution (which doesn't support negative padding) to fail.
+ return false;
+ }
+ // Compute the new, even padding for the backward conv operation.
+ int64_t new_conv_padding = std::min(padding_low, padding_high);
+ int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
+ input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+ padding_low - new_conv_padding);
+ input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+ padding_high - new_conv_padding);
+
+ // Since we move some padding from the backward convolution to the kPad, we
+ // need to accordingly reduce the padding amount of the backward convolution
+ // and its inner forward convolution.
+ auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
+ new_dim->set_padding_low(new_conv_padding);
+ new_dim->set_padding_high(new_conv_padding);
+ }
+
+ // Create a new backward convolution replacing the old one.
+ HloComputation* computation = backward_conv->parent();
+ HloInstruction* output = backward_conv->mutable_operand(1);
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(input->shape().element_type())));
+ HloInstruction* padded_input =
+ MakePadHlo(input, padding, input_padding_config).value();
+
+ // The shape of the backward_conv CustomCall is a tuple (conv_result,
+ // scratch_buffer). Extract out the shape of conv_result.
+ HloInstruction* new_backward_conv =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ backward_conv->shape(), {padded_input, output}));
+ new_backward_conv->set_window(new_backward_conv_window);
+
+ VLOG(1) << "Canonicalizing backward filter conv";
+ VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
+ << new_backward_conv->ToString();
+
+ TF_CHECK_OK(
+ computation->ReplaceInstruction(backward_conv, new_backward_conv));
+ return true;
+}
+
+bool ConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
+ HloInstruction* backward_conv) {
+ if (window_util::HasSymmetricPadding(backward_conv->window())) {
+ return false;
+ }
+
+ Window new_backward_conv_window = backward_conv->window();
+ ConvolutionDimensionNumbers backward_conv_dnums =
+ backward_conv->convolution_dimension_numbers();
+
+ // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
+ // Get the shape of conv_result.
+ Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
+
+ Shape new_backward_conv_shape = backward_conv_shape;
+ for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+ int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
+ int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
+ if (padding_low < 0 || padding_high < 0) {
+ // TODO(b/32744257): The following canonicalization wouldn't remove
+ // negative padding in a backward convolution, and would therefore cause
+ // cuDNN convolution (which doesn't support negative padding) to fail.
+ return false;
+ }
+ // If the backward convolution has uneven padding on the activations, we
+ // move some padding on the larger end to "internal" padding, so that the
+ // backward convolution produces larger activations which get sliced later.
+ //
+ // For example, suppose we have a non-canonical HLO
+ // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
+ // where the amount of padding low is larger, we can canonicalize it to
+ // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
+ // [A] = Slice([B A])
+ if (padding_low > padding_high) {
+ IncreasePaddingLowBy(padding_high - padding_low,
+ new_backward_conv_window.mutable_dimensions(i));
+ } else if (padding_low < padding_high) {
+ IncreasePaddingHighBy(padding_low - padding_high,
+ new_backward_conv_window.mutable_dimensions(i));
+ }
+ // Decreasing the padding by X *increases* the size of our output by X.
+ // Note that we have swapped input spatial dimensions with output spatial
+ // dimensions to be compatible with the cuDNN API, so
+ // input_spatial_dimensions(i) gives the i-th spatial dimension of the
+ // output.
+ int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
+ new_backward_conv_shape.set_dimensions(
+ dim, new_backward_conv_shape.dimensions(dim) +
+ std::abs(padding_low - padding_high));
+ }
+
+ // Create a new backward convolution replacing the old one.
+ HloComputation* computation = backward_conv->parent();
+ HloInstruction* output = backward_conv->mutable_operand(0);
+ HloInstruction* filter = backward_conv->mutable_operand(1);
+
+ HloInstruction* new_backward_conv_call =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ ShapeUtil::MakeTupleShape(
+ {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
+ {output, filter}));
+ new_backward_conv_call->set_window(new_backward_conv_window);
+
+ // The CustomCall created above returns a tuple (conv_result, scratch_memory).
+ // Extract out the two elements.
+ HloInstruction* new_backward_conv =
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_backward_conv_shape, new_backward_conv_call, 0));
+ HloInstruction* new_backward_conv_scratch =
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_backward_conv_call->shape().tuple_shapes(1),
+ new_backward_conv_call, 1));
+
+ // Slice the new backward convolution.
+ //
+ // Initialize start_indices and limit_indices as no slicing.
+ std::vector<int64_t> start_indices(
+ new_backward_conv->shape().dimensions_size(), 0LL);
+ std::vector<int64_t> limit_indices(
+ new_backward_conv->shape().dimensions().begin(),
+ new_backward_conv->shape().dimensions().end());
+ std::vector<int64_t> strides(new_backward_conv->shape().dimensions_size(),
+ 1LL);
+ for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+ int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
+ int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
+ // Note that we have swapped input spatial dimensions with output spatial
+ // dimensions to be compatible with the cuDNN API, so
+ // input_spatial_dimensions(i) gives the i-th spatial dimension of the
+ // output.
+ int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
+ if (padding_low > padding_high) {
+ // If the amount of low padding (of the old backward convolution) is
+ // larger, we internally pad the low end of the activations and slice
+ // internal padding out here.
+ start_indices[dim] += padding_low - padding_high;
+ } else if (padding_low < padding_high) {
+ // If the amount of high padding is larger, we slice out the internal
+ // padding on the high end.
+ limit_indices[dim] -= padding_high - padding_low;
+ }
+ }
+
+ // Replace the old backward convolution with the slice.
+ Shape slice_shape =
+ ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
+ limit_indices, strides)
+ .value();
+ CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
+ << ShapeUtil::HumanString(slice_shape) << " vs "
+ << ShapeUtil::HumanString(backward_conv_shape);
+
+ HloInstruction* slice = computation->AddInstruction(
+ HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
+ start_indices, limit_indices, strides));
+ HloInstruction* new_tuple = computation->AddInstruction(
+ HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
+
+ VLOG(1) << "Canonicalizing backward input conv";
+ VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
+ << new_tuple->ToString();
+
+ TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
+ return true;
+}
+
+absl::StatusOr<bool> ConvPaddingLegalization::RunOnComputation(
+ HloComputation* computation) {
+ bool changed = false;
+ std::vector<HloCustomCallInstruction*> convs;
+ for (auto* instr : computation->instructions()) {
+ if (IsCustomCallToDnnConvolution(*instr)) {
+ convs.push_back(Cast<HloCustomCallInstruction>(instr));
+ }
+ }
+ for (HloCustomCallInstruction* instruction : convs) {
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
+ changed |= [&] {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ case CudnnConvKind::kForwardGraph:
+ return CanonicalizeForwardConvolution(instruction);
+ case CudnnConvKind::kBackwardInput:
+ return CanonicalizeBackwardInputConvolution(instruction);
+ case CudnnConvKind::kBackwardFilter:
+ return CanonicalizeBackwardFilterConvolution(instruction);
+ }
+ }();
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> ConvPaddingLegalization::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h
new file mode 100644
index 0000000..1841c92
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
+// inserts Pad instructions before Convolution instructions with uncanonicalized
+// padding, so that they can be lowered to Cudnn/Miopen convolution.
+class ConvPaddingLegalization : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "conv-padding-legalization";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+ // Returns if any changes are made to the parent computation.
+ bool CanonicalizeForwardConvolution(HloInstruction* conv);
+ bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv);
+ bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv);
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc
new file mode 100644
index 0000000..06682e7
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc
@@ -0,0 +1,96 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using ConvPaddingLegalizationTest = HloTestBase;
+
+TEST_F(ConvPaddingLegalizationTest, BackwardInputConvolve) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule convolution_module
+ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) {
+ %operand = f64[2,2,2,3]{3,2,1,0} parameter(0)
+ %kernel = f64[2,3,2,3]{3,2,1,0} constant(
+ {
+ { /*i0=0*/
+ { /*i1=0*/
+ { 0.29629629629629628, 0.30246913580246915, 0.30864197530864196 },
+ { 0.31481481481481483, 0.32098765432098764, 0.3271604938271605 }
+ },
+ { /*i1=1*/
+ { 0.25925925925925924, 0.26543209876543211, 0.27160493827160492 },
+ { 0.27777777777777779, 0.2839506172839506, 0.29012345679012347 }
+ },
+ { /*i1=2*/
+ { 0.22222222222222221, 0.22839506172839505, 0.23456790123456789 },
+ { 0.24074074074074073, 0.24691358024691357, 0.25308641975308643 }
+ }
+ },
+ { /*i0=1*/
+ { /*i1=0*/
+ { 0.18518518518518517, 0.19135802469135801, 0.19753086419753085 },
+ { 0.20370370370370369, 0.20987654320987653, 0.21604938271604937 }
+ },
+ { /*i1=1*/
+ { 0.14814814814814814, 0.15432098765432098, 0.16049382716049382 },
+ { 0.16666666666666666, 0.1728395061728395, 0.17901234567901234 }
+ },
+ { /*i2=2*/
+ { 0.1111111111111111, 0.11728395061728394, 0.12345679012345678 },
+ { 0.12962962962962962, 0.13580246913580246, 0.1419753086419753 }
+ }
+ }
+ })
+ %reverse = f64[2,3,2,3]{3,2,1,0} reverse(%kernel), dimensions={0,1}
+ ROOT %custom-call = (f64[2,2,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f64[2,2,2,3]{3,2,1,0} %operand, f64[2,3,2,3]{3,2,1,0} %reverse), window={size=2x3 stride=2x2 pad=0_0x0_1}, dim_labels=bf01_01io->b01f, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
+}
+ )")
+ .value();
+ ASSERT_TRUE(ConvPaddingLegalization().Run(module.get()).value());
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::Tuple(
+ m::Slice(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardInputCallTarget},
+ m::Op(), m::Reverse(m::Constant())),
+ 0)),
+ m::GetTupleElement())));
+ auto slice = root->operand(0);
+ Shape expected_slice_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 4});
+ EXPECT_TRUE(ShapeUtil::Equal(slice->shape(), expected_slice_shape));
+ auto conv = slice->operand(0);
+ Shape expected_conv_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 5});
+ EXPECT_TRUE(ShapeUtil::Equal(conv->shape(), expected_conv_shape));
+}
+
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc
new file mode 100644
index 0000000..e19622d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc
@@ -0,0 +1,869 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/window_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+absl::Status CheckTypes(HloInstruction* conv,
+ const se::GpuComputeCapability cc) {
+ auto valid_shape = [conv, &cc](const Shape& shape) -> absl::Status {
+ PrimitiveType type = shape.element_type();
+ if (!primitive_util::IsFloatingPointType(type) &&
+ !primitive_util::IsIntegralType(type)) {
+ // Among integral types, only S8 is supported. But CudnnFusedConvRewriter
+ // may rewrite convolutions of wider types into S8 convolutions, so allow
+ // all integral convolutions here.
+ return Unimplemented(
+ "Convolutions must have floating-point or integral operands/outputs, "
+ "but got convolution with type %s: %s",
+ primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
+ }
+ if (primitive_util::IsF8Type(type)) {
+ if (type != F8E4M3FN && type != F8E5M2) {
+ return Unimplemented(
+ "The only FP8 types supported in convolutions are f8e5m2 and "
+ "f8e4m3, "
+ "but got convolution with FP8 type %s: %s",
+ primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
+ }
+ if (!std::holds_alternative<se::CudaComputeCapability>(cc)) {
+ return Unimplemented(
+ "FP8 convolutions are only supported on CUDA GPUs, but got "
+ "FP8 convolution on ROCm GPU: %s",
+ conv->ToString());
+ } else if (!std::get<se::CudaComputeCapability>(cc).IsAtLeastHopper()) {
+ return Unimplemented(
+ "FP8 convolutions are only supported on CUDA GPUs with compute "
+ "capability at least 9.0, but got "
+ "FP8 convolution on GPU with compute capability %s: %s",
+ std::get<se::CudaComputeCapability>(cc).ToString(),
+ conv->ToString());
+ }
+ }
+ return absl::OkStatus();
+ };
+
+ TF_RETURN_IF_ERROR(valid_shape(conv->shape()));
+ TF_RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape()));
+ TF_RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape()));
+ return absl::OkStatus();
+}
+
+using ConvolutionMatch = std::optional<
+ std::tuple<Window, ConvolutionDimensionNumbers, HloInstruction*>>;
+
+// Determine whether conv2d is equal to conv1d.
+bool MaybeConv1dToConv2d(HloInstruction* conv) {
+ if (conv->window().dimensions().size() != 2) {
+ return false;
+ }
+ if (conv->operand(1)->opcode() != HloOpcode::kReshape) {
+ return false;
+ }
+ auto filter = conv->operand(1);
+ std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
+ filter->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+ if (reshape_degenerate.has_value() &&
+ reshape_degenerate->deleted_dimensions.empty() &&
+ reshape_degenerate->inserted_dimensions.size() == 1) {
+ const auto& dnums = conv->convolution_dimension_numbers();
+ for (auto dim : dnums.kernel_spatial_dimensions()) {
+ if (dim == reshape_degenerate->inserted_dimensions[0]) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
+ const ConvolutionDimensionNumbers& dnums =
+ conv->convolution_dimension_numbers();
+ if (dnums.input_spatial_dimensions_size() > 3) {
+ return false;
+ }
+
+ // CuDNN does not accept zero-element arguments
+ if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
+ return false;
+ }
+
+ // CuDNN can perform either cross correlation (no reversal),
+ // or convolution (all dimensions reversed).
+ if (dnums.input_spatial_dimensions_size() == 2
+ ? !window_util::AllOrNoneReversed(conv->window())
+ : window_util::HasWindowReversal(conv->window())) {
+ return false;
+ }
+ return true;
+}
+
+// Try to match a backward filter pattern that contains "conv".
+// Precondition: "conv" is a kConvolution.
+ConvolutionMatch MatchBackwardFilter(HloInstruction* conv) {
+ VLOG(2) << "Trying to match convolution backward filter.";
+
+ if (conv->feature_group_count() > 1) {
+ VLOG(1) << conv->ToString()
+ << " is a forward convolution. All grouped backward filters are "
+ "mapped to batch grouped convolutions in tf2xla bridge. Hence "
+ "backward filter "
+ "convolutions cannot have feature groups greater than 1 at this "
+ "point. No need to fold to backward filter.";
+ return std::nullopt;
+ }
+
+ // Step 1: match the instruction pattern without considering the paddings and
+ // dimension numbers just yet. We may need some generic pattern matcher
+ // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
+ //
+ // Backward filter convolution is implemented in XLA as the forward
+ // convolution of padded activations and dilated gradients. Padding on
+ // activations and dilation on gradients are specified in the "window" field
+ // of the forward convolution.
+ //
+ // activations gradients
+ // \ /
+ // v v
+ // Convolution
+ // conv
+ CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
+
+ // Step 2: match paddings and dimension numbers of the forward convolution.
+ const ConvolutionDimensionNumbers& conv_dnums =
+ conv->convolution_dimension_numbers();
+ auto input_batch_dim = conv_dnums.input_batch_dimension();
+ auto input_feature_dim = conv_dnums.input_feature_dimension();
+ auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
+ auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
+ auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
+ auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
+ auto output_batch_dim = conv_dnums.output_batch_dimension();
+ auto output_feature_dim = conv_dnums.output_feature_dimension();
+ auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
+ for (const WindowDimension& window_dim : conv->window().dimensions()) {
+ if (window_dim.stride() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have stride of 1.";
+ return std::nullopt;
+ }
+ if (window_dim.base_dilation() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have no base (LHS) dilation.";
+ return std::nullopt;
+ }
+ if (window_dim.padding_low() < 0) {
+ VLOG(1) << "Padding low should be non-negative.";
+ return std::nullopt;
+ }
+ if (window_dim.window_reversal()) {
+ VLOG(1) << "Window reversal field not supported";
+ return std::nullopt;
+ }
+ // Padding high will be checked in Step 3.
+ }
+ // Mathematically, there is no difference between convolution forward vs
+ // backward filter. A backward filter:
+ // [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
+ // Can be treated as a forward convolution with `N` treated as the new
+ // contracting (feature) dimension, `O` treated as the new batch dimension,
+ // and `C` treated as the new output feature dimension. The only difference is
+ // layouts and performance.
+ //
+ // Since there is no way to precisely tell whether we want a foward conv or
+ // backward filter conv, we have to rely on heuristics. Empirically forward
+ // convolutions have very small kernel dimensions, while in the backward pass
+ // "kernel dimensions" are large. If kernel dimensions are smaller than the
+ // output dimensions, return foward conv; otherwise proceed with backward
+ // filter conv. But for conv1d, it is not same. Due to conv1d always reshape
+ // 1D-filter to 2D-filter, even backward or forward will exist one small
+ // kernel dimension. We should handle this special case.
+ int small_kernel_dimension_num = 0;
+ for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
+ if (conv->operand(1)->shape().dimensions(kernel_spatial_dims[i]) <=
+ conv->shape().dimensions(output_spatial_dims[i])) {
+ small_kernel_dimension_num += 1;
+ }
+ }
+ if ((kernel_spatial_dims.empty() || small_kernel_dimension_num > 1 ||
+ (!MaybeConv1dToConv2d(conv) && small_kernel_dimension_num == 1)) &&
+ !window_util::HasWindowDilation(conv->window())) {
+ VLOG(1) << conv->ToString()
+ << " is a regular forward convolution. No need "
+ "to fold it to a backward filter convolution....";
+ return std::nullopt;
+ }
+
+ // Step 3: fuse the matched HLOs into a backward convolution instruction.
+ //
+ // Compute the window of the backward convolution.
+ Window backward_conv_window;
+ for (int i = 0; i < input_spatial_dims.size(); ++i) {
+ WindowDimension* dim = backward_conv_window.add_dimensions();
+ // The window size of the backward convolution equals the output size of the
+ // forward convolution.
+ int64_t filter_size = conv->shape().dimensions(output_spatial_dims[i]);
+ dim->set_size(filter_size);
+ // The window stride equals the window dilation of the forward convolution.
+ dim->set_stride(conv->window().dimensions(i).window_dilation());
+ // The window's low padding is the same as the low padding of the
+ // activations.
+ dim->set_padding_low(conv->window().dimensions(i).padding_low());
+ dim->set_base_dilation(1);
+ dim->set_window_dilation(1);
+
+ int64_t input_size =
+ conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
+ int64_t output_size = conv->window().dimensions(i).size();
+ // Compute the range of the amount of valid high padding. We first compute
+ // min_padding_high, the amount of padding on the right/bottom to ensure the
+ // last patch ends at the border, i.e.,
+ //
+ // input_size + dim->padding_low() + min_padding_high
+ // = (output_size - 1) * stride + filter_size
+ //
+ // Because convolution ignores trailing incomplete windows, any amount of
+ // padding high from min_padding_high to min_padding_high+stride-1
+ // (max_padding_high) has the same effect.
+ int64_t padded_input_size = filter_size + (output_size - 1) * dim->stride();
+ int64_t min_padding_high =
+ padded_input_size - input_size - dim->padding_low();
+ int64_t max_padding_high = min_padding_high + dim->stride() - 1;
+ CHECK_GE(dim->padding_low(), 0);
+ // In practice, since cuDNN convolution only supports even padding, we make
+ // the amount of high padding the same as the amount of low padding as long
+ // as it is between min_padding_high and max_padding_high. If it is not in
+ // that range, we pick the one that's closest to dim->padding_low() and let
+ // GpuConvPaddingLegalization canonicalize the resultant backward
+ // convolution later. Picking the closest one minimizes the cost of the kPad
+ // instruction to be inserted by GpuConvPaddingLegalization.
+ if (dim->padding_low() >= min_padding_high &&
+ dim->padding_low() <= max_padding_high) {
+ dim->set_padding_high(dim->padding_low());
+ } else {
+ if (dim->padding_low() < min_padding_high) {
+ dim->set_padding_high(min_padding_high);
+ } else {
+ dim->set_padding_high(max_padding_high);
+ }
+ }
+ if (dim->padding_high() < 0) {
+ LOG(WARNING)
+ << "Fusing this pattern to backward filter convolution would cause "
+ "negative padding ("
+ << dim->padding_high()
+ << ") on right/bottom of the weight gradients, which is not "
+ "supported by GpuConvPaddingLegalization (b/32744257). "
+ "Falling back to "
+ "unfused convolution for instruction: "
+ << conv->ToString();
+ return std::nullopt;
+ }
+ }
+
+ // Restore the dimension numbers of the backward convolution from the forward
+ // convolution. The two activation dimensions are reversed (batch and
+ // feature).
+ ConvolutionDimensionNumbers backward_conv_dnums;
+ backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
+ backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
+ for (int i = 0; i < input_spatial_dims.size(); ++i) {
+ backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
+ }
+ backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
+ backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
+ for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
+ backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
+ }
+ // The dimension numbering of the output of the forward convolution (before
+ // transposition) is the same as that of the activations (according to the
+ // semantics of kConvolution). The batch dimension of the activations should
+ // be treated as the input feature dimension, and the feature dimension should
+ // be treated as the output feature.
+ backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
+ backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
+ for (int i = 0; i < output_spatial_dims.size(); ++i) {
+ backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
+ }
+
+ HloInstruction* lhs = conv->mutable_operand(0);
+ return std::make_tuple(backward_conv_window, backward_conv_dnums, lhs);
+}
+
+// Try to match a backward input pattern that contains "conv".
+// Precondition: "conv" is a kConvolution.
+ConvolutionMatch MatchBackwardInput(HloInstruction* conv) {
+ VLOG(2) << "Trying to match convolution backward input.";
+
+ // TODO(timshen) Theoretically cuDNN supports grouped convolutions also
+ // for the backward input convolution, but based on the cudnn's current state
+ // there is not much performance improvement when using the
+ // cudnn backward input API for grouped conv.
+ // This needs to be re-evaluated for future cuDNN versions.
+ // Note that we already have the necessary code down below, the only thing to
+ // enable it is to remove the following early return.
+ if (conv->feature_group_count() > 1) {
+ return std::nullopt;
+ }
+
+ // Match instruction pattern.
+ CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
+ HloInstruction* reverse_filter = conv->mutable_operand(1);
+ ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
+
+ // Match BackwardInput for a depthwise convolution and thunk it to forward
+ // convolution Output feature dimension and input feature dimension has been
+ // swapped in the bridge. Hence to get the actual input features we need to
+ // query the output feature dimension
+ auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
+ auto kernel_out_features =
+ reverse_filter->shape().dimensions(kernel_out_feature_dim);
+
+ // For a depthwise convolution, the input features must be equal to the
+ // feature_group_count. We can leverage this property to match a depthwise
+ // convolution and thunk it to forward conv
+ if (conv->feature_group_count() > 1 &&
+ kernel_out_features == conv->feature_group_count()) {
+ return std::nullopt;
+ }
+
+ // We pattern-match to a backwards input conv if:
+ //
+ // - all spatial dims of the filter are reversed
+ //
+ // OR
+ //
+ // - filter is 1x1 or a constant AND
+ // - conv has base dilation (otherwise this is just a regular forward conv).
+ //
+ // The final criterion above is just for canonicalization; cudnn seems to run
+ // just as fast if we canonicalize 1x1/constant filters without base dilation
+ // to forward or backward convs. We canonicalize to forward conv because (a)
+ // it's more natural (constant filters usually show up when doing inference,
+ // and having backwards convolutions in inference graphs would be weird), and
+ // (b) cudnn has special fusions for forward conv plus bias and activation,
+ // and we want to pattern-match to that after running this pass.
+ bool is_reversed_filter =
+ reverse_filter->opcode() == HloOpcode::kReverse &&
+ absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+ reverse_filter->dimensions());
+ // For conv1d which reshape to conv2d, filter reverse pattern is
+ // reshape(reverse(filter)). It seems we can reuse conv2d backward input
+ // pattern matcher, but after algsimp pass, this pattern will change to
+ // reverse(reshape(filter)) and fail to match. So matching conv1d backward
+ // input need different processing logic.
+ bool is_reversed_conv1d_filter =
+ MaybeConv1dToConv2d(conv) &&
+ reverse_filter->operand(0)->opcode() == HloOpcode::kReverse;
+ bool is_1x1_filter =
+ absl::c_all_of(conv->window().dimensions(),
+ [](const WindowDimension& d) { return d.size() == 1; });
+ if (!is_reversed_filter && !is_reversed_conv1d_filter &&
+ !(window_util::HasBaseDilation(conv->window()) &&
+ (reverse_filter->IsConstant() || is_1x1_filter))) {
+ VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+ "kReverse, or it's not a base-dilated conv with a 1x1 or "
+ "constant filter.";
+ return std::nullopt;
+ }
+
+ // Match padding and dilation of the forward convolution.
+ for (const WindowDimension& window_dim : conv->window().dimensions()) {
+ if (window_dim.stride() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have stride of 1.";
+ return std::nullopt;
+ }
+ if (window_dim.window_dilation() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have no window dilation.";
+ return std::nullopt;
+ }
+ if (window_dim.window_reversal()) {
+ VLOG(1) << "Window reversal field not supported";
+ return std::nullopt;
+ }
+ }
+
+ const auto& input_spatial_dims = dnums.input_spatial_dimensions();
+ const auto& output_spatial_dims = dnums.output_spatial_dimensions();
+ CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
+ CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
+
+ const Window& old_window = conv->window();
+ Window new_window = old_window;
+ for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
+ // Restore backward convolution's padding config from the matched pattern.
+ // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
+ // convert backward input convolution to a variant of forward convolution.
+ //
+ // The stride of the backward convolution
+ // = the base dilation factor of the forward convolution
+ auto dim = new_window.mutable_dimensions(i);
+ dim->set_stride(old_window.dimensions(i).base_dilation());
+ dim->set_base_dilation(1);
+
+ // The low padding = kernel_size - 1 - low padding on the gradients
+ // Make sure the low padding is not negative.
+ auto kernel_size = old_window.dimensions(i).size();
+ auto backward_padding_low =
+ kernel_size - 1 - old_window.dimensions(i).padding_low();
+ if (backward_padding_low < 0) {
+ LOG(WARNING)
+ << "The low padding of the backward convolution would be negative ("
+ << backward_padding_low
+ << "), which isn't supported by GpuConvPaddingLegalization "
+ "for now (b/32744257).";
+ return std::nullopt;
+ }
+ dim->set_padding_low(backward_padding_low);
+
+ // Compute the range of the amount of padding on the right/bottom of the
+ // activations. XLA's convolution requires all patches to be within the
+ // padded base. This gives us flexiblity to choose the amount of high
+ // padding from a set of values without changing the result of the backward
+ // convolution. The minimum amount (min_padding_high) makes the last patch
+ // end at the border. The maximum amount (max_padding_high) equals
+ // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
+ // size to change.
+ auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
+ auto output_size =
+ conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
+ auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
+ auto total_pad_size = padded_input_size - unpadded_input_size;
+ auto min_padding_high = total_pad_size - backward_padding_low;
+ auto max_padding_high = min_padding_high + dim->stride() - 1;
+
+ if (backward_padding_low >= min_padding_high &&
+ backward_padding_low <= max_padding_high) {
+ // In the best case (most likely), if backward_padding_low is in the range
+ // of the amounts of valid high padding, we choose backward_padding_low
+ // because cuDNN supports even padding only.
+ dim->set_padding_high(backward_padding_low);
+ } else {
+ // Otherwise, we choose the amount that's closest to backward_padding_low,
+ // and GpuConvPaddingLegalization will later insert kSlice
+ // instructions to enforce even padding.
+ //
+ // For example, consider the backward convolution pattern
+ //
+ // ab xy
+ // | pad | reverse
+ // .a.b yx
+ // \ /
+ // ABC
+ //
+ // The amount of low padding on activations (in backward convolution) is
+ // backward_padding_low = kernel_size - 1 - forward_padding_low
+ // = 2 - 1 - 1 = 0
+ //
+ // The amount of padding high must be between 1 and 2, in order to make
+ // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
+ // the range of [1,2], so we pick the closest valid amount of padding
+ // high, which is 1 in this case. Therefore, we fuse the above pattern to
+ //
+ // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
+ if (backward_padding_low < min_padding_high) {
+ dim->set_padding_high(min_padding_high);
+ } else {
+ dim->set_padding_high(max_padding_high);
+ }
+ }
+ // GpuConvPaddingLegalization doesn't handle backward input
+ // convolution with negative padding for now. So fall back to unfused
+ // convolution in case of negative padding. For example,
+ // ABCD = Conv(abc, reverse(xy), padding_high=2)
+ // could be fused to
+ // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
+ // with positive padding low but negative padding high.
+ if (dim->padding_high() < 0) {
+ LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
+ "negative padding ("
+ << dim->padding_high()
+ << ") on right/bottom of the activations, which is not "
+ "supported by GpuConvPaddingLegalization (b/32744257). "
+ "Falling back to unfused convolution for instruction: "
+ << conv->ToString();
+ return std::nullopt;
+ }
+ }
+
+ // OK, it's a match! Switch the input feature dimension with the output
+ // feature dimension. Also switch the output with the input. This is the way
+ // cuDNN expects it to be.
+ auto conv_dnums = conv->convolution_dimension_numbers();
+ dnums.set_kernel_input_feature_dimension(
+ conv_dnums.kernel_output_feature_dimension());
+ dnums.set_kernel_output_feature_dimension(
+ conv_dnums.kernel_input_feature_dimension());
+ for (int i = 0; i < input_spatial_dims.size(); ++i) {
+ dnums.set_input_spatial_dimensions(i,
+ conv_dnums.output_spatial_dimensions(i));
+ dnums.set_output_spatial_dimensions(i,
+ conv_dnums.input_spatial_dimensions(i));
+ }
+ dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
+ dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
+ dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
+ dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
+
+ // If we matched against a constant, we need to add a reverse op that can be
+ // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+ // unnecessary reverses.
+ if (reverse_filter->opcode() != HloOpcode::kReverse &&
+ reverse_filter->IsConstant()) {
+ // Create a double-reverse, which is a nop.
+ HloComputation* c = conv->parent();
+ reverse_filter = c->AddInstruction(
+ HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
+ dnums.kernel_spatial_dimensions()));
+ reverse_filter = c->AddInstruction(
+ HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
+ dnums.kernel_spatial_dimensions()));
+ TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
+ }
+
+ // Calculate the 'rhs' that goes into the backward input convolution.
+ HloInstruction* rhs = reverse_filter;
+ // One reverse is subsumed by the cuDNN call.
+ if (rhs->opcode() == HloOpcode::kReverse) {
+ rhs = rhs->mutable_operand(0);
+ } else if (is_reversed_conv1d_filter) {
+ auto src = rhs->mutable_operand(0)->mutable_operand(0);
+ rhs = conv->parent()->AddInstruction(
+ HloInstruction::CreateReshape(rhs->shape(), src));
+ }
+ if (conv->feature_group_count() == 1) {
+ return std::make_tuple(new_window, dnums, rhs);
+ }
+
+ // Handle grouped convolutions. Because we swapped the input feature dimension
+ // with the output feature dimension, we need to also reshape the kernel so
+ // that the 'feature_group_count' parameter still makes sense. The
+ // 'feature_group_count' parameter essentially specifies how often the
+ // 'kernel_input_feature_dimension' is repeated. So when we swap these
+ // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+ // 'feature_group_count' and multiply the new
+ // 'kernel_output_feature_dimension' by 'feature_group_count'.
+ int64_t input_feature_dimension = dnums.kernel_input_feature_dimension();
+ int64_t output_feature_dimension = dnums.kernel_output_feature_dimension();
+ // The following code assumes that input_feature_dimension and
+ // output_feature_dimension are adjacent.
+ if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
+ return std::nullopt;
+ }
+
+ int64_t input_features = rhs->shape().dimensions(input_feature_dimension);
+ int64_t output_features = rhs->shape().dimensions(output_feature_dimension);
+
+ // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
+ // out_depth / G]
+ std::vector<int64_t> reshape_dims = SpanToVector(rhs->shape().dimensions());
+ auto num_groups = conv->feature_group_count();
+ CHECK_EQ(input_features % num_groups, 0)
+ << "Input feature count should be an exact multiple of feature group "
+ "count";
+ reshape_dims[input_feature_dimension] =
+ reshape_dims[input_feature_dimension] / num_groups;
+ reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
+ num_groups);
+
+ HloComputation* c = conv->parent();
+ rhs = c->AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
+
+ // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
+ // in_depth/G, G, out_depth / G]
+ std::vector<int64_t> transpose_dims(rhs->shape().dimensions_size());
+ std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
+ transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
+ transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
+ input_feature_dimension);
+ std::vector<int64_t> transpose_reshape_dims =
+ SpanToVector(rhs->shape().dimensions());
+ transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
+ input_feature_dimension);
+ transpose_reshape_dims.insert(
+ transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
+ rhs = c->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
+ rhs, transpose_dims));
+
+ // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
+ // in_depth/G, out_depth]
+ Shape new_shape = rhs->shape();
+ new_shape.DeleteDimension(output_feature_dimension);
+ new_shape.set_dimensions(output_feature_dimension,
+ output_features * num_groups);
+ rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+ return std::make_tuple(new_window, dnums, rhs);
+}
+
+HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
+ HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64_t feature_group_count,
+ const PrecisionConfig& precision_config,
+ const OpMetadata& metadata) {
+ HloComputation* computation = lhs->parent();
+
+ // This call returns a tuple of (conv_result, scratch_memory), where
+ // conv_result is the actual result of the convolution, and scratch_memory is
+ // temporary memory used by cudnn.
+ //
+ // At the moment, we don't know how much scratch memory this conv is going to
+ // use, so we put u8[0] in this place. Later on another pass will choose
+ // which conv algorithm to use, and at that point we'll modify the shape of
+ // this second tuple element.
+ Shape call_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
+
+ HloInstruction* custom_call = computation->AddInstruction(
+ HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
+ custom_call->set_window(window);
+ custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
+ *custom_call->mutable_precision_config() = precision_config;
+ custom_call->set_metadata(metadata);
+
+ // Give the customcall a user-friendly name.
+ std::optional<std::string> name;
+ if (call_target == kCudnnConvForwardCallTarget) {
+ name = "cudnn-conv";
+ } else if (call_target == kCudnnConvBackwardInputCallTarget) {
+ name = "cudnn-conv-bw-input";
+ } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
+ name = "cudnn-conv-bw-filter";
+ } else if (call_target == kCudnnConvBiasActivationForwardCallTarget) {
+ name = "cudnn-conv-bias-activation";
+ }
+ if (name.has_value()) {
+ computation->parent()->SetAndUniquifyInstrName(custom_call, *name);
+ }
+
+ return custom_call;
+}
+
+HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
+ HloInstruction* conv) {
+ CHECK_EQ(conv->feature_group_count(), 1);
+ int64_t num_groups = conv->batch_group_count();
+ auto dim_numbers = conv->convolution_dimension_numbers();
+ auto lhs = conv->mutable_operand(0);
+ auto rhs = conv->mutable_operand(1);
+
+ int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
+
+ Shape output_shape = conv->shape();
+ int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
+ int64_t input_feature = lhs->shape().dimensions(input_feature_dimension);
+
+ HloComputation* computation = lhs->parent();
+ auto add = [&](std::unique_ptr<HloInstruction> inst) {
+ return computation->AddInstruction(std::move(inst));
+ };
+ // Reshape batch_dim N -> [G, N/G]
+ std::vector<int64_t> reshape_dims = SpanToVector(lhs->shape().dimensions());
+ reshape_dims[input_batch_dimension] =
+ reshape_dims[input_batch_dimension] / num_groups;
+ reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
+ lhs = add(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
+
+ // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
+ // W, G, C]
+ std::vector<int64_t> transpose_dims(lhs->shape().dimensions_size());
+ std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
+ transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
+ transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
+ input_batch_dimension);
+ std::vector<int64_t> transpose_reshape_dims =
+ ComposePermutations(lhs->shape().dimensions(), transpose_dims);
+ lhs = add(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
+ lhs, transpose_dims));
+
+ // Merge [G,C] -> [C*G]
+ Shape new_shape = lhs->shape();
+ new_shape.DeleteDimension(input_feature_dimension);
+ new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
+ lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
+
+ std::vector<HloInstruction*> new_operands = {lhs, rhs};
+ auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
+ new_conv->set_feature_group_count(num_groups);
+ new_conv->set_batch_group_count(1);
+ new_conv->set_convolution_dimension_numbers(dim_numbers);
+ return computation->AddInstruction(std::move(new_conv));
+}
+
+CudnnConvBackendConfig GetDefaultBackendConfig() {
+ CudnnConvBackendConfig config;
+ config.set_conv_result_scale(1);
+ return config;
+}
+
+// Helper function to create a custom_call instruction to replace the given
+// conv instruction
+static absl::StatusOr<HloInstruction*> CreateCustomCallHelper(
+ HloInstruction* conv, const se::GpuComputeCapability& cc) {
+ TF_RETURN_IF_ERROR(CheckTypes(conv, cc));
+ if (ConvolutionMatch m = MatchBackwardInput(conv)) {
+ auto& [window, dnums, rhs] = *m;
+ return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
+ conv->mutable_operand(0), rhs, window, dnums,
+ conv->feature_group_count(), conv->precision_config(),
+ conv->metadata());
+ }
+
+ if (ConvolutionMatch m = MatchBackwardFilter(conv)) {
+ auto& [window, dnums, lhs] = *m;
+ return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
+ conv->mutable_operand(1), window, dnums,
+ conv->batch_group_count(), conv->precision_config(),
+ conv->metadata());
+ }
+
+ // If all else fails, try a forward convolution.
+ if (CanImplementAsGpuForwardConv(conv)) {
+ if (conv->batch_group_count() > 1) {
+ conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
+ }
+
+ return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
+ conv->mutable_operand(0), conv->mutable_operand(1),
+ conv->window(), conv->convolution_dimension_numbers(),
+ conv->feature_group_count(), conv->precision_config(),
+ conv->metadata());
+ }
+
+ return nullptr;
+}
+
+// Tries to rewrite a single convolution into a call to cudnn/miopen.
+absl::StatusOr<bool> RunOnInstruction(HloInstruction* conv,
+ const se::GpuComputeCapability& cc) {
+ CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
+ CreateCustomCallHelper(conv, cc));
+ if (custom_call == nullptr) {
+ return false;
+ }
+
+ GpuBackendConfig gpu_backend_config;
+ *gpu_backend_config.mutable_cudnn_conv_backend_config() =
+ GetDefaultBackendConfig();
+ TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
+
+ VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
+ << custom_call->ToString();
+
+ // The CustomCall returns a tuple (conv_result, scratch_memory). Extract
+ // out the conv result and replace `conv` with it.
+ TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
+ conv,
+ HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
+ return true;
+}
+
+// Rewrites the convolutions in the given computation into calls to
+// cudnn/miopen.
+// Returns true if it made any changes.
+absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
+ const se::GpuComputeCapability& cc) {
+ std::vector<HloInstruction*> convs;
+ for (auto* hlo : computation->instructions()) {
+ if (hlo->opcode() == HloOpcode::kConvolution) {
+ convs.push_back(hlo);
+ }
+ }
+
+ bool changed = false;
+ for (HloInstruction* conv : convs) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc));
+ changed |= result;
+ }
+ return changed;
+}
+} // namespace
+
+absl::StatusOr<bool> ConvRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_VLOG_LINES(2, "ConvRewriter::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result,
+ RunOnComputation(computation, compute_capability_));
+ changed |= result;
+ }
+ XLA_VLOG_LINES(2, "ConvRewriter::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+/*static*/ bool ConvRewriter::ConvIsLowerable(HloInstruction* conv) {
+ return CanImplementAsGpuForwardConv(conv) || MatchBackwardFilter(conv) ||
+ MatchBackwardInput(conv);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h
new file mode 100644
index 0000000..69369f1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites plain convolutions, backwards-filter convolutions, and
+// backwards-input convolutions into CustomCall HLOs that call into
+// Cudnn/Miopen.
+//
+// This pass does not fuse other ops into the convolution. Instead, specific
+// patterns of ops will be matched and fused into the custom call in
+// CudnnFusedConvRewriter.
+
+class ConvRewriter : public HloModulePass {
+ public:
+ explicit ConvRewriter(const se::GpuComputeCapability& compute_capability)
+ : compute_capability_(compute_capability) {};
+
+ absl::string_view name() const override { return "conv-rewriter"; }
+
+ static bool ConvIsLowerable(HloInstruction* conv);
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::GpuComputeCapability compute_capability_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc
new file mode 100644
index 0000000..d01ffd1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc
@@ -0,0 +1,807 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+
+#include <optional>
+#include <string>
+
+#include "absl/log/check.h"
+#include "absl/strings/str_format.h"
+#include "xla/array4d.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/literal_util.h"
+#include "xla/protobuf_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/shape_inference.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/test.h"
+#include "xla/test_helpers.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class ConvRewriterTest : public HloTestBase {
+ public:
+ ConvRewriterTest()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {
+ for (int i = 0; i < 2; ++i) {
+ WindowDimension* window_dim = default_conv_window_.add_dimensions();
+ window_dim->set_size(1);
+ window_dim->set_stride(1);
+ window_dim->set_padding_low(0);
+ window_dim->set_padding_high(0);
+ window_dim->set_window_dilation(1);
+ window_dim->set_base_dilation(1);
+ }
+ // TF data shapes are by default in the NHWC order, and filter shape is by
+ // default in HWIO order. For backward filter convolution, we need to swap
+ // the batch and feature dimension in the activations, and treat the batch
+ // dimension in gradients as the input feature dimension in the filter.
+ //
+ // TODO(jingyue): Add more tests on NCHW input order, which TF also
+ // supports.
+ tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
+ tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
+ tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
+ tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
+ tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
+ 3);
+ tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
+ tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
+ tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
+ tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
+
+ tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
+ tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
+ tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
+ tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
+ tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
+ tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
+ tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
+ tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
+ tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
+ tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
+ tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
+ tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
+ }
+
+ protected:
+ const se::GpuComputeCapability& GetComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .gpu_compute_capability();
+ }
+
+ bool RunPass(HloModule* module) {
+ return ConvRewriter(GetComputeCapability()).Run(module).value();
+ }
+
+ // A convolution window with stride 1 and zero padding. The size fields are
+ // not set.
+ Window default_conv_window_;
+ ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
+ ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
+};
+
+TEST_F(ConvRewriterTest, BackwardFilterConvolve) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
+ Window conv_window = default_conv_window_;
+ conv_window.mutable_dimensions(1)->set_size(2);
+ conv_window.mutable_dimensions(1)->set_window_dilation(2);
+ auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_,
+ /*preferred_element_type=*/std::nullopt)
+ .value(),
+ activations, gradients, /*feature_group_count=*/1,
+ /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+ OpMetadata metadata;
+ metadata.set_op_name("foo");
+ conv->set_metadata(metadata);
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ ASSERT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+
+ // Check that metadata was preserved.
+ const auto& md_after_opt =
+ entry_computation->root_instruction()->operand(0)->metadata();
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
+ << md_after_opt.DebugString() << " vs " << metadata.DebugString();
+}
+
+TEST_F(ConvRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
+ Window conv_window = default_conv_window_;
+ conv_window.mutable_dimensions(1)->set_size(3);
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_,
+ /*preferred_element_type=*/std::nullopt)
+ .value(),
+ activations, gradients, /*feature_group_count=*/1,
+ /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Extracted from block35 training.
+TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(35);
+ conv_window.mutable_dimensions(i)->set_padding_low(1);
+ conv_window.mutable_dimensions(i)->set_padding_high(1);
+ }
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+}
+
+// Extracted from inception v3 training.
+TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(4);
+ conv_window.mutable_dimensions(i)->set_padding_high(-1);
+ conv_window.mutable_dimensions(i)->set_window_dilation(2);
+ }
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+}
+
+TEST_F(ConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(35);
+ // Uneven padding: padding_low=0, padding_high=1
+ conv_window.mutable_dimensions(i)->set_padding_high(1);
+ }
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+}
+
+TEST_F(ConvRewriterTest, BackwardInputConvolveEvenPadding) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(7);
+ conv_window.mutable_dimensions(i)->set_padding_low(3);
+ conv_window.mutable_dimensions(i)->set_padding_high(3);
+ }
+ ConvolutionDimensionNumbers conv_dnums;
+ conv_dnums.set_input_batch_dimension(0);
+ conv_dnums.set_output_batch_dimension(0);
+ conv_dnums.set_input_feature_dimension(1);
+ conv_dnums.set_output_feature_dimension(1);
+ conv_dnums.add_input_spatial_dimensions(2);
+ conv_dnums.add_output_spatial_dimensions(2);
+ conv_dnums.add_input_spatial_dimensions(3);
+ conv_dnums.add_output_spatial_dimensions(3);
+ conv_dnums.set_kernel_input_feature_dimension(0);
+ conv_dnums.set_kernel_output_feature_dimension(1);
+ conv_dnums.add_kernel_spatial_dimensions(2);
+ conv_dnums.add_kernel_spatial_dimensions(3);
+
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
+ /*rhs=*/reverse_kernel, /*feature_group_count=*/1,
+ /*batch_group_count=*/1, conv_window, conv_dnums,
+ DefaultPrecisionConfig(2)));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ conv_dnums, /*preferred_element_type=*/std::nullopt)
+ .value()));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+
+ ASSERT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+ const HloInstruction* custom_call =
+ entry_computation->root_instruction()->operand(0);
+ for (int i = 0; i < 2; ++i) {
+ const WindowDimension& window_dim = custom_call->window().dimensions(i);
+ // Low padding of the backward input convolution
+ // = kernel_size - 1 - low padding on gradients.
+ EXPECT_EQ(3, window_dim.padding_low());
+ EXPECT_EQ(3, window_dim.padding_high());
+ EXPECT_EQ(1, window_dim.stride());
+ EXPECT_EQ(1, window_dim.base_dilation());
+ }
+}
+
+// Convolve([abc], [x], base_dilation=2)
+// = Convolve([abc], Reverse([x]), base_dilation=2)
+// = BackwardInputConvolve([abc], [x], stride=2)
+TEST_F(ConvRewriterTest, BackwardInputConvolve1x1Filter) {
+ auto builder = HloComputation::Builder(TestName());
+ // NHWC dimension order.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+ // HWOI dimension order.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
+
+ Window conv_window = default_conv_window_;
+ conv_window.mutable_dimensions(1)->set_base_dilation(2);
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(
+ output->shape(), kernel->shape(),
+ /*feature_group_count=*/1,
+ /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_,
+ /*preferred_element_type=*/std::nullopt)
+ .value(),
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+ /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+}
+
+// BackwardInputConvolve([abc], [x], stride=1) is equivalent to
+// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
+// convolution.
+TEST_F(ConvRewriterTest,
+ BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
+ auto builder = HloComputation::Builder(TestName());
+ // NHWC dimension order.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+ // HWOI dimension order.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(
+ output->shape(), kernel->shape(), /*feature_group_count=*/1,
+ /*batch_group_count=*/1, default_conv_window_,
+ tf_default_dnums_for_backward_input_,
+ /*preferred_element_type=*/std::nullopt)
+ .value(),
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+ /*batch_group_count=*/1, default_conv_window_,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Extracted from Inception V3 training.
+//
+// filter(HWIO)
+// 3x3x192x320
+// |
+// v
+// gradients(NHWC) reverse
+// 20x4x4x320 3x3x192x320
+// \ /
+// \ /
+// conv (NHWC) with padding (low=2,high=3,interior=1)
+// 20x10x10x192
+//
+// Gradients are padded unevenly.
+TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(3);
+ conv_window.mutable_dimensions(i)->set_padding_low(2);
+ conv_window.mutable_dimensions(i)->set_padding_high(3);
+ // Interior padding = 1.
+ conv_window.mutable_dimensions(i)->set_base_dilation(2);
+ }
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, /*batch_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_,
+ /*preferred_element_type=*/std::nullopt)
+ .value()));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ ASSERT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+ const HloInstruction* custom_call =
+ entry_computation->root_instruction()->operand(0);
+ for (int i = 0; i < 2; ++i) {
+ const WindowDimension& window_dim = custom_call->window().dimensions(i);
+ EXPECT_EQ(0, window_dim.padding_low());
+ EXPECT_EQ(0, window_dim.padding_high());
+ EXPECT_EQ(2, window_dim.stride());
+ EXPECT_EQ(1, window_dim.base_dilation());
+ }
+}
+
+// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
+// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
+TEST_F(ConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(3);
+ conv_window.mutable_dimensions(i)->set_padding_low(3);
+ conv_window.mutable_dimensions(i)->set_padding_high(2);
+ conv_window.mutable_dimensions(i)->set_base_dilation(2);
+ }
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, /*batch_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_,
+ /*preferred_element_type=*/std::nullopt)
+ .value()));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Extracted from Resnet-50.
+//
+// For simplicity, we focus on the column dimension and ignore other dimensions.
+// We use [?] to represent the shape instead of the content.
+//
+// Suppose operator FC does
+// [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame
+//
+// BC = BackwardInput(FC) does:
+// [14] = conv([7], reverse([3]),
+// padding_low=2, padding_high=1, base_dilation=2)
+//
+// We should fuse BC even though padding on activations is uneven, because
+// GpuConvPaddingLegalization will canonicalize the fusion HLO.
+TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
+ auto builder = HloComputation::Builder(TestName());
+ // The gradients are in NCHW layout.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
+ // The kernel is in HWIO layout.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
+ forward_conv_col_dim->set_size(3);
+ forward_conv_col_dim->set_padding_low(2);
+ forward_conv_col_dim->set_padding_high(1);
+ forward_conv_col_dim->set_base_dilation(2);
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, /*batch_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_,
+ /*preferred_element_type=*/std::nullopt)
+ .value()));
+
+ auto module = CreateNewVerifiedModule();
+ const HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ ASSERT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+ const WindowDimension& backward_conv_col_dim =
+ entry_computation->root_instruction()->operand(0)->window().dimensions(1);
+ EXPECT_EQ(0, backward_conv_col_dim.padding_low());
+ EXPECT_EQ(1, backward_conv_col_dim.padding_high());
+}
+
+// For simplicity, we focus on the column dimension and ignore other dimensions.
+// We use [?] to represent the shape instead of the content.
+//
+// Suppose operator FC does
+// [3] = conv([4], [2], padding_low=1, padding_high=-1)
+//
+// BC = BackwardInput(FC) does:
+// [4] = conv([3], reverse([2]), padding_high=2)
+//
+// We currently don't fuse BC because GpuConvPaddingLegalization
+// doesn't support negative padding on the gradients of backward convolution
+// (b/32744257).
+TEST_F(ConvRewriterTest,
+ BackwardInputConvolveNegativePaddingHighOnActivations) {
+ auto builder = HloComputation::Builder(TestName());
+ // The gradients are in NCHW layout.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+ // The kernel is in HWIO layout.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
+ forward_conv_col_dim->set_size(2);
+ forward_conv_col_dim->set_padding_high(2);
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
+ /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, /*batch_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_,
+ /*preferred_element_type=*/std::nullopt)
+ .value()));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_THAT(entry_computation->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Check that we will materialize a reversed version of a constant in order to
+// pattern-match a backwards input convolution.
+TEST_F(ConvRewriterTest, BackwardInputConvolveConstantFilter) {
+ Array4D<float> constant_arr(4, 4, 2, 2);
+ constant_arr.FillIota(0);
+ std::string constant_str =
+ LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape();
+
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule test
+
+ ENTRY entry_computation {
+ param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
+ constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
+ ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
+ window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
+ dim_labels=bf01_01oi->bf01, feature_group_count=1
+ })",
+ constant_str);
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ EXPECT_TRUE(RunPass(m.get()));
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardInputCallTarget},
+ m::Parameter(), m::Reverse(m::Constant())),
+ 0)));
+}
+
+TEST_F(ConvRewriterTest, TestBackwardFilterPatternMatch) {
+ // All filter dimensions are larger than the corresponding output dimensions.
+ // This must be a backward filter convolution.
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f32[8,120,256,256] parameter(0)
+ filter = f32[8,120,256,256] parameter(1)
+
+ ROOT conv = f32[120,120,3,3] convolution(input, filter), window={size=256x256 pad=1_1x1_1}, dim_labels=fb01_io01->fb01
+ })");
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ EXPECT_TRUE(RunPass(m.get()));
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardFilterCallTarget},
+ m::Parameter(0), m::Parameter(1)),
+ 0)));
+}
+
+TEST_F(ConvRewriterTest, TestBackwardFilterPatternNoMatch) {
+ // At least one filter dimension is smaller than the corresponding output
+ // dimension. This must be a forward convolution.
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f32[8,128,2,32] parameter(0)
+ filter = f32[3,3,128,128] parameter(1)
+
+ ROOT conv = f32[8,128,2,32] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+ })");
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ EXPECT_TRUE(RunPass(m.get()));
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvForwardCallTarget}, m::Parameter(0),
+ m::Parameter(1)),
+ 0)));
+}
+
+TEST_F(ConvRewriterTest, TestConv1dBackwardFilterPatternMatch) {
+ // There exist one kernel dimension equal to output dimension, regard
+ // it as backward filter if conv is 1d.
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f32[8,256,128] parameter(0)
+ filter = f32[8,254,128] parameter(1)
+ reshape.1 = f32[8,1,256,128] reshape(input)
+ reshape.2 = f32[8,1,254,128] reshape(filter)
+ ROOT conv = f32[1,3,128,128] convolution(reshape.1, reshape.2), window={size=1x254}, dim_labels=f01b_i01o->01bf
+ })");
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ EXPECT_TRUE(RunPass(m.get()));
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardFilterCallTarget},
+ m::Reshape(), m::Reshape()),
+ 0)));
+}
+
+TEST_F(ConvRewriterTest, TestConv1dBackwardInputPatternMatch) {
+ // For conv1d backward input, filter may reverse first and then reshape.
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f32[8,254,128] parameter(0)
+ filter = f32[3,128,128] parameter(1)
+ reverse = f32[3,128,128] reverse(filter), dimensions={0}
+ reshape.1 = f32[8,1,254,128] reshape(input)
+ reshape.2 = f32[1,3,128,128] reshape(reverse)
+ ROOT conv = f32[8,1,256,128] convolution(reshape.1, reshape.2), window={size=1x3 pad=0_0x2_2}, dim_labels=b01f_01oi->b01f
+ })");
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ EXPECT_TRUE(RunPass(m.get()));
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBackwardInputCallTarget},
+ m::Reshape(), m::Reshape()),
+ 0)));
+}
+
+TEST_F(ConvRewriterTest, TestInvalidTypes) {
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+ ROOT conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ })");
+
+ // Test complex types
+ for (std::string_view type : {"c64", "c128"}) {
+ const std::string module_with_type =
+ absl::StrReplaceAll(module_str, {{"TYPE", type}});
+ TF_ASSERT_OK_AND_ASSIGN(auto m,
+ ParseAndReturnVerifiedModule(module_with_type));
+
+ absl::Status s = ConvRewriter(GetComputeCapability()).Run(m.get()).status();
+ EXPECT_THAT(
+ s, tsl::testing::StatusIs(
+ absl::StatusCode::kUnimplemented,
+ ::testing::HasSubstr("Convolutions must have floating-point or "
+ "integral operands/outputs")));
+ }
+
+ // Test FP8 type on unsupported GPUs
+ std::string module_with_type =
+ absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fn"}});
+ TF_ASSERT_OK_AND_ASSIGN(auto m,
+ ParseAndReturnVerifiedModule(module_with_type));
+ absl::Status s =
+ ConvRewriter(se::CudaComputeCapability::Ampere()).Run(m.get()).status();
+ EXPECT_THAT(s, tsl::testing::StatusIs(
+ absl::StatusCode::kUnimplemented,
+ ::testing::HasSubstr(
+ "FP8 convolutions are only supported on CUDA "
+ "GPUs with compute capability at least 9.0")));
+ s = ConvRewriter(se::RocmComputeCapability{"gfx942"}).Run(m.get()).status();
+ EXPECT_THAT(s, tsl::testing::StatusIs(
+ absl::StatusCode::kUnimplemented,
+ ::testing::HasSubstr(
+ "FP8 convolutions are only supported on CUDA GPUs")));
+
+ // Test unsupported FP8 type
+ module_with_type = absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fnuz"}});
+ TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(module_with_type));
+ s = ConvRewriter(GetComputeCapability()).Run(m.get()).status();
+ EXPECT_THAT(s,
+ tsl::testing::StatusIs(
+ absl::StatusCode::kUnimplemented,
+ ::testing::HasSubstr("The only FP8 types supported in "
+ "convolutions are f8e5m2 and f8e4m3")));
+}
+
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc
new file mode 100644
index 0000000..a7dc96e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc
@@ -0,0 +1,80 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h"
+
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync(
+ HloComputation* computation,
+ absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
+ const {
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> replaced_ops;
+ CollectiveBackendConfig sync_config;
+ sync_config.set_is_sync(true);
+ for (auto& [async_start, async_done] : async_pairs) {
+ // Tag the async start with is_sync = true.
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ async_start->backend_config<GpuBackendConfig>());
+ *gpu_config.mutable_collective_backend_config() = sync_config;
+ TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
+ replaced_ops[async_start] = nullptr;
+ replaced_ops[async_done] = async_start;
+ }
+
+ // Update schedule.
+ HloModule* module = computation->parent();
+ const HloInstructionSequence& sequence =
+ module->schedule().sequence(computation);
+ std::vector<HloInstruction*> new_sequence;
+ new_sequence.reserve(sequence.size());
+ for (HloInstruction* instr : sequence.instructions()) {
+ auto it = replaced_ops.find(instr);
+ // If its not a start or done, add it to new schedule.
+ if (it == replaced_ops.end()) {
+ new_sequence.push_back(instr);
+ continue;
+ }
+
+ // If its a start op, do not add it to the schedule yet.
+ if (it->second == nullptr) {
+ continue;
+ }
+
+ // Its a done op. First add the start and then the done.
+ new_sequence.push_back(it->second);
+ new_sequence.push_back(instr);
+ }
+ module->schedule().set_sequence(computation, new_sequence);
+ return absl::OkStatus();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h
new file mode 100644
index 0000000..6507080
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h
@@ -0,0 +1,47 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
+
+#include <utility>
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/convert_async_collectives_to_sync.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuConvertAsyncCollectivesToSync : public ConvertAsyncCollectivesToSync {
+ public:
+ using ConvertAsyncCollectivesToSync::ConvertAsyncCollectivesToSync;
+ absl::string_view name() const override {
+ return "gpu-convert-async-collectives-to-sync";
+ }
+
+ absl::Status ConvertAsyncInstructionsToSync(
+ HloComputation* computation,
+ absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
+ const override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc
new file mode 100644
index 0000000..d38ab70
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc
@@ -0,0 +1,347 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h"
+
+#include <string_view>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::IsFalse;
+using ::testing::IsTrue;
+
+// Note: The pass only processes modules that are already scheduled. If the test
+// does not work as expected, make sure to check if "is_scheduled=true" is added
+// to the HLO module string.
+class GpuConvertAsyncCollectivesToSyncTest : public HloTestBase {
+ public:
+ absl::Status RunPass(HloModule *module, bool expect_change,
+ HloPredicate is_nop = {}) {
+ TF_ASSIGN_OR_RETURN(bool changed,
+ GpuConvertAsyncCollectivesToSync{is_nop}.Run(module));
+ EXPECT_EQ(changed, expect_change);
+ return absl::OkStatus();
+ }
+
+ // Returns true if the instruction with the given name is synchronous.
+ bool IsSync(HloModule *module, std::string_view name) {
+ const HloInstruction *inst = FindInstruction(module, name);
+ if (inst == nullptr) {
+ return false;
+ }
+ auto backend_config = inst->backend_config<GpuBackendConfig>()
+ .value()
+ .collective_backend_config();
+ return backend_config.is_sync();
+ }
+
+ HloPredicate is_nop_simple_ =
+ HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kGetTupleElement,
+ HloOpcode::kParameter>;
+};
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduce) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ apply_op {
+ x = u32[] parameter(0)
+ y = u32[] parameter(1)
+ ROOT apply_op = u32[] add(x, y)
+ }
+
+ ENTRY test_computation {
+ id = u32[] replica-id()
+ start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+ ROOT done = u32[] all-reduce-done(start)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNop) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ apply_op {
+ x = u32[] parameter(0)
+ y = u32[] parameter(1)
+ ROOT apply_op = u32[] add(x, y)
+ }
+
+ ENTRY test_computation {
+ id = u32[] replica-id()
+ start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3, replica_groups={{0,1}, {2,3}}
+ id2 = f32[] bitcast(id)
+ ROOT done = u32[] all-reduce-done(start)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true, is_nop_simple_));
+ EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
+}
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectiveBroadcast) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ collective_broadcast {
+ p0 = u32[8] parameter(0)
+ ROOT result = u32[8] collective-broadcast(p0), replica_groups={{0,1}, {2,3}}
+ }
+
+ ENTRY main {
+ data = u32[8] parameter(0)
+ cb-start = ((u32[8]{0}), u32[8]{0}) async-start(u32[8]{0} %data), calls=collective_broadcast
+ ROOT %ars = u32[8]{0} async-done(((u32[8]{0}), u32[8]{0}) %cb-start), calls=collective_broadcast
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "cb-start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNonNop) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ apply_op {
+ x = u32[] parameter(0)
+ y = u32[] parameter(1)
+ ROOT apply_op = u32[] add(x, y)
+ }
+
+ ENTRY test_computation {
+ id = u32[] replica-id()
+ start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+ id2 = u32[] add(id, id)
+ ROOT done = u32[] all-reduce-done(start)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/false));
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllGather) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+ ENTRY test_computation {
+ a1 = u32[1, 2] parameter(0)
+ ags = (u32[1, 2], u32[2, 2]) all-gather-start(a1), dimensions={0}, channel_id=3
+ ROOT allgather = u32[2,2] all-gather-done(ags)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "ags"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectivePermute) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ ENTRY test_computation {
+ p = u32[2] parameter(0)
+ start = (u32[2], u32[2], u32[], u32[]) collective-permute-start(p), source_target_pairs={{0,1}, {1,0}}
+ ROOT done = u32[2] collective-permute-done(start)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleReduceScatter) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ add {
+ lhs = u32[] parameter(0)
+ rhs = u32[] parameter(1)
+ ROOT add = u32[] add(lhs, rhs)
+ }
+
+ reduce_scatter {
+ p0 = u32[8] parameter(0)
+ ROOT result = u32[4] reduce-scatter(p0), replica_groups={{0,3}, {1,2}},
+ dimensions={0}, to_apply=add
+ }
+
+ ENTRY main {
+ data = u32[8] parameter(0)
+ rs-start = ((u32[8]{0}), u32[4]{0}) async-start(u32[8]{0} %data), calls=reduce_scatter
+ ROOT %ars = u32[4]{0} async-done(((u32[8]{0}), u32[4]{0}) %rs-start), calls=reduce_scatter
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "rs-start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllToAll) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ all_to_all {
+ p0 = u32[2] parameter(0)
+ ROOT result = u32[2] all-to-all(p0), dimensions={0}, replica_groups={{0,1},{2,3}}
+ }
+
+ ENTRY test_computation {
+ a1 = u32[2] parameter(0)
+ a2a-start = ((u32[2]), u32[2]) async-start(u32[2] a1), calls=all_to_all
+ ROOT a2s = u32[2] async-done(a2a-start), calls=all_to_all
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "a2a-start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, ControlDeps) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ apply_op {
+ x = u32[] parameter(0)
+ y = u32[] parameter(1)
+ ROOT apply_op = u32[] add(x, y)
+ }
+
+ ENTRY test_computation {
+ id = u32[] replica-id()
+ start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+ done1 = u32[] all-reduce-done(start1)
+ start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4, control-predecessors={done1}
+ done2 = u32[] all-reduce-done(start2)
+ ROOT x = u32[] add(done1, done2)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
+ EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+// Test multiple in-flight collectives that are ordered in a streaming fashion:
+// i.e., ends are in start order (FIFO).
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightStreaming) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ apply_op {
+ x = u32[] parameter(0)
+ y = u32[] parameter(1)
+ ROOT apply_op = u32[] add(x, y)
+ }
+
+ ENTRY test_computation {
+ id = u32[] replica-id()
+ start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+ start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
+ done1 = u32[] all-reduce-done(start1)
+ done2 = u32[] all-reduce-done(start2)
+ ROOT x = u32[] add(done1, done2)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
+ EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0}
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNested) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ apply_op {
+ x = u32[] parameter(0)
+ y = u32[] parameter(1)
+ ROOT apply_op = u32[] add(x, y)
+ }
+
+ ENTRY test_computation {
+ id = u32[] replica-id()
+ start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+ start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
+ done2 = u32[] all-reduce-done(start2)
+ done1 = u32[] all-reduce-done(start1)
+ ROOT x = u32[] add(done1, done2)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
+ EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0} where
+// inner pair can be converted but not outer.
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNestedPartial) {
+ const absl::string_view hlo_string = R"(
+ HloModule test, is_scheduled=true
+
+ apply_op {
+ x = u32[] parameter(0)
+ y = u32[] parameter(1)
+ ROOT apply_op = u32[] add(x, y)
+ }
+
+ ENTRY test_computation {
+ id = u32[] replica-id()
+ start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+ start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
+ done2 = u32[] all-reduce-done(start2)
+ id2 = u32[] add(done2, done2)
+ done1 = u32[] all-reduce-done(start1)
+ ROOT x = u32[] add(done1, done2)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+ EXPECT_THAT(IsSync(module.get(), "start1"), IsFalse());
+ EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc
new file mode 100644
index 0000000..eb43ca2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc
@@ -0,0 +1,197 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/copy_fusion.h"
+
+#include <cstdint>
+#include <queue>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+bool OnlyElementwiseOpsReachableFromParams(HloComputation* fused_computation) {
+ std::queue<const HloInstruction*> q;
+ absl::flat_hash_set<const HloInstruction*> visited;
+ for (auto param : fused_computation->parameter_instructions()) {
+ q.push(param);
+ visited.insert(param);
+ }
+ while (!q.empty()) {
+ const HloInstruction* hlo = q.front();
+ q.pop();
+ for (auto user : hlo->users()) {
+ if ((!user->IsElementwiseOnOperand(user->operand_index(hlo)) ||
+ user->opcode() == HloOpcode::kCopy) &&
+ user->opcode() != HloOpcode::kBitcast &&
+ user->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+ if (visited.insert(user).second) {
+ q.push(user);
+ }
+ }
+ }
+ return true;
+}
+
+absl::StatusOr<bool> CopyFusion::DoCopyFusion(HloComputation* computation) {
+ bool changed = false;
+ std::vector<HloInstruction*> defs_before_uses =
+ computation->MakeInstructionPostOrder();
+
+ for (HloInstruction* hlo : defs_before_uses) {
+ if (hlo->opcode() != HloOpcode::kFusion) {
+ continue;
+ }
+ std::vector<HloInstruction*> copies;
+ std::vector<HloInstruction*> other_users;
+ HloComputation* fused_computation = hlo->fused_instructions_computation();
+ if (!OnlyElementwiseOpsReachableFromParams(fused_computation)) {
+ continue;
+ }
+ HloInstruction* root = fused_computation->root_instruction();
+ if (IsReductionFromOrToContiguousDimensions(*root) ||
+ root->opcode() == HloOpcode::kScatter ||
+ (hlo->IsMultiOutputFusion() &&
+ absl::c_all_of(root->operands(), [](const HloInstruction* slice) {
+ return slice->opcode() == HloOpcode::kSlice;
+ }))) {
+ continue;
+ }
+ for (auto user : hlo->users()) {
+ HloInstruction* copy_user = user;
+ // Skip get-tuple-element ops.
+ if (copy_user->opcode() == HloOpcode::kGetTupleElement &&
+ copy_user->user_count() == 1) {
+ if (IsReductionFromOrToContiguousDimensions(
+ *(root->operand(copy_user->tuple_index())))) {
+ other_users.push_back(user);
+ continue;
+ }
+ copy_user = copy_user->users()[0];
+ }
+ // Skip bitcast ops.
+ if (copy_user->opcode() == HloOpcode::kBitcast &&
+ copy_user->user_count() == 1) {
+ copy_user = copy_user->users()[0];
+ }
+ if (copy_user->opcode() == HloOpcode::kCopy &&
+ copy_user->shape() == copy_user->operand(0)->shape() &&
+ !copy_user->shape().IsTuple() &&
+ !copy_user->HasControlDependencies()) {
+ copies.push_back(copy_user);
+ } else {
+ other_users.push_back(user);
+ }
+ }
+ if (copies.empty()) {
+ continue;
+ }
+ auto fusion_adaptor = HloFusionAdaptor::ForComputation(fused_computation);
+ auto dynamic_update_slices =
+ GetOutputDefiningDynamicUpdateSlices(fusion_adaptor->GetRoots());
+ // Skip dynamic update slice fusions which might be emitted in-place.
+ if (!dynamic_update_slices.empty() &&
+ (root->opcode() != HloOpcode::kTuple ||
+ dynamic_update_slices.size() == root->shape().tuple_shapes_size())) {
+ continue;
+ }
+ changed = true;
+
+ HloInstruction::InstructionVector tuple_elements;
+ int64_t num_outputs =
+ hlo->IsMultiOutputFusion() ? root->operand_count() : int64_t{1};
+ tuple_elements.reserve(copies.size() + num_outputs);
+ if (hlo->IsMultiOutputFusion()) {
+ for (HloInstruction* operand : root->operands()) {
+ tuple_elements.push_back(operand);
+ }
+ } else {
+ tuple_elements.push_back(root);
+ }
+
+ for (auto copy : copies) {
+ HloInstruction* user = copy;
+ std::vector<HloInstruction*> operand_chain;
+ operand_chain.push_back(user);
+ while (user->operand(0) != hlo) {
+ user = user->mutable_operand(0);
+ operand_chain.push_back(user);
+ }
+ HloInstruction* clone_operand = root;
+ if (hlo->IsMultiOutputFusion()) {
+ clone_operand = root->mutable_operand(user->tuple_index());
+ CHECK_EQ(operand_chain.back()->opcode(), HloOpcode::kGetTupleElement);
+ operand_chain.pop_back();
+ }
+ for (int64_t i = operand_chain.size() - 1; i >= 0; --i) {
+ HloInstruction* user = operand_chain[i];
+ clone_operand = fused_computation->AddInstruction(
+ user->CloneWithNewOperands(user->shape(), {clone_operand}));
+ }
+ tuple_elements.push_back(clone_operand);
+ }
+
+ HloInstruction* new_root = fused_computation->AddInstruction(
+ HloInstruction::CreateTuple(tuple_elements));
+ fused_computation->set_root_instruction(new_root,
+ /*accept_different_shape=*/true);
+ *hlo->mutable_shape() = new_root->shape();
+
+ if (root->opcode() == HloOpcode::kTuple) {
+ TF_RETURN_IF_ERROR(fused_computation->RemoveInstruction(root));
+ } else {
+ auto get_tuple_element_root = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(hlo, 0));
+ TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(
+ other_users, get_tuple_element_root));
+ }
+ for (int64_t i = 0; i < copies.size(); ++i) {
+ auto get_tuple_element = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(hlo, num_outputs + i));
+ TF_RETURN_IF_ERROR(
+ computation->ReplaceInstruction(copies[i], get_tuple_element));
+ }
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> CopyFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ // Only for the entry computation we can be sure that the copies do not share
+ // a buffer with a parameter of the fusion that it will be fused with. For
+ // example while loop computations have tuple parameters that need to share
+ // the buffers with the output tuples, and copies inserted by the
+ // CopyInsertion pass will share a buffer with the tuple output (and thus
+ // with the tuple input as well).
+ return DoCopyFusion(module->entry_computation());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.h b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h
new file mode 100644
index 0000000..a6a1ae4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h
@@ -0,0 +1,49 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// CopyFusion checks if a fusion is followed by multiple copies and if so, adds
+// those copies to the fusion, replacing the copies with get_tuple_elements.
+class CopyFusion : public HloModulePass {
+ public:
+ CopyFusion() = default;
+
+ absl::string_view name() const override { return "copy_fusion"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ absl::StatusOr<bool> DoCopyFusion(HloComputation* computation);
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc
new file mode 100644
index 0000000..1bd2d11
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc
@@ -0,0 +1,500 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/copy_fusion.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/str_cat.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+namespace m = ::xla::match;
+
+class CopyFusionTest : public HloTestBase {
+ public:
+ CopyFusion cf_;
+};
+
+const char kModulePrefix[] = R"(
+ HloModule test_module
+
+ scalar_add_computation {
+ scalar_lhs.0 = f32[] parameter(0)
+ scalar_rhs.0 = f32[] parameter(1)
+ ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
+ }
+ scalar_mul_computation {
+ scalar_lhs.1 = f32[] parameter(0)
+ scalar_rhs.1 = f32[] parameter(1)
+ ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
+ })";
+
+TEST_F(CopyFusionTest, CopyFusionTransposeOfBroadcastedConstantTwoCopies) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ two = f32[] constant(2.0)
+ broadcast = f32[16,32]{1,0} broadcast(two), dimensions={}
+ s.1 = f32[16,32]{1,0} sqrt(broadcast)
+ ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+ }
+
+ ENTRY main {
+ fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
+ copy.1 = f32[32,16]{1,0} copy(fusion)
+ copy.2 = f32[32,16]{1,0} copy(fusion)
+ ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement())));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionTransposeTwoCopies) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ param_0.1 = f32[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+ }
+
+ ENTRY main {
+ p = f32[16,32]{1,0} parameter(0)
+ fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
+ copy.1 = f32[32,16]{1,0} copy(fusion)
+ copy.2 = f32[32,16]{1,0} copy(fusion)
+ ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
+ })"))
+ .value();
+ ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopies) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ }
+
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+ copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+ copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+ ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement())));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithReduce) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
+ copy.1 = f32[512]{0} copy(fusion)
+ copy.2 = f32[512]{0} copy(fusion)
+ ROOT root = (f32[512]{0}, f32[512]{0}) tuple(copy.1, copy.2)
+ })"))
+ .value();
+ ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldRunWithUncopiedReduce) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ two = f32[] constant(2.0)
+ broadcast = f32[128,512,28,28]{3,2,1,0} broadcast(two)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(broadcast, broadcast)
+ const = f32[] constant(0.0)
+ reduce = f32[512]{0} reduce(mul, const), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(mul, reduce)
+ }
+
+ ENTRY entry {
+ fusion = (f32[128,512,28,28]{3,2,1,0}, f32[512]) fusion(), kind=kInput, calls=fused_computation
+ gte = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
+ gte.2 = f32[512]{0} get-tuple-element(fusion), index=1
+ copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte)
+ ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(copy.1, gte.2)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement())));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Multiply(), m::Reduce(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotFuseForSliceMultioutputFusion) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
+ slice1 = f32[128,100,28,28]{3,2,1,0} slice(mul), slice={[0:128],[0:100],[0:28],[0:28]}
+ slice2 = f32[128,200,28,28]{3,2,1,0} slice(mul), slice={[0:128],[50:250],[0:28],[0:28]}
+ ROOT tuple = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) tuple(slice1, slice2)
+ }
+
+ ENTRY entry {
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) fusion(p1), kind=kInput, calls=fused_computation
+ })"))
+ .value();
+ ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithScatter) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+ scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
+ input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} negate(p0)
+ ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(input_tensor, scatter_indices, updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=scalar_add_computation
+}
+
+ ENTRY entry {
+ param.0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+ param.1 = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ param.2 = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
+ fusion = f32[50,49,48,47,46]{4,3,2,1,0} fusion(param.0, param.1, param.2), kind=kInput, calls=fused_computation
+ ROOT copy = f32[50,49,48,47,46]{4,3,2,1,0} copy(fusion)
+ })"))
+ .value();
+ ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunOutsideEntryComputation) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+fused_computation.549 {
+ param_0.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
+ bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.8511)
+ slice = bf16[15,1,2,48,128,1]{5,4,3,2,1,0} slice(bitcast.52601), slice={[0:15:1], [0:1:1], [0:2:1], [0:48:1], [0:128:1], [0:1:1]}
+ bitcast = bf16[15,1,2,48,128]{4,3,2,1,0} bitcast(slice)
+ ROOT broadcast = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} broadcast(bitcast), dimensions={0,1,2,3,4}
+}
+
+condition {
+ constant_6915 = s32[] constant(15)
+ param.218 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
+ get-tuple-element.3714 = s32[] get-tuple-element(param.218), index=1
+ ROOT compare.1738 = pred[] compare(get-tuple-element.3714, constant_6915), direction=LT
+}
+
+body {
+ tuple_param = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
+ param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(tuple_param), index=0
+ param_1 = s32[] get-tuple-element(tuple_param), index=1
+ fusion.549 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation.549
+ bitcast = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(fusion.549)
+ copy = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(bitcast)
+ constant_one = s32[] constant(1)
+ add = s32[] add(param_1, constant_one), control-predecessors={fusion.549}
+ ROOT tuple = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) tuple(copy, add)
+}
+
+ENTRY main {
+ param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
+ zero = s32[] constant(0)
+ copy.0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(param_0)
+ copy.1 = s32[] copy(zero)
+ tuple = tuple(copy.0, copy.1)
+ ROOT while = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) while(tuple), condition=condition, body=body, backend_config="{\"known_trip_count\":{\"n\":\"15\"}}"
+})"))
+ .value();
+ ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithDynamicUpdateSliceInplace) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p.0 = f16[50,96,1024]{2,1,0} parameter(0)
+ p.1 = f16[1,96,1024]{2,1,0} parameter(1)
+ c.0 = s32[3]{0} constant({0, 0, 0})
+ ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+ }
+
+ ENTRY entry {
+ p0 = f16[50,96,1024]{2,1,0} parameter(0)
+ p1 = f16[1,96,1024]{2,1,0} parameter(1)
+ fusion = f16[50,96,1024]{2,1,0} fusion(p0, p1), kind=kInput, calls=fused_computation
+ copy.1 = f16[50,96,1024]{2,1,0} copy(fusion)
+ copy.2 = f16[50,96,1024]{2,1,0} copy(fusion)
+ ROOT root = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy.1, copy.2)
+ })"))
+ .value();
+ ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionWithDynamicUpdateSliceNotInplace) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ one = f32[] constant(1.0)
+ zero = f32[] constant(0.0)
+ p.0 = f16[50,96,1024]{2,1,0} broadcast(one), dimensions={}
+ p.1 = f16[1,96,1024]{2,1,0} broadcast(zero), dimensions={}
+ c.0 = s32[3]{0} constant({0, 0, 0})
+ dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+ neg = f16[50,96,1024]{2,1,0} negate(dynamic-update-slice)
+ ROOT tuple = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(dynamic-update-slice, neg)
+ }
+
+ ENTRY entry {
+ fusion = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) fusion(), kind=kInput, calls=fused_computation
+ gte.0 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=0
+ gte.1 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=1
+ bitcast = f16[1,50,96,1024]{3,2,1,0} bitcast(gte.0)
+ copy = f16[1,50,96,1024]{3,2,1,0} copy(bitcast)
+ ROOT root = (f16[1,50,96,1024]{3,2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy, gte.1)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement())));
+ EXPECT_THAT(
+ fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::DynamicUpdateSlice(), m::Negate(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionTransposeAndThreeCopies) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ two = f32[] constant(2.0)
+ param_0.1 = f32[16,32]{1,0} broadcast(two), dimensions={}
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+ }
+
+ ENTRY entry {
+ fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
+ copy.1 = f32[32,16]{1,0} copy(fusion)
+ copy.2 = f32[32,16]{1,0} copy(fusion)
+ copy.3 = f32[32,16]{1,0} copy(fusion)
+ ROOT root = (f32[32,16]{1,0}, f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.1, copy.2, copy.3)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root,
+ GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement(), m::GetTupleElement())));
+ EXPECT_THAT(
+ fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneCopy) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ }
+
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+ ROOT copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::GetTupleElement(m::Fusion(&fusion))));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Negate(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopiesAndTransposeCopy) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ }
+
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+ copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+ transpose = f32[128,512,28,28]{2,3,0,1} copy(fusion)
+ bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
+ copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+ ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::Bitcast(), m::GetTupleElement())));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneNonTransposeCopy) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ }
+
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+ copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+ transpose.1 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
+ bitcast.1 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.1)
+ transpose.2 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
+ bitcast.2 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.2)
+ ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}) tuple(copy.1, bitcast.1, bitcast.2)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::Bitcast(), m::Bitcast())));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Negate(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionSkipTupleCopies) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
+ copy.1 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
+ copy.2 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
+ ROOT root = ((f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}),(f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0})) tuple(copy.1, copy.2)
+ })"))
+ .value();
+ ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionTupleAndGetTuple) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
+ gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
+ gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
+ copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
+ copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
+ ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement())));
+ EXPECT_THAT(
+ fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionWithFusionReturningTupleAndOtherUser) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+ ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
+ gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
+ gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
+ copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
+ copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
+ transpose = f32[128,512,28,28]{2,3,0,1} copy(gte.1)
+ bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
+ ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
+ })"))
+ .value();
+ ASSERT_TRUE(cf_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root,
+ GmockMatch(m::Tuple(m::Copy(), m::Bitcast(),
+ m::GetTupleElement(m::Fusion(&fusion)))));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy())));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc
new file mode 100644
index 0000000..43f1242
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc
@@ -0,0 +1,210 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+static absl::StatusOr<bool> PadForGemm(HloDotInstruction* dot,
+ PrimitiveType datatype,
+ int pad_to_multiple_of) {
+ auto* lhs = dot->mutable_operand(0);
+ auto* rhs = dot->mutable_operand(1);
+
+ Shape lshape = lhs->shape();
+ Shape rshape = rhs->shape();
+ Shape result_shape = dot->shape();
+
+ if (lshape.element_type() != datatype || rshape.element_type() != datatype) {
+ return false;
+ }
+
+ auto pad_dim = [&](Shape& s, int dim) {
+ s.set_dimensions(dim,
+ RoundUpTo<int64_t>(s.dimensions(dim), pad_to_multiple_of));
+ };
+
+ auto pad_matrix_dims = [&pad_dim](Shape s) {
+ // Since the dot instruction is canonicalized, the last two dimensions for
+ // each operand represent non-batch dimensions, and the others are the same
+ // for both operands and correspond to batch dimensions.
+ pad_dim(s, s.rank() - 2);
+ pad_dim(s, s.rank() - 1);
+ return s;
+ };
+
+ Shape new_lshape = pad_matrix_dims(lshape);
+ Shape new_rshape = pad_matrix_dims(rshape);
+ Shape new_result_shape = pad_matrix_dims(result_shape);
+
+ if (new_lshape == lshape && new_rshape == rshape) {
+ return false;
+ }
+
+ VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
+ VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
+ << new_result_shape;
+
+ auto create_padding_config = [](Shape& shape, Shape& new_shape) {
+ PaddingConfig padding_config;
+ for (int i = 0; i < shape.rank(); ++i) {
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_high(new_shape.dimensions()[i] -
+ shape.dimensions()[i]);
+ dimension->set_edge_padding_low(0);
+ dimension->set_interior_padding(0);
+ }
+ return padding_config;
+ };
+
+ auto l_padding_config = create_padding_config(lshape, new_lshape);
+ auto r_padding_config = create_padding_config(rshape, new_rshape);
+
+ HloComputation* parent = dot->parent();
+
+ HloInstruction* zero_float = parent->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(datatype)));
+ zero_float->set_metadata(dot->metadata());
+
+ HloInstruction* lpad = parent->AddInstruction(
+ HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
+ lpad->set_metadata(dot->metadata());
+
+ HloInstruction* rpad = parent->AddInstruction(
+ HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
+ rpad->set_metadata(dot->metadata());
+
+ HloInstruction* new_dot = parent->AddInstruction(
+ dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
+
+ std::vector<int64_t> start_indices(result_shape.rank(), 0);
+ std::vector<int64_t> strides(result_shape.rank(), 1);
+ HloInstruction* slice = parent->AddInstruction(
+ HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
+ result_shape.dimensions(), strides));
+ slice->set_metadata(dot->metadata());
+
+ bool is_root = dot->user_count() == 0;
+
+ TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
+
+ if (is_root) {
+ parent->set_root_instruction(slice);
+ }
+
+ return true;
+}
+
+namespace {
+
+// We need this check because PadForGemm works in the assumption that
+// the dot instruction is canonicalized.
+bool CheckCanonical(HloDotInstruction* dot) {
+ const auto& dimension_numbers = dot->dot_dimension_numbers();
+
+ if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
+ dot->operand(0)->shape().rank() ||
+ dimension_numbers.rhs_batch_dimensions_size() + 2 !=
+ dot->operand(1)->shape().rank()) {
+ VLOG(2)
+ << dot->ToString()
+ << " is not canonical: Expected all dimensions but 2 to be "
+ "batch_dimensions. Hence, this dot is not a candidate for padding.";
+ return false;
+ }
+
+ std::vector<int64_t> canonical_batch_dims(
+ dimension_numbers.lhs_batch_dimensions_size());
+ absl::c_iota(canonical_batch_dims, 0);
+ if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
+ canonical_batch_dims) ||
+ !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
+ canonical_batch_dims)) {
+ VLOG(2)
+ << dot->ToString()
+ << " is not canonical: Expected batch dimensions to be all "
+ "dimensions except for the last 2 ones. Hence, this dot is not a "
+ "candidate for padding.";
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+static std::vector<HloDotInstruction*> GetRelevantDots(
+ const se::GpuComputeCapability& gpu_compute_capability,
+ HloComputation* comp, PrimitiveType datatype) {
+ std::vector<HloDotInstruction*> gemms;
+
+ for (HloInstruction* instr : comp->instructions()) {
+ if (IsMatrixMultiplication(*instr)) {
+ HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
+ if (instr->operand(0)->shape().element_type() == datatype &&
+ CheckCanonical(dot) &&
+ !(instr->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_enable_triton_gemm() &&
+ legacy_triton::IsTritonSupportedInstruction(
+ *dot, gpu_compute_capability) &&
+ ShouldTritonHandleGEMM(*dot, gpu_compute_capability))) {
+ gemms.push_back(dot);
+ }
+ }
+ }
+ return gemms;
+}
+
+absl::StatusOr<bool> CublasPadForGemms::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloDotInstruction* dot :
+ GetRelevantDots(gpu_compute_capability_, comp, datatype_)) {
+ TF_ASSIGN_OR_RETURN(bool result,
+ PadForGemm(dot, datatype_, pad_to_multiple_of_));
+ changed |= result;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h
new file mode 100644
index 0000000..8c7d8e5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h
@@ -0,0 +1,63 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Adds padding to dot operations to make them run faster on GPUs.
+//
+//
+// This can be used to pad f16 dots on tensor cores, or s8 dots to multiples of
+// four.
+//
+// This pass depends on xla::DotDecomposer pass,
+// so it should go strictly later.
+class CublasPadForGemms : public HloModulePass {
+ public:
+ CublasPadForGemms(const se::GpuComputeCapability gpu_compute_capability,
+ PrimitiveType datatype, int32_t pad_to_multiple_of)
+ : gpu_compute_capability_(gpu_compute_capability),
+ datatype_(datatype),
+ pad_to_multiple_of_(pad_to_multiple_of) {}
+
+ absl::string_view name() const override { return "cublas-pad-for-gemms"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const se::GpuComputeCapability gpu_compute_capability_;
+ PrimitiveType datatype_;
+ int32_t pad_to_multiple_of_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc
new file mode 100644
index 0000000..77a32c9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc
@@ -0,0 +1,306 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace m = ::xla::match;
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CublasGemmPadForTensorCoresTest : public HloTestBase {
+ protected:
+ bool PadForF16Gemms(HloModule* module) {
+ return CublasPadForGemms(se::CudaComputeCapability(7, 0),
+ PrimitiveType::F16, 8)
+ .Run(module)
+ .value();
+ }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+ // Some pads would not be added if we detect that Triton will handle the
+ // given dot operation.
+ debug_options.set_xla_gpu_triton_gemm_any(false);
+ return debug_options;
+ }
+};
+
+TEST_F(CublasGemmPadForTensorCoresTest, OneDotRootComputation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f16[2048,1024] parameter(0)
+ %param2 = f16[1024,33708] parameter(1)
+ ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
+ f16[1024,33708]{0,1} %param2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ })")
+ .value();
+
+ EXPECT_TRUE(PadForF16Gemms(module.get()));
+ SCOPED_TRACE(module->ToString());
+
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(
+ m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {2048, 1024}),
+ m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {1024, 33712}))
+ .WithShape(F16, {2048, 33712})
+ .WithContractingDims(/*lhs_contracting_dims=*/{1},
+ /*rhs_contracting_dims=*/{0}))
+ .WithShape(F16, {2048, 33708})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, OneDotS8RootComputation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = s8[2047,1023] parameter(0)
+ %param2 = s8[1023,33707] parameter(1)
+ ROOT %dot.2309 = s32[2047,33707]{1,0} dot(s8[2047,1023]{1,0} %param1,
+ s8[1023,33707]{0,1} %param2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ })")
+ .value();
+
+ EXPECT_TRUE(
+ CublasPadForGemms(se::CudaComputeCapability(7, 0), PrimitiveType::S8, 4)
+ .Run(module.get())
+ .value());
+ SCOPED_TRACE(module->ToString());
+
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(
+ m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(S8, {2047, 1023}),
+ m::Constant().WithShape(S8, {}))
+ .WithShape(S8, {2048, 1024}),
+ m::Pad(m::Parameter().WithShape(S8, {1023, 33707}),
+ m::Constant().WithShape(S8, {}))
+ .WithShape(S8, {1024, 33708}))
+ .WithShape(S32, {2048, 33708})
+ .WithContractingDims(/*lhs_contracting_dims=*/{1},
+ /*rhs_contracting_dims=*/{0}))
+ .WithShape(S32, {2047, 33707})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, TwoDotsComputation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f16[2048, 1024] parameter(0)
+ %param2 = f16[1024, 33708] parameter(1)
+ %param3 = f16[33708, 1] parameter(2)
+ %dot1 = f16[2048, 33708]{1,0} dot(f16[2048, 1024]{1,0} %param1,
+ f16[1024, 33708]{0,1} %param2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT %dot2 = f16[2048, 1]{1,0} dot(f16[2048, 33708]{1,0} %dot1,
+ f16[33708, 1]{0,1} %param3),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ })")
+ .value();
+
+ EXPECT_TRUE(PadForF16Gemms(module.get()));
+ SCOPED_TRACE(module->ToString());
+
+ auto* root = module->entry_computation()->root_instruction();
+ const HloInstruction* dot2 = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(
+ m::Slice(
+ m::Dot(
+ m::Pad(m::Slice(m::Dot(&dot2,
+ m::Pad().WithShape(F16, {2048, 1024}),
+ m::Pad().WithShape(F16, {1024, 33712}))
+ .WithContractingDims(
+ /*lhs_contracting_dims=*/{1},
+ /*rhs_contracting_dims=*/{0})
+ .WithShape(F16, {2048, 33712}))
+ .WithShape(F16, {2048, 33708}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {2048, 33712}),
+
+ m::Pad(m::Parameter().WithShape(F16, {33708, 1}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {33712, 8}))
+ .WithShape(F16, {2048, 8})
+ .WithContractingDims(/*lhs_contracting_dims=*/{1},
+ /*rhs_contracting_dims=*/{0}))
+ .WithShape(F16, {2048, 1})));
+
+ EXPECT_THAT(
+ dot2,
+ GmockMatch(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {2048, 1024}),
+ m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {1024, 33712}))
+ .WithContractingDims(/*lhs_contracting_dims=*/{1},
+ /*rhs_contracting_dims=*/{0})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, DotWithBatchDimensions) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f16[3, 5, 2048, 1024] parameter(0)
+ %param2 = f16[3, 5, 1024, 33708] parameter(1)
+ ROOT %dot.2309 = f16[3, 5, 2048, 33708]{3, 2, 1,0} dot(f16[3, 5, 2048, 1024]{3, 2, 1,0} %param1,
+ f16[3, 5, 1024, 33708]{2, 3, 0,1} %param2), lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1}, lhs_contracting_dims={3}, rhs_contracting_dims={2}})")
+ .value();
+
+ EXPECT_TRUE(PadForF16Gemms(module.get()));
+ SCOPED_TRACE(module->ToString());
+
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(
+ m::Slice(
+ m::Dot(m::Pad(m::Parameter().WithShape(F16, {3, 5, 2048, 1024}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {3, 5, 2048, 1024}),
+ m::Pad(m::Parameter().WithShape(F16, {3, 5, 1024, 33708}),
+ m::Constant().WithShape(F16, {}))
+ .WithShape(F16, {3, 5, 1024, 33712}))
+ .WithShape(F16, {3, 5, 2048, 33712})
+ .WithContractingDims(/*lhs_contracting_dims=*/{3},
+ /*rhs_contracting_dims=*/{2}))
+ .WithShape(F16, {3, 5, 2048, 33708})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, NoDotComputation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %maximum = f32[] maximum(f32[] %x, f32[] %y)
+ })")
+ .value();
+
+ EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, F32DotComputation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f32[2048,1024] parameter(0)
+ %param2 = f32[1024,33708] parameter(1)
+ ROOT %dot.2309 = f32[2048,33708]{1,0} dot(f32[2048,1024]{1,0} %param1,
+ f32[1024,33708]{0,1} %param2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
+ .value();
+
+ EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, F64DotComputation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f64[2048,1024] parameter(0)
+ %param2 = f64[1024,33708] parameter(1)
+ ROOT %dot.2309 = f64[2048,33708]{1,0} dot(f64[2048,1024]{1,0} %param1,
+ f64[1024,33708]{0,1} %param2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
+ .value();
+
+ EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, MultiplesOf8DotComputation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f16[2048,1024] parameter(0)
+ %param2 = f16[1024,33712] parameter(1)
+ ROOT %dot.2309 = f16[2048,33712]{1,0} dot(f16[2048,1024]{1,0} %param1,
+ f16[1024,33712]{0,1} %param2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
+ .value();
+
+ EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, CheckSavingMetadata) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f16[2048,1024] parameter(0)
+ %param2 = f16[1024,33708] parameter(1)
+ ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
+ f16[1024,33708]{0,1} %param2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0},
+ metadata={op_type="MatMul" op_name="transformer_v2/Transformer/decode/embedding_shared_weights_1/presoftmax_linear/MatMul"}
+ })")
+ .value();
+
+ SCOPED_TRACE(module->ToString());
+
+ EXPECT_TRUE(PadForF16Gemms(module.get()));
+ auto metadata = module->entry_computation()->root_instruction()->metadata();
+ EXPECT_EQ("MatMul", metadata.op_type());
+ EXPECT_EQ(
+ "transformer_v2/Transformer/decode/embedding_shared_weights_1/"
+ "presoftmax_linear/MatMul",
+ metadata.op_name());
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, NotCanonicalizedDot) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ %param1 = f16[3, 5, 2048, 1024] parameter(0)
+ %param2 = f16[3, 5, 1024, 33708] parameter(1)
+ ROOT %dot.2309 = f16[3,2048, 33708]{2, 1, 0} dot(f16[3, 5, 2048, 1024]{3, 2, 1, 0} %param1, f16[3, 5, 1024, 33708]{3, 2, 1, 0} %param2), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={3, 1}, rhs_contracting_dims={2, 1}})")
+ .value();
+
+ EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
new file mode 100644
index 0000000..00b73c9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
@@ -0,0 +1,660 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h"
+
+#include <optional>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/stream_executor/cuda/cuda_dnn.h"
+#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+inline absl::StatusOr<CudnnfMHAMaskKind> AsCudnnFmhaMaskKind(
+ CudnnfMHABackendConfig_MaskType mask_type) {
+ switch (mask_type) {
+ case CudnnfMHABackendConfig::NO_MASK:
+ return CudnnfMHAMaskKind::kNoMask;
+ case CudnnfMHABackendConfig::PADDING:
+ return CudnnfMHAMaskKind::kPadding;
+ case CudnnfMHABackendConfig::CAUSAL:
+ return CudnnfMHAMaskKind::kCausal;
+ case CudnnfMHABackendConfig::PADDING_CAUSAL:
+ return CudnnfMHAMaskKind::kPaddingCausal;
+ case CudnnfMHABackendConfig::ALIBI:
+ return CudnnfMHAMaskKind::kAlibi;
+ default:
+ return xla::Internal("Unknown fmha mask kind.");
+ }
+}
+
+// This is an interim structure to hold the parameters to construct a
+// GpufMHAConfig.
+// Struct to describe properties of a FMHA without being tied to specific
+// IR. Will be used to help build FMHA thunks from either XLA HLO or
+// LHLO GPU dialect in MLIR.
+struct GpufMHADescriptor {
+ CudnnfMHAKind kind;
+ CudnnfMHABackendConfig backend_config;
+ CudnnfMHAMaskKind mask_type;
+ Shape lhs_bmm1_shape;
+ Shape rhs_bmm1_shape;
+ Shape rhs_bmm2_shape;
+ Shape intermediate_lhs_bmm2_shape;
+ // This will contain both output shape and activation shape
+ absl::InlinedVector<Shape, 2> output_shapes;
+ DotDimensionNumbers bmm1_dnums;
+ DotDimensionNumbers bmm2_dnums;
+
+ std::optional<Shape> mask_shape;
+ std::optional<Shape> bias_shape;
+};
+
+struct GpufMHABackwardDescriptor {
+ CudnnfMHAKind kind;
+ CudnnfMHABackendConfig backend_config;
+ CudnnfMHAMaskKind mask_type;
+ Shape bmm1_grad_gemm1_rhs_shape;
+ Shape bmm1_grad_gemm2_rhs_shape;
+ Shape bmm2_grad_gemm1_lhs_shape;
+ Shape bmm2_grad_gemm2_rhs_shape;
+ Shape d_output_shape;
+ Shape d_bmm1_lhs_shape;
+ Shape d_bmm1_rhs_shape;
+ Shape d_bmm2_rhs_shape;
+ DotDimensionNumbers bmm1_grad_gemm1_dnums;
+ DotDimensionNumbers bmm1_grad_gemm2_dnums;
+ DotDimensionNumbers bmm2_grad_gemm1_dnums;
+ DotDimensionNumbers bmm2_grad_gemm2_dnums;
+
+ std::optional<Shape> d_s_shape;
+ std::optional<Shape> fwd_output_shape;
+ std::optional<Shape> mask_shape;
+ std::optional<Shape> d_bias_shape;
+ std::optional<Shape> bias_shape;
+ bool force_deterministic;
+};
+
+// Structure to describe static properties of a GPU fused Multi-Headed
+// Attention.
+struct GpufMHAConfig {
+ static absl::StatusOr<GpufMHAConfig> For(const GpufMHADescriptor &fmha_desc);
+ PrimitiveType
+ input_type; // Capture the primitive type of one of the inputs of BMM1
+ PrimitiveType output_type;
+ CudnnfMHAKind kind;
+ std::optional<double> fmha_scale;
+ std::optional<double> dropout_rate;
+ std::optional<int64_t> seed;
+
+ se::dnn::AlgorithmDesc algorithm;
+ CudnnfMHAMaskKind mask_type;
+ // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len]
+ // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
+ se::dnn::MatmulTensorDescriptor lhs_bmm1;
+ se::dnn::MatmulTensorDescriptor rhs_bmm1;
+ se::dnn::MatmulTensorDescriptor rhs_bmm2;
+ se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2;
+ se::dnn::TensorDescriptor output;
+
+ std::optional<se::dnn::TensorDescriptor> activation;
+ std::optional<se::dnn::TensorDescriptor> mask;
+ std::optional<se::dnn::TensorDescriptor> bias;
+};
+
+// Structure to describe static properties of a GPU fused Multi-Headed
+// Attention backward.
+struct GpufMHABackwardConfig {
+ static absl::StatusOr<GpufMHABackwardConfig> For(
+ const GpufMHABackwardDescriptor &fmha_desc);
+ PrimitiveType
+ input_type; // Capture the primitive type of one of the inputs of BMM1
+ PrimitiveType output_type;
+ CudnnfMHAKind kind;
+ std::optional<double> fmha_scale;
+ std::optional<double> dropout_rate;
+ std::optional<int64_t> seed;
+
+ se::dnn::AlgorithmDesc algorithm;
+ CudnnfMHAMaskKind mask_type;
+ // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
+ // d_bias -> [1, num_heads, q_seq_len, kv_seq_len]
+ se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs;
+ se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs;
+ se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs;
+ se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs;
+ se::dnn::MatmulTensorDescriptor d_output;
+ se::dnn::TensorDescriptor d_bmm1_lhs;
+ se::dnn::TensorDescriptor d_bmm1_rhs;
+ se::dnn::TensorDescriptor d_bmm2_rhs;
+ std::optional<se::dnn::TensorDescriptor> d_s;
+ std::optional<se::dnn::TensorDescriptor> mask;
+ std::optional<se::dnn::TensorDescriptor> d_bias;
+ std::optional<se::dnn::TensorDescriptor> fwd_output;
+ std::optional<se::dnn::TensorDescriptor> bias;
+};
+
+using se::DeviceMemory;
+using se::DeviceMemoryBase;
+using se::dnn::DataType;
+using se::dnn::MatmulTensorDescriptor;
+using se::dnn::TensorDescriptor;
+
+/*static*/ absl::StatusOr<GpufMHAConfig> GpufMHAConfig::For(
+ const GpufMHADescriptor &desc) {
+ // Get shapes from desc.
+ const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape;
+ const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape;
+ const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape;
+ const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape;
+ const Shape &output_shape = desc.output_shapes[0];
+
+ // Get DNN dtype from primtive types
+ TF_ASSIGN_OR_RETURN(
+ DataType lhs_bmm1_type,
+ GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type()));
+ TF_ASSIGN_OR_RETURN(
+ DataType rhs_bmm1_type,
+ GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(
+ DataType rhs_bmm2_type,
+ GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type()));
+ TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type,
+ GetDNNDataTypeFromPrimitiveType(
+ intermediate_lhs_bmm2_shape.element_type()));
+ TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType(
+ output_shape.element_type()));
+ GpufMHAConfig config;
+ config.input_type = lhs_bmm1_shape.element_type();
+ config.output_type = output_shape.element_type();
+
+ // Get MatmulTensorDescriptors for BMM1
+ config.lhs_bmm1 =
+ MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(),
+ desc.lhs_bmm1_shape.layout().minor_to_major(),
+ desc.bmm1_dnums.lhs_batch_dimensions(),
+ desc.bmm1_dnums.lhs_contracting_dimensions());
+ config.rhs_bmm1 =
+ MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(),
+ desc.rhs_bmm1_shape.layout().minor_to_major(),
+ desc.bmm1_dnums.rhs_batch_dimensions(),
+ desc.bmm1_dnums.rhs_contracting_dimensions());
+
+ // Get MatmulTensorDescriptors for BMM2
+ config.rhs_bmm2 =
+ MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(),
+ desc.rhs_bmm2_shape.layout().minor_to_major(),
+ desc.bmm2_dnums.rhs_batch_dimensions(),
+ desc.bmm2_dnums.rhs_contracting_dimensions());
+
+ config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For(
+ lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(),
+ desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(),
+ desc.bmm2_dnums.lhs_batch_dimensions(),
+ desc.bmm2_dnums.lhs_contracting_dimensions());
+
+ config.output = TensorDescriptor::For(output_type, output_shape.dimensions(),
+ output_shape.layout().minor_to_major());
+
+ if (desc.output_shapes.size() > 1) {
+ const Shape &activation_shape = desc.output_shapes.back();
+ // Generally, activation should have same type as output, but set it
+ // explicityly just to be safe.
+ TF_ASSIGN_OR_RETURN(
+ DataType activation_type,
+ GetDNNDataTypeFromPrimitiveType(activation_shape.element_type()));
+ config.activation =
+ TensorDescriptor::For(activation_type, activation_shape.dimensions(),
+ activation_shape.layout().minor_to_major());
+ }
+
+ if (desc.mask_shape) {
+ const Shape &mask_shape = *desc.mask_shape;
+ TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
+ mask_shape.element_type()));
+ config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
+ mask_shape.layout().minor_to_major());
+ }
+
+ if (desc.bias_shape) {
+ const Shape &bias_shape = *desc.bias_shape;
+ TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
+ bias_shape.element_type()));
+ config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
+ bias_shape.layout().minor_to_major());
+ }
+ config.kind = desc.kind;
+ config.mask_type = desc.mask_type;
+ const CudnnfMHABackendConfig &backend_config = desc.backend_config;
+ config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
+ config.fmha_scale.emplace(backend_config.fmha_scale());
+ config.dropout_rate.emplace(backend_config.dropout_rate());
+ config.seed.emplace(backend_config.seed());
+ return config;
+}
+
+/*static*/ absl::StatusOr<GpufMHABackwardConfig> GpufMHABackwardConfig::For(
+ const GpufMHABackwardDescriptor &desc) {
+ // Get shapes from desc.
+ const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape;
+ const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape;
+ const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape;
+ const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape;
+ const Shape &d_output_shape = desc.d_output_shape;
+ const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape;
+ const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape;
+ const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape;
+ // Get DNN dtype from primtive types
+ TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type,
+ GetDNNDataTypeFromPrimitiveType(
+ bmm1_grad_gemm1_rhs_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type,
+ GetDNNDataTypeFromPrimitiveType(
+ bmm1_grad_gemm2_rhs_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type,
+ GetDNNDataTypeFromPrimitiveType(
+ bmm2_grad_gemm1_lhs_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type,
+ GetDNNDataTypeFromPrimitiveType(
+ bmm2_grad_gemm2_rhs_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(
+ DataType d_output_type,
+ GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(
+ DataType d_bmm1_lhs_type,
+ GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(
+ DataType d_bmm1_rhs_type,
+ GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type()));
+
+ TF_ASSIGN_OR_RETURN(
+ DataType d_bmm2_rhs_type,
+ GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type()));
+
+ GpufMHABackwardConfig config;
+ config.input_type = bmm1_grad_gemm1_rhs_shape.element_type();
+ config.output_type = d_bmm1_lhs_shape.element_type();
+
+ // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1
+ config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For(
+ bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(),
+ desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(),
+ desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(),
+ desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions());
+
+ // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2
+ config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For(
+ bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(),
+ desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(),
+ desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(),
+ desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions());
+
+ // Get MatmulTensorDescriptors for BMM2 grad GEMM 1
+ config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For(
+ bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
+ desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(),
+ desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(),
+ desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions());
+
+ config.d_output = MatmulTensorDescriptor::For(
+ d_output_type, d_output_shape.dimensions(),
+ desc.d_output_shape.layout().minor_to_major(),
+ desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(),
+ desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions());
+
+ // Get MatmulTensorDescriptors for BMM2 grad GEMM 2
+ config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For(
+ bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(),
+ desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(),
+ desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(),
+ desc.bmm2_grad_gemm2_dnums
+ .rhs_contracting_dimensions()); // FMHA TODO: transpose here?
+
+ config.d_bmm1_lhs =
+ TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(),
+ d_bmm1_lhs_shape.layout().minor_to_major());
+ config.d_bmm1_rhs =
+ TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(),
+ d_bmm1_rhs_shape.layout().minor_to_major());
+ config.d_bmm2_rhs =
+ TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(),
+ d_bmm2_rhs_shape.layout().minor_to_major());
+ config.d_s = TensorDescriptor::For(
+ bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
+ bmm2_grad_gemm1_lhs_shape.layout().minor_to_major());
+
+ if (desc.d_bias_shape) {
+ const Shape &d_bias_shape = *desc.d_bias_shape;
+ // Get DNN dtype from primtive types
+ TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType(
+ d_bias_shape.element_type()));
+ config.d_bias =
+ TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(),
+ d_bias_shape.layout().minor_to_major());
+ }
+
+ if (desc.mask_shape) {
+ const Shape &mask_shape = *desc.mask_shape;
+ TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
+ mask_shape.element_type()));
+ config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
+ mask_shape.layout().minor_to_major());
+ }
+ if (desc.fwd_output_shape) {
+ const Shape &fwd_output_shape = *desc.fwd_output_shape;
+ TF_ASSIGN_OR_RETURN(
+ DataType fwd_output_type,
+ GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type()));
+ config.fwd_output =
+ TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(),
+ fwd_output_shape.layout().minor_to_major());
+ }
+
+ if (desc.bias_shape) {
+ const Shape &bias_shape = *desc.bias_shape;
+ TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
+ bias_shape.element_type()));
+ config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
+ bias_shape.layout().minor_to_major());
+ }
+
+ config.kind = desc.kind;
+ config.mask_type = desc.mask_type;
+ const CudnnfMHABackendConfig &backend_config = desc.backend_config;
+ config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
+ config.fmha_scale.emplace(backend_config.fmha_scale());
+ config.dropout_rate.emplace(backend_config.dropout_rate());
+ config.seed.emplace(backend_config.seed());
+ return config;
+}
+
+absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
+ se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) {
+ if (IsFwdCustomCallTofMHA(*custom_call)) {
+ TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind,
+ xla::gpu::GetCudnnfMHAKind(custom_call));
+ std::optional<Shape> mask_shape, bias_shape;
+ {
+ bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax ||
+ kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout;
+
+ if (has_bias) {
+ const HloInstruction *bias = custom_call->operand(3);
+ bias_shape = bias->shape();
+ }
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ const auto gpu_config,
+ custom_call->backend_config<xla::gpu::GpuBackendConfig>());
+ const xla::gpu::CudnnfMHABackendConfig &config =
+ gpu_config.cudnn_fmha_backend_config();
+ Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
+ absl::InlinedVector<Shape, 2> output_shapes = {
+ ShapeUtil::GetSubshape(custom_call->shape(), {0})};
+
+ bool has_activation =
+ xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3;
+ if (has_activation) {
+ output_shapes.push_back(
+ ShapeUtil::GetSubshape(custom_call->shape(), {1}));
+ }
+
+ Shape q_shape = custom_call->operand(0)->shape();
+ Shape k_shape = custom_call->operand(1)->shape();
+ Shape v_shape = custom_call->operand(2)->shape();
+ TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
+ AsCudnnFmhaMaskKind(config.mask_type()));
+ GpufMHADescriptor descriptor = {kind,
+ config,
+ cudnn_mask_type,
+ q_shape,
+ k_shape,
+ v_shape,
+ intermediate_tensor_shape,
+ output_shapes,
+ config.bmm1_dot_dimension_numbers(),
+ config.bmm2_dot_dimension_numbers(),
+ mask_shape,
+ bias_shape};
+
+ TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config,
+ GpufMHAConfig::For(descriptor));
+ TF_ASSIGN_OR_RETURN(
+ se::dnn::FMHAMaskKind dnn_mask_type,
+ GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
+ TF_ASSIGN_OR_RETURN(
+ se::gpu::CudnnGraph graph,
+ se::gpu::GetCudnnFlashAttentionOperationGraph(
+ dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1,
+ fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias,
+ fmha_config.activation, static_cast<float>(*fmha_config.fmha_scale),
+ fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
+ fmha_config.dropout_rate, dnn_mask_type));
+ return std::move(graph);
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ auto gpu_config,
+ custom_call->backend_config<xla::gpu::GpuBackendConfig>());
+ xla::gpu::CudnnfMHABackendConfig &config =
+ *gpu_config.mutable_cudnn_fmha_backend_config();
+
+ int input_index = 0;
+ Shape bmm1_grad_gemm1_rhs_shape =
+ custom_call->operand(input_index++)->shape();
+ Shape bmm1_grad_gemm2_rhs_shape =
+ custom_call->operand(input_index++)->shape();
+ Shape bmm2_grad_gemm2_rhs_shape =
+ custom_call->operand(input_index++)->shape();
+ Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape());
+ input_index++;
+ Shape d_output_shape = custom_call->operand(input_index++)->shape();
+
+ TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind,
+ GetCudnnfMHAKind(custom_call));
+ std::optional<Shape> mask_shape;
+
+ bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax ||
+ kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout);
+ std::optional<Shape> bias_shape;
+ if (has_bias) {
+ bias_shape = custom_call->operand(input_index++)->shape();
+ }
+
+ std::optional<Shape> fwd_output_shape =
+ custom_call->operand(input_index++)->shape();
+ if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING ||
+ config.mask_type() ==
+ xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) {
+ // skip q_seqlen and kv_seqlen
+ input_index += 2;
+ }
+ TF_RET_CHECK(input_index == custom_call->operand_count());
+
+ int output_index = 0;
+ Shape d_bmm1_lhs_shape =
+ ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+ Shape d_bmm1_rhs_shape =
+ ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+ Shape d_bmm2_rhs_shape =
+ ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+ std::optional<Shape> d_s_shape;
+ std::optional<Shape> d_bias_shape;
+ bool has_dbias = custom_call->shape().tuple_shapes().size() == 5;
+ if (has_dbias) {
+ d_bias_shape =
+ ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+ }
+ // The last one is the workspace.
+ TF_RET_CHECK(output_index ==
+ custom_call->shape().tuple_shapes().size() - 1);
+ TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
+ AsCudnnFmhaMaskKind(config.mask_type()));
+
+ const DebugOptions &debug_options =
+ custom_call->GetModule()->config().debug_options();
+ bool force_deterministic =
+ debug_options.xla_gpu_deterministic_ops() ||
+ debug_options.xla_gpu_exclude_nondeterministic_ops();
+ config.set_force_deterministic(force_deterministic);
+ TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));
+
+ GpufMHABackwardDescriptor descriptor = {
+ kind,
+ config,
+ cudnn_mask_type,
+ bmm1_grad_gemm1_rhs_shape,
+ bmm1_grad_gemm2_rhs_shape,
+ bmm2_grad_gemm1_lhs_shape,
+ bmm2_grad_gemm2_rhs_shape,
+ d_output_shape,
+ d_bmm1_lhs_shape,
+ d_bmm1_rhs_shape,
+ d_bmm2_rhs_shape,
+ config.bmm1_grad_gemm1_dot_dimension_numbers(),
+ config.bmm1_grad_gemm2_dot_dimension_numbers(),
+ config.bmm2_grad_gemm1_dot_dimension_numbers(),
+ config.bmm2_grad_gemm2_dot_dimension_numbers(),
+ d_s_shape,
+ fwd_output_shape,
+ mask_shape,
+ d_bias_shape,
+ bias_shape,
+ force_deterministic};
+
+ TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config,
+ GpufMHABackwardConfig::For(descriptor));
+ TF_ASSIGN_OR_RETURN(
+ se::dnn::FMHAMaskKind dnn_mask_type,
+ GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
+
+ TF_ASSIGN_OR_RETURN(
+ se::gpu::CudnnGraph graph,
+ se::gpu::GetCudnnFlashAttentionBackwardOperationGraph(
+ dnn_support, fmha_config.bmm1_grad_gemm1_rhs,
+ fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs,
+ fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output,
+ fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs,
+ fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate,
+ fmha_config.seed, *fmha_config.fmha_scale,
+ fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
+ fmha_config.bias != std::nullopt, dnn_mask_type,
+ force_deterministic));
+ return std::move(graph);
+ }
+}
+
+class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport &dnn_support,
+ BinaryMap &compilation_results)
+ : dnn_support_(dnn_support), compilation_results_(compilation_results) {}
+
+ void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) {
+ if (workspace_size == 0) {
+ return;
+ }
+ VLOG(4) << "Applying workspace size " << workspace_size << " to "
+ << hlo.ToString();
+ Shape *shape = hlo.mutable_shape();
+ shape->mutable_tuple_shapes()->back().set_dimensions(0, workspace_size);
+ }
+
+ absl::Status HandleCustomCall(HloInstruction *hlo) override {
+ if (!IsCustomCallTofMHA(*hlo)) {
+ return absl::OkStatus();
+ }
+
+ TF_ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace,
+ FingerprintWithBackendConfig<GpuBackendConfig>(*hlo));
+ auto workspace_size_it =
+ workspace_sizes_.find(fingerprint_without_workspace);
+ if (workspace_size_it == workspace_sizes_.cend()) {
+ TF_ASSIGN_OR_RETURN(
+ se::gpu::CudnnGraph graph,
+ HloCustomCallToCuDnnGraph(dnn_support_,
+ DynCast<HloCustomCallInstruction>(hlo)));
+
+ const int64_t workspace_size = graph.Graph().get_workspace_size();
+ workspace_sizes_.insert(workspace_size_it,
+ {fingerprint_without_workspace, workspace_size});
+ AddWorkspace(*hlo, workspace_size);
+
+ std::vector<uint8_t> serialized_graph;
+ RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph));
+ // Compute a new fingerprint with a potential workspace for the
+ // compilation results to match a fingerprint computed by the emitter.
+ TF_ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace,
+ FingerprintWithBackendConfig<GpuBackendConfig>(*hlo));
+ compilation_results_[fingerprint_with_workspace] =
+ std::string(reinterpret_cast<char *>(serialized_graph.data()),
+ serialized_graph.size());
+ } else {
+ VLOG(4) << "Cache hit.";
+ AddWorkspace(*hlo, workspace_size_it->second);
+ }
+
+ MarkAsChanged();
+ return absl::OkStatus();
+ }
+
+ private:
+ se::dnn::DnnSupport &dnn_support_;
+ BinaryMap &compilation_results_;
+ absl::flat_hash_map<std::string, int64_t> workspace_sizes_;
+};
+
+} // namespace
+
+absl::StatusOr<bool> CuDnnCustomCallCompiler::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ XLA_SCOPED_LOGGING_TIMER_LEVEL("cuDNN custom call compiler", 8);
+ return CuDnnCustomCallVisitor(dnn_support_, compilation_results_)
+ .RunOnModule(module, execution_threads);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h
new file mode 100644
index 0000000..810286f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h
@@ -0,0 +1,57 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// Compile cuDNN custom calls to binaries and serialize them.
+// Also adjust them in HLO to have correct workspace size.
+class CuDnnCustomCallCompiler : public HloModulePass {
+ public:
+ explicit CuDnnCustomCallCompiler(se::StreamExecutor& stream_exec,
+ BinaryMap& compilation_results)
+ : dnn_support_(*stream_exec.AsDnn()),
+ compilation_results_(compilation_results) {}
+
+ absl::string_view name() const override {
+ return "cudnn-custom-call-compiler";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::dnn::DnnSupport& dnn_support_;
+ BinaryMap& compilation_results_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc
new file mode 100644
index 0000000..71ed08c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc
@@ -0,0 +1,65 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CustomCallVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleCustomCall(HloInstruction *hlo) override {
+ if (hlo->custom_call_target() != kCuDnnFusionKind) {
+ return absl::OkStatus();
+ }
+ HloComputation *computation = hlo->GetModule()->AddEmbeddedComputation(
+ hlo->called_computations()[0]->Clone());
+ HloInstruction *fusion =
+ hlo->parent()->AddInstruction(HloInstruction::CreateFusion(
+ hlo->shape(), HloInstruction::FusionKind::kCustom, hlo->operands(),
+ computation));
+ GpuBackendConfig gpu_config;
+ FusionBackendConfig &backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+ backend_config.set_kind(hlo->custom_call_target());
+ TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, fusion));
+ return absl::OkStatus();
+ }
+};
+
+} // namespace
+
+absl::StatusOr<bool> CuDnnCustomCallConverter::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ return CustomCallVisitor().RunOnModule(module, execution_threads);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h
new file mode 100644
index 0000000..5397a4d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h
@@ -0,0 +1,47 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Converts custom calls with kCuDnnFusionKind backend config to
+// fusions with the same backend config. Frameworks can pass computations
+// outlined this way through StableHLO; after the conversion they can be
+// processed by XLA using the existing pipeline for custom fusions.
+class CuDnnCustomCallConverter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "cudnn-custom-call-converter";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc
new file mode 100644
index 0000000..ad29e15
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc
@@ -0,0 +1,47 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h"
+
+#include <gtest/gtest.h>
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ConverterTest = HloTestBase;
+
+TEST_F(ConverterTest, CustomCallGetsConvertedToCustomFusion) {
+ RunAndFilecheckHloRewrite(R"(
+f {
+ a = s8[] parameter(0)
+ ROOT r = s8[] add(a, a)
+}
+
+ENTRY e {
+ b = s8[] parameter(0)
+ ROOT c = s8[] custom-call(b),
+ custom_call_target="__cudnn$fusion", called_computations={f}
+})",
+ CuDnnCustomCallConverter(), R"(
+; CHECK: ROOT %fusion = s8[] fusion(%b), kind=kCustom, calls=%f
+; CHECK-SAME: "fusion_backend_config":{"kind":"__cudnn$fusion"}
+ )");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc
new file mode 100644
index 0000000..c51a76f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc
@@ -0,0 +1,1566 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+
+#include <algorithm>
+#include <array>
+#include <cstdint>
+#include <functional>
+#include <limits>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/ml_dtypes.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = match;
+
+bool IsConvCustomCall(const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kCustomCall &&
+ (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
+ instr->custom_call_target() ==
+ kCudnnConvBiasActivationForwardCallTarget);
+}
+
+bool IsConvDepthwise(const HloInstruction* instr) {
+ int64_t feature_group_count = instr->feature_group_count();
+ if (feature_group_count == 1) {
+ return false;
+ }
+
+ const HloInstruction* input = instr->operand(0);
+ int64_t input_feature_dimension =
+ instr->convolution_dimension_numbers().input_feature_dimension();
+ int64_t input_feature_count =
+ input->shape().dimensions(input_feature_dimension);
+ return input_feature_count == feature_group_count;
+}
+
+// We don't want to upgrade depthwise convolutions to ConvBiasActivation,
+// because the fused CUDNN functions are slower for some of those.
+bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
+ return IsConvCustomCall(instr) && !IsConvDepthwise(instr);
+}
+
+bool IsROCm(se::GpuComputeCapability cc) {
+ return std::holds_alternative<se::RocmComputeCapability>(cc);
+}
+
+// elu, relu6, and leaky-relu activations are supported in cudnn via the
+// "runtime fusion" engine, which JIT compiles C++ code. This can be slow to
+// compile, so we guard it with a debug option.
+//
+// nvidia currently recommends that we enable this only on Ampere+, but we've
+// tested on Turing (sm75) and it seems to work fine.
+//
+// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default
+// due to apparent bugs in cudnn 8.9.0. See debug_options_flags.cc for details.
+bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts,
+ se::GpuComputeCapability cc) {
+ const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&cc);
+ if (cuda_cc != nullptr)
+ return debug_opts.xla_gpu_use_runtime_fusion() && cuda_cc->IsAtLeast(7, 5);
+ else
+ return true;
+}
+
+bool IsSuitableForCudnnRuntimeFusion(HloInstruction* conv) {
+ // cudnn runtime fusion is pathologically slow on convs with side-inputs.
+ // TODO(kaixih@nvidia): remove this check when cuDNN fixes it.
+ if (conv->operands().size() > 3) {
+ return false;
+ }
+
+ // cuDNN runtime funsion kernels require 32-bit aligned data access, which
+ // means that the number of in/out channels must be divisible by 2 for fp16.
+ // (We don't currently do runtime fusion for int8.)
+ if (conv->operand(0)->shape().element_type() != F16) {
+ return false;
+ }
+ const Shape& shape = conv->operand(1)->shape();
+ int64_t num_input_features = shape.dimensions(
+ conv->convolution_dimension_numbers().kernel_input_feature_dimension());
+ int64_t num_output_features = shape.dimensions(
+ conv->convolution_dimension_numbers().kernel_output_feature_dimension());
+ if (num_input_features % 2 != 0 || num_output_features % 2 != 0) {
+ return false;
+ }
+
+ return true;
+}
+
+// Can instr be converted to type `dst_ty` without losing any precision? For
+// our purposes, this is true if:
+//
+// - instr already has type dst_ty, or
+// - instr is convert<wider type>(op_with_dst_ty), or
+// - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
+// get back exactly the original value, or
+// - instr is a broadcast, reshape, or transpose of one of the above.
+bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
+ PrimitiveType dst_ty) {
+ if (instr->shape().element_type() == dst_ty) {
+ return true;
+ }
+
+ if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
+ // Check that the convert from dst_ty to instr->element_type() doesn't lose
+ // precision. Otherwise, this convert is not lossless.
+ return primitive_util::CastPreservesValues(dst_ty,
+ instr->shape().element_type());
+ }
+
+ if (instr->opcode() == HloOpcode::kConstant) {
+ if (!instr->shape().IsArray()) {
+ return false;
+ }
+ // Check if instr's literal roundtrips to ty and back to its original type
+ // without modification.
+ PrimitiveType orig_ty = instr->shape().element_type();
+
+ // The only reason Convert() should fail is if we don't support converting
+ // from x to y, which indeed means it's not losslessly-convertible.
+ absl::StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
+ if (!converted1.ok()) {
+ return false;
+ }
+ absl::StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
+ if (!converted2.ok()) {
+ return false;
+ }
+
+ return instr->literal() == *converted2;
+ }
+
+ if (instr->opcode() == HloOpcode::kBroadcast ||
+ instr->opcode() == HloOpcode::kReshape ||
+ instr->opcode() == HloOpcode::kTranspose) {
+ return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
+ }
+
+ return false;
+}
+
+// Helpers suitable for use in m::Op().WithPredicate(...).
+bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
+ return IsLosslesslyConvertibleTo(instr, S8);
+}
+bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
+ return IsLosslesslyConvertibleTo(instr, F16);
+}
+
+// If `conv` is a vanilla forward conv, transforms it into a
+// conv-bias-activation. If it's already a conv-bias-activation, does nothing.
+//
+// If `conv` is anything else, returns an error.
+absl::StatusOr<HloInstruction*> EnsureIsConvBiasActivation(
+ HloInstruction* conv) {
+ CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
+
+ if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
+ return conv;
+ }
+
+ if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
+ HloComputation* comp = conv->parent();
+
+ const Shape& shape = conv->shape().tuple_shapes(0);
+ int64_t num_output_features = shape.dimensions(
+ conv->convolution_dimension_numbers().output_feature_dimension());
+
+ // bias for integer convs is always f32, see
+ // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+ PrimitiveType bias_ty;
+ if (primitive_util::IsIntegralType(shape.element_type())) {
+ bias_ty = F32;
+ } else {
+ bias_ty = shape.element_type();
+ }
+ auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
+
+ absl::InlinedVector<HloInstruction*, 3> new_operands(
+ conv->operands().begin(), conv->operands().end());
+ new_operands.push_back(bias);
+
+ HloInstruction* new_conv = comp->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), new_operands));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
+ new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
+ comp->parent()->SetAndUniquifyInstrName(new_conv,
+ "cudnn-conv-bias-activation");
+ return new_conv;
+ }
+
+ return FailedPrecondition("Unsupported conv: %s", conv->ToString());
+}
+
+// convert<cvt_type>(gte(custom-call<conv_type>(int8_x, int8_w))) ->
+// gte(custom-call<cvt_type>(int8_x, int8_w))
+absl::StatusOr<bool> FuseConvertTypeIntoConv(HloComputation* comp,
+ PrimitiveType conv_type,
+ PrimitiveType cvt_type) {
+ bool changed = false;
+ for (auto instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* conv = nullptr;
+ auto tuple_elem =
+ m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
+ .WithElementType(conv_type);
+ auto pattern =
+ m::Convert(tuple_elem.WithOneUser()).WithElementType(cvt_type);
+ if (!Match(instr, pattern)) {
+ continue;
+ }
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseConvertTypeIntoConv: ", conv->ToString());
+ })) {
+ continue;
+ }
+
+ Shape new_shape = conv->shape();
+ new_shape.mutable_tuple_shapes(0)->set_element_type(cvt_type);
+ HloInstruction* new_conv =
+ comp->AddInstruction(conv->CloneWithNewShape(new_shape));
+ comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
+ MakeGetTupleElementHlo(new_conv, 0));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
+
+ changed = true;
+ }
+
+ return changed;
+}
+
+struct ConvConvertTypes {
+ PrimitiveType convolution_type;
+ PrimitiveType conversion_type;
+};
+
+// Remove convert around convolution by making the convolution-type
+// (custom call) to be the same as the conversion result.
+// For example: convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
+// gte(custom-call<float>(int8_x, int8_w))
+absl::StatusOr<bool> FuseRemoveConvertInConv(HloComputation* comp) {
+ bool changed = false;
+ // Note: We are eliminating F16->F32 because it fails on internal tests.
+ std::array<ConvConvertTypes, 3> types{{
+ {S32, F32},
+ {S8, F32},
+ {F32, S8},
+ }};
+ for (auto [conv_type, cvt_type] : types) {
+ TF_ASSIGN_OR_RETURN(bool curr_change,
+ FuseConvertTypeIntoConv(comp, conv_type, cvt_type));
+ changed |= curr_change;
+ }
+ return changed;
+}
+
+// alpha * gte(custom-call(...)) ->
+// gte(custom-call(..., backend_config={alpha})).
+absl::StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
+ bool changed = false;
+ for (auto instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* conv = nullptr;
+ HloInstruction* gte = nullptr;
+ HloInstruction* alpha = nullptr;
+
+ auto pattern = m::MultiplyAnyOrder(
+ m::GetTupleElement(
+ >e, m::Op(&conv).WithPredicate(IsNonDepthwiseConvCustomCall), 0)
+ .WithOneUse(),
+ m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
+ if (!Match(instr, pattern)) {
+ continue;
+ }
+
+ // alpha is f32 except for f64 convs, where it's f64. See
+ // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+ PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
+ if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+
+ if (config.conv_result_scale() != 1) {
+ continue;
+ }
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseConvAlpha: ", conv->ToString());
+ })) {
+ continue;
+ }
+
+ // StreamExecutor doesn't support the alpha parameter on non-bias-activation
+ // convs, so we have to upgrade `conv`.
+ TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+
+ TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
+ config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
+ TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
+
+ changed = true;
+ }
+ return changed;
+}
+
+// The format of the serialized graph describing a sequence of ops fused
+// into the cuDNN convolution Custom Call is
+// "UID:[output_type]conv();UID[output_type]:op_name(operand
+// UID);UID:[output_type]op_name(operand UID);..." with the convolution assumed
+// to be the first op in the graph. Operand UIDs identifying ops outside the
+// serialized graph are elided. Currently, multiplication and division by a
+// broadcast scalar, addition of a matrix bias, the application of a ReLU
+// activation and the calculation of the maximum of the absolute value are
+// supported.
+class GraphString {
+ public:
+ GraphString() = default;
+
+ bool AppendOp(std::string op_name, HloInstruction* op,
+ std::vector<HloInstruction*> operands = {}) {
+ std::optional<int64_t> operand_uid;
+ int num_operands_in_graph = 0;
+ for (HloInstruction* operand : operands) {
+ if (OpInGraph(operand->unique_id())) {
+ num_operands_in_graph++;
+ // Ops with more than one operand in the graph are not supported.
+ if (num_operands_in_graph > 1) {
+ return false;
+ }
+ operand_uid = operand->unique_id();
+ }
+ }
+ graph_.emplace_back(OpDescriptor(
+ {op->unique_id(), op->shape().element_type(), op_name, operand_uid}));
+ return true;
+ }
+
+ void ChangeDataType(PrimitiveType type) {
+ DCHECK(!graph_.empty());
+ graph_.back().output_type = type;
+ }
+
+ std::string Graph() const {
+ std::string graph;
+ for (OpDescriptor op : graph_) {
+ graph.append(std::to_string(op.uid));
+ graph.append(":[" +
+ primitive_util::LowercasePrimitiveTypeName(op.output_type) +
+ "]");
+ graph.append(op.name);
+ graph.append("(");
+ if (op.operand.has_value()) {
+ graph.append(std::to_string(*op.operand));
+ }
+ graph.append(");");
+ }
+ return graph;
+ }
+
+ bool OpInGraph(int64_t uid, std::string op_name = "") const {
+ auto op_filter = [&](OpDescriptor op) -> bool {
+ if (op_name.empty()) {
+ return op.uid == uid;
+ } else {
+ return op.uid == uid && op.name == op_name;
+ }
+ };
+ return std::find_if(graph_.begin(), graph_.end(), op_filter) !=
+ graph_.end();
+ }
+
+ private:
+ struct OpDescriptor {
+ int64_t uid;
+ PrimitiveType output_type;
+ std::string name;
+ std::optional<int64_t> operand;
+ };
+
+ std::vector<OpDescriptor> graph_;
+};
+
+bool IsF8Type(const HloInstruction* instr) {
+ return primitive_util::IsF8Type(instr->shape().element_type());
+}
+
+bool IsScalar(const HloInstruction* instr) {
+ return ShapeUtil::IsScalar(instr->shape());
+}
+
+std::optional<PrimitiveType> IsSaturatingCastToF8(HloInstruction* instr) {
+ HloInstruction *op, *clamp_lower, *clamp_upper;
+ if (Match(instr,
+ m::Convert(
+ &op,
+ m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), m::Op(),
+ m::Broadcast(m::ConstantScalar(&clamp_upper))))) &&
+ ((op->shape().element_type() == F8E4M3FN &&
+ clamp_lower->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e4m3fn>::lowest())) &&
+ clamp_upper->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e4m3fn>::max()))) ||
+ (op->shape().element_type() == F8E5M2 &&
+ clamp_lower->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e5m2>::lowest())) &&
+ clamp_upper->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e5m2>::max()))))) {
+ return op->shape().element_type();
+ }
+ return std::nullopt;
+}
+
+// Returns whether the HLO Computation applied by `op` calculates the largest
+// element.
+bool AppliesMaxReduce(HloInstruction* op) {
+ HloComputation* reduce_comp = op->to_apply();
+ HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
+ return ShapeUtil::IsScalar(op->shape()) &&
+ ShapeUtil::IsScalar(op->operand(1)->shape()) &&
+ op->operand(1)->IsConstant() &&
+ op->operand(1)->literal().GetAsDouble({}) <= 0. &&
+ reduce_comp_root->opcode() == HloOpcode::kMaximum &&
+ reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
+ reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
+}
+
+// Recursively captures and serializes the graph of pointwise operations
+// operating on the convolution.
+void CaptureConvGraphRecursive(HloInstruction* instr,
+ std::vector<HloInstruction*>& operands,
+ std::vector<HloInstruction*>& aux_outputs,
+ GraphString& graph_string,
+ absl::flat_hash_set<int>& visited_instrs,
+ HloInstruction*& final_instr) {
+ // Avoid visiting the same instruction more than once.
+ if (!visited_instrs.emplace(instr->unique_id()).second) {
+ return;
+ }
+ final_instr = instr;
+
+ // Copy the current state in case fusion will be unsuccessful or unfavorable.
+ GraphString init_graph_string = graph_string;
+ std::vector<HloInstruction*> init_operands = operands,
+ init_aux_outputs = aux_outputs;
+ // The loop adds each user of `instr` that supports fusion into the
+ // cuDNN convolution Custom Call to GraphString. Most ops following the
+ // convolution describe a linear sequence that generates a single return
+ // tensor. The identification of one of these linear ops is followed by a
+ // recursive call of CaptureConvGraphRecursive to match and potentially fuse
+ // its users. The calculation of the scalar maximum of the absolute value
+ // (Amax) of a preceding op is considered a nonlinear user as it adds a
+ // return value to the convolution. The users of a nonlinear op are
+ // not considered for fusion into the Custom Call. The numbers of linear and
+ // nonlinear users of `instr` are stored in `num_linear_users` and
+ // `num_nonlinear_users`.
+ int num_linear_users = 0, num_nonlinear_users = 0;
+ for (HloInstruction* user : instr->users()) {
+ HloInstruction *op, *operand0, *operand1;
+ // Add
+ if (Match(user, m::AddAnyOrder(&op, m::Op(&operand0), m::Op(&operand1)))) {
+ if (graph_string.AppendOp("add", op, {operand0, operand1})) {
+ operands.push_back(operand0 == instr ? operand1 : operand0);
+ num_linear_users++;
+ CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+ visited_instrs, final_instr);
+ }
+ continue;
+ }
+ // Scale
+ if (Match(user, m::MultiplyAnyOrder(&op, m::Op(&operand0),
+ m::Broadcast(m::Op(&operand1)))) &&
+ ShapeUtil::IsScalar(operand1->shape())) {
+ if (graph_string.AppendOp("scale", op, {operand0, operand1})) {
+ operands.push_back(operand1);
+ num_linear_users++;
+ CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+ visited_instrs, final_instr);
+ }
+ continue;
+ }
+ // Inverse Scale
+ if (Match(user, m::Divide(&op, m::Op(&operand0),
+ m::Broadcast(m::Op(&operand1)))) &&
+ ShapeUtil::IsScalar(operand1->shape())) {
+ if (graph_string.AppendOp("invscale", op, {operand0, operand1})) {
+ operands.push_back(operand1);
+ num_linear_users++;
+ CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+ visited_instrs, final_instr);
+ }
+ continue;
+ }
+ // ReLU
+ if (Match(user, m::MaximumAnyOrder(&op, m::Op(&operand0),
+ m::Broadcast(m::ConstantScalar(0))))) {
+ if (graph_string.AppendOp("relu", op, {operand0})) {
+ num_linear_users++;
+ CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+ visited_instrs, final_instr);
+ }
+ continue;
+ }
+ // Maximum of the absolute value (Amax) following ReLU (elided Abs) -- not
+ // a linear user
+ if (Match(user, m::Reduce(&op, m::Op(&operand0), m::Op())) &&
+ graph_string.OpInGraph(operand0->unique_id(), "relu") &&
+ AppliesMaxReduce(op)) {
+ if (graph_string.AppendOp("amax", op, {operand0})) {
+ aux_outputs.emplace_back(op);
+ num_nonlinear_users++;
+ }
+ continue;
+ }
+
+ // The following patterns match the user of `user`.
+ if (!user->users().empty()) {
+ HloInstruction* users_user = user->users()[0];
+ // Convert with Clamp to FP8 types
+ std::optional<PrimitiveType> f8_type = IsSaturatingCastToF8(users_user);
+ if (f8_type.has_value()) {
+ graph_string.ChangeDataType(f8_type.value());
+ num_linear_users++;
+ CaptureConvGraphRecursive(users_user, operands, aux_outputs,
+ graph_string, visited_instrs, final_instr);
+ continue;
+ }
+ // Maximum of the absolute value (Amax) -- not a linear user
+ if (Match(users_user,
+ m::Reduce(&op, m::Abs(m::Op(&operand0)), m::Op())) &&
+ AppliesMaxReduce(op)) {
+ if (graph_string.AppendOp("amax", op, {operand0})) {
+ aux_outputs.emplace_back(op);
+ num_nonlinear_users++;
+ }
+ continue;
+ }
+ }
+ }
+ // Do not fuse into the cuDNN convolution Custom Call when there are more than
+ // one linear or nonlinear users, or when the number of users eligible for
+ // fusion is less than the total number of users.
+ if (num_linear_users > 1 || num_nonlinear_users > 1 ||
+ num_linear_users + num_nonlinear_users < instr->user_count()) {
+ graph_string = init_graph_string;
+ operands = init_operands;
+ aux_outputs = init_aux_outputs;
+ final_instr = instr;
+ }
+}
+
+// Captures in a GraphString the subgraph of pointwise operations operating on
+// the convolution that will be fused into the cuDNN convolution Custom Call.
+absl::StatusOr<
+ std::tuple<std::vector<HloInstruction*>, std::vector<HloInstruction*>,
+ GraphString, HloInstruction*>>
+CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution,
+ HloInstruction* wide_input, HloInstruction* wide_filter,
+ HloInstruction* input_scale, HloInstruction* filter_scale,
+ bool x_mult_scale, bool w_mult_scale) {
+ GraphString graph_string;
+ graph_string.AppendOp("conv", instr);
+
+ // Shift the scaling of the input and filter to the output of the convolution.
+ HloInstruction *input_scaled_conv, *filter_scaled_conv;
+ if (input_scale) {
+ TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(0, wide_input));
+ HloInstruction* bcast_input_scale = instr->AddInstruction(
+ HloInstruction::CreateBroadcast(instr->shape(), input_scale, {}));
+ input_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
+ instr->shape(),
+ x_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide, instr,
+ bcast_input_scale));
+ TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(input_scaled_conv));
+ }
+ if (filter_scale) {
+ TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(1, wide_filter));
+ HloInstruction* bcast_filter_scale = instr->AddInstruction(
+ HloInstruction::CreateBroadcast(instr->shape(), filter_scale, {}));
+ filter_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
+ instr->shape(),
+ w_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide,
+ input_scale ? input_scaled_conv : instr, bcast_filter_scale));
+ TF_RETURN_IF_ERROR((input_scale ? input_scaled_conv : instr)
+ ->ReplaceAllUsesWith(filter_scaled_conv));
+ }
+
+ std::vector<HloInstruction*> operands, aux_outputs;
+ absl::flat_hash_set<int> visited_instrs;
+ HloInstruction* final_instr;
+ CaptureConvGraphRecursive(instr, operands, aux_outputs, graph_string,
+ visited_instrs, final_instr);
+ return std::make_tuple(operands, aux_outputs, graph_string, final_instr);
+}
+
+// Matches convolutions operating on FP8 inputs and filters and rewrites into a
+// ForwardGraph Custom Call. For scaled FP8 convolutions on Hopper systems, the
+// following steps are elided and rewritten into a ForwardGraph Custom Call:
+//
+// 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32.
+// 2. Optionally unscale the filter and input by multiplying or dividing by
+// scalars.
+// 3. Evaluate the convolution based on the scaled filter and input.
+// 4. Apply a series of elementwise transformations, where a transformation can
+// be adding a matrix bias, applying a ReLU activation, or
+// multiplying or dividing by a broadcast scalar.
+// 5. Optionally calculate the maximum of the absolute of the result.
+// 6. Optionally cast the output back to FP8.
+absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
+ se::CudaComputeCapability cc,
+ se::dnn::VersionInfo dnn_version,
+ int32_t toolkit_version) {
+ bool changed = false;
+
+ if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) {
+ return false;
+ }
+ if (toolkit_version < 12000) {
+ return false;
+ }
+ if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) {
+ return false;
+ }
+ for (auto instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction *convolution, *gte, *input, *filter,
+ *input_scale = nullptr, *filter_scale = nullptr,
+ *input_scale_op = nullptr, *filter_scale_op = nullptr,
+ *wide_input = nullptr, *wide_filter = nullptr;
+
+ auto conv_operand_maybe_scaled = [](HloInstruction** operand,
+ HloInstruction** wide_operand,
+ HloInstruction** scale_op,
+ HloInstruction** scale) {
+ return m::AnyOf<HloInstruction>(
+ m::Op(operand).WithPredicate(IsF8Type),
+ m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
+ m::Divide(
+ scale_op,
+ m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
+ m::Broadcast(m::Op(scale).WithPredicate(IsScalar))),
+ m::MultiplyAnyOrder(
+ scale_op,
+ m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
+ m::Broadcast(m::Op(scale).WithPredicate(IsScalar))));
+ };
+
+ // TODO(philipphack): Consider allowing ops between dequantization and
+ // convolution.
+ auto pattern = m::GetTupleElement(
+ >e,
+ m::CustomCall(
+ &convolution,
+ conv_operand_maybe_scaled(&input, &wide_input, &input_scale_op,
+ &input_scale),
+ conv_operand_maybe_scaled(&filter, &wide_filter, &filter_scale_op,
+ &filter_scale))
+ .WithPredicate(IsConvCustomCall),
+ 0);
+ if (Match(instr, pattern)) {
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("F8GraphConv: ", convolution->ToString());
+ })) {
+ continue;
+ }
+
+ std::vector<HloInstruction*> operands, aux_outputs;
+ GraphString graph_string;
+ HloInstruction* final_instr;
+
+ TF_ASSIGN_OR_RETURN(
+ std::tie(operands, aux_outputs, graph_string, final_instr),
+ CaptureConvGraph(
+ instr, convolution, wide_input, wide_filter, input_scale,
+ filter_scale,
+ input_scale_op ? input_scale_op->opcode() == HloOpcode::kMultiply
+ : false,
+ filter_scale_op
+ ? filter_scale_op->opcode() == HloOpcode::kMultiply
+ : false));
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ convolution->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+
+ config.set_serialized_graph(graph_string.Graph());
+ operands.insert(operands.begin(), input);
+ operands.insert(operands.begin() + 1, filter);
+
+ std::vector<Shape> output_shapes;
+ output_shapes.emplace_back(ShapeUtil::ChangeElementType(
+ ShapeUtil::GetTupleElementShape(convolution->shape(), 0),
+ final_instr->shape().element_type()));
+ for (HloInstruction* aux_output : aux_outputs) {
+ output_shapes.emplace_back(aux_output->shape());
+ }
+ output_shapes.emplace_back(
+ ShapeUtil::GetTupleElementShape(convolution->shape(), 1));
+
+ HloInstruction* new_convolution =
+ comp->AddInstruction(convolution->CloneWithNewOperands(
+ ShapeUtil::MakeTupleShape(output_shapes), operands));
+
+ new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget);
+ TF_RETURN_IF_ERROR(new_convolution->set_backend_config(gpu_config));
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
+ MakeGetTupleElementHlo(new_convolution, 0));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte));
+
+ for (int i = 0; i < aux_outputs.size(); ++i) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
+ MakeGetTupleElementHlo(new_convolution, i + 1));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(aux_outputs[i], new_gte));
+ }
+
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
+ bool changed = false;
+ for (auto instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* conv = nullptr;
+ HloInstruction* gte = nullptr;
+ HloInstruction* addend = nullptr;
+
+ auto pattern = m::AddAnyOrder(
+ m::GetTupleElement(>e,
+ m::Op(&conv)
+ .WithPredicate(IsNonDepthwiseConvCustomCall)
+ .WithOneUse(),
+ 0)
+ .WithOneUse(),
+ m::Op(&addend));
+ if (!Match(instr, pattern)) {
+ continue;
+ }
+
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
+ })) {
+ continue;
+ }
+
+ // If it's a vanilla forward conv, upgrade it to a bias-activation conv. We
+ // only want to do this if the fusion will succeed, but we're guaranteed
+ // that it will, because the only reason we'll bail at this point is if
+ // !can_accept_bias && !can_accept_side_input, and our shiny new
+ // bias-activation conv will be able to accept both.
+ if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
+ TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+ }
+
+ // Can't fuse bias or side-input if the conv already has a relu (or other
+ // activation), because bias and side-input are added before the activation
+ // is applied.
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+ if (config.activation_mode() != se::dnn::kNone) {
+ continue;
+ }
+
+ // Does `conv` already have a (nonzero) bias? Does it already have a
+ // side_input?
+ bool can_accept_bias =
+ Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
+ bool can_accept_side_input = conv->operand_count() < 4;
+
+ // The addend can be fused as a bias if
+ // - it is 1D broadcasted in the output feature dimension, and
+ // - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
+ // convs, and conv_ty for floating-point convs)
+ PrimitiveType conv_ty = gte->shape().element_type();
+ PrimitiveType bias_ty =
+ primitive_util::IsFloatingPointType(conv_ty) ? conv_ty : F32;
+ bool addend_may_be_rank1_bias =
+ addend->opcode() == HloOpcode::kBroadcast &&
+ addend->dimensions().size() == 1 &&
+ addend->dimensions(0) ==
+ conv->convolution_dimension_numbers().output_feature_dimension() &&
+ IsLosslesslyConvertibleTo(addend, bias_ty);
+
+ bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
+ addend->dimensions().empty() &&
+ IsLosslesslyConvertibleTo(addend, bias_ty);
+
+ absl::InlinedVector<HloInstruction*, 4> new_operands(
+ conv->operands().begin(), conv->operands().end());
+ if (can_accept_bias && addend_may_be_rank1_bias) {
+ new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
+ &addend->operand(0)->metadata());
+ } else if (can_accept_bias && addend_may_be_rank0_bias) {
+ new_operands[2] = MakeBroadcastHlo(
+ MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
+ &addend->operand(0)->metadata()),
+ /*broadcast_dimensions=*/{},
+ /*result_shape_bounds=*/
+ {gte->shape().dimensions(conv->convolution_dimension_numbers()
+ .output_feature_dimension())});
+ } else if (can_accept_side_input) {
+ CHECK_EQ(new_operands.size(), 3);
+ new_operands.push_back(addend);
+ config.set_side_input_scale(1);
+ } else {
+ // Can't fuse; this op already has a bias and a side-input.
+ continue;
+ }
+
+ HloInstruction* new_conv = comp->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), new_operands));
+ comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
+ MakeGetTupleElementHlo(new_conv, 0));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
+ changed = true;
+ }
+ return changed;
+}
+
+// custom-call(..., alpha * side_input) ->
+// custom-call(..., side_input, backend_config={alpha}).
+//
+// We also have to support the more complicated case of
+//
+// custom-call(..., reshape(side_input * alpha)) -->
+// custom-call(..., reshape(side_input), backend_config={alpha}),
+//
+// where `reshape` can be an arbitrary chain of reshapes+transposes. This idiom
+// is created by the ReshapeMover pass.
+absl::StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* conv;
+ HloInstruction* side_input;
+ auto pattern = m::Op(&conv)
+ .WithPredicate(IsConvCustomCall)
+ .WithOperand(3, m::Op(&side_input));
+ if (!Match(instr, pattern)) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+ if (config.side_input_scale() != 1) {
+ continue;
+ }
+
+ // Given side_input, pattern match the following (working from bottom up).
+ //
+ // before_reshape = multiply(base, broadcast(alpha))
+ // side_input = chain_of_reshapes_and_transposes(before_reshape)
+ //
+ // where alpha is a scalar constant.
+ //
+ // alpha is f32 except for f64 convs, where it's f64. See
+ // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+ HloInstruction* before_reshape = side_input;
+ while (before_reshape->opcode() == HloOpcode::kReshape ||
+ before_reshape->opcode() == HloOpcode::kTranspose) {
+ before_reshape = before_reshape->mutable_operand(0);
+ }
+
+ PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
+ PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
+ HloInstruction* base;
+ HloInstruction* alpha;
+ if (!Match(
+ before_reshape,
+ m::MultiplyAnyOrder(
+ m::Op(&base),
+ m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
+ [&](const HloInstruction* instr) {
+ return IsLosslesslyConvertibleTo(instr, alpha_ty);
+ }))))) {
+ continue;
+ }
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
+ })) {
+ continue;
+ }
+
+ // Rewrite conv's operand 3 to
+ //
+ // chain_of_reshapes_and_transposes(before_reshape).
+ //
+ // and store alpha in the conv's backend config.
+ //
+ // We're going to do something bad here: We aren't going to check that the
+ // chain of reshapes/transposes has one use, so we're potentially
+ // duplicating all these instructions (once with alpha and once without).
+ //
+ // This is justified because
+ //
+ // - duplicating reshapes/transposes shouldn't be "that bad" -- these
+ // instructions can usually be fused, and
+ //
+ // - *not* fusing alpha can be catastrophic. For s8->s8 convolutions, the
+ // side-input must be s8. But the product side_input * alpha is f32, so
+ // we can only see that side-input is s8 if we fuse alpha. IOW not fusing
+ // alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
+ // slower than some extra transposes.
+
+ // Recursively clone chain_of_reshapes_and_transposes until we get to
+ // `before_reshape`, at which point we skip the multiply(base, alpha) and
+ // just return base.
+ std::function<HloInstruction*(const HloInstruction*)> clone =
+ [&](const HloInstruction* instr) {
+ if (instr == before_reshape) {
+ return base;
+ }
+ CHECK(instr->opcode() == HloOpcode::kReshape ||
+ instr->opcode() == HloOpcode::kTranspose)
+ << "Must be reshape or transpose: " << instr->ToString();
+ return comp->AddInstruction(instr->CloneWithNewOperands(
+ instr->shape(), {clone(instr->operand(0))}));
+ };
+ absl::InlinedVector<HloInstruction*, 4> new_operands(
+ conv->operands().begin(), conv->operands().end());
+ new_operands[3] = clone(side_input);
+
+ HloInstruction* new_conv = comp->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), new_operands));
+ comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+
+ TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
+ config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
+
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
+ changed = true;
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> FuseElu(HloComputation* comp,
+ se::GpuComputeCapability cc) {
+ if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
+ cc)) {
+ return false;
+ }
+
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction *gte1, *gte2, *gte3;
+ HloInstruction* conv;
+ HloInstruction* expm1;
+
+ if (!Match(instr,
+ m::Select(m::Compare(m::GetTupleElement(>e1, m::Op()),
+ m::Broadcast(m::ConstantEffectiveScalar(0)))
+ .WithComparisonDirection(ComparisonDirection::kGt)
+ .WithOneUse(),
+ m::GetTupleElement(
+ >e2,
+ m::Op(&conv)
+ .WithPredicate(IsNonDepthwiseConvCustomCall)
+ .WithOneUse(),
+ /*tuple_index=*/0)
+ // TODO(jlebar): Why only fp16?
+ .WithElementType(F16),
+ m::Op(&expm1)
+ .WithOpcode(HloOpcode::kExpm1)
+ .WithOperand(0, m::GetTupleElement(>e3, m::Op()))
+ .WithOneUse()))) {
+ continue;
+ }
+
+ // The three GTEs should be the same, and these should be the only uses.
+ if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
+ continue;
+ }
+
+ if (!IsSuitableForCudnnRuntimeFusion(conv)) {
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+ if (config.activation_mode() != se::dnn::kNone) {
+ continue;
+ }
+
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseElu: ", conv->ToString());
+ })) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+ config.set_activation_mode(se::dnn::kElu);
+ TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
+ changed = true;
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> FuseRelu(HloComputation* comp) {
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* gte;
+ HloInstruction* conv;
+ if (!Match(instr,
+ m::MaximumAnyOrder(
+ m::Broadcast(m::ConstantEffectiveScalar(0)),
+ m::GetTupleElement(
+ >e, m::Op(&conv)
+ .WithPredicate(IsNonDepthwiseConvCustomCall)
+ .WithOneUse())
+ .WithOneUse()))) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+ if (config.activation_mode() != se::dnn::kNone) {
+ continue;
+ }
+
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseRelu: ", conv->ToString());
+ })) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+ config.set_activation_mode(se::dnn::kRelu);
+ TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
+ changed = true;
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
+ se::GpuComputeCapability cc) {
+ if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
+ cc)) {
+ return false;
+ }
+
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction *gte, *conv;
+ if (!Match(
+ instr,
+ m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
+ m::GetTupleElement(
+ >e, m::Op(&conv)
+ .WithPredicate(IsNonDepthwiseConvCustomCall)
+ .WithOneUse())
+ // TODO(jlebar): Why only fp16?
+ .WithElementType(F16)
+ .WithOneUse(),
+ m::Broadcast(m::ConstantEffectiveScalar(6))))) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+ if (config.activation_mode() != se::dnn::kNone) {
+ continue;
+ }
+
+ if (!IsSuitableForCudnnRuntimeFusion(conv)) {
+ continue;
+ }
+
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseRelu6: ", conv->ToString());
+ })) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+ config.set_activation_mode(se::dnn::kRelu6);
+ TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
+ changed = true;
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> FuseLeakyRelu(HloComputation* comp,
+ se::GpuComputeCapability cc) {
+ if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
+ cc)) {
+ return false;
+ }
+
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction *gte1, *gte2, *gte3, *conv, *alpha;
+ if (!Match(instr,
+ m::Select(
+ m::Compare(m::GetTupleElement(>e1, m::Op()),
+ m::Broadcast(m::ConstantEffectiveScalar(0)))
+ .WithComparisonDirection(ComparisonDirection::kGt)
+ .WithOneUse(),
+ m::GetTupleElement(
+ >e2, m::Op(&conv)
+ .WithPredicate(IsNonDepthwiseConvCustomCall)
+ .WithOneUse())
+ // TODO(jlebar): Why only fp16?
+ .WithElementType(F16),
+ m::Multiply(m::GetTupleElement(>e3, m::Op()),
+ m::Broadcast(m::ConstantEffectiveScalar(&alpha)))
+ .WithOneUse()))) {
+ continue;
+ }
+
+ // The three GTEs should be the same, and these should be the only uses.
+ if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+ if (config.activation_mode() != se::dnn::kNone) {
+ continue;
+ }
+
+ if (!IsSuitableForCudnnRuntimeFusion(conv)) {
+ continue;
+ }
+
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseLeakyRelu: ", conv->ToString());
+ })) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+ config.set_activation_mode(se::dnn::kLeakyRelu);
+ TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
+ config.set_leakyrelu_alpha(alpha_f64.GetFirstElement<double>());
+ TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
+ changed = true;
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* gte = nullptr;
+ HloInstruction* conv = nullptr;
+
+ auto f32_convertible_to_f16_pat =
+ m::Op().WithElementType(F32).WithPredicate(
+ IsLosslesslyConvertibleToF16);
+ if (!MatchAndLogIfFailed(
+ instr, "f16 conv",
+ m::Convert(
+ m::GetTupleElement(
+ >e,
+ m::Op(&conv)
+ .WithPredicate(IsConvCustomCall)
+ .WithOperand(0, f32_convertible_to_f16_pat)
+ .WithOperand(1, f32_convertible_to_f16_pat)
+ .WithOperandIfPresent(2, f32_convertible_to_f16_pat)
+ .WithOperandIfPresent(3, f32_convertible_to_f16_pat),
+ 0)
+ .WithOneUse())
+ .WithElementType(F16),
+ VLOG_IS_ON(3),
+ m::Op().WithOperand(0, m::GetTupleElement(m::Op().WithPredicate(
+ IsConvCustomCall))))) {
+ continue;
+ }
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseConvertToF16: ", conv->ToString());
+ })) {
+ continue;
+ }
+
+ VLOG(2) << "Matched fp16 conv: " << conv->ToString();
+
+ // In fp16 convs, all operands, including `bias`, must be fp16. This is
+ // different from int8 convs, where the bias is fp32. See table of
+ // supported datatypes at
+ // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+ absl::InlinedVector<HloInstruction*, 4> new_operands;
+ for (HloInstruction* operand : conv->operands()) {
+ new_operands.push_back(
+ MakeConvertToHlo(operand, F16, &operand->metadata()));
+ }
+
+ Shape new_shape = conv->shape();
+ new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
+
+ HloInstruction* new_conv = comp->AddInstruction(
+ conv->CloneWithNewOperands(new_shape, new_operands));
+ comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
+ MakeGetTupleElementHlo(new_conv, 0));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
+ changed = true;
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,
+ se::GpuComputeCapability cc) {
+ if (IsROCm(cc)) return false;
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* gte = nullptr;
+ HloInstruction* conv = nullptr;
+
+ auto conv_pattern =
+ m::Op(&conv)
+ .WithPredicate(IsConvCustomCall)
+ .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
+ .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
+
+ PrimitiveType conv_output_ty;
+ if (MatchAndLogIfFailed(
+ instr, "s8->s8 conv",
+ m::Convert(m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(-128)),
+ m::GetTupleElement(
+ >e,
+ conv_pattern.WithOperandIfPresent(
+ 3, m::Op().WithPredicate(
+ IsLosslesslyConvertibleToS8)),
+ 0)
+ .WithOneUse(),
+ m::Broadcast(m::ConstantEffectiveScalar(127))))
+ .WithElementType(S8),
+ VLOG_IS_ON(3),
+ m::Convert(m::Clamp(m::Op(),
+ m::GetTupleElement(
+ m::Op().WithPredicate(IsConvCustomCall)),
+ m::Op()))
+ .WithElementType(S8))) {
+ conv_output_ty = S8;
+ } else if (MatchAndLogIfFailed(
+ instr, "s8->f32 conv",
+ m::GetTupleElement(>e,
+ conv_pattern.WithOperandIfPresent(
+ 3, m::Op().WithElementType(F32)),
+ 0)
+ .WithElementType(F32),
+ VLOG_IS_ON(3),
+ m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
+ .WithElementType(F32))) {
+ conv_output_ty = F32;
+ } else {
+ continue;
+ }
+ if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+ return absl::StrCat("FuseConvertToS8: ", conv->ToString());
+ })) {
+ continue;
+ }
+
+ absl::InlinedVector<HloInstruction*, 4> new_operands(
+ conv->operands().begin(), conv->operands().end());
+ new_operands[0] =
+ MakeConvertToHlo(new_operands[0], S8, &new_operands[0]->metadata());
+ new_operands[1] =
+ MakeConvertToHlo(new_operands[1], S8, &new_operands[1]->metadata());
+ // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn. See
+ // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+ if (new_operands.size() >= 4) {
+ // side-input always matches conv output type. We checked in the patterns
+ // above that it's losslessly-convertible to this type.
+ new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty,
+ &new_operands[3]->metadata());
+ }
+
+ Shape new_shape = conv->shape();
+ new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
+
+ HloInstruction* new_conv = comp->AddInstruction(
+ conv->CloneWithNewOperands(new_shape, new_operands));
+ comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
+ MakeGetTupleElementHlo(new_conv, 0));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
+ changed = true;
+ }
+ return changed;
+}
+
+absl::Status CheckNoIllegalIntegerConvs(HloComputation* comp) {
+ auto is_integral_not_s8 = [](const Shape& s) {
+ return primitive_util::IsIntegralType(s.element_type()) &&
+ s.element_type() != S8;
+ };
+
+ std::vector<HloInstruction*> bad_convs;
+ for (HloInstruction* instr : comp->instructions()) {
+ if (!IsConvCustomCall(instr)) {
+ continue;
+ }
+ if (is_integral_not_s8(instr->shape().tuple_shapes(0)) ||
+ is_integral_not_s8(instr->operand(0)->shape()) ||
+ is_integral_not_s8(instr->operand(1)->shape()) ||
+ (instr->operand_count() >= 4 &&
+ is_integral_not_s8(instr->operand(3)->shape()))) {
+ bad_convs.push_back(instr);
+ }
+ }
+
+ if (bad_convs.empty()) {
+ return absl::OkStatus();
+ }
+
+ return Unimplemented(
+ R"(
+Can't lower one or more integer convolutions to idioms supported by CuDNN.
+
+CuDNN integer convolutions must have:
+
+ - s8 input and filter,
+ - f32 bias (if present),
+ - s8 or f32 output, and
+ - s8 side_input (if present) if output is s8.
+
+For each of the unsupported convs below, we weren't able to lower one of the
+operands or the output to the appropriate type.
+
+See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
+
+https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
+https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
+
+Unsupported convs:
+%s
+
+******* Full HLO module *******
+%s
+)",
+ absl::StrJoin(bad_convs, "\n",
+ [](std::string* out, HloInstruction* instr) {
+ absl::StrAppend(out, " - ", instr->ToString());
+ }),
+ comp->parent()->ToString());
+}
+
+void VlogStats(HloModule* module) {
+ if (!VLOG_IS_ON(1)) {
+ return;
+ }
+
+ VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
+ absl::flat_hash_map<std::string, int> stats;
+ for (HloComputation* comp : module->MakeNonfusionComputations()) {
+ for (HloInstruction* instr : comp->instructions()) {
+ if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
+ continue;
+ }
+
+ VLOG(3) << instr->ToString();
+
+ if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
+ ++stats["01 non-fused forward convs"];
+ } else if (instr->custom_call_target() ==
+ kCudnnConvBiasActivationForwardCallTarget) {
+ ++stats["02 fused forward convs"];
+ }
+
+ PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
+ PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
+ if (conv_in_ty == F32) {
+ ++stats["10 f32 convs"];
+ } else if (conv_in_ty == F16) {
+ ++stats["11 f16 convs"];
+ } else if (conv_in_ty == S8) {
+ if (conv_out_ty == S8) {
+ ++stats["12 s8->s8 convs"];
+ } else if (conv_out_ty == F32) {
+ ++stats["13 s8->f32 convs"];
+ } else {
+ LOG(ERROR) << "Unexpected conv: " << instr->ToString();
+ }
+ }
+
+ if (instr->operand_count() > 2) {
+ ++stats["20 convs with bias"];
+ if (Match(instr->operand(2),
+ m::Broadcast(m::ConstantEffectiveScalar(0)))) {
+ ++stats["21 convs with 0 bias"];
+ }
+ }
+ if (instr->operand_count() > 3) {
+ ++stats["22 convs with side-input"];
+ }
+
+ auto gpu_config = instr->backend_config<GpuBackendConfig>();
+ if (!gpu_config.ok()) {
+ LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
+ continue;
+ }
+ const CudnnConvBackendConfig& config =
+ gpu_config->cudnn_conv_backend_config();
+ if (config.conv_result_scale() != 1) {
+ ++stats["30 convs with result scale"];
+ }
+ if (config.side_input_scale() != 0 && config.side_input_scale() != 1) {
+ ++stats["31 convs with side-input scale"];
+ }
+ ++stats[absl::StrCat(
+ "32 convs with activation mode ",
+ se::dnn::ActivationMode_Name(config.activation_mode()))];
+ }
+ }
+
+ std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
+ stats.end());
+ absl::c_sort(stats_sorted);
+ for (const auto& kv : stats_sorted) {
+ VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
+ absl::string_view(kv.first).substr(3));
+ }
+}
+
+} // namespace
+
+absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool any_changed = false;
+
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ bool changed = false;
+ // Rewrite FP8 convolutions and supported adjacent pointwise ops into a
+ // ForwardGraph Custom Call.
+ if (!IsROCm(compute_capability_)) {
+ auto cc = std::get<se::CudaComputeCapability>(compute_capability_);
+ TF_ASSIGN_OR_RETURN(
+ changed, F8GraphConv(comp, cc, dnn_version_, toolkit_version_));
+ if (changed) {
+ return changed;
+ }
+ }
+ // Fuse "inside out" starting with the operations closest to the conv.
+ TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp));
+ any_changed |= changed;
+
+ TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
+ any_changed |= changed;
+
+ // s8 convs' bias and side-input appear before conversion to s8.
+ //
+ // Run FuseBiasOrSideInput twice, so we get both the bias and the side
+ // input, if both are present.
+ TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
+ any_changed |= changed;
+
+ // Relu might appear before or after convert-to-f16/s8, so we check in both
+ // cases.
+ TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
+ any_changed |= changed;
+
+ TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
+ any_changed |= changed;
+
+ TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_));
+ any_changed |= changed;
+
+ // f16 convs' bias+side-input can appear before or after conversion to f16.
+ TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
+ any_changed |= changed;
+
+ TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
+ any_changed |= changed;
+
+ // Check that we don't have any convs outputting integer types other than
+ // s8 - cudnn does not support these. They should have been transformed to
+ // int8->int8 or int8->float above.
+ TF_RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp));
+ }
+
+ VlogStats(module);
+
+ return any_changed;
+}
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h
new file mode 100644
index 0000000..5caf9c0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h
@@ -0,0 +1,135 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites custom-calls targeting cudnnConvolutionForward to
+// cudnnConvolutionBiasActivationForward by fusing operations following forward
+// convolution. This transform must run after GpuConvRewriter.
+//
+// Semantics of underlying cudnn ops:
+//
+// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#cudnnconvolutionforward
+// https://docs.nvidia.com/deeplearning/cudnn/latest/developer/misc.html#scaling-parameters
+//
+// ## Floating-point convs
+//
+// A "complete" fused floating-point conv has the form
+//
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)),
+//
+// which we fuse to
+//
+// cudnnConvolutionBiasActivationForward(x, w, bias, side_input).
+//
+// You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still
+// get a fused convolution. alpha1/2 must be broadcasts of scalar constants.
+//
+// f16 convs accumulate in f32. We represent this in HLO as an f32 convolution
+// whose inputs can be converted to f16 without loss of precision and whose
+// output is immediately converted to f16. A fused f16 conv must follow one of
+// the following idioms.
+//
+// 1. convert_f16(conv_f32(x_f32, w_f32)) +
+// side_input_f16 + broadcast(bias_f16)
+//
+// 2. convert_f16(conv_f32(x_f32, w_f32) +
+// side_input_f32 + broadcast(bias_f32))
+//
+// (These are not strictly mathematically equivalent, but cudnn doesn't tell us
+// which one it does, and we deem them "close enough".)
+//
+// The foo_f32 HLOs must all be losslessly-convertible to f16. Some valid
+// examples:
+//
+// - foo_f32 = convert_f32(foo_f16)
+// - foo_f32 = an f32 constant whose values all fit within f16
+// - foo_f32 = broadcast/transpose/reshape(one of the above)
+//
+// If you have a relu, it can appear before or after the convert_f16.
+//
+// Note that here `bias` must be losslessly-convertible to f16; this is
+// different than for s8 convolutions, where bias is f32.
+//
+// ## Integer convs
+//
+// In pure HLO, a "complete" integer conv is spelled as one of the following
+// `result`s.
+//
+// base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) +
+// alpha2_f32 * side_input +
+// bias_f32
+//
+// result_f32 = max(base, 0)
+// result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0)
+// result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127))
+//
+// The foo_s32 HLOs must be losslessly-convertible to s8. If the `result_s8`
+// case, side_input should be an f32 HLO that's losslessly-convertible to s8;
+// otherwise, it should be losslessly-convertible to f32.
+//
+// In the `result_s8` case where there's no bias, side-input, or alpha1, you can
+// skip the convert_f32 on conv.
+//
+// If you have an integer convolution that doesn't fit one of these idioms, this
+// pass returns an error -- cudnn will not be able to run it.
+class CudnnFusedConvRewriter : public HloModulePass {
+ public:
+ CudnnFusedConvRewriter(se::CudaComputeCapability cc,
+ se::dnn::VersionInfo dnn_version,
+ int32_t toolkit_version)
+ : compute_capability_(cc),
+ dnn_version_(dnn_version),
+ toolkit_version_(toolkit_version) {}
+ CudnnFusedConvRewriter(se::RocmComputeCapability cc,
+ se::dnn::VersionInfo dnn_version,
+ int32_t toolkit_version)
+ : compute_capability_(cc),
+ dnn_version_(dnn_version),
+ toolkit_version_(toolkit_version) {}
+
+ absl::string_view name() const override {
+ return "cudnn-fused-convolution-rewriter";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const se::GpuComputeCapability compute_capability_;
+ const se::dnn::VersionInfo dnn_version_;
+ const int32_t toolkit_version_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc
new file mode 100644
index 0000000..5e22a6f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc
@@ -0,0 +1,3169 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+
+#include <array>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <thread> // NOLINT
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/convert_mover.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/hlo_constant_folding.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_fix.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/reshape_mover.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#elif TENSORFLOW_USE_ROCM
+#include "rocm/rocm_config.h"
+#endif // GOOGLE_CUDA
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// TODO(b/210165681): The tests in this file are fragile to HLO op names.
+
+namespace m = match;
+
+using ::testing::HasSubstr;
+using ::testing::Not;
+
+// TODO: Use constexpr vector once XLA is compiled with C++20.
+const auto* kf16f32f64 = new std::vector<std::string>({"f16", "f32", "f64"});
+const auto* kf16f32 = new std::vector<std::string>({"f16", "f32"});
+
+class CudnnFusedConvRewriterHloTest : public HloTestBase {
+ public:
+ bool IsCuda() {
+ return std::holds_alternative<se::CudaComputeCapability>(
+ backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .gpu_compute_capability());
+ }
+ se::CudaComputeCapability GetCudaComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability();
+ }
+ stream_executor::dnn::VersionInfo GetDnnVersion() {
+ return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
+ }
+
+ int32_t GetToolkitVersion() const {
+#if GOOGLE_CUDA
+ return CUDA_VERSION;
+#elif TENSORFLOW_USE_ROCM
+ return TF_ROCM_VERSION;
+#endif
+ return 0;
+ }
+
+ CudnnFusedConvRewriterHloTest()
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/false,
+ /*instruction_can_change_layout_func=*/{}) {}
+};
+
+class CudnnFusedConvRewriterTest : public GpuCodegenTest {
+ public:
+ bool IsCuda() {
+ return std::holds_alternative<se::CudaComputeCapability>(
+ backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .gpu_compute_capability());
+ }
+ se::CudaComputeCapability GetCudaComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability();
+ }
+ stream_executor::dnn::VersionInfo GetDnnVersion() {
+ return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
+ }
+
+ int32_t GetToolkitVersion() const {
+#if GOOGLE_CUDA
+ return CUDA_VERSION;
+#elif TENSORFLOW_USE_ROCM
+ return TF_ROCM_VERSION;
+#endif
+ return 0;
+ }
+
+ protected:
+ std::string GetOptimizedHlo(absl::string_view hlo_string) {
+ // cudnn_vectorize_convolutions transforms convolutions, making it hard to
+ // match them here in this test. What's worse, the transforms it does
+ // depends on the GPU that's available! So just disable them for this
+ // function that gets the optimized HLO. When we actually run the module
+ // we'll still have this pass enabled.
+ HloModuleConfig config = GetModuleConfigForTest();
+ DebugOptions debug_opts = config.debug_options();
+ debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ config.set_debug_options(debug_opts);
+
+ auto result = backend().compiler()->RunHloPasses(
+ ParseAndReturnVerifiedModule(hlo_string, config).value(),
+ backend().default_stream_executor(), backend().memory_allocator());
+ if (!result.status().ok()) {
+ TF_EXPECT_OK(result.status())
+ << "HLO compilation failed: " << result.status();
+ return "";
+ }
+ HloPrintOptions print_opts;
+ print_opts.set_print_operand_shape(false);
+ return (*result)->ToString(print_opts);
+ }
+
+ void TestMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
+ const std::string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_THAT(optimized_hlo_string,
+ Not(HasSubstr(kCudnnConvForwardCallTarget)))
+ << optimized_hlo_string;
+ EXPECT_THAT(optimized_hlo_string,
+ HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_with_new_type));
+ DebugOptions debug_opts = module->config().debug_options();
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ module->mutable_config().set_debug_options(debug_opts);
+ EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01}))
+ << optimized_hlo_string;
+ }
+ }
+
+ void TestClamp(absl::string_view pre_hlo_string,
+ absl::string_view post_hlo_string) {
+ std::string alpha_conv_scalar, alpha_side_input_scalar;
+ std::string elementwise_type;
+
+ std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
+ EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
+ EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
+ EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01}))
+ << pre_hlo_string;
+
+ absl::StatusOr<bool> filecheck_result =
+ RunFileCheck(optimized_hlo_string, post_hlo_string);
+ ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
+ EXPECT_TRUE(*filecheck_result);
+ }
+
+ void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
+ const std::string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+ SCOPED_TRACE(optimized_hlo_string);
+ EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
+ EXPECT_THAT(optimized_hlo_string,
+ Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
+ }
+ }
+
+ void TestF8(std::string pre_hlo_string, std::string custom_call_string,
+ std::string serialized_graph_string) {
+ if (!IsCuda()) return;
+ if (GetCudaComputeCapability().IsAtLeast(
+ se::CudaComputeCapability::HOPPER)) {
+ // On Hopper and newer architectures, test numerical correctness and
+ // verify the HLO of the Custom Call with operand and return layouts and
+ // the serialized graph based on the full compiler pipeline.
+ std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
+ EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
+ EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
+ EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.15, 0.15}))
+ << pre_hlo_string;
+
+ absl::StatusOr<bool> filecheck_result =
+ RunFileCheck(optimized_hlo_string, custom_call_string);
+ ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
+ EXPECT_TRUE(*filecheck_result);
+
+ filecheck_result =
+ RunFileCheck(optimized_hlo_string, serialized_graph_string);
+ ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
+ EXPECT_TRUE(*filecheck_result);
+ } else {
+ // On older architectures, disregard layout information and only verify
+ // the basic configuration of the convolution Custom Call using the number
+ // of operands and the window_size and serialized graph attributes based
+ // on the ConvRewriter and CudnnFusedConvRewriter passes.
+ std::string::size_type p0 = custom_call_string.find(':');
+ std::string::size_type p1 = custom_call_string.find("custom-call");
+ custom_call_string.erase(p0 + 1, p1 - p0 - 2);
+ p0 = custom_call_string.find(", dim_labels");
+ custom_call_string.erase(p0);
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(pre_hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool changed,
+ RunHloPass(ConvRewriter(GetCudaComputeCapability()), module.get()));
+ EXPECT_TRUE(changed);
+ RunAndFilecheckHloRewrite(
+ module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+ CudnnFusedConvRewriter(
+ se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
+ GetDnnVersion(), GetToolkitVersion()),
+ custom_call_string);
+ RunAndFilecheckHloRewrite(
+ module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+ CudnnFusedConvRewriter(
+ se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
+ GetDnnVersion(), GetToolkitVersion()),
+ serialized_graph_string);
+ }
+ }
+
+ void TestF8Parameterized(std::string template_pre_hlo_string,
+ std::string template_custom_call_string,
+ std::string template_serialized_graph_string) {
+ std::array<absl::string_view, 2> types = {"f8e4m3fn", "f8e5m2"};
+ std::array<absl::string_view, 2> clamp_lower = {"-448.", "-57344."};
+ std::array<absl::string_view, 2> clamp_upper = {"448.", "57344."};
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ for (int i = 0; i < 2; ++i) {
+ replacements["<<InputType>>"] = types[i];
+ for (int j = 0; j < 2; ++j) {
+ replacements["<<FilterType>>"] = types[j];
+ for (int k = 0; k < 2; ++k) {
+ replacements["<<OutputType>>"] = types[k];
+ replacements["<<ClampLower>>"] = clamp_lower[k];
+ replacements["<<ClampUpper>>"] = clamp_upper[k];
+ TestF8(absl::StrReplaceAll(template_pre_hlo_string, replacements),
+ absl::StrReplaceAll(template_custom_call_string, replacements),
+ absl::StrReplaceAll(template_serialized_graph_string,
+ replacements));
+ }
+ }
+ }
+ }
+};
+
+#if GOOGLE_CUDA
+#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900)
+#define MAYBE_SKIP_TEST(CAUSE) \
+ do { \
+ if (absl::string_view(CAUSE) == "F8") \
+ GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; \
+ } while (0)
+#else
+#define MAYBE_SKIP_TEST(CAUSE)
+#endif
+#else
+#define MAYBE_SKIP_TEST(CAUSE) \
+ do { \
+ GTEST_SKIP() << "ROCm does not support " CAUSE " fusion"; \
+ } while (0)
+#endif
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) {
+ // max(0, conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseReluWithDepthwiseConv) {
+ // max(0, conv(x, w));
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,1,17] parameter(1)
+
+ conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
+ ROOT relu = TYPE[1,17,9,9] maximum(zeros, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBias) {
+ // max(0, conv(x, w) + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, Test3D) {
+ // max(0, conv(x, w) + bias);
+ std::string body = R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,5,7,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,5,7,64] parameter(0)
+ filter = TYPE[3,3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,5,7,64] convolution(input, filter), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,5,7,64] broadcast(bias), dimensions={4}
+ add1 = TYPE[1,3,5,7,64] add(conv, broadcasted_bias)
+ )";
+
+ std::string relu = R"(
+ ROOT relu = TYPE[1,3,5,7,64] maximum(zeros, add1)
+ })";
+
+ std::string elu = R"(
+ cmp = pred[1,3,5,7,64] compare(add1, zeros), direction=GT
+ expm1 = TYPE[1,3,5,7,64] exponential-minus-one(add1)
+ ROOT elu = TYPE[1,3,5,7,64] select(cmp, add1, expm1)
+ })";
+
+ TestMatchWithAllTypes(body + relu);
+ if (!IsCuda()) TestMatchWithAllTypes(body + elu);
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBiasMultiCall) {
+ // max(0, conv(x, w) + bias);
+ std::string code = R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,<<<format>>>,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,<<<format>>>,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,<<<format>>>,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,<<<format>>>,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,<<<format>>>,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,<<<format>>>,64] maximum(zeros, add1)
+ })";
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<<format>>>"] = "3,3";
+ TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
+ replacements["<<<format>>>"] = "5,5";
+ TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
+ replacements["<<<format>>>"] = "3,3";
+ TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBiasNoRelu) {
+ // conv(x, w) + bias;
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ ROOT add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseBiasWithDepthwiseConv) {
+ // conv(x, w) + bias;
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,1,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestElu) {
+ // sum = conv(x, w) + bias
+ // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
+ expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
+ ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) {
+ // sum = conv(x, w) + bias
+ // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,1,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
+ expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
+ ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestRelu6) {
+ if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
+ se::CudaComputeCapability::AMPERE)) {
+ GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
+ "the Nvidia Ampere+ GPUs.";
+ }
+ // sum = conv(x, w) + bias
+ // clamp(0, sum, 6);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ six = TYPE[] constant(6)
+ sixes = TYPE[1,3,3,64] broadcast(six), dimensions={}
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu6 = TYPE[1,3,3,64] clamp(zeros, sum, sixes)
+ })");
+}
+
+// At time of writing, cudnn runtime fusion cannot handle f16 convs with an odd
+// number of input/output channels. Check that we don't try to run this conv
+// with runtime fusion (or, if we do, that it works!).
+TEST_F(CudnnFusedConvRewriterTest, TestRelu6OddChannels) {
+ if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
+ se::CudaComputeCapability::AMPERE)) {
+ GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
+ "the Nvidia Ampere+ GPUs.";
+ }
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+ ENTRY Test {
+ zeros = TYPE[1,384,1024,32] broadcast(TYPE[] constant(0)), dimensions={}
+ sixes = TYPE[1,384,1024,32] broadcast(TYPE[] constant(6)), dimensions={}
+ input = TYPE[1,769,2049,3] parameter(0)
+ filter = TYPE[32,3,3,3] parameter(1)
+ bias = TYPE[32] parameter(2)
+ conv = TYPE[1,384,1024,32] convolution(input, filter), window={size=3x3 stride=2x2}, dim_labels=b01f_o01i->b01f
+ broadcasted_bias = TYPE[1,384,1024,32] broadcast(bias), dimensions={3}
+ sum = add(conv, broadcasted_bias)
+ ROOT relu6 = clamp(zeros, sum, sixes)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) {
+ if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
+ se::CudaComputeCapability::AMPERE)) {
+ GTEST_SKIP()
+ << "Conv-Bias-LeakyRelu fusion is supported and recommended with "
+ "the Nvidia Ampere+ GPUs.";
+ }
+ // sum = conv(x, w) + bias
+ // select(compare(sum, 0, GT), sum, multiply(sum, alpha));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha = TYPE[] constant(0.2)
+ alphas = TYPE[1,3,3,64] broadcast(alpha), dimensions={}
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
+ mul = TYPE[1,3,3,64] multiply(sum, alphas)
+ ROOT elu = TYPE[1,3,3,64] select(cmp, sum, mul)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) {
+ // max(0, conv(x, w) + side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseSideInputWithDepthwiseConv) {
+ // max(0, conv(x, w) + side_input);
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,1,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+ add1 = TYPE[1,3,3,64] add(conv, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) {
+ // max(0, conv(x, w) + side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) {
+ // max(0, 0.999994934 * conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseScaledDepthwiseConv) {
+ // max(0, 0.999994934 * conv(x, w));
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,1,17] parameter(1)
+
+ conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
+ alpha_conv = TYPE[1,17,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = TYPE[1,17,9,9] multiply(conv, alpha_conv)
+ ROOT relu = TYPE[1,17,9,9] maximum(zeros, scaled_conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) {
+ EXPECT_TRUE(RunAndCompare(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = f32[] constant(inf)
+ zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = f32[] constant(0.999994934)
+
+ input = f32[1,17,9,9] parameter(0)
+ filter = f32[3,3,17,32] parameter(1)
+
+ conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv)
+ ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv)
+ })",
+ ErrorSpec{0.01}));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvAndScaledSideInput) {
+ // max(0, conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseDepthwiseConvWithScaledSideInput) {
+ // max(0, conv(x, w) + 0.899994934 * side_input);
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,1,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) {
+ // max(0.1, conv(x, w)) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ point_one = TYPE[] constant(0.1)
+ point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
+ const char* kHloString = R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = f32[] constant(0)
+ zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = f32[1,17,9,9] parameter(0)
+ filter = f32[3,3,17,32] parameter(1)
+
+ conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
+ ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
+ })";
+
+ const std::string optimized_hlo_string =
+ backend()
+ .compiler()
+ ->RunHloPasses(
+ ParseAndReturnVerifiedModule(kHloString, GetModuleConfigForTest())
+ .value(),
+ backend().default_stream_executor(), backend().memory_allocator())
+ .value()
+ ->ToString();
+ EXPECT_THAT(optimized_hlo_string,
+ ::testing::ContainsRegex(
+ R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
+ // The convolution below would crash if feature_count is not preserved.
+ const char* kHloString = R"(
+ HloModule jaxpr_computation__6.19
+
+ primitive_computation__1.4 {
+ parameter.5 = f32[] parameter(0)
+ parameter.6 = f32[] parameter(1)
+ ROOT add.7 = f32[] add(parameter.5, parameter.6)
+ }
+
+ ENTRY jaxpr_computation__7.8 {
+ parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1)
+ parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0)
+ convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53
+ constant.13 = f32[] constant(0)
+ broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={}
+ maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14)
+ ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4
+ }
+ )";
+ EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01}));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ ROOT conv_a = f8e4m3fn[1,16,6,6] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f8e4m3fn]conv();"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_f32 = f32[1,128,6,6] convert(input)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ z_scale = f32[] parameter(2)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+ ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE_UID:[0-9]+]]:[f8e4m3fn]scale([[CONV_UID]]);"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledOutputF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_f32 = f32[1,128,6,6] convert(input)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ z_scale = f32[] parameter(2)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ conv_a_scaled = f32[1,16,6,6] divide(conv_a, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+ ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]invscale([[CONV_UID]]);"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8Parameterized) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8Parameterized(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = <<InputType>>[1,128,6,6] parameter(0)
+ filter = <<FilterType>>[3,3,128,16] parameter(1)
+ input_scale = f32[] parameter(2)
+ input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+ filter_scale = f32[] parameter(3)
+ filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+ input_f32 = f32[1,128,6,6] convert(input)
+ input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+ z_scale = f32[] parameter(4)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+ c1 = f32[] constant(<<ClampLower>>)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(<<ClampUpper>>)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+ ROOT conv_f8 = <<OutputType>>[1,16,6,6] convert(conv_a_clamped)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (<<OutputType>>[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[<<OutputType>>]scale([[SCALE1_UID]]);"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_scale = f32[] parameter(2)
+ input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+ filter_scale = f32[] parameter(3)
+ filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+ input_f32 = f32[1,128,6,6] convert(input)
+ input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+ bias = f32[1,16,6,6] parameter(4)
+ z_scale = f32[] parameter(5)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ conv_a_bias = f32[1,16,6,6] add(conv_a, bias)
+ conv_a_scaled = f32[1,16,6,6] multiply(conv_a_bias, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+ ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]], /*index=5*/[[OPERAND5:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[ADD_UID:[0-9]+]]:[f32]add([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[ADD_UID]]);"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_f32 = f32[1,128,6,6] convert(input)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ z_scale = f32[] parameter(2)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ c = f32[] constant(0)
+ c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
+ relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
+ ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[RELU_UID:[0-9]+]]:[f32]relu([[CONV_UID]]);[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvAmaxF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] maximum(a, b)
+ }
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_scale = f32[] parameter(2)
+ input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+ filter_scale = f32[] parameter(3)
+ filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+ input_f32 = f32[1,128,6,6] convert(input)
+ input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+ conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ z_scale = f32[] parameter(4)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+ conv_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+ abs_conv_a = f32[1,16,6,6] abs(conv_a)
+ c0 = f32[] constant(-inf)
+ amax = f32[] reduce(abs_conv_a, c0), dimensions={0,1,2,3}, to_apply=apply
+ ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(conv_a_clamped_f8, amax)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[SCALE1_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[SCALE1_UID]]);"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvReluAmaxF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] maximum(a, b)
+ }
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_scale = f32[] parameter(2)
+ input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+ filter_scale = f32[] parameter(3)
+ filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+ input_f32 = f32[1,128,6,6] convert(input)
+ input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+ c = f32[] constant(0)
+ c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
+ z_scale = f32[] parameter(4)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
+ relu_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
+ abs_relu_a = f32[1,16,6,6] abs(relu_a)
+ c0 = f32[] constant(-inf)
+ amax = f32[] reduce(abs_relu_a, c0), dimensions={0,1,2,3}, to_apply=apply
+ ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(relu_a_clamped_f8, amax)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[RELU_UID:[0-9]+]]:[f32]relu([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[RELU_UID]]);"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputMultipleUsersF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_f32 = f32[1,128,6,6] convert(input)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ z_scale0 = f32[] parameter(2)
+ z_scale0_bcast = f32[1,16,6,6] broadcast(z_scale0), dimensions={}
+ z_scale1 = f32[] parameter(3)
+ z_scale1_bcast = f32[1,16,6,6] broadcast(z_scale1), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ conv_a_scaled0 = f32[1,16,6,6] multiply(conv_a, z_scale0_bcast)
+ conv_a_scaled1 = f32[1,16,6,6] multiply(conv_a, z_scale1_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ conv_a_clamped0 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled0, c2_bcast)
+ conv_a_clamped1 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled1, c2_bcast)
+ conv_a_convert0 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped0)
+ conv_a_convert1 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped1)
+ ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f8e4m3fn[1,16,6,6]) tuple(conv_a_convert0, conv_a_convert1)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputUnsupportedUserF8) {
+ MAYBE_SKIP_TEST("F8");
+ TestF8(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f8e4m3fn[1,128,6,6] parameter(0)
+ filter = f8e4m3fn[3,3,128,16] parameter(1)
+ input_f32 = f32[1,128,6,6] convert(input)
+ filter_f32 = f32[3,3,128,16] convert(filter)
+ z_scale = f32[] parameter(2)
+ z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+ conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ conv_a_cos = f32[1,16,6,6] cosine(conv_a)
+ conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+ conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+ conv_a_convert = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+ ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[1,16,6,6]) tuple(conv_a_convert, conv_a_cos)
+
+ })",
+ // custom_call
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )",
+ // serialized_graph
+ R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) {
+ MAYBE_SKIP_TEST("I8");
+ // max(0, clamp(conv(x, w)))); for int8_t
+ TestClamp(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = s8[] constant(0)
+ zeros = s8[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = s8[1,17,9,9] parameter(0)
+ filter = s8[3,3,17,32] parameter(1)
+
+ inputs32 = s32[1,17,9,9] convert(input)
+ filters32 = s32[3,3,17,32] convert(filter)
+
+ conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+
+ lower = s32[] constant(-128)
+ lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
+ upper = s32[] constant(127)
+ uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
+
+ clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
+
+ ROOT convert = s8[1,32,9,9] convert(clamp)
+ })",
+ // post_hlo
+ R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (s8[1,9,9,32]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[fusion_2_1:%[^ ]+]], [[fusion_1_2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s8[1,17,9,9] parameter(0)
+ filter = s8[3,3,17,32] parameter(1)
+
+ inputs32 = s32[1,17,9,9] convert(input)
+ filters32 = s32[3,3,17,32] convert(filter)
+
+ conv = s32[1,32,9,9] convolution(inputs32, filters32),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+
+ ROOT convert = f32[1,32,9,9] convert(conv)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvForwardCallTarget}), 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+ filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+ bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+ side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
+
+ conv = s32[1,32,9,9] convolution(input, filter),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ conv_f32 = f32[1,32,9,9] convert(conv)
+ ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
+ add(add(conv_f32, bias), side_input),
+ f32[1,32,9,9] broadcast(f32[] constant(127))))
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1),
+ m::Parameter(2), m::Parameter(3)),
+ 0)
+ .WithShape(S8, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+ filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+
+ conv = s32[1,32,9,9] convolution(input, filter),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
+ conv,
+ s32[1,32,9,9] broadcast(s32[] constant(127))))
+ zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
+ ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), //
+ m::Parameter(1), //
+ m::Broadcast(
+ m::ConstantEffectiveScalar(0).WithElementType(F32))),
+ 0)
+ .WithShape(S8, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s8[1,17,9,9] parameter(0)
+ filter = s8[3,3,17,32] parameter(1)
+ bias = f32[32] parameter(2)
+ bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+ side_input_f32 = f32[1,32,9,9] parameter(3)
+
+ inputs32 = s32[1,17,9,9] convert(input)
+ filters32 = s32[3,3,17,32] convert(filter)
+
+ conv = s32[1,32,9,9] convolution(inputs32, filters32),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ conv_f32 = f32[1,32,9,9] convert(conv)
+ sum1 = add(conv_f32, bias_broadcast)
+ ROOT sum2 = add(sum1, side_input_f32)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1),
+ m::Parameter(2), m::Parameter(3)),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+}
+
+// The ReshapeMover pass changes
+// reshape(side_input) * alpha -->
+// reshape(side_input * alpha).
+// Make sure we can pattern-match this.
+TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+ filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+ bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+ side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
+ side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
+
+ conv = s32[1,32,9,9] convolution(input, filter),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
+ add(add(f32[1,32,9,9] convert(conv), bias), side_input),
+ f32[1,32,9,9] broadcast(f32[] constant(127))))
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ HloPassFix<HloPassPipeline> simplify("simplify");
+ simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
+ simplify.AddPass<ReshapeMover>();
+ simplify.AddPass<ConvertMover>();
+ TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), //
+ m::Parameter(1), //
+ m::Parameter(2), //
+ m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
+ 0)
+ .WithShape(S8, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.conv_result_scale(), 1);
+ EXPECT_EQ(config.side_input_scale(), 0.25);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s8[1,17,9,9] parameter(0)
+ filter = s8[3,3,17,32] parameter(1)
+ inputs32 = s32[1,17,9,9] convert(input)
+ filters32 = s32[3,3,17,32] convert(filter)
+ alpha = f32[] constant(42)
+ alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
+
+ conv = s32[1,32,9,9] convolution(inputs32, filters32),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ convert = f32[1,32,9,9] convert(conv)
+ ROOT root = multiply(convert, alpha_broadcast)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.conv_result_scale(), 42);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[32] parameter(2)
+ bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+ zero = f32[] constant(0)
+ zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias_broadcast)
+ ROOT relu = maximum(sum, zeros)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+ zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias)
+ relu = maximum(sum, zeros)
+ not_relu = minimum(sum, zeros)
+ ROOT root = tuple(relu, not_relu)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::MaximumAnyOrder(
+ m::Broadcast(m::ConstantEffectiveScalar(0)),
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})),
+ m::Minimum())));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f16[1,16,9,9] parameter(0)
+ filters = f16[3,3,16,32] parameter(1)
+ bias = f16[32] parameter(2)
+ bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
+ zero = f16[] constant(0)
+ zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
+ conv = f16[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias_broadcast)
+ cmp = compare(sum, zeros), direction=GT
+ expm1 = exponential-minus-one(sum)
+ ROOT elu = select(cmp, sum, expm1)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ DebugOptions debug_opts = m->config().debug_options();
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ m->mutable_config().set_debug_options(debug_opts);
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ // elu fusion is only active on Ampere+.
+ CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kElu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f16[1,16,9,9] parameter(0)
+ filters = f16[3,3,16,32] parameter(1)
+ bias = f16[32] parameter(2)
+ bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
+ zero = f16[] constant(0)
+ zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
+ conv = f16[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias_broadcast)
+ cmp = compare(sum, zeros), direction=GT
+ expm1 = exponential-minus-one(sum)
+ elu = select(cmp, sum, expm1)
+ not_elu = minimum(sum, zeros)
+ ROOT root = tuple(elu, not_elu)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ DebugOptions debug_opts = m->config().debug_options();
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ m->mutable_config().set_debug_options(debug_opts);
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ auto gte_pattern =
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9});
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::Select(m::Compare(gte_pattern,
+ m::Broadcast(m::ConstantEffectiveScalar(0)))
+ .WithComparisonDirection(ComparisonDirection::kGt),
+ gte_pattern,
+ m::Op()
+ .WithPredicate(HloPredicateIsOp<HloOpcode::kExpm1>)
+ .WithOperand(0, gte_pattern)),
+ m::Minimum())));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) {
+ const std::string module_str = R"(
+ HloModule Test
+ ENTRY Test {
+ inputs = f16[1,18,9,9] parameter(0)
+ filters = f16[3,3,18,32] parameter(1)
+ bias = f16[32] parameter(2)
+ bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
+ zero = f16[] constant(0)
+ zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
+ sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
+ conv = f16[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias_broadcast)
+ ROOT relu = clamp(zeros, sum, sixes)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ DebugOptions debug_opts = m->config().debug_options();
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ m->mutable_config().set_debug_options(debug_opts);
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ // relu6 fusion is only enabled on Ampere+.
+ CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kRelu6);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) {
+ const std::string module_str = R"(
+ HloModule Test
+ ENTRY Test {
+ inputs = f16[1,18,9,9] parameter(0)
+ filters = f16[3,3,18,32] parameter(1)
+ bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+ zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
+ sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
+ conv = f16[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias)
+ relu = clamp(zeros, sum, sixes)
+ not_relu = minimum(sum, zeros)
+ ROOT root = tuple(relu, not_relu)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ DebugOptions debug_opts = m->config().debug_options();
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ m->mutable_config().set_debug_options(debug_opts);
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9}),
+ m::Broadcast(m::ConstantEffectiveScalar(6))),
+ m::Minimum())));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) {
+ const std::string module_str = R"(
+ HloModule Test
+ ENTRY Test {
+ inputs = f16[1,16,9,9] parameter(0)
+ filters = f16[3,3,16,32] parameter(1)
+ bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+ zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
+ alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
+ conv = f16[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias)
+ cmp = compare(sum, zeros), direction=GT
+ mul = multiply(sum, alphas)
+ ROOT leaky_relu = select(cmp, sum, mul)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ DebugOptions debug_opts = m->config().debug_options();
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ m->mutable_config().set_debug_options(debug_opts);
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ // Leaky-relu fusion is only enabled on Ampere+.
+ CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kLeakyRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) {
+ const std::string module_str = R"(
+ HloModule Test
+ ENTRY Test {
+ inputs = f16[1,16,9,9] parameter(0)
+ filters = f16[3,3,16,32] parameter(1)
+ bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+ zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
+ alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
+ conv = f16[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, bias)
+ cmp = compare(sum, zeros), direction=GT
+ mul = multiply(sum, alphas)
+ leaky_relu = select(cmp, sum, mul)
+ not_leaky_relu = minimum(sum, zeros)
+ ROOT root = tuple(leaky_relu, not_leaky_relu)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ DebugOptions debug_opts = m->config().debug_options();
+ debug_opts.set_xla_gpu_use_runtime_fusion(true);
+ m->mutable_config().set_debug_options(debug_opts);
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ auto gte_pattern =
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9});
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::Select(m::Compare(gte_pattern,
+ m::Broadcast(m::ConstantEffectiveScalar(0)))
+ .WithComparisonDirection(ComparisonDirection::kGt)
+ .WithOneUse(),
+ gte_pattern,
+ m::Multiply(gte_pattern,
+ m::Broadcast(m::ConstantEffectiveScalar()))),
+ m::Minimum())));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+ alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(multiply(alpha, conv), bias)
+ ROOT root = tuple(conv, sum)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv1;
+ const HloInstruction* conv2;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(m::CustomCall(&conv1), 0),
+ m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
+ m::MultiplyAnyOrder(
+ m::Broadcast(m::Parameter(3)),
+ m::GetTupleElement(m::CustomCall(&conv2), 0))))));
+ EXPECT_EQ(conv1, conv2);
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv1->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.conv_result_scale(), 1);
+ EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = tuple(conv, add(conv, bias))
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv1;
+ const HloInstruction* conv2;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(m::CustomCall(&conv1), 0),
+ m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
+ m::GetTupleElement(m::CustomCall(&conv2), 0)))));
+ EXPECT_EQ(conv1, conv2);
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv1->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.conv_result_scale(), 1);
+ EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ side_input = f32[1,32,9,9] parameter(2)
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
+ ROOT root = add(relu, side_input)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::AddAnyOrder(
+ m::Parameter(2),
+ m::GetTupleElement(
+ m::CustomCall(&conv, m::Parameter(0), m::Parameter(1),
+ m::Broadcast(m::ConstantEffectiveScalar(0))),
+ 0))));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.conv_result_scale(), 1);
+ EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
+ ROOT root = add(relu, bias)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::AddAnyOrder(
+ m::Broadcast(m::Parameter(2)),
+ m::GetTupleElement(m::CustomCall(
+ &conv, m::Parameter(0), m::Parameter(1),
+ m::Broadcast(m::ConstantEffectiveScalar(0)))))));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.conv_result_scale(), 1);
+ EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ side_input = f32[1,32,9,9] parameter(2)
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = tuple(conv, add(conv, side_input))
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv1;
+ const HloInstruction* conv2;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(m::CustomCall(&conv1), 0),
+ m::AddAnyOrder(m::Parameter(2),
+ m::GetTupleElement(m::CustomCall(&conv2), 0)))));
+ EXPECT_EQ(conv1, conv2);
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv1->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.conv_result_scale(), 1);
+ EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
+ filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv1;
+ const HloInstruction* conv2;
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(m::CustomCall(&conv1), 0),
+ m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
+ EXPECT_EQ(conv1, conv2);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+ filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ conv_s8 = s8[1,32,9,9] convert(clamp(
+ f32[1,32,9,9] broadcast(f32[] constant(-128)),
+ conv,
+ f32[1,32,9,9] broadcast(f32[] constant(127))))
+ ROOT root = tuple(conv, conv_s8)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv1;
+ const HloInstruction* conv2;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(m::CustomCall(&conv1), 0),
+ m::Convert(m::Clamp(m::Op(), //
+ m::GetTupleElement(m::CustomCall(&conv2), 0),
+ m::Op())))));
+ EXPECT_EQ(conv1, conv2);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string_view module_str = R"(
+ HloModule Test
+
+ ENTRY test_entry {
+ inputs = s8[1, 17, 9, 9] parameter(0)
+ filters = s8[3, 3, 17, 32] parameter(1)
+ mult_op = f32[1, 32, 9, 9] parameter(2)
+ conv = s32[1, 32, 9, 9] convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+ ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+ SCOPED_TRACE(m->ToString());
+ HloInstruction* conv1 = nullptr;
+ // Checks that it removed the Convert inside multiply around conv.
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
+ m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string_view module_str = R"(
+ HloModule Test
+
+ ENTRY test_entry {
+ inputs = s8[1, 17, 9, 9] parameter(0)
+ filters = s8[3, 3, 17, 32] parameter(1)
+ mult_op = f32[1, 32, 9, 9] parameter(2)
+ conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+ ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+ SCOPED_TRACE(m->ToString());
+ HloInstruction* conv1 = nullptr;
+ // Checks that it removed the Convert inside multiply around conv.
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
+ m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string_view module_str = R"(
+ HloModule Test
+
+ ENTRY test_entry {
+ inputs = f32[1, 17, 9, 9] parameter(0)
+ filters = f32[3, 3, 17, 32] parameter(1)
+ mult_op = s8[1, 32, 9, 9] parameter(2)
+ conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+ ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+ SCOPED_TRACE(m->ToString());
+ HloInstruction* conv1 = nullptr;
+ // Checks that it removed the Convert inside multiply around conv.
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
+ m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) {
+ const std::string_view module_str = R"(
+ HloModule Test
+
+ ENTRY test_entry {
+ inputs = f32[1, 17, 9, 9] parameter(0)
+ filters = f32[3, 3, 17, 32] parameter(1)
+ mult_op = s8[1, 32, 9, 9] parameter(2)
+ sub_op = s8[1, 32, 9, 9] parameter(3)
+ conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+ another = subtract(s8[1, 32, 9, 9] convert(conv), sub_op)
+ ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+ SCOPED_TRACE(m->ToString());
+ HloInstruction* conv1 = nullptr;
+ // Checks that it removed the Convert inside multiply around conv.
+ ASSERT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Multiply(
+ m::Convert(m::GetTupleElement(m::CustomCall(&conv1))),
+ m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[32] parameter(2)
+ bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = add(conv, bias_broadcast)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ side_input = f32[1,32,9,9] parameter(2)
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = add(conv, side_input)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1),
+ m::Broadcast(m::ConstantEffectiveScalar(0))
+ .WithShape(F32, {32}),
+ m::Parameter(2)),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ side_input = f32[1,32,9,9] parameter(2)
+ side_input_scale = f32[] constant(42)
+ side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
+ side_input_product = multiply(side_input, side_input_scale_broadcast)
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = add(conv, side_input_product)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1),
+ m::Broadcast(m::ConstantEffectiveScalar(0))
+ .WithShape(F32, {32}),
+ m::Parameter(2)),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 42);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[32] parameter(2)
+ side_input = f32[1,32,9,9] parameter(3)
+ bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, side_input)
+ ROOT sum2 = add(sum, bias_broadcast)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2),
+ m::Parameter(3)),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] parameter(0)
+ filters = f32[3,3,17,32] parameter(1)
+ bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
+ conv = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ ROOT root = add(conv, bias)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1),
+ m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
+ 0)
+ .WithShape(F32, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f16[1,17,9,9] parameter(0)
+ filters = f16[3,3,17,32] parameter(1)
+ bias = f16[32] parameter(2)
+ side_input = f16[1,32,9,9] parameter(3)
+
+ inputs_f32 = f32[1,17,9,9] convert(inputs)
+ filters_f32 = f32[3,3,17,32] convert(filters)
+ bias_f32 = f32[32] convert(bias)
+ bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
+ side_input_f32 = f32[1,32,9,9] convert(side_input)
+ conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, side_input_f32)
+ sum2 = add(sum, bias_broadcast)
+ ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Parameter(1), m::Parameter(2),
+ m::Parameter(3)),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+// We should be able to lower this to an f16 convolution even though the
+// f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
+TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
+ filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
+ bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+ side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
+
+ conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ conv_f16 = f16[1,32,9,9] convert(conv_f32)
+ ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
+ .WithElementType(F16),
+ m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
+ .WithElementType(F16),
+ m::Parameter(2), m::Reshape(m::Parameter(3))),
+ 0)
+ .WithShape(F16, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f16[1,17,9,9] parameter(0)
+ filters = f16[3,3,17,32] parameter(1)
+ bias = f32[32] parameter(2)
+ side_input = f16[1,32,9,9] parameter(3)
+
+ inputs_f32 = f32[1,17,9,9] convert(inputs)
+ filters_f32 = f32[3,3,17,32] convert(filters)
+ bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+ side_input_f32 = f32[1,32,9,9] convert(side_input)
+ conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=bf01_01io->bf01
+ sum = add(conv, side_input_f32)
+ sum2 = add(sum, bias_broadcast)
+ ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ // fp16 convs only support fp16 biases. Because bias is fp32, it doesn't get
+ // fused in, and we get an fp32 conv.
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Convert(m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Convert(m::Parameter(0)).WithElementType(F32),
+ m::Convert(m::Parameter(1)).WithElementType(F32),
+ m::Parameter(2),
+ m::Convert(m::Parameter(3)).WithElementType(F32)),
+ 0))
+ .WithShape(F16, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f16[1,2,2,2] parameter(0)
+ filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
+ bias = f16[2] parameter(1)
+ bias_f32 = f32[2] convert(bias)
+ side_input_f32 = f32[1,2,2,2] constant({{
+ {{0.5, 0.25}, {0.125, 0.0625}},
+ {{0.5, 0.25}, {0.125, 0.0625}}
+ }})
+
+ inputs_f32 = f32[1,2,2,2] convert(inputs)
+ bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
+ conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
+ window={size=1x1}, dim_labels=bf01_01io->bf01
+ sum = add(conv, side_input_f32)
+ sum2 = add(sum, bias_broadcast)
+ ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph, and fold
+ // convert back into constants.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+ HloConstantFolding constant_folding;
+ TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), m::Constant().WithElementType(F16),
+ m::Parameter(1), m::Constant().WithElementType(F16)),
+ 0)
+ .WithShape(F16, {1, 2, 2, 2})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ inputs = f16[1,2,2,2] parameter(0)
+ filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
+ bias = f16[2] parameter(1)
+ bias_f32 = f32[2] convert(bias)
+ side_input_f32 = f32[1,2,2,2] constant({{
+ {{0.1, 0.2}, {0.3, 0.4}},
+ {{0.5, 0.6}, {0.7, 0.8}}
+ }})
+
+ inputs_f32 = f32[1,2,2,2] convert(inputs)
+ bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
+ conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
+ window={size=1x1}, dim_labels=bf01_01io->bf01
+ sum = add(conv, side_input_f32)
+ sum2 = add(sum, bias_broadcast)
+ ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph, and fold
+ // convert back into constants.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+ HloConstantFolding constant_folding;
+ TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ // This doesn't get transformed into an f16 conv because the filters param is
+ // not losslessly expressible as f16.
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Convert(m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Convert(m::Parameter(0)).WithElementType(F32),
+ m::Constant().WithElementType(F32),
+ m::Convert(m::Parameter(1)).WithElementType(F32),
+ m::Constant().WithElementType(F32)),
+ 0)
+ .WithShape(F32, {1, 2, 2, 2}))
+ .WithElementType(F16)));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) {
+ MAYBE_SKIP_TEST("I8");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s8[1,17,9,9] parameter(0)
+ filter = s8[3,3,17,32] parameter(1)
+ inputs32 = s32[1,17,9,9] convert(input)
+ filters32 = s32[3,3,17,32] convert(filter)
+
+ conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+
+ zero = s32[] constant(0)
+ zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
+ relu = maximum(conv, zeros)
+
+ lower = s32[] constant(-128)
+ lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
+ upper = s32[] constant(127)
+ uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
+
+ clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
+
+ ROOT convert = s8[1,32,9,9] convert(clamp)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), //
+ m::Parameter(1), //
+ m::Broadcast(m::ConstantEffectiveScalar(0))
+ .WithShape(F32, {32})),
+ 0)
+ .WithShape(S8, {1, 32, 9, 9})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+ EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) {
+ MAYBE_SKIP_TEST("F64");
+ const std::string module_str = R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = f64[1,17,9,9] parameter(0)
+ filter = f64[3,3,17,32] parameter(1)
+ bias = f64[1,32,9,9] broadcast(f64[32] convert(f32[32] parameter(2))), dimensions={1}
+ conv = f64[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT root = f64[1,32,9,9] add(conv, bias)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ConvRewriter rewriter{GetCudaComputeCapability()};
+ TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+ CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+ GetToolkitVersion()};
+ TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+ // Simplify new `convert`'s that may be added to the graph.
+ AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+ TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ const HloInstruction* conv;
+ ASSERT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+ m::Parameter(0), //
+ m::Parameter(1), //
+ m::Convert(m::Parameter(2)).WithShape(F64, {32})),
+ 0)
+ .WithShape(F64, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
+ MAYBE_SKIP_TEST("I8");
+ // clamp(max(0, conv(x, w)+bias)); for int8_t
+ TestClamp(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = f32[] constant(0)
+ zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = s8[1,3,3,64] parameter(0)
+ filter = s8[3,3,64,64] parameter(1)
+ bias = f32[64] parameter(2)
+
+ inputs32 = s32[1,3,3,64] convert(input)
+ filters32 = s32[3,3,64,64] convert(filter)
+
+ conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+ convfloat = f32[1,3,3,64] convert(conv)
+ broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
+ relu = f32[1,3,3,64] maximum(zeros, add1)
+
+ lower = f32[] constant(-128)
+ lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+ upper = f32[] constant(127)
+ uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+ clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+
+ ROOT convert = s8[1,3,3,64] convert(clamp)
+ })",
+ // post_hlo
+ R"(
+// CHECK: [[cudnn_conv_bias_activation_7_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+ )");
+}
+
+// Disabled per b/190854862 or nvbugs/3326122.
+TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) {
+ MAYBE_SKIP_TEST("I8");
+ // max(0, convert<float>(conv<int32_t>(int8_x),
+ // conv<int32_t>(int8_w))+float_bias)); int8_t to float via bias.
+ TestClamp(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = f32[] constant(0)
+ zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = s8[1,3,3,64] parameter(0)
+ filter = s8[3,3,64,64] parameter(1)
+ bias = f32[64] parameter(2)
+
+ inputs32 = s32[1,3,3,64] convert(input)
+ filters32 = s32[3,3,64,64] convert(filter)
+
+ conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+ convfloat = f32[1,3,3,64] convert(conv)
+ broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
+ ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
+ })",
+ // post_hlo
+ R"(
+ ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> f32[1,3,3,64] {
+ ; CHECK: [[custom_call_0:%[^ ]+]]{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call([[input_1:%[^ ]+]], [[copy_2:%[^ ]+]]{{(\.[0-9])?}}, [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
+ ; CHECK-NEXT: ROOT [[get_tuple_element_4:%[^ ]+]]{{(\.[0-9])?}} = f32[1,3,3,64]{3,2,1,0} get-tuple-element([[custom_call_0]]{{(\.[0-9])?}}), index=0
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest,
+ TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) {
+ MAYBE_SKIP_TEST("I8");
+ // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
+ // convert<int32_t>(int8_side_input) + bias)); for int8_t
+ TestClamp(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = f32[] constant(0)
+ zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = f32[] constant(0.999994934)
+ alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = f32[] constant(0.899994934)
+ alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = s8[1,3,3,64] parameter(0)
+ filter = s8[3,3,64,64] parameter(1)
+ side_input = s8[1,3,3,64] parameter(2)
+ bias = f32[64] parameter(3)
+
+ inputs32 = s32[1,3,3,64] convert(input)
+ filters32 = s32[3,3,64,64] convert(filter)
+ side_input_f32 = f32[1,3,3,64] convert(side_input)
+
+ conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+ convfloat = f32[1,3,3,64] convert(conv)
+ scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
+ scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
+ broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = f32[1,3,3,64] add(add1, scaled_side_input)
+ relu = f32[1,3,3,64] maximum(zeros, add2)
+
+ lower = f32[] constant(-128)
+ lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+ upper = f32[] constant(127)
+ uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+ clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+
+ ROOT convert = s8[1,3,3,64] convert(clamp)
+ })",
+ // post_hlo
+ R"(
+// CHECK: [[cudnn_conv_bias_activation_11_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest,
+ TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) {
+ MAYBE_SKIP_TEST("I8");
+ // From:
+ // convert<int8_t>(clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
+ // float_side_input + bias))); To: convert<int8_t>(clamp(conv(int8_x, int8_w,
+ // float_alpha_side, float_side_input, float_bias)));
+ TestClamp(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = f32[] constant(0)
+ zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = f32[] constant(0.999994934)
+ alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = f32[] constant(0.899994934)
+ alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = s8[1,3,3,64] parameter(0)
+ filter = s8[3,3,64,64] parameter(1)
+ side_input = f32[1,3,3,64] parameter(2)
+ bias = f32[64] parameter(3)
+
+ inputs32 = s32[1,3,3,64] convert(input)
+ filters32 = s32[3,3,64,64] convert(filter)
+
+ conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+ convfloat = f32[1,3,3,64] convert(conv)
+ scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
+ scaled_side_input = f32[1,3,3,64] multiply(side_input, alpha_side_input)
+ broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = f32[1,3,3,64] add(add1, scaled_side_input)
+ relu = f32[1,3,3,64] maximum(zeros, add2)
+
+ lower = f32[] constant(-128)
+ lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+ upper = f32[] constant(127)
+ uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+ clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+
+ ROOT convert = s8[1,3,3,64] convert(clamp)
+ })",
+ // post_hlo
+ R"(
+// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest,
+ TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) {
+ MAYBE_SKIP_TEST("I8");
+ // From:
+ // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
+ // convert<float>(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w,
+ // float_alpha_side, convert<float>(int8_side_input), float_bias));
+ TestClamp(
+ // pre_hlo
+ R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = f32[] constant(0)
+ zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = f32[] constant(0.999994934)
+ alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = f32[] constant(0.899994934)
+ alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = s8[1,3,3,64] parameter(0)
+ filter = s8[3,3,64,64] parameter(1)
+ side_input = s8[1,3,3,64] parameter(2)
+ bias = f32[64] parameter(3)
+
+ inputs32 = s32[1,3,3,64] convert(input)
+ filters32 = s32[3,3,64,64] convert(filter)
+ side_input_f32 = f32[1,3,3,64] convert(side_input)
+
+ conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+ convfloat = f32[1,3,3,64] convert(conv)
+ scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
+ scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
+ broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = f32[1,3,3,64] add(add1, scaled_side_input)
+ relu = f32[1,3,3,64] maximum(zeros, add2)
+
+ lower = f32[] constant(-128)
+ lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+ upper = f32[] constant(127)
+ uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+ ROOT clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+ })",
+ // post_hlo
+ R"(
+// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[fusion_1_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+ )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) {
+ MAYBE_SKIP_TEST("I8");
+ // Check that integer convolution without clamp to int8_t is not allowed.
+ // convert<int8_t>(custom_call<int32_t>(int32_x, int32_w,
+ // cudnnConvolutionForward))
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
+ zero = s8[] constant(0)
+ zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
+ input = s8[1,17,9,9]{3,2,1,0} parameter(0)
+ filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
+ custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
+ get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
+ convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
+ ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
+ })");
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
+ GetDnnVersion(), GetToolkitVersion())
+ .Run(m.get())
+ .ok());
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) {
+ MAYBE_SKIP_TEST("I8");
+ // Although bias and so on are fused with forward convolution,
+ // it is still not allowed if the output is not clampped/converted to int8_t
+ // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for
+ // int8_t
+
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
+ zero = s8[] constant(0)
+ zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
+ input = s8[1,17,9,9]{3,2,1,0} parameter(0)
+ filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
+ custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
+ get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
+ convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
+ ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
+ })");
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
+ GetDnnVersion(), GetToolkitVersion())
+ .Run(m.get())
+ .ok());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc
new file mode 100644
index 0000000..7fbd589
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc
@@ -0,0 +1,1777 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <numeric>
+#include <optional>
+#include <queue>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/types.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#endif
+
+namespace xla {
+namespace gpu {
+namespace {
+namespace m = match;
+
+// A struct that contains all the matched nodes
+// and results from pattern matching forward graph
+struct MatchFwdResult {
+ HloInstruction* matched_bmm_1 = nullptr;
+ HloInstruction* matched_bmm_2 = nullptr;
+ HloInstruction* matched_bias = nullptr;
+ HloInstruction* matched_scale = nullptr;
+ HloInstruction* matched_softmax_input = nullptr;
+ HloInstruction* matched_reduce_sum = nullptr;
+
+ double matched_dropout_rate = 0.0;
+ bool need_canonicalization = false;
+ bool is_training = false;
+ // We use this to keep track of whether the bias is being
+ // applied to the bmm1 is a causal mask, cuDNN can generate causal mask inside
+ // the attention kernel to save I/O.
+ bool is_causal_mask = false;
+ bool has_match = false;
+ std::string matched_custom_call_name;
+};
+
+// A struct that contains all the matched nodes
+// and results from pattern matching backward graph
+struct MatchBwdResult {
+ HloInstruction* matched_bmm_1_grad_1 = nullptr;
+ HloInstruction* matched_bmm_1_grad_2 = nullptr;
+
+ HloInstruction* matched_bmm_2_grad_1 = nullptr;
+ HloInstruction* matched_bmm_2_grad_2 = nullptr;
+ HloInstruction* matched_dbias = nullptr;
+ // We use this to keep track of all gradient bmms that need
+ // canonicalization.
+ bool bmm_1_grad_1_need_canonicalization = false;
+ bool bmm_1_grad_2_need_canonicalization = false;
+ bool bmm_2_grad_1_need_canonicalization = false;
+ bool bmm_2_grad_2_need_canonicalization = false;
+
+ bool has_match = false;
+ std::string matched_custom_call_name;
+};
+
+template <typename Pattern>
+auto OptionalReshape(Pattern pattern) {
+ auto shared = m::SharedSubpattern(pattern);
+ return m::AnyOf<HloInstruction>(m::Reshape(shared), shared);
+}
+
+template <typename Pattern>
+auto OptionalConvert(Pattern pattern) {
+ auto shared = m::SharedSubpattern(pattern);
+ return m::AnyOf<HloInstruction>(m::Convert(shared), shared);
+}
+
+template <typename Pattern>
+auto OptionalBitcast(Pattern pattern) {
+ auto shared = m::SharedSubpattern(pattern);
+ return m::AnyOf<HloInstruction>(m::Bitcast(shared), shared);
+}
+
+template <typename Pattern>
+auto OptionalBroadcast(Pattern pattern) {
+ auto shared = m::SharedSubpattern(pattern);
+ return m::AnyOf<HloInstruction>(m::Broadcast(shared), shared);
+}
+
+bool IsBatchedMatmul(const HloInstruction* instr) {
+ if (instr->opcode() != HloOpcode::kDot) return false;
+ if (Cast<HloDotInstruction>(instr)->sparse_operands()) return false;
+ const DotDimensionNumbers& dot_dims = instr->dot_dimension_numbers();
+ bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
+ !dot_dims.rhs_batch_dimensions().empty();
+ return is_batch_dot;
+}
+
+// We need to check if current gemm is sharing a parent node with a forward
+// fMHA call because when we match backward gemms, the only way that we can be
+// sure this is a backward gemm is to see if it's sharing the same operand with
+// forward mha call(i.e Q,K,V,activation tensors). We can also use this function
+// to infer if a gemm is a forward fmha gemm or not. We check this by doing a
+// BFS of all operands to see if there's any user that is a forward fMHA custom
+// call. We continue the traversal for shape ops like bitcast, reshape and
+// transpose until we see a forward fmha call or there's no shape ops in path
+// which means that current node will never share the same operand with a
+// forward fmha call.
+bool IsSharingOperandWithFwdMha(HloInstruction* gemm) {
+ for (int64_t i = 0; i < gemm->operands().size(); i++) {
+ std::queue<HloInstruction*> visit_list;
+ visit_list.push(gemm->mutable_operand(i));
+ while (!visit_list.empty()) {
+ HloInstruction* current_instr = visit_list.front();
+ for (auto user : current_instr->users()) {
+ switch (user->opcode()) {
+ case HloOpcode::kBitcast:
+ case HloOpcode::kReshape:
+ case HloOpcode::kTranspose: {
+ visit_list.push(user);
+ break;
+ }
+ case HloOpcode::kCustomCall: {
+ if (IsFwdCustomCallTofMHA(*user)) {
+ return true;
+ }
+ } break;
+ default:
+ break;
+ }
+ }
+ visit_list.pop();
+ }
+ }
+ return false;
+}
+// When we reach a gemm instruction, it could be one of the 3 cases:
+// 1. one of the 2 gemms in forward fmha call
+// 2. one of the 4 gemms in backward fmha call
+// 3. gemms of other un-related layers
+// 3 can be easily ruled out by the pattern matcher.
+// However, 1 and 2 have very similar bmm-bmm structures.
+// We need to determine that we exactly match case 1 for forward gemms
+// which have below properties:
+// - A batched matmul
+// - None of the operands is a forward fmha call, in which case would make it
+// a backward gemm.
+// - It's not directly or indirectly sharing an operand with any other fmha
+// call, in which case would make it a backward gemm
+bool IsFirstFwdMatmul(HloInstruction* gemm) {
+ return IsBatchedMatmul(gemm) && !IsFwdCustomCallTofMHA(*gemm->operand(0)) &&
+ !IsFwdCustomCallTofMHA(*gemm->operand(1)) &&
+ !IsSharingOperandWithFwdMha(gemm);
+}
+
+bool IsScalar(const HloInstruction* instr) {
+ return ShapeUtil::IsEffectiveScalar(instr->shape());
+}
+
+bool IsReduceMax(const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kReduce &&
+ instr->to_apply()->root_instruction()->opcode() == HloOpcode::kMaximum;
+}
+
+bool IsReduceSum(const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kReduce &&
+ instr->to_apply()->root_instruction()->opcode() == HloOpcode::kAdd;
+}
+
+// Set up subpatterns for re-use.
+// Matches softmax sub-pattern ->
+// divide(exp(Subtract(producer, reduce_max(producer))),
+// broadcast(reduce_add(exp(Subtract(...))))). There might be reshape and
+// convert nodes between reduce and Subtract.
+// TODO TJ: Make this more general to any patterns that has this structure when
+// cudnn runner supports generic cudnnOpGraphs. producer
+// | \
+// | reduce
+// | |
+// | broadcast
+// | /
+// root
+auto GetUnfusedReduceMaxSumSoftmaxPattern(
+ HloInstruction** softmax_input = nullptr,
+ HloInstruction** softmax_reduce_sum = nullptr,
+ HloInstruction** softmax_reduce_sum_bcast = nullptr) {
+ // The reduce-max part of the softmax
+ // reduce_max and subtract will always have exactly 1 user
+ // in both training and inference
+ // softmax_input should always have exactly 2 users
+ auto unfused_softmax_max_subpattern = m::SharedSubpattern(
+ m::Subtract(
+ m::Op(),
+ m::Broadcast(OptionalConvert(
+ m::Op()
+ .WithPredicate(IsReduceMax)
+ .WithOneUse()
+ .WithOperand(0, OptionalBitcast(OptionalConvert(
+ m::Op(softmax_input).WithNumUser(2)))))))
+ .WithOneUse());
+ // The reduce-add part of the softmax
+ // reduce_sum and reduce_sum_broadcast should have 2 users in training
+ // and 1 user in inference
+ auto unfused_softmax_sum_subpattern = m::SharedSubpattern(m::Divide(
+ OptionalBitcast(m::Exp(unfused_softmax_max_subpattern)),
+ m::Broadcast(
+ softmax_reduce_sum_bcast,
+ OptionalConvert(
+ m::Op(softmax_reduce_sum)
+ .WithOperand(0, OptionalBitcast(OptionalConvert(
+ m::Exp(unfused_softmax_max_subpattern))))
+ .WithPredicate(IsReduceSum)
+ .WithAtMostNumUser(2)))
+ .WithAtMostNumUser(2)));
+ return unfused_softmax_sum_subpattern;
+}
+
+std::optional<double> GetConstantValue(const HloInstruction* inst) {
+ if (!IsScalar(inst)) {
+ return std::nullopt;
+ }
+ switch (inst->shape().element_type()) {
+ case F16:
+ return static_cast<float>(inst->literal().GetFirstElement<half>());
+ case BF16:
+ return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
+ case F32:
+ return inst->literal().GetFirstElement<float>();
+ case F64:
+ return inst->literal().GetFirstElement<double>();
+ default:
+ return std::nullopt;
+ }
+}
+
+double GetDropoutRateFromHlo(HloInstruction* dropout) {
+ std::optional<double> dropout_rate_inv;
+ dropout_rate_inv = GetConstantValue(dropout);
+ if (!dropout_rate_inv.has_value()) {
+ return 0.0;
+ }
+ // In dropout, inputs are divided by (1 - rate), we need to divide 1 by
+ // the constant in dropout node and substract
+ // from 1 here to get the actual dropout rate.
+ return (1.0 - (1.0 / *dropout_rate_inv));
+}
+
+bool IsComputeCapabilityAndCudnnSupported(
+ stream_executor::CudaComputeCapability cc,
+ stream_executor::dnn::VersionInfo cudnn_version,
+ stream_executor::dnn::VersionInfo supported_cudnn_version) {
+ if (cc.IsAtLeastAmpere() && cudnn_version >= supported_cudnn_version) {
+ return true;
+ }
+ VLOG(2) << absl::StrFormat(
+ "CudnnFusedMHARewriter did not run. Unsupported compute "
+ "capability(%s; major should be >= 8, minor should be 0) or cudnn version"
+ "(%s; should be >= %s)",
+ cc.ToString(), cudnn_version.ToString(),
+ supported_cudnn_version.ToString());
+ return false;
+}
+
+bool IsSupportedPrimitiveType(const HloInstruction* bmm) {
+ PrimitiveType dtype = bmm->shape().element_type();
+ return dtype == BF16 || dtype == F16;
+}
+
+std::vector<int64_t> GetDimensionVector(absl::Span<const int64_t> dimensions,
+ absl::Span<const int64_t> dim_nums) {
+ std::vector<int64_t> vec(dim_nums.size());
+ for (int i = 0; i < dim_nums.size(); i++) {
+ vec[i] = dimensions.at(dim_nums.at(i));
+ }
+ return vec;
+}
+
+struct QKVLayout {
+ int64_t batch;
+ int64_t num_heads;
+ int64_t seqlen_q;
+ int64_t seqlen_kv;
+ int64_t hidden_dim;
+};
+
+absl::StatusOr<std::optional<QKVLayout>> GetQKVLayout(
+ HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization) {
+ // get layout from bmm1
+ const DotDimensionNumbers& bmm1_dnums = bmm_1->dot_dimension_numbers();
+ TF_ASSIGN_OR_RETURN(
+ std::vector<int64_t> bmm1_s_q_dims,
+ GetNonContractingDims(bmm_1->operand(0)->shape(),
+ bmm1_dnums.lhs_batch_dimensions(),
+ bmm1_dnums.lhs_contracting_dimensions()));
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<int64_t> bmm1_s_kv_dims,
+ GetNonContractingDims(bmm_1->operand(1)->shape(),
+ bmm1_dnums.rhs_batch_dimensions(),
+ bmm1_dnums.rhs_contracting_dimensions()));
+
+ std::vector<int64_t> bmm1_bh =
+ GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
+ bmm1_dnums.lhs_batch_dimensions());
+
+ std::vector<int64_t> bmm1_s_q = GetDimensionVector(
+ bmm_1->operand(0)->shape().dimensions(), bmm1_s_q_dims);
+
+ std::vector<int64_t> bmm1_s_kv = GetDimensionVector(
+ bmm_1->operand(1)->shape().dimensions(), bmm1_s_kv_dims);
+
+ std::vector<int64_t> bmm1_d =
+ GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
+ bmm1_dnums.lhs_contracting_dimensions());
+
+ TF_RET_CHECK(bmm1_bh.size() == 2);
+ TF_RET_CHECK(bmm1_s_q.size() == 1);
+ TF_RET_CHECK(bmm1_s_kv.size() == 1);
+ TF_RET_CHECK(bmm1_d.size() == 1);
+
+ // get layout from bmm2
+ const DotDimensionNumbers& bmm2_dnums = bmm_2->dot_dimension_numbers();
+ TF_ASSIGN_OR_RETURN(
+ std::vector<int64_t> bmm2_lhs_non_contracting_dims,
+ GetNonContractingDims(bmm_2->operand(0)->shape(),
+ bmm2_dnums.lhs_batch_dimensions(),
+ bmm2_dnums.lhs_contracting_dimensions()));
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<int64_t> bmm2_rhs_non_contracting_dims,
+ GetNonContractingDims(bmm_2->operand(1)->shape(),
+ bmm2_dnums.rhs_batch_dimensions(),
+ bmm2_dnums.rhs_contracting_dimensions()));
+
+ std::vector<int64_t> bmm2_bh =
+ GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+ bmm2_dnums.lhs_batch_dimensions());
+
+ std::vector<int64_t> bmm2_s_kv =
+ GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+ bmm2_dnums.lhs_contracting_dimensions());
+
+ std::vector<int64_t> bmm2_s_q =
+ need_canonicalization
+ ? GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
+ bmm2_rhs_non_contracting_dims)
+ : GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+ bmm2_lhs_non_contracting_dims);
+
+ std::vector<int64_t> bmm2_d =
+ need_canonicalization
+ ? GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+ bmm2_lhs_non_contracting_dims)
+ : GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
+ bmm2_rhs_non_contracting_dims);
+
+ TF_RET_CHECK(bmm2_bh.size() == 2);
+ TF_RET_CHECK(bmm2_s_q.size() == 1);
+ TF_RET_CHECK(bmm2_s_kv.size() == 1);
+ TF_RET_CHECK(bmm2_d.size() == 1);
+
+ // check if bhsd is correct between bmm1 and bmm2
+ if (bmm1_bh[0] != bmm2_bh[0] || bmm1_bh[1] != bmm2_bh[1] ||
+ bmm1_s_q[0] != bmm2_s_q[0] || bmm1_s_kv[0] != bmm2_s_kv[0] ||
+ bmm1_d[0] != bmm2_d[0]) {
+ return std::nullopt;
+ }
+
+ QKVLayout qkv_layout;
+ qkv_layout.batch = bmm1_bh[0];
+ qkv_layout.num_heads = bmm1_bh[1];
+ qkv_layout.seqlen_q = bmm1_s_q[0];
+ qkv_layout.seqlen_kv = bmm1_s_kv[0];
+ qkv_layout.hidden_dim = bmm1_d[0];
+ return qkv_layout;
+}
+
+absl::StatusOr<bool> IsFlashAttention(
+ QKVLayout qkv_layout, bool is_training,
+ stream_executor::CudaComputeCapability cc,
+ stream_executor::dnn::VersionInfo cudnn_version) {
+ int64_t s_q = qkv_layout.seqlen_q;
+ int64_t s_kv = qkv_layout.seqlen_kv;
+ int64_t hidden_dim = qkv_layout.hidden_dim;
+ // start with most relaxed constraint
+ bool is_seqlen_supported = (!is_training || (s_q % 2 == 0 && s_kv % 2 == 0));
+ bool is_hidden_dim_supported = hidden_dim <= 128 && hidden_dim % 8 == 0;
+ bool is_flash_attention = is_seqlen_supported && is_hidden_dim_supported;
+ if (!is_flash_attention) return false;
+
+ // going backwards to check compatibility
+ if ((is_training && (s_q < 64 || s_kv < 64)) &&
+ !IsComputeCapabilityAndCudnnSupported(
+ cc, cudnn_version, stream_executor::dnn::VersionInfo(9, 0, 0))) {
+ VLOG(2) << "Flash attention training with seq < 64 not supported cuDNN < "
+ "9.0.0.";
+ return false;
+ }
+
+ if ((hidden_dim != 64 && hidden_dim != 128) &&
+ !IsComputeCapabilityAndCudnnSupported(
+ cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 6))) {
+ VLOG(2) << "Flash attention head dim != 64 or 128 not supported with cuDNN "
+ "< 8.9.6.";
+ return false;
+ }
+
+ if ((is_training && s_kv % 64 != 0) &&
+ !IsComputeCapabilityAndCudnnSupported(
+ cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 5))) {
+ VLOG(2) << "Flash attention training with seq kv % 64 != 0 not supported "
+ "with cuDNN < 8.9.5.";
+ return false;
+ }
+
+ if (!IsComputeCapabilityAndCudnnSupported(
+ cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 4))) {
+ VLOG(2) << "Require cuDNN 8.9.4 to run flash attention.";
+ return false;
+ }
+ return is_flash_attention;
+}
+
+bool IsCausalMaskPattern(HloInstruction* mask) {
+ auto causal_mask =
+ m::Select(m::Compare(m::Iota(), m::Iota()), m::Broadcast(m::Constant()),
+ m::Broadcast(m::Constant()));
+ auto causal_mask_pattern_fwd_remat =
+ m::Broadcast(OptionalBitcast(causal_mask));
+ auto causal_mask_pattern_bwd = m::Broadcast(m::Convert(OptionalBitcast(
+ m::Minimum(m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))));
+ HloInstruction* param = nullptr;
+ HloInstruction* gte = nullptr;
+ auto causal_mask_pattern_fwd = m::Broadcast(
+ OptionalBitcast(m::GetTupleElement(>e, m::Parameter(¶m))));
+ auto causal_mask_pattern = m::AnyOf<HloInstruction>(
+ causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd,
+ causal_mask_pattern_bwd);
+ if (Match(mask, causal_mask_pattern)) {
+ if (param != nullptr && param->parent()->IsWhileBodyComputation()) {
+ // need to track to outside of the while loop body to find the real mask.
+ auto while_instr = param->parent()->WhileCallInstruction();
+ auto mask_index = gte->tuple_index();
+ auto actual_mask =
+ while_instr->mutable_operand(0)->mutable_operand(mask_index);
+ auto causal_mask_pattern_fwd =
+ OptionalBitcast(m::Convert(m::MinimumAnyOrder(
+ m::Op(),
+ OptionalBitcast(m::MinimumAnyOrder(
+ m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))));
+ return Match(actual_mask, causal_mask_pattern_fwd);
+ }
+ return true;
+ }
+ return false;
+}
+
+MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result,
+ int64_t bmm2_operand_position,
+ HloInstruction* instr) {
+ // Matches the dropout-softmax subpattern.
+ // Softmax_output is a divide
+ // Dropout can take multiple forms, we capture 2 forms here based on
+ // heurustics Form 1 -> softmax - mul - select(dropout) - BMM2
+ MatchFwdResult match_result = previous_result;
+ HloInstruction* softmax_reduce_sum;
+ HloInstruction* softmax_reduce_sum_bcast;
+ HloInstruction* bmm_2;
+ HloInstruction* softmax_input;
+ HloInstruction* dropout = nullptr;
+ auto dropout_softmax_pattern_form_1 = m::Select(
+ m::Op(),
+ OptionalConvert(m::MultiplyAnyOrder(
+ OptionalBitcast(OptionalReshape(
+ OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
+ &softmax_input, &softmax_reduce_sum,
+ &softmax_reduce_sum_bcast)))),
+ m::Broadcast(
+ OptionalConvert(m::Constant(&dropout).WithPredicate(IsScalar))))),
+ m::Op());
+
+ // Form 2 -> softmax - mul - BMM2
+ // /
+ // /
+ // select(dropout)
+ auto dropout_softmax_pattern_form_2 =
+ OptionalBitcast(OptionalBitcast(OptionalConvert(m::MultiplyAnyOrder(
+ OptionalReshape(OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
+ &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast))),
+ m::Broadcast(
+ OptionalConvert(OptionalBitcast(OptionalReshape(m::Select(
+ m::Op(),
+ m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)),
+ m::Op())))))))));
+
+ // Form3 -> softmax - mul(dropout) - mul(scale) - BMM2
+ auto dropout_softmax_pattern_form_3 = m::MultiplyAnyOrder(
+ m::MultiplyAnyOrder(
+ OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
+ &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast)),
+ m::Op()),
+ m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)));
+
+ // Try matching BMM1 - (Scale) - (Bias) - Softmax - (Dropout) -
+ // BMM2 Dropout with non-zero drop rate has select(divide(softmax_output,
+ // broadcast(1-dropout_rate)))
+ auto softmax_dropout_bmm2_pattern =
+ m::Op(&bmm_2)
+ .WithPredicate(IsBatchedMatmul)
+ .WithOperand(bmm2_operand_position,
+ m::AnyOf<HloInstruction>(
+ OptionalBitcast(OptionalConvert(
+ GetUnfusedReduceMaxSumSoftmaxPattern(
+ &softmax_input, &softmax_reduce_sum,
+ &softmax_reduce_sum_bcast))),
+ dropout_softmax_pattern_form_1,
+ dropout_softmax_pattern_form_2,
+ dropout_softmax_pattern_form_3));
+
+ if (!Match(instr, softmax_dropout_bmm2_pattern) ||
+ !IsSupportedPrimitiveType(bmm_2)) {
+ match_result.has_match = false;
+ return match_result;
+ }
+ if (softmax_reduce_sum->users()[0]->opcode() == HloOpcode::kConvert) {
+ softmax_reduce_sum = softmax_reduce_sum->users()[0];
+ }
+ match_result.is_training = softmax_reduce_sum->user_count() == 2 &&
+ softmax_reduce_sum_bcast->user_count() == 2;
+ match_result.matched_bmm_2 = bmm_2;
+ if (dropout) {
+ match_result.matched_dropout_rate = GetDropoutRateFromHlo(dropout);
+ }
+ match_result.matched_softmax_input = softmax_input;
+ match_result.matched_reduce_sum = softmax_reduce_sum;
+ match_result.has_match = true;
+ return match_result;
+}
+
+MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
+ HloInstruction* softmax_input,
+ bool has_dropout) {
+ MatchFwdResult match_result = previous_result;
+ HloInstruction* bmm_1;
+ HloInstruction* bias = nullptr;
+ HloInstruction* scale = nullptr;
+ // bmm1/scale/bias add should have 2 users if being connected to softmax
+ // otherwise should have exactly 1 user
+ auto first_bmm_pattern =
+ m::SharedSubpattern(m::Op(&bmm_1).WithPredicate(IsBatchedMatmul));
+ auto unfused_scaled_bmm_subpattern = m::MultiplyAnyOrder(
+ OptionalConvert(first_bmm_pattern.WithOneUse()),
+ OptionalConvert(
+ m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar))));
+ if (Match(softmax_input,
+ OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
+ first_bmm_pattern, unfused_scaled_bmm_subpattern))))) {
+ // bmm1 - (scale) - softmax
+ match_result.matched_bmm_1 = bmm_1;
+ match_result.matched_scale = scale;
+ match_result.matched_custom_call_name =
+ has_dropout ? kCudnnfMHASoftmaxDropoutCallTarget
+ : kCudnnfMHASoftmaxCallTarget;
+ match_result.has_match = true;
+ } else if (Match(softmax_input,
+ OptionalBitcast(m::AddAnyOrder(
+ OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
+ unfused_scaled_bmm_subpattern.WithOneUse(),
+ first_bmm_pattern.WithOneUse()))),
+ m::Op(&bias))))) {
+ // bmm1 - (scale) - bias - softmax
+ match_result.matched_bmm_1 = bmm_1;
+ match_result.matched_scale = scale;
+ match_result.matched_custom_call_name =
+ has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget
+ : kCudnnfMHAScaleBiasSoftmaxCallTarget;
+ match_result.is_causal_mask |= IsCausalMaskPattern(bias);
+ if (!match_result.is_causal_mask &&
+ bias->opcode() == HloOpcode::kBroadcast) {
+ // we can take the bias before broadcast
+ auto dims = Cast<HloBroadcastInstruction>(bias)->dimensions();
+ if (dims == std::vector<int64_t>{2, 3} ||
+ dims == std::vector<int64_t>{0, 2, 3} ||
+ dims == std::vector<int64_t>{1, 2, 3}) {
+ // shapes [1, 1, s, s], [b, 1, s, s], [1, h, s, s] are supported
+ HloInstruction* bias_bc = bias->mutable_operand(0);
+ // bitcast bias_before_broadcast to be 4D
+ std::vector<int64_t> bitcast_dims(bias->shape().rank(), 1);
+ for (int dim : dims) {
+ bitcast_dims[dim] = bias->shape().dimensions()[dim];
+ }
+ bias = bias_bc->AddInstruction(HloInstruction::CreateBitcast(
+ ShapeUtil::MakeShape(bias->shape().element_type(), bitcast_dims),
+ bias_bc));
+ }
+ }
+ match_result.matched_bias = bias;
+ match_result.has_match = true;
+ } else {
+ match_result.has_match = false;
+ }
+ return match_result;
+}
+
+// We will try to match all the patterns below:
+// BMM1 - Scale - bias - Softmax - Dropout - BMM2
+// BMM1 - Scale - bias - Softmax - BMM2
+// BMM1 - Softmax - Dropout - BMM2
+// BMM1 - Softmax - BMM2
+MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
+ // We need to match 2 general cases:
+ // 1. bmm1 --> (intermediate nodes) --> bmm2 <-- V matrix
+ // 2. V matrix --> bmm2 <-- (intermediate nodes) <-- bmm1
+ // to determine if we need to canonicalize bmm2.
+ // So we go through both of bmm2's operands and see which one matches our
+ // desired patterns, if operand 1 consumes them, then we need to canonicalize.
+ MatchFwdResult match_result;
+ for (auto bmm2_operand_pos : {0, 1}) {
+ if (bmm2_operand_pos == 1) {
+ match_result.need_canonicalization = true;
+ }
+
+ bool has_dropout = false;
+ // We first check if bmm2 is connect to a softmax or dropout.
+ // If so, we set softmax input and dropout rate to their corresponding
+ // values.
+ match_result =
+ MatchSoftmaxDropoutBmm(match_result, bmm2_operand_pos, instr);
+ if (!match_result.has_match) {
+ continue;
+ }
+ has_dropout = match_result.matched_dropout_rate > 0.0;
+ match_result = MatchBmm1UnfusedBiasSoftmaxBmm2(
+ match_result, match_result.matched_softmax_input, has_dropout);
+ if (match_result.has_match) {
+ return match_result;
+ }
+ }
+ // Didn't find any match
+ match_result.need_canonicalization = false;
+ return match_result;
+}
+
+bool IsBmm2GradGemm2(HloInstruction* instr) {
+ // Check to see if input bmm is bmm2 gradient gemm2, it needs to be either:
+ // 1. having 1 user in cases of dropout
+ // 2. having 2 users in other cases.
+ return (instr->user_count() == 1) || (instr->user_count() == 2);
+}
+
+MatchBwdResult MatchBmm1GradGemm1(MatchBwdResult previous_result,
+ HloInstruction* bmm_1) {
+ MatchBwdResult match_result = previous_result;
+ match_result.has_match = false;
+ const HloInstruction* q_tensor = bmm_1->operand(0);
+ for (int64_t i = 0; i < q_tensor->user_count(); i++) {
+ HloInstruction* q_tensor_user_i = q_tensor->users()[i];
+ if (IsBatchedMatmul(q_tensor_user_i) && q_tensor_user_i != bmm_1) {
+ match_result.matched_bmm_1_grad_1 = q_tensor_user_i;
+ // Check for canonicalization.
+ if (match_result.matched_bmm_1_grad_1->operand_index(q_tensor) != 1) {
+ match_result.bmm_1_grad_1_need_canonicalization = true;
+ }
+ match_result.has_match = true;
+ }
+ }
+ return match_result;
+}
+
+MatchBwdResult MatchBmm1GradGemm2(MatchBwdResult previous_result,
+ HloInstruction* fwd_fmha_call) {
+ HloInstruction* bmm_1_grad_2 = nullptr;
+ MatchBwdResult match_result = previous_result;
+ match_result.has_match = false;
+ // bmm1 gradient gemm2 shares the same input d_s as bmm1 gradient gemm1.
+ // Check to see if bmm1 grad gemm1 needs canonicalization or not, if not,
+ // then the shared input is the first operand.
+ int64_t d_s_index = match_result.bmm_1_grad_1_need_canonicalization ? 1 : 0;
+ HloInstruction* d_s_user_0 = match_result.matched_bmm_1_grad_1;
+
+ HloInstruction* d_s = d_s_user_0->mutable_operand(d_s_index);
+ if (d_s->opcode() == HloOpcode::kBitcast && d_s->user_count() == 1) {
+ d_s = d_s->mutable_operand(0);
+ }
+
+ auto bmm_1_grad_2_it = std::find_if(
+ d_s->users().begin(), d_s->users().end(), [&](HloInstruction* instr) {
+ return instr != match_result.matched_bmm_1_grad_1 &&
+ instr->opcode() == HloOpcode::kDot;
+ });
+ if (bmm_1_grad_2_it != d_s->users().end()) {
+ bmm_1_grad_2 = *bmm_1_grad_2_it;
+ } else {
+ return match_result;
+ }
+
+ match_result.matched_bmm_1_grad_2 = bmm_1_grad_2;
+
+ if (match_result.matched_bmm_1_grad_2->operand_index(d_s) != 0) {
+ match_result.bmm_1_grad_2_need_canonicalization = true;
+ }
+ match_result.has_match = true;
+ return match_result;
+}
+
+MatchBwdResult MatchBmm2GradGemm1(HloInstruction* fwd_fmha_call) {
+ HloInstruction* bmm_2_grad_1 = nullptr;
+ MatchBwdResult matched_result;
+ // The second GTE of the forward MHA call is the input of the bmm2's gradient
+ // gemm 1, we check to see if the current gemm satisfies above condition.
+ int64_t activation_out_gte_index = 1;
+ if (fwd_fmha_call->user_count() < 2 ||
+ fwd_fmha_call->users()[activation_out_gte_index]->opcode() !=
+ HloOpcode::kGetTupleElement ||
+ fwd_fmha_call->users()[activation_out_gte_index]->user_count() > 1 ||
+ !IsBatchedMatmul(
+ fwd_fmha_call->users()[activation_out_gte_index]->users()[0])) {
+ matched_result.has_match = false;
+ return matched_result;
+ }
+ // Found fmha->GTE->gemm, assign it to bmm_2_grad_1 and check to see if it
+ // needs canonicalization.
+ bmm_2_grad_1 = fwd_fmha_call->users()[activation_out_gte_index]->users()[0];
+ matched_result.matched_bmm_2_grad_1 = bmm_2_grad_1;
+ if (bmm_2_grad_1->operand_index(
+ fwd_fmha_call->users()[activation_out_gte_index]) != 0) {
+ matched_result.bmm_2_grad_1_need_canonicalization = true;
+ }
+
+ matched_result.has_match = true;
+ return matched_result;
+}
+
+MatchBwdResult MatchBmm2GradGemm2(MatchBwdResult previous_result,
+ HloInstruction* fwd_fmha_call,
+ bool v_transposed) {
+ MatchBwdResult match_result = previous_result;
+ match_result.has_match = false;
+ // If v tensor is transposed by forward fmha call, then we need to take fmha v
+ // input's producer's producer.
+ const HloInstruction* v_tensor = v_transposed
+ ? fwd_fmha_call->operand(2)->operand(0)
+ : fwd_fmha_call->operand(2);
+ for (int64_t i = 0; i < v_tensor->user_count(); i++) {
+ HloInstruction* v_tensor_user_i = v_tensor->users()[i];
+ if (IsBatchedMatmul(v_tensor_user_i) && IsBmm2GradGemm2(v_tensor_user_i)) {
+ match_result.matched_bmm_2_grad_2 = v_tensor_user_i;
+ // Check for canonicalization.
+ if (match_result.matched_bmm_2_grad_2->operand_index(v_tensor) != 1) {
+ match_result.bmm_2_grad_2_need_canonicalization = true;
+ }
+ match_result.has_match = true;
+ }
+ }
+ return match_result;
+}
+
+MatchBwdResult MatchDbias(MatchBwdResult previous_result,
+ HloInstruction* d_intermediate,
+ const absl::flat_hash_set<HloInstruction*> users) {
+ MatchBwdResult match_result = previous_result;
+ auto user_count = d_intermediate->user_count();
+ HloInstruction* dbias_user = nullptr;
+ HloInstruction* dbias = nullptr;
+ for (auto user : d_intermediate->users()) {
+ if (users.contains(user)) {
+ user_count -= 1;
+ } else {
+ dbias_user = user;
+ }
+ }
+ auto ConsumeExtraConvert = [](HloInstruction* instr) {
+ Match(instr->users()[0], m::Convert(&instr, m::Op()).WithOneUse());
+ return true;
+ };
+ // user_count == 1 && (reduce-> {convert} ->bitcast)
+ match_result.has_match =
+ user_count == 1 &&
+ Match(dbias_user, m::Reduce(&dbias, m::Op(), m::Op()).WithOneUse()) &&
+ dbias->shape().rank() == 3 && ConsumeExtraConvert(dbias);
+ if (match_result.has_match) {
+ // cuDNN only supports dbias for [1, h, s, s]
+ // make sure reduce dimension is on batch dim
+ auto reduce_dim = dbias->dimensions();
+ if (reduce_dim.size() == 1 && reduce_dim[0] == 0) {
+ match_result.matched_dbias = dbias;
+ } else {
+ match_result.has_match = false;
+ }
+ }
+ return match_result;
+}
+
+MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result,
+ HloInstruction* fwd_fmha_call) {
+ MatchBwdResult match_result = previous_result;
+ bool is_bmm1_grad1_canonicalized =
+ match_result.bmm_1_grad_1_need_canonicalization;
+ match_result.has_match = false;
+ bool has_scale = false;
+ bool has_dropout = false;
+ // Backward dropout pattern
+ // select(mask, bmm2_grad2, broadcast())
+ auto bwd_dropout_pattern_form_1 = m::SharedSubpattern(
+ OptionalBitcast(OptionalReshape(OptionalConvert(m::Select(
+ m::Op(), m::Op().WithPredicate([&](const HloInstruction* instr) {
+ return instr == match_result.matched_bmm_2_grad_2;
+ }),
+ m::Broadcast(
+ OptionalConvert(m::Constant().WithPredicate(IsScalar))))))));
+
+ // multiply(bmm2_grad2, broadcast(select(mask, broadcast(), op())))
+ auto bwd_dropout_pattern_form_2 =
+ m::SharedSubpattern(OptionalBitcast(m::MultiplyAnyOrder(
+ OptionalConvert(
+ m::Op().WithPredicate([&](const HloInstruction* instr) {
+ return instr == match_result.matched_bmm_2_grad_2;
+ })),
+ m::Broadcast(OptionalConvert(OptionalBitcast(OptionalReshape(
+ m::Select(m::Op(),
+ m::Broadcast(OptionalConvert(
+ m::Constant().WithPredicate(IsScalar))),
+ m::Op()))))))));
+ auto bwd_dropout_pattern_form_3 = OptionalConvert(m::MultiplyAnyOrder(
+ m::MultiplyAnyOrder(
+ m::Op().WithPredicate([&](const HloInstruction* instr) {
+ return instr == match_result.matched_bmm_2_grad_2;
+ }),
+ m::Broadcast(m::Constant().WithPredicate(IsScalar))),
+ m::Op()));
+ auto bwd_dropout_pattern = m::AnyOf<HloInstruction>(
+ bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2,
+ bwd_dropout_pattern_form_3);
+ // Backward softmax pattern
+ HloInstruction* bwd_softmax_input = nullptr;
+ HloInstruction* exp_1;
+ HloInstruction* exp_2;
+ HloInstruction* d_softmax;
+
+ // d_softmax = exp * (dy / s_b - sum(dy * exp * 1 / s^2))
+ // there could be at most 3 users of d_softmax: bmm1grad1 bmm1grad2 and dbias
+ auto bwd_softmax_pattern = OptionalBitcast(OptionalConvert(
+ m::MultiplyAnyOrder(
+ &d_softmax,
+ m::AddAnyOrder(
+ m::Divide().WithOneUse(),
+ m::Broadcast(OptionalBitcast(OptionalConvert(
+ m::Negate(
+ OptionalBitcast(
+ m::Op()
+ .WithPredicate(IsReduceSum)
+ .WithOneUse()
+ .WithOperand(
+ 0, OptionalBitcast(
+ m::MultiplyAnyOrder(
+ m::MultiplyAnyOrder(
+ m::Op(&bwd_softmax_input),
+ m::Broadcast())
+ .WithOneUse(),
+ m::Exp(&exp_2, m::Op()))
+ .WithOneUse()))))
+ .WithOneUse())))),
+ m::Exp(&exp_1, m::Op()))
+ .WithAtMostNumUser(3)));
+
+ // Backward scale input pattern
+ HloInstruction* bwd_scale_input = nullptr;
+ HloInstruction* bwd_scale = nullptr;
+
+ auto bwd_scale_pattern =
+ m::MultiplyAnyOrder(&bwd_scale, m::Op(&bwd_scale_input),
+ m::Broadcast(m::Constant().WithPredicate(IsScalar)))
+ .WithNumUser(2);
+ int intermediate_input_pos = is_bmm1_grad1_canonicalized ? 1 : 0;
+
+ HloInstruction* intermediate_input =
+ match_result.matched_bmm_1_grad_1->mutable_operand(
+ intermediate_input_pos);
+
+ has_scale = Match(intermediate_input, bwd_scale_pattern);
+
+ if (has_scale) {
+ intermediate_input = bwd_scale_input;
+ }
+
+ if (!Match(intermediate_input, bwd_softmax_pattern) || exp_1 != exp_2) {
+ return match_result;
+ }
+ has_dropout = Match(bwd_softmax_input, bwd_dropout_pattern);
+ // If no dropout but softmax input is not coming from bmm2 gradient gemm 2,
+ // then it's not the pattern that we care about.
+ if (!has_dropout &&
+ !Match(bwd_softmax_input,
+ OptionalConvert((OptionalBitcast(
+ m::Op().WithPredicate([&](const HloInstruction* instr) {
+ return instr == match_result.matched_bmm_2_grad_2;
+ })))))) {
+ return match_result;
+ }
+
+ if (has_dropout) {
+ // has bias
+ if (fwd_fmha_call->custom_call_target() ==
+ kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget)
+ match_result.matched_custom_call_name =
+ kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
+ // no bias
+ if (fwd_fmha_call->custom_call_target() ==
+ kCudnnfMHASoftmaxDropoutCallTarget)
+ match_result.matched_custom_call_name =
+ kCudnnfMHASoftmaxDropoutBackwardCallTarget;
+ } else {
+ // has bias
+ if (fwd_fmha_call->custom_call_target() ==
+ kCudnnfMHAScaleBiasSoftmaxCallTarget)
+ match_result.matched_custom_call_name =
+ kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget;
+ // no bias
+ if (fwd_fmha_call->custom_call_target() == kCudnnfMHASoftmaxCallTarget)
+ match_result.matched_custom_call_name =
+ kCudnnfMHASoftmaxBackwardCallTarget;
+ }
+ // try to pattern match dbias
+ HloInstruction* dS = d_softmax;
+ if (dS->users()[0]->opcode() == HloOpcode::kConvert) {
+ dS = dS->users()[0];
+ }
+ if (has_scale) {
+ // bmm1-(scale)-(bias)-softmax pattern users could be dbias or scale bwd
+ if (dS->user_count() == 1) {
+ // no dbias
+ match_result.has_match = true;
+ } else if (dS->user_count() == 2) {
+ match_result = MatchDbias(match_result, dS, {bwd_scale});
+ } else {
+ match_result.has_match = false;
+ }
+ } else {
+ // bmm1-(bias)-softmax pattern
+ // users could be dbias besides bmm1grad1 bmm1grad2
+ if (dS->user_count() == 2) {
+ match_result.has_match = true;
+ } else if (dS->user_count() == 3) {
+ match_result = MatchDbias(match_result, dS,
+ {match_result.matched_bmm_1_grad_1,
+ match_result.matched_bmm_1_grad_2});
+ } else {
+ match_result.has_match = false;
+ }
+ }
+ return match_result;
+}
+// First, we look for the bmm2 gradient gemm 1 which takes the activation
+// output from a forward fmha call.
+// Secondly, look for bmm2 gradient gemm 2 that takes the v tensor as an
+// input. We take the v tensor from the third operand of the forward fmha
+// call. If forward is canonicalized, then we skip the additional transpose in
+// between.
+// Then we look for bmm1 gradient gemm1 by searching for gemms that share q
+// tensor with current fmha call.
+MatchBwdResult MatchBackwardBmms(HloInstruction* fwd_fmha_call,
+ HloInstruction* bmm_1, bool v_transposed) {
+ MatchBwdResult matched_result = MatchBmm2GradGemm1(fwd_fmha_call);
+ if (!matched_result.has_match) {
+ return matched_result;
+ }
+
+ matched_result =
+ MatchBmm2GradGemm2(matched_result, fwd_fmha_call, v_transposed);
+ if (!matched_result.has_match) {
+ return matched_result;
+ }
+
+ matched_result = MatchBmm1GradGemm1(matched_result, bmm_1);
+ if (!matched_result.has_match) {
+ return matched_result;
+ }
+
+ matched_result = MatchBmm1GradGemm2(matched_result, fwd_fmha_call);
+ if (!matched_result.has_match) {
+ return matched_result;
+ }
+ return matched_result;
+}
+// We will match the backward graphs for all forward patterns defined in
+// MatchFwdMHAPatternsForCanonicalization
+MatchBwdResult MatchBwdMHAPatternsForCanonicalization(
+ HloInstruction* fwd_fmha_call, HloInstruction* bmm_1, bool v_transposed) {
+ MatchBwdResult match_result =
+ MatchBackwardBmms(fwd_fmha_call, bmm_1, v_transposed);
+ if (!match_result.has_match) {
+ return match_result;
+ }
+ match_result = MatchBwdBmmSoftmaxDropoutBmm(match_result, fwd_fmha_call);
+ return match_result;
+}
+
+absl::StatusOr<bool> IsMHABlockSupported(
+ HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization,
+ bool is_training, bool is_causal_mask, std::string& custom_call_name,
+ const DebugOptions& debug_options,
+ stream_executor::CudaComputeCapability cc,
+ stream_executor::dnn::VersionInfo cudnn_version) {
+ if (MHACallHasDropout(custom_call_name) &&
+ !debug_options.xla_gpu_fused_attention_use_cudnn_rng()) {
+ VLOG(3) << "Using CUDNN RNG for fused attention dropout is not enabled.\n";
+ return false;
+ }
+
+ // cuDNN 8.8 currently only supports BF16 and F16 data types.
+ if (!IsSupportedPrimitiveType(bmm_1) || !IsSupportedPrimitiveType(bmm_2)) {
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Unsupported primitive type for cuDNN MHA fusion:\n"
+ << bmm_1->ToString() << "\nOR\n"
+ << bmm_2->ToString() << "\n"
+ << "BF16 and F16 are the supported Dtypes.";
+ }
+ return false;
+ }
+
+ if (bmm_1->shape().rank() != 4 || bmm_2->shape().rank() != 4) {
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Unsupported bmm rank for cuDNN MHA fusion:\n"
+ << bmm_1->ToString() << "\nOR\n"
+ << bmm_2->ToString() << "\n"
+ << "Only bmm with rank 4 is supported.";
+ }
+ return false;
+ }
+
+ // get batch/num heads/sequence length/hidden dim from bmm1 and bmm2
+ // also make sure they are the same between bmm1 and bmm2
+ TF_ASSIGN_OR_RETURN(std::optional<QKVLayout> qkv_layout,
+ GetQKVLayout(bmm_1, bmm_2, need_canonicalization));
+ if (!qkv_layout.has_value()) {
+ VLOG(2) << "bmm1 and bmm2 have different qkv layout.";
+ return false;
+ }
+
+ // check if matched attention block is supported by cuDNN flash attention.
+ TF_ASSIGN_OR_RETURN(
+ bool is_flash_attention,
+ IsFlashAttention(qkv_layout.value(), is_training, cc, cudnn_version));
+ if (is_flash_attention) {
+ if (is_causal_mask) {
+ // if bias is causal mask, needs to remove bias from name
+ if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget) {
+ custom_call_name = kCudnnfMHASoftmaxDropoutCallTarget;
+ } else if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxCallTarget) {
+ custom_call_name = kCudnnfMHASoftmaxCallTarget;
+ }
+ }
+ }
+ return is_flash_attention;
+}
+
+absl::StatusOr<HloInstruction*> CanonicalizeBatchedGemmForcuDNNFMHA(
+ HloInstruction* bmm, HloComputation* comp) {
+ if (VLOG_IS_ON(3)) {
+ VLOG(3) << "Before FMHA Dot Cannonicalization: \n"
+ << comp->parent()->ToString();
+ }
+ HloInstruction* lhs_bmm = bmm->mutable_operand(0);
+ HloInstruction* rhs_bmm = bmm->mutable_operand(1);
+ const DotDimensionNumbers& dnums = bmm->dot_dimension_numbers();
+
+ int64_t rank = bmm->shape().dimensions_size();
+ std::vector<int64_t> perm(rank);
+ std::iota(perm.begin(), perm.end(), 0);
+ // Swap the non-contracting dims of BMM shape. By contract, the
+ // non-contracting dims in the output are the last two dimensions.
+ std::swap(perm[rank - 1], perm[rank - 2]);
+
+ DotDimensionNumbers new_dnums = dnums;
+ std::swap(*new_dnums.mutable_lhs_contracting_dimensions(),
+ *new_dnums.mutable_rhs_contracting_dimensions());
+ std::swap(*new_dnums.mutable_lhs_batch_dimensions(),
+ *new_dnums.mutable_rhs_batch_dimensions());
+ auto original_bmm_shape = bmm->shape();
+ HloInstruction* new_dot = comp->AddInstruction(HloInstruction::CreateDot(
+ ShapeUtil::MakeShape(original_bmm_shape.element_type(),
+ Permute(original_bmm_shape.dimensions(), perm)),
+ /* lhs */ rhs_bmm, /* rhs */ lhs_bmm, new_dnums,
+ bmm->precision_config()));
+
+ TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+ bmm, HloInstruction::CreateTranspose(original_bmm_shape, new_dot, perm)));
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "After FMHA Dot Cannonicalization: \n"
+ << comp->parent()->ToString();
+ }
+ return new_dot;
+}
+
+absl::StatusOr<HloInstruction*> ChangeCheckedDimToFastest(
+ HloComputation* comp, HloInstruction* bmm, bool is_lhs,
+ bool should_contracting_be_fastest) {
+ const DotDimensionNumbers& dot_dims_bmm = bmm->dot_dimension_numbers();
+ DotDimensionNumbers new_dot_dims_bmm = dot_dims_bmm;
+ int64_t bmm_operand = is_lhs ? 0 : 1;
+ absl::Span<const int64_t> contracting_dims =
+ is_lhs ? dot_dims_bmm.lhs_contracting_dimensions()
+ : dot_dims_bmm.rhs_contracting_dimensions();
+ absl::Span<const int64_t> batch_dims =
+ is_lhs ? dot_dims_bmm.lhs_batch_dimensions()
+ : dot_dims_bmm.rhs_batch_dimensions();
+ absl::Span<const int64_t> lhs_minor_to_major_bmm =
+ bmm->operand(0)->shape().layout().minor_to_major();
+ absl::Span<const int64_t> rhs_minor_to_major_bmm =
+ bmm->operand(1)->shape().layout().minor_to_major();
+
+ absl::Span<const int64_t>& minor_to_major_to_check =
+ is_lhs ? lhs_minor_to_major_bmm : rhs_minor_to_major_bmm;
+
+ CHECK_EQ(contracting_dims.size(), 1);
+ TF_ASSIGN_OR_RETURN(std::vector<int64_t> non_contracting_dims,
+ GetNonContractingDims(bmm->operand(bmm_operand)->shape(),
+ batch_dims, contracting_dims));
+ CHECK_EQ(non_contracting_dims.size(), 1);
+ HloInstruction* operand_bmm = bmm->mutable_operand(bmm_operand);
+ int64_t hidden_dim = should_contracting_be_fastest ? contracting_dims[0]
+ : non_contracting_dims[0];
+ int64_t minor_dim = minor_to_major_to_check[0];
+ // If the hidden dim of the target operand is not the fastest moving
+ // dimension, make it so.
+ if (minor_dim != hidden_dim) {
+ std::vector<int64_t> perm(bmm->shape().dimensions_size());
+ std::iota(perm.begin(), perm.end(), 0);
+ std::swap(perm[hidden_dim], perm[minor_dim]);
+
+ if (is_lhs) {
+ new_dot_dims_bmm.set_lhs_contracting_dimensions(0,
+ non_contracting_dims[0]);
+ } else {
+ new_dot_dims_bmm.set_rhs_contracting_dimensions(0,
+ non_contracting_dims[0]);
+ }
+
+ operand_bmm = comp->AddInstruction(
+ HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShapeWithDenseLayout(
+ bmm->shape().element_type(),
+ Permute(operand_bmm->shape().dimensions(), perm),
+ minor_to_major_to_check),
+ operand_bmm, perm),
+ &operand_bmm->metadata());
+ *((DynCast<HloDotInstruction>(bmm))->mutable_dot_dimension_numbers()) =
+ new_dot_dims_bmm;
+ }
+ return operand_bmm;
+}
+
+absl::StatusOr<HloInstruction*> FuseFwdMultiHeadedAttentionBlock(
+ HloComputation* comp, HloInstruction* bmm_1, HloInstruction* bmm_2,
+ HloInstruction* bias, HloInstruction* scale, HloInstruction* reduce_sum,
+ HloInstruction* softmax_input, double dropout_rate,
+ std::string& custom_call_name, stream_executor::CudaComputeCapability cc,
+ bool is_training, bool& changed, bool& v_transposed, bool is_causal_mask) {
+ double scale_value = 1.0;
+ HloInstruction* lhs_bmm1;
+ HloInstruction* rhs_bmm1;
+ HloInstruction* rhs_bmm2;
+ DotDimensionNumbers orig_bmm1_dot_dim = bmm_1->dot_dimension_numbers();
+ DotDimensionNumbers orig_bmm2_dot_dim = bmm_2->dot_dimension_numbers();
+
+ TF_ASSIGN_OR_RETURN(rhs_bmm1, ChangeCheckedDimToFastest(
+ comp, bmm_1, false /*is_lhs*/,
+ true /*should_contracting_be_fastest*/));
+ TF_ASSIGN_OR_RETURN(lhs_bmm1, ChangeCheckedDimToFastest(
+ comp, bmm_1, true /*is_lhs*/,
+ true /*should_contracting_be_fastest*/));
+
+ TF_ASSIGN_OR_RETURN(rhs_bmm2, ChangeCheckedDimToFastest(
+ comp, bmm_2, false /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+
+ if (rhs_bmm2 != bmm_2->mutable_operand(1)) {
+ v_transposed = true;
+ }
+
+ GpuBackendConfig gpu_config;
+ CudnnfMHABackendConfig& fmha_config =
+ *gpu_config.mutable_cudnn_fmha_backend_config();
+
+ *fmha_config.mutable_bmm1_dot_dimension_numbers() =
+ bmm_1->dot_dimension_numbers();
+ *fmha_config.mutable_bmm2_dot_dimension_numbers() =
+ bmm_2->dot_dimension_numbers();
+
+ TF_RET_CHECK((dropout_rate >= 0.0 && dropout_rate <= 1.0));
+ // Restore original DotDimensionNumbers.
+ *((DynCast<HloDotInstruction>(bmm_1))->mutable_dot_dimension_numbers()) =
+ orig_bmm1_dot_dim;
+ *((DynCast<HloDotInstruction>(bmm_2))->mutable_dot_dimension_numbers()) =
+ orig_bmm2_dot_dim;
+
+ // If scale node is assigned, extract value from it.
+ if (scale != nullptr) {
+ std::optional<double> value;
+ value = GetConstantValue(scale);
+ TF_RET_CHECK(value.has_value());
+ scale_value = (double)*value;
+ }
+
+ fmha_config.set_fmha_scale(scale_value);
+ fmha_config.set_dropout_rate(dropout_rate);
+ // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
+ // graph.
+ // TODO Find a way to compute original seed from dropout keys.
+ fmha_config.set_seed(42);
+
+ *fmha_config.mutable_intermediate_tensor_shape() = bmm_1->shape().ToProto();
+ {
+ auto* algorithm = fmha_config.mutable_algorithm();
+ algorithm->set_algo_id(0); // engine id
+ algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+ std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
+ std::vector<int64_t> knob_vals = {1, 0};
+ for (int i = 0; i < knob_ids.size(); ++i) {
+ (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
+ }
+ algorithm->set_is_cudnn_frontend(true);
+ algorithm->mutable_workspace_size()->set_value(0);
+ }
+ // set is_causal_mask here
+ // choose to generate causal mask inside cuDNN attention or not
+ fmha_config.set_mask_type(is_causal_mask ? CudnnfMHABackendConfig::CAUSAL
+ : CudnnfMHABackendConfig::NO_MASK);
+
+ const Shape& output_shape = bmm_2->shape();
+
+ Shape call_shape;
+ // Activation output is used by backward gemm.
+ HloInstruction* activation_output = nullptr;
+
+ // Output Order: {O, Fwd act*, workspace}
+ std::vector<Shape> output_shapes = {output_shape};
+ if (is_training) {
+ activation_output = bmm_2->mutable_operand(0);
+ // Sometimes activation output is bitcast, the actual activation is the
+ // other user of the producer of bmm_2's first operand.
+ if (activation_output->user_count() < 2 &&
+ activation_output->opcode() == HloOpcode::kBitcast) {
+ HloInstruction* producer = activation_output->mutable_operand(0);
+ TF_RET_CHECK(producer->user_count() == 2);
+ HloInstruction* bmm2_grad2_user =
+ producer->users()[0] == activation_output ? producer->users()[1]
+ : producer->users()[0];
+ // might be (transpose) - bmm2_grad2
+ if (IsBatchedMatmul(bmm2_grad2_user)) {
+ activation_output = producer;
+ } else if (bmm2_grad2_user->opcode() == HloOpcode::kTranspose) {
+ activation_output = bmm2_grad2_user;
+ } else {
+ return Internal("Unexpected activation patterns");
+ }
+ }
+ // if it is flash attention, should output softmax stats to the bwd
+ TF_RET_CHECK(reduce_sum != nullptr);
+ output_shapes.push_back(
+ ShapeUtil::MakeShape(F32, reduce_sum->shape().dimensions()));
+ }
+ output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
+ call_shape = ShapeUtil::MakeTupleShape(output_shapes);
+
+ // Input Order: {Q, K, V, bias*}
+ std::vector<HloInstruction*> operands = {lhs_bmm1, rhs_bmm1, rhs_bmm2};
+ if (!is_causal_mask && bias != nullptr) {
+ HloInstruction* original_bias;
+ HloInstruction* original_broadcast;
+ // There will be cases where the bias is up-casted to wider float type,
+ // we need to take the original bias node and broadcast it without
+ // converting.
+ if (Match(bias, m::Broadcast(
+ &original_broadcast,
+ m::Convert(
+ m::Op(&original_bias)
+ .WithPredicate([](const HloInstruction* instr) {
+ return instr->shape().element_type() == F16 ||
+ instr->shape().element_type() == BF16;
+ }))
+ .WithPredicate([](const HloInstruction* instr) {
+ return instr->shape().element_type() == F32 ||
+ instr->shape().element_type() == F64;
+ })))) {
+ absl::Span<const int64_t> original_bcast_dims =
+ (DynCast<HloBroadcastInstruction>(original_broadcast))->dimensions();
+ // This is to deal with cases like paxml where an extra dimension of 1 is
+ // added to the left of the tensor.
+ // TODO Make this logic more generic
+ absl::Span<const int64_t> original_broadcast_shape_dims =
+ original_broadcast->shape().dimensions();
+ int64_t starting_index = original_broadcast_shape_dims.size() == 5 &&
+ original_broadcast_shape_dims[0] == 1
+ ? 1
+ : 0;
+ std::vector<int64_t> bcast_dimensions;
+ for (auto& dim : original_bcast_dims) {
+ bcast_dimensions.push_back(dim - starting_index);
+ }
+
+ const Shape& bcast_shape = bmm_1->shape();
+ bias = comp->AddInstruction(HloInstruction::CreateBroadcast(
+ bcast_shape, original_bias, bcast_dimensions));
+ }
+ operands.push_back(bias);
+ }
+
+ HloInstruction* fmha_call =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ call_shape, operands, absl::string_view(custom_call_name)));
+ TF_RETURN_IF_ERROR(fmha_call->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(SetFMHAInstructionName(bmm_1->GetModule(), fmha_call));
+
+ TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+ bmm_2,
+ HloInstruction::CreateGetTupleElement(bmm_2->shape(), fmha_call, 0)));
+
+ if (activation_output) {
+ HloInstruction* activation_gte =
+ comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+ activation_output->shape(), fmha_call, 1));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstructionWithDifferentShape(
+ activation_output, activation_gte,
+ /*preserve_sharding=*/false,
+ /*relay_control_dependency=*/false,
+ /*remove_unused_operands=*/false)
+ .status());
+ }
+
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "After CudnnFusedMHARewriter: \n" << comp->parent()->ToString();
+ }
+ changed = true;
+ return fmha_call;
+}
+
+absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
+ HloComputation* comp, HloInstruction* bmm_1_grad_1,
+ HloInstruction* bmm_1_grad_2, HloInstruction* bmm_2_grad_1,
+ HloInstruction* bmm_2_grad_2, HloInstruction* fwd_fmha_call,
+ HloInstruction* dbias, HloInstruction* bias,
+ std::string& bwd_custom_call_name) {
+ HloInstruction* rhs_bmm1_grad_gemm1;
+ HloInstruction* lhs_bmm1_grad_gemm2;
+ HloInstruction* rhs_bmm2_grad_gemm2;
+ HloInstruction* d_output_grad;
+
+ DotDimensionNumbers orig_bmm1_grad1_config =
+ bmm_1_grad_1->dot_dimension_numbers();
+ DotDimensionNumbers orig_bmm1_grad2_config =
+ bmm_1_grad_2->dot_dimension_numbers();
+ DotDimensionNumbers orig_bmm2_grad1_config =
+ bmm_2_grad_1->dot_dimension_numbers();
+ DotDimensionNumbers orig_bmm2_grad2_config =
+ bmm_2_grad_2->dot_dimension_numbers();
+
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ fwd_fmha_call->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& fwd_config =
+ gpu_config.cudnn_fmha_backend_config();
+ bool is_causal_mask =
+ fwd_config.mask_type() == CudnnfMHABackendConfig::CAUSAL;
+ CudnnfMHABackendConfig bwd_fmha_config;
+ // Q tensor
+ TF_ASSIGN_OR_RETURN(
+ rhs_bmm1_grad_gemm1,
+ ChangeCheckedDimToFastest(comp, bmm_1_grad_1, false /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+ // K tensor
+ TF_ASSIGN_OR_RETURN(
+ lhs_bmm1_grad_gemm2,
+ ChangeCheckedDimToFastest(comp, bmm_1_grad_2, false /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+
+ // Forward activation
+ // softmax_stats
+ HloInstruction* fwd_act;
+ int64_t fwd_act_index = 1;
+ fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+ fwd_fmha_call->shape().tuple_shapes(fwd_act_index), fwd_fmha_call,
+ fwd_act_index));
+
+ // V tensor
+ TF_ASSIGN_OR_RETURN(
+ rhs_bmm2_grad_gemm2,
+ ChangeCheckedDimToFastest(comp, bmm_2_grad_2, false /*is_lhs*/,
+ true /*should_contracting_be_fastest*/));
+ // d output to bmm2_grad2
+ // Since d_o is the input of 2 bmms, we set the dim number using the
+ // constraint
+ // -> the contracting dimension of the lhs of bmm_2_grad_2 needs to be the
+ // fastest moving dimension.
+ TF_ASSIGN_OR_RETURN(
+ d_output_grad,
+ ChangeCheckedDimToFastest(comp, bmm_2_grad_2, true /*is_lhs*/,
+ true /*should_contracting_be_fastest*/));
+ // d output to bmm2_grad1
+ // we don't use this value but we call this to make sure dot number is being
+ // set correctly
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * bmm_2_grad_1_rhs,
+ ChangeCheckedDimToFastest(comp, bmm_2_grad_1, false /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+ (void)bmm_2_grad_1_rhs;
+ // Operand order: {Q, K, V, Fwd act, d_o, bias*, O*}
+ std::vector<HloInstruction*> operands = {
+ rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, fwd_act,
+ d_output_grad};
+
+ // For flash attention, add fwd output to input list
+ if (!is_causal_mask && bias) {
+ operands.push_back(bias);
+ }
+ HloInstruction* fwd_output;
+ for (auto user : fwd_fmha_call->users()) {
+ if (user->opcode() == HloOpcode::kGetTupleElement &&
+ user->tuple_index() == 0) {
+ fwd_output = user;
+ }
+ }
+ // should be able to find the instruction
+ TF_RET_CHECK(fwd_output != nullptr);
+ // check dO and O have the same layout as it is required by cuDNN
+ TF_RET_CHECK(fwd_output->shape() == d_output_grad->shape());
+ operands.push_back(fwd_output);
+
+ *bwd_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
+ bmm_1_grad_1->dot_dimension_numbers();
+ *bwd_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
+ bmm_1_grad_2->dot_dimension_numbers();
+ *bwd_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
+ bmm_2_grad_1->dot_dimension_numbers();
+ *bwd_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
+ bmm_2_grad_2->dot_dimension_numbers();
+
+ // Restore original DotDimensionNumbers
+ *((DynCast<HloDotInstruction>(bmm_1_grad_1))
+ ->mutable_dot_dimension_numbers()) = orig_bmm1_grad1_config;
+ *((DynCast<HloDotInstruction>(bmm_1_grad_2))
+ ->mutable_dot_dimension_numbers()) = orig_bmm1_grad2_config;
+ *((DynCast<HloDotInstruction>(bmm_2_grad_1))
+ ->mutable_dot_dimension_numbers()) = orig_bmm2_grad1_config;
+ *((DynCast<HloDotInstruction>(bmm_2_grad_2))
+ ->mutable_dot_dimension_numbers()) = orig_bmm2_grad2_config;
+
+ bwd_fmha_config.set_fmha_scale(fwd_config.fmha_scale());
+ bwd_fmha_config.set_dropout_rate(fwd_config.dropout_rate());
+ // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
+ // graph.
+ // TODO Find a way to compute original seed from dropout keys.
+ bwd_fmha_config.set_seed(fwd_config.seed());
+ bwd_fmha_config.set_mask_type(is_causal_mask
+ ? CudnnfMHABackendConfig::CAUSAL
+ : CudnnfMHABackendConfig::NO_MASK);
+
+ *bwd_fmha_config.mutable_intermediate_tensor_shape() =
+ fwd_config.intermediate_tensor_shape();
+ {
+ auto* algorithm = bwd_fmha_config.mutable_algorithm();
+ algorithm->set_algo_id(0); // engine id
+ algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+ std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
+ std::vector<int64_t> knob_vals = {1, 0};
+ for (int i = 0; i < knob_ids.size(); ++i) {
+ (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
+ }
+ algorithm->set_is_cudnn_frontend(true);
+ algorithm->mutable_workspace_size()->set_value(0);
+ }
+
+ // Output order:
+ // {dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), dbias*, workspace}
+ std::vector<Shape> output_shapes = {
+ bmm_1_grad_2->shape(), bmm_1_grad_1->shape(), bmm_2_grad_1->shape()};
+
+ if (dbias) {
+ // Cudnn kernel only outputs dbias in this shape [1, num_heads, seq, seq],
+ // so we add a dimension of 1 to existing dbias' shape.
+ std::vector<int64_t> dbias_shape_vector =
+ SpanToVector(dbias->shape().dimensions());
+ dbias_shape_vector.insert(dbias_shape_vector.begin(), 1);
+ Shape cudnn_dbias_shape =
+ ShapeUtil::MakeShape(dbias->shape().element_type(), dbias_shape_vector);
+ output_shapes.push_back(cudnn_dbias_shape);
+ }
+ // Reserved placeholder for workspace
+ output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
+ Shape call_shape = ShapeUtil::MakeTupleShape(output_shapes);
+ HloInstruction* fmha_bwd_call =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ call_shape, operands, absl::string_view(bwd_custom_call_name)));
+ GpuBackendConfig bwd_gpu_config;
+ *bwd_gpu_config.mutable_cudnn_fmha_backend_config() = bwd_fmha_config;
+ TF_RETURN_IF_ERROR(fmha_bwd_call->set_backend_config(bwd_gpu_config));
+ TF_RETURN_IF_ERROR(
+ SetFMHAInstructionName(bmm_1_grad_1->GetModule(), fmha_bwd_call));
+
+ // Q gradient
+ TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+ bmm_1_grad_2, HloInstruction::CreateGetTupleElement(bmm_1_grad_2->shape(),
+ fmha_bwd_call, 0)));
+ // K gradient
+ TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+ bmm_1_grad_1, HloInstruction::CreateGetTupleElement(bmm_1_grad_1->shape(),
+ fmha_bwd_call, 1)));
+ // V gradient
+ TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+ bmm_2_grad_1, HloInstruction::CreateGetTupleElement(bmm_2_grad_1->shape(),
+ fmha_bwd_call, 2)));
+
+ if (dbias) {
+ // Reshape fmha dbias output to original user's input shape.
+ // If the reshape doesn't involve physical data movement,
+ // algebraic simplifer can change it to a no-op bitcast.
+ Shape original_shape = dbias->shape();
+ HloInstruction* dbias_user = dbias->users()[0];
+ HloInstruction* cudnn_dbias_output =
+ comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+ output_shapes[3], fmha_bwd_call, 3));
+ HloInstruction* reshape_dbias = comp->AddInstruction(
+ HloInstruction::CreateReshape(original_shape, cudnn_dbias_output));
+ TF_RETURN_IF_ERROR(dbias_user->ReplaceOperandWith(
+ dbias_user->operand_index(dbias), reshape_dbias));
+
+ TF_RETURN_IF_ERROR(
+ comp->ReplaceInstructionWithDifferentShape(dbias, cudnn_dbias_output));
+ }
+ return true;
+}
+
+absl::Status RestoreFwdGraph(
+ HloComputation* comp, HloInstruction* fwd_fmha_call, HloInstruction* bmm2,
+ HloInstruction* activation, HloInstruction* original_bmm2_producer0,
+ HloInstruction* original_bmm2_producer1,
+ std::vector<HloInstruction*>& original_activation_producers,
+ bool bmm_2_need_canonicalization) {
+ // If backward pattern is not matched, we need to restore the
+ // original graph structure.
+ // Replacing new GTEs added by forward FMHA call with cloned old
+ // activations and bmm2.
+ HloInstruction* output_gte = fwd_fmha_call->users()[0];
+ HloInstruction* activation_gte = fwd_fmha_call->users()[1];
+ std::string suffix = "fmha_no_match_clone";
+ HloInstruction* cloned_activation =
+ comp->AddInstruction(activation->CloneWithNewOperands(
+ activation->shape(), original_activation_producers, suffix));
+
+ // Since old activation is detached by forward FMHA rewrite, we need
+ // to use the newly cloned activation.
+ HloInstruction* lhs = activation == original_bmm2_producer0
+ ? cloned_activation
+ : original_bmm2_producer0;
+ HloInstruction* rhs = activation == original_bmm2_producer0
+ ? original_bmm2_producer1
+ : cloned_activation;
+ HloInstruction* cloned_bmm2 = comp->AddInstruction(
+ bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix));
+ if (bmm_2_need_canonicalization) {
+ TF_RET_CHECK(output_gte->users()[0]->opcode() == HloOpcode::kTranspose);
+ TF_RETURN_IF_ERROR(
+ comp->ReplaceInstruction(output_gte->users()[0], cloned_bmm2));
+ } else {
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
+ }
+ TF_RETURN_IF_ERROR(
+ comp->ReplaceInstruction(activation_gte, cloned_activation));
+ return absl::OkStatus();
+}
+} // namespace
+
+absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool any_changed = false;
+ // we use this set to keep track of all already matched attention block
+ absl::flat_hash_set<HloInstruction*> matched_bmm1;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ const DebugOptions& debug_options =
+ comp->parent()->config().debug_options();
+ const se::dnn::VersionInfo cudnn_version =
+ GetDnnVersionInfoOrDefault(stream_executor_, cudnn_version_);
+#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
+ // CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware.
+ // Some cuDNN versions work with CUDA 11, but it is impractical for us to
+ // test those combinations so just disable them.
+ return false;
+#endif
+ if (!debug_options.xla_gpu_enable_cudnn_fmha() ||
+ !IsComputeCapabilityAndCudnnSupported(
+ compute_capability_, cudnn_version,
+ stream_executor::dnn::VersionInfo(9, 0, 0))) {
+ return false;
+ }
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ bool v_transposed = false;
+ bool changed = false;
+ MatchFwdResult matched_result =
+ MatchFwdMHAPatternsForCanonicalization(instr);
+ if (!matched_result.has_match) {
+ continue;
+ }
+ // We check the validity of bmms here before canonicalization so we don't
+ // modify the graph if mha fusion is not possible
+ TF_ASSIGN_OR_RETURN(
+ bool is_mha_module_supported,
+ IsMHABlockSupported(
+ matched_result.matched_bmm_1, matched_result.matched_bmm_2,
+ matched_result.need_canonicalization, matched_result.is_training,
+ matched_result.is_causal_mask,
+ matched_result.matched_custom_call_name, debug_options,
+ compute_capability_, cudnn_version));
+
+ if (!is_mha_module_supported) continue;
+ // If we have an activation with more than 1 users in non-training mode,
+ // we cannot rewrite the graph. So skip processing the rest.
+ HloInstruction* activation =
+ matched_result.need_canonicalization
+ ? matched_result.matched_bmm_2->mutable_operand(1)
+ : matched_result.matched_bmm_2->mutable_operand(0);
+ if (!matched_result.is_training && activation->user_count() > 1) {
+ VLOG(2)
+ << "Activation: " << activation->ToString()
+ << " cannot have more than 1 users in non-training mode. Skipping.";
+ continue;
+ }
+ HloInstruction* original_bmm2_producer0 =
+ matched_result.matched_bmm_2->mutable_operand(0);
+ HloInstruction* original_bmm2_producer1 =
+ matched_result.matched_bmm_2->mutable_operand(1);
+
+ HloInstruction* original_bmm2 = matched_result.matched_bmm_2;
+ std::vector<HloInstruction*> original_activation_producers;
+ for (HloInstruction* operand : activation->mutable_operands()) {
+ original_activation_producers.push_back(operand);
+ }
+ // We make sure no attention block is matched and replaced twice here
+ if (!matched_bmm1.insert(matched_result.matched_bmm_1).second) {
+ continue;
+ }
+ // If we need to canonicalize the bmm, we will assign the newly
+ // canonicalized bmm to bmm_2.
+ if (matched_result.need_canonicalization) {
+ TF_ASSIGN_OR_RETURN(matched_result.matched_bmm_2,
+ CanonicalizeBatchedGemmForcuDNNFMHA(
+ matched_result.matched_bmm_2, comp));
+ }
+
+ // Fuse the bmms and intermediate nodes into fMHA call, the fused call
+ // will replace bmm_2.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * fwd_fmha_call,
+ FuseFwdMultiHeadedAttentionBlock(
+ comp, matched_result.matched_bmm_1, matched_result.matched_bmm_2,
+ matched_result.matched_bias, matched_result.matched_scale,
+ matched_result.matched_reduce_sum,
+ matched_result.matched_softmax_input,
+ matched_result.matched_dropout_rate,
+ matched_result.matched_custom_call_name, compute_capability_,
+ matched_result.is_training, changed, v_transposed,
+ matched_result.is_causal_mask));
+ any_changed |= changed;
+ if (matched_result.is_training) {
+ MatchBwdResult matched_bwd_result =
+ MatchBwdMHAPatternsForCanonicalization(
+ fwd_fmha_call, matched_result.matched_bmm_1, v_transposed);
+ if (!matched_bwd_result.has_match) {
+ VLOG(2) << "Backward pattern not matching, skipping.";
+ // restore fwd graph if bwd pattern match failed
+ TF_RETURN_IF_ERROR(
+ RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
+ original_bmm2_producer0, original_bmm2_producer1,
+ original_activation_producers,
+ matched_result.need_canonicalization));
+ continue;
+ }
+ if (matched_bwd_result.matched_dbias &&
+ !(compute_capability_.IsAtLeastHopper() &&
+ cudnn_version >= stream_executor::dnn::VersionInfo(9, 0, 0))) {
+ VLOG(2) << "Flash attention dbias requires cudnn 9.0.0 + hopper.";
+ // restore fwd graph if bwd pattern match failed
+ TF_RETURN_IF_ERROR(
+ RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
+ original_bmm2_producer0, original_bmm2_producer1,
+ original_activation_producers,
+ matched_result.need_canonicalization));
+ continue;
+ }
+ // Canonicalize gemms
+ if (matched_bwd_result.bmm_1_grad_1_need_canonicalization) {
+ TF_ASSIGN_OR_RETURN(
+ matched_bwd_result.matched_bmm_1_grad_1,
+ CanonicalizeBatchedGemmForcuDNNFMHA(
+ matched_bwd_result.matched_bmm_1_grad_1, comp));
+ }
+ if (matched_bwd_result.bmm_1_grad_2_need_canonicalization) {
+ TF_ASSIGN_OR_RETURN(
+ matched_bwd_result.matched_bmm_1_grad_2,
+ CanonicalizeBatchedGemmForcuDNNFMHA(
+ matched_bwd_result.matched_bmm_1_grad_2, comp));
+ }
+ if (matched_bwd_result.bmm_2_grad_1_need_canonicalization) {
+ TF_ASSIGN_OR_RETURN(
+ matched_bwd_result.matched_bmm_2_grad_1,
+ CanonicalizeBatchedGemmForcuDNNFMHA(
+ matched_bwd_result.matched_bmm_2_grad_1, comp));
+ }
+ if (matched_bwd_result.bmm_2_grad_2_need_canonicalization) {
+ TF_ASSIGN_OR_RETURN(
+ matched_bwd_result.matched_bmm_2_grad_2,
+ CanonicalizeBatchedGemmForcuDNNFMHA(
+ matched_bwd_result.matched_bmm_2_grad_2, comp));
+ }
+
+ // Fuse the corresponding gradient graph to an fMHA fused call.s
+ TF_ASSIGN_OR_RETURN(
+ changed,
+ FuseBwdMultiHeadedAttentionBlock(
+ comp, matched_bwd_result.matched_bmm_1_grad_1,
+ matched_bwd_result.matched_bmm_1_grad_2,
+ matched_bwd_result.matched_bmm_2_grad_1,
+ matched_bwd_result.matched_bmm_2_grad_2, fwd_fmha_call,
+ matched_bwd_result.matched_dbias, matched_result.matched_bias,
+ matched_bwd_result.matched_custom_call_name));
+ any_changed |= changed;
+ }
+ }
+ }
+
+ return any_changed;
+}
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h
new file mode 100644
index 0000000..6a985ee
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h
@@ -0,0 +1,59 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedMHARewriter : public HloModulePass {
+ public:
+ explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
+ se::StreamExecutor* stream_executor)
+ : compute_capability_(cc), stream_executor_(stream_executor) {}
+
+ explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
+ se::dnn::VersionInfo cudnn_version)
+ : compute_capability_(cc), cudnn_version_(cudnn_version) {}
+
+ absl::string_view name() const override {
+ return "cudnn-fused-multi-headed-attention-rewriter";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const se::CudaComputeCapability compute_capability_;
+ se::StreamExecutor* stream_executor_ = nullptr;
+ const se::dnn::VersionInfo cudnn_version_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc
new file mode 100644
index 0000000..a64fd062
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc
@@ -0,0 +1,3021 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h"
+
+#include <cstddef>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/computation_layout.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h"
+#include "xla/service/hlo_cse.h"
+#include "xla/service/hlo_dce.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/hlo_verifier.h"
+#include "xla/service/layout_normalization.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/reshape_decomposer.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/test_helpers.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
+#endif
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = xla::match;
+
+class CudnnFusedMhaRewriterTestHloTest : public HloTestBase {
+ public:
+ se::CudaComputeCapability GetCudaComputeCapability() {
+ // Fake a supported compute capability to run tests,
+ // we don't run any kernels in these tests so they should be safe
+ // to run anywhere.
+ return se::CudaComputeCapability(8, 0);
+ }
+
+ se::CudaComputeCapability GetRealCudaComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability();
+ }
+
+ se::dnn::VersionInfo GetCudnnVersion() {
+ // Fake a supported compute capability to run tests,
+ // we don't run any kernels in these tests so they should be safe
+ // to run anywhere.
+ return se::dnn::VersionInfo(9, 0, 0);
+ }
+
+ CudnnFusedMhaRewriterTestHloTest()
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/false,
+ /*instruction_can_change_layout_func=*/{}) {
+#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
+ skip_reason_ = "cuDNN fused MHA requires CUDA 12 or later.";
+ return;
+#endif
+ }
+
+ protected:
+ size_t CountFusedAttentionCall(HloModule* module, bool is_backward = false) {
+ return absl::c_count_if(module->entry_computation()->instructions(),
+ [&](const HloInstruction* instr) {
+ if (is_backward) {
+ return IsBwdCustomCallTofMHA(*instr);
+ } else {
+ return IsFwdCustomCallTofMHA(*instr);
+ }
+ });
+ }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ auto debug_options = HloTestBase::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_cudnn_fmha(true);
+ debug_options.set_xla_gpu_fused_attention_use_cudnn_rng(true);
+ return debug_options;
+ }
+
+ HloModuleConfig GetModuleConfig() {
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ HloModuleConfig config_with_fmha;
+ config_with_fmha.set_debug_options(debug_options);
+ return config_with_fmha;
+ }
+
+ // Centralize skip checks in the constructor. Unfortunately we cannot call
+ // GTEST_SKIP from the constructor. Instead, we set (if needed) `skip_reason`,
+ // and then check it from all test fixtures.
+ // An alternative would be to use the SetUp() override, but for this to be
+ // correct we'd have to ensure that all the parents' SetUp() methods are
+ // called, which is error prone.
+ std::optional<absl::string_view> skip_reason_;
+};
+
+constexpr absl::string_view
+ hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+
+region_0.7 {
+ Arg_0.8 = bf16[] parameter(0)
+ Arg_1.9 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+ Arg_0.20 = f32[] parameter(0)
+ Arg_1.21 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.6 {
+ Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+ Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
+ Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
+ dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+ constant = bf16[] constant(-inf)
+ reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
+ broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+ subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
+ exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+ convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
+ constant.1 = f32[] constant(0)
+ reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+ convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
+ broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+ divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
+ ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+})";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1SoftmaxBmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto m, ParseAndReturnVerifiedModule(
+ hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
+ EXPECT_TRUE(result);
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+ .WithShape(BF16, {16, 16, 256, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
+ 2);
+}
+
+constexpr absl::string_view
+ hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+
+region_0.7 {
+ Arg_0.8 = bf16[] parameter(0)
+ Arg_1.9 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+ Arg_0.20 = f32[] parameter(0)
+ Arg_1.21 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.6 {
+ Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+ Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
+ Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
+ dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+ constant = bf16[] constant(-inf)
+ reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
+ broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+ subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
+ exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+ convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
+ constant.1 = f32[] constant(0)
+ reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+ convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
+ broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+ divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
+ ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+})";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1SoftmaxBmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto m, ParseAndReturnVerifiedModule(
+ hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
+ EXPECT_TRUE(result);
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+ .WithShape(BF16, {16, 16, 256, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(config.bmm1_dot_dimension_numbers().lhs_contracting_dimensions()[0],
+ 2);
+ EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
+ 2);
+}
+
+constexpr absl::string_view
+ hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+
+region_0.7 {
+ Arg_0.8 = bf16[] parameter(0)
+ Arg_1.9 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+ Arg_0.20 = f32[] parameter(0)
+ Arg_1.21 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.6 {
+ Arg_2.3 = bf16[16,16,256,64]{2,3,1,0} parameter(2)
+ Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
+ Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
+ dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+ constant = bf16[] constant(-inf)
+ reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
+ broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+ subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
+ exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+ convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
+ constant.1 = f32[] constant(0)
+ reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+ convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
+ broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+ divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
+ ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+})";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1SoftmaxBmm2Pattern_bmm2_non_contracting_dim_not_most_minor) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto m, ParseAndReturnVerifiedModule(
+ hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
+ EXPECT_TRUE(result);
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+ .WithShape(BF16, {16, 16, 256, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(config.bmm2_dot_dimension_numbers().lhs_contracting_dimensions()[0],
+ 3);
+ EXPECT_EQ(config.bmm2_dot_dimension_numbers().rhs_contracting_dimensions()[0],
+ 3);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1CombinedMaskBiasSoftmaxBmm2) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_,
+entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
+
+region_0.32.clone {
+ Arg_0.0 = f32[] parameter(0)
+ Arg_1.0 = f32[] parameter(1)
+ ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
+}
+
+region_1.44 {
+ Arg_0.45 = f32[] parameter(0)
+ Arg_1.46 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+
+ENTRY main.61 {
+ Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+ transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
+ Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+ transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
+ Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+ transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
+ Arg_4.5 = pred[16,1,256,256]{3,2,1,0} parameter(4), sharding={replicated}
+ bitcast.35 = pred[16,256,256]{2,1,0} bitcast(Arg_4.5)
+ convert.49 = s32[16,256,256]{2,1,0} convert(bitcast.35)
+ constant.5 = s32[] constant(0)
+ broadcast.10 = s32[16,256,256]{2,1,0} broadcast(constant.5), dimensions={}
+ compare = pred[16,256,256]{2,1,0} compare(convert.49, broadcast.10), direction=GT
+ constant.7 = bf16[] constant(0)
+ broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.7), dimensions={}
+ constant.9 = bf16[] constant(-9.999e+09)
+ broadcast.13 = bf16[16,256,256]{2,1,0} broadcast(constant.9), dimensions={}
+ select = bf16[16,256,256]{2,1,0} select(compare, broadcast.12, broadcast.13)
+ convert.51 = f32[16,256,256]{2,1,0} convert(select)
+ broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.51), dimensions={0,2,3}
+ Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
+ bitcast.52 = bf16[16,256,256]{2,1,0} bitcast(Arg_3.4)
+ convert.52 = f32[16,256,256]{2,1,0} convert(bitcast.52)
+ broadcast.15 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.52), dimensions={1,2,3}
+ add.1 = f32[16,16,256,256]{3,2,1,0} add(broadcast.14, broadcast.15)
+ dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ convert.55 = f32[16,16,256,256]{3,2,1,0} convert(dot.2)
+ add.18 = f32[16,16,256,256]{3,2,1,0} add(convert.55, add.1)
+ constant.11 = f32[] constant(-inf)
+ reduce.36 = f32[16,16,256]{2,1,0} reduce(add.18, constant.11), dimensions={3}, to_apply=region_0.32.clone
+ broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+ subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.18, broadcast.17)
+ exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+ constant.14 = f32[] constant(0)
+ reduce.48 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.44
+ broadcast.18 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.48), dimensions={0,1,2}
+ divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.18)
+ convert.68 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
+ dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.68), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Transpose(
+ m::Transpose(m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
+ 0)))
+ .WithShape(BF16, {16, 256, 16, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(fmha->operands().size(), 4);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1UnfusedSoftmaxBmm2) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->f16[2,6,40,64]{3,2,1,0}}
+
+region_0.7 {
+ Arg_0.8 = f16[] parameter(0)
+ Arg_1.9 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+ Arg_0.20 = f32[] parameter(0)
+ Arg_1.21 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.31 {
+ Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
+ dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ constant = f16[] constant(-inf)
+ reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
+ broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+ subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
+ exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
+ convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
+ constant.1 = f32[] constant(0)
+ reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+ convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
+ broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+ divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
+ Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
+ ROOT dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+ .WithShape(F16, {2, 6, 40, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_FLOAT_EQ(config.fmha_scale(), 1.0);
+ EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0);
+ EXPECT_EQ(fmha->operands().size(), 3);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1ConvertedMaskAddedAfterFirstGemmSoftmaxBmm2) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
+
+region_0.27.clone {
+ Arg_0.0 = f32[] parameter(0)
+ Arg_1.0 = f32[] parameter(1)
+ ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
+}
+
+region_1.39 {
+ Arg_0.40 = f32[] parameter(0)
+ Arg_1.41 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.40, Arg_1.41)
+}
+
+ENTRY main.56 {
+ Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+ transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
+ Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+ transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
+ Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+ transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
+ dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ convert.47 = f32[16,16,256,256]{3,2,1,0} convert(dot)
+ Arg_3.4 = pred[16,1,256,256]{3,2,1,0} parameter(3), sharding={replicated}
+ bitcast.37 = pred[16,256,256]{2,1,0} bitcast(Arg_3.4)
+ convert.42 = s32[16,256,256]{2,1,0} convert(bitcast.37)
+ constant.6 = s32[] constant(0)
+ broadcast.9 = s32[16,256,256]{2,1,0} broadcast(constant.6), dimensions={}
+ compare = pred[16,256,256]{2,1,0} compare(convert.42, broadcast.9), direction=GT
+ constant.8 = bf16[] constant(0)
+ broadcast.11 = bf16[16,256,256]{2,1,0} broadcast(constant.8), dimensions={}
+ constant.10 = bf16[] constant(-9.999e+09)
+ broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.10), dimensions={}
+ select = bf16[16,256,256]{2,1,0} select(compare, broadcast.11, broadcast.12)
+ convert.48 = f32[16,256,256]{2,1,0} convert(select)
+ broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.48), dimensions={0,2,3}
+ add.2 = f32[16,16,256,256]{3,2,1,0} add(convert.47, broadcast.14)
+ constant.13 = f32[] constant(-inf)
+ reduce.31 = f32[16,16,256]{2,1,0} reduce(add.2, constant.13), dimensions={3}, to_apply=region_0.27.clone
+ broadcast.16 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.31), dimensions={0,1,2}
+ subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.2, broadcast.16)
+ exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+ constant.14 = f32[] constant(0)
+ reduce.43 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.39
+ broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.43), dimensions={0,1,2}
+ divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.17)
+ convert.63 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
+ dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.63), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Transpose(
+ m::Transpose(m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
+ 0)))
+ .WithShape(BF16, {16, 256, 16, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(fmha->operands().size(), 4);
+}
+
+// negative test
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1Bmm2Pattern_bmm1_contracting_dim_not_equal_64) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+ENTRY main.6 {
+ Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+ Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
+ Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
+ dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+ ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(&fmha, m::Dot(m::Parameter(0), m::Parameter(1)),
+ m::Parameter(2))
+ .WithShape(BF16, {16, 16, 256, 64})));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1Bmm2Pattern_bmm2_rhs_non_contracting_dim_not_equal_64) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0})->bf16[16,16,256,32]{3,2,1,0}}
+ENTRY main.6 {
+ Arg_2.3 = bf16[16,16,256,32]{3,2,1,0} parameter(2)
+ Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
+ Arg_1.2 = bf16[16,16,256,64]{3,2,1,0} parameter(1)
+ dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+ ROOT dot.1 = bf16[16,16,256,32]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(&fmha, m::Op(), m::Parameter(2))
+ .WithShape(BF16, {16, 16, 256, 32})));
+}
+
+// check if MHA is unsupported, canonicalization will not kick in
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1Bmm2PatternUncanonicalized_bmm1_contracting_dim_not_equal_64) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,64,256]{3,2,1,0}}
+
+ENTRY main.6 {
+ Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+ Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
+ Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
+ dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+ ROOT dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(Arg_2.3, dot.0), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(&fmha, m::Parameter(2), m::Op())
+ .WithShape(BF16, {16, 16, 64, 256})));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxDropoutBmm2) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
+
+region_0.34 {
+ Arg_0.35 = bf16[] parameter(0)
+ Arg_1.36 = bf16[] parameter(1)
+ ROOT maximum.37 = bf16[] maximum(Arg_0.35, Arg_1.36)
+}
+
+region_1.46 {
+ Arg_0.47 = f32[] parameter(0)
+ Arg_1.48 = f32[] parameter(1)
+ ROOT add.49 = f32[] add(Arg_0.47, Arg_1.48)
+}
+
+ENTRY main.82 {
+ Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+ copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
+ transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
+ Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+ copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+ transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
+ Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+ copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
+ transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
+ dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
+ reshape.31 = bf16[16,256,256]{2,1,0} reshape(Arg_3.4)
+ broadcast.32 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.31), dimensions={1,2,3}
+ add.33 = bf16[16,16,256,256]{3,2,1,0} add(dot, broadcast.32)
+ constant.21 = bf16[] constant(-inf)
+ reduce.38 = bf16[16,16,256]{2,1,0} reduce(add.33, constant.21), dimensions={3}, to_apply=region_0.34
+ broadcast.42 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.38), dimensions={0,1,2}
+ subtract.43 = bf16[16,16,256,256]{3,2,1,0} subtract(add.33, broadcast.42)
+ exponential.44 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.43)
+ convert.45 = f32[16,16,256,256]{3,2,1,0} convert(exponential.44)
+ constant.9 = f32[] constant(0)
+ reduce.50 = f32[16,16,256]{2,1,0} reduce(convert.45, constant.9), dimensions={3}, to_apply=region_1.46
+ convert.1 = bf16[16,16,256]{2,1,0} convert(reduce.50)
+ broadcast.55 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.1), dimensions={0,1,2}
+ divide.56 = bf16[16,16,256,256]{3,2,1,0} divide(exponential.44, broadcast.55)
+ constant.18 = u32[1]{0} constant({255383827})
+ constant.17 = u32[1]{0} constant({267815257})
+ constant.2 = u32[1]{0} constant({0})
+ constant.19 = u32[1]{0} constant({3213575472})
+ custom-call.26 = (u32[1]{0}, u32[1]{0}) custom-call(constant.18, constant.17, constant.2, constant.19), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
+ get-tuple-element.27 = u32[1]{0} get-tuple-element(custom-call.26), index=0
+ reshape.58 = u32[] reshape(get-tuple-element.27)
+ broadcast.62 = u32[32768]{0} broadcast(reshape.58), dimensions={}
+ get-tuple-element.28 = u32[1]{0} get-tuple-element(custom-call.26), index=1
+ reshape.59 = u32[] reshape(get-tuple-element.28)
+ broadcast.63 = u32[32768]{0} broadcast(reshape.59), dimensions={}
+ iota.57 = u32[65536]{0} iota(), iota_dimension=0
+ slice.60 = u32[32768]{0} slice(iota.57), slice={[0:32768]}
+ slice.61 = u32[32768]{0} slice(iota.57), slice={[32768:65536]}
+ custom-call.64 = (u32[32768]{0}, u32[32768]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[32768]{0}, u32[32768]{0}, u32[32768]{0}, u32[32768]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\000\000\000\000\000\000"
+ get-tuple-element.65 = u32[32768]{0} get-tuple-element(custom-call.64), index=0
+ get-tuple-element.66 = u32[32768]{0} get-tuple-element(custom-call.64), index=1
+ concatenate.67 = u32[65536]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
+ constant.15 = u32[] constant(9)
+ broadcast.3 = u32[65536]{0} broadcast(constant.15), dimensions={}
+ shift-right-logical.0 = u32[65536]{0} shift-right-logical(concatenate.67, broadcast.3)
+ constant.13 = u32[] constant(1065353216)
+ broadcast.11 = u32[65536]{0} broadcast(constant.13), dimensions={}
+ or.0 = u32[65536]{0} or(shift-right-logical.0, broadcast.11)
+ bitcast-convert.0 = f32[65536]{0} bitcast-convert(or.0)
+ constant.3 = f32[] constant(-1)
+ broadcast.17 = f32[65536]{0} broadcast(constant.3), dimensions={}
+ add.1 = f32[65536]{0} add(bitcast-convert.0, broadcast.17)
+ broadcast.18 = f32[65536]{0} broadcast(constant.9), dimensions={}
+ maximum.0 = f32[65536]{0} maximum(add.1, broadcast.18)
+ constant.7 = f32[] constant(0.9)
+ broadcast.19 = f32[65536]{0} broadcast(constant.7), dimensions={}
+ compare.0 = pred[65536]{0} compare(maximum.0, broadcast.19), direction=LT
+ constant = bf16[] constant(1.109)
+ broadcast.20 = bf16[65536]{0} broadcast(constant), dimensions={}
+ constant.4 = bf16[] constant(0)
+ broadcast.21 = bf16[65536]{0} broadcast(constant.4), dimensions={}
+ select.1 = bf16[65536]{0} select(compare.0, broadcast.20, broadcast.21)
+ reshape.19 = bf16[16,16,256]{2,1,0} reshape(select.1)
+ broadcast.9 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.19), dimensions={0,1,3}
+ multiply.79 = bf16[16,16,256,256]{3,2,1,0} multiply(divide.56, broadcast.9)
+ dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, multiply.79), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ transpose.81 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
+ ROOT copy.3 = bf16[16,256,16,64]{3,2,1,0} copy(transpose.81)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
+ m::CustomCall(
+ &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
+ 0))))
+ .WithShape(BF16, {16, 256, 16, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(fmha->operands().size(), 4);
+ EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16Bmm1ScaleBiasSoftmaxDropoutForm2Bmm2) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0})->bf16[32,40,60,64]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
+
+region_0.29 {
+ Arg_0.30 = bf16[] parameter(0)
+ Arg_1.31 = bf16[] parameter(1)
+ ROOT maximum.32 = bf16[] maximum(Arg_0.30, Arg_1.31)
+}
+
+region_1.41 {
+ Arg_0.42 = f32[] parameter(0)
+ Arg_1.43 = f32[] parameter(1)
+ ROOT add.44 = f32[] add(Arg_0.42, Arg_1.43)
+}
+
+ENTRY main.79 {
+ Arg_2.3 = bf16[32,40,60,64]{3,2,1,0} parameter(2), sharding={replicated}
+ copy = bf16[32,40,60,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
+ transpose.2 = bf16[32,60,64,40]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
+ constant.19 = u32[1]{0} constant({2718843009})
+ constant.18 = u32[1]{0} constant({1272950319})
+ constant.2 = u32[1]{0} constant({0})
+ constant.20 = u32[1]{0} constant({2711844646})
+ custom-call.54 = (u32[1]{0}, u32[1]{0}) custom-call(constant.19, constant.18, constant.2, constant.20), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
+ get-tuple-element.55 = u32[1]{0} get-tuple-element(custom-call.54), index=0
+ reshape.58 = u32[] reshape(get-tuple-element.55)
+ broadcast.62 = u32[1536000]{0} broadcast(reshape.58), dimensions={}
+ get-tuple-element.56 = u32[1]{0} get-tuple-element(custom-call.54), index=1
+ reshape.59 = u32[] reshape(get-tuple-element.56)
+ broadcast.63 = u32[1536000]{0} broadcast(reshape.59), dimensions={}
+ iota.57 = u32[3072000]{0} iota(), iota_dimension=0
+ slice.60 = u32[1536000]{0} slice(iota.57), slice={[0:1536000]}
+ slice.61 = u32[1536000]{0} slice(iota.57), slice={[1536000:3072000]}
+ custom-call.64 = (u32[1536000]{0}, u32[1536000]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000p\027\000\000\000\000\000"
+ get-tuple-element.65 = u32[1536000]{0} get-tuple-element(custom-call.64), index=0
+ get-tuple-element.66 = u32[1536000]{0} get-tuple-element(custom-call.64), index=1
+ concatenate.67 = u32[3072000]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
+ constant.16 = u32[] constant(9)
+ broadcast.2 = u32[3072000]{0} broadcast(constant.16), dimensions={}
+ shift-right-logical.0 = u32[3072000]{0} shift-right-logical(concatenate.67, broadcast.2)
+ constant.14 = u32[] constant(1065353216)
+ broadcast.6 = u32[3072000]{0} broadcast(constant.14), dimensions={}
+ or.0 = u32[3072000]{0} or(shift-right-logical.0, broadcast.6)
+ bitcast-convert.0 = f32[3072000]{0} bitcast-convert(or.0)
+ constant.3 = f32[] constant(-1)
+ broadcast.8 = f32[3072000]{0} broadcast(constant.3), dimensions={}
+ add.1 = f32[3072000]{0} add(bitcast-convert.0, broadcast.8)
+ constant.10 = f32[] constant(0)
+ broadcast.10 = f32[3072000]{0} broadcast(constant.10), dimensions={}
+ maximum.0 = f32[3072000]{0} maximum(add.1, broadcast.10)
+ constant.8 = f32[] constant(0.9)
+ broadcast.12 = f32[3072000]{0} broadcast(constant.8), dimensions={}
+ compare.0 = pred[3072000]{0} compare(maximum.0, broadcast.12), direction=LT
+ reshape.18 = pred[32,60,40,40]{3,2,1,0} reshape(compare.0)
+ Arg_0.1 = bf16[32,40,60,64]{3,2,1,0} parameter(0), sharding={replicated}
+ copy.1 = bf16[32,40,60,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+ transpose = bf16[32,60,40,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
+ Arg_1.2 = bf16[32,40,60,64]{3,2,1,0} parameter(1), sharding={replicated}
+ copy.2 = bf16[32,40,60,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
+ transpose.1 = bf16[32,60,64,40]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
+ dot = bf16[32,60,40,40]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.25 = bf16[] constant(1)
+ broadcast.26 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.25), dimensions={}
+ add.28 = bf16[32,60,40,40]{3,2,1,0} add(dot, broadcast.26)
+ constant.24 = bf16[] constant(-inf)
+ reduce.33 = bf16[32,60,40]{2,1,0} reduce(add.28, constant.24), dimensions={3}, to_apply=region_0.29
+ broadcast.37 = bf16[32,60,40,40]{3,2,1,0} broadcast(reduce.33), dimensions={0,1,2}
+ subtract.38 = bf16[32,60,40,40]{3,2,1,0} subtract(add.28, broadcast.37)
+ exponential.39 = bf16[32,60,40,40]{3,2,1,0} exponential(subtract.38)
+ convert.40 = f32[32,60,40,40]{3,2,1,0} convert(exponential.39)
+ reduce.45 = f32[32,60,40]{2,1,0} reduce(convert.40, constant.10), dimensions={3}, to_apply=region_1.41
+ convert.0 = bf16[32,60,40]{2,1,0} convert(reduce.45)
+ broadcast.50 = bf16[32,60,40,40]{3,2,1,0} broadcast(convert.0), dimensions={0,1,2}
+ divide.51 = bf16[32,60,40,40]{3,2,1,0} divide(exponential.39, broadcast.50)
+ constant = bf16[] constant(1.109)
+ broadcast.1 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant), dimensions={}
+ multiply = bf16[32,60,40,40]{3,2,1,0} multiply(divide.51, broadcast.1)
+ constant.4 = bf16[] constant(0)
+ broadcast.5 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.4), dimensions={}
+ select.76 = bf16[32,60,40,40]{3,2,1,0} select(reshape.18, multiply, broadcast.5)
+ dot.1 = bf16[32,60,64,40]{3,2,1,0} dot(transpose.2, select.76), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ transpose.78 = bf16[32,40,60,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
+ ROOT copy.3 = bf16[32,40,60,64]{3,2,1,0} copy(transpose.78)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
+ m::CustomCall(
+ &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
+ 0))))
+ .WithShape(BF16, {32, 40, 60, 64})));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
+ EXPECT_EQ(fmha->operands().size(), 4);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1Bmm2) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0})}
+
+ENTRY main.17 {
+ Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+ copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
+ transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
+ Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+ copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+ transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
+ Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+ copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
+ transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
+ dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ transpose.7 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
+ Arg_3.4 = bf16[16,256,16,64]{3,2,1,0} parameter(3), sharding={replicated}
+ copy.3 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_3.4), sharding={replicated}
+ transpose.4 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.3), dimensions={0,2,1,3}
+ dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.4, transpose.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ copy.4 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_1.2), sharding={replicated}
+ transpose.12 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.4), dimensions={0,2,1,3}
+ dot.4 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.15 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.4), dimensions={0,2,1,3}
+ dot.3 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.13 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.3), dimensions={0,2,1,3}
+ copy.5 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_3.4), sharding={replicated}
+ transpose.8 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.5), dimensions={0,2,3,1}
+ dot.10 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.8, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.11 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.10), dimensions={0,3,1,2}
+ tuple.16 = (bf16[16,256,16,64]{1,3,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{1,3,2,0}) tuple(transpose.7, transpose.15, transpose.13, transpose.11)
+ get-tuple-element = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=0
+ copy.6 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element)
+ get-tuple-element.1 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=1
+ copy.7 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.1)
+ get-tuple-element.2 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=2
+ copy.8 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.2)
+ get-tuple-element.3 = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=3
+ copy.9 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.3)
+ ROOT tuple = (bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}) tuple(copy.6, copy.7, copy.8, copy.9)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ const auto status = RunHloPass(&fusedMhaRewriter, m.get());
+ const bool changed = status.value();
+ EXPECT_EQ(changed, false);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16MiniT5xTest) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__lambda_, entry_computation_layout={(bf16[12,512,32,64]{3,2,1,0},bf16[12,512,2,32,64]{4,3,2,1,0},f32[12,512]{1,0},f32[12,512]{1,0})->(bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true}
+
+region_0.51 {
+ Arg_0.52 = bf16[] parameter(0)
+ Arg_1.53 = bf16[] parameter(1)
+ ROOT maximum.54 = bf16[] maximum(Arg_0.52, Arg_1.53)
+}
+
+region_1.63 {
+ Arg_0.64 = f32[] parameter(0)
+ Arg_1.65 = f32[] parameter(1)
+ ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
+}
+
+region_3.99 {
+ Arg_0.100 = bf16[] parameter(0)
+ Arg_1.101 = bf16[] parameter(1)
+ ROOT add.102 = bf16[] add(Arg_0.100, Arg_1.101)
+}
+
+ENTRY main.129 {
+ Arg_1.2 = bf16[12,512,2,32,64]{4,3,2,1,0} parameter(1), sharding={replicated}
+ copy = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
+ slice.42 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy), slice={[0:12], [0:512], [1:2], [0:32], [0:64]}
+ reshape.44 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.42)
+ transpose.5 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.44), dimensions={0,2,3,1}
+ Arg_0.1 = bf16[12,512,32,64]{3,2,1,0} parameter(0), sharding={replicated}
+ copy.1 = bf16[12,512,32,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+ constant.5 = bf16[] constant(0.125)
+ broadcast.6 = bf16[12,512,32,64]{3,1,2,0} broadcast(constant.5), dimensions={}
+ multiply.45 = bf16[12,512,32,64]{3,1,2,0} multiply(copy.1, broadcast.6)
+ transpose = bf16[12,32,512,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
+ copy.2 = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
+ slice.41 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy.2), slice={[0:12], [0:512], [0:1], [0:32], [0:64]}
+ reshape.43 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.41)
+ transpose.1 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.43), dimensions={0,2,3,1}
+ dot = bf16[12,32,512,512]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_2.3 = f32[12,512]{1,0} parameter(2), sharding={replicated}
+ constant.14 = f32[] constant(0)
+ broadcast.19 = f32[12,512]{1,0} broadcast(constant.14), dimensions={}
+ compare.24 = pred[12,512]{1,0} compare(Arg_2.3, broadcast.19), direction=GT
+ broadcast.30 = pred[12,512,512]{2,1,0} broadcast(compare.24), dimensions={0,1}
+ Arg_3.4 = f32[12,512]{1,0} parameter(3), sharding={replicated}
+ compare.25 = pred[12,512]{1,0} compare(Arg_3.4, broadcast.19), direction=GT
+ broadcast.33 = pred[12,512,512]{2,1,0} broadcast(compare.25), dimensions={0,2}
+ and.34 = pred[12,512,512]{2,1,0} and(broadcast.30, broadcast.33)
+ convert.4 = s32[12,512,512]{2,1,0} convert(and.34)
+ constant.16 = s32[] constant(0)
+ broadcast.21 = s32[12,512,512]{2,1,0} broadcast(constant.16), dimensions={}
+ compare.0 = pred[12,512,512]{2,1,0} compare(convert.4, broadcast.21), direction=GT
+ constant.20 = bf16[] constant(0)
+ broadcast.22 = bf16[12,512,512]{2,1,0} broadcast(constant.20), dimensions={}
+ constant.11 = bf16[] constant(-9.999e+09)
+ broadcast.23 = bf16[12,512,512]{2,1,0} broadcast(constant.11), dimensions={}
+ select.0 = bf16[12,512,512]{2,1,0} select(compare.0, broadcast.22, broadcast.23)
+ broadcast.49 = bf16[12,32,512,512]{3,2,1,0} broadcast(select.0), dimensions={0,2,3}
+ add.50 = bf16[12,32,512,512]{3,2,1,0} add(dot, broadcast.49)
+ constant.22 = bf16[] constant(-inf)
+ reduce.55 = bf16[12,32,512]{2,1,0} reduce(add.50, constant.22), dimensions={3}, to_apply=region_0.51
+ broadcast.59 = bf16[12,32,512,512]{3,2,1,0} broadcast(reduce.55), dimensions={0,1,2}
+ subtract.60 = bf16[12,32,512,512]{3,2,1,0} subtract(add.50, broadcast.59)
+ exponential.61 = bf16[12,32,512,512]{3,2,1,0} exponential(subtract.60)
+ convert.62 = f32[12,32,512,512]{3,2,1,0} convert(exponential.61)
+ reduce.67 = f32[12,32,512]{2,1,0} reduce(convert.62, constant.14), dimensions={3}, to_apply=region_1.63
+ convert.5 = bf16[12,32,512]{2,1,0} convert(reduce.67)
+ broadcast.72 = bf16[12,32,512,512]{3,2,1,0} broadcast(convert.5), dimensions={0,1,2}
+ divide.73 = bf16[12,32,512,512]{3,2,1,0} divide(exponential.61, broadcast.72)
+ dot.1 = bf16[12,32,64,512]{3,2,1,0} dot(transpose.5, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ convert.6 = f32[12,32,64,512]{3,2,1,0} convert(dot.1)
+ reduce.83 = f32[] reduce(convert.6, constant.14), dimensions={0,3,1,2}, to_apply=region_1.63
+ convert.84 = bf16[] convert(reduce.83)
+ constant.2 = bf16[] constant(0.0007935)
+ multiply.86 = bf16[] multiply(convert.84, constant.2)
+ broadcast.9 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.2), dimensions={}
+ dot.2 = bf16[12,32,512,512]{3,2,1,0} dot(broadcast.9, transpose.5), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ divide.109 = bf16[12,32,512,512]{3,2,1,0} divide(dot.2, broadcast.72)
+ constant.10 = bf16[] constant(1)
+ broadcast.24 = bf16[12,32,512]{2,1,0} broadcast(constant.10), dimensions={}
+ multiply.4 = bf16[12,32,512]{2,1,0} multiply(convert.5, convert.5)
+ divide.0 = bf16[12,32,512]{2,1,0} divide(broadcast.24, multiply.4)
+ broadcast.96 = bf16[12,32,512,512]{3,2,1,0} broadcast(divide.0), dimensions={0,1,2}
+ multiply.97 = bf16[12,32,512,512]{3,2,1,0} multiply(dot.2, broadcast.96)
+ multiply.98 = bf16[12,32,512,512]{3,2,1,0} multiply(multiply.97, exponential.61)
+ reduce.103 = bf16[12,32,512]{2,1,0} reduce(multiply.98, constant.20), dimensions={3}, to_apply=region_3.99
+ negate.0 = bf16[12,32,512]{2,1,0} negate(reduce.103)
+ broadcast.10 = bf16[12,32,512,512]{3,2,1,0} broadcast(negate.0), dimensions={0,1,2}
+ add.118 = bf16[12,32,512,512]{3,2,1,0} add(divide.109, broadcast.10)
+ multiply.119 = bf16[12,32,512,512]{3,2,1,0} multiply(add.118, exponential.61)
+ transpose.9 = bf16[12,32,512,64]{2,3,1,0} transpose(reshape.43), dimensions={0,2,1,3}
+ copy.3 = bf16[12,32,512,64]{3,2,1,0} copy(transpose.9)
+ dot.4 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, copy.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ broadcast.12 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.5), dimensions={}
+ multiply.3 = bf16[12,32,512,64]{3,2,1,0} multiply(dot.4, broadcast.12)
+ transpose.11 = bf16[12,512,32,64]{3,1,2,0} transpose(multiply.3), dimensions={0,2,1,3}
+ broadcast.7 = bf16[12,32,64,512]{3,2,1,0} broadcast(constant.2), dimensions={}
+ dot.90 = bf16[12,32,64,512]{3,2,1,0} dot(broadcast.7, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.91 = bf16[12,512,32,64]{1,3,2,0} transpose(dot.90), dimensions={0,3,1,2}
+ reshape.92 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.91)
+ pad.93 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.92, constant.20), padding=0_0x0_0x1_0x0_0x0_0
+ dot.3 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ copy.4 = bf16[12,32,512,64]{2,3,1,0} copy(dot.3)
+ transpose.121 = bf16[12,512,32,64]{1,3,2,0} transpose(copy.4), dimensions={0,2,1,3}
+ reshape.124 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.121)
+ pad.125 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.124, constant.20), padding=0_0x0_0x0_1x0_0x0_0
+ add.126 = bf16[12,512,2,32,64]{1,4,3,0,2} add(pad.93, pad.125)
+ tuple.128 = (bf16[], bf16[12,512,32,64]{3,1,2,0}, bf16[12,512,2,32,64]{1,4,3,0,2}) tuple(multiply.86, transpose.11, add.126)
+ get-tuple-element = bf16[] get-tuple-element(tuple.128), index=0
+ get-tuple-element.1 = bf16[12,512,32,64]{3,1,2,0} get-tuple-element(tuple.128), index=1
+ copy.5 = bf16[12,512,32,64]{3,2,1,0} copy(get-tuple-element.1)
+ get-tuple-element.2 = bf16[12,512,2,32,64]{1,4,3,0,2} get-tuple-element(tuple.128), index=2
+ copy.6 = bf16[12,512,2,32,64]{4,3,2,1,0} copy(get-tuple-element.2)
+ ROOT tuple = (bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0}) tuple(get-tuple-element, copy.5, copy.6)
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ AlgebraicSimplifierOptions alg_sim_options;
+ alg_sim_options.set_supports_non_canonical_dots(false);
+ alg_sim_options.set_is_layout_sensitive(true);
+ alg_sim_options.set_enable_conv_operand_swap(false);
+ AlgebraicSimplifier alge_simp{alg_sim_options};
+ ReshapeDecomposer reshape_decomposer;
+ LayoutNormalization layout_normalizer;
+ HloCSE cse{/*is_layout_sensitive=*/true};
+ TF_ASSERT_OK(RunHloPass(&reshape_decomposer, m.get()).status());
+ TF_ASSERT_OK(RunHloPass(&layout_normalizer, m.get()).status());
+ TF_ASSERT_OK(RunHloPass(&cse, m.get()).status());
+ TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
+
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+ CudnnFusedMHATransposeFusion fmha_transpose_fusion;
+
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
+ TF_ASSERT_OK(RunHloPass(&fmha_transpose_fusion, m.get()).status());
+
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ EXPECT_EQ(CountFusedAttentionCall(m.get()), 1);
+ EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 1);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ ActivationHasMoreThan1UserShouldNotLower) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule test
+
+%region_50.2457 (Arg_0.2458: bf16[], Arg_1.2459: bf16[]) -> bf16[] {
+ %Arg_0.2458 = bf16[] parameter(0)
+ %Arg_1.2459 = bf16[] parameter(1)
+ ROOT %maximum.2 = bf16[] maximum(bf16[] %Arg_0.2458, bf16[] %Arg_1.2459)
+}
+
+%region_36.2316 (Arg_0.2317: f32[], Arg_1.2318: f32[]) -> f32[] {
+ %Arg_0.2317 = f32[] parameter(0)
+ %Arg_1.2318 = f32[] parameter(1)
+ ROOT %add.342 = f32[] add(f32[] %Arg_0.2317, f32[] %Arg_1.2318)
+}
+
+ENTRY main {
+ %transpose.482 = bf16[4,5,64]{2,1,0} parameter(0)
+ %transpose.484 = bf16[4,64,5]{2,1,0} parameter(1)
+ %dot.20 = bf16[4,5,5]{2,1,0} dot(bf16[4,5,64]{2,1,0} %transpose.482, bf16[4,64,5]{2,1,0} %transpose.484), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+ %constant.2515 = bf16[] constant(0.125)
+ %broadcast.789 = bf16[4,5,5]{2,1,0} broadcast(bf16[] %constant.2515), dimensions={}
+ %multiply.267 = bf16[4,5,5]{2,1,0} multiply(bf16[4,5,5]{2,1,0} %dot.20, bf16[4,5,5]{2,1,0} %broadcast.789)
+ %constant.287 = f32[] constant(-1)
+ %broadcast.792 = bf16[4,5,5]{2,1,0} parameter(3)
+ %add.348 = bf16[4,5,5]{2,1,0} add(bf16[4,5,5]{2,1,0} %multiply.267, bf16[4,5,5]{2,1,0} %broadcast.792)
+ %constant.2510 = bf16[] constant(-inf)
+ %reduce.2550 = bf16[4,5]{1,0} reduce(bf16[4,5,5]{2,1,0} %add.348, bf16[] %constant.2510), dimensions={2}, to_apply=%region_50.2457
+ %broadcast.793 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %reduce.2550), dimensions={0,1}
+ %subtract.81 = bf16[4,5,5]{2,1,0} subtract(bf16[4,5,5]{2,1,0} %add.348, bf16[4,5,5]{2,1,0} %broadcast.793)
+ %exponential.21 = bf16[4,5,5]{2,1,0} exponential(bf16[4,5,5]{2,1,0} %subtract.81)
+ %convert.180 = f32[4,5,5]{2,1,0} convert(bf16[4,5,5]{2,1,0} %exponential.21)
+ %constant.2509 = f32[] constant(0)
+ %reduce.2558 = f32[4,5]{1,0} reduce(f32[4,5,5]{2,1,0} %convert.180, f32[] %constant.2509), dimensions={2}, to_apply=%region_36.2316
+ %convert.182 = bf16[4,5]{1,0} convert(f32[4,5]{1,0} %reduce.2558)
+ %broadcast.794 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %convert.182), dimensions={0,1}
+ %divide.25 = bf16[4,5,5]{2,1,0} divide(bf16[4,5,5]{2,1,0} %exponential.21, bf16[4,5,5]{2,1,0} %broadcast.794)
+ %transpose.481 = bf16[4,64,5]{2,1,0} parameter(2)
+ %dot.21 = bf16[4,64,5]{2,1,0} dot(bf16[4,64,5]{2,1,0} %transpose.481, bf16[4,5,5]{2,1,0} %divide.25), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
+ ROOT %tuple.2668 = (bf16[4,5,5]{2,1,0}, bf16[4,64,5]{2,1,0}) tuple(bf16[4,5,5]{2,1,0} %divide.25, bf16[4,64,5]{2,1,0} %dot.21)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision*/ true);
+ ASSERT_IS_OK(verifier.Run(m.get()).status());
+
+ EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxBmm2ShouldNotBeLowered) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+ Arg_0.22 = f16[] parameter(0)
+ Arg_1.23 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+ Arg_0.34 = f32[] parameter(0)
+ Arg_1.35 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+ Arg_0.56 = f16[] parameter(0)
+ Arg_1.57 = f16[] parameter(1)
+ ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+ constant.18 = pred[2,6,128,128]{3,2,1,0} constant({...})
+ Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.22 = f16[] constant(2)
+ broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+ multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+ constant.19 = f16[] constant(1)
+ broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+ add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+ constant.21 = f16[] constant(0)
+ broadcast.23 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.21), dimensions={}
+ select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.18, add.3, broadcast.23)
+ constant.15 = f16[] constant(-inf)
+ reduce.25 = f16[2,6,128]{2,1,0} reduce(select.1, constant.15), dimensions={3}, to_apply=region_0.21
+ broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+ subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.17)
+ exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+ constant.17 = f32[] constant(0)
+ reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+ convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+ broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+ Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+ broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+ multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+ broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+ multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+ multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+ reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+ broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.59), dimensions={0,1,2}
+ add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+ multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+ select.3 = f16[2,6,128,128]{3,2,1,0} select(constant.18, multiply.8, broadcast.23)
+ multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(select.3, broadcast.24)
+ dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision*/ true);
+ ASSERT_IS_OK(verifier.Run(m.get()).status());
+
+ // The backward pattern in the graph is not a valid fmha pattern,
+ // we expect no rewrite happening.
+ EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
+ EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2ShouldNotLower) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.38 {
+ Arg_0.39 = f16[] parameter(0)
+ Arg_1.40 = f16[] parameter(1)
+ ROOT maximum.1 = f16[] maximum(Arg_0.39, Arg_1.40)
+}
+
+region_1.50 {
+ Arg_0.51 = f32[] parameter(0)
+ Arg_1.52 = f32[] parameter(1)
+ ROOT add.2 = f32[] add(Arg_0.51, Arg_1.52)
+}
+
+region_2.99 {
+ Arg_0.100 = f16[] parameter(0)
+ Arg_1.101 = f16[] parameter(1)
+ ROOT add.3 = f16[] add(Arg_0.100, Arg_1.101)
+}
+
+ENTRY main.126 {
+ constant.6 = u32[1]{0} constant({2718843009})
+ constant.8 = u32[1]{0} constant({1272950319})
+ constant.10 = u32[1]{0} constant({0})
+ constant.12 = u32[1]{0} constant({2711844646})
+ custom-call.65 = (u32[1]{0}, u32[1]{0}) custom-call(constant.6, constant.8, constant.10, constant.12), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
+ get-tuple-element.66 = u32[1]{0} get-tuple-element(custom-call.65), index=0
+ bitcast.343 = u32[] bitcast(get-tuple-element.66)
+ broadcast.27 = u32[98304]{0} broadcast(bitcast.343), dimensions={}
+ get-tuple-element.67 = u32[1]{0} get-tuple-element(custom-call.65), index=1
+ bitcast.344 = u32[] bitcast(get-tuple-element.67)
+ broadcast.28 = u32[98304]{0} broadcast(bitcast.344), dimensions={}
+ iota.68 = u32[196608]{0} iota(), iota_dimension=0
+ slice = u32[98304]{0} slice(iota.68), slice={[0:98304]}
+ slice.1 = u32[98304]{0} slice(iota.68), slice={[98304:196608]}
+ custom-call.75 = (u32[98304]{0}, u32[98304]{0}) custom-call(broadcast.27, broadcast.28, slice, slice.1), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[98304]{0}, u32[98304]{0}, u32[98304]{0}, u32[98304]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\001\000\000\000\000\000"
+ get-tuple-element.76 = u32[98304]{0} get-tuple-element(custom-call.75), index=0
+ get-tuple-element.77 = u32[98304]{0} get-tuple-element(custom-call.75), index=1
+ concatenate.2 = u32[196608]{0} concatenate(get-tuple-element.76, get-tuple-element.77), dimensions={0}
+ constant.56 = u32[] constant(9)
+ broadcast.63 = u32[196608]{0} broadcast(constant.56), dimensions={}
+ shift-right-logical.3 = u32[196608]{0} shift-right-logical(concatenate.2, broadcast.63)
+ constant.57 = u32[] constant(1065353216)
+ broadcast.64 = u32[196608]{0} broadcast(constant.57), dimensions={}
+ or.3 = u32[196608]{0} or(shift-right-logical.3, broadcast.64)
+ bitcast-convert.3 = f32[196608]{0} bitcast-convert(or.3)
+ constant.58 = f32[] constant(-1)
+ broadcast.65 = f32[196608]{0} broadcast(constant.58), dimensions={}
+ add.10 = f32[196608]{0} add(bitcast-convert.3, broadcast.65)
+ constant.48 = f32[] constant(0)
+ broadcast.66 = f32[196608]{0} broadcast(constant.48), dimensions={}
+ maximum.4 = f32[196608]{0} maximum(add.10, broadcast.66)
+ constant.59 = f32[] constant(0.9)
+ broadcast.67 = f32[196608]{0} broadcast(constant.59), dimensions={}
+ compare.3 = pred[196608]{0} compare(maximum.4, broadcast.67), direction=LT
+ bitcast.308 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
+ constant.44 = pred[2,6,128,128]{3,2,1,0} constant({...})
+ Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.34 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.55 = f16[] constant(2)
+ broadcast.61 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.55), dimensions={}
+ multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(dot.34, broadcast.61)
+ constant.52 = f16[] constant(1)
+ broadcast.39 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.52), dimensions={}
+ add.6 = f16[2,6,128,128]{3,2,1,0} add(multiply.8, broadcast.39)
+ constant.54 = f16[] constant(0)
+ broadcast.52 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.54), dimensions={}
+ select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.44, add.6, broadcast.52)
+ constant.41 = f16[] constant(-inf)
+ reduce.42 = f16[2,6,128]{2,1,0} reduce(select.1, constant.41), dimensions={3}, to_apply=region_0.38
+ broadcast.42 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.42), dimensions={0,1,2}
+ subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.42)
+ exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+ reduce.54 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.48), dimensions={3}, to_apply=region_1.50
+ convert.9 = f16[2,6,128]{2,1,0} convert(reduce.54)
+ broadcast.68 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.68)
+ constant.60 = f16[] constant(1.1113)
+ broadcast.69 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.60), dimensions={}
+ multiply.20 = f16[2,6,128,128]{3,2,1,0} multiply(divide.5, broadcast.69)
+ select.8 = f16[2,6,128,128]{3,2,1,0} select(bitcast.308, multiply.20, broadcast.52)
+ Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.88 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ bitcast.248 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
+ Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.91 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ select.6 = f16[2,6,128,128]{3,2,1,0} select(bitcast.248, dot.91, broadcast.52)
+ multiply.17 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.69)
+ divide.4 = f16[2,6,128,128]{3,2,1,0} divide(multiply.17, broadcast.68)
+ broadcast.55 = f16[2,6,128]{2,1,0} broadcast(constant.52), dimensions={}
+ multiply.11 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.55, multiply.11)
+ broadcast.56 = f16[2,6,128]{2,1,0} broadcast(constant.60), dimensions={}
+ multiply.12 = f16[2,6,128]{2,1,0} multiply(divide.3, broadcast.56)
+ broadcast.58 = f16[2,6,128,128]{3,2,1,0} broadcast(multiply.12), dimensions={0,1,2}
+ multiply.13 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.58)
+ multiply.14 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.13, exponential.1)
+ reduce.103 = f16[2,6,128]{2,1,0} reduce(multiply.14, constant.54), dimensions={3}, to_apply=region_2.99
+ broadcast.62 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.103), dimensions={0,1,2}
+ add.9 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.62)
+ multiply.18 = f16[2,6,128,128]{3,2,1,0} multiply(add.9, exponential.1)
+ select.7 = f16[2,6,128,128]{3,2,1,0} select(constant.44, multiply.18, broadcast.52)
+ multiply.19 = f16[2,6,128,128]{3,2,1,0} multiply(select.7, broadcast.61)
+ dot.124 = f16[2,6,128,64]{3,2,1,0} dot(multiply.19, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.19), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.125 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.88, dot.124, dot, dot.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision*/ true);
+ ASSERT_IS_OK(verifier.Run(m.get()).status());
+
+ // The backward pattern in the graph is not a valid fmha pattern,
+ // we expect no rewrite happening.
+ EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
+ EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ F16TrainingBmm1ScaleBiasSoftmaxBmm2QTranspose) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+ Arg_0.22 = f16[] parameter(0)
+ Arg_1.23 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+ Arg_0.34 = f32[] parameter(0)
+ Arg_1.35 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+ Arg_0.56 = f16[] parameter(0)
+ Arg_1.57 = f16[] parameter(1)
+ ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+ Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.22 = f16[] constant(2)
+ broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+ multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+ constant.19 = f16[] constant(1)
+ broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+ add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+ constant.21 = f16[] constant(0)
+ constant.15 = f16[] constant(-inf)
+ reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
+ broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+ subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
+ exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+ constant.17 = f32[] constant(0)
+ reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+ convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+ broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+ Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+ broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+ multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+ broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+ multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+ multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+ reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+ negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
+ broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+ add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+ multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+ multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
+ dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
+ .WithShape(F16, {2, 6, 128, 64}),
+ m::GetTupleElement(
+ m::CustomCall(&fmha,
+ {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 0)
+ .WithShape(F16, {2, 6, 128, 64}),
+ m::Transpose(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 1))
+ .WithShape(F16, {2, 6, 64, 128}),
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
+ .WithShape(F16, {2, 6, 128, 64}))));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(fmha->operands().size(), 7);
+ EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ F16Bmm1UnfusedSoftmaxBmm2IncorrectBmm1NumUsers) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
+
+region_0.7 {
+ Arg_0.8 = f16[] parameter(0)
+ Arg_1.9 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+ Arg_0.20 = f32[] parameter(0)
+ Arg_1.21 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.31 {
+ Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
+ dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ // extra user of bmm1
+ neg.1 = f16[2,6,40,40]{3,2,1,0} negate(dot)
+ constant = f16[] constant(-inf)
+ reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
+ broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+ subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
+ exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
+ convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
+ constant.1 = f32[] constant(0)
+ reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+ convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
+ broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+ divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
+ Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::Dot(), m::Negate())));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ F16Bmm1UnfusedSoftmaxBmm2IncorrectSoftmaxNumUsers) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
+
+region_0.7 {
+ Arg_0.8 = f16[] parameter(0)
+ Arg_1.9 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+ Arg_0.20 = f32[] parameter(0)
+ Arg_1.21 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.31 {
+ Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
+ dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ constant = f16[] constant(-inf)
+ reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
+ broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+ subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
+ // extra user of softmax sub node
+ neg.1 = f16[2,6,40,40]{3,2,1,0} negate(subtract.1)
+ exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
+ convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
+ constant.1 = f32[] constant(0)
+ reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+ convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
+ broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+ divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
+ Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+ ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::Dot(), m::Negate())));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ F16TrainingBmm1ScaleBiasSoftmaxBmm2IncorrectSoftmaxBwdNumUsers) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+ Arg_0.22 = f16[] parameter(0)
+ Arg_1.23 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+ Arg_0.34 = f32[] parameter(0)
+ Arg_1.35 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+ Arg_0.56 = f16[] parameter(0)
+ Arg_1.57 = f16[] parameter(1)
+ ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+ Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.22 = f16[] constant(2)
+ broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+ multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+ constant.19 = f16[] constant(1)
+ broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+ add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+ constant.21 = f16[] constant(0)
+ constant.15 = f16[] constant(-inf)
+ reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
+ broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+ subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
+ exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+ constant.17 = f32[] constant(0)
+ reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+ convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+ broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+ Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+ // extra user of softmax bwd divide node
+ neg.1 = f16[2,6,128,128]{3,2,1,0} negate(divide.4)
+ broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+ multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+ broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+ multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+ multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+ reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+ negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
+ broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+ add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+ multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+ multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
+ dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
+ m::Negate())));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1SoftmaxBmm2IncorrectRank) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule reproducer, entry_computation_layout={(f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, /*index=5*/f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0})->f16[8,16,2,5,64]{4,3,2,1,0}}
+
+region_0.36 {
+ Arg_0.37 = f16[] parameter(0)
+ Arg_1.38 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.37, Arg_1.38)
+}
+
+region_1.48 {
+ Arg_0.49 = f32[] parameter(0)
+ Arg_1.50 = f32[] parameter(1)
+ ROOT add.1 = f32[] add(Arg_0.49, Arg_1.50)
+}
+
+ENTRY main {
+ arg2.3 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(2), parameter_replication={false}
+ bitcast.31 = f16[640,128]{1,0} bitcast(arg2.3)
+ arg5.6 = f32[128,2,64]{2,1,0} parameter(5), parameter_replication={false}
+ convert.3 = f16[128,2,64]{2,1,0} convert(arg5.6)
+ bitcast.36 = f16[128,128]{1,0} bitcast(convert.3)
+ dot = f16[640,128]{1,0} dot(bitcast.31, bitcast.36), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
+ bitcast.39 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot)
+ transpose.27 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.39), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
+ arg6.7 = f32[2,64]{1,0} parameter(6), parameter_replication={false}
+ convert.4 = f16[2,64]{1,0} convert(arg6.7)
+ broadcast.9 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.4), dimensions={3,5}
+ add.2 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.27, broadcast.9)
+ bitcast.49 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.2)
+ arg0.1 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(0), parameter_replication={false}
+ bitcast.53 = f16[640,128]{1,0} bitcast(arg0.1)
+ arg3.4 = f32[128,2,64]{2,1,0} parameter(3), parameter_replication={false}
+ convert.5 = f16[128,2,64]{2,1,0} convert(arg3.4)
+ bitcast.58 = f16[128,128]{1,0} bitcast(convert.5)
+ dot.1 = f16[640,128]{1,0} dot(bitcast.53, bitcast.58), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
+ bitcast.61 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.1)
+ transpose.28 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} transpose(bitcast.61), dimensions={0,1,2,4,5,3}, frontend_attributes={grad_x="false",grad_y="false"}
+ arg4.5 = f32[2,64]{1,0} parameter(4), parameter_replication={false}
+ convert.6 = f16[2,64]{1,0} convert(arg4.5)
+ broadcast.10 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(convert.6), dimensions={3,4}
+ add.3 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} add(transpose.28, broadcast.10)
+ constant.29 = f16[] constant(0.125)
+ broadcast.11 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(constant.29), dimensions={}
+ multiply = f16[1,8,16,2,64,5]{5,4,3,2,1,0} multiply(add.3, broadcast.11)
+ bitcast.74 = f16[8,16,2,64,5]{4,3,2,1,0} bitcast(multiply)
+ dot.6 = f16[8,16,2,5,5]{4,3,2,1,0} dot(bitcast.49, bitcast.74), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
+ constant.35 = f16[] constant(-inf)
+ reduce.1 = f16[8,16,2,5]{3,2,1,0} reduce(dot.6, constant.35), dimensions={3}, to_apply=region_0.36
+ broadcast.12 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(reduce.1), dimensions={0,1,2,4}
+ subtract.2 = f16[8,16,2,5,5]{4,3,2,1,0} subtract(dot.6, broadcast.12)
+ exponential.2 = f16[8,16,2,5,5]{4,3,2,1,0} exponential(subtract.2)
+ convert.7 = f32[8,16,2,5,5]{4,3,2,1,0} convert(exponential.2)
+ constant.34 = f32[] constant(0)
+ reduce.3 = f32[8,16,2,5]{3,2,1,0} reduce(convert.7, constant.34), dimensions={3}, to_apply=region_1.48
+ convert.8 = f16[8,16,2,5]{3,2,1,0} convert(reduce.3)
+ broadcast.13 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(convert.8), dimensions={0,1,2,4}
+ divide.2 = f16[8,16,2,5,5]{4,3,2,1,0} divide(exponential.2, broadcast.13)
+ bitcast.98 = f16[8,16,2,5,5]{3,4,2,1,0} bitcast(divide.2)
+ arg1.2 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(1), parameter_replication={false}
+ bitcast.102 = f16[640,128]{1,0} bitcast(arg1.2)
+ arg7.8 = f32[128,2,64]{2,1,0} parameter(7), parameter_replication={false}
+ convert.9 = f16[128,2,64]{2,1,0} convert(arg7.8)
+ bitcast.107 = f16[128,128]{1,0} bitcast(convert.9)
+ dot.3 = f16[640,128]{1,0} dot(bitcast.102, bitcast.107), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
+ bitcast.110 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.3)
+ transpose.30 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.110), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
+ arg8.9 = f32[2,64]{1,0} parameter(8), parameter_replication={false}
+ convert.10 = f16[2,64]{1,0} convert(arg8.9)
+ broadcast.14 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.10), dimensions={3,5}
+ add.4 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.30, broadcast.14)
+ bitcast.120 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.4)
+ ROOT dot.7 = f16[8,16,2,5,64]{4,3,2,1,0} dot(bitcast.98, bitcast.120), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
+} // main
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ const auto status_or = RunHloPass(&fusedMhaRewriter, m.get());
+ TF_ASSERT_OK(status_or.status());
+ EXPECT_FALSE(status_or.value());
+
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot()));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm2Grad1IncorrectPattern) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+ Arg_0.22 = f16[] parameter(0)
+ Arg_1.23 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+ Arg_0.34 = f32[] parameter(0)
+ Arg_1.35 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+ Arg_0.56 = f16[] parameter(0)
+ Arg_1.57 = f16[] parameter(1)
+ ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+ Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.22 = f16[] constant(2)
+ broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+ multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+ constant.19 = f16[] constant(1)
+ broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+ add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+ constant.21 = f16[] constant(0)
+ constant.15 = f16[] constant(-inf)
+ reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
+ broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+ subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
+ exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+ constant.17 = f32[] constant(0)
+ reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+ convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+ broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+ Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+ broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+ multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+ broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+ multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+ multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+ reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+ negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
+ broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+ add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+ multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+ multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
+ dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ // add another user of ds multiply.9 here, neg.1 should not be pattern matched as bmm2grad1
+ neg.1 = f16[2,6,128,128]{3,2,1,0} negate(multiply.9)
+ dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
+ m::Negate())));
+}
+
+// flash attention
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ FlashAttentionBF16TrainingBmm1CausalMaskSoftmaxBmm2Pattern) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+region_0.32 {
+ Arg_0.33 = bf16[] parameter(0)
+ Arg_1.34 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
+}
+region_1.44 {
+ Arg_0.45 = f32[] parameter(0)
+ Arg_1.46 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+region_2.66 {
+ Arg_0.67 = bf16[] parameter(0)
+ Arg_1.68 = bf16[] parameter(1)
+ ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
+}
+ENTRY main.92 {
+ Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.17 = bf16[] constant(2)
+ broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
+ multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
+ iota.2 = s32[2048,2048]{1,0} iota(), iota_dimension=0
+ iota.5 = s32[2048,2048]{1,0} iota(), iota_dimension=1
+ compare.1 = pred[2048,2048]{1,0} compare(iota.2, iota.5), direction=LT
+ constant.6 = bf16[] constant(-2.366e+38)
+ broadcast.16 = bf16[2048,2048]{1,0} broadcast(constant.6), dimensions={}
+ constant.16 = bf16[] constant(0)
+ broadcast.17 = bf16[2048,2048]{1,0} broadcast(constant.16), dimensions={}
+ select.2 = bf16[2048,2048]{1,0} select(compare.1, broadcast.16, broadcast.17)
+ broadcast.19 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(select.2), dimensions={2,3}
+ add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, broadcast.19)
+ constant.10 = bf16[] constant(-inf)
+ reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
+ broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+ subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
+ exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
+ constant.14 = f32[] constant(0)
+ reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
+ convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
+ broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
+ Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
+ constant.15 = bf16[] constant(1)
+ broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
+ multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
+ broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+ multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
+ multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
+ reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
+ negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
+ broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+ add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
+ multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
+ multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
+ dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ const HloInstruction* fwd_fmha;
+ const HloInstruction* bwd_fmha;
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(
+ m::CustomCall(&fwd_fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+ .WithShape(BF16, {2, 6, 2048, 128}),
+ m::GetTupleElement(
+ m::CustomCall(&bwd_fmha, {kCudnnfMHASoftmaxBackwardCallTarget}),
+ 0)
+ .WithShape(BF16, {2, 6, 2048, 128}),
+ m::Transpose(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
+ .WithShape(BF16, {2, 6, 128, 2048}),
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
+ .WithShape(BF16, {2, 6, 2048, 128}))));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fwd_fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(fwd_fmha->operands().size(), 3);
+ EXPECT_EQ(bwd_fmha->operands().size(), 6);
+ EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+ EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ FlashAttentionBF16TrainingBmm1BiasSoftmaxBmm2Pattern) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+region_0.32 {
+ Arg_0.33 = bf16[] parameter(0)
+ Arg_1.34 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
+}
+region_1.44 {
+ Arg_0.45 = f32[] parameter(0)
+ Arg_1.46 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+region_2.66 {
+ Arg_0.67 = bf16[] parameter(0)
+ Arg_1.68 = bf16[] parameter(1)
+ ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
+}
+ENTRY main.92 {
+ Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.17 = bf16[] constant(2)
+ broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
+ multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
+ // bias
+ Arg_4.5 = bf16[2,6,2048,2048]{3,2,1,0} parameter(4), sharding={replicated}
+ add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, Arg_4.5)
+ constant.10 = bf16[] constant(-inf)
+ constant.16 = bf16[] constant(0)
+ reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
+ broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+ subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
+ exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
+ constant.14 = f32[] constant(0)
+ reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
+ convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
+ broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
+ Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
+ constant.15 = bf16[] constant(1)
+ broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
+ multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
+ broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+ multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
+ multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
+ reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
+ negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
+ broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+ add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
+ multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
+ multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
+ dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
+ .WithShape(BF16, {2, 6, 2048, 128}),
+ m::GetTupleElement(
+ m::CustomCall(&fmha,
+ {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 0)
+ .WithShape(BF16, {2, 6, 2048, 128}),
+ m::Transpose(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 1))
+ .WithShape(BF16, {2, 6, 128, 2048}),
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
+ .WithShape(BF16, {2, 6, 2048, 128}))));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(fmha->operands().size(), 7);
+ EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+ EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ FlashAttentionBF16TrainingBmm1SoftmaxBmm2Pattern) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+region_0.32 {
+ Arg_0.33 = bf16[] parameter(0)
+ Arg_1.34 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
+}
+region_1.44 {
+ Arg_0.45 = f32[] parameter(0)
+ Arg_1.46 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+region_2.66 {
+ Arg_0.67 = bf16[] parameter(0)
+ Arg_1.68 = bf16[] parameter(1)
+ ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
+}
+ENTRY main.92 {
+ Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
+ dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ constant.17 = bf16[] constant(2)
+ broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
+ multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
+ constant.10 = bf16[] constant(-inf)
+ constant.16 = bf16[] constant(0)
+ reduce.36 = bf16[2,6,2048]{2,1,0} reduce(multiply.2, constant.10), dimensions={3}, to_apply=region_0.32
+ broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+ subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(multiply.2, broadcast.21)
+ exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
+ convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
+ constant.14 = f32[] constant(0)
+ reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
+ convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
+ broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+ divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
+ Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
+ dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
+ dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
+ constant.15 = bf16[] constant(1)
+ broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
+ multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
+ divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
+ broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+ multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
+ multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
+ reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
+ negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
+ broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+ add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
+ multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
+ multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
+ dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ HloDCE dce;
+ TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+ .WithShape(BF16, {2, 6, 2048, 128}),
+ m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), 0)
+ .WithShape(BF16, {2, 6, 2048, 128}),
+ m::Transpose(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
+ .WithShape(BF16, {2, 6, 128, 2048}),
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
+ .WithShape(BF16, {2, 6, 2048, 128}))));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(fmha->operands().size(), 6);
+ EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+ EXPECT_FLOAT_EQ(config.fmha_scale(), 2);
+ EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
+}
+
+// GPT3 pattern
+TEST_F(CudnnFusedMhaRewriterTestHloTest, FlashAttentionBF16TrainingGPT3_5B) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={((s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}))->(s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0})}
+add {
+ x = bf16[] parameter(0)
+ y = bf16[] parameter(1)
+ ROOT add.580 = bf16[] add(x, y)
+}
+
+region_20.962 {
+ Arg_0.963 = f32[] parameter(0)
+ Arg_1.964 = f32[] parameter(1)
+ ROOT add.579 = f32[] add(Arg_0.963, Arg_1.964)
+}
+
+region_39.1120 {
+ Arg_0.1121 = f32[] parameter(0)
+ Arg_1.1122 = f32[] parameter(1)
+ ROOT maximum.21 = f32[] maximum(Arg_0.1121, Arg_1.1122)
+}
+
+main {
+ param.3 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) parameter(0)
+ get-tuple-element.31 = s32[] get-tuple-element(param.3), index=0
+ constant.1961 = s32[] constant(1)
+ add.581 = s32[] add(get-tuple-element.31, constant.1961)
+ get-tuple-element.32 = bf16[24,32,2048,2048]{2,1,3,0} get-tuple-element(param.3), index=25
+ bitcast.187 = bf16[24,2048,32,2048]{3,2,1,0} bitcast(get-tuple-element.32)
+ constant.1977 = s32[] constant(23)
+ subtract.221 = s32[] subtract(constant.1977, get-tuple-element.31)
+ constant.1980 = s32[] constant(0)
+ compare.210 = pred[] compare(subtract.221, constant.1980), direction=LT
+ constant.1979 = s32[] constant(47)
+ subtract.222 = s32[] subtract(constant.1979, get-tuple-element.31)
+ select.372 = s32[] select(compare.210, subtract.222, subtract.221)
+ dynamic-slice.324 = bf16[1,2048,32,2048]{3,2,1,0} dynamic-slice(bitcast.187, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,2048,32,2048}
+ bitcast.756 = bf16[2048,32,2048]{2,1,0} bitcast(dynamic-slice.324)
+ convert.282 = f32[2048,32,2048]{2,1,0} convert(bitcast.756)
+ constant.1991 = bf16[] constant(1)
+ broadcast.1270 = bf16[32,2048]{1,0} broadcast(constant.1991), dimensions={}
+ get-tuple-element.33 = bf16[32,2048]{1,0} get-tuple-element(param.3), index=27
+ subtract.229 = bf16[32,2048]{1,0} subtract(broadcast.1270, get-tuple-element.33)
+ convert.285 = f32[32,2048]{1,0} convert(subtract.229)
+ broadcast.1228 = f32[2048,32,2048]{2,1,0} broadcast(convert.285), dimensions={1,2}
+ multiply.656 = f32[2048,32,2048]{2,1,0} multiply(convert.282, broadcast.1228)
+ bitcast.367 = f32[32,2048,2048]{1,0,2} bitcast(multiply.656)
+ constant.1968 = f32[] constant(0)
+ reduce.84 = f32[] reduce(bitcast.367, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+ all-reduce.230 = f32[] all-reduce(reduce.84), channel_id=278, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ broadcast.1221 = f32[32,2048,4096]{2,1,0} broadcast(convert.285), dimensions={0,1}
+ reduce.85 = f32[] reduce(broadcast.1221, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+ all-reduce.14 = f32[] all-reduce(reduce.85), channel_id=49, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=region_20.962
+ constant.2005 = f32[] constant(1)
+ maximum.24 = f32[] maximum(all-reduce.14, constant.2005)
+ divide.96 = f32[] divide(all-reduce.230, maximum.24)
+ broadcast.1223 = f32[2048,32,2048]{2,1,0} broadcast(divide.96), dimensions={}
+ subtract.219 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1223)
+ multiply.644 = f32[2048,32,2048]{2,1,0} multiply(subtract.219, broadcast.1228)
+ multiply.645 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, multiply.644)
+ bitcast.271 = f32[32,2048,2048]{1,0,2} bitcast(multiply.645)
+ reduce.86 = f32[] reduce(bitcast.271, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+ all-reduce.231 = f32[] all-reduce(reduce.86), channel_id=279, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ divide.99 = f32[] divide(all-reduce.231, maximum.24)
+ rsqrt.16 = f32[] rsqrt(divide.99)
+ multiply.650 = f32[] multiply(rsqrt.16, constant.1968)
+ divide.100 = f32[] divide(multiply.650, maximum.24)
+ constant.1974 = f32[] constant(2)
+ multiply.652 = f32[] multiply(divide.100, constant.1974)
+ broadcast.1227 = f32[2048,32,2048]{2,1,0} broadcast(multiply.652), dimensions={}
+ multiply.653 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, broadcast.1227)
+ multiply.654 = f32[2048,32,2048]{2,1,0} multiply(multiply.653, broadcast.1228)
+ negate.56 = f32[2048,32,2048]{2,1,0} negate(multiply.654)
+ bitcast.321 = f32[32,2048,2048]{1,0,2} bitcast(negate.56)
+ reduce.87 = f32[] reduce(bitcast.321, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+ all-reduce.232 = f32[] all-reduce(reduce.87), channel_id=280, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ divide.101 = f32[] divide(all-reduce.232, maximum.24)
+ broadcast.1229 = f32[32,2048]{1,0} broadcast(divide.101), dimensions={}
+ multiply.655 = f32[32,2048]{1,0} multiply(broadcast.1229, convert.285)
+ broadcast.1230 = f32[2048,32,2048]{2,1,0} broadcast(multiply.655), dimensions={1,2}
+ add.582 = f32[2048,32,2048]{2,1,0} add(multiply.654, broadcast.1230)
+ broadcast.1236 = f32[2048,32,2048]{2,1,0} broadcast(constant.1968), dimensions={}
+ compare.208 = pred[2048,32,2048]{2,1,0} compare(multiply.656, broadcast.1236), direction=GE
+ abs.22 = f32[2048,32,2048]{2,1,0} abs(multiply.656)
+ bitcast.373 = f32[32,2048,2048]{1,0,2} bitcast(abs.22)
+ constant.1989 = f32[] constant(-inf)
+ reduce.88 = f32[] reduce(bitcast.373, constant.1989), dimensions={0,1,2}, to_apply=region_39.1120
+ all-reduce.233 = f32[] all-reduce(reduce.88), channel_id=281, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_39.1120
+ broadcast.1233 = f32[2048,32,2048]{2,1,0} broadcast(all-reduce.233), dimensions={}
+ compare.207 = pred[2048,32,2048]{2,1,0} compare(abs.22, broadcast.1233), direction=EQ
+ convert.286 = f32[2048,32,2048]{2,1,0} convert(compare.207)
+ bitcast.393 = f32[32,2048,2048]{1,0,2} bitcast(convert.286)
+ reduce.89 = f32[] reduce(bitcast.393, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+ all-reduce.234 = f32[] all-reduce(reduce.89), channel_id=282, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ divide.103 = f32[] divide(constant.1968, all-reduce.234)
+ broadcast.1238 = f32[2048,32,2048]{2,1,0} broadcast(divide.103), dimensions={}
+ select.370 = f32[2048,32,2048]{2,1,0} select(compare.207, broadcast.1238, broadcast.1236)
+ select.369 = f32[2048,32,2048]{2,1,0} select(compare.208, select.370, broadcast.1236)
+ constant.1976 = pred[] constant(false)
+ broadcast.1237 = pred[2048,32,2048]{2,1,0} broadcast(constant.1976), dimensions={}
+ compare.209 = pred[2048,32,2048]{2,1,0} compare(compare.208, broadcast.1237), direction=EQ
+ select.371 = f32[2048,32,2048]{2,1,0} select(compare.209, select.370, broadcast.1236)
+ negate.57 = f32[2048,32,2048]{2,1,0} negate(select.371)
+ add.583 = f32[2048,32,2048]{2,1,0} add(select.369, negate.57)
+ multiply.658 = f32[2048,32,2048]{2,1,0} multiply(add.583, broadcast.1228)
+ add.585 = f32[2048,32,2048]{2,1,0} add(add.582, multiply.658)
+ convert.287 = bf16[2048,32,2048]{2,1,0} convert(add.585)
+ get-tuple-element.34 = bf16[32,2048,2048]{1,0,2} get-tuple-element(param.3), index=1
+ bitcast.1652 = bf16[2048,32,2048]{2,1,0} bitcast(get-tuple-element.34)
+ get-tuple-element.35 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=22
+ bitcast.461 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.35)
+ dynamic-slice.325 = bf16[1,1024,3,16,128]{4,3,2,1,0} dynamic-slice(bitcast.461, select.372, constant.1980, constant.1980, constant.1980, /*index=5*/constant.1980), dynamic_slice_sizes={1,1024,3,16,128}
+ bitcast.485 = bf16[3,1024,16,128]{3,2,0,1} bitcast(dynamic-slice.325)
+ all-gather.7 = bf16[3,4096,16,128]{3,2,0,1} all-gather(bitcast.485), channel_id=60, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
+ bitcast.1420 = bf16[6144,4096]{0,1} bitcast(all-gather.7)
+ bitcast.500 = f32[32,2048,2048]{1,0,2} bitcast(convert.282)
+ reduce.90 = f32[32,2048]{1,0} reduce(bitcast.500, constant.1968), dimensions={2}, to_apply=region_20.962
+ all-reduce.23 = f32[32,2048]{1,0} all-reduce(reduce.90), channel_id=58, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ constant.1983 = f32[] constant(0.000244140625)
+ broadcast.1243 = f32[32,2048]{1,0} broadcast(constant.1983), dimensions={}
+ multiply.660 = f32[32,2048]{1,0} multiply(all-reduce.23, broadcast.1243)
+ broadcast.1242 = f32[2048,32,2048]{2,1,0} broadcast(multiply.660), dimensions={1,2}
+ subtract.224 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1242)
+ multiply.661 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, subtract.224)
+ bitcast.527 = f32[32,2048,2048]{1,0,2} bitcast(multiply.661)
+ reduce.91 = f32[32,2048]{1,0} reduce(bitcast.527, constant.1968), dimensions={2}, to_apply=region_20.962
+ all-reduce.24 = f32[32,2048]{1,0} all-reduce(reduce.91), channel_id=59, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ multiply.662 = f32[32,2048]{1,0} multiply(all-reduce.24, broadcast.1243)
+ constant.1990 = f32[] constant(1e-05)
+ broadcast.1264 = f32[32,2048]{1,0} broadcast(constant.1990), dimensions={}
+ add.587 = f32[32,2048]{1,0} add(multiply.662, broadcast.1264)
+ bitcast.1447 = f32[1,32,2048]{2,1,0} bitcast(add.587)
+ rsqrt.20 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1447)
+ bitcast.1892 = f32[32,2048]{1,0} bitcast(rsqrt.20)
+ broadcast.1337 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1892), dimensions={1,2}
+ multiply.754 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1337)
+ convert.314 = bf16[2048,32,2048]{2,1,0} convert(multiply.754)
+ get-tuple-element.36 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=20
+ dynamic-slice.326 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.36, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+ broadcast.1266 = bf16[1,2048]{1,0} broadcast(constant.1991), dimensions={}
+ add.588 = bf16[1,2048]{1,0} add(dynamic-slice.326, broadcast.1266)
+ bitcast.1992 = bf16[2048]{0} bitcast(add.588)
+ broadcast.1338 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1992), dimensions={0}
+ multiply.755 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, broadcast.1338)
+ get-tuple-element.37 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=19
+ dynamic-slice.327 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.37, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+ bitcast.1998 = bf16[2048]{0} bitcast(dynamic-slice.327)
+ broadcast.1339 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1998), dimensions={0}
+ add.640 = bf16[2048,32,2048]{2,1,0} add(multiply.755, broadcast.1339)
+ bitcast.2003 = bf16[32,2048,2048]{1,0,2} bitcast(add.640)
+ all-gather.8 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=61, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.597 = bf16[4096,65536]{1,0} bitcast(all-gather.8)
+ dot.42 = bf16[6144,65536]{1,0} dot(bitcast.1420, bitcast.597), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitcast.623 = bf16[3,16,128,32,2048]{4,3,2,1,0} bitcast(dot.42)
+ transpose.112 = bf16[3,32,16,128,2048]{4,3,2,1,0} transpose(bitcast.623), dimensions={0,3,1,2,4}
+ get-tuple-element.38 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=21
+ dynamic-slice.328 = bf16[1,3,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.38, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,3,16,128}
+ bitcast.626 = bf16[3,16,128]{2,1,0} bitcast(dynamic-slice.328)
+ broadcast.1250 = bf16[3,32,16,128,2048]{4,3,2,1,0} broadcast(bitcast.626), dimensions={0,2,3}
+ add.591 = bf16[3,32,16,128,2048]{4,3,2,1,0} add(transpose.112, broadcast.1250)
+ slice.87 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[2:3], [0:32], [0:16], [0:128], [0:2048]}
+ bitcast.1280 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.87)
+ slice.88 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[0:1], [0:32], [0:16], [0:128], [0:2048]}
+ constant.2007 = bf16[] constant(0.08838)
+ broadcast.1251 = bf16[1,32,16,128,2048]{4,3,2,1,0} broadcast(constant.2007), dimensions={}
+ multiply.666 = bf16[1,32,16,128,2048]{4,3,2,1,0} multiply(slice.88, broadcast.1251)
+ bitcast.1330 = bf16[32,16,128,2048]{3,2,1,0} bitcast(multiply.666)
+ transpose.113 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1330), dimensions={0,1,3,2}
+ slice.89 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[1:2], [0:32], [0:16], [0:128], [0:2048]}
+ bitcast.647 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.89)
+ dot.43 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.113, bitcast.647), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ convert.291 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.43)
+ get-tuple-element.39 = bf16[32,1,2048,2048]{3,2,0,1} get-tuple-element(param.3), index=26
+ bitcast.651 = bf16[1,32,2048,2048]{3,2,1,0} bitcast(get-tuple-element.39)
+ iota.38 = s32[2048,2048]{1,0} iota(), iota_dimension=0
+ iota.39 = s32[2048,2048]{1,0} iota(), iota_dimension=1
+ compare.211 = pred[2048,2048]{1,0} compare(iota.38, iota.39), direction=LT
+ constant.1987 = bf16[] constant(-2.366e+38)
+ broadcast.1252 = bf16[2048,2048]{1,0} broadcast(constant.1987), dimensions={}
+ constant.2006 = bf16[] constant(0)
+ broadcast.1253 = bf16[2048,2048]{1,0} broadcast(constant.2006), dimensions={}
+ select.373 = bf16[2048,2048]{1,0} select(compare.211, broadcast.1252, broadcast.1253)
+ broadcast.1254 = bf16[1,32,2048,2048]{3,2,1,0} broadcast(select.373), dimensions={2,3}
+ minimum.5 = bf16[1,32,2048,2048]{3,2,1,0} minimum(bitcast.651, broadcast.1254)
+ bitcast.673 = bf16[32,2048,2048]{2,1,0} bitcast(minimum.5)
+ convert.292 = f32[32,2048,2048]{2,1,0} convert(bitcast.673)
+ broadcast.1256 = f32[32,16,2048,2048]{3,2,1,0} broadcast(convert.292), dimensions={0,2,3}
+ add.593 = f32[32,16,2048,2048]{3,2,1,0} add(convert.291, broadcast.1256)
+ reduce.92 = f32[32,16,2048]{2,1,0} reduce(add.593, constant.1989), dimensions={3}, to_apply=region_39.1120
+ broadcast.1258 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.92), dimensions={0,1,2}
+ subtract.226 = f32[32,16,2048,2048]{3,2,1,0} subtract(add.593, broadcast.1258)
+ exponential.8 = f32[32,16,2048,2048]{3,2,1,0} exponential(subtract.226)
+ reduce.93 = f32[32,16,2048]{2,1,0} reduce(exponential.8, constant.1968), dimensions={3}, to_apply=region_20.962
+ broadcast.1309 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.93), dimensions={0,1,2}
+ divide.109 = f32[32,16,2048,2048]{3,2,1,0} divide(exponential.8, broadcast.1309)
+ convert.306 = bf16[32,16,2048,2048]{3,2,1,0} convert(divide.109)
+ dot.44 = bf16[32,16,128,2048]{3,2,1,0} dot(bitcast.1280, convert.306), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ transpose.116 = bf16[32,2048,16,128]{3,2,1,0} transpose(dot.44), dimensions={0,3,1,2}
+ bitcast.711 = bf16[65536,2048]{1,0} bitcast(transpose.116)
+ get-tuple-element.40 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=24
+ dynamic-slice.329 = bf16[1,1024,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.40, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,16,128}
+ bitcast.724 = bf16[1024,16,128]{2,1,0} bitcast(dynamic-slice.329)
+ all-gather.9 = bf16[4096,16,128]{2,1,0} all-gather(bitcast.724), channel_id=62, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
+ bitcast.729 = bf16[2048,4096]{0,1} bitcast(all-gather.9)
+ dot.57 = bf16[65536,4096]{0,1} dot(bitcast.711, bitcast.729), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitcast.733 = bf16[32,2048,4096]{1,0,2} bitcast(dot.57)
+ reduce-scatter = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.733), channel_id=322, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
+ bitcast.763 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter)
+ get-tuple-element.41 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=23
+ dynamic-slice.330 = bf16[1,1024]{1,0} dynamic-slice(get-tuple-element.41, select.372, constant.1980), dynamic_slice_sizes={1,1024}
+ bitcast.748 = bf16[1024]{0} bitcast(dynamic-slice.330)
+ collective-permute.1 = bf16[1024]{0} collective-permute(bitcast.748), channel_id=64, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+ all-gather.10 = bf16[2048]{0} all-gather(collective-permute.1), channel_id=65, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={0}, use_global_device_ids=true
+ broadcast.1261 = bf16[2048,32,2048]{2,1,0} broadcast(all-gather.10), dimensions={0}
+ add.596 = bf16[2048,32,2048]{2,1,0} add(bitcast.763, broadcast.1261)
+ add.597 = bf16[2048,32,2048]{2,1,0} add(add.596, bitcast.756)
+ convert.295 = f32[2048,32,2048]{2,1,0} convert(add.597)
+ bitcast.774 = f32[32,2048,2048]{1,0,2} bitcast(convert.295)
+ reduce.94 = f32[32,2048]{1,0} reduce(bitcast.774, constant.1968), dimensions={2}, to_apply=region_20.962
+ all-reduce.26 = f32[32,2048]{1,0} all-reduce(reduce.94), channel_id=66, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ multiply.668 = f32[32,2048]{1,0} multiply(all-reduce.26, broadcast.1243)
+ broadcast.1263 = f32[2048,32,2048]{2,1,0} broadcast(multiply.668), dimensions={1,2}
+ subtract.228 = f32[2048,32,2048]{2,1,0} subtract(convert.295, broadcast.1263)
+ multiply.669 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, subtract.228)
+ bitcast.809 = f32[32,2048,2048]{1,0,2} bitcast(multiply.669)
+ reduce.95 = f32[32,2048]{1,0} reduce(bitcast.809, constant.1968), dimensions={2}, to_apply=region_20.962
+ all-reduce.27 = f32[32,2048]{1,0} all-reduce(reduce.95), channel_id=67, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ multiply.670 = f32[32,2048]{1,0} multiply(all-reduce.27, broadcast.1243)
+ add.598 = f32[32,2048]{1,0} add(multiply.670, broadcast.1264)
+ bitcast.1148 = f32[1,32,2048]{2,1,0} bitcast(add.598)
+ rsqrt.19 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1148)
+ bitcast.1602 = f32[32,2048]{1,0} bitcast(rsqrt.19)
+ broadcast.1329 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1602), dimensions={1,2}
+ multiply.750 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1329)
+ convert.312 = bf16[2048,32,2048]{2,1,0} convert(multiply.750)
+ get-tuple-element.42 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=18
+ dynamic-slice.331 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.42, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+ add.599 = bf16[1,2048]{1,0} add(dynamic-slice.331, broadcast.1266)
+ bitcast.1609 = bf16[2048]{0} bitcast(add.599)
+ broadcast.1330 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1609), dimensions={0}
+ multiply.745 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, broadcast.1330)
+ get-tuple-element.43 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=17
+ dynamic-slice.332 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.43, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+ bitcast.1615 = bf16[2048]{0} bitcast(dynamic-slice.332)
+ broadcast.1331 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1615), dimensions={0}
+ add.636 = bf16[2048,32,2048]{2,1,0} add(multiply.745, broadcast.1331)
+ bitcast.1620 = bf16[32,2048,2048]{1,0,2} bitcast(add.636)
+ all-gather.12 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=69, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.877 = bf16[65536,4096]{0,1} bitcast(all-gather.12)
+ get-tuple-element.44 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=15
+ dynamic-slice.333 = bf16[1,1024,8192]{2,1,0} dynamic-slice(get-tuple-element.44, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
+ bitcast.890 = bf16[1024,8192]{1,0} bitcast(dynamic-slice.333)
+ all-gather.11 = bf16[4096,8192]{1,0} all-gather(bitcast.890), channel_id=68, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
+ dot.45 = bf16[65536,8192]{1,0} dot(bitcast.877, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ get-tuple-element.45 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=14
+ dynamic-slice.334 = bf16[1,8192]{1,0} dynamic-slice(get-tuple-element.45, select.372, constant.1980), dynamic_slice_sizes={1,8192}
+ bitcast.906 = bf16[8192]{0} bitcast(dynamic-slice.334)
+ broadcast.1269 = bf16[65536,8192]{1,0} broadcast(bitcast.906), dimensions={1}
+ add.601 = bf16[65536,8192]{1,0} add(dot.45, broadcast.1269)
+ bitcast.997 = bf16[32,2048,8192]{2,1,0} bitcast(add.601)
+ broadcast.1333 = bf16[2048,32,2048]{2,1,0} broadcast(subtract.229), dimensions={1,2}
+ multiply.746 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1652, broadcast.1333)
+ bitcast.1739 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.746)
+ all-gather.14 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=71, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.934 = bf16[65536,4096]{0,1} bitcast(all-gather.14)
+ get-tuple-element.46 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=16
+ bitcast.935 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.46)
+ dynamic-slice.335 = bf16[1,1024,8192]{2,1,0} dynamic-slice(bitcast.935, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
+ bitcast.947 = bf16[8192,1024]{0,1} bitcast(dynamic-slice.335)
+ all-gather.13 = bf16[8192,4096]{0,1} all-gather(bitcast.947), channel_id=70, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
+ dot.46 = bf16[65536,8192]{1,0} dot(bitcast.934, all-gather.13), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ bitcast.1092 = bf16[32,2048,8192]{2,1,0} bitcast(dot.46)
+ broadcast.1335 = bf16[32,2048,8192]{2,1,0} broadcast(subtract.229), dimensions={0,1}
+ multiply.703 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.1092, broadcast.1335)
+ multiply.685 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.703)
+ constant.2002 = bf16[] constant(0.5)
+ broadcast.1288 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2002), dimensions={}
+ multiply.686 = bf16[32,2048,8192]{2,1,0} multiply(multiply.685, broadcast.1288)
+ broadcast.1287 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1991), dimensions={}
+ multiply.700 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, bitcast.997)
+ multiply.693 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.700)
+ constant.1998 = bf16[] constant(0.04468)
+ broadcast.1282 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1998), dimensions={}
+ multiply.694 = bf16[32,2048,8192]{2,1,0} multiply(multiply.693, broadcast.1282)
+ add.605 = bf16[32,2048,8192]{2,1,0} add(bitcast.997, multiply.694)
+ constant.2010 = bf16[] constant(0.7969)
+ broadcast.1324 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2010), dimensions={}
+ multiply.695 = bf16[32,2048,8192]{2,1,0} multiply(add.605, broadcast.1324)
+ tanh.7 = bf16[32,2048,8192]{2,1,0} tanh(multiply.695)
+ subtract.231 = bf16[32,2048,8192]{2,1,0} subtract(broadcast.1287, tanh.7)
+ multiply.691 = bf16[32,2048,8192]{2,1,0} multiply(multiply.686, subtract.231)
+ multiply.737 = bf16[32,2048,8192]{2,1,0} multiply(multiply.691, tanh.7)
+ add.630 = bf16[32,2048,8192]{2,1,0} add(multiply.691, multiply.737)
+ multiply.738 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1324)
+ constant.2011 = bf16[] constant(0.03564)
+ broadcast.1326 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2011), dimensions={}
+ multiply.739 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1326)
+ constant.2012 = bf16[] constant(3)
+ broadcast.1327 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2012), dimensions={}
+ multiply.740 = bf16[32,2048,8192]{2,1,0} multiply(multiply.700, broadcast.1327)
+ multiply.741 = bf16[32,2048,8192]{2,1,0} multiply(multiply.739, multiply.740)
+ add.632 = bf16[32,2048,8192]{2,1,0} add(multiply.738, multiply.741)
+ add.637 = bf16[32,2048,8192]{2,1,0} add(tanh.7, broadcast.1287)
+ multiply.747 = bf16[32,2048,8192]{2,1,0} multiply(add.637, broadcast.1288)
+ multiply.743 = bf16[32,2048,8192]{2,1,0} multiply(multiply.703, multiply.747)
+ add.635 = bf16[32,2048,8192]{2,1,0} add(add.632, multiply.743)
+ bitcast.1629 = bf16[65536,8192]{1,0} bitcast(add.635)
+ dot.47 = bf16[65536,4096]{0,1} dot(bitcast.1629, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ bitcast.1130 = bf16[32,2048,4096]{1,0,2} bitcast(dot.47)
+ reduce-scatter.1 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1130), channel_id=323, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
+ bitcast.1766 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.1)
+ multiply.712 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1766, broadcast.1330)
+ convert.299 = f32[2048,32,2048]{2,1,0} convert(multiply.712)
+ multiply.707 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, convert.299)
+ bitcast.1135 = f32[32,2048,2048]{1,0,2} bitcast(multiply.707)
+ reduce.96 = f32[32,2048]{1,0} reduce(bitcast.1135, constant.1968), dimensions={2}, to_apply=region_20.962
+ all-reduce.29 = f32[32,2048]{1,0} all-reduce(reduce.96), channel_id=73, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ bitcast.1140 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.29)
+ divide.105 = f32[1,32,2048]{2,1,0} divide(rsqrt.19, bitcast.1148)
+ constant.2008 = f32[] constant(-0.5)
+ broadcast.1313 = f32[1,32,2048]{2,1,0} broadcast(constant.2008), dimensions={}
+ multiply.708 = f32[1,32,2048]{2,1,0} multiply(divide.105, broadcast.1313)
+ multiply.709 = f32[1,32,2048]{2,1,0} multiply(bitcast.1140, multiply.708)
+ constant.2009 = f32[] constant(0.00048828125)
+ broadcast.1315 = f32[1,32,2048]{2,1,0} broadcast(constant.2009), dimensions={}
+ multiply.710 = f32[1,32,2048]{2,1,0} multiply(multiply.709, broadcast.1315)
+ bitcast.1235 = f32[32,2048]{1,0} bitcast(multiply.710)
+ broadcast.1296 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1235), dimensions={1,2}
+ multiply.717 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1296)
+ multiply.718 = f32[2048,32,2048]{2,1,0} multiply(convert.299, broadcast.1329)
+ add.617 = f32[2048,32,2048]{2,1,0} add(multiply.717, multiply.718)
+ negate.58 = f32[2048,32,2048]{2,1,0} negate(multiply.717)
+ bitcast.1189 = f32[32,2048,2048]{1,0,2} bitcast(negate.58)
+ reduce.97 = f32[32,2048]{1,0} reduce(bitcast.1189, constant.1968), dimensions={2}, to_apply=region_20.962
+ negate.59 = f32[2048,32,2048]{2,1,0} negate(multiply.718)
+ bitcast.1203 = f32[32,2048,2048]{1,0,2} bitcast(negate.59)
+ reduce.98 = f32[32,2048]{1,0} reduce(bitcast.1203, constant.1968), dimensions={2}, to_apply=region_20.962
+ add.613 = f32[32,2048]{1,0} add(reduce.97, reduce.98)
+ all-reduce.274 = f32[32,2048]{1,0} all-reduce(add.613), channel_id=335, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ multiply.719 = f32[32,2048]{1,0} multiply(all-reduce.274, broadcast.1243)
+ broadcast.1297 = f32[2048,32,2048]{2,1,0} broadcast(multiply.719), dimensions={1,2}
+ add.618 = f32[2048,32,2048]{2,1,0} add(add.617, broadcast.1297)
+ convert.301 = bf16[2048,32,2048]{2,1,0} convert(add.618)
+ add.619 = bf16[2048,32,2048]{2,1,0} add(bitcast.1652, convert.301)
+ add.616 = bf16[2048,32,2048]{2,1,0} add(convert.287, add.619)
+ bitcast.2063 = bf16[32,2048,2048]{1,0,2} bitcast(add.619)
+ all-gather.15 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2063), channel_id=76, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.1263 = bf16[65536,4096]{0,1} bitcast(all-gather.15)
+ bitcast.1269 = bf16[4096,2048]{1,0} bitcast(all-gather.9)
+ dot.48 = bf16[65536,2048]{1,0} dot(bitcast.1263, bitcast.1269), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitcast.1381 = bf16[32,2048,16,128]{3,2,1,0} bitcast(dot.48)
+ transpose.122 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,1,3}
+ dot.49 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.122, bitcast.1280), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ convert.303 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.49)
+ broadcast.1298 = f32[32,16,2048]{2,1,0} broadcast(constant.2005), dimensions={}
+ multiply.720 = f32[32,16,2048]{2,1,0} multiply(reduce.93, reduce.93)
+ divide.106 = f32[32,16,2048]{2,1,0} divide(broadcast.1298, multiply.720)
+ broadcast.1299 = f32[32,16,2048,2048]{3,2,1,0} broadcast(divide.106), dimensions={0,1,2}
+ multiply.721 = f32[32,16,2048,2048]{3,2,1,0} multiply(convert.303, broadcast.1299)
+ multiply.722 = f32[32,16,2048,2048]{3,2,1,0} multiply(multiply.721, exponential.8)
+ reduce.99 = f32[32,16,2048]{2,1,0} reduce(multiply.722, constant.1968), dimensions={3}, to_apply=region_20.962
+ negate.61 = f32[32,16,2048]{2,1,0} negate(reduce.99)
+ broadcast.1305 = f32[32,16,2048,2048]{3,2,1,0} broadcast(negate.61), dimensions={0,1,2}
+ divide.108 = f32[32,16,2048,2048]{3,2,1,0} divide(convert.303, broadcast.1309)
+ add.622 = f32[32,16,2048,2048]{3,2,1,0} add(broadcast.1305, divide.108)
+ multiply.724 = f32[32,16,2048,2048]{3,2,1,0} multiply(add.622, exponential.8)
+ convert.305 = bf16[32,16,2048,2048]{3,2,1,0} convert(multiply.724)
+ dot.50 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.113), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ bitcast.1934 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.50)
+ pad.6 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1934, constant.2006), padding=1_1x0_0x0_0x0_0x0_0
+ transpose.120 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.647), dimensions={0,1,3,2}
+ dot.51 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.120), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ broadcast.1307 = bf16[32,16,2048,128]{3,2,1,0} broadcast(constant.2007), dimensions={}
+ multiply.725 = bf16[32,16,2048,128]{3,2,1,0} multiply(dot.51, broadcast.1307)
+ bitcast.1941 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(multiply.725)
+ pad.7 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1941, constant.2006), padding=0_2x0_0x0_0x0_0x0_0
+ add.638 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(pad.6, pad.7)
+ transpose.123 = bf16[32,16,128,2048]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,3,1}
+ dot.89 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.306, transpose.123), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ bitcast.1949 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.89)
+ pad.8 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1949, constant.2006), padding=2_0x0_0x0_0x0_0x0_0
+ add.639 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(add.638, pad.8)
+ transpose.127 = bf16[32,2048,3,16,128]{4,3,2,1,0} transpose(add.639), dimensions={1,3,0,2,4}
+ bitcast.1416 = bf16[65536,6144]{1,0} bitcast(transpose.127)
+ dot.52 = bf16[65536,4096]{0,1} dot(bitcast.1416, bitcast.1420), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitcast.1424 = bf16[32,2048,4096]{1,0,2} bitcast(dot.52)
+ reduce-scatter.2 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1424), channel_id=324, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
+ bitcast.1851 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.2)
+ multiply.732 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1851, broadcast.1338)
+ convert.308 = f32[2048,32,2048]{2,1,0} convert(multiply.732)
+ multiply.727 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, convert.308)
+ bitcast.1434 = f32[32,2048,2048]{1,0,2} bitcast(multiply.727)
+ reduce.100 = f32[32,2048]{1,0} reduce(bitcast.1434, constant.1968), dimensions={2}, to_apply=region_20.962
+ all-reduce.33 = f32[32,2048]{1,0} all-reduce(reduce.100), channel_id=78, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ bitcast.1439 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.33)
+ divide.110 = f32[1,32,2048]{2,1,0} divide(rsqrt.20, bitcast.1447)
+ multiply.728 = f32[1,32,2048]{2,1,0} multiply(divide.110, broadcast.1313)
+ multiply.729 = f32[1,32,2048]{2,1,0} multiply(bitcast.1439, multiply.728)
+ multiply.730 = f32[1,32,2048]{2,1,0} multiply(multiply.729, broadcast.1315)
+ bitcast.1485 = f32[32,2048]{1,0} bitcast(multiply.730)
+ broadcast.1321 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1485), dimensions={1,2}
+ multiply.734 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1321)
+ multiply.735 = f32[2048,32,2048]{2,1,0} multiply(convert.308, broadcast.1337)
+ add.625 = f32[2048,32,2048]{2,1,0} add(multiply.734, multiply.735)
+ negate.62 = f32[2048,32,2048]{2,1,0} negate(multiply.734)
+ bitcast.1491 = f32[32,2048,2048]{1,0,2} bitcast(negate.62)
+ reduce.101 = f32[32,2048]{1,0} reduce(bitcast.1491, constant.1968), dimensions={2}, to_apply=region_20.962
+ negate.63 = f32[2048,32,2048]{2,1,0} negate(multiply.735)
+ bitcast.1505 = f32[32,2048,2048]{1,0,2} bitcast(negate.63)
+ reduce.102 = f32[32,2048]{1,0} reduce(bitcast.1505, constant.1968), dimensions={2}, to_apply=region_20.962
+ add.626 = f32[32,2048]{1,0} add(reduce.101, reduce.102)
+ all-reduce.275 = f32[32,2048]{1,0} all-reduce(add.626), channel_id=336, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+ multiply.736 = f32[32,2048]{1,0} multiply(all-reduce.275, broadcast.1243)
+ broadcast.1323 = f32[2048,32,2048]{2,1,0} broadcast(multiply.736), dimensions={1,2}
+ add.628 = f32[2048,32,2048]{2,1,0} add(add.625, broadcast.1323)
+ convert.309 = bf16[2048,32,2048]{2,1,0} convert(add.628)
+ add.629 = bf16[2048,32,2048]{2,1,0} add(add.616, convert.309)
+ bitcast.1525 = bf16[32,2048,2048]{1,0,2} bitcast(add.629)
+ get-tuple-element.47 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=2
+ reduce.103 = bf16[8192]{0} reduce(add.635, constant.2006), dimensions={0,1}, to_apply=add
+ all-reduce.36 = bf16[8192]{0} all-reduce(reduce.103), channel_id=81, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ bitcast.1583 = bf16[1,8192]{1,0} bitcast(all-reduce.36)
+ dynamic-update-slice.28 = bf16[24,8192]{1,0} dynamic-update-slice(get-tuple-element.47, bitcast.1583, select.372, constant.1980)
+ get-tuple-element.48 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=3
+ all-gather.16 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=82, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.1625 = bf16[4096,65536]{1,0} bitcast(all-gather.16)
+ dot.53 = bf16[4096,8192]{1,0} dot(bitcast.1625, bitcast.1629), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ reduce-scatter.3 = bf16[1024,8192]{1,0} reduce-scatter(dot.53), channel_id=325, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={0}, to_apply=add
+ bitcast.1634 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.3)
+ dynamic-update-slice.29 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(get-tuple-element.48, bitcast.1634, select.372, constant.1980, constant.1980)
+ get-tuple-element.49 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=4
+ collective-permute.2 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.49), channel_id=85, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+ all-gather.17 = bf16[24,2048]{0,1} all-gather(collective-permute.2), channel_id=86, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+ bitcast.1649 = bf16[2048,24]{1,0} bitcast(all-gather.17)
+ reduce.104 = bf16[2048]{0} reduce(bitcast.1739, constant.2006), dimensions={0,1}, to_apply=add
+ all-reduce.38 = bf16[2048]{0} all-reduce(reduce.104), channel_id=84, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ bitcast.1671 = bf16[2048,1]{1,0} bitcast(all-reduce.38)
+ dynamic-update-slice.30 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1649, bitcast.1671, constant.1980, select.372)
+ constant.2013 = s32[8]{0} constant({0, 2048, 0, 2048, 1024, 3072, 1024, 3072})
+ partition-id.3 = u32[] partition-id()
+ dynamic-slice.336 = s32[1]{0} dynamic-slice(constant.2013, partition-id.3), dynamic_slice_sizes={1}
+ constant.2014 = s32[8]{0} constant({0, 2048, 0, 2048, 0, 2048, 0, 2048})
+ dynamic-slice.337 = s32[1]{0} dynamic-slice(constant.2014, partition-id.3), dynamic_slice_sizes={1}
+ subtract.232 = s32[1]{0} subtract(dynamic-slice.336, dynamic-slice.337)
+ bitcast.2087 = s32[] bitcast(subtract.232)
+ dynamic-slice.338 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.30, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+ bitcast.1695 = bf16[24,1024]{0,1} bitcast(dynamic-slice.338)
+ collective-permute.9 = bf16[24,1024]{0,1} collective-permute(bitcast.1695), channel_id=109, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+ get-tuple-element.50 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=5
+ bitcast.1698 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.50)
+ multiply.748 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.747)
+ multiply.749 = bf16[32,2048,8192]{2,1,0} multiply(multiply.748, broadcast.1335)
+ bitcast.1735 = bf16[8192,65536]{0,1} bitcast(multiply.749)
+ all-gather.18 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=87, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.1743 = bf16[65536,4096]{0,1} bitcast(all-gather.18)
+ dot.54 = bf16[8192,4096]{0,1} dot(bitcast.1735, bitcast.1743), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ reduce-scatter.4 = bf16[8192,1024]{0,1} reduce-scatter(dot.54), channel_id=326, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+ bitcast.1748 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.4)
+ dynamic-update-slice.31 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(bitcast.1698, bitcast.1748, select.372, constant.1980, constant.1980)
+ bitcast.1758 = bf16[24,8192,1024]{1,2,0} bitcast(dynamic-update-slice.31)
+ get-tuple-element.51 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=6
+ collective-permute.3 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.51), channel_id=90, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+ all-gather.19 = bf16[24,2048]{0,1} all-gather(collective-permute.3), channel_id=91, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+ bitcast.1763 = bf16[2048,24]{1,0} bitcast(all-gather.19)
+ reduce.105 = bf16[2048]{0} reduce(reduce-scatter.1, constant.2006), dimensions={0,1}, to_apply=add
+ all-reduce.40 = bf16[2048]{0} all-reduce(reduce.105), channel_id=89, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ bitcast.1779 = bf16[2048,1]{1,0} bitcast(all-reduce.40)
+ dynamic-update-slice.32 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1763, bitcast.1779, constant.1980, select.372)
+ dynamic-slice.339 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.32, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+ bitcast.1794 = bf16[24,1024]{0,1} bitcast(dynamic-slice.339)
+ collective-permute.10 = bf16[24,1024]{0,1} collective-permute(bitcast.1794), channel_id=110, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+ get-tuple-element.52 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=7
+ collective-permute.4 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.52), channel_id=93, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+ all-gather.20 = bf16[24,2048]{0,1} all-gather(collective-permute.4), channel_id=94, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+ bitcast.1801 = bf16[2048,24]{1,0} bitcast(all-gather.20)
+ multiply.751 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, bitcast.1766)
+ bitcast.1817 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.751)
+ reduce.106 = bf16[2048]{0} reduce(bitcast.1817, constant.2006), dimensions={0,1}, to_apply=add
+ all-reduce.41 = bf16[2048]{0} all-reduce(reduce.106), channel_id=92, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ bitcast.1826 = bf16[2048,1]{1,0} bitcast(all-reduce.41)
+ dynamic-update-slice.33 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1801, bitcast.1826, constant.1980, select.372)
+ dynamic-slice.340 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.33, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+ bitcast.1841 = bf16[24,1024]{0,1} bitcast(dynamic-slice.340)
+ collective-permute.11 = bf16[24,1024]{0,1} collective-permute(bitcast.1841), channel_id=111, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+ get-tuple-element.53 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=8
+ collective-permute.5 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.53), channel_id=96, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+ all-gather.21 = bf16[24,2048]{0,1} all-gather(collective-permute.5), channel_id=97, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+ bitcast.1848 = bf16[2048,24]{1,0} bitcast(all-gather.21)
+ reduce.107 = bf16[2048]{0} reduce(reduce-scatter.2, constant.2006), dimensions={0,1}, to_apply=add
+ all-reduce.42 = bf16[2048]{0} all-reduce(reduce.107), channel_id=95, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ bitcast.1864 = bf16[2048,1]{1,0} bitcast(all-reduce.42)
+ dynamic-update-slice.34 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1848, bitcast.1864, constant.1980, select.372)
+ dynamic-slice.341 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.34, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+ bitcast.1879 = bf16[24,1024]{0,1} bitcast(dynamic-slice.341)
+ collective-permute.12 = bf16[24,1024]{0,1} collective-permute(bitcast.1879), channel_id=112, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+ get-tuple-element.54 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=9
+ collective-permute.6 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.54), channel_id=99, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+ all-gather.22 = bf16[24,2048]{0,1} all-gather(collective-permute.6), channel_id=100, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+ bitcast.1886 = bf16[2048,24]{1,0} bitcast(all-gather.22)
+ multiply.753 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, bitcast.1851)
+ bitcast.1905 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.753)
+ reduce.108 = bf16[2048]{0} reduce(bitcast.1905, constant.2006), dimensions={0,1}, to_apply=add
+ all-reduce.43 = bf16[2048]{0} all-reduce(reduce.108), channel_id=98, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ bitcast.1914 = bf16[2048,1]{1,0} bitcast(all-reduce.43)
+ dynamic-update-slice.35 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1886, bitcast.1914, constant.1980, select.372)
+ dynamic-slice.342 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.35, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+ bitcast.1929 = bf16[24,1024]{0,1} bitcast(dynamic-slice.342)
+ collective-permute.13 = bf16[24,1024]{0,1} collective-permute(bitcast.1929), channel_id=113, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+ get-tuple-element.55 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=10
+ bitcast.1979 = bf16[3,32,2048,16,128]{4,2,3,1,0} bitcast(add.639)
+ reduce.109 = bf16[3,16,128]{2,1,0} reduce(bitcast.1979, constant.2006), dimensions={1,2}, to_apply=add
+ all-reduce.44 = bf16[3,16,128]{2,1,0} all-reduce(reduce.109), channel_id=101, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ bitcast.1963 = bf16[1,3,16,128]{3,2,1,0} bitcast(all-reduce.44)
+ dynamic-update-slice.36 = bf16[24,3,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.55, bitcast.1963, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
+ get-tuple-element.56 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=11
+ bitcast.1974 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.56)
+ transpose.130 = bf16[3,16,128,32,2048]{4,3,2,1,0} transpose(add.639), dimensions={0,2,4,1,3}
+ bitcast.1983 = bf16[6144,65536]{1,0} bitcast(transpose.130)
+ all-gather.23 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=102, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.2007 = bf16[65536,4096]{0,1} bitcast(all-gather.23)
+ dot.55 = bf16[6144,4096]{0,1} dot(bitcast.1983, bitcast.2007), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitcast.2011 = bf16[3,16,128,4096]{2,1,0,3} bitcast(dot.55)
+ reduce-scatter.5 = bf16[3,16,128,1024]{2,1,0,3} reduce-scatter(bitcast.2011), channel_id=327, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={3}, to_apply=add
+ bitcast.2015 = bf16[1,1024,3,16,128]{4,3,2,1,0} bitcast(reduce-scatter.5)
+ dynamic-update-slice.37 = bf16[24,1024,3,16,128]{4,3,2,1,0} dynamic-update-slice(bitcast.1974, bitcast.2015, select.372, constant.1980, constant.1980, /*index=5*/constant.1980, constant.1980)
+ bitcast.2025 = bf16[24,3,1024,16,128]{4,3,1,2,0} bitcast(dynamic-update-slice.37)
+ get-tuple-element.57 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=12
+ reduce.110 = bf16[2048]{0} reduce(bitcast.2063, constant.2006), dimensions={0,1}, to_apply=add
+ all-reduce.46 = bf16[2048]{0} all-reduce(reduce.110), channel_id=104, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ dynamic-slice.343 = bf16[1024]{0} dynamic-slice(all-reduce.46, bitcast.2087), dynamic_slice_sizes={1024}
+ bitcast.2046 = bf16[1,1024]{1,0} bitcast(dynamic-slice.343)
+ collective-permute.7 = bf16[1,1024]{1,0} collective-permute(bitcast.2046), channel_id=105, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+ dynamic-update-slice.38 = bf16[24,1024]{1,0} dynamic-update-slice(get-tuple-element.57, collective-permute.7, select.372, constant.1980)
+ get-tuple-element.58 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=13
+ bitcast.2066 = bf16[2048,65536]{1,0} bitcast(add.619)
+ transpose.133 = bf16[16,32,2048,128]{3,2,1,0} transpose(dot.44), dimensions={1,0,3,2}
+ bitcast.2072 = bf16[32,2048,16,128]{3,1,0,2} bitcast(transpose.133)
+ all-gather.24 = bf16[32,2048,32,128]{3,1,0,2} all-gather(bitcast.2072), channel_id=106, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+ bitcast.2073 = bf16[32,32,2048,128]{3,2,1,0} bitcast(all-gather.24)
+ transpose.134 = bf16[32,2048,32,128]{3,2,1,0} transpose(bitcast.2073), dimensions={1,2,0,3}
+ bitcast.2077 = bf16[65536,4096]{1,0} bitcast(transpose.134)
+ dot.56 = bf16[2048,4096]{1,0} dot(bitcast.2066, bitcast.2077), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitcast.2081 = bf16[2048,32,128]{2,1,0} bitcast(dot.56)
+ all-reduce.47 = bf16[2048,32,128]{2,1,0} all-reduce(bitcast.2081), channel_id=107, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+ constant.2015 = s32[8]{0} constant({0, 0, 16, 16, 0, 0, 16, 16})
+ dynamic-slice.344 = s32[1]{0} dynamic-slice(constant.2015, partition-id.3), dynamic_slice_sizes={1}
+ bitcast.2095 = s32[] bitcast(dynamic-slice.344)
+ dynamic-slice.345 = bf16[1024,16,128]{2,1,0} dynamic-slice(all-reduce.47, bitcast.2087, bitcast.2095, constant.1980), dynamic_slice_sizes={1024,16,128}
+ bitcast.2102 = bf16[1,1024,16,128]{3,2,1,0} bitcast(dynamic-slice.345)
+ collective-permute.8 = bf16[1,1024,16,128]{3,2,1,0} collective-permute(bitcast.2102), channel_id=108, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+ dynamic-update-slice.39 = bf16[24,1024,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.58, collective-permute.8, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
+ ROOT tuple.2 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) tuple(add.581, bitcast.1525, dynamic-update-slice.28, dynamic-update-slice.29, collective-permute.9, /*index=5*/bitcast.1758, collective-permute.10, collective-permute.11, collective-permute.12, collective-permute.13, /*index=10*/dynamic-update-slice.36, bitcast.2025, dynamic-update-slice.38, dynamic-update-slice.39, get-tuple-element.45, /*index=15*/get-tuple-element.44, get-tuple-element.46, get-tuple-element.43, get-tuple-element.42, get-tuple-element.37, /*index=20*/get-tuple-element.36, get-tuple-element.38, get-tuple-element.35, get-tuple-element.41, get-tuple-element.40, /*index=25*/get-tuple-element.32, get-tuple-element.39, get-tuple-element.33)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ HloInstruction* fwd_instruction = nullptr;
+ HloInstruction* bwd_instruction = nullptr;
+ SCOPED_TRACE(m->ToString());
+ for (HloInstruction* instr :
+ m->entry_computation()->MakeInstructionPostOrder()) {
+ if (instr->opcode() == HloOpcode::kCustomCall &&
+ instr->custom_call_target() == kCudnnfMHASoftmaxCallTarget) {
+ fwd_instruction = instr;
+ }
+ if (instr->opcode() == HloOpcode::kCustomCall &&
+ instr->custom_call_target() == kCudnnfMHASoftmaxBackwardCallTarget) {
+ bwd_instruction = instr;
+ }
+ }
+ EXPECT_NE(fwd_instruction, nullptr);
+ EXPECT_NE(bwd_instruction, nullptr);
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fwd_instruction->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+ BF16TrainingBmm2CanonicalizationRestoreFwdGraph) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ const char* module_str = R"(
+HloModule pjit__unnamed_function_, entry_computation_layout={(bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,4,256,256]{3,2,1,0})->(bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false,false}, num_partitions=4
+
+region_0.6 {
+ Arg_0.7 = bf16[] parameter(0)
+ Arg_1.8 = bf16[] parameter(1)
+ ROOT maximum.5 = bf16[] maximum(Arg_0.7, Arg_1.8)
+}
+
+region_1.10 {
+ Arg_0.11 = f32[] parameter(0)
+ Arg_1.12 = f32[] parameter(1)
+ ROOT add.14 = f32[] add(Arg_0.11, Arg_1.12)
+}
+
+add.clone {
+ x.1 = u32[] parameter(0)
+ y.1 = u32[] parameter(1)
+ ROOT add.15 = u32[] add(x.1, y.1)
+}
+
+region_2.65 {
+ Arg_0.66 = bf16[] parameter(0)
+ Arg_1.67 = bf16[] parameter(1)
+ ROOT add.16 = bf16[] add(Arg_0.66, Arg_1.67)
+}
+
+ENTRY main.164_spmd {
+ param = bf16[2,256,4,64]{3,2,1,0} parameter(2), sharding={devices=[2,1,2,1]<=[4]}
+ transpose.26 = bf16[2,4,64,256]{3,2,1,0} transpose(param), dimensions={0,2,3,1}
+ param.1 = bf16[2,256,4,64]{3,2,1,0} parameter(0), sharding={devices=[2,1,2,1]<=[4]}
+ transpose.27 = bf16[2,4,256,64]{3,2,1,0} transpose(param.1), dimensions={0,2,1,3}
+ constant.46 = bf16[] constant(0.5)
+ broadcast.126 = bf16[2,4,256,64]{3,2,1,0} broadcast(constant.46), dimensions={}
+ multiply.34 = bf16[2,4,256,64]{3,2,1,0} multiply(transpose.27, broadcast.126)
+ param.2 = bf16[2,256,4,64]{3,2,1,0} parameter(1), sharding={devices=[2,1,2,1]<=[4]}
+ transpose.29 = bf16[2,4,64,256]{3,2,1,0} transpose(param.2), dimensions={0,2,3,1}
+ dot.12 = bf16[2,4,256,256]{3,2,1,0} dot(multiply.34, transpose.29), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ param.3 = bf16[2,4,256,256]{3,2,1,0} parameter(4), sharding={devices=[2,2,1,1]<=[4]}
+ add.17 = bf16[2,4,256,256]{3,2,1,0} add(dot.12, param.3)
+ constant.47 = bf16[] constant(-inf)
+ reduce.4 = bf16[2,4,256]{2,1,0} reduce(add.17, constant.47), dimensions={3}, to_apply=region_0.6
+ broadcast.127 = bf16[2,4,256,256]{3,2,1,0} broadcast(reduce.4), dimensions={0,1,2}
+ subtract.14 = bf16[2,4,256,256]{3,2,1,0} subtract(add.17, broadcast.127)
+ exponential.2 = bf16[2,4,256,256]{3,2,1,0} exponential(subtract.14)
+ convert.46 = f32[2,4,256,256]{3,2,1,0} convert(exponential.2)
+ constant.48 = f32[] constant(0)
+ reduce.5 = f32[2,4,256]{2,1,0} reduce(convert.46, constant.48), dimensions={3}, to_apply=region_1.10
+ convert.47 = bf16[2,4,256]{2,1,0} convert(reduce.5)
+ broadcast.128 = bf16[2,4,256,256]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2}
+ divide.7 = bf16[2,4,256,256]{3,2,1,0} divide(exponential.2, broadcast.128)
+ broadcast.129 = f32[4096]{0} broadcast(constant.48), dimensions={}
+ constant.50 = u32[] constant(0)
+ broadcast.131 = u32[8192]{0} broadcast(constant.50), dimensions={}
+ broadcast.133 = u32[4096]{0} broadcast(constant.50), dimensions={}
+ iota.3 = u32[8192]{0} iota(), iota_dimension=0
+ slice.14 = u32[4096]{0} slice(iota.3), slice={[0:4096]}
+ slice.15 = u32[4096]{0} slice(iota.3), slice={[4096:8192]}
+ custom-call.3 = (u32[4096]{0}, u32[4096]{0}) custom-call(broadcast.133, broadcast.133, slice.14, slice.15), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[4096]{0}, u32[4096]{0}, u32[4096]{0}, u32[4096]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\020\000\000\000\000\000\000"
+ get-tuple-element.6 = u32[4096]{0} get-tuple-element(custom-call.3), index=0
+ constant.115 = u32[1]{0} constant({0})
+ constant.52 = u32[4]{0} constant({0, 0, 1, 1})
+ partition-id = u32[] partition-id()
+ dynamic-slice.21 = u32[1]{0} dynamic-slice(constant.52, partition-id), dynamic_slice_sizes={1}
+ constant.116 = u32[1]{0} constant({1})
+ clamp.3 = u32[1]{0} clamp(constant.115, dynamic-slice.21, constant.116)
+ convert.48 = s32[1]{0} convert(clamp.3)
+ constant.117 = s32[1]{0} constant({2048})
+ multiply.35 = s32[1]{0} multiply(convert.48, constant.117)
+ bitcast.105 = s32[] bitcast(multiply.35)
+ dynamic-slice.22 = u32[2048]{0} dynamic-slice(get-tuple-element.6, bitcast.105), dynamic_slice_sizes={2048}
+ constant.58 = s32[4]{0} constant({0, 0, 1, 1})
+ dynamic-slice.23 = s32[1]{0} dynamic-slice(constant.58, partition-id), dynamic_slice_sizes={1}
+ multiply.36 = s32[1]{0} multiply(dynamic-slice.23, constant.117)
+ bitcast.108 = s32[] bitcast(multiply.36)
+ dynamic-update-slice.2 = u32[8192]{0} dynamic-update-slice(broadcast.131, dynamic-slice.22, bitcast.108)
+ get-tuple-element.7 = u32[4096]{0} get-tuple-element(custom-call.3), index=1
+ dynamic-slice.24 = u32[2048]{0} dynamic-slice(get-tuple-element.7, bitcast.105), dynamic_slice_sizes={2048}
+ constant.65 = s32[] constant(4096)
+ add.18 = s32[] add(bitcast.108, constant.65)
+ dynamic-update-slice.3 = u32[8192]{0} dynamic-update-slice(dynamic-update-slice.2, dynamic-slice.24, add.18)
+ all-reduce = u32[8192]{0} all-reduce(dynamic-update-slice.3), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.clone
+ constant.118 = s32[1]{0} constant({4096})
+ multiply.37 = s32[1]{0} multiply(dynamic-slice.23, constant.118)
+ bitcast.119 = s32[] bitcast(multiply.37)
+ dynamic-slice.25 = u32[4096]{0} dynamic-slice(all-reduce, bitcast.119), dynamic_slice_sizes={4096}
+ constant.69 = u32[] constant(9)
+ broadcast.134 = u32[4096]{0} broadcast(constant.69), dimensions={}
+ shift-right-logical.6 = u32[4096]{0} shift-right-logical(dynamic-slice.25, broadcast.134)
+ constant.70 = u32[] constant(1065353216)
+ broadcast.135 = u32[4096]{0} broadcast(constant.70), dimensions={}
+ or.5 = u32[4096]{0} or(shift-right-logical.6, broadcast.135)
+ bitcast-convert.5 = f32[4096]{0} bitcast-convert(or.5)
+ constant.71 = f32[] constant(-1)
+ broadcast.136 = f32[4096]{0} broadcast(constant.71), dimensions={}
+ add.19 = f32[4096]{0} add(bitcast-convert.5, broadcast.136)
+ maximum.6 = f32[4096]{0} maximum(broadcast.129, add.19)
+ constant.72 = f32[] constant(0.5)
+ broadcast.137 = f32[4096]{0} broadcast(constant.72), dimensions={}
+ compare.4 = pred[4096]{0} compare(maximum.6, broadcast.137), direction=LT
+ bitcast.135 = pred[2,8,256]{2,1,0} bitcast(compare.4)
+ convert.49 = bf16[2,8,256]{2,1,0} convert(bitcast.135)
+ constant.80 = s32[] constant(0)
+ constant.78 = s32[4]{0} constant({0, 4, 0, 4})
+ dynamic-slice.26 = s32[1]{0} dynamic-slice(constant.78, partition-id), dynamic_slice_sizes={1}
+ bitcast.181 = s32[] bitcast(dynamic-slice.26)
+ dynamic-slice.27 = bf16[2,4,256]{2,1,0} dynamic-slice(convert.49, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
+ broadcast.139 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.27), dimensions={0,1,3}
+ multiply.38 = bf16[2,4,256,256]{3,2,1,0} multiply(divide.7, broadcast.139)
+ constant.93 = bf16[] constant(2)
+ broadcast.141 = bf16[2,4,256,256]{3,2,1,0} broadcast(constant.93), dimensions={}
+ multiply.39 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.38, broadcast.141)
+ dot.13 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.26, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ transpose.31 = bf16[4,2,64,256]{3,2,1,0} transpose(dot.13), dimensions={1,0,2,3}
+ bitcast.154 = bf16[2,256,4,64]{1,3,0,2} bitcast(transpose.31)
+ all-gather = bf16[2,256,8,64]{1,3,0,2} all-gather(bitcast.154), channel_id=2, replica_groups={{0,1},{2,3}}, dimensions={2}, use_global_device_ids=true
+ bitcast.155 = bf16[8,2,64,256]{3,2,1,0} bitcast(all-gather)
+ transpose.32 = bf16[2,8,64,256]{3,2,1,0} transpose(bitcast.155), dimensions={1,0,2,3}
+ bitcast.157 = bf16[2,256,8,64]{1,3,2,0} bitcast(transpose.32)
+ all-gather.1 = bf16[4,256,8,64]{1,3,2,0} all-gather(bitcast.157), channel_id=3, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true
+ bitcast.236 = bf16[4,8,64,256]{3,2,1,0} bitcast(all-gather.1)
+ transpose.38 = bf16[4,256,8,64]{3,2,1,0} transpose(bitcast.236), dimensions={0,3,1,2}
+ param.4 = bf16[2,256,4,64]{3,2,1,0} parameter(3), sharding={devices=[2,1,2,1]<=[4]}
+ transpose.33 = bf16[2,4,256,64]{3,2,1,0} transpose(param.4), dimensions={0,2,1,3}
+ dot.14 = bf16[2,4,256,256]{3,2,1,0} dot(transpose.33, transpose.26), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ broadcast.142 = bf16[4096]{0} broadcast(constant.93), dimensions={}
+ constant.95 = bf16[] constant(0)
+ broadcast.143 = bf16[4096]{0} broadcast(constant.95), dimensions={}
+ select.4 = bf16[4096]{0} select(compare.4, broadcast.142, broadcast.143)
+ bitcast.176 = bf16[2,8,256]{2,1,0} bitcast(select.4)
+ dynamic-slice.28 = bf16[2,4,256]{2,1,0} dynamic-slice(bitcast.176, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
+ broadcast.145 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.28), dimensions={0,1,3}
+ multiply.40 = bf16[2,4,256,256]{3,2,1,0} multiply(dot.14, broadcast.145)
+ divide.8 = bf16[2,4,256,256]{3,2,1,0} divide(multiply.40, broadcast.128)
+ constant.106 = bf16[] constant(1)
+ broadcast.146 = bf16[2,4,256]{2,1,0} broadcast(constant.106), dimensions={}
+ multiply.41 = bf16[2,4,256]{2,1,0} multiply(convert.47, convert.47)
+ divide.9 = bf16[2,4,256]{2,1,0} divide(broadcast.146, multiply.41)
+ broadcast.147 = bf16[2,4,256,256]{3,2,1,0} broadcast(divide.9), dimensions={0,1,2}
+ multiply.42 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.40, broadcast.147)
+ multiply.43 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.42, exponential.2)
+ reduce.6 = bf16[2,4,256]{2,1,0} reduce(multiply.43, constant.95), dimensions={3}, to_apply=region_2.65
+ negate.4 = bf16[2,4,256]{2,1,0} negate(reduce.6)
+ broadcast.148 = bf16[2,4,256,256]{3,2,1,0} broadcast(negate.4), dimensions={0,1,2}
+ add.20 = bf16[2,4,256,256]{3,2,1,0} add(divide.8, broadcast.148)
+ multiply.44 = bf16[2,4,256,256]{3,2,1,0} multiply(add.20, exponential.2)
+ transpose.34 = bf16[2,4,256,64]{3,2,1,0} transpose(param.2), dimensions={0,2,1,3}
+ dot.15 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, transpose.34), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ multiply.45 = bf16[2,4,256,64]{3,2,1,0} multiply(dot.15, broadcast.126)
+ transpose.39 = bf16[2,256,4,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
+ dot.16 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, multiply.34), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.40 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.16), dimensions={0,2,1,3}
+ transpose.36 = bf16[2,4,64,256]{3,2,1,0} transpose(param.4), dimensions={0,2,3,1}
+ dot.11 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.36, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.41 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.11), dimensions={0,3,1,2}
+ ROOT tuple.2 = (bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}) tuple(transpose.38, transpose.39, transpose.40, transpose.41)
+} // main.164_spmd
+)";
+ // Dropout bwd pattern not supported, should not lower fwd as well
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+ GetCudnnVersion()};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+ SCOPED_TRACE(m->ToString());
+ // check if fwd graph has been restored with cloned activation
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::Transpose(), m::Transpose(), m::Transpose(),
+ m::Transpose(m::Dot(
+ m::Op(), m::Op().WithPredicate([](const HloInstruction* instr) {
+ return instr->name() == "multiply.39.fmha_no_match_clone";
+ }))))));
+}
+
+constexpr absl::string_view hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})->(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true,true}
+
+region_0.14 {
+ Arg_0.15 = bf16[] parameter(0)
+ Arg_1.16 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(Arg_0.15, Arg_1.16)
+}
+
+region_1.27 {
+ Arg_0.28 = f32[] parameter(0)
+ Arg_1.29 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0.28, Arg_1.29)
+}
+
+region_2.56 {
+ Arg_0.57 = bf16[] parameter(0)
+ Arg_1.58 = bf16[] parameter(1)
+ ROOT add.1 = bf16[] add(Arg_0.57, Arg_1.58)
+}
+
+ENTRY main.87 {
+ Arg_2.3 = bf16[2,1024,4,64]{3,2,1,0} parameter(2)
+ transpose.12 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
+ Arg_0.1 = bf16[2,1024,4,64]{3,2,1,0} parameter(0)
+ transpose.13 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
+ Arg_1.2 = bf16[2,1024,4,64]{3,2,1,0} parameter(1)
+ transpose.15 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
+ dot = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.13, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ Arg_4.5 = bf16[4,1024,1024]{2,1,0} parameter(4)
+ broadcast.9 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(Arg_4.5), dimensions={1,2,3}
+ add.2 = bf16[2,4,1024,1024]{3,2,1,0} add(dot, broadcast.9)
+ constant.10 = bf16[] constant(-inf)
+ reduce.18 = bf16[2,4,1024]{2,1,0} reduce(add.2, constant.10), dimensions={3}, to_apply=region_0.14
+ broadcast.10 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(reduce.18), dimensions={0,1,2}
+ subtract = bf16[2,4,1024,1024]{3,2,1,0} subtract(add.2, broadcast.10)
+ exponential = bf16[2,4,1024,1024]{3,2,1,0} exponential(subtract)
+ convert.5 = f32[2,4,1024,1024]{3,2,1,0} convert(exponential)
+ constant.9 = f32[] constant(0)
+ reduce.31 = f32[2,4,1024]{2,1,0} reduce(convert.5, constant.9), dimensions={3}, to_apply=region_1.27
+ convert.6 = bf16[2,4,1024]{2,1,0} convert(reduce.31)
+ broadcast.11 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(convert.6), dimensions={0,1,2}
+ divide.2 = bf16[2,4,1024,1024]{3,2,1,0} divide(exponential, broadcast.11)
+ dot.1 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.12, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ transpose.22 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
+ Arg_3.4 = bf16[2,1024,4,64]{3,2,1,0} parameter(3)
+ transpose.17 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,1,3}
+ dot.2 = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.17, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ divide.3 = bf16[2,4,1024,1024]{3,2,1,0} divide(dot.2, broadcast.11)
+ constant.0 = bf16[] constant(1)
+ broadcast.13 = bf16[2,4,1024]{2,1,0} broadcast(constant.0), dimensions={}
+ multiply.2 = bf16[2,4,1024]{2,1,0} multiply(convert.6, convert.6)
+ divide.4 = bf16[2,4,1024]{2,1,0} divide(broadcast.13, multiply.2)
+ broadcast.14 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(divide.4), dimensions={0,1,2}
+ multiply.3 = bf16[2,4,1024,1024]{3,2,1,0} multiply(dot.2, broadcast.14)
+ multiply.4 = bf16[2,4,1024,1024]{3,2,1,0} multiply(multiply.3, exponential)
+ constant.8 = bf16[] constant(0)
+ reduce.60 = bf16[2,4,1024]{2,1,0} reduce(multiply.4, constant.8), dimensions={3}, to_apply=region_2.56
+ negate.1 = bf16[2,4,1024]{2,1,0} negate(reduce.60)
+ broadcast.15 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(negate.1), dimensions={0,1,2}
+ add.3 = bf16[2,4,1024,1024]{3,2,1,0} add(divide.3, broadcast.15)
+ multiply.5 = bf16[2,4,1024,1024]{3,2,1,0} multiply(add.3, exponential)
+ transpose.18 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,1,3}
+ dot.4 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.18), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.23 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.4), dimensions={0,2,1,3}
+ dot.3 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.13), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.24 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.3), dimensions={0,2,1,3}
+ transpose.20 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,3,1}
+ dot.49 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.20, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+ transpose.25 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.49), dimensions={0,3,1,2}
+ reduce.81 = bf16[4,1024,1024]{2,1,0} reduce(multiply.5, constant.8), dimensions={0}, to_apply=region_2.56
+ ROOT tuple = (bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0}) tuple(transpose.22, transpose.23, transpose.24, transpose.25, reduce.81)
+} // main.87
+)";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxBmm2PatternDbias) {
+ if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto m,
+ ParseAndReturnVerifiedModule(hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias));
+ // require cudnn 8.9.6 + hopper for dbias
+ CudnnFusedMHARewriter fusedMhaRewriter{se::CudaComputeCapability(9, 0),
+ se::dnn::VersionInfo(9, 0, 0)};
+ TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ const HloInstruction* fmha;
+
+ SCOPED_TRACE(m->ToString());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::Transpose(
+ m::Transpose(m::GetTupleElement(
+ m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
+ 0)))
+ .WithShape(BF16, {2, 1024, 4, 64}),
+ m::Transpose(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 0))
+ .WithShape(BF16, {2, 1024, 4, 64}),
+ m::Transpose(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 1))
+ .WithShape(BF16, {2, 1024, 4, 64}),
+ m::Transpose(
+ m::Transpose(m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 2)))
+ .WithShape(BF16, {2, 1024, 4, 64}),
+ m::Reshape(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+ 3))
+ .WithShape(BF16, {4, 1024, 1024}))));
+ TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+ EXPECT_EQ(fmha->operands().size(), 4);
+ EXPECT_EQ(fmha->operand(3)->shape(),
+ ShapeUtil::MakeShape(BF16, {1, 4, 1024, 1024}));
+ EXPECT_EQ(config.fmha_scale(), 1.0);
+ EXPECT_EQ(config.dropout_rate(), 0.0);
+ EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
+}
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc
new file mode 100644
index 0000000..7299643
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc
@@ -0,0 +1,668 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <iterator>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+namespace m = match;
+
+bool IsFMHACustomCall(const HloInstruction* instr) {
+ return IsCustomCallTofMHA(*instr);
+}
+
+bool IsFwdFMHACustomCall(const HloInstruction* instr) {
+ return IsFwdCustomCallTofMHA(*instr);
+}
+
+bool IsBwdFMHACustomCall(const HloInstruction* instr) {
+ return IsBwdCustomCallTofMHA(*instr);
+}
+
+absl::StatusOr<bool> FuseArgPrologueTransposeWithcuDNNFMHA(
+ HloInstruction* fmha, int64_t operand_index, bool is_lhs,
+ bool should_contracting_be_fastest) {
+ HloInstruction* transpose_arg = fmha->mutable_operand(operand_index);
+ HloInstruction* transpose_arg_operand = transpose_arg->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ CudnnfMHABackendConfig config = gpu_config.cudnn_fmha_backend_config();
+ CudnnfMHABackendConfig& new_fmha_config =
+ *gpu_config.mutable_cudnn_fmha_backend_config();
+
+ std::vector<int64_t> inverse_perm =
+ InversePermutation(transpose_arg->dimensions());
+ DotDimensionNumbers new_bmm_dot_dims;
+ if (IsFwdCustomCallTofMHA(*fmha)) {
+ if (operand_index == 0 || operand_index == 1) {
+ new_bmm_dot_dims = config.bmm1_dot_dimension_numbers();
+ } else {
+ new_bmm_dot_dims = config.bmm2_dot_dimension_numbers();
+ }
+ } else {
+ switch (operand_index) {
+ case 0:
+ // Q
+ new_bmm_dot_dims = config.bmm1_grad_gemm1_dot_dimension_numbers();
+ break;
+ case 1:
+ // K
+ new_bmm_dot_dims = config.bmm1_grad_gemm2_dot_dimension_numbers();
+ break;
+ case 2:
+ // V
+ new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
+ break;
+ case 3:
+ // Forward activation
+ new_bmm_dot_dims = config.bmm2_grad_gemm1_dot_dimension_numbers();
+ break;
+ case 4:
+ // Output gradient
+ new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
+ break;
+ default:
+ return Internal("Invalid operand index.");
+ }
+ }
+ absl::Span<const int64_t> checked_dims;
+ std::vector<int64_t> checked_dims_vec;
+
+ // `should_contracting_be_fastest` means if contracting dim is the head
+ // dim. cuDNN requires head dim to be the fastest dim. fwd bmm1 and bwd
+ // bmm2grad1 should set this value to true.
+ if (should_contracting_be_fastest) {
+ checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
+ : new_bmm_dot_dims.rhs_contracting_dimensions();
+ } else {
+ absl::Span<const int64_t> batch_dims =
+ is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
+ : new_bmm_dot_dims.rhs_batch_dimensions();
+ absl::Span<const int64_t> contracting_dims =
+ is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
+ : new_bmm_dot_dims.rhs_contracting_dimensions();
+
+ TF_ASSIGN_OR_RETURN(checked_dims_vec,
+ GetNonContractingDims(transpose_arg->shape(),
+ batch_dims, contracting_dims));
+ checked_dims = checked_dims_vec;
+ }
+
+ int64_t checked_dims_bmm_size = checked_dims.size();
+ std::vector<int64_t> new_bmm_checked_dims(checked_dims_bmm_size);
+ for (int i = 0; i < checked_dims_bmm_size; i++) {
+ auto itr =
+ std::find(inverse_perm.begin(), inverse_perm.end(), checked_dims[i]);
+ if (itr == inverse_perm.end()) {
+ return Internal("Invalid inverse perm");
+ }
+ new_bmm_checked_dims[i] = std::distance(inverse_perm.begin(), itr);
+ }
+ // We want to make sure that making the argument to transpose, an input to
+ // fmha, doesn't break cuDNN constraint that the head dim of
+ // corresponding operand of BMM is the fastest moving dimension.
+ // One exception is the forward activation which doesn't have the constraint
+ // since it does not have head dim.
+ absl::Span<const int64_t> minor_to_major_bmm =
+ transpose_arg_operand->shape().layout().minor_to_major();
+ if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) &&
+ !(IsBwdCustomCallTofMHA(*fmha) && operand_index == 3)) {
+ return false;
+ }
+ if (should_contracting_be_fastest) {
+ if (is_lhs) {
+ new_bmm_dot_dims.clear_lhs_contracting_dimensions();
+ *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
+ new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
+ } else {
+ new_bmm_dot_dims.clear_rhs_contracting_dimensions();
+ *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
+ new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
+ }
+ }
+ auto& batch_dims = is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
+ : new_bmm_dot_dims.rhs_batch_dimensions();
+ int64_t batch_dims_bmm_size = batch_dims.size();
+ std::vector<int64_t> new_bmm_batch_dims(batch_dims_bmm_size);
+ for (int i = 0; i < batch_dims_bmm_size; i++) {
+ auto itr =
+ std::find(inverse_perm.begin(), inverse_perm.end(), batch_dims[i]);
+ if (itr == inverse_perm.end()) {
+ return Internal("Invalid inverse perm");
+ }
+ new_bmm_batch_dims[i] = std::distance(inverse_perm.begin(), itr);
+ }
+
+ if (is_lhs) {
+ new_bmm_dot_dims.clear_lhs_batch_dimensions();
+ *new_bmm_dot_dims.mutable_lhs_batch_dimensions() = {
+ new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
+
+ } else {
+ new_bmm_dot_dims.clear_rhs_batch_dimensions();
+ *new_bmm_dot_dims.mutable_rhs_batch_dimensions() = {
+ new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
+ }
+
+ if (!should_contracting_be_fastest) {
+ // Given the non-contracting dimensions, we can use the same function,
+ // GetNonContractingDims, to find the new contracting dims. Simply pass the
+ // non-contracting dimensions as the second argument.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<int64_t> new_bmm_contracting_dims,
+ GetNonContractingDims(transpose_arg_operand->shape(),
+ new_bmm_batch_dims, new_bmm_checked_dims));
+ if (is_lhs) {
+ new_bmm_dot_dims.clear_lhs_contracting_dimensions();
+ *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
+ new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
+
+ } else {
+ new_bmm_dot_dims.clear_rhs_contracting_dimensions();
+ *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
+ new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
+ }
+ }
+ if (IsFwdCustomCallTofMHA(*fmha)) {
+ if (operand_index == 0 || operand_index == 1) {
+ // Q or K
+ *new_fmha_config.mutable_bmm1_dot_dimension_numbers() = new_bmm_dot_dims;
+ } else {
+ // V
+ *new_fmha_config.mutable_bmm2_dot_dimension_numbers() = new_bmm_dot_dims;
+ }
+ } else {
+ switch (operand_index) {
+ case 0:
+ // Q
+ *new_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
+ new_bmm_dot_dims;
+ break;
+ case 1:
+ // K
+ *new_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
+ new_bmm_dot_dims;
+ break;
+ case 2:
+ // V
+ *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
+ new_bmm_dot_dims;
+ break;
+ case 3:
+ // Forward activation
+ *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
+ new_bmm_dot_dims;
+ break;
+ case 4: {
+ // Output gradient
+ *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
+ new_bmm_dot_dims;
+ DotDimensionNumbers bmm2_grad_gemm1_dot_dims =
+ config.bmm2_grad_gemm1_dot_dimension_numbers();
+ absl::Span<const int64_t> bmm2_grad_gemm1_contracting_dims =
+ bmm2_grad_gemm1_dot_dims.rhs_contracting_dimensions();
+ CHECK_EQ(bmm2_grad_gemm1_contracting_dims.size(), 1);
+ absl::Span<const int64_t> transpose_permutation =
+ transpose_arg->dimensions();
+ auto itr = std::find(transpose_permutation.begin(),
+ transpose_permutation.end(),
+ bmm2_grad_gemm1_contracting_dims[0]);
+ if (itr == transpose_permutation.end()) {
+ return Internal(
+ "bmm2 gradident gemm1 contracting dimension not found.");
+ }
+ int64_t index = std::distance(transpose_permutation.begin(), itr);
+ std::vector<int64_t> new_bmm2_grad_gemm1_rhs_contracting_dims = {index};
+ // Find the new batch dimensions, this is done by passing new
+ // contracting dimensions and contracting dimension of lhs of
+ // bmm2_grad_gemm2(which is the non-contracting dimension of rhs
+ // bmm2_grad_gemm1) to GetNonContractingDims.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<int64_t> new_bmm2_grad_gemm1_rhs_batch_dims,
+ GetNonContractingDims(
+ transpose_arg_operand->shape(),
+ new_bmm2_grad_gemm1_rhs_contracting_dims,
+ new_bmm_dot_dims.lhs_contracting_dimensions()));
+ bmm2_grad_gemm1_dot_dims.clear_rhs_contracting_dimensions();
+ bmm2_grad_gemm1_dot_dims.clear_rhs_batch_dimensions();
+ *bmm2_grad_gemm1_dot_dims.mutable_rhs_contracting_dimensions() = {
+ new_bmm2_grad_gemm1_rhs_contracting_dims.begin(),
+ new_bmm2_grad_gemm1_rhs_contracting_dims.end()};
+ *bmm2_grad_gemm1_dot_dims.mutable_rhs_batch_dimensions() = {
+ new_bmm2_grad_gemm1_rhs_batch_dims.begin(),
+ new_bmm2_grad_gemm1_rhs_batch_dims.end()};
+ *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
+ bmm2_grad_gemm1_dot_dims;
+ break;
+ }
+ default:
+ return Internal("Invalid operand index.");
+ }
+ }
+
+ TF_RETURN_IF_ERROR(fmha->set_backend_config(gpu_config));
+
+ TF_RETURN_IF_ERROR(fmha->ReplaceOperandWithDifferentShape(
+ operand_index, transpose_arg_operand));
+
+ return true;
+}
+
+/* Let's say A is transposed to B with perm {3, 0, 2, 1} as shown below:
+A[16, 256, 32, 64]
+ |
+ |
+ | Transpose with perm = {3, 0, 2, 1}
+ |
+ \/
+B[64, 16, 32, 256]
+
+The inverse perm to obtain A from B would be {1, 3, 2, 0}. That is
+B[64, 16, 32, 256]
+ |
+ |
+ | Transpose' with inv_perm = {1, 3, 2, 0}
+ |
+ \/
+A[16, 256, 32, 64]
+
+Now, let's say B is the lhs of a BatchedMatmul and the lhs_contracting
+dim is 3 (i.e dim 256). In order to now make A the lhs to the
+batchedMatmul (thus consuming the Transpose from A->B), we need to find
+the dimension number in A that corresponds to dimension number 3 in B.
+This can be done by finding the index of dim num 3 in inv_perm. That
+would be 2. Hence, dim num 3 in B is equivalent to dim num 2 in A. Thus
+the new lhs_contracting dim ,if A were to be the new lhs, would be 2.
+
+Similarly, we need to find corresponding batch dimensions as well.
+*/
+absl::StatusOr<bool> FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) {
+ bool changed = false;
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction *transpose_arg0, *transpose_arg0_operand;
+ HloInstruction *transpose_arg1, *transpose_arg1_operand;
+ HloInstruction *transpose_arg2, *transpose_arg2_operand;
+ HloInstruction *transpose_arg3, *transpose_arg3_operand;
+ HloInstruction *transpose_arg4, *transpose_arg4_operand;
+
+ HloInstruction* fmha;
+
+ // Arg0 is common between forward and backward fmha calls, so we match
+ // either of these.
+ auto pattern_arg0 =
+ m::Op(&fmha)
+ .WithPredicate(IsFMHACustomCall)
+ .WithOperand(0, m::Transpose(&transpose_arg0,
+ m::Op(&transpose_arg0_operand)));
+ if (Match(instr, pattern_arg0)) {
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 0: \n"
+ << comp->parent()->ToString();
+ }
+ if (IsFwdFMHACustomCall(fmha)) {
+ // Q tensor in forward graph is lhs with constraint on contracting dim.
+ TF_ASSIGN_OR_RETURN(changed,
+ FuseArgPrologueTransposeWithcuDNNFMHA(
+ fmha, 0, true /*is_lhs*/,
+ true /*should_contracting_be_fastest*/));
+ } else {
+ // Q tensor in backward graph is rhs with constraint on non-contracting
+ // dim.
+ TF_ASSIGN_OR_RETURN(changed,
+ FuseArgPrologueTransposeWithcuDNNFMHA(
+ fmha, 0, false /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+ }
+
+ if (changed && VLOG_IS_ON(2)) {
+ VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 0: \n"
+ << comp->parent()->ToString();
+ }
+ }
+
+ // Arg1 is common between forward and backward fmha calls, so we match
+ // either of these.
+ auto pattern_arg1 =
+ m::Op(&fmha)
+ .WithPredicate(IsFMHACustomCall)
+ .WithOperand(1, m::Transpose(&transpose_arg1,
+ m::Op(&transpose_arg1_operand)));
+ if (Match(instr, pattern_arg1)) {
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 1: \n"
+ << comp->parent()->ToString();
+ }
+ if (IsFwdFMHACustomCall(fmha)) {
+ // K tensor in forward graph is rhs with constraint on contracting dim.
+ TF_ASSIGN_OR_RETURN(changed,
+ FuseArgPrologueTransposeWithcuDNNFMHA(
+ fmha, 1, false /*is_lhs*/,
+ true /*should_contracting_be_fastest*/));
+ } else {
+ // K tensor in backward graph is rhs with constraint on non-contracting
+ // dim.
+ TF_ASSIGN_OR_RETURN(changed,
+ FuseArgPrologueTransposeWithcuDNNFMHA(
+ fmha, 1, false /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+ }
+
+ if (changed && VLOG_IS_ON(2)) {
+ VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 1: \n"
+ << comp->parent()->ToString();
+ }
+ }
+
+ // Arg2 is common between forward and backward fmha calls, so we match
+ // either of these.
+ auto pattern_arg2 =
+ m::Op(&fmha)
+ .WithPredicate(IsFMHACustomCall)
+ .WithOperand(2, m::Transpose(&transpose_arg2,
+ m::Op(&transpose_arg2_operand)));
+ if (Match(instr, pattern_arg2)) {
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 2: \n"
+ << comp->parent()->ToString();
+ }
+ if (IsFwdFMHACustomCall(fmha)) {
+ // V tensor in forward graph is rhs with constraint on non-contracting
+ // dim.
+ TF_ASSIGN_OR_RETURN(changed,
+ FuseArgPrologueTransposeWithcuDNNFMHA(
+ fmha, 2, false /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+ } else {
+ // V tensor in backward graph is rhs with constraint on contracting dim.
+ TF_ASSIGN_OR_RETURN(changed,
+ FuseArgPrologueTransposeWithcuDNNFMHA(
+ fmha, 2, false /*is_lhs*/,
+ true /*should_contracting_be_fastest*/));
+ }
+
+ if (changed && VLOG_IS_ON(2)) {
+ VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 2: \n"
+ << comp->parent()->ToString();
+ }
+ }
+
+ // We only care about arg3 of backward
+ auto pattern_arg3 =
+ m::Op(&fmha)
+ .WithPredicate(IsBwdFMHACustomCall)
+ .WithOperand(3, m::Transpose(&transpose_arg3,
+ m::Op(&transpose_arg3_operand)));
+ if (Match(instr, pattern_arg3)) {
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 3: \n"
+ << comp->parent()->ToString();
+ }
+ // Forward activation tensor in backward graph is lhs with constraint on
+ // non-contracting dim.
+ TF_ASSIGN_OR_RETURN(changed,
+ FuseArgPrologueTransposeWithcuDNNFMHA(
+ fmha, 3, true /*is_lhs*/,
+ false /*should_contracting_be_fastest*/));
+
+ if (changed && VLOG_IS_ON(2)) {
+ VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 3: \n"
+ << comp->parent()->ToString();
+ }
+ }
+
+ // We only care about arg4 of backward
+ auto pattern_arg4 =
+ m::Op(&fmha)
+ .WithPredicate(IsBwdFMHACustomCall)
+ .WithOperand(4, m::Transpose(&transpose_arg4,
+ m::Op(&transpose_arg4_operand)));
+ if (Match(instr, pattern_arg4)) {
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 4: \n"
+ << comp->parent()->ToString();
+ }
+ // D_output tensor in backward graph is lhs with constraint on
+ // contracting dim.
+ // make sure we dont change layout of dO in flash attention case as dO
+ // should have the same layout of O
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ fmha->backend_config<GpuBackendConfig>());
+ if (changed && VLOG_IS_ON(2)) {
+ VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 4: \n"
+ << comp->parent()->ToString();
+ }
+ }
+ }
+ return changed;
+}
+
+/* Let's say FMHA out is transposed to result with perm {1, 2, 0, 3} as shown
+below: FMHA_out[b0, b1, n, m]{}
+ |
+ |
+ Transpose with perm = {1, 2, 0, 3}
+ |
+ \/
+result[b1, n, b0, m]{1, 0, 3, 2}
+The goal is to find the minor_to_major of 'FMHA_out' such that it's physical
+layout matches the physical layout of 'result', thus eliminating the need for an
+explicit transpose. cuDNN can perform an implicit transpose by knowing the
+corresponding strides (inferred from the corresponding minor_to_major).
+
+In order to find the required mino_to_major of 'FMHA_out', we first determine
+the inverse perm to obtain 'FMHA_out' from 'result'. The function
+"ShapeUtil::PermuteDimensions" generates a transposed shape such that the
+physical layout of the transposed shape is equivalent to the input shape.
+Calling this function with 'result' shape as the input shape and the inverse
+perm as the permutation will generate an output shape whose dimensions match
+'FMHA_out' dimensions but the physical layout is equivalent to 'result'. This is
+exactly what we want.
+
+FMHA output should have exactly one gte instruction for a tuple index
+so we can safely fuse the transpose following that gte to FMHA
+
+FMHA_out = gte(FMHA, index=0)
+FMHA_out_t = transpose(FMHA_out)
+use(FMHA_out_t)
+
+after fusion:
+
+FMHA_out_t = gte(FMHA, index=0)
+use(FMHA_out_t)
+*/
+
+absl::StatusOr<bool> FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) {
+ bool changed = false;
+
+ auto only_one_gte_with_spec_index = [](const HloInstruction* instr,
+ int64_t index) {
+ int count = 0;
+ for (auto user : instr->users()) {
+ if (user->opcode() == HloOpcode::kGetTupleElement &&
+ user->tuple_index() == index) {
+ count += 1;
+ }
+ }
+ return count == 1;
+ };
+
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ HloInstruction* fmha;
+ HloInstruction* transpose;
+ HloInstruction* gte;
+ auto fwd_tuple_elem =
+ m::GetTupleElement(>e,
+ m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0)
+ .WithOneUser();
+ // Note that we don't match any specific tuple index in matcher for
+ // backward.
+ auto bwd_tuple_elem =
+ m::GetTupleElement(>e,
+ m::Op(&fmha).WithPredicate(IsBwdFMHACustomCall))
+ .WithOneUser();
+ auto fwd_pattern = m::Transpose(&transpose, fwd_tuple_elem);
+ auto bwd_pattern = m::Transpose(&transpose, bwd_tuple_elem);
+
+ if (Match(instr, fwd_pattern)) {
+ // check if only one gte with such index exist
+ int64_t tuple_index = gte->tuple_index();
+ if (!only_one_gte_with_spec_index(fmha, tuple_index)) continue;
+
+ std::vector<int64_t> inverse_perm =
+ InversePermutation(transpose->dimensions());
+
+ auto expected_fmha_shape =
+ ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
+
+ // cuDNN requires the last dimension of the output to be the fastest
+ // moving.
+ if (expected_fmha_shape.layout().minor_to_major()[0] !=
+ expected_fmha_shape.dimensions_size() - 1) {
+ VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
+ "the fastest moving. The last dimension is dim: "
+ << expected_fmha_shape.dimensions_size() - 1
+ << " but the upon fusion with transpose, the fmha output shape "
+ "would have been "
+ << expected_fmha_shape.ToString(true)
+ << " and the fastest moving "
+ "dimension would be dim: "
+ << expected_fmha_shape.layout().minor_to_major()[0];
+ continue;
+ }
+ Shape call_shape = fmha->shape();
+ *call_shape.mutable_tuple_shapes(0) = expected_fmha_shape;
+ HloInstruction* new_fmha_custom_call =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ call_shape, fmha->operands(),
+ absl::string_view(fmha->custom_call_target())));
+
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
+ fmha->backend_config<GpuBackendConfig>());
+ TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
+ TF_RETURN_IF_ERROR(
+ SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
+ new_fmha_custom_call->set_metadata(fmha->metadata());
+
+ auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_fmha_custom_call->shape().tuple_shapes(0), new_fmha_custom_call,
+ 0));
+ TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+ instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
+ TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
+
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "After forward FuseEpilogueTransposeWithcuDNNFMHA: \n"
+ << comp->parent()->ToString();
+ }
+ changed |= true;
+ } else if (Match(instr, bwd_pattern)) {
+ // check if only one gte with such index exist
+ int64_t operand_tuple_idx = gte->tuple_index();
+ if (!only_one_gte_with_spec_index(fmha, operand_tuple_idx)) continue;
+
+ std::vector<int64_t> inverse_perm =
+ InversePermutation(transpose->dimensions());
+
+ auto expected_fmha_shape =
+ ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
+
+ // cuDNN requires the last dimension of the output to be the fastest
+ // moving.
+ if (expected_fmha_shape.layout().minor_to_major()[0] !=
+ expected_fmha_shape.dimensions_size() - 1) {
+ VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
+ "the fastest moving. The last dimension is dim: "
+ << expected_fmha_shape.dimensions_size() - 1
+ << " but the upon fusion with transpose, the fmha output shape "
+ "would have been "
+ << expected_fmha_shape.ToString(true)
+ << " and the fastest moving "
+ "dimension would be dim: "
+ << expected_fmha_shape.layout().minor_to_major()[0];
+ continue;
+ }
+ Shape call_shape = fmha->shape();
+ *call_shape.mutable_tuple_shapes(operand_tuple_idx) = expected_fmha_shape;
+ HloInstruction* new_fmha_custom_call =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ call_shape, fmha->operands(),
+ absl::string_view(fmha->custom_call_target())));
+
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
+ fmha->backend_config<GpuBackendConfig>());
+ TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
+ TF_RETURN_IF_ERROR(
+ SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
+ new_fmha_custom_call->set_metadata(fmha->metadata());
+ TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
+
+ auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_fmha_custom_call->shape().tuple_shapes(operand_tuple_idx),
+ new_fmha_custom_call, operand_tuple_idx));
+ TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+ instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
+
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "After backward FuseEpilogueTransposeWithcuDNNFMHA: \n"
+ << comp->parent()->ToString();
+ }
+ changed |= true;
+ }
+ }
+ return changed;
+}
+} // namespace
+
+absl::StatusOr<bool> CudnnFusedMHATransposeFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool any_changed = false;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ bool changed = false;
+ TF_ASSIGN_OR_RETURN(changed, FusePrologueTransposeWithcuDNNFMHA(comp));
+ any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(changed, FuseEpilogueTransposeWithcuDNNFMHA(comp));
+ any_changed |= changed;
+ }
+
+ return any_changed;
+}
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h
new file mode 100644
index 0000000..825d97e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h
@@ -0,0 +1,45 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedMHATransposeFusion : public HloModulePass {
+ public:
+ CudnnFusedMHATransposeFusion() = default;
+
+ absl::string_view name() const override {
+ return "cudnn-fused-multi-headed-attention-transpose-fusion";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc
new file mode 100644
index 0000000..519b495
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc
@@ -0,0 +1,739 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/primitive_util.h"
+#include "xla/service/dump.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cudnn_support_utils.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/kernel_reuse_cache.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/triton_fusion_analysis.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/cuda/cuda_dnn.h"
+#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace fe = cudnn_frontend;
+namespace graph = fe::graph;
+
+inline std::optional<fe::PointwiseMode_t> GetElementwiseMode(
+ const HloInstruction& instruction) {
+ const HloOpcode opcode = instruction.opcode();
+ using m = fe::PointwiseMode_t;
+ switch (opcode) {
+ case HloOpcode::kAbs:
+ return m::ABS;
+ case HloOpcode::kAdd:
+ return m::ADD;
+ case HloOpcode::kCeil:
+ return m::CEIL;
+ case HloOpcode::kCompare:
+ switch (instruction.comparison_direction()) {
+ case Comparison::Direction::kEq:
+ return m::CMP_EQ;
+ case Comparison::Direction::kNe:
+ return m::CMP_NEQ;
+ case Comparison::Direction::kGe:
+ return m::CMP_GE;
+ case Comparison::Direction::kGt:
+ return m::CMP_GT;
+ case Comparison::Direction::kLe:
+ return m::CMP_LE;
+ case Comparison::Direction::kLt:
+ return m::CMP_LT;
+ }
+ break;
+ case HloOpcode::kConvert:
+ return m::IDENTITY;
+ case HloOpcode::kCos:
+ return m::COS;
+ case HloOpcode::kDivide:
+ return m::DIV;
+ case HloOpcode::kExp:
+ return m::EXP;
+ case HloOpcode::kFloor:
+ return m::FLOOR;
+ case HloOpcode::kLog:
+ return m::LOG;
+ case HloOpcode::kMaximum:
+ return m::MAX;
+ case HloOpcode::kMinimum:
+ return m::MIN;
+ case HloOpcode::kMultiply:
+ return m::MUL;
+ case HloOpcode::kNegate:
+ return m::NEG;
+ case HloOpcode::kPower:
+ return m::POW;
+ case HloOpcode::kRsqrt:
+ return m::RSQRT;
+#if CUDNN_VERSION >= 90100
+ case HloOpcode::kSelect:
+ return m::BINARY_SELECT;
+#endif // CUDNN_VERSION
+ case HloOpcode::kSin:
+ return m::SIN;
+ case HloOpcode::kSqrt:
+ return m::SQRT;
+ case HloOpcode::kSubtract:
+ return m::SUB;
+ case HloOpcode::kTan:
+ return m::TAN;
+ case HloOpcode::kTanh:
+ return m::TANH_FWD;
+ default:
+ return std::nullopt;
+ }
+}
+
+inline std::optional<fe::DataType_t> ToCudnnDataType(const PrimitiveType type) {
+ using t = fe::DataType_t;
+ switch (type) {
+ case PrimitiveType::F32:
+ return t::FLOAT;
+ case PrimitiveType::F16:
+ return t::HALF;
+ case PrimitiveType::BF16:
+ return t::BFLOAT16;
+ case PrimitiveType::S32:
+ return t::INT32;
+ case PrimitiveType::S8:
+ return t::INT8;
+ case PrimitiveType::PRED:
+ return t::INT8;
+ case PrimitiveType::F8E5M2:
+ return t::FP8_E5M2;
+ case PrimitiveType::F8E4M3FN:
+ return t::FP8_E4M3;
+ default:
+ return std::nullopt;
+ }
+}
+
+inline std::optional<fe::DataType_t> GetComputeDataType(
+ const PrimitiveType type) {
+ fe::DataType_t compute_dtype = fe::DataType_t::FLOAT;
+ if (primitive_util::IsIntegralType(type)) {
+#if CUDNN_VERSION >= 90100
+ compute_dtype = fe::DataType_t::INT32;
+#else
+ VLOG(3) << "Integer math requires cuDNN 9.1+.";
+ return std::nullopt;
+#endif // CUDNN_VERSION
+ }
+ return compute_dtype;
+}
+
+int FusionLevel(const HloInstruction& hlo) {
+ return hlo.GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_cudnn_gemm_fusion_level();
+};
+
+// Extracts dimensions and strides from HLO tensors in the format expected by
+// cuDNN.
+class GemmDimensionAdapter {
+ explicit GemmDimensionAdapter(const HloDotInstruction& dot,
+ TritonFusionAnalysis analysis)
+ : analysis_(std::move(analysis)), dot_(dot) {};
+
+ public:
+ const TritonFusionAnalysis analysis_;
+
+ static absl::StatusOr<std::optional<GemmDimensionAdapter>> Create(
+ const HloComputation& computation) {
+ const HloInstruction* maybe_dot =
+ hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot);
+ if (maybe_dot == nullptr) {
+ VLOG(3) << "Not a GEMM fusion.";
+ return std::nullopt;
+ }
+ const HloDotInstruction* dot = DynCast<HloDotInstruction>(
+ hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot));
+ if (absl::c_any_of(dot->precision_config().operand_precision(),
+ [](int x) { return x != PrecisionConfig::DEFAULT; })) {
+ VLOG(3) << "Non-default precision is not supported.";
+ return std::nullopt;
+ }
+ TF_ASSIGN_OR_RETURN(auto analysis,
+ TritonFusionAnalysis::Execute(computation));
+ return GemmDimensionAdapter{*dot, std::move(analysis)};
+ }
+
+ bool DimensionsAndStrides(const HloInstruction& hlo,
+ const TritonFusionAnalysis::Scope scope,
+ std::vector<int64_t>& dimensions,
+ std::vector<int64_t>& strides) {
+ const DotDimensionNumbers& dims = dot_.dot_dimension_numbers();
+ // GEMM fusions require a specific canonical order of dimensions.
+ constexpr int kBatchDimensionIndex = 0;
+ constexpr int kOutputLHSNonContractingDimensionIndex = 1;
+ std::vector<int64_t> dim_indices;
+ int lhs_noncontracting_index = -1;
+ switch (scope) {
+ case TritonFusionAnalysis::Scope::LHS:
+ lhs_noncontracting_index =
+ GetNonContractingDims(dot_.operand(0)->shape(),
+ dims.lhs_batch_dimensions(),
+ dims.lhs_contracting_dimensions())
+ .value()[0];
+ dim_indices = {
+ dims.lhs_batch_dimensions().empty() ? -1
+ : dims.lhs_batch_dimensions(0),
+ lhs_noncontracting_index, dims.lhs_contracting_dimensions(0)};
+ break;
+ case TritonFusionAnalysis::Scope::RHS:
+ dim_indices = {dims.rhs_batch_dimensions().empty()
+ ? -1
+ : dims.rhs_batch_dimensions(0),
+ dims.rhs_contracting_dimensions(0),
+ GetNonContractingDims(dot_.operand(1)->shape(),
+ dims.rhs_batch_dimensions(),
+ dims.rhs_contracting_dimensions())
+ .value()[0]};
+ break;
+ case TritonFusionAnalysis::Scope::OUTPUT:
+ lhs_noncontracting_index = dot_.shape().rank() - 2;
+ dim_indices = {dims.lhs_batch_dimensions().empty() ? -1 : 0,
+ lhs_noncontracting_index, dot_.shape().rank() - 1};
+ break;
+ case TritonFusionAnalysis::Scope::META:
+ LOG(FATAL) << "Unsupported scope.";
+ }
+ dimensions.reserve(dim_indices.size());
+ strides.reserve(dim_indices.size());
+ for (const int index : dim_indices) {
+ const auto* spec = analysis_.IterSpec(scope, &hlo, index);
+ if (spec == nullptr) {
+ dimensions.push_back(1);
+ strides.push_back(strides.empty() ? 1 : strides.back());
+ continue;
+ } else {
+ if (spec->size() == 1) {
+ // The dimension is not split, nothing to do.
+ } else if (spec->size() == 2) {
+ if (FusionLevel(hlo) < 3) {
+ return false;
+ }
+ if (!dims.lhs_batch_dimensions().empty()) {
+ VLOG(8) << "Noncontracting dimension split is not compatible with "
+ "batch dimensions.";
+ return false;
+ }
+ if (index != lhs_noncontracting_index) {
+ VLOG(8) << "Only LHS noncontracting dimension can be split.";
+ return false;
+ }
+ switch (scope) {
+ case TritonFusionAnalysis::Scope::LHS:
+ lhs_noncontracting_split_ = spec->back().count;
+ break;
+ case TritonFusionAnalysis::Scope::OUTPUT:
+ if (lhs_noncontracting_split_ != spec->back().count) {
+ VLOG(8) << "Output non-contracting dimension has to be split "
+ "the same way as the LHS input one if it is split.";
+ return false;
+ }
+ break;
+ default:
+ VLOG(8) << "Only LHS noncontracting dimension can be split.";
+ return false;
+ }
+ // Assign the major part of the noncontracting dimension to the
+ // unused batch one.
+ CHECK_EQ(dimensions[kBatchDimensionIndex], 1);
+ dimensions[kBatchDimensionIndex] = spec->back().count;
+ strides[kBatchDimensionIndex] = spec->back().stride;
+ } else {
+ VLOG(8) << "The dimension is split multiple times.";
+ return false;
+ }
+ dimensions.push_back(spec->front().count);
+ strides.push_back(spec->front().stride);
+ }
+ }
+ if (lhs_noncontracting_split_ > 1 &&
+ scope == TritonFusionAnalysis::Scope::OUTPUT &&
+ dimensions[kBatchDimensionIndex] == 1) {
+ // LHS input noncontracting dimension is split but the corresponding
+ // output one is not. Assign part of the output one to the unused batch
+ // dimension.
+ dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_;
+ dimensions[kOutputLHSNonContractingDimensionIndex] /=
+ lhs_noncontracting_split_;
+ strides[kBatchDimensionIndex] =
+ strides[kOutputLHSNonContractingDimensionIndex] *
+ dimensions[kOutputLHSNonContractingDimensionIndex];
+ }
+ return true;
+ }
+
+ private:
+ int64_t lhs_noncontracting_split_ = 1;
+ const HloDotInstruction& dot_;
+};
+
+template <PrimitiveType XlaT, typename T>
+std::shared_ptr<graph::Tensor_attributes> LiteralToCudnnTensor(
+ const HloInstruction& hlo, graph::Graph& graph) {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<XlaT>::type;
+ return graph.tensor(T(hlo.literal().GetFirstElement<NativeT>()));
+}
+
+std::optional<std::shared_ptr<graph::Tensor_attributes>>
+HandleConstantHloToCudnnGraph(const HloInstruction& hlo, graph::Graph& graph) {
+ CHECK(hlo.IsConstant()) << "HLO is not a constant: " << hlo.ToShortString();
+ if (!ShapeUtil::IsScalar(hlo.shape())) {
+ VLOG(3) << "Currently only support fusing scalar in the graph";
+ return std::nullopt;
+ }
+ PrimitiveType constant_type = hlo.shape().element_type();
+ switch (constant_type) {
+ case BF16:
+ return LiteralToCudnnTensor<BF16, __nv_bfloat16>(hlo, graph);
+ case F32:
+ return LiteralToCudnnTensor<F32, float>(hlo, graph);
+ case S32:
+ return LiteralToCudnnTensor<S32, int>(hlo, graph);
+ default:
+ VLOG(3) << "Unsupported constant type: "
+ << PrimitiveType_Name(constant_type);
+ return std::nullopt;
+ }
+}
+
+std::optional<std::shared_ptr<graph::Tensor_attributes>>
+HandleClampToCudnnGraph(
+ const HloInstruction& hlo, graph::Graph& graph,
+ absl::flat_hash_map<const HloInstruction*,
+ std::shared_ptr<graph::Tensor_attributes>>
+ hlo_to_cudnn,
+ fe::DataType_t data_type, fe::DataType_t compute_dtype) {
+ CHECK(hlo.opcode() == HloOpcode::kClamp)
+ << "HLO is not a clamp: " << hlo.ToShortString();
+ CHECK(hlo.operands().size() == 3)
+ << "Clamp requires to have 3 operands: " << hlo.ToShortString();
+ // clamp = max(lower, min(value, upper));
+ const auto min_attrs = graph::Pointwise_attributes()
+ .set_mode(fe::PointwiseMode_t::MIN)
+ .set_compute_data_type(compute_dtype);
+ std::shared_ptr<graph::Tensor_attributes> min_tensor = graph.pointwise(
+ hlo_to_cudnn[hlo.operand(1)], hlo_to_cudnn[hlo.operand(2)], min_attrs);
+ min_tensor->set_data_type(data_type).set_name(std::string(hlo.name()));
+ const auto max_attrs = graph::Pointwise_attributes()
+ .set_mode(fe::PointwiseMode_t::MAX)
+ .set_compute_data_type(compute_dtype);
+ return graph.pointwise(min_tensor, hlo_to_cudnn[hlo.operand(0)], max_attrs);
+}
+
+// Traverses fusion computations and creates cuDNN graphs out of them.
+absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
+ const HloFusionInstruction& fusion) {
+ const HloComputation& computation = *fusion.fused_instructions_computation();
+ VLOG(5) << fusion.ToString();
+ VLOG(5) << computation.ToString();
+ graph::Graph graph;
+ std::vector<HloInstruction*> instructions =
+ computation.MakeInstructionPostOrder();
+ absl::flat_hash_map<const HloInstruction*,
+ std::shared_ptr<graph::Tensor_attributes>>
+ hlo_to_cudnn;
+ TF_ASSIGN_OR_RETURN(std::optional<GemmDimensionAdapter> adapter,
+ GemmDimensionAdapter::Create(computation));
+ if (!adapter.has_value()) {
+ return std::nullopt;
+ }
+ auto add_parameter = [&](const HloInstruction& parameter,
+ std::vector<int64_t>& dimensions,
+ std::vector<int64_t> strides) {
+ const std::optional<fe::DataType_t> data_type =
+ ToCudnnDataType(parameter.shape().element_type());
+ if (!data_type.has_value()) {
+ VLOG(3) << "Unsupported data type.";
+ return false;
+ }
+ hlo_to_cudnn[¶meter] = graph.tensor(
+ graph::Tensor_attributes()
+ .set_dim(dimensions)
+ .set_stride(strides)
+ .set_data_type(*data_type)
+ .set_name(std::string(parameter.name()))
+ .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number())));
+ return true;
+ };
+ for (const TritonFusionAnalysis::Scope scope :
+ {TritonFusionAnalysis::Scope::LHS, TritonFusionAnalysis::Scope::RHS,
+ TritonFusionAnalysis::Scope::OUTPUT}) {
+ for (const HloInstruction* parameter :
+ adapter->analysis_.ScopeParameters(scope)) {
+ std::vector<int64_t> dimensions;
+ std::vector<int64_t> strides;
+ if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions,
+ strides)) {
+ VLOG(3) << "Unsupported dimensions.";
+ return std::nullopt;
+ }
+ if (!add_parameter(*parameter, dimensions, strides)) {
+ return std::nullopt;
+ }
+ }
+ }
+
+ for (const HloInstruction* hlo : instructions) {
+ VLOG(5) << hlo->ToShortString();
+ auto operand = [&hlo_to_cudnn, &hlo](int i) {
+ return hlo_to_cudnn[hlo->operand(i)];
+ };
+ const auto data_type = ToCudnnDataType(hlo->shape().element_type());
+ if (!data_type.has_value()) {
+ VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type();
+ return std::nullopt;
+ }
+ if (hlo->opcode() == HloOpcode::kParameter) {
+ CHECK(hlo_to_cudnn.contains(hlo));
+ continue;
+ } else if (hlo->opcode() == HloOpcode::kCustomCall) {
+ if (hlo->user_count() != 1 ||
+ !IsWorkspaceAllocationRoot(*hlo->users()[0])) {
+ VLOG(3) << "Custom calls are only expected to be used for workspace "
+ "allocation.";
+ return std::nullopt;
+ }
+ continue;
+ } else if (hlo->opcode() == HloOpcode::kTuple) {
+ if (!IsWorkspaceAllocationRoot(*hlo)) {
+ VLOG(3) << "Tuples are only expected at outputs for workspace "
+ "allocation.";
+ return std::nullopt;
+ }
+ continue;
+ } else if (FusionLevel(fusion) >= 2 &&
+ hlo->opcode() == HloOpcode::kConstant) {
+ if (const auto const_tensor = HandleConstantHloToCudnnGraph(*hlo, graph);
+ const_tensor.has_value()) {
+ hlo_to_cudnn[hlo] = const_tensor.value();
+ } else {
+ return std::nullopt;
+ }
+ } else if (hlo->opcode() == HloOpcode::kReshape ||
+ hlo->opcode() == HloOpcode::kBitcast ||
+ hlo->opcode() == HloOpcode::kTranspose ||
+ hlo->opcode() == HloOpcode::kCopy ||
+ (FusionLevel(fusion) >= 2 &&
+ hlo->opcode() == HloOpcode::kBroadcast)) {
+ // All these are accounted for separately as transformations of strides.
+ hlo_to_cudnn[hlo] = operand(0);
+ } else if (hlo->IsElementwise()) {
+ const auto compute_dtype =
+ GetComputeDataType(hlo->shape().element_type());
+ if (!compute_dtype.has_value()) {
+ return std::nullopt;
+ }
+ if (hlo->opcode() == HloOpcode::kClamp) {
+ const auto clamp =
+ HandleClampToCudnnGraph(*hlo, graph, hlo_to_cudnn,
+ data_type.value(), compute_dtype.value());
+ if (!clamp.has_value()) {
+ return std::nullopt;
+ }
+ hlo_to_cudnn[hlo] = clamp.value();
+ } else {
+ const auto mode = GetElementwiseMode(*hlo);
+ if (!mode.has_value()) {
+ VLOG(3) << "Unsupported elementwise operation.";
+ return std::nullopt;
+ }
+ const auto attrs = graph::Pointwise_attributes()
+ .set_mode(mode.value())
+ .set_compute_data_type(compute_dtype.value());
+ if (hlo->operand_count() == 1) {
+ hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs);
+ // Sets the dimensions for unary ops whose operands are broadcast
+ // for cuDNN to infer its inputs' shapes. constant has dimension [1]
+ // while cuDNN requires constant to have dimension [1,1,1]. Not
+ // setting output of the unary shapes results in the rejection of
+ // the cuDNN graph.
+ if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) {
+ const auto scope = adapter->analysis_.QueryInstructionScope(*hlo);
+ std::vector<int64_t> dimensions;
+ std::vector<int64_t> strides;
+ if (!scope.has_value()) {
+ LOG(FATAL) << "No scope for instruction: "
+ << hlo->ToShortString();
+ }
+ if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
+ strides)) {
+ VLOG(3) << "Unsupported hlo for querying dimensions: "
+ << hlo->ToShortString();
+ } else {
+ hlo_to_cudnn[hlo]->set_dim(dimensions);
+ }
+ }
+ } else if (hlo->operand_count() == 2) {
+ hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs);
+ } else if (hlo->operand_count() == 3) {
+ if (hlo->opcode() != HloOpcode::kSelect) {
+ VLOG(3) << "Unexpected ternary operation: " << hlo->ToString();
+ return std::nullopt;
+ }
+ // Operand order for select differs between HLO and cuDNN.
+ hlo_to_cudnn[hlo] =
+ graph.pointwise(operand(1), operand(2), operand(0), attrs);
+ } else {
+ VLOG(3) << "Unimplemented elementwise operation.";
+ return std::nullopt;
+ }
+ }
+ } else if (hlo->opcode() == HloOpcode::kDot) {
+ const auto compute_dtype =
+ GetComputeDataType(hlo->shape().element_type());
+ if (!compute_dtype.has_value()) {
+ return std::nullopt;
+ }
+ hlo_to_cudnn[hlo] =
+ graph.matmul(operand(0), operand(1),
+ graph::Matmul_attributes().set_compute_data_type(
+ compute_dtype.value()));
+ } else {
+ VLOG(3) << "Unimplemented operation.";
+ return std::nullopt;
+ }
+ if (hlo_to_cudnn[hlo] == nullptr) {
+ VLOG(3) << "Creation of the operation failed.";
+ return std::nullopt;
+ }
+ hlo_to_cudnn[hlo]
+ ->set_data_type(data_type.value())
+ .set_name(std::string(hlo->name()));
+ }
+ const HloInstruction* output = instructions.back();
+ if (instructions.back()->shape().IsTuple()) {
+ output = instructions.back()->operand(0);
+ }
+ std::vector<int64_t> dimensions;
+ std::vector<int64_t> strides;
+ if (!adapter->DimensionsAndStrides(
+ *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) {
+ VLOG(3) << "Unsupported dimensions.";
+ return std::nullopt;
+ }
+ hlo_to_cudnn[output]
+ ->set_output(true)
+ .set_dim(dimensions)
+ .set_stride(strides)
+ .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count()));
+ if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) {
+ json dump;
+ graph.serialize(dump);
+ DumpToFileInDirOrStdout(
+ /*module=*/*fusion.GetModule(),
+ /*file_prefix=*/"",
+ /*file_suffix=*/
+ absl::StrCat("cudnn_fusion_", fusion.name(), ".json"),
+ /*contents=*/dump.dump(1));
+ }
+
+ return se::gpu::CudnnGraph(std::move(graph));
+}
+
+// Creates a cuDNN graph, queries cuDNN whether it is supported.
+absl::StatusOr<se::gpu::CudnnGraph> PrepareGraph(
+ se::dnn::DnnSupport& dnn_support, const HloFusionInstruction& hlo) {
+ TF_ASSIGN_OR_RETURN(std::optional<se::gpu::CudnnGraph> graph,
+ HloFusionToCuDnnGraph(hlo));
+ if (!graph.has_value()) {
+ return absl::InternalError("Construction of cuDNN graph failed.");
+ }
+ TF_RETURN_IF_ERROR(graph->Prepare(
+ dnn_support,
+ se::NumericOptions{RequireDeterminism(hlo.GetModule()->config()),
+ /*allow_tf32=*/true}));
+ return *graph;
+}
+
+absl::StatusOr<HloInstruction*> AddWorkspace(HloInstruction& fusion,
+ const int64_t workspace_size) {
+ HloComputation* computation = fusion.fused_instructions_computation();
+ HloInstruction* custom_call =
+ computation->AddInstruction(HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeShape(S8, {workspace_size}), {},
+ kWorkspaceAllocationCustomCallTarget));
+ HloInstruction* output_tuple =
+ computation->AddInstruction(HloInstruction::CreateTuple(
+ {computation->root_instruction(), custom_call}));
+ computation->set_root_instruction(output_tuple, true);
+ HloInstruction* new_fusion = fusion.parent()->AddInstruction(
+ fusion.CloneWithNewShape(output_tuple->shape()));
+ TF_RETURN_IF_ERROR(fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(new_fusion, 0))));
+ TF_RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion));
+ return new_fusion;
+}
+
+class CuDnnFusionVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit CuDnnFusionVisitor(se::dnn::DnnSupport& dnn_support,
+ BinaryMap& compilation_results)
+ : dnn_support_(dnn_support), compilation_results_(compilation_results) {}
+
+ absl::Status HandleFusion(HloInstruction* hlo) override {
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ hlo->backend_config<GpuBackendConfig>());
+ const auto& fusion_backend_config = gpu_config.fusion_backend_config();
+ if (fusion_backend_config.kind() != kCuDnnFusionKind) {
+ return absl::OkStatus();
+ }
+ int64_t plan_id = -1;
+ if (fusion_backend_config.has_cudnn_fusion_config()) {
+ plan_id = fusion_backend_config.cudnn_fusion_config().plan_id();
+ }
+
+ VLOG(4) << "Processing " << hlo->ToString();
+ VLOG(4) << "Plan ID: " << plan_id;
+
+ auto add_workspace = [&](const int64_t workspace_size) {
+ if (workspace_size > 0) {
+ TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size));
+ SetVisited(*hlo);
+ }
+ return absl::OkStatus();
+ };
+ const std::string fingerprint_without_workspace =
+ GetComputationFingerprint(hlo->fused_instructions_computation(), {});
+ auto workspace_size_it =
+ workspace_sizes_.find(fingerprint_without_workspace);
+ if (workspace_size_it == workspace_sizes_.cend()) {
+ TF_ASSIGN_OR_RETURN(
+ se::gpu::CudnnGraph graph,
+ PrepareGraph(dnn_support_, *DynCast<HloFusionInstruction>(hlo)));
+
+ if (plan_id >= 0) {
+ // Build single plan with given ID.
+ if (plan_id >= graph.Graph().get_execution_plan_count()) {
+ return absl::InternalError("cuDNN graph plan does not exist.");
+ }
+ TF_RETURN_IF_ERROR(graph.Build(dnn_support_, plan_id));
+ } else {
+ // Build plans one by one till first successful when no plan_id was
+ // provided.
+ for (plan_id = 0; plan_id < graph.Graph().get_execution_plan_count();
+ ++plan_id) {
+ VLOG(7) << "Trying plan ID " << plan_id;
+ if (graph.Build(dnn_support_, plan_id).ok()) {
+ VLOG(7) << "Successfully built plan ID " << plan_id;
+ break;
+ }
+ }
+ if (plan_id == graph.Graph().get_execution_plan_count()) {
+ return absl::InternalError("No cuDNN plans can be built.");
+ }
+ }
+ const int64_t workspace_size = graph.Graph().get_workspace_size();
+ workspace_sizes_.insert(workspace_size_it,
+ {fingerprint_without_workspace, workspace_size});
+ TF_RETURN_IF_ERROR(add_workspace(workspace_size));
+
+ std::vector<uint8_t> serialized_graph;
+ RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph));
+ // Compute a new fingerprint with a potential workspace for the
+ // compilation results to match a fingerprint computed by the emitter.
+ compilation_results_[GetComputationFingerprint(
+ hlo->fused_instructions_computation(), {})] =
+ std::string(reinterpret_cast<char*>(serialized_graph.data()),
+ serialized_graph.size());
+ } else {
+ VLOG(4) << "Cache hit.";
+ TF_RETURN_IF_ERROR(add_workspace(workspace_size_it->second));
+ }
+ auto cudnn_config = gpu_config.mutable_fusion_backend_config()
+ ->mutable_cudnn_fusion_config();
+ cudnn_config->set_plan_id(plan_id);
+ TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
+
+ MarkAsChanged();
+ return absl::OkStatus();
+ }
+
+ private:
+ se::dnn::DnnSupport& dnn_support_;
+ // <HLO computation fingerprint, serialized compiled cuDNN graph>.
+ BinaryMap& compilation_results_;
+ absl::flat_hash_map<std::string, int64_t> workspace_sizes_;
+};
+
+} // namespace
+
+absl::StatusOr<bool> CuDnnFusionCompiler::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_SCOPED_LOGGING_TIMER("cuDNN fusion compiler");
+ return CuDnnFusionVisitor(dnn_support_, compilation_results_)
+ .RunOnModule(module, execution_threads);
+}
+
+int CuDnnFusionCompiler::GetAvailablePlanCount(
+ se::StreamExecutor& stream_exec, const HloFusionInstruction& hlo) {
+ auto graph = PrepareGraph(*stream_exec.AsDnn(), hlo);
+ if (!graph.ok()) {
+ return 0;
+ }
+ return std::min(
+ static_cast<int32_t>(graph->Graph().get_execution_plan_count()),
+ hlo.GetModule()->config().debug_options().xla_gpu_cudnn_gemm_max_plans());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h
new file mode 100644
index 0000000..4917914
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h
@@ -0,0 +1,59 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// Converts HLO fusions with cuDNN backend config to cuDNN graphs,
+// compiles them using a cuDNN handle and serializes them.
+class CuDnnFusionCompiler : public HloModulePass {
+ public:
+ explicit CuDnnFusionCompiler(se::StreamExecutor& stream_exec,
+ BinaryMap& compilation_results)
+ : dnn_support_(*stream_exec.AsDnn()),
+ compilation_results_(compilation_results) {}
+
+ absl::string_view name() const override { return "cudnn-fusion-compiler"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ static int GetAvailablePlanCount(se::StreamExecutor& stream_exec,
+ const HloFusionInstruction& hlo);
+
+ private:
+ se::dnn::DnnSupport& dnn_support_;
+ BinaryMap& compilation_results_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc
new file mode 100644
index 0000000..5d5e089
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc
@@ -0,0 +1,1553 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_norm_rewriter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/wrappers.pb.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/types.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/protobuf/dnn.pb.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#endif
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace m = match;
+
+// Traverses the graph upward starting at instr and returns the
+// first instruction that is not a convert, bitcast or reshape.
+const HloInstruction* SkipUnaryOps(const HloInstruction* instr) {
+ while (instr->opcode() == HloOpcode::kConvert ||
+ instr->opcode() == HloOpcode::kBitcast ||
+ instr->opcode() == HloOpcode::kReshape) {
+ instr = instr->operand(0);
+ }
+ return instr;
+}
+
+// Recursively traverses the graph downward starting at instr and stores in
+// instrs the users that are not a convert, bitcast or reshape.
+void SkipUnaryOpsTopDownRecursive(HloInstruction* instr,
+ std::vector<HloInstruction*>& instrs) {
+ if (instr->opcode() == HloOpcode::kConvert ||
+ instr->opcode() == HloOpcode::kBitcast ||
+ instr->opcode() == HloOpcode::kReshape) {
+ for (HloInstruction* user : instr->users()) {
+ SkipUnaryOpsTopDownRecursive(user, instrs);
+ }
+ } else {
+ instrs.emplace_back(instr);
+ }
+}
+
+// Holds auxiliary information about individual layer norm patterns rewritten
+// into a cuDNN Custom Call.
+struct NormMetadata {
+ // Transposes applied to the input and output of the forward layer norm to
+ // order the normalization and non-normalization dimensions as required by
+ // cuDNN. Nullptr if no transposes were inserted.
+ HloInstruction *x_transpose, *y_transpose;
+ // The reduction and non-reduction dimensions of the input into the forward
+ // layer norm before the potential application of transposes and adjusted for
+ // the removal of any degenerate dimensions in the input to the norm.
+ std::vector<int64_t> norm_dims_adjusted, non_norm_dims_adjusted;
+};
+
+// Map from the instruction pointer of a layer norm Custom Call to its metadata.
+using NormMetadataMap = absl::flat_hash_map<HloInstruction*, NormMetadata>;
+
+// Captures multiple HloInstruction pointers and verifies that their target
+// is identical.
+//
+// Example:
+// Pattern cos(x) / sin(x) with cos and sin intended to operate on the same
+// HloInstruction:
+// UniqueHloInstruction x;
+// bool m = Match(
+// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)),
+// m::Sin(m::Op().WithPredicate(x.capture_and_verify))));
+// m is true and x.Instr() returns an HloInstruction pointer to the operand of
+// cosine and sine iff HloInstruction *instr points to a division of a cosine by
+// a sine that operate on the same instruction.
+class UniqueHloInstruction {
+ public:
+ UniqueHloInstruction()
+ : is_set_(false), instr_(nullptr), capture_or_verify_() {}
+ HloInstruction* Instr() const { return instr_; }
+ void SetInstr(HloInstruction* instr) {
+ is_set_ = true;
+ instr_ = instr;
+ }
+
+ // Stores instr when invoked the first time. Otherwise, compares instr to the
+ // stored value and sets the stored value to nullptr if the comparison fails.
+ bool CaptureOrVerify(HloInstruction* instr) {
+ if (is_set_ && instr != instr_) {
+ instr_ = nullptr;
+ }
+ if (!is_set_) {
+ is_set_ = true;
+ instr_ = instr;
+ }
+ return instr_;
+ }
+
+ // Returns a std::function for capturing or verifying an instruction using
+ // WithPredicate.
+ std::function<bool(const HloInstruction*)> GetCaptureOrVerifyFn() {
+ if (!capture_or_verify_) {
+ capture_or_verify_ = [this](const HloInstruction* instr) -> bool {
+ return CaptureOrVerify(const_cast<HloInstruction*>(instr));
+ };
+ }
+ return capture_or_verify_;
+ }
+
+ private:
+ bool is_set_;
+ HloInstruction* instr_;
+ std::function<bool(const HloInstruction*)> capture_or_verify_;
+};
+
+// Returns an architecture-specific constant for the calculation of an upper
+// bound for the size of the scratch space for layer norm kernels.
+absl::StatusOr<int64_t> CConstant(
+ se::CudaComputeCapability cuda_compute_capability) {
+ if (cuda_compute_capability.major == se::CudaComputeCapability::AMPERE) {
+ return 32 * 128;
+ } else if (cuda_compute_capability.major ==
+ se::CudaComputeCapability::HOPPER) {
+ return 32 * 144;
+ }
+ return xla::Internal("Norm kernels require Ampere or Hopper architecture.");
+}
+
+// Returns whether the element type of instr is compatible with layer norm
+// kernels.
+bool CompatibleElementType(const HloInstruction* instr) {
+ PrimitiveType element_type = instr->shape().element_type();
+ return element_type == BF16 || element_type == F16 || element_type == F32;
+}
+
+// Returns the dimensions associated with shape, adjusted for the removal of any
+// degenerate dimensions in shape. Specifically, for each dimension d in
+// dimensions, returns the new index of d if all dimensions of size 1 are
+// removed from shape. If d has size 1, it is not included in the returned
+// vector.
+std::vector<int64_t> AdjustedDimensions(const Shape& shape,
+ absl::Span<const int64_t> dimensions) {
+ absl::flat_hash_map<int64_t, int64_t> dimension_map;
+ for (int64_t dimension = 0, non_degen_dimension = 0; dimension < shape.rank();
+ ++dimension) {
+ if (shape.dimensions(dimension) > 1) {
+ dimension_map.insert({dimension, non_degen_dimension});
+ non_degen_dimension++;
+ }
+ }
+ std::vector<int64_t> adjusted_dimensions;
+ for (int64_t dimension : dimensions) {
+ auto non_degenerate_dimension = dimension_map.find(dimension);
+ if (non_degenerate_dimension != dimension_map.end()) {
+ adjusted_dimensions.emplace_back(non_degenerate_dimension->second);
+ }
+ }
+ return adjusted_dimensions;
+}
+
+// Returns the dimensions of broadcast or reduction instructions, adjusted for
+// the removal of any degenerate dimensions in the output or input.
+std::vector<int64_t> AdjustedDimensions(const HloInstruction* instr) {
+ Shape shape;
+ if (instr->opcode() == HloOpcode::kBroadcast) {
+ shape = instr->shape();
+ } else if (instr->opcode() == HloOpcode::kReduce) {
+ shape = instr->operand(0)->shape();
+ } else {
+ return {};
+ }
+ return AdjustedDimensions(shape, instr->dimensions());
+}
+
+// Returns whether the HLO Computation applied by instr calculates the sum of
+// the elements. When provided, compares reduce_dims to the dimensions of the
+// reduction.
+bool AppliesAddReduce(const HloInstruction* instr,
+ absl::Span<const int64_t> reduce_dims = {}) {
+ if (instr->opcode() != HloOpcode::kReduce) {
+ return false;
+ }
+
+ // Verify the dimensions of the reduction.
+ if (!reduce_dims.empty() && AdjustedDimensions(instr) != reduce_dims) {
+ return false;
+ }
+
+ HloComputation* reduce_comp = instr->to_apply();
+ HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
+ return instr->operand_count() == 2 &&
+ instr->operand(1)->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::IsScalar(instr->operand(1)->shape()) &&
+ instr->operand(1)->literal().GetAsDouble({}) == 0. &&
+ reduce_comp_root->opcode() == HloOpcode::kAdd &&
+ reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
+ reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
+}
+
+// Returns whether instr multiplies the result of a reduction by one over the
+// number of reduced elements.
+bool CalculatesExpectation(const HloInstruction* instr) {
+ instr = SkipUnaryOps(instr);
+ if (instr->opcode() != HloOpcode::kMultiply) {
+ return false;
+ }
+ bool bcast_operand = instr->operand(0)->opcode() != HloOpcode::kBroadcast;
+ const HloInstruction *broadcast = instr->operand(bcast_operand),
+ *reduce = SkipUnaryOps(instr->operand(!bcast_operand));
+ if (reduce->opcode() != HloOpcode::kReduce ||
+ broadcast->opcode() != HloOpcode::kBroadcast ||
+ broadcast->operand(0)->opcode() != HloOpcode::kConstant) {
+ return false;
+ }
+
+ float actual_r_nelems =
+ broadcast->operand(0)->literal().GetAsDouble({}).value();
+ int64_t nelems = 1;
+ for (int64_t norm_dim : reduce->dimensions()) {
+ nelems *= reduce->operand(0)->shape().dimensions()[norm_dim];
+ }
+ // The absolute of the difference between the actual scaling factor and the
+ // reference value must not exceed a prescribed threshold.
+ float r_nelems = 1. / static_cast<float>(nelems);
+ float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
+ return abs(actual_r_nelems - r_nelems) <
+ ((actual_r_nelems + r_nelems) * numerical_epsilon);
+}
+
+// Returns whether target can be reached from instr by recursively traversing
+// the graph across converts, bitcasts and reshapes.
+bool FindTargetRecursive(
+ const HloInstruction* instr, const HloInstruction* target,
+ absl::flat_hash_set<const HloInstruction*>& visited_instrs,
+ const HloInstruction* transpose) {
+ visited_instrs.emplace(instr);
+ const absl::flat_hash_set<HloOpcode> supported_ops = {
+ HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
+ if (instr == target) {
+ return true;
+ }
+ // Look for target among the users of instr.
+ for (HloInstruction* user : instr->users()) {
+ if ((supported_ops.contains(user->opcode()) || user == transpose) &&
+ !visited_instrs.contains(user)) {
+ return FindTargetRecursive(user, target, visited_instrs, transpose);
+ }
+ }
+ // Ascend the graph if target is not found and instr is a convert, bitcast
+ // or reshape.
+ if (supported_ops.contains(instr->opcode())) {
+ return FindTargetRecursive(instr->operand(0), target, visited_instrs,
+ transpose);
+ }
+ return false;
+}
+
+bool FindTarget(const HloInstruction* custom_call, const HloInstruction* instr,
+ const HloInstruction* target,
+ const NormMetadataMap& norm_metadata) {
+ absl::flat_hash_set<const HloInstruction*> visited_instrs;
+ auto custom_call_metadata = norm_metadata.find(custom_call);
+ if (custom_call_metadata == norm_metadata.end()) {
+ return false;
+ }
+ return FindTargetRecursive(instr, target, visited_instrs,
+ custom_call_metadata->second.x_transpose);
+}
+
+// Maps the dimension numbers in dimensions from shape original_shape to shape
+// reshaped_shape, assuming that the shapes are related through a strict
+// reshape. Returns an empty vector if a dimension mapping is not found.
+std::vector<int64_t> MapDimensions(const Shape& original_shape,
+ const Shape& reshaped_shape,
+ const absl::Span<const int64_t> dimensions) {
+ auto dimension_product =
+ [](const Shape& shape,
+ absl::Span<const int64_t> product_dimensions) -> int64_t {
+ int64_t product = 1;
+ for (int64_t product_dimension : product_dimensions) {
+ product *= shape.dimensions(product_dimension);
+ }
+ return product;
+ };
+ // Construct the dimension mapping.
+ absl::flat_hash_map<int64_t, std::vector<int64_t>> dimensions_map;
+ std::vector<int64_t> original_dimensions, reshaped_dimensions;
+ for (int64_t original_dimension = 0, reshaped_dimension = 0;
+ original_dimension < original_shape.rank(); ++original_dimension) {
+ original_dimensions.emplace_back(original_dimension);
+ while ((reshaped_dimensions.empty() ||
+ dimension_product(reshaped_shape, reshaped_dimensions) <
+ dimension_product(original_shape, original_dimensions)) &&
+ reshaped_dimension < reshaped_shape.rank()) {
+ reshaped_dimensions.emplace_back(reshaped_dimension++);
+ }
+
+ // Many-to-many dimension mappings are not supported.
+ if (original_dimensions.size() > 1 && reshaped_dimensions.size() > 1) {
+ return {};
+ }
+
+ if (dimension_product(original_shape, original_dimensions) ==
+ dimension_product(reshaped_shape, reshaped_dimensions)) {
+ std::vector<int64_t> original_dimensions_in_dimensions;
+ std::set_intersection(
+ original_dimensions.begin(), original_dimensions.end(),
+ dimensions.begin(), dimensions.end(),
+ std::back_inserter(original_dimensions_in_dimensions));
+ // The unique mapping of dimensions requires either all or none of the
+ // entries of original_dimensions to be an element of dimensions.
+ if (!original_dimensions_in_dimensions.empty() &&
+ original_dimensions_in_dimensions.size() !=
+ original_dimensions.size()) {
+ return {};
+ }
+ for (int64_t dimension : original_dimensions) {
+ dimensions_map.insert({dimension, reshaped_dimensions});
+ }
+ original_dimensions.clear();
+ reshaped_dimensions.clear();
+ }
+ }
+
+ // Map the dimensions numbers to the reshaped shape.
+ std::vector<int64_t> mapped_dimensions;
+ for (int64_t dimension : dimensions) {
+ auto mapped_dimension = dimensions_map.find(dimension);
+ if (mapped_dimension == dimensions_map.end()) {
+ return {};
+ }
+ mapped_dimensions.insert(mapped_dimensions.end(),
+ mapped_dimension->second.begin(),
+ mapped_dimension->second.end());
+ }
+
+ // Eliminate duplicates in the mapped dimension numbers.
+ mapped_dimensions.erase(
+ std::unique(mapped_dimensions.begin(), mapped_dimensions.end()),
+ mapped_dimensions.end());
+ return mapped_dimensions;
+}
+
+// Recursively traverses the graph across converts, bitcasts and reshapes,
+// starting from instr, and returns the first addition-reduction identified.
+// Returns nullptr if no addition-reduction is found.
+HloInstruction* FindAddReduceRecursive(
+ HloInstruction* instr, const Shape& orig_instr_shape,
+ const absl::Span<const int64_t> reduce_dims,
+ absl::flat_hash_set<HloInstruction*>& visited_instrs) {
+ visited_instrs.emplace(instr);
+ const absl::flat_hash_set<HloOpcode> supported_ops = {
+ HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
+ // Look for a reduction among the users of instr.
+ for (HloInstruction* user : instr->users()) {
+ if (user->opcode() == HloOpcode::kReduce) {
+ std::vector<int64_t> mapped_reduce_dims =
+ MapDimensions(orig_instr_shape, instr->shape(), reduce_dims);
+ if (!mapped_reduce_dims.empty() &&
+ AppliesAddReduce(user, mapped_reduce_dims)) {
+ return user;
+ }
+ }
+ if (supported_ops.contains(user->opcode()) &&
+ !visited_instrs.contains(user)) {
+ return FindAddReduceRecursive(user, orig_instr_shape, reduce_dims,
+ visited_instrs);
+ }
+ }
+ // Ascend the graph if the addition-reduction is not found and instr is a
+ // convert, bitcast or reshape.
+ if (supported_ops.contains(instr->opcode())) {
+ return FindAddReduceRecursive(instr->mutable_operand(0), orig_instr_shape,
+ reduce_dims, visited_instrs);
+ }
+ return nullptr;
+}
+
+HloInstruction* FindAddReduce(HloInstruction* instr,
+ const absl::Span<const int64_t> reduce_dims) {
+ absl::flat_hash_set<HloInstruction*> visited_instrs;
+ return FindAddReduceRecursive(instr, instr->shape(), reduce_dims,
+ visited_instrs);
+}
+
+// Type conversion from and to any of BF16, FP16 and FP32.
+template <typename Pattern>
+auto SupportedConvert(Pattern pattern) {
+ auto supported_convert = [](const HloInstruction* instr) -> bool {
+ return CompatibleElementType(instr) &&
+ CompatibleElementType(instr->operand(0));
+ };
+ return m::Convert(pattern).WithPredicate(supported_convert);
+}
+
+// Bitcast or reshape adding or removing degenerate dimensions.
+template <typename Pattern>
+auto SupportedBitcastOrReshape(Pattern pattern) {
+ auto supported_bitcast_or_reshape = [](const HloInstruction* instr) -> bool {
+ return ShapeUtil::Equal(
+ ShapeUtil::DropDegenerateDimensions(instr->shape()),
+ ShapeUtil::DropDegenerateDimensions(instr->operand(0)->shape()));
+ };
+ return m::AnyOf<HloInstruction>(
+ m::Bitcast(pattern).WithPredicate(supported_bitcast_or_reshape),
+ m::Reshape(pattern).WithPredicate(supported_bitcast_or_reshape));
+}
+
+// Matches pattern, SupportedConvert(pattern),
+// SupportedBitcastOrReshape(pattern),
+// SupportedConvert(SupportedBitcastOrReshape(pattern)) and
+// SupportedBitcastOrReshape(SupportedConvert(pattern)).
+template <typename Pattern>
+auto OptionalSupportedTransform(Pattern pattern) {
+ auto shared_subpattern = m::SharedSubpattern(pattern);
+ return m::AnyOf<HloInstruction>(
+ SupportedConvert(SupportedBitcastOrReshape(shared_subpattern)),
+ SupportedBitcastOrReshape(SupportedConvert(shared_subpattern)),
+ SupportedConvert(shared_subpattern),
+ SupportedBitcastOrReshape(shared_subpattern), shared_subpattern);
+}
+
+// Bitcast or reshape with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern>
+auto BitcastOrReshape(Pattern pattern) {
+ return OptionalSupportedTransform(
+ m::AnyOf<HloInstruction>(m::Bitcast(pattern), m::Reshape(pattern)));
+}
+
+// Transpose with optional supported type conversion and/or addition or removal
+// of degenerate dimensions.
+template <typename Pattern>
+auto Transpose(Pattern pattern) {
+ return OptionalSupportedTransform(m::Transpose(pattern));
+}
+
+// Rsqrt with optional supported type conversion and/or addition or removal of
+// degenerate dimensions.
+template <typename Pattern>
+auto Rsqrt(HloInstruction** rsqrt, Pattern pattern) {
+ return OptionalSupportedTransform(m::Rsqrt(rsqrt, pattern));
+}
+
+// AddAnyOrder with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto AddAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
+ return OptionalSupportedTransform(m::AddAnyOrder(pattern0, pattern1));
+}
+
+// Subtract with optional supported type conversion and/or addition or removal
+// of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto Subtract(Pattern0 pattern0, Pattern1 pattern1) {
+ return OptionalSupportedTransform(m::Subtract(pattern0, pattern1));
+}
+
+// Capturing subtract with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto Subtract(HloInstruction** subtract, Pattern0 pattern0, Pattern1 pattern1) {
+ return OptionalSupportedTransform(m::Subtract(subtract, pattern0, pattern1));
+}
+
+// Multiply with optional supported type conversion and/or addition or removal
+// of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto MultiplyAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
+ return OptionalSupportedTransform(m::MultiplyAnyOrder(pattern0, pattern1));
+}
+
+// Capturing multiply with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto MultiplyAnyOrder(HloInstruction** multiply, Pattern0 pattern0,
+ Pattern1 pattern1) {
+ return OptionalSupportedTransform(
+ m::MultiplyAnyOrder(multiply, pattern0, pattern1));
+}
+
+// Multiplication of pattern by itself with optional supported type conversion
+// and/or addition or removal of degenerate dimensions.
+template <typename Pattern>
+auto Square(Pattern pattern) {
+ return MultiplyAnyOrder(pattern, pattern)
+ .WithPredicate([](const HloInstruction* instr) {
+ return instr->unique_operands().size() == 1;
+ });
+}
+
+// Multiplication of the square of pattern by pattern with optional supported
+// type conversion and/or addition or removal of degenerate dimensions. The root
+// instruction of pattern cannot be a multiplication.
+template <typename Pattern>
+auto Cube(Pattern pattern) {
+ auto unique_cube = [](const HloInstruction* instr) -> bool {
+ bool square_operand = instr->operand(0)->opcode() != HloOpcode::kMultiply;
+ return instr->operand(!square_operand)->opcode() != HloOpcode::kMultiply &&
+ instr->operand(square_operand)->operand(0) ==
+ instr->operand(!square_operand);
+ };
+ return MultiplyAnyOrder(Square(pattern), pattern).WithPredicate(unique_cube);
+}
+
+// Addition-reduction of pattern with optional supported type conversion and/or
+// addition or removal of degenerate dimensions.
+template <typename Pattern>
+auto AddReduce(Pattern pattern) {
+ return OptionalSupportedTransform(
+ m::Reduce(pattern, m::Op())
+ .WithPredicate([](const HloInstruction* instr) {
+ return AppliesAddReduce(instr);
+ }));
+}
+
+// Capturing addition-reduction of pattern with optional supported type
+// conversion and/or addition or removal of degenerate dimensions.
+template <typename Pattern>
+auto AddReduce(HloInstruction** reduction, Pattern pattern) {
+ return OptionalSupportedTransform(
+ m::Reduce(reduction, pattern, m::Op())
+ .WithPredicate([](const HloInstruction* instr) {
+ return AppliesAddReduce(instr);
+ }));
+}
+
+// Negated addition-reduction.
+template <typename Pattern>
+auto NegateAddReduce(HloInstruction** reduction, Pattern pattern) {
+ return m::AnyOf<HloInstruction>(AddReduce(reduction, m::Negate(pattern)),
+ m::Negate(AddReduce(reduction, pattern)));
+}
+
+// Expected value, or mean, with optional broadcast.
+template <typename Pattern>
+auto Expectation(Pattern pattern) {
+ auto shared_subpattern =
+ MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
+ .WithPredicate([](const HloInstruction* instr) {
+ return CalculatesExpectation(instr);
+ });
+ return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+ shared_subpattern);
+}
+
+// Expected value, or mean, with optional broadcast.
+template <typename Pattern>
+auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) {
+ auto shared_subpattern = OptionalSupportedTransform(
+ m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
+ .WithPredicate([](const HloInstruction* instr) {
+ return CalculatesExpectation(instr);
+ })
+ .WithPredicate(expectation->GetCaptureOrVerifyFn()));
+ return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+ shared_subpattern);
+}
+
+// Expected value, or mean, with optional broadcast.
+template <typename Pattern>
+auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce,
+ Pattern pattern) {
+ auto shared_subpattern = OptionalSupportedTransform(
+ m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
+ AddReduce(reduce, pattern))
+ .WithPredicate([](const HloInstruction* instr) {
+ return CalculatesExpectation(instr);
+ })
+ .WithPredicate(expectation->GetCaptureOrVerifyFn()));
+ return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+ shared_subpattern);
+}
+
+// Variance, expressed as expectation(X^2) - expectation(X)^2 or
+// expectation((X - expectation(X))^2).
+auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation,
+ UniqueHloInstruction* x) {
+ return m::AnyOf<HloInstruction>(
+ Subtract(
+ Expectation(Square(OptionalSupportedTransform(
+ m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))),
+ Square(Expectation(expectation,
+ OptionalSupportedTransform(m::Op().WithPredicate(
+ x->GetCaptureOrVerifyFn())))))
+ .WithPredicate(variance->GetCaptureOrVerifyFn()),
+ Expectation(
+ Square(Subtract(
+ OptionalSupportedTransform(
+ m::Op().WithPredicate(x->GetCaptureOrVerifyFn())),
+ Expectation(expectation,
+ OptionalSupportedTransform(m::Op().WithPredicate(
+ x->GetCaptureOrVerifyFn()))))))
+ .WithPredicate(variance->GetCaptureOrVerifyFn()));
+}
+
+// Reciprocal of the square root of variance + epsilon with optional broadcast.
+auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x,
+ UniqueHloInstruction* variance,
+ UniqueHloInstruction* expectation,
+ UniqueHloInstruction* epsilon) {
+ auto shared_subpattern = m::SharedSubpattern(Rsqrt(
+ norm_factor, AddAnyOrder(Variance(variance, expectation, x),
+ m::Broadcast(m::ConstantScalar().WithPredicate(
+ epsilon->GetCaptureOrVerifyFn())))));
+ return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+ shared_subpattern);
+}
+
+// Any order of p0 * p1 * p2.
+template <typename P0, typename P1, typename P2>
+auto MultiplyMultiplyAnyOrder(P0 p0, P1 p1, P2 p2) {
+ return m::AnyOf<HloInstruction>(
+ MultiplyAnyOrder(p0, MultiplyAnyOrder(p1, p2)),
+ MultiplyAnyOrder(p1, MultiplyAnyOrder(p0, p2)),
+ MultiplyAnyOrder(p2, MultiplyAnyOrder(p0, p1)));
+}
+
+// Any order of p0 + p1 + p2.
+template <typename P0, typename P1, typename P2>
+auto AddAddAnyOrder(P0 p0, P1 p1, P2 p2) {
+ return m::AnyOf<HloInstruction>(AddAnyOrder(p0, AddAnyOrder(p1, p2)),
+ AddAnyOrder(p1, AddAnyOrder(p0, p2)),
+ AddAnyOrder(p2, AddAnyOrder(p0, p1)));
+}
+
+// Any order of p0 * (p1 + p2).
+template <typename P0, typename P1, typename P2>
+auto MultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2) {
+ return m::AnyOf<HloInstruction>(
+ MultiplyAnyOrder(p0, AddAnyOrder(p1, p2)),
+ AddAnyOrder(MultiplyAnyOrder(p0, p1), MultiplyAnyOrder(p0, p2)));
+}
+
+// Any order of p0 - p1 + p2.
+template <typename P0, typename P1, typename P2>
+auto SubtractAddAnyOrder(P0 p0, P1 p1, P2 p2) {
+ return m::AnyOf<HloInstruction>(AddAnyOrder(Subtract(p0, p1), p2),
+ AddAnyOrder(Subtract(p2, p1), p0),
+ Subtract(AddAnyOrder(p0, p2), p1));
+}
+
+// Any order of (p0 - p1) * p2 * p3 + p4.
+template <typename P0, typename P1, typename P2, typename P3, typename P4>
+auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) {
+ return m::AnyOf<HloInstruction>(
+ SubtractAddAnyOrder(MultiplyMultiplyAnyOrder(p0, p2, p3),
+ MultiplyMultiplyAnyOrder(p1, p2, p3), p4),
+ AddAnyOrder(MultiplyMultiplyAnyOrder(Subtract(p0, p1), p2, p3), p4));
+}
+
+// Expectation fused into a layer norm Custom Call.
+auto FusedExpectation(UniqueHloInstruction* custom_call) {
+ auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
+ m::CustomCall({kCudnnNormCallTarget})
+ .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+ 1));
+ return m::AnyOf<HloInstruction>(shared_subpattern,
+ BitcastOrReshape(shared_subpattern));
+}
+
+// Expectation fused into a layer norm Custom Call.
+auto FusedExpectation(UniqueHloInstruction* fused_expectation,
+ UniqueHloInstruction* custom_call) {
+ auto shared_subpattern = m::SharedSubpattern(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnNormCallTarget})
+ .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+ 1)
+ .WithPredicate(fused_expectation->GetCaptureOrVerifyFn()));
+ return m::AnyOf<HloInstruction>(shared_subpattern,
+ BitcastOrReshape(shared_subpattern));
+}
+
+// Norm factor fused into a layer norm Custom Call.
+auto FusedNormFactor(UniqueHloInstruction* custom_call) {
+ auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
+ m::CustomCall({kCudnnNormCallTarget})
+ .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+ 2));
+ return m::AnyOf<HloInstruction>(shared_subpattern,
+ BitcastOrReshape(shared_subpattern));
+}
+
+// Norm factor fused into a layer norm Custom Call.
+auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor,
+ UniqueHloInstruction* custom_call) {
+ auto shared_subpattern = m::SharedSubpattern(
+ m::GetTupleElement(
+ m::CustomCall({kCudnnNormCallTarget})
+ .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+ 2)
+ .WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn()));
+ return m::AnyOf<HloInstruction>(shared_subpattern,
+ BitcastOrReshape(shared_subpattern));
+}
+
+// Derivative of the norm factor w.r.t. variance + epsilon,
+// d(norm_factor)/d(variance + epsilon)
+// = d((variance + epsilon)^-1/2)/d(variance + epsilon)
+// = -1/2 * norm_factor^3.
+// Forwards custom_call to FusedNormFactor for verification.
+auto DNormFactor(UniqueHloInstruction* custom_call) {
+ return MultiplyAnyOrder(m::Broadcast(m::ConstantScalar(-0.5)),
+ Cube(FusedNormFactor(custom_call)));
+}
+
+// Zero-centered input of the layer norm, X - expectation(X). Verifies that
+// custom_call is a forward layer norm fusing X. Forwards custom_call to
+// FusedExpectation for verification.
+auto XCenter(UniqueHloInstruction* x, UniqueHloInstruction* custom_call,
+ const NormMetadataMap& norm_metadata) {
+ auto capture_or_verify_x =
+ [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
+ return x->CaptureOrVerify(
+ FindTarget(custom_call->Instr(), instr->operand(0),
+ custom_call->Instr()->operand(0), norm_metadata)
+ ? custom_call->Instr()->mutable_operand(0)
+ : nullptr);
+ };
+ return Subtract(m::Op(), m::Broadcast(FusedExpectation(custom_call)))
+ .WithPredicate(capture_or_verify_x);
+}
+
+// Zero-centered input of the layer norm, X - expectation(X). Captures X in x if
+// custom_call is a forward layer norm fusing X. Forwards custom_call to
+// FusedExpectation for comparison.
+auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x,
+ UniqueHloInstruction* fused_expectation,
+ UniqueHloInstruction* custom_call,
+ const NormMetadataMap& norm_metadata) {
+ auto capture_or_verify_x =
+ [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
+ return x->CaptureOrVerify(
+ FindTarget(custom_call->Instr(), instr->operand(0),
+ custom_call->Instr()->operand(0), norm_metadata)
+ ? custom_call->Instr()->mutable_operand(0)
+ : nullptr);
+ };
+ return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation,
+ custom_call)))
+ .WithPredicate(x_center->GetCaptureOrVerifyFn())
+ .WithPredicate(capture_or_verify_x);
+}
+
+// Addition-reduction of the product of XCenter, the broadcasted scale and DY,
+// XCenter * scale * DY. Captures the scale in scale if custom_call is a forward
+// layer norm fusing the scale. Forwards custom_call to XCenter for comparison.
+auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
+ UniqueHloInstruction* dy, UniqueHloInstruction* x,
+ HloInstruction** reduce, const NormMetadataMap& norm_metadata) {
+ auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
+ const HloInstruction* instr) -> bool {
+ return scale->CaptureOrVerify(FindTarget(custom_call->Instr(), instr,
+ custom_call->Instr()->operand(1),
+ norm_metadata)
+ ? custom_call->Instr()->mutable_operand(1)
+ : nullptr);
+ };
+ return AddReduce(
+ reduce, MultiplyMultiplyAnyOrder(
+ XCenter(x, custom_call, norm_metadata),
+ m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)),
+ m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
+}
+
+// Product of XCenter and the scaled and broadcasted product of F0 and
+// d(norm_factor)/d(variance + epsilon), XCenter * F0 * DNormFactor * 2 /
+// nelems. Forwards custom_call to XCenter, F0 and DNormFactor for capture or
+// verification.
+auto F1(UniqueHloInstruction* x, UniqueHloInstruction* x_center,
+ UniqueHloInstruction* fused_expectation,
+ UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
+ UniqueHloInstruction* dy, HloInstruction** reduce,
+ const NormMetadataMap& norm_metadata) {
+ auto broadcasts_two_over_nelems = [](const HloInstruction* instr) -> bool {
+ const HloInstruction* multiply = SkipUnaryOps(instr->operand(0));
+ bool bcast_operand =
+ multiply->operand(0)->opcode() != HloOpcode::kBroadcast;
+
+ // The captured scalar must be two over the number of elements in the
+ // broadcasted dimensions.
+ float actual_two_over_nelems = multiply->operand(bcast_operand)
+ ->operand(0)
+ ->literal()
+ .GetAsDouble({})
+ .value();
+ int64_t nelems = 1;
+ for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
+ if (!absl::c_linear_search(instr->dimensions(), i)) {
+ nelems *= instr->shape().dimensions()[i];
+ }
+ }
+ // The absolute of the difference between the actual scaling factor and the
+ // reference value must not exceed a prescribed threshold.
+ float two_over_nelems = 2. / static_cast<float>(nelems);
+ float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
+ return abs(actual_two_over_nelems - two_over_nelems) <
+ ((actual_two_over_nelems + two_over_nelems) * numerical_epsilon);
+ };
+
+ return MultiplyAnyOrder(
+ XCenter(x_center, x, fused_expectation, custom_call, norm_metadata),
+ m::Broadcast(
+ MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
+ MultiplyAnyOrder(DNormFactor(custom_call),
+ F0(custom_call, scale, dy, x,
+ reduce, norm_metadata))))
+ .WithPredicate(broadcasts_two_over_nelems));
+}
+
+// Product of the norm factor, scale and DY, NormFactor * scale * DY. Captures
+// the scale in scale if custom_call is a forward layer norm fusing the scale.
+// Forwards custom_call to FusedNormFactor for comparison.
+auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale,
+ UniqueHloInstruction* dy, UniqueHloInstruction* custom_call,
+ const NormMetadataMap& norm_metadata) {
+ auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
+ const HloInstruction* instr) -> bool {
+ return scale->CaptureOrVerify(
+ FindTarget(custom_call->Instr(), instr->operand(0),
+ custom_call->Instr()->operand(1), norm_metadata)
+ ? custom_call->Instr()->mutable_operand(1)
+ : nullptr);
+ };
+ return MultiplyAnyOrder(
+ m::Broadcast(
+ BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))),
+ MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale),
+ m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
+}
+
+class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit CudnnNormRewriterVisitor(
+ const se::CudaComputeCapability cuda_compute_capability)
+ : cuda_compute_capability_(cuda_compute_capability) {}
+
+ absl::Status HandleAdd(HloInstruction* instr) override {
+ TF_RETURN_IF_ERROR(MatchLayerNorm(instr));
+ TF_RETURN_IF_ERROR(MatchLayerNormGradient(instr));
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleSubtract(HloInstruction* instr) override {
+ return MatchLayerNorm(instr);
+ }
+
+ // Matches and rewrites layer norm patterns,
+ // Y = (X - expectation(X))/sqrt(variance(X) + epsilon) * scale + bias,
+ // into Custom Calls to cuDNN.
+ absl::Status MatchLayerNorm(HloInstruction* instr) {
+ UniqueHloInstruction x, expectation, variance, epsilon;
+ HloInstruction *scale, *bias, *reduce, *norm_factor, *broadcast_scale,
+ *broadcast_bias;
+ if (Match(
+ instr,
+ SubtractMultiplyAddAnyOrder(
+ OptionalSupportedTransform(
+ m::Op().WithPredicate(x.GetCaptureOrVerifyFn())),
+ Expectation(&expectation, &reduce,
+ OptionalSupportedTransform(m::Op().WithPredicate(
+ x.GetCaptureOrVerifyFn()))),
+ NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon),
+ m::Broadcast(&broadcast_scale, m::Op(&scale)),
+ m::Broadcast(&broadcast_bias, m::Op(&bias))))) {
+#if CUDNN_VERSION < 8905
+ // Layer norm kernels are available with cuDNN 8.9.5 and above.
+ VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5.";
+ return absl::OkStatus();
+#endif // CUDNN_VERSION < 8905
+
+ if (!instr->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_enable_cudnn_layer_norm()) {
+ VLOG(1) << "Layer norm Custom Calls disabled.";
+ return absl::OkStatus();
+ }
+
+ // Layer norm kernels require Ampere or Hopper architectures.
+ if (cuda_compute_capability_.major != se::CudaComputeCapability::AMPERE &&
+ cuda_compute_capability_.major != se::CudaComputeCapability::HOPPER) {
+ VLOG(1) << "Layer norm Custom Calls require Ampere or Hopper "
+ "architectures.";
+ return absl::OkStatus();
+ }
+
+ // Verify the uniqueness of the inputs.
+ if (!x.Instr() || !expectation.Instr() || !variance.Instr() ||
+ !epsilon.Instr()) {
+ VLOG(1) << "Layer norm operands not unique.";
+ return absl::OkStatus();
+ }
+
+ // Verify the input and output layouts.
+ // TODO(philipphack): Consider supporting more general cases.
+ if (!LayoutUtil::IsMonotonicWithDim0Major(x.Instr()->shape().layout()) ||
+ !LayoutUtil::IsMonotonicWithDim0Major(scale->shape().layout()) ||
+ !LayoutUtil::IsMonotonicWithDim0Major(bias->shape().layout()) ||
+ !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) {
+ VLOG(1) << "Layer norm input and/or output layouts nor supported.";
+ return absl::OkStatus();
+ }
+
+ // Verify the element types. The element types of input and output and the
+ // shapes of scale and bias must match.
+ if (!CompatibleElementType(instr) || !CompatibleElementType(scale) ||
+ !CompatibleElementType(bias) ||
+ !ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) ||
+ !ShapeUtil::Equal(scale->shape(), bias->shape())) {
+ VLOG(1) << "Layer norm input types or shapes not supported.";
+ return absl::OkStatus();
+ }
+
+ // Verify that the shapes of scale and bias are compatible with the
+ // operation. The adjusted norm dimensions are the dimensions of the
+ // reduction after removing any degenerate dimensions from the input of
+ // the reduction.
+ std::vector<int64_t> norm_dims(reduce->dimensions().begin(),
+ reduce->dimensions().end());
+ std::vector<int64_t> norm_dims_adjusted = AdjustedDimensions(reduce);
+ if (norm_dims_adjusted.size() !=
+ ShapeUtil::DropDegenerateDimensions(scale->shape())
+ .dimensions_size()) {
+ VLOG(1) << "Layer norm input dimensions not supported.";
+ return absl::OkStatus();
+ }
+
+ // Verify the broadcasts of scale and bias.
+ if (!ShapeUtil::EqualIgnoringElementType(
+ ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
+ ShapeUtil::DropDegenerateDimensions(broadcast_scale->shape())) ||
+ !ShapeUtil::EqualIgnoringElementType(
+ ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
+ ShapeUtil::DropDegenerateDimensions(broadcast_bias->shape())) ||
+ norm_dims_adjusted != AdjustedDimensions(broadcast_scale) ||
+ norm_dims_adjusted != AdjustedDimensions(broadcast_bias)) {
+ VLOG(1) << "Layer norm operand broadcast not supported.";
+ return absl::OkStatus();
+ }
+
+ // If necessary, transpose the input so that the dimensions not being
+ // normalized are the leading dimensions.
+ std::vector<int64_t> non_norm_dims;
+ for (int64_t x_dim = 0; x_dim < x.Instr()->shape().rank(); ++x_dim) {
+ if (std::find(norm_dims.begin(), norm_dims.end(), x_dim) ==
+ norm_dims.end()) {
+ non_norm_dims.emplace_back(x_dim);
+ }
+ }
+ std::vector<int64_t> non_norm_dims_adjusted =
+ AdjustedDimensions(x.Instr()->shape(), non_norm_dims);
+
+ std::vector<int64_t> x_transpose_order = non_norm_dims;
+ x_transpose_order.insert(x_transpose_order.end(), norm_dims.begin(),
+ norm_dims.end());
+
+ bool apply_transpose = false;
+ for (int i = 0; i < x_transpose_order.size(); ++i) {
+ if (x_transpose_order[i] != i) {
+ apply_transpose = true;
+ break;
+ }
+ }
+
+ std::optional<HloInstruction*> x_transpose;
+ // The transpose applied to the output is the inverse of the transpose
+ // applied to the input.
+ std::vector<int64_t> y_transpose_order(x_transpose_order.size());
+ if (apply_transpose) {
+ for (int k = 0; k < x_transpose_order.size(); ++k) {
+ y_transpose_order[x_transpose_order[k]] = k;
+ }
+ TF_ASSIGN_OR_RETURN(x_transpose,
+ MakeTransposeHlo(x.Instr(), x_transpose_order));
+ }
+
+ // Combine the dimensions not normalized into the first dimension of the
+ // input as required by cuDNN.
+ std::vector<int64_t> reshaped_dims = {1};
+ for (auto non_norm_dim : non_norm_dims) {
+ reshaped_dims[0] *= x.Instr()->shape().dimensions(non_norm_dim);
+ }
+ for (auto norm_dim : norm_dims) {
+ reshaped_dims.emplace_back(x.Instr()->shape().dimensions(norm_dim));
+ }
+ // cuDNN requires tensors to have at least four dimensions.
+ while (reshaped_dims.size() < 4) {
+ reshaped_dims.emplace_back(1);
+ }
+
+ Shape reshaped_shape = ShapeUtil::MakeShape(
+ x.Instr()->shape().element_type(), reshaped_dims);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * x_reshape,
+ MakeReshapeHlo(reshaped_shape, x_transpose.value_or(x.Instr())));
+
+ // Reshape the scale and bias. The first dimension corresponds to the
+ // non-normalization dimension of the norm input and must have size 1.
+ std::vector<int64_t> reshaped_scale_dims = reshaped_dims;
+ reshaped_scale_dims[0] = 1;
+
+ Shape scale_bias_shape = ShapeUtil::MakeShape(
+ scale->shape().element_type(), reshaped_scale_dims);
+ TF_ASSIGN_OR_RETURN(HloInstruction * scale_reshape,
+ MakeReshapeHlo(scale_bias_shape, scale));
+ TF_ASSIGN_OR_RETURN(HloInstruction * bias_reshape,
+ MakeReshapeHlo(scale_bias_shape, bias));
+ GpuBackendConfig gpu_backend_config;
+ CudnnNormBackendConfig& backend_config =
+ *gpu_backend_config.mutable_cudnn_norm_backend_config();
+ backend_config.set_epsilon(
+ epsilon.Instr()->literal().GetAsDouble({}).value());
+ backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_INFER);
+ auto* algorithm = backend_config.mutable_algorithm();
+ algorithm->set_algo_id(0);
+ algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+ algorithm->set_is_cudnn_frontend(true);
+
+ // Set the workspace size to its upper bound.
+ // TODO(philipphack): Consider autotuning the norm kernels.
+ TF_ASSIGN_OR_RETURN(const int64_t c_constant,
+ CConstant(cuda_compute_capability_));
+ const int64_t workspace_size =
+ (2 * c_constant * (4 + 256)) + (2 * reshaped_dims[0] * 4) + 64;
+ algorithm->mutable_workspace_size()->set_value(workspace_size);
+
+ // The output of the Custom Call is a tuple, the second element of which
+ // describes the scratch space.
+ Shape custom_call_shape = ShapeUtil::MakeTupleShape(
+ {x_reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})});
+
+ HloInstruction* custom_call =
+ instr->AddInstruction(HloInstruction::CreateCustomCall(
+ custom_call_shape, {x_reshape, scale_reshape, bias_reshape},
+ kCudnnNormCallTarget));
+ TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * gte,
+ MakeGetTupleElementHlo(custom_call, 0));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * y_reshape,
+ MakeReshapeHlo(x_transpose.value_or(instr)->shape(), gte));
+
+ std::optional<HloInstruction*> y_transpose;
+ if (apply_transpose) {
+ TF_ASSIGN_OR_RETURN(y_transpose,
+ MakeTransposeHlo(y_reshape, y_transpose_order));
+ }
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(instr, y_transpose.value_or(y_reshape)));
+
+ // Store metadata for potential use in the backward graph.
+ norm_metadata_.insert(
+ {custom_call,
+ NormMetadata({x_transpose.value_or(nullptr),
+ y_transpose.value_or(nullptr), norm_dims_adjusted,
+ non_norm_dims_adjusted})});
+
+ VLOG(1) << "Layer norm rewritten into Custom Call.";
+
+ // The layer norm training graph separately contains the norm factor
+ // divided by the sum of variance and epsilon.
+ for (HloInstruction* user : norm_factor->users()) {
+ if (user->opcode() == HloOpcode::kDivide &&
+ user->operand_index(norm_factor) == 0) {
+ TF_ASSIGN_OR_RETURN(bool changed,
+ MatchNormFactor(user, custom_call, variance,
+ expectation, epsilon));
+ if (changed) {
+ break;
+ }
+ }
+ }
+ }
+
+ return absl::OkStatus();
+ }
+
+ // The layer norm training graph separately contains the expectation as well
+ // as the norm factor and its cube, (variance + epsilon)^-1/2 and (variance +
+ // epsilon)^-3/2. When identified in the graph, these quantities are fused
+ // into the layer norm Custom Call.
+ absl::StatusOr<bool> MatchNormFactor(HloInstruction* instr,
+ HloInstruction* custom_call,
+ UniqueHloInstruction& variance,
+ UniqueHloInstruction& expectation,
+ UniqueHloInstruction& epsilon) {
+ HloInstruction* gte = custom_call->users()[0];
+ if (Match(instr,
+ m::Divide(
+ m::Op(),
+ AddAnyOrder(
+ m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()),
+ m::Broadcast(m::ConstantScalar().WithPredicate(
+ epsilon.GetCaptureOrVerifyFn())))))) {
+ // Verify the uniqueness of the operands.
+ if (!variance.Instr() || !epsilon.Instr()) {
+ VLOG(1) << "Layer norm operands not unique.";
+ return false;
+ }
+
+ // Verify the element types.
+ if (!CompatibleElementType(instr) ||
+ !CompatibleElementType(expectation.Instr())) {
+ VLOG(1) << "Layer norm input types not compatible.";
+ return false;
+ }
+
+ // Retrieve metadata of the forward layer norm.
+ auto norm_metadata = norm_metadata_.extract(custom_call);
+ if (!norm_metadata) {
+ VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
+ return false;
+ }
+
+ // The shape of the expectation and norm factor return values of the
+ // Custom Call is [nelems, 1, 1, 1], where nelems is the
+ // number of elements in the expectation and norm factor shapes.
+ auto make_compatible_shape = [](Shape shape) -> Shape {
+ return ShapeUtil::MakeShape(shape.element_type(),
+ {ShapeUtil::ElementsIn(shape), 1, 1, 1});
+ };
+
+ Shape expectation_shape =
+ make_compatible_shape(expectation.Instr()->shape());
+ Shape norm_factor_shape = make_compatible_shape(instr->shape());
+
+ // The augmented Custom Call additionally returns the expectation and the
+ // norm factor.
+ std::vector<Shape> tuple_shapes = custom_call->shape().tuple_shapes();
+ tuple_shapes.insert(tuple_shapes.begin() + 1,
+ {expectation_shape, norm_factor_shape});
+
+ Shape custom_call_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
+
+ HloInstruction* new_custom_call = instr->AddInstruction(
+ custom_call->CloneWithNewShape(custom_call_shape));
+
+ TF_ASSIGN_OR_RETURN(
+ GpuBackendConfig gpu_backend_config,
+ custom_call->backend_config<xla::gpu::GpuBackendConfig>());
+ CudnnNormBackendConfig& backend_config =
+ *gpu_backend_config.mutable_cudnn_norm_backend_config();
+ backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_TRAIN);
+
+ // Update the workspace size.
+ TF_ASSIGN_OR_RETURN(const int64_t c_constant,
+ CConstant(cuda_compute_capability_));
+ const int64_t workspace_size = (2 * c_constant * (4 + 256)) + 32;
+ backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
+ workspace_size);
+ TF_RETURN_IF_ERROR(
+ new_custom_call->set_backend_config(gpu_backend_config));
+
+ auto replace_with_new_cc = [new_custom_call, this](
+ HloInstruction* old_instr,
+ int tuple_index) -> absl::Status {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_gte,
+ MakeGetTupleElementHlo(new_custom_call, tuple_index));
+ HloInstruction* new_instr = new_gte;
+ if (!ShapeUtil::Equal(new_gte->shape(), old_instr->shape())) {
+ TF_ASSIGN_OR_RETURN(new_instr,
+ MakeReshapeHlo(old_instr->shape(), new_gte));
+ }
+ if (old_instr->opcode() != HloOpcode::kDivide) {
+ // Replace the result of the layer norm or the expectation.
+ TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
+ } else {
+ // Replace the norm factor, (variance + epsilon)^-1/2.
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(old_instr->mutable_operand(0), new_instr));
+ // Also replace the norm factor to the power of 3, (variance +
+ // epsilon)^-1/2 / (variance + epsilon) = ((variance +
+ // epsilon)^-1/2)^3.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_multiply0,
+ MakeBinaryHlo(HloOpcode::kMultiply, new_instr, new_instr));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_multiply1,
+ MakeBinaryHlo(HloOpcode::kMultiply, new_multiply0, new_instr));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_multiply1));
+ }
+ return absl::OkStatus();
+ };
+
+ // Replace the result of the original Custom Call as well as the
+ // expectation and the norm factor with the augmented Custom Call.
+ TF_RETURN_IF_ERROR(replace_with_new_cc(gte, 0));
+ TF_RETURN_IF_ERROR(replace_with_new_cc(expectation.Instr(), 1));
+ TF_RETURN_IF_ERROR(replace_with_new_cc(instr, 2));
+
+ // Update the Custom Call associated with the metadata of the forward
+ // norm.
+ norm_metadata.key() = new_custom_call;
+ norm_metadata_.insert(std::move(norm_metadata));
+
+ VLOG(1)
+ << "Expectation and norm factor fused into layer norm Custom Call.";
+ }
+
+ return true;
+ }
+
+ // Matches and rewrites the backward graph of layer norm patterns into Custom
+ // Calls to cuDNN when the associated forward graph has been rewritten into a
+ // cuDNN Custom Call. The gradients are
+ // DX = F1 + F2 - AddReduce(F1 + F2) / nelems,
+ // Dscale = AddReduce(DY * XCenter * NormFactor),
+ // Dbias = AddReduce(DY),
+ // with
+ // F0 = XCenter * scale * DY,
+ // F1 = XCenter * F0 * DNormFactor * 2 / nelems,
+ // F2 = NormFactor * scale * DY,
+ // XCenter = X - expectation(X),
+ // NormFactor = (variance(X) + epsilon)^-1/2 and
+ // DNormFactor = -1/2 * NormFactor^3.
+ absl::Status MatchLayerNormGradient(HloInstruction* instr) {
+ UniqueHloInstruction fwd_custom_call, x, x_center, scale, dy,
+ fused_expectation, fused_norm_factor;
+ HloInstruction *broadcast, *scalar, *dscale, *dbias, *reduce0, *reduce1,
+ *reduce2, *reduce3;
+ if (Match(instr,
+ AddAddAnyOrder(
+ m::Broadcast(
+ &broadcast,
+ MultiplyAddAnyOrder(
+ m::Broadcast(m::ConstantScalar(&scalar)),
+ NegateAddReduce(&reduce0,
+ F1(&x, &x_center, &fused_expectation,
+ &fwd_custom_call, &scale, &dy,
+ &reduce2, norm_metadata_)),
+ NegateAddReduce(
+ &reduce1, F2(&fused_norm_factor, &scale, &dy,
+ &fwd_custom_call, norm_metadata_)))),
+ F2(&fused_norm_factor, &scale, &dy, &fwd_custom_call,
+ norm_metadata_),
+ F1(&x, &x_center, &fused_expectation, &fwd_custom_call,
+ &scale, &dy, &reduce3, norm_metadata_)))) {
+ // Skip initial convert, if present.
+ if (instr->user_count() == 1 &&
+ instr->users()[0]->opcode() == HloOpcode::kConvert &&
+ CompatibleElementType(instr->users()[0])) {
+ instr = instr->users()[0];
+ }
+
+ // Verify the uniqueness of the captured Custom Call and inputs.
+ if (!fwd_custom_call.Instr() || !x.Instr() || !dy.Instr() ||
+ !x_center.Instr() || !scale.Instr() || !fused_expectation.Instr() ||
+ !fused_norm_factor.Instr()) {
+ VLOG(1) << "Layer norm gradient inputs not unique.";
+ return absl::OkStatus();
+ }
+
+ // Retrieve metadata of the forward layer norm.
+ auto norm_metadata = norm_metadata_.find(fwd_custom_call.Instr());
+ if (norm_metadata == norm_metadata_.end()) {
+ VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
+ return absl::OkStatus();
+ }
+
+ // Verify the dimensions of reductions in the backward graph.
+ if (AdjustedDimensions(reduce0) !=
+ norm_metadata->second.norm_dims_adjusted ||
+ AdjustedDimensions(reduce1) !=
+ norm_metadata->second.norm_dims_adjusted ||
+ AdjustedDimensions(reduce2) !=
+ norm_metadata->second.norm_dims_adjusted ||
+ AdjustedDimensions(reduce3) !=
+ norm_metadata->second.norm_dims_adjusted) {
+ VLOG(1) << "Unexpected reductions dimensions in layer norm gradient.";
+ return absl::OkStatus();
+ }
+
+ // The captured scalar must be one over the number of elements in the
+ // broadcasted dimensions.
+ float actual_r_nelems = scalar->literal().GetAsDouble({}).value();
+ int64_t nelems = 1;
+ for (int i = 0; i < broadcast->shape().dimensions_size(); ++i) {
+ if (!absl::c_linear_search(broadcast->dimensions(), i)) {
+ nelems *= broadcast->shape().dimensions()[i];
+ }
+ }
+ // The absolute of the difference between the actual scaling factor and
+ // the reference value must not exceed a prescribed threshold.
+ float r_nelems = 1. / static_cast<float>(nelems);
+ float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
+ if (!(abs(actual_r_nelems - r_nelems) <
+ ((actual_r_nelems + r_nelems) * numerical_epsilon))) {
+ VLOG(1)
+ << "Layer norm backward broadcast operand outside expected range.";
+ return absl::OkStatus();
+ }
+
+ // Identify Dscale = AddReduce(DY * XCenter * norm factor) with factor0
+ // and factor1 intended to be XCenter and DY or DY and XCenter.
+ auto find_dscale =
+ [&fused_norm_factor, &norm_metadata](
+ const UniqueHloInstruction& factor0,
+ const UniqueHloInstruction& factor1) -> HloInstruction* {
+ for (HloInstruction* factor0_user : factor0.Instr()->users()) {
+ std::vector<HloInstruction*> users;
+ SkipUnaryOpsTopDownRecursive(factor0_user, users);
+ // One of the users of factor0 must be a chained multiplication by the
+ // fused norm factor and factor1.
+ for (HloInstruction* user : users) {
+ if (Match(user,
+ MultiplyAnyOrder(
+ m::Op(), MultiplyAnyOrder(
+ m::Broadcast(BitcastOrReshape(m::Op().Is(
+ fused_norm_factor.Instr()))),
+ m::Op().Is(factor1.Instr()))))) {
+ // Dscale is an addition-reduction of the product.
+ for (HloInstruction* multiply_user : user->users()) {
+ if (AppliesAddReduce(
+ multiply_user,
+ norm_metadata->second.non_norm_dims_adjusted)) {
+ return multiply_user;
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+ };
+ if (!(dscale = find_dscale(x_center, dy)) &&
+ !(dscale = find_dscale(dy, x_center))) {
+ VLOG(1) << "Unable to identify Dscale in graph.";
+ return absl::OkStatus();
+ }
+
+ // Find Dbias, i.e. an addition-reduction of DY, starting from DY.
+ // Rewriting proceeds without fusing Dbias if unsuccessful.
+ dbias = FindAddReduce(dy.Instr(),
+ norm_metadata->second.non_norm_dims_adjusted);
+
+ // Verify the input and output layouts.
+ // TODO(philipphack): Consider supporting more general cases.
+ if (!LayoutUtil::IsMonotonicWithDim0Major(dy.Instr()->shape().layout()) ||
+ !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()) ||
+ !LayoutUtil::IsMonotonicWithDim0Major(dscale->shape().layout()) ||
+ (dbias &&
+ !LayoutUtil::IsMonotonicWithDim0Major(dbias->shape().layout()))) {
+ VLOG(1) << "Layer norm input and/or output layouts nor supported.";
+ return absl::OkStatus();
+ }
+
+ // The types of X and DX must match.
+ if (x.Instr()->shape().element_type() != instr->shape().element_type()) {
+ VLOG(1) << "The types of X and DX must match.";
+ return absl::OkStatus();
+ }
+
+ // The types and shapes of scale, Dscale and Dbias (if present) must
+ // match.
+ if (!ShapeUtil::Equal(
+ ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
+ ShapeUtil::DropDegenerateDimensions(dscale->shape())) ||
+ (dbias &&
+ !ShapeUtil::Equal(
+ ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
+ ShapeUtil::DropDegenerateDimensions(dbias->shape())))) {
+ VLOG(1) << "Backward layer norm types not supported.";
+ return absl::OkStatus();
+ }
+
+ // Verify the element types.
+ if (!CompatibleElementType(dy.Instr())) {
+ VLOG(1) << "Backward layer norm types not supported.";
+ return absl::OkStatus();
+ }
+
+ // cuDNN requires the byte size of the element type of X to be at least
+ // that of DY and scale.
+ if (ShapeUtil::ByteSizeOfPrimitiveType(
+ x.Instr()->shape().element_type()) <
+ ShapeUtil::ByteSizeOfPrimitiveType(
+ dy.Instr()->shape().element_type()) ||
+ ShapeUtil::ByteSizeOfPrimitiveType(
+ x.Instr()->shape().element_type()) <
+ ShapeUtil::ByteSizeOfPrimitiveType(
+ scale.Instr()->shape().element_type())) {
+ VLOG(1) << "Backward layer norm types not supported.";
+ return absl::OkStatus();
+ }
+
+ // Transpose DY applying the stored transpose order of X from the forward
+ // graph.
+ HloInstruction* transposed_dy = dy.Instr();
+ if (norm_metadata->second.x_transpose) {
+ TF_ASSIGN_OR_RETURN(
+ transposed_dy,
+ MakeTransposeHlo(dy.Instr(),
+ norm_metadata->second.x_transpose->dimensions()));
+ }
+ TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_dy,
+ MakeReshapeHlo(x.Instr()->shape(), transposed_dy));
+
+ Shape dx_shape = ShapeUtil::MakeShape(instr->shape().element_type(),
+ x.Instr()->shape().dimensions());
+
+ Shape dscale_dbias_shape = ShapeUtil::MakeShape(
+ dscale->shape().element_type(), scale.Instr()->shape().dimensions());
+
+ GpuBackendConfig gpu_backend_config;
+ CudnnNormBackendConfig& backend_config =
+ *gpu_backend_config.mutable_cudnn_norm_backend_config();
+ backend_config.set_kind(CudnnNormBackendConfig::LAYER_BWD);
+ auto* algorithm = backend_config.mutable_algorithm();
+ algorithm->set_algo_id(0);
+ algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+ algorithm->set_is_cudnn_frontend(true);
+
+ // Set the workspace size to its upper bound.
+ // TODO(philipphack): Consider autotuning the norm kernels.
+ TF_ASSIGN_OR_RETURN(const int64_t c_constant,
+ CConstant(cuda_compute_capability_));
+ const int64_t workspace_size =
+ (2 * c_constant * (4 + 256)) +
+ (2 * x.Instr()->shape().dimensions(0) * 4) + 64;
+ algorithm->mutable_workspace_size()->set_value(workspace_size);
+
+ // The output of the Custom Call is a tuple. The output shape of Dscale
+ // and Dbias is that of scale.
+ Shape custom_call_shape = ShapeUtil::MakeTupleShape(
+ {dx_shape, dscale_dbias_shape, dscale_dbias_shape,
+ ShapeUtil::MakeShape(U8, {workspace_size})});
+
+ HloInstruction* custom_call =
+ instr->AddInstruction(HloInstruction::CreateCustomCall(
+ custom_call_shape,
+ {x.Instr(), scale.Instr(), reshaped_dy, fused_expectation.Instr(),
+ fused_norm_factor.Instr()},
+ kCudnnNormCallTarget));
+ TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
+
+ auto replace_with_cc = [custom_call, norm_metadata, transposed_dy, this](
+ HloInstruction* old_instr,
+ int tuple_index) -> absl::Status {
+ TF_ASSIGN_OR_RETURN(HloInstruction * gte,
+ MakeGetTupleElementHlo(custom_call, tuple_index));
+ HloInstruction* new_instr;
+ // Transpose DX applying the stored transpose order of Y from the
+ // forward graph.
+ if (tuple_index == 0 && norm_metadata->second.y_transpose) {
+ TF_ASSIGN_OR_RETURN(new_instr,
+ MakeReshapeHlo(transposed_dy->shape(), gte));
+ TF_ASSIGN_OR_RETURN(
+ new_instr,
+ MakeTransposeHlo(
+ new_instr, norm_metadata->second.y_transpose->dimensions()));
+ } else {
+ TF_ASSIGN_OR_RETURN(new_instr,
+ MakeReshapeHlo(old_instr->shape(), gte));
+ }
+ TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
+ return absl::OkStatus();
+ };
+
+ TF_RETURN_IF_ERROR(replace_with_cc(instr, 0));
+ TF_RETURN_IF_ERROR(replace_with_cc(dscale, 1));
+ if (dbias) {
+ TF_RETURN_IF_ERROR(replace_with_cc(dbias, 2));
+ }
+ VLOG(1) << "Gradients w.r.t. x"
+ << (dbias ? ", scale and bias" : " and scale")
+ << " rewritten into layer norm backward Custom Call.";
+ }
+
+ return absl::OkStatus();
+ }
+
+ private:
+ se::CudaComputeCapability cuda_compute_capability_;
+ NormMetadataMap norm_metadata_;
+};
+
+absl::StatusOr<bool> RunOnComputation(
+ HloComputation* computation,
+ se::CudaComputeCapability cuda_compute_capability) {
+ CudnnNormRewriterVisitor visitor(cuda_compute_capability);
+ TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+ return visitor.changed();
+}
+
+} // anonymous namespace
+
+CudnnNormRewriter::CudnnNormRewriter(
+ se::CudaComputeCapability cuda_compute_capability)
+ : cuda_compute_capability_(cuda_compute_capability) {}
+
+absl::StatusOr<bool> CudnnNormRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(
+ bool result, RunOnComputation(computation, cuda_compute_capability_));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h
new file mode 100644
index 0000000..a2332d3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h
@@ -0,0 +1,48 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites norm patterns into Custom Calls to the cuDNN library. Currently, the
+// forward and backward passes of layer norm patterns are implemented.
+class CudnnNormRewriter : public HloModulePass {
+ public:
+ explicit CudnnNormRewriter(se::CudaComputeCapability cuda_compute_capability);
+ absl::string_view name() const override { return "norm-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::CudaComputeCapability cuda_compute_capability_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc
new file mode 100644
index 0000000..d130c08
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc
@@ -0,0 +1,1610 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include <gtest/gtest.h>
+#include "xla/error_spec.h"
+#include "xla/stream_executor/device_description.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#endif
+
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CudnnNormRewriterTest : public GpuCodegenTest {
+ public:
+ se::CudaComputeCapability GetCudaComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability();
+ }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_cudnn_layer_norm(true);
+ return debug_options;
+ }
+
+ protected:
+ void SetUp() override {
+#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
+ GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
+#endif
+ if (!(GetCudaComputeCapability().major ==
+ se::CudaComputeCapability::AMPERE) &&
+ !(GetCudaComputeCapability().major ==
+ se::CudaComputeCapability::HOPPER)) {
+ GTEST_SKIP()
+ << "Layer norm kernels require Ampere or Hopper architectures.";
+ }
+ }
+ void TestNorm(std::string hlo_text, std::string optimized_hlo) {
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, optimized_hlo);
+ }
+};
+
+// The following tests evaluate LayerNormXDY configurations, with X the rank of
+// the input and Y the dimensions that are normalized.
+TEST_F(CudnnNormRewriterTest, LayerNorm2D1) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4] parameter(0)
+ input_square = f32[2,4] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2] multiply(input_square_sum, r_nelems_bcast)
+ input_sum = f32[2] reduce(input, c0),dimensions={1}, to_apply=apply
+ input_mean = f32[2] multiply(input_sum, r_nelems_bcast)
+ input_mean_square = f32[2] multiply(input_mean, input_mean)
+ variance = f32[2] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
+ norm_factor = f32[2] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
+ input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
+ input_center = f32[2,4] subtract(input, input_mean_bcast)
+ norm = f32[2,4] multiply(norm_factor_bcast, input_center)
+ scale = f32[4] parameter(1)
+ scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
+ norm_scale = f32[2,4] multiply(norm, scale_bcast)
+ bias = f32[4] parameter(2)
+ bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
+ ROOT out = f32[2,4] add(norm_scale, bias_broadcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+ r_nelems = f32[] constant(0.125)
+ r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
+ input_sum = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+ input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
+ input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
+ variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+ norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[8] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[8] parameter(2)
+ bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+ ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[1,4,6,8] parameter(0)
+ input_square = f32[1,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[1,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+ r_nelems = f32[] constant(0.125)
+ r_nelems_bcast = f32[1,4,6] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[1,4,6] multiply(input_square_sum, r_nelems_bcast)
+ input_sum = f32[1,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+ input_mean = f32[1,4,6] multiply(input_sum, r_nelems_bcast)
+ input_mean_square = f32[1,4,6] multiply(input_mean, input_mean)
+ variance = f32[1,4,6] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[1,4,6] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[1,4,6] add(variance, epsilon_bcast)
+ norm_factor = f32[1,4,6] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[1,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+ input_mean_bcast = f32[1,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+ input_center = f32[1,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[1,4,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[8] parameter(1)
+ scale_bcast = f32[1,4,6,8] broadcast(scale), dimensions={3}
+ norm_scale = f32[1,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[8] parameter(2)
+ bias_bcast = f32[1,4,6,8] broadcast(bias), dimensions={3}
+ ROOT out = f32[1,4,6,8] add(norm_scale, bias_bcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[1,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[1,4,6,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[24,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: ROOT [[GTE_BITCAST:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} bitcast([[GTE]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D2) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
+ r_nelems = f32[] constant(0.166667)
+ r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,4,8] multiply(input_square_sum, r_nelems_bcast)
+ reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
+ input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,4,8] multiply(input_mean, input_mean)
+ variance = f32[2,4,8] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[6] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[6] parameter(2)
+ bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
+ ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,4,6,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,1,6,8] parameter(0)
+ input_square = f32[2,1,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,1,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
+ r_nelems = f32[] constant(0.166667)
+ r_nelems_bcast = f32[2,1,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,1,8] multiply(input_square_sum, r_nelems_bcast)
+ reduce = f32[2,1,8] reduce(input, c0), dimensions={2}, to_apply=apply
+ input_mean = f32[2,1,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,1,8] multiply(input_mean, input_mean)
+ variance = f32[2,1,8] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,1,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,1,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,1,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,1,6,8] broadcast(norm_factor), dimensions={0,1,3}
+ input_mean_bcast = f32[2,1,6,8] broadcast(input_mean), dimensions={0,1,3}
+ input_center = f32[2,1,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,1,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[6] parameter(1)
+ scale_bcast = f32[2,1,6,8] broadcast(scale), dimensions={2}
+ norm_scale = f32[2,1,6,8] multiply(norm, scale_bcast)
+ bias = f32[6] parameter(2)
+ bias_broadcast = f32[2,1,6,8] broadcast(bias), dimensions={2}
+ ROOT out = f32[2,1,6,8] add(norm_scale, bias_broadcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,1,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,1,6,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,6]{3,2,1,0} transpose([[P0]]), dimensions={1,0,3,2}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D12) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+ r_nelems = f32[] constant(0.041667)
+ r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+ reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+ input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+ variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[4,6] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[4,6] parameter(2)
+ bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
+ ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> f32[2,4,6,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,1,8] parameter(0)
+ input_square = f32[2,4,1,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+ reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+ input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+ variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
+ input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
+ input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[4,1] parameter(1)
+ scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
+ norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
+ bias = f32[4,1] parameter(2)
+ bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
+ ROOT out = f32[2,4,1,8] add(norm_scale, bias_broadcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> f32[2,4,1,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,2,2,2] parameter(0)
+ input_square = f32[2,2,2,2] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,2,2] reduce(input_square, c0), dimensions={3}, to_apply=apply
+ r_nelems = f32[] constant(0.5)
+ r_nelems_bcast = f32[2,2,2] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,2,2] multiply(input_square_sum, r_nelems_bcast)
+ input_sum = f32[2,2,2] reduce(input, c0), dimensions={3}, to_apply=apply
+ input_mean = f32[2,2,2] multiply(input_sum, r_nelems_bcast)
+ input_mean_square = f32[2,2,2] multiply(input_mean, input_mean)
+ variance = f32[2,2,2] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,2,2] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,2,2] add(variance, epsilon_bcast)
+ norm_factor = f32[2,2,2] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,2,2,2] broadcast(norm_factor), dimensions={0,1,2}
+ input_mean_bcast = f32[2,2,2,2] broadcast(input_mean), dimensions={0,1,2}
+ input_center = f32[2,2,2,2] subtract(input, input_mean_bcast)
+ norm = f32[2,2,2,2] multiply(norm_factor_bcast, input_center)
+ scale = f32[2] parameter(1)
+ scale_bcast = f32[2,2,2,2] broadcast(scale), dimensions={2}
+ norm_scale = f32[2,2,2,2] multiply(norm, scale_bcast)
+ bias = f32[2] parameter(2)
+ bias_bcast = f32[2,2,2,2] broadcast(bias), dimensions={3}
+ ROOT out = f32[2,2,2,2] add(norm_scale, bias_bcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,2,2,2], {{.*}}: f32[2], {{.*}}: f32[2]) -> f32[2,2,2,2] {
+; CHECK-NOT: custom_call_target="__cudnn$norm"
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f16[2,4,6,8] parameter(0)
+ input_f32 = f32[2,4,6,8] convert(input)
+ input_square = f32[2,4,6,8] multiply(input_f32, input_f32)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+ r_nelems = f32[] constant(0.125)
+ r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
+ input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply
+ input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
+ input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
+ variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+ norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+ input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[8] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[8] parameter(2)
+ bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+ ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
+; CHECK-NOT: custom_call_target="__cudnn$norm"
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4] parameter(0)
+ input_square = f32[2,4] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
+ reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
+ input_mean = f32[2] multiply(reduce,r_nelems_bcast)
+ input_mean_square = f32[2] multiply(input_mean,input_mean)
+ variance = f32[2] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
+ norm_factor = f32[2] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
+ input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
+ input_center = f32[2,4] subtract(input,input_mean_bcast)
+ norm = f32[2,4] multiply(norm_factor_bcast,input_center)
+ scale = f32[4] parameter(1)
+ scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
+ norm_scale = f32[2,4] multiply(norm,scale_bcast)
+ bias = f32[4] parameter(2)
+ bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
+ norm_scale_bias = f32[2,4] add(norm_scale, bias_broadcast)
+ norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
+ ROOT out = (f32[2,4], f32[2], f32[2], f32[2]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> (f32[2,4], f32[2], f32[2], f32[2]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
+; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE1]])
+; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE2]])
+; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2]{0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+ r_nelems = f32[] constant(0.125)
+ r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
+ reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+ input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
+ variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+ norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[8] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+ norm_scale = f32[2,4,6,8] multiply(norm,scale_bcast)
+ bias = f32[8] parameter(2)
+ bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+ norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+ norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
+ ROOT out = (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
+; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE1]])
+; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE2]])
+; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,4,6]{2,1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+ r_nelems = f32[] constant(0.041667)
+ r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+ reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+ input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+ variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[4,6] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[4,6] parameter(2)
+ bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
+ norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+ norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+ ROOT out = (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
+; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
+; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,1,8] parameter(0)
+ input_square = f32[2,4,1,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+ reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+ input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+ variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
+ input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
+ input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
+ scale = f32[4,1] parameter(1)
+ scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
+ norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
+ bias = f32[4,1] parameter(2)
+ bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
+ norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_broadcast)
+ norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+ ROOT out = (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK: }
+; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
+; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
+; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4] parameter(0)
+ input_square = f32[2,4] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
+ reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
+ input_mean = f32[2] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2] multiply(input_mean,input_mean)
+ variance = f32[2] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
+ norm_factor = f32[2] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
+ input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
+ input_center = f32[2,4] subtract(input, input_mean_bcast)
+ norm = f32[2,4] multiply(input_center, norm_factor_bcast)
+ scale = f32[4] parameter(1)
+ scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
+ norm_scale = f32[2,4] multiply(norm, scale_bcast)
+ bias = f32[4] parameter(2)
+ bias_bcast = f32[2,4] broadcast(bias), dimensions={1}
+ norm_scale_bias = f32[2,4] add(norm_scale, bias_bcast)
+ doutput = f32[2,4] parameter(3)
+ dbias = f32[4] reduce(doutput, c0), dimensions={0}, to_apply=apply
+ norm_doutput = f32[2,4] multiply(norm, doutput)
+ dscale = f32[4] reduce(norm_doutput, c0), dimensions={0}, to_apply=apply
+ scale_doutput = f32[2,4] multiply(scale_bcast, doutput)
+ input_center_scale_doutput = f32[2,4] multiply(input_center, scale_doutput)
+ f0 = f32[2] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
+ norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
+ c1 = f32[] constant(-0.5)
+ c1_bcast = f32[2] broadcast(c1), dimensions={}
+ dnorm_factor = f32[2] multiply(norm_factor_cube, c1_bcast)
+ f0_dnorm_factor = f32[2] multiply(f0, dnorm_factor)
+ c2 = f32[] constant(0.5)
+ c2_bcast = f32[2] broadcast(c2), dimensions={}
+ f0_dnorm_factor_scaled = f32[2] multiply(f0_dnorm_factor, c2_bcast)
+ f0_dnorm_factor_scaled_bcast = f32[2,4] broadcast(f0_dnorm_factor_scaled), dimensions={0}
+ f1 = f32[2,4] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+ minus_f1 = f32[2,4] negate(f1)
+ minus_f1_sum = f32[2] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
+ f2 = f32[2,4] multiply(norm_factor_bcast, scale_doutput)
+ minus_f2 = f32[2,4] negate(f2)
+ minus_f2_sum = f32[2] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
+ minus_f1_f2_sum = f32[2] add(minus_f1_sum, minus_f2_sum)
+ minus_f1_f2_sum_scaled = f32[2] multiply(minus_f1_f2_sum, r_nelems_bcast)
+ minus_f1_f2_sum_scaled_bcast = f32[2,4] broadcast(minus_f1_f2_sum_scaled), dimensions={0}
+ f1_f2 = f32[2,4] add(f1, f2)
+ dinput = f32[2,4] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+ ROOT out = (f32[2,4], f32[2,4], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> (f32[2,4], f32[2,4], f32[4], f32[4]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
+; CHECK: }
+; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P3]])
+; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0
+; CHECK-DAG: "kind":"LAYER_BWD"
+; CHECK: }
+; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE3]])
+; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
+; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
+; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+ reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+ r_nelems = f32[] constant(0.125)
+ r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,4,6] multiply(input_square_sum,r_nelems_bcast)
+ input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,4,6] multiply(input_mean,input_mean)
+ variance = f32[2,4,6] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+ norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+ scale = f32[8] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[8] parameter(2)
+ bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+ norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+ doutput = f32[2,4,6,8] parameter(3)
+ dbias = f32[8] reduce(doutput, c0), dimensions={0,1,2}, to_apply=apply
+ norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
+ dscale = f32[8] reduce(norm_doutput, c0), dimensions={0,1,2}, to_apply=apply
+ scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
+ input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+ f0 = f32[2,4,6] reduce(input_center_scale_doutput, c0), dimensions={3}, to_apply=apply
+ norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
+ c1 = f32[] constant(-0.5)
+ c1_bcast = f32[2,4,6] broadcast(c1), dimensions={}
+ dnorm_factor = f32[2,4,6] multiply(norm_factor_cube, c1_bcast)
+ f0_dnorm_factor = f32[2,4,6] multiply(f0, dnorm_factor)
+ c2 = f32[] constant(0.25)
+ c2_bcast = f32[2,4,6] broadcast(c2), dimensions={}
+ f0_dnorm_factor_scaled = f32[2,4,6] multiply(f0_dnorm_factor, c2_bcast)
+ f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,2}
+ f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+ minus_f1 = f32[2,4,6,8] negate(f1)
+ minus_f1_sum = f32[2,4,6] reduce(minus_f1, c0), dimensions={3}, to_apply=apply
+ f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+ minus_f2 = f32[2,4,6,8] negate(f2)
+ minus_f2_sum = f32[2,4,6] reduce(minus_f2, c0), dimensions={3}, to_apply=apply
+ minus_f1_f2_sum = f32[2,4,6] add(minus_f1_sum, minus_f2_sum)
+ minus_f1_f2_sum_scaled = f32[2,4,6] multiply(minus_f1_f2_sum, r_nelems_bcast)
+ minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,2}
+ f1_f2 = f32[2,4,6,8] add(f1, f2)
+ dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+ ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) tuple(norm_scale_bias, dinput, dscale, dbias)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
+; CHECK: }
+; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
+; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P3]])
+; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0
+; CHECK-DAG: "kind":"LAYER_BWD"
+; CHECK: }
+; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE3]])
+; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE4]])
+; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE5]])
+; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[8]{0}, f32[8]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
+ reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
+ r_nelems = f32[] constant(0.166667)
+ r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,4,8] multiply(input_square_sum,r_nelems_bcast)
+ input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,4,8] multiply(input_mean,input_mean)
+ variance = f32[2,4,8] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+ scale = f32[6] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[6] parameter(2)
+ bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
+ norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+ doutput = f32[2,4,6,8] parameter(3)
+ dbias = f32[6] reduce(doutput, c0), dimensions={0,1,3}, to_apply=apply
+ norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
+ dscale = f32[6] reduce(norm_doutput, c0), dimensions={0,1,3}, to_apply=apply
+ scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
+ input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+ f0 = f32[2,4,8] reduce(input_center_scale_doutput, c0), dimensions={2}, to_apply=apply
+ norm_factor_cube = f32[2,4,8] divide(norm_factor, variance_plus_epsilon)
+ c1 = f32[] constant(-0.5)
+ c1_bcast = f32[2,4,8] broadcast(c1), dimensions={}
+ dnorm_factor = f32[2,4,8] multiply(norm_factor_cube, c1_bcast)
+ f0_dnorm_factor = f32[2,4,8] multiply(f0, dnorm_factor)
+ c2 = f32[] constant(0.333333)
+ c2_bcast = f32[2,4,8] broadcast(c2), dimensions={}
+ f0_dnorm_factor_scaled = f32[2,4,8] multiply(f0_dnorm_factor, c2_bcast)
+ f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,3}
+ f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+ minus_f1 = f32[2,4,6,8] negate(f1)
+ minus_f1_sum = f32[2,4,8] reduce(minus_f1, c0), dimensions={2}, to_apply=apply
+ f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+ minus_f2 = f32[2,4,6,8] negate(f2)
+ minus_f2_sum = f32[2,4,8] reduce(minus_f2, c0), dimensions={2}, to_apply=apply
+ minus_f1_f2_sum = f32[2,4,8] add(minus_f1_sum, minus_f2_sum)
+ minus_f1_f2_sum_scaled = f32[2,4,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+ minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,3}
+ f1_f2 = f32[2,4,6,8] add(f1, f2)
+ dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+ ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) tuple(norm_scale_bias, dinput, dscale, dbias)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
+; CHECK: }
+; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
+; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2}
+; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
+; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0
+; CHECK-DAG: "kind":"LAYER_BWD"
+; CHECK: }
+; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
+; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
+; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]])
+; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]])
+; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+ reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+ r_nelems = f32[] constant(0.041667)
+ r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
+ input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,8] multiply(input_mean,input_mean)
+ variance = f32[2,8] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+ scale = f32[4,6] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[4,6] parameter(2)
+ bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
+ norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+ doutput = f32[2,4,6,8] parameter(3)
+ dbias = f32[4,6] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
+ norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
+ dscale = f32[4,6] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
+ scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
+ input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+ f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
+ norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+ c1 = f32[] constant(-0.5)
+ c1_bcast = f32[2,8] broadcast(c1), dimensions={}
+ dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
+ f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
+ c2 = f32[] constant(0.083333)
+ c2_bcast = f32[2,8] broadcast(c2), dimensions={}
+ f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
+ f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
+ f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+ minus_f1 = f32[2,4,6,8] negate(f1)
+ minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
+ f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+ minus_f2 = f32[2,4,6,8] negate(f2)
+ minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
+ minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
+ minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+ minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
+ f1_f2 = f32[2,4,6,8] add(f1, f2)
+ dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+ ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) tuple(norm_scale_bias, dinput, dscale, dbias)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
+; CHECK: }
+; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
+; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2}
+; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
+; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0
+; CHECK-DAG: "kind":"LAYER_BWD"
+; CHECK: }
+; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
+; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
+; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]])
+; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]])
+; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,1,8] parameter(0)
+ input_square = f32[2,4,1,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+ reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
+ input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,8] multiply(input_mean,input_mean)
+ variance = f32[2,8] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
+ input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
+ input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,1,8] multiply(input_center, norm_factor_bcast)
+ scale = f32[4,1] parameter(1)
+ scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
+ norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
+ bias = f32[4,1] parameter(2)
+ bias_bcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
+ norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_bcast)
+ doutput = f32[2,4,1,8] parameter(3)
+ dbias = f32[4,1] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
+ norm_doutput = f32[2,4,1,8] multiply(norm, doutput)
+ dscale = f32[4,1] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
+ scale_doutput = f32[2,4,1,8] multiply(scale_bcast, doutput)
+ input_center_scale_doutput = f32[2,4,1,8] multiply(input_center, scale_doutput)
+ f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
+ norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+ c1 = f32[] constant(-0.5)
+ c1_bcast = f32[2,8] broadcast(c1), dimensions={}
+ dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
+ f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
+ c2 = f32[] constant(0.5)
+ c2_bcast = f32[2,8] broadcast(c2), dimensions={}
+ f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
+ f0_dnorm_factor_scaled_bcast = f32[2,4,1,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
+ f1 = f32[2,4,1,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+ minus_f1 = f32[2,4,1,8] negate(f1)
+ minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
+ f2 = f32[2,4,1,8] multiply(norm_factor_bcast, scale_doutput)
+ minus_f2 = f32[2,4,1,8] negate(f2)
+ minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
+ minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
+ minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+ minus_f1_f2_sum_scaled_bcast = f32[2,4,1,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
+ f1_f2 = f32[2,4,1,8] add(f1, f2)
+ dinput = f32[2,4,1,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+ ROOT out = (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) tuple(norm_scale_bias, dinput, dscale, dbias)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1], {{.*}}: f32[2,4,1,8]) -> (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
+; CHECK: }
+; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(3)
+; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P3]]), dimensions={2,0,3,1}
+; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
+; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0
+; CHECK-DAG: "kind":"LAYER_BWD"
+; CHECK: }
+; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG: [[FUSION0:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=0
+; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=1
+; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE4]])
+; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE5]])
+; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+// TODO(b/343124533) Reenable when fixed
+TEST_F(CudnnNormRewriterTest,
+ DISABLED_LayerNormTrainBackward4D1DoutputReshapeSplit) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
+ reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
+ input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
+ variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+ scale = f32[4] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[4] parameter(2)
+ bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
+ norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+ doutput = f32[2,4,48] parameter(3)
+ dbias = f32[4] reduce(doutput, c0), dimensions={0,2}, to_apply=apply
+ doutput_bitcast = f32[2,4,6,8] reshape(doutput)
+ norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
+ dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
+ scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
+ input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+ f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
+ norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
+ c1 = f32[] constant(-0.5)
+ c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
+ dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
+ f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
+ c2 = f32[] constant(0.5)
+ c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
+ f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
+ f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
+ f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+ minus_f1 = f32[2,4,6,8] negate(f1)
+ minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
+ f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+ minus_f2 = f32[2,4,6,8] negate(f2)
+ minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
+ minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
+ minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+ minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
+ f1_f2 = f32[2,4,6,8] add(f1, f2)
+ dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+ ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,48]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
+; CHECK: }
+; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,48]{2,1,0} parameter(3)
+; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
+; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0
+; CHECK-DAG: "kind":"LAYER_BWD"
+; CHECK: }
+; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
+; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
+; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
+; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
+; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+// TODO(b/343124533) Reenable when fixed
+TEST_F(CudnnNormRewriterTest,
+ DISABLED_LayerNormTrainBackward4D1DoutputReshapeCombine) {
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a,b)
+ }
+
+ ENTRY test {
+ input = f32[2,4,6,8] parameter(0)
+ input_square = f32[2,4,6,8] multiply(input, input)
+ c0 = f32[] constant(0)
+ input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
+ reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
+ r_nelems = f32[] constant(0.25)
+ r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
+ input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
+ input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
+ input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
+ variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
+ epsilon = f32[] constant(0.001)
+ epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
+ variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
+ norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
+ norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
+ input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
+ input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+ norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+ scale = f32[4] parameter(1)
+ scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
+ norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+ bias = f32[4] parameter(2)
+ bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
+ norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+ doutput = f32[2,4,6,2,2,2] parameter(3)
+ dbias = f32[4] reduce(doutput, c0), dimensions={0,2,3,4,5}, to_apply=apply
+ doutput_bitcast = f32[2,4,6,8] reshape(doutput)
+ norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
+ dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
+ scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
+ input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+ f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
+ norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
+ c1 = f32[] constant(-0.5)
+ c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
+ dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
+ f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
+ c2 = f32[] constant(0.5)
+ c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
+ f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
+ f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
+ f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+ minus_f1 = f32[2,4,6,8] negate(f1)
+ minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
+ f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+ minus_f2 = f32[2,4,6,8] negate(f2)
+ minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
+ minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
+ minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+ minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
+ f1_f2 = f32[2,4,6,8] add(f1, f2)
+ dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+ ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
+ })";
+
+ const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,6,2,2,2]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0.001
+; CHECK-DAG: "kind":"LAYER_FWD_TRAIN"
+; CHECK: }
+; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,2,2,2]{5,4,3,2,1,0} parameter(3)
+; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
+; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK: custom_call_target="__cudnn$norm",
+; CHECK: backend_config={
+; CHECK-DAG: "epsilon":0
+; CHECK-DAG: "kind":"LAYER_BWD"
+; CHECK: }
+; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
+; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
+; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
+; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
+; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+ )";
+
+ TestNorm(hlo_text, optimized_hlo);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc
new file mode 100644
index 0000000..2acdf9a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc
@@ -0,0 +1,528 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/functional/bind_front.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/cudnn_support_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+// Creates and returns an HLO that zero-pads one or more dimensions in the given
+// instruction so that its shape is equal to the given shape.
+//
+// Padding is added to the end of each relevant dimension.
+//
+// If the instruction already has the given shape, simply returns it without an
+// intervening pad.
+static HloInstruction* PadInstruction(HloInstruction* instr,
+ const Shape& new_shape) {
+ HloComputation* comp = instr->parent();
+
+ const Shape& shape = instr->shape();
+ PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank());
+
+ bool added_padding = false;
+ for (int64_t dim = 0; dim < shape.rank(); ++dim) {
+ if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
+ continue;
+ }
+ CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim));
+ pad_config.mutable_dimensions(dim)->set_edge_padding_high(
+ new_shape.dimensions(dim) - shape.dimensions(dim));
+ added_padding = true;
+ }
+ if (!added_padding) {
+ return instr;
+ }
+
+ auto* zero = comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
+ return comp->AddInstruction(
+ HloInstruction::CreatePad(new_shape, instr, zero, pad_config),
+ &instr->metadata());
+}
+
+// Modifies the given convolution to have the given input and result shapes.
+static absl::Status PadConv(HloCustomCallInstruction* conv,
+ absl::Span<const Shape> new_input_shapes,
+ const Shape& new_result_shape) {
+ CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
+ << "conv must use 0 scratch bytes, i.e. this pass must be run "
+ "before CudnnConvAlgorithmPicker.";
+ std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(conv->operand_count());
+ for (int i = 0; i < conv->operand_count(); ++i) {
+ new_operands.push_back(
+ PadInstruction(conv->mutable_operand(i), new_input_shapes[i]));
+ }
+ const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+ bool changed = false;
+ for (int i = 0; i < conv->operand_count(); ++i) {
+ changed |= (new_operands[i] != conv->mutable_operand(i));
+ }
+ CHECK(changed) << "We should have had to pad at least one input operand.";
+
+ auto add = [&](std::unique_ptr<HloInstruction> new_instr) {
+ return conv->parent()->AddInstruction(std::move(new_instr));
+ };
+
+ Shape new_conv_shape = ShapeUtil::MakeTupleShape(
+ {new_result_shape, ShapeUtil::MakeShape(U8, {0})});
+ auto* new_conv =
+ add(conv->CloneWithNewOperands(new_conv_shape, new_operands));
+
+ // Clone conv's name to new_conv. This is safe because we're going to remove
+ // conv below.
+ new_conv->SetAndSanitizeName(conv->name());
+
+ VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with "
+ << new_conv->ToString();
+
+ // Slice the new conv result if necessary, keeping in mind that new_conv
+ // has tuple shape (new_result_shape, u8[0]).
+ if (!ShapeUtil::Equal(result_shape, new_result_shape)) {
+ std::vector<int64_t> start_indices(result_shape.dimensions_size(), 0);
+ std::vector<int64_t> end_indices(result_shape.dimensions().begin(),
+ result_shape.dimensions().end());
+ std::vector<int64_t> strides(result_shape.dimensions_size(), 1);
+
+ auto* new_conv_result = add(
+ HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0));
+ auto* empty_temp_buffer =
+ add(HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({})));
+ auto* sliced_result = add(HloInstruction::CreateSlice(
+ result_shape, new_conv_result, start_indices, end_indices, strides));
+ new_conv =
+ add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer}));
+ }
+
+ return conv->parent()->ReplaceInstruction(conv, new_conv);
+}
+
+static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
+ HloComputation* comp) {
+ std::vector<HloCustomCallInstruction*> convs;
+ for (HloInstruction* instr : comp->instructions()) {
+ if (IsCustomCallToDnnConvolution(*instr)) {
+ convs.push_back(Cast<HloCustomCallInstruction>(instr));
+ }
+ }
+ return convs;
+}
+
+// This is the main function of the transform. It runs on a given custom call
+// nodes to cuDNN convolution, calls resolve_pad_shapes to resolve
+// the desired input/output feature map shapes, and adds necessary padding and
+// slicing nodes around them.
+//
+// resolve_pad_shapes takes conv, a custom call instruction to cuDNN convolution
+// that may need padding to figure out the desired padded input and output
+// tensor shapes and store the desired shapes in new_input_shapes and
+// new_input_shapes. Notice that new_input_shapes is a vector for multiple
+// input tensors. This function shall return true if padding is necessary or
+// false otherwise in addition to status.
+static absl::StatusOr<bool> ResolveAndPad(
+ HloCustomCallInstruction* conv,
+ std::function<absl::StatusOr<bool>(HloCustomCallInstruction* conv,
+ std::vector<Shape>* new_input_shapes,
+ Shape* new_result_shape)>
+ resolve_pad_shapes) {
+ std::vector<Shape> new_input_shapes;
+ Shape new_result_shape;
+ TF_ASSIGN_OR_RETURN(bool result, resolve_pad_shapes(conv, &new_input_shapes,
+ &new_result_shape));
+ if (result) {
+ TF_RETURN_IF_ERROR(PadConv(conv, new_input_shapes, new_result_shape));
+ return true;
+ }
+ return false;
+}
+
+// Adds padding to cudnn convolutions to make them run faster on GPUs with
+// tensor cores.
+//
+// - f16 convolutions are padded to have input/output channel dimensions that
+// are multiples of 8, so that we can use tensor cores.
+//
+// - f16 convolutions with 3 input channels and 32 or 64 output channels are
+// padded to 4 input channels. There's a special-cased cudnn algorithm just
+// for this.
+//
+// Don't run this pass on GPUs without tensor cores -- it will make them slower!
+//
+// TODO(jlebar): Also pad dots.
+static absl::StatusOr<bool> TryResolvePaddedShapesForTensorCore(
+ HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
+ Shape* new_result_shape_ptr) {
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
+ const auto& dnums = conv->convolution_dimension_numbers();
+ auto* lhs = conv->mutable_operand(0);
+ auto* rhs = conv->mutable_operand(1);
+ const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+ // Nothing to do on non-f16 convolutions.
+ if (result_shape.element_type() != PrimitiveType::F16) {
+ return false;
+ }
+
+ // When convolution is grouped, the shapes are in agreement with the group
+ // size. We cannot pad them independently.
+ if (conv->feature_group_count() > 1 || conv->batch_group_count() > 1) {
+ VLOG(2) << "Do not pad grouped convolution.";
+ return false;
+ }
+
+ // TODO(timshen): Don't skip forward-activation convs if we find a benchmark
+ // where there's a speedup.
+ if (kind == CudnnConvKind::kForwardActivation) {
+ return false;
+ }
+
+ Shape new_lhs_shape = lhs->shape();
+ Shape new_rhs_shape = rhs->shape();
+ Shape& new_result_shape = *new_result_shape_ptr;
+ new_result_shape = conv->shape().tuple_shapes(0);
+
+ // new_{input,filter_output}_shape points to the appropriate one of
+ // new_{lhs,rhs,result}_shape.
+ Shape* new_input_shape;
+ Shape* new_filter_shape;
+ Shape* new_output_shape;
+ std::tie(new_input_shape, new_filter_shape, new_output_shape) = [&] {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ case CudnnConvKind::kForwardGraph:
+ return std::make_tuple(&new_lhs_shape, &new_rhs_shape,
+ &new_result_shape);
+ case CudnnConvKind::kBackwardInput:
+ return std::make_tuple(&new_result_shape, &new_rhs_shape,
+ &new_lhs_shape);
+ case CudnnConvKind::kBackwardFilter:
+ return std::make_tuple(&new_lhs_shape, &new_result_shape,
+ &new_rhs_shape);
+ }
+ }();
+
+ // If there are 3 input features and 32 or 64 output features, pad the input
+ // features to 4. Otherwise, try padding to multiples of 8 and check that
+ // this doesn't make any of the conv buffers too much larger.
+ auto input_features =
+ new_input_shape->dimensions(dnums.input_feature_dimension());
+ auto output_features =
+ new_output_shape->dimensions(dnums.output_feature_dimension());
+ if (input_features == 3 && (output_features == 32 || output_features == 64)) {
+ new_input_shape->set_dimensions(dnums.input_feature_dimension(), 4);
+ new_filter_shape->set_dimensions(dnums.kernel_input_feature_dimension(), 4);
+ } else {
+ auto pad_dim = [](Shape* s, int64_t dim) {
+ s->set_dimensions(dim, RoundUpTo<int64_t>(s->dimensions(dim), 8));
+ };
+ pad_dim(new_input_shape, dnums.input_feature_dimension());
+ pad_dim(new_filter_shape, dnums.kernel_input_feature_dimension());
+ pad_dim(new_filter_shape, dnums.kernel_output_feature_dimension());
+ pad_dim(new_output_shape, dnums.output_feature_dimension());
+
+ // We won't pad a conv if doing so increases the total number of bytes in
+ // the lhs, rhs, or result by more than this amount.
+ //
+ // TODO(jlebar): This number was tuned experimentally. It represents a
+ // compromise on our current benchmarks; it speeds some up significantly,
+ // and doesn't slow any down. But we can observe by changing this value
+ // that there's additional room for speedups. Achieving those speedups
+ // without also slowing other things down will likely require a more
+ // sophisticated heuristic, possibly some form of auto-tuning.
+ static constexpr double kMaxBytesTouchedBound = 1.35;
+
+ // Check that padding wouldn't increase the total bytes read/written by this
+ // operation too much.
+ auto check_size_increase = [&](const Shape& old_shape,
+ const Shape& new_shape) {
+ int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
+ int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
+ if (new_bytes <= old_bytes * kMaxBytesTouchedBound) {
+ return true;
+ }
+ VLOG(3)
+ << "Not padding convolution; doing so would change input / result "
+ "shape from "
+ << ShapeUtil::HumanString(old_shape) << " to "
+ << ShapeUtil::HumanString(new_shape) << ", a size increase of "
+ << new_bytes / static_cast<double>(old_bytes) << "x > "
+ << kMaxBytesTouchedBound << "x: " << conv->ToString();
+ return false;
+ };
+
+ if (!check_size_increase(lhs->shape(), new_lhs_shape) ||
+ !check_size_increase(rhs->shape(), new_rhs_shape) ||
+ !check_size_increase(result_shape, new_result_shape)) {
+ return false;
+ }
+ }
+
+ if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
+ ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) {
+ VLOG(3) << "No need to pad features of " << conv->ToString();
+ return false;
+ }
+
+ new_input_shapes_ptr->push_back(new_lhs_shape);
+ new_input_shapes_ptr->push_back(new_rhs_shape);
+ return true;
+}
+
+// Adds padding to cudnn integer convolutions to make input and output feature
+// maps multiples of pad_to (usually 4 or 32).
+absl::StatusOr<bool> TryResolvePaddedShapesForIntegerConvolution(
+ int pad_to, const se::CudaComputeCapability& compute_capability,
+ HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
+ Shape* new_result_shape_ptr) {
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
+ const Shape& input_shape = conv->operand(0)->shape();
+ const Shape& kernel_shape = conv->operand(1)->shape();
+ const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+ // Integer convolution only
+ if (!primitive_util::IsIntegralType(input_shape.element_type())) {
+ return false;
+ }
+
+ // kForward and kForwardActivation only
+ if (kind != CudnnConvKind::kForward &&
+ kind != CudnnConvKind::kForwardActivation) {
+ return false;
+ }
+
+ const auto& dnums = conv->convolution_dimension_numbers();
+ std::vector<Shape>& new_input_shapes = *new_input_shapes_ptr;
+ for (auto operand : conv->operands()) {
+ new_input_shapes.push_back(operand->shape());
+ }
+ Shape& new_result_shape = *new_result_shape_ptr;
+ new_result_shape = conv->shape().tuple_shapes(0);
+
+ // The input/kernel/output might already be vectorized (i.e. cudnn layout
+ // NCHW_VECT_C). If so, we pad the features dim so that
+ // size(features_dim) * size(vect_dim) is a multiple of pad_to.
+ std::optional<int64_t> input_vect_dim;
+ std::optional<int64_t> kernel_vect_dim;
+ std::optional<int64_t> result_vect_dim;
+ std::tie(input_vect_dim, kernel_vect_dim, result_vect_dim) =
+ FindVectorizedFeatureDims(dnums, input_shape, kernel_shape, result_shape);
+
+ int64_t input_vect_size =
+ input_vect_dim.has_value() ? input_shape.dimensions(*input_vect_dim) : 1;
+ int64_t kernel_vect_size = kernel_vect_dim.has_value()
+ ? kernel_shape.dimensions(*kernel_vect_dim)
+ : 1;
+ int64_t result_vect_size = result_vect_dim.has_value()
+ ? result_shape.dimensions(*result_vect_dim)
+ : 1;
+ if (pad_to % input_vect_size != 0 || pad_to % kernel_vect_size != 0 ||
+ pad_to % result_vect_size != 0) {
+ // If the conv is already vectorized but pad_to is not a multiple of the
+ // vector size, we choose not to pad. This is a weird case, because the
+ // only useful vector sizes in cudnn (as of writing) are 4 and 32, and those
+ // are also the only pad_to cases.
+ return false;
+ }
+
+ // Check that cudnn support our desired integer padding/vectorization.
+ TF_ASSIGN_OR_RETURN(bool cudnn_supports,
+ CudnnSupportsOptimizedIntegerConvolution(
+ compute_capability, *conv, pad_to));
+ if (!cudnn_supports) {
+ return false;
+ }
+
+ // Pad the features to multiples of pad_to.
+ {
+ auto pad_dim = [&](Shape* s, int64_t dim, int64_t cur_vect_size) {
+ CHECK_EQ(pad_to % cur_vect_size, 0);
+ s->set_dimensions(
+ dim, RoundUpTo<int64_t>(s->dimensions(dim), pad_to / cur_vect_size));
+ };
+
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ CHECK_EQ(new_input_shapes.size(), 2);
+ // Input feature maps
+ pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
+ input_vect_size);
+ // Kernel for the input feature maps
+ pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
+ kernel_vect_size);
+ // Kernel for the output feature maps. In the NCHW_VECT_C, only the
+ // kernel input feature dim is vectorized, so this has cur_vect_size 1.
+ pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
+ /*cur_vect_size=*/1);
+ // Output feature maps
+ pad_dim(&new_result_shape, dnums.output_feature_dimension(),
+ result_vect_size);
+ break;
+ case CudnnConvKind::kForwardActivation:
+ CHECK(new_input_shapes.size() == 3 || new_input_shapes.size() == 4);
+ // Input feature maps
+ pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
+ input_vect_size);
+ // Kernel for the input feature maps
+ pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
+ kernel_vect_size);
+ // Kernel for the output feature maps. In the NCHW_VECT_C, only the
+ // kernel input feature dim is vectorized, so this has cur_vect_size 1.
+ pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
+ /*cur_vect_size=*/1);
+
+ // Bias. This ia 1D vector of length output-depth, and it's unclear if
+ // we *have* to pad it. But hey, we might as well. cur_vect_size 1
+ // because NCHW_VECT_C doesn't apply here (there is no channels
+ // dimension!).
+ pad_dim(&new_input_shapes[2], /*dim=*/0, /*cur_vect_size=*/1);
+
+ if (new_input_shapes.size() == 4) {
+ // Optional side input. Same layout as result, so gets padded the
+ // same.
+ pad_dim(&new_input_shapes[3], dnums.output_feature_dimension(),
+ result_vect_size);
+ }
+ // Output feature maps
+ pad_dim(&new_result_shape, dnums.output_feature_dimension(),
+ result_vect_size);
+ break;
+ default:
+ CHECK(false);
+ }
+
+ // We won't pad a conv if doing so increases the total number of bytes in
+ // the lhs, rhs, or result by a factor of this much or more.
+ //
+ // Note: It's important that this bound is exclusive. It's a performance
+ // regression to pad and increase input/output size by 2x, so we only pad
+ // strictly less than 2x.
+ //
+ // TODO(jlebar): This number was tuned experimentally, but without much
+ // experimental evidence.
+ static constexpr double kMaxBytesTouchedBound = 2;
+
+ // Check that padding wouldn't increase the total bytes read/written by this
+ // operation too much.
+ auto check_size_increase = [&](const Shape& old_shape,
+ const Shape& new_shape) {
+ int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
+ int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
+ if (new_bytes < old_bytes * kMaxBytesTouchedBound) {
+ return true;
+ }
+ VLOG(3)
+ << "Not padding convolution; doing so would change input / result "
+ "shape from "
+ << ShapeUtil::HumanString(old_shape) << " to "
+ << ShapeUtil::HumanString(new_shape) << ", a size increase of "
+ << new_bytes / static_cast<double>(old_bytes)
+ << "x >= " << kMaxBytesTouchedBound << "x: " << conv->ToString();
+ return false;
+ };
+
+ // Check size increase only on the input and output. No need to check the
+ // filter, since that's determined by the input/output. The bias (if
+ // present) is tiny (1D array of length output-depth), so padding doesn't
+ // matter. And the side-input, if present, is the same shape as the input.
+ if (!check_size_increase(conv->operand(0)->shape(), new_input_shapes[0]) ||
+ !check_size_increase(result_shape, new_result_shape)) {
+ return false;
+ }
+ }
+
+ bool changed = false;
+ for (int64_t i = 0; i < conv->operand_count(); ++i) {
+ changed |=
+ !ShapeUtil::Equal(conv->operand(i)->shape(), new_input_shapes[i]);
+ }
+ if (!changed) {
+ VLOG(3) << "No need to pad features of " << conv->ToString();
+ }
+
+ return changed;
+}
+
+absl::StatusOr<bool> CudnnPadForConvolutions::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
+ // On Turing and later (sm75+), pad to multiples of 32 bytes if possible,
+ // because that lets us use the fast int8x32 data type.
+ bool local_changed = false;
+ if (compute_capability_.IsAtLeast(7, 5)) {
+ TF_ASSIGN_OR_RETURN(
+ local_changed,
+ ResolveAndPad(conv, absl::bind_front(
+ TryResolvePaddedShapesForIntegerConvolution,
+ 32, compute_capability_)));
+ }
+ if (!local_changed) {
+ TF_ASSIGN_OR_RETURN(
+ local_changed,
+ ResolveAndPad(conv, absl::bind_front(
+ TryResolvePaddedShapesForIntegerConvolution,
+ 4, compute_capability_)));
+ }
+ changed |= local_changed;
+ }
+ if (compute_capability_.IsAtLeast(se::CudaComputeCapability::VOLTA)) {
+ for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
+ TF_ASSIGN_OR_RETURN(
+ bool local_changed,
+ ResolveAndPad(conv, TryResolvePaddedShapesForTensorCore));
+ changed |= local_changed;
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h
new file mode 100644
index 0000000..719efec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+
+// Two zero-paddings for CuDNN thunking are done in this transform: padding for
+// tensor cores and padding for integer convolutions. This transform also
+// add slice instruction to remove unnecessary output features.
+class CudnnPadForConvolutions : public HloModulePass {
+ public:
+ explicit CudnnPadForConvolutions(se::CudaComputeCapability compute_capability)
+ : compute_capability_(compute_capability) {}
+
+ absl::string_view name() const override {
+ return "cudnn_pad_for_convolutions";
+ }
+ // Run PadForConvolutions on the given module and return if any change is made
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const se::CudaComputeCapability compute_capability_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc
new file mode 100644
index 0000000..7cee2c5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc
@@ -0,0 +1,456 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = xla::match;
+
+class CudnnPadForConvolutionsTest : public HloTestBase {};
+
+TEST_F(CudnnPadForConvolutionsTest, DoNotPadF16ForwardConvWhenGrouped) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[704,48,1,49]{3,2,1,0} parameter(0)
+ filter = f16[44,768,1,50]{3,2,1,0} parameter(1)
+ ROOT result = (f16[1,128,48,768]{3,2,1,0}, u8[0]{0})
+ custom-call(input, filter)
+ , window={size=1x50 pad=0_0x64_64}
+ , dim_labels=fb01_io01->01bf
+ , feature_group_count=16
+ , custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvInputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,41] parameter(0)
+ filter = f16[2,2,41,40] parameter(1)
+ ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+
+ SCOPED_TRACE(module->ToString());
+
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::CustomCall(
+ {kCudnnConvForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
+ m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 48, 40}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvOutputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ output = f16[10,20,30,41] parameter(0)
+ filter = f16[2,2,40,41] parameter(1)
+ ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardInput"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::CustomCall(
+ {kCudnnConvBackwardInputCallTarget},
+ m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
+ m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 40, 48}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvOutputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,40] parameter(0)
+ filter = f16[2,2,40,41] parameter(1)
+ ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvForwardCallTarget}, m::Parameter(0),
+ m::Pad(m::Parameter(1), m::Op())))),
+ m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvInputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ output = f16[10,20,30,40] parameter(0)
+ filter = f16[2,2,41,40] parameter(1)
+ result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardInput"
+ ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ GmockMatch(m::GetTupleElement(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvBackwardInputCallTarget}, m::Parameter(0),
+ m::Pad(m::Parameter(1), m::Op())))),
+ m::Op()))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvInputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,41] parameter(0)
+ output = f16[10,20,30,40] parameter(1)
+ result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardFilter"
+ ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ GmockMatch(m::GetTupleElement(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvBackwardFilterCallTarget},
+ m::Pad(m::Parameter(0), m::Op()), m::Parameter(1)))),
+ m::Op()))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvOutputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,40] parameter(0)
+ output = f16[10,20,30,41] parameter(1)
+ result = (f16[2,2,40,41], u8[0]) custom-call(input, output),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardFilter"
+ ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ GmockMatch(m::GetTupleElement(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvBackwardFilterCallTarget}, m::Parameter(0),
+ m::Pad(m::Parameter(1), m::Op())))),
+ m::Op()))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInputFeatures3To4) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,3] parameter(0)
+ filter = f16[2,2,3,32] parameter(1)
+ ROOT result = (f16[10,20,30,32], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+
+ SCOPED_TRACE(module->ToString());
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::CustomCall(
+ {kCudnnConvForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 4}),
+ m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 4, 32}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvInputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,41] parameter(0)
+ filter = s8[2,2,41,40] parameter(1)
+ ROOT result = (f32[10,20,30,40], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+
+ SCOPED_TRACE(module->ToString());
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::CustomCall(
+ {kCudnnConvForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 44}),
+ m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 44, 40}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvOutputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,40] parameter(0)
+ filter = s8[2,2,40,41] parameter(1)
+ ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvForwardCallTarget}, m::Parameter(0),
+ m::Pad(m::Parameter(1), m::Op())))),
+ m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInt8To32OnSm75) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,40] parameter(0)
+ filter = s8[2,2,40,41] parameter(1)
+ ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 64}),
+ m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 64, 64})))),
+ m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32OnSm70) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,40] parameter(0)
+ filter = s8[2,2,40,41] parameter(1)
+ ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvForwardCallTarget}, m::Parameter(0),
+ m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
+ m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32FloatOutputSm75) {
+ // This test checks that the padding pass correctly calls
+ // CudnnSupportsOptimizedIntegerConvolution() which should reject this
+ // convolution because its output type is f32. It should be padded to int8x4
+ // because that supports float outputs.
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,38] parameter(0)
+ filter = s8[2,2,38,41] parameter(1)
+ ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 40}),
+ m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
+ m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadInt8UnsupportedFilterTypeOutputSm75) {
+ // This test checks that the padding pass correctly calls
+ // CudnnSupportsOptimizedIntegerConvolution() which should reject this
+ // convolution because kernel type is f32.
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,38] parameter(0)
+ filter = f32[2,2,38,41] parameter(1)
+ ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadToInt8x32ExcessiveBlowup) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[128,4,48,48] parameter(0)
+ filter = s8[64,4,3,3] parameter(1)
+ ROOT result = (f32[128,64,48,48], u8[0]) custom-call(input, filter),
+ window={size=3x3}, dim_labels=bf01_io01->bf01,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,41,4] parameter(0)
+ filter = s8[2,2,41,4,168] parameter(1)
+ ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Slice(m::GetTupleElement(
+ m::CustomCall({kCudnnConvForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op())
+ .WithShape(S8, {10, 20, 30, 48, 4}),
+ m::Pad(m::Parameter(1), m::Op())
+ .WithShape(S8, {2, 2, 48, 4, 192})))
+ .WithShape(S8, {10, 20, 30, 48, 4})),
+ m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32BiasActivation) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,41,4] parameter(0)
+ filter = s8[2,2,41,4,168] parameter(1)
+ bias = f32[10] parameter(2)
+ side_input = s8[10,20,30,42,4] parameter(3)
+ ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter, bias, side_input),
+ window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
+ custom_call_target="__cudnn$convBiasActivationForward"
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Slice(
+ m::GetTupleElement(
+ m::CustomCall(
+ {kCudnnConvBiasActivationForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op())
+ .WithShape(S8, {10, 20, 30, 48, 4}),
+ m::Pad(m::Parameter(1), m::Op())
+ .WithShape(S8, {2, 2, 48, 4, 192}),
+ m::Pad(m::Parameter(2), m::Op()).WithShape(F32, {32}),
+ m::Pad(m::Parameter(3), m::Op())
+ .WithShape(S8, {10, 20, 30, 48, 4})))
+ .WithShape(S8, {10, 20, 30, 48, 4})),
+ m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest,
+ PadIntFusedForwardConvInputAndOutputChannels) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule Test
+
+ ENTRY %Test (input: s8[1,3,3,2], filter: s8[3,3,2,5], side_input: s8[1,3,3,5], bias: s8[5]) -> f32[1,3,3,5] {
+ %input = s8[1,3,3,3]{3,2,1,0} parameter(0)
+ %filter = s8[3,3,2,5]{3,2,1,0} parameter(1)
+ %bias = s8[5]{0} parameter(3)
+ %convert = f32[5]{0} convert(s8[5]{0} %bias)
+ %side_input = f32[1,3,3,5]{3,2,1,0} parameter(2)
+ %custom-call.1 = (f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,3,3,3]{3,2,1,0} %input, s8[3,3,2,5]{3,2,1,0} %filter, f32[5]{0} %convert, f32[1,3,3,5]{3,2,1,0} %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"activationMode\":\"2\",\"convResultScale\":1,\"sideInputScale\":1}"
+ ROOT %get-tuple-element.1 = f32[1,3,3,5]{3,2,1,0} get-tuple-element((f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) %custom-call.1), index=0
+ })")
+ .value();
+ EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::GetTupleElement(m::Tuple(
+ m::Slice(m::GetTupleElement(m::CustomCall(
+ {kCudnnConvBiasActivationForwardCallTarget},
+ m::Pad(m::Parameter(0), m::Op()),
+ m::Pad(m::Parameter(1), m::Op()),
+ m::Pad(m::Convert(m::Parameter(3)), m::Op()),
+ m::Pad(m::Parameter(2), m::Op())))),
+ m::Op()))));
+}
+
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc
new file mode 100644
index 0000000..30e0a4b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc
@@ -0,0 +1,482 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_simplify_padding.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+namespace {
+namespace m = ::xla::match;
+
+// If exactly one index of `dims` is false, returns that index. If 0 or more
+// than one index is false, returns nullopt.
+std::optional<int64_t> FindFalseIndex(absl::Span<const bool> vals) {
+ std::optional<int64_t> missing_dim;
+ for (int i = 0; i < vals.size(); i++) {
+ if (vals[i]) {
+ continue;
+ }
+ if (missing_dim.has_value()) {
+ VLOG(2) << "Multiple dimensions are missing from conv dnums; can't "
+ "determine which is vect_c dimension";
+ return std::nullopt;
+ }
+ missing_dim = i;
+ }
+ return missing_dim;
+}
+
+// Finds the vect_c dimension in the convolution's output.
+//
+// The vect_c dimension in dnums is the dimension that's not mentioned in
+// `dnums`. If there's zero or more than one such dimension, returns nullopt.
+std::optional<int64_t> FindOutputVectCDim(HloInstruction* conv) {
+ const ConvolutionDimensionNumbers& dnums =
+ conv->convolution_dimension_numbers();
+ int64_t num_dims = conv->shape().tuple_shapes(0).dimensions_size();
+ absl::InlinedVector<bool, 5> seen_dims(num_dims);
+ seen_dims[dnums.output_batch_dimension()] = true;
+ seen_dims[dnums.output_feature_dimension()] = true;
+ for (int64_t d : dnums.output_spatial_dimensions()) {
+ seen_dims[d] = true;
+ }
+ return FindFalseIndex(seen_dims);
+}
+
+// Finds the vect_c dimension in the convolution's kernel.
+std::optional<int64_t> FindKernelVectCDim(HloInstruction* conv) {
+ const ConvolutionDimensionNumbers& dnums =
+ conv->convolution_dimension_numbers();
+ int64_t num_dims = conv->operand(1)->shape().dimensions_size();
+ absl::InlinedVector<bool, 5> seen_dims(num_dims);
+ seen_dims[dnums.kernel_input_feature_dimension()] = true;
+ seen_dims[dnums.kernel_output_feature_dimension()] = true;
+ for (int64_t d : dnums.kernel_spatial_dimensions()) {
+ seen_dims[d] = true;
+ }
+ return FindFalseIndex(seen_dims);
+}
+
+// Attempts to count the number of output features at the end of conv that are
+// guaranteed to be 0.
+//
+// This is the same as counting the number of values o at the end of the kernel
+// for which kernel[i,o,h,w] is 0 for all values i,h,w.
+std::optional<int64_t> NumTrailingZeroOutputFeatures(HloInstruction* conv) {
+ const ConvolutionDimensionNumbers& dnums =
+ conv->convolution_dimension_numbers();
+ int64_t feature_dim = dnums.kernel_output_feature_dimension();
+ const HloInstruction* weights = conv->operand(1);
+
+ // If the filter is reordered for an int8x32 NCHW_VECT_C convolution, find the
+ // original, un-reordered filter and check *it* for trailing zero output
+ // features.
+ auto backend_config = conv->backend_config<GpuBackendConfig>();
+ if (backend_config.ok() &&
+ backend_config->cudnn_conv_backend_config().reordered_int8_nchw_vect()) {
+ VLOG(2) << "Matched int8x32 convolution with filter reordering";
+
+ // Try to set weights to the original, un-reordered value.
+ const HloInstruction *reshape, *transpose;
+ bool matched =
+ Match(weights, m::Reshape(m::Transpose(
+ &transpose, m::Reshape(&reshape, m::Op(&weights)))));
+
+ // Verify some properties of the reshape-transpose-reshape pattern.
+ // If these don't hold, it means that some pass (e.g. constant folding)
+ // has modified the filter, making making it infeasible to get the original,
+ // un-reordered value.
+ if (!matched || feature_dim != 0 || transpose->shape().rank() != 8) {
+ VLOG(2) << "The filter output feature dimension cannot be determined, as "
+ "the reordering sequence is modified";
+ return std::nullopt;
+ }
+
+ // Calculate the actual output feature dimension before the transpose.
+ // For example: the input filter [I, O, H, W] will get reshaped to
+ // [I/32, 8, 4, O/8, 4, 2, H, W], transposed in a way that is compatible
+ // with cuDNN INT8x32_CONFIG convolutions (see 'cudnn_support_utils.h') and
+ // reshaped again to [O, I/32, H, W, 32]. While the output features
+ // dimension is zero, we need to get the dimension in the original shape
+ // (equals to one in this example).
+ const auto& transpose_dimensions =
+ Cast<HloTransposeInstruction>(transpose)->dimensions();
+
+ // Calculate combined dimensions size before the first appearing output
+ // component [O/8], which appears in position 3 of the transpose.
+ int64_t preceding_size = 1;
+ for (int64_t i = transpose_dimensions.at(3) - 1; i >= 0; --i) {
+ preceding_size *= reshape->shape().dimensions(i);
+ }
+
+ // Skip dimensions in front until the output features dimension is found.
+ int64_t accumulated_size = 1;
+ for (int64_t size : weights->shape().dimensions()) {
+ if (accumulated_size < preceding_size) {
+ accumulated_size *= size;
+ ++feature_dim;
+ } else {
+ break;
+ }
+ }
+ // Sanity check; if this condition doesn't hold, something is off.
+ if (accumulated_size != preceding_size) {
+ VLOG(2) << "Something is really wrong here, I give up";
+ return std::nullopt;
+ }
+ VLOG(2) << "Computed output feature dimension: " << feature_dim;
+ }
+
+ VLOG(2) << "Computing NumTrailingZeroOutputFeatures of " << conv->ToString()
+ << "\nwith weights " << weights->ToString();
+ if (Match(weights, m::Pad(m::Op(), m::ConstantEffectiveScalar(0)))) {
+ const PaddingConfig::PaddingConfigDimension& padding_config =
+ weights->padding_config().dimensions(feature_dim);
+ // The last N output feature weights are all 0.
+ VLOG(2) << "Success: Weights is a pad; padding on output feature dim is "
+ << padding_config.edge_padding_high();
+ return padding_config.edge_padding_high();
+ } else if (const HloInstruction * pad; Match(
+ weights, m::Reshape(m::Pad(&pad, m::Op(),
+ m::ConstantEffectiveScalar(0))))) {
+ // Check that the reshape merely adds a VECT_C to the kernel input features.
+ // That is, we reshape from [I,O,H,W] (in some order) to [I/k,k,O,H,W] (in
+ // the same order) for some constant k (probably 32). Then check how much
+ // the pad adds to the O dimension.
+ std::optional<int64_t> vect_c_dim = FindKernelVectCDim(conv);
+ if (!vect_c_dim.has_value()) {
+ VLOG(2) << "fail: Can't find vect_c dimension in conv.";
+ return std::nullopt;
+ }
+ if (*vect_c_dim != dnums.kernel_input_feature_dimension() + 1) {
+ VLOG(2) << "fail: vect_c dim is in the wrong place; should be right "
+ "after kernel input feature dims in conv.";
+ return std::nullopt;
+ }
+ absl::InlinedVector<int64_t, 5> expected_pad_dim_sizes(
+ weights->shape().dimensions().begin(),
+ weights->shape().dimensions().end());
+ expected_pad_dim_sizes[dnums.kernel_input_feature_dimension()] *=
+ weights->shape().dimensions(*vect_c_dim);
+ expected_pad_dim_sizes.erase(expected_pad_dim_sizes.begin() + *vect_c_dim);
+ if (pad->shape().dimensions() != expected_pad_dim_sizes) {
+ VLOG(2) << "fail: Reshape doesn't simply merge vect_c dimension into "
+ "input features dim "
+ << weights->ToString() << " but expected dims "
+ << absl::StrJoin(expected_pad_dim_sizes, ",");
+ return std::nullopt;
+ }
+
+ // If the filter dnums are e.g. [I,O,H,W] then after reshape they are
+ // [I/k,k,O,H,W] and the new index of O is greater less than before the
+ // reshape (which we know only adds the I/k and k dims, which we also know
+ // are contiguous). OTOH if the O comes before the I in the original, then
+ // the index of O doesn't change after the reshape.
+ int64_t feature_dim_before_reshape = feature_dim;
+ if (dnums.kernel_output_feature_dimension() >
+ dnums.kernel_input_feature_dimension()) {
+ feature_dim_before_reshape--;
+ }
+ const PaddingConfig::PaddingConfigDimension& padding_config =
+ pad->padding_config().dimensions(feature_dim_before_reshape);
+
+ // The last N output feature weights are all 0.
+ VLOG(2) << "Success: Weights is a reshape of a pad; padding on output "
+ "feature dim is "
+ << padding_config.edge_padding_high();
+ return padding_config.edge_padding_high();
+ } else if (Match(weights, m::Constant())) {
+ // Iterate backwards over `weights` to find the index of the first nonzero
+ // value.
+ //
+ // TODO(jlebar): This may be slow, because it iterates over potentially the
+ // whole constant and does a multi_index -> linear_index conversion for each
+ // element. If necessary we could rewrite this by using linear indices, but
+ // we'd have to be careful of the fact that literals can have arbitrary
+ // layouts, so you can't just iterate over the literal's bytes.
+ const Literal& lit = weights->literal();
+ const auto& dims = weights->shape().dimensions();
+ absl::InlinedVector<int64_t, 5> multi_index;
+ for (int64_t dim : dims) {
+ multi_index.push_back(dim - 1);
+ }
+ // This iterates through the literal with feature_dim as the most
+ // major dimension looking for the final non-zero feature.
+ auto decrement_multi_index = [&] {
+ for (int i = 0; i < multi_index.size(); ++i) {
+ if (i != feature_dim) {
+ int64_t& idx = multi_index[i];
+ --idx;
+ if (idx == -1) {
+ idx = dims[i] - 1;
+ } else {
+ return true;
+ }
+ }
+ }
+ int64_t& idx = multi_index[feature_dim];
+ --idx;
+ return idx != -1;
+ };
+ do {
+ if (!lit.IsZero(multi_index)) {
+ break;
+ }
+ } while (decrement_multi_index());
+
+ // The iteration stops if a feature has a non-zero value (or -1), but we
+ // want the first zero feature which is always the next one (or 0 if -1).
+ int64_t first_trailing_zero_feature = multi_index[feature_dim] + 1;
+
+ if (first_trailing_zero_feature == 0) {
+ VLOG(2) << "Weights constant is entirely zero.";
+ } else {
+ VLOG(2) << "First nonzero index in weights constant is "
+ << absl::StrJoin(multi_index, ",");
+ }
+ int64_t ret =
+ std::max<int64_t>(0, weights->shape().dimensions(feature_dim) -
+ first_trailing_zero_feature);
+ VLOG(2) << "Success: weights is a constant; num zero trailing output "
+ "features is "
+ << ret;
+ return ret;
+ }
+ return std::nullopt;
+}
+
+absl::StatusOr<bool> TrySimplifyPadding(HloInstruction* instr) {
+ // Match one of the following patterns.
+ // conv -> slice -> pad
+ // conv -> reshape -> slice-> pad
+ // conv -> transpose -> reshape -> slice -> pad
+ //
+ // where `pad` (the root of the pattern) is `instr`.
+ HloInstruction* conv;
+ HloInstruction* transpose = nullptr; // optional
+ HloInstruction* reshape = nullptr; // optional
+ HloInstruction* slice;
+ HloInstruction* pad;
+ auto conv_matcher = m::GetTupleElement(
+ m::CustomCall(&conv).WithPredicate([](const HloInstruction* instr) {
+ return instr->custom_call_target() == kCudnnConvForwardCallTarget ||
+ instr->custom_call_target() ==
+ kCudnnConvBiasActivationForwardCallTarget;
+ }),
+ 0);
+ auto pad_matcher = m::Pad(m::Op(), m::ConstantEffectiveScalar(0));
+ if (!MatchAndLogIfFailed(instr, "conv-slice-pad",
+ m::Pad(&pad, m::Slice(&slice, conv_matcher),
+ m::ConstantEffectiveScalar(0)),
+ VLOG_IS_ON(3), pad_matcher) &&
+ !MatchAndLogIfFailed(
+ instr, "conv-reshape-slice-pad",
+ m::Pad(&pad, m::Slice(&slice, m::Reshape(&reshape, conv_matcher)),
+ m::ConstantEffectiveScalar(0)),
+ VLOG_IS_ON(3), pad_matcher) &&
+ !MatchAndLogIfFailed(
+ instr, "conv-transpose-reshape-slice-pad",
+ m::Pad(&pad,
+ m::Slice(&slice,
+ m::Reshape(&reshape,
+ m::Transpose(&transpose, conv_matcher))),
+ m::ConstantEffectiveScalar(0)),
+ VLOG_IS_ON(3), pad_matcher)) {
+ return false;
+ }
+
+ VLOG(2) << "Found pattern to attempt to simplify:\n"
+ << "conv: " << conv->ToString() //
+ << "\ntranspose: "
+ << (transpose != nullptr ? transpose->ToString() : "(null)")
+ << "\nreshape: "
+ << (reshape != nullptr ? reshape->ToString() : "(null)")
+ << "\nslice: " << slice->ToString() //
+ << "\npad: " << pad->ToString();
+
+ // Now check that we can merge the slice into the pad, because the slice is
+ // slicing off elements that we know are 0 and the pad is just adding those 0s
+ // back.
+ //
+ // First, we have to check whether any of the output features at the end of
+ // the conv are known to be 0.
+ std::optional<int64_t> num_known_zero_output_features =
+ NumTrailingZeroOutputFeatures(conv);
+ if (!num_known_zero_output_features.has_value() ||
+ *num_known_zero_output_features == 0) {
+ VLOG(2) << "fail: Didn't find any known-zero output features";
+ return false;
+ }
+
+ // We now know that some of the output features of the conv (starting at
+ // known_zero_output_features_start_idx) are zero. Check if the
+ // optional-reshape + optional-transpose + slice + pad combination is setting
+ // all of these features to 0. If so, we can merge the slice into the pad.
+ const auto& dnums = conv->convolution_dimension_numbers();
+ int64_t output_feature_dim;
+ if (reshape == nullptr) {
+ CHECK_EQ(transpose, nullptr);
+ output_feature_dim = dnums.output_feature_dimension();
+ } else {
+ std::optional<int64_t> vect_c_dim_before_transpose =
+ FindOutputVectCDim(conv);
+ if (!vect_c_dim_before_transpose.has_value()) {
+ VLOG(2) << "Couldn't find vect_c output dim in conv.";
+ return false;
+ }
+
+ // If there's no transpose, check that the vect_c dim is immediately after
+ // the feature dim. OTOH if there is a transpose, check that the transpose
+ // moves the vect_c dim immediately after the feature dim.
+ int64_t feature_dim_after_transpose;
+ int64_t vect_c_dim_after_transpose;
+ if (transpose == nullptr) {
+ feature_dim_after_transpose = dnums.output_feature_dimension();
+ vect_c_dim_after_transpose = *vect_c_dim_before_transpose;
+ } else {
+ const auto& transpose_dims = transpose->dimensions();
+ feature_dim_after_transpose = std::distance(
+ transpose->dimensions().begin(),
+ absl::c_find(transpose_dims, dnums.output_feature_dimension()));
+ vect_c_dim_after_transpose = std::distance(
+ transpose->dimensions().begin(),
+ absl::c_find(transpose_dims, *vect_c_dim_before_transpose));
+ }
+ if (vect_c_dim_after_transpose != feature_dim_after_transpose + 1) {
+ VLOG(2) << "fail: after transpose (if present), vect_c dim must appear "
+ "immediately after output feature dim: Computed "
+ "vect_d_dim_after_transpose to be "
+ << vect_c_dim_after_transpose;
+ return false;
+ }
+
+ // Now check that the reshape merges the feature + vect_c dims and
+ // doesn't do anything else.
+ absl::InlinedVector<int64_t, 5> expected_reshape_dim_sizes(
+ reshape->operand(0)->shape().dimensions().begin(),
+ reshape->operand(0)->shape().dimensions().end());
+ expected_reshape_dim_sizes[feature_dim_after_transpose] *=
+ expected_reshape_dim_sizes[vect_c_dim_after_transpose];
+ expected_reshape_dim_sizes.erase(expected_reshape_dim_sizes.begin() +
+ vect_c_dim_after_transpose);
+ if (reshape->shape().dimensions() != expected_reshape_dim_sizes) {
+ VLOG(2) << "fail: Reshape doesn't merge vect_c with feature dimension.";
+ return false;
+ }
+
+ output_feature_dim = feature_dim_after_transpose;
+ }
+
+ // Check that `slice` slices only the output feature dimension.
+ if (!absl::c_all_of(slice->slice_starts(), [](auto v) { return v == 0; }) ||
+ !absl::c_all_of(slice->slice_strides(), [](auto v) { return v == 1; })) {
+ VLOG(2) << "fail: Slice doesn't start at the front or has stride != 1.";
+ return false;
+ }
+
+ // We're only allowed to slice the feature dim.
+ for (int64_t dim = 0; dim < slice->slice_limits().size(); dim++) {
+ if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
+ (dim != output_feature_dim &&
+ slice->slice_limits(dim) !=
+ slice->operand(0)->shape().dimensions(dim))) {
+ VLOG(2) << "fail: Slice removes something other than the features dim.";
+ return false;
+ }
+ }
+ int64_t num_sliced_from_feature_dim =
+ slice->operand(0)->shape().dimensions(output_feature_dim) -
+ slice->slice_limits(output_feature_dim);
+
+ // If we slice off more than the known-zero output features, then we need to
+ // keep the slice -- it's "doing something".
+ if (num_sliced_from_feature_dim > *num_known_zero_output_features) {
+ VLOG(2) << "fail: Slice removes " << num_sliced_from_feature_dim
+ << " features from the conv, but only "
+ << *num_known_zero_output_features
+ << " features in the conv are known to be zero.";
+ return false;
+ }
+
+ // Check if we can merge the slice into the pad.
+ if (pad->padding_config().dimensions(output_feature_dim).interior_padding() !=
+ 0) {
+ VLOG(2)
+ << "fail: Can't merge slice into pad because pad adds interior padding "
+ "in feature dimension.";
+ return false;
+ }
+
+ // Okay! If we got here, it's legal to fold the slice into the pad. We pad
+ // less, because we know that the sliced-off elements are all 0. Ideally, the
+ // pad becomes a nop and gets eliminated by algsimp later.
+ VLOG(1) << "Eliminating " << num_sliced_from_feature_dim
+ << " elements of padding from conv " << conv->name();
+ PaddingConfig new_padding_config = pad->padding_config();
+ PaddingConfig::PaddingConfigDimension* new_pad_feature_dim =
+ new_padding_config.mutable_dimensions(output_feature_dim);
+ // This is safe even if the new edge_padding_high is negative -- negative
+ // padding is allowed.
+ new_pad_feature_dim->set_edge_padding_high(
+ new_pad_feature_dim->edge_padding_high() - num_sliced_from_feature_dim);
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_pad,
+ MakePadHlo(slice->mutable_operand(0),
+ pad->mutable_operand(1), new_padding_config));
+ TF_RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad));
+ return true;
+}
+
+} // anonymous namespace
+
+absl::StatusOr<bool> CudnnSimplifyPadding::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ TF_ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr));
+ changed |= c;
+ }
+ }
+ return changed;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h
new file mode 100644
index 0000000..67580b4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h
@@ -0,0 +1,67 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// Simplifies or eliminates padding introduced by CudnnPadForConvolutions and
+// CudnnVectorizeConvolutions.
+//
+// CudnnVectorizeConvolutions will generate code that does the following.
+// - pad input and output features to a multiple of 32 (or 4),
+// - reshape input from [N,C,H,W] to [N,C/32,H,W,32] and reshape kernel from
+// [I,O,H,W] to [I/32,32,O,H,W],
+// - run the conv,
+// - reshape output from [N,C/32,H,W,32] to [N,C,H,W], and finally
+// - slice off the padding on the C channel.
+//
+// But if this is followed by another convolution (very common), then the slice
+// is immediately followed by another pad. This may be redundant; we know that
+// the trailing channels sliced off from the first conv are 0.
+//
+// Ideally we can eliminate the whole reshape+slice+pad+reshape sequence between
+// the two convolutions.
+//
+// Specifically, this pass tries to merge the slice at the end of the sequence
+// above into the pad from the next convolution (when we can prove that the
+// sliced-off elements are all 0). We then rely on algsimp to remove the pad if
+// it's a nop and then to merge and eliminate the remaining reshapes.
+//
+// This pass should run after CudnnVectorizeConvolutions and there should be no
+// simplification passes in between that modify the reshape-transpose-reshape
+// introduced by int8x32 convolution filter reordering.
+class CudnnSimplifyPadding : public HloModulePass {
+ public:
+ CudnnSimplifyPadding() = default;
+
+ absl::string_view name() const override { return "cudnn_simplify_padding"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc
new file mode 100644
index 0000000..e924cca
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc
@@ -0,0 +1,771 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_simplify_padding.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/functional/function_ref.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "xla/literal.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/call_inliner.h"
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+#include "xla/service/hlo_pass_fix.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/reshape_mover.h"
+#include "xla/service/tuple_simplifier.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class CudnnSimplifyPaddingTest : public HloTestBase {
+ protected:
+ // Runs the whole relevant pass pipeline starting at CudnnPadForConvolutions.
+ // This lets us test that we're matching the patterns that actually get
+ // generated by padding+vectorization.
+ absl::StatusOr<bool> RunEndToEnd(std::pair<int, int> compute_capability,
+ HloModule* module) {
+ se::CudaComputeCapability cc{compute_capability.first,
+ compute_capability.second};
+
+ TF_RETURN_IF_ERROR(
+ RunHloPass(CudnnPadForConvolutions(cc), module).status());
+
+ TF_RETURN_IF_ERROR(
+ RunHloPass(CudnnVectorizeConvolutions(
+ cc, /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0}),
+ module)
+ .status());
+ VLOG(1) << "after vectorizing convs:\n" << module->ToString();
+
+ TF_RETURN_IF_ERROR(RunHloPass(CallInliner(), module).status());
+ VLOG(1) << "after inliner:\n" << module->ToString();
+
+ TF_RETURN_IF_ERROR(RunHloPass(TupleSimplifier(), module).status());
+ VLOG(1) << "after tuple simplifier:\n" << module->ToString();
+
+ TF_ASSIGN_OR_RETURN(bool changed,
+ RunHloPass(CudnnSimplifyPadding(), module));
+ VLOG(1) << "after simplify_padding:\n" << module->ToString();
+
+ {
+ // reshape-mover expects to be run alongside algsimp.
+ HloPassFix<HloPassPipeline> pipeline("reshape-mover and algsimp");
+ pipeline.AddPass<ReshapeMover>();
+ pipeline.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions());
+ TF_RETURN_IF_ERROR(RunHloPass(pipeline, module).status());
+ }
+ VLOG(1) << "after reshape mover + algsimp:\n" << module->ToString();
+
+ return changed;
+ }
+
+ absl::StatusOr<bool> RunJustThisPass(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(bool changed,
+ RunHloPass(CudnnSimplifyPadding(), module));
+ VLOG(1) << "after simplify_padding:\n" << module->ToString();
+
+ // I know the name says "just this pass", but you really want algsimp too,
+ // otherwise the resulting patterns are ugly/hard to match.
+ TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
+ AlgebraicSimplifierOptions()),
+ module)
+ .status());
+ return changed;
+ }
+};
+
+void ExpectOnlyPadsOneDim(int64_t dim, int64_t padding_high,
+ const PaddingConfig& p) {
+ SCOPED_TRACE(p.DebugString());
+ for (int i = 0; i < p.dimensions_size(); ++i) {
+ SCOPED_TRACE(absl::StrCat("dimension ", i));
+ EXPECT_EQ(p.dimensions(i).edge_padding_low(), 0);
+ if (i == dim) {
+ EXPECT_EQ(p.dimensions(i).edge_padding_high(), padding_high);
+ } else {
+ EXPECT_EQ(p.dimensions(i).edge_padding_high(), 0);
+ }
+ }
+}
+
+template <typename NativeT>
+void SetConstantValue(
+ HloInstruction* instr,
+ absl::FunctionRef<NativeT(absl::Span<const int64_t>, NativeT)> value_fn) {
+ Literal new_literal = instr->literal().Clone();
+ new_literal.MutableEachCell<int8_t>(value_fn);
+ TF_EXPECT_OK(instr->parent()->ReplaceWithNewInstruction(
+ instr, HloInstruction::CreateConstant(std::move(new_literal))));
+}
+
+TEST_F(CudnnSimplifyPaddingTest, EndToEnd) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ conv1 = (s8[10,20,30,190], u8[0]) custom-call(
+ s8[10,20,30,63] parameter(0), s8[3,5,63,190] parameter(1),
+ f32[10] parameter(2), s8[10,20,30,190] parameter(3)),
+ window={size=3x5}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBiasActivationForward"
+ conv1_result = get-tuple-element(conv1), index=0
+ ROOT conv2 = (s8[10,20,30,29], u8[0]) custom-call(
+ conv1_result, s8[3,5,190,29] parameter(4),
+ f32[10] parameter(5), s8[10,20,30,29] parameter(6)),
+ window={size=3x5}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBiasActivationForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ // conv2 should be fed directly from conv1, without any intervening
+ // reshapes/pads.
+ EXPECT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Slice(m::Reshape(m::GetTupleElement(m::CustomCall(
+ {"__cudnn$convBiasActivationForward"},
+ m::GetTupleElement(
+ m::CustomCall({"__cudnn$convBiasActivationForward"}), 0),
+ m::Op(), m::Op(), m::Op())))),
+ m::Op())));
+}
+
+TEST_F(CudnnSimplifyPaddingTest, EndToEndNCHW) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ conv1 = (s8[1,64,480,400], u8[0]) custom-call(
+ s8[1,112,480,400] parameter(0), s8[3,3,112,64] parameter(1),
+ f32[64] parameter(2)),
+ window={size=3x3}, dim_labels=bf01_01io->bf01,
+ custom_call_target="__cudnn$convBiasActivationForward"
+ conv1_result = get-tuple-element(conv1), index=0
+ convert = f32[1,64,480,400] convert(conv1_result)
+ constant = f32[] constant(0.349002093)
+ broadcast = f32[1,64,480,400] broadcast(constant)
+ ROOT multiply = f32[1,64,480,400] multiply(convert, broadcast)
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+ // The SimplifyPadding pass itself does not do anything.
+ EXPECT_FALSE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ // The reshape introduced by CudnnVectorizeConvolutions should have been moved
+ // to the root.
+ EXPECT_THAT(root, GmockMatch(m::Reshape(m::Multiply())));
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedWeights) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+ const HloInstruction* pad = nullptr;
+ ASSERT_THAT(root,
+ GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
+ m::ConstantScalar(0))));
+
+ ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
+}
+
+// This is similar to PaddedWeights, except the only 3 elements of the weights
+// are padded to 0 while we slice off 4 elements from the output features. As a
+// result, not all of the sliced elements are 0, and we can't merge the slice
+// into the pad that follows.
+TEST_F(CudnnSimplifyPaddingTest, PaddedWeightsNotPaddedEnough) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_3
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNCHW) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+ weights = s8[2,32,64,3,3] reshape(weights_p)
+ conv = (s8[10,2,32,10,10], u8[0]) custom-call(
+ s8[10,2,32,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=bf?01_i?o01->bf?01,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_5x0_0x0_0
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+ const HloInstruction* pad = nullptr;
+ ASSERT_THAT(
+ root, GmockMatch(
+ m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
+ m::ConstantScalar(0))));
+
+ ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/1, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNHWC) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights_p = pad(s8[3,3,64,60] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+ weights = s8[3,3,2,32,64] reshape(weights_p)
+ conv = (s8[10,10,10,2,32], u8[0]) custom-call(
+ s8[10,10,10,2,32] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f?_01i?o->b01f?,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,60] slice(s8[10,10,10,64] reshape(conv_result)), slice={[0:10], [0:10], [0:10], [0:60]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+ const HloInstruction* pad = nullptr;
+ ASSERT_THAT(
+ root, GmockMatch(
+ m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
+ m::ConstantScalar(0))));
+
+ ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedTransposedAndReshapedOutput) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+ weights = s8[2,32,64,3,3] reshape(weights_p)
+ conv = (s8[10,2,10,10,32], u8[0]) custom-call(
+ s8[10,2,10,10,32] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=bf01?_i?o01->bf01?,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
+ slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+ const HloInstruction* pad = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Pad(
+ &pad,
+ m::Reshape(m::Transpose(m::GetTupleElement(m::CustomCall(), 0))),
+ m::ConstantScalar(0))));
+
+ ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/2, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeight) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(0),
+ s8[3,3,10,10] constant({...})
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ // Set the constant's value. (The HLO text above sets it to all 0s.)
+ {
+ HloInstruction* weights = nullptr;
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
+ m::Op(), m::Constant(&weights)))),
+ m::Op())));
+ SetConstantValue<int8_t>(
+ weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
+ if (dims[3] < 6) return 1;
+ return 0;
+ });
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+ const HloInstruction* pad = nullptr;
+ ASSERT_THAT(root,
+ GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
+ m::ConstantScalar(0))));
+
+ ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeightIsNotLargeEnough) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(0),
+ s8[3,3,10,10] constant({...})
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ // Set the constant's value. (The HLO text above sets it to all 0s.)
+ {
+ HloInstruction* weights = nullptr;
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
+ m::Op(), m::Constant(&weights)))),
+ m::Op())));
+ SetConstantValue<int8_t>(
+ weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
+ // The sixth feature dimension (i.e. index 5) is only partially 0.
+ if (dims[3] < 5 /*|| (dims[3] == 5 && dims[2] > 1)*/) return 0;
+ return 1;
+ });
+ }
+
+ // Some of the value sliced off are not 0, so we can't merge the slice into
+ // the pad.
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, ReshapeDoesntMergeVectCDim) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+ weights = s8[2,64,3,3,32] reshape(weights_p)
+ conv = (s8[10,2,10,10,32], u8[0]) custom-call(
+ s8[10,2,10,10,32] parameter(1),
+ weights_p
+ ), window={size=3x3}, dim_labels=bf01?_io01?->bf01?,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInOutput) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+ weights = s8[2,64,3,3,32] reshape(weights_p)
+ conv = (s8[10,2,10,10,4,8], u8[0]) custom-call(
+ s8[10,2,10,10,32] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=bf01?_io01?->bf01??,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ conv_transposed = s8[10,2,4,8,10,10] transpose(conv_result), dimensions={0,1,4,5,2,3}
+ slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInKernel) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+ weights = s8[2,64,3,3,4,8] reshape(weights_p)
+ conv = (s8[10,2,10,10,32], u8[0]) custom-call(
+ s8[10,2,10,10,32] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=bf01?_io01??->bf01?,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
+ slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginning) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,9,10,6] slice(conv_result), slice={[0:10], [1:10], [0:10], [0:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginningOfFeatureDim) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,5] slice(conv_result), slice={[0:10], [0:10], [0:10], [1:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceHasStride) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,3] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6:2]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PadAddsInteriorPadding) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5_1
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceMoreElementsThanPad) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+ conv = (s8[10,10,10,10], u8[0]) custom-call(
+ s8[10,10,10,10] parameter(1),
+ weights
+ ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ conv_result = get-tuple-element(conv), index=0
+ slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+ ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_2
+ }
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+ const HloInstruction* slice = nullptr;
+ // The pass creates a pad with negative padding; this is simplified by algsimp
+ // into a slice.
+ ASSERT_THAT(root, GmockMatch(m::Slice(
+ &slice, m::GetTupleElement(m::CustomCall(), 0))));
+ for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
+ SCOPED_TRACE(i);
+ EXPECT_EQ(slice->slice_starts(i), 0);
+ EXPECT_EQ(slice->slice_strides(i), 1);
+ if (i != 3) {
+ EXPECT_EQ(slice->slice_limits(i), 10);
+ } else {
+ EXPECT_EQ(slice->slice_limits(i), 8);
+ }
+ }
+}
+
+TEST_F(CudnnSimplifyPaddingTest, NoChangeOnNonTrivialConstants) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule jit_outer
+
+ENTRY main.26 {
+ reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
+ constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
+ { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ }, {
+ { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ }, {
+ { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
+ cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
+ get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
+ slice.2 = f32[1,5,1,12]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:1], [0:12]}
+ constant.0 = f32[] constant(0)
+ ROOT pad.1 = f32[1,5,3,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x2_0x0_0
+}
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, NoChangeOnComplexSlices) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule jit_outer
+
+ENTRY main.26 {
+ reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
+ constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
+ { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ }, {
+ { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ }, {
+ { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
+ cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
+ get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
+ slice.2 = f32[1,5,5,4]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [2:6]}
+ constant.0 = f32[] constant(0)
+ ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_8
+}
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, ScanOrderFeatureDimLast) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule jit_outer
+
+ENTRY main.26 {
+ reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
+ constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
+ { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ }, {
+ { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+ }, {
+ { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+ { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
+ cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
+ get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
+ slice.2 = f32[1,5,5,6]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [0:6]}
+ constant.0 = f32[] constant(0)
+ ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_6
+}
+ )")
+ .value();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputFirst) {
+ // Test feature dimension calculation from reordering transpose (oi01)
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
+ s8[1,112,80,80] parameter(0), s8[63,112,3,3] parameter(1)),
+ window={size=3x3}, dim_labels=bf01_oi01->bf01,
+ custom_call_target="__cudnn$convForward"
+ gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
+ const.0 = s8[] constant(0)
+ ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputLast) {
+ // Test feature dimension calculation from reordering transpose (01io)
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
+ s8[1,112,80,80] parameter(0), s8[3,3,112,63] parameter(1)),
+ window={size=3x3}, dim_labels=bf01_01io->bf01,
+ custom_call_target="__cudnn$convForward"
+ gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
+ const.0 = s8[] constant(0)
+ ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+}
+
+} // anonymous namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc
new file mode 100644
index 0000000..698b8fb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc
@@ -0,0 +1,650 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/client/xla_builder.h"
+#include "xla/client/xla_computation.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/cudnn_support_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Finds convolutions that this pass may be able to transform, namely int8_t
+// cudnn forward or forward-bias-activation convolutions
+//
+// cudnn as of v8.2 supports the following data type combinations for forward
+// and forward-bias-activation convolutions. We have to make sure we only
+// vectorize to one of these supported configs.
+//
+// in out
+// int8x1 int8x1
+// int8x1 float
+// int8x1 int32_t
+//
+// int8x4 int8x4
+// int8x4 float
+//
+// int8x32 int8x32
+//
+// https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward
+//
+// For now we restrict ourselves to only the int8xN -> int8xN cases. We could
+// allow the int8x4 -> float case in the future if desirable.
+static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
+ HloComputation* comp) {
+ std::vector<HloCustomCallInstruction*> convs;
+ for (HloInstruction* instr : comp->instructions()) {
+ if (instr->opcode() != HloOpcode::kCustomCall ||
+ (instr->custom_call_target() != kCudnnConvForwardCallTarget &&
+ instr->custom_call_target() !=
+ kCudnnConvBiasActivationForwardCallTarget) ||
+ instr->operand_count() < 2) {
+ continue;
+ }
+
+ PrimitiveType input_ty = instr->operand(0)->shape().element_type();
+ PrimitiveType output_ty = instr->shape().tuple_shapes(0).element_type();
+ if (input_ty == output_ty && (input_ty == S8 || input_ty == U8)) {
+ convs.push_back(Cast<HloCustomCallInstruction>(instr));
+ }
+ }
+ return convs;
+}
+
+// Converts an XlaBuilder into an HloComputation in the same module as
+// `sibling_computation`.
+//
+// Yes, we serialize/deserialize as a proto. :)
+static absl::StatusOr<HloComputation*> BuilderToHloComputation(
+ XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
+ TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
+ TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
+ HloModuleConfig config(program_shape);
+ TF_ASSIGN_OR_RETURN(auto new_module,
+ HloModule::CreateFromProto(comp.proto(), config));
+
+ HloModule* dest_module = sibling_computation->parent();
+ HloCloneContext context(dest_module);
+ return dest_module->DeepCloneComputation(new_module->entry_computation(),
+ &context);
+}
+
+// Reshapes `instr` so that it has an extra dimension of size `vect_size` right
+// after `dim`.
+static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
+ XlaBuilder& b = *instr.builder();
+ Shape shape = b.GetShape(instr).value();
+ DimensionVector new_dims(shape.dimensions().begin(),
+ shape.dimensions().end());
+ CHECK_EQ(new_dims[dim] % vect_size, 0);
+ new_dims[dim] /= vect_size;
+ new_dims.insert(new_dims.begin() + dim + 1, vect_size);
+ return Reshape(instr, new_dims);
+}
+
+// Reshapes `shape` so that there's an extra dimension of size `vect_size` right
+// after `dim`.
+//
+// For example given shape=s8[10, 32, 20], dim=1, vect_size=4, returns
+// s8[10, 8, 4, 20].
+static Shape SplitShapeAtDim(Shape shape, int64_t dim, int64_t vect_size) {
+ DimensionVector new_dims(shape.dimensions().begin(),
+ shape.dimensions().end());
+ CHECK_EQ(new_dims[dim] % vect_size, 0);
+ new_dims[dim] /= vect_size;
+ new_dims.insert(new_dims.begin() + dim + 1, vect_size);
+ return ShapeUtil::MakeShape(shape.element_type(), new_dims);
+}
+
+// Transposes dimension `src` to right before `dst`.
+static XlaOp MoveDim(XlaOp instr, int64_t src, int64_t dst) {
+ XlaBuilder& b = *instr.builder();
+ int64_t rank = b.GetShape(instr)->dimensions_size();
+
+ DimensionVector idxs(rank);
+ absl::c_iota(idxs, 0);
+ if (src < dst) {
+ idxs.insert(idxs.begin() + dst, src);
+ idxs.erase(idxs.begin() + src);
+ } else {
+ idxs.erase(idxs.begin() + src);
+ idxs.insert(idxs.begin() + dst, src);
+ }
+ return Transpose(instr, idxs);
+}
+
+// Reshapes instr so that dimension `vect_dim` has size `vect_size`, by stealing
+// elements from `dim`.
+//
+// Requires that this is possible without merging and re-splitting the two
+// dimensions. I.e. there should be some amount of dim that we can "split off"
+// and add to vect_dim to get it to have size vect_size.
+static XlaOp RevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
+ int64_t vect_size) {
+ XlaBuilder& b = *instr.builder();
+ Shape shape = b.GetShape(instr).value();
+ auto size = [&](int64_t d) { return shape.dimensions(d); };
+
+ CHECK_LE(size(vect_dim), vect_size);
+ CHECK_EQ(vect_size % size(vect_dim), 0);
+
+ int64_t split_factor = vect_size / size(vect_dim);
+ CHECK_EQ(size(dim) % split_factor, 0);
+
+ // Split dim into [C, split_factor].
+ instr = SplitAtDim(instr, dim, split_factor);
+
+ // SplitAtDim may have added a dimension before vect_dim.
+ if (vect_dim > dim) {
+ vect_dim++;
+ }
+
+ // Move the split_factor dimension to right before vect_dim.
+ instr = MoveDim(instr, dim + 1, vect_dim);
+
+ // Moving the split_factor dimension may have *removed* a dimension before
+ // vect_dim.
+ if (vect_dim > dim) {
+ vect_dim--;
+ }
+
+ // Collapse the split_factor dimension into vect_dim.
+ return Collapse(instr, {vect_dim, vect_dim + 1});
+}
+
+// Inverse of RevectorizeInstr. Reshapes instr so that dimension `vect_dim` has
+// size `vect_size`, moving excess elements into `dim`.
+static XlaOp UnrevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
+ int64_t orig_vect_size) {
+ XlaBuilder& b = *instr.builder();
+ Shape shape = b.GetShape(instr).value();
+ auto size = [&](int64_t d) { return shape.dimensions(d); };
+
+ CHECK_GE(size(vect_dim), orig_vect_size);
+ CHECK_EQ(size(vect_dim) % orig_vect_size, 0);
+
+ // Split vect_dim into [C, orig_vect_size].
+ instr = SplitAtDim(instr, vect_dim, orig_vect_size);
+
+ // SplitAtDim may have added a dimension before dim.
+ if (dim > vect_dim) {
+ dim++;
+ }
+
+ // Move the `C` dimension to right after `dim`. Take into account that
+ // SplitAtDim may have added a dimension before dim.
+ instr = MoveDim(instr, vect_dim, dim + 1);
+
+ // MoveDim may have *removed* a dimension before dim.
+ if (dim > vect_dim) {
+ dim--;
+ }
+
+ // Collapse the `C` and `dim` dimensions.
+ return Collapse(instr, {dim, dim + 1});
+}
+
+// Adds a vectorized-feature dimension to dnums right after the current feature
+// dimension.
+//
+// ConvolutionDimensionNumbers doesn't represent the vectorized-feature
+// dimension explicitly, because the whole concept of a vectorized-feature
+// dimension is specific to cudnn. Rather, the vectorized-feature dimension is
+// implicit; it's the first dimension that *doesn't* appear in the dnums.
+//
+// This function "makes room" in dnums for the new vectorized dimension by
+// incrementing any dimensions which appear after the feature dim. The implicit
+// vector dim is then in this "empty" spot.
+static ConvolutionDimensionNumbers VectorizeDnums(
+ ConvolutionDimensionNumbers dnums, bool reordered_filter) {
+ int64_t input_vect_dim = dnums.input_feature_dimension();
+ if (dnums.input_batch_dimension() > input_vect_dim) {
+ dnums.set_input_batch_dimension(dnums.input_batch_dimension() + 1);
+ }
+ for (int64_t& d : *dnums.mutable_input_spatial_dimensions()) {
+ if (d > input_vect_dim) {
+ ++d;
+ }
+ }
+
+ if (!reordered_filter) {
+ int64_t kernel_vect_dim = dnums.kernel_input_feature_dimension();
+ if (dnums.kernel_output_feature_dimension() > kernel_vect_dim) {
+ dnums.set_kernel_output_feature_dimension(
+ dnums.kernel_output_feature_dimension() + 1);
+ }
+ for (int64_t& d : *dnums.mutable_kernel_spatial_dimensions()) {
+ if (d > kernel_vect_dim) {
+ ++d;
+ }
+ }
+ }
+
+ int64_t output_vect_dim = dnums.output_feature_dimension();
+ if (dnums.output_batch_dimension() > output_vect_dim) {
+ dnums.set_output_batch_dimension(dnums.output_batch_dimension() + 1);
+ }
+ for (int64_t& d : *dnums.mutable_output_spatial_dimensions()) {
+ if (d > output_vect_dim) {
+ ++d;
+ }
+ }
+
+ return dnums;
+}
+
+// Reorders the convolution's filter and bias (if present) according to
+// cudnnReorderFilterAndBias. Also marks that the filter + bias are reordered
+// in the conv's backend-config.
+absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv,
+ XlaOp* operands) {
+ bool has_bias = conv->operand_count() > 2;
+ VLOG(1) << "Reordering filter" << (has_bias ? " and bias" : "")
+ << " (replacement for cudnnReorderFilterAndBias)";
+
+ auto builder = operands->builder();
+ ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
+
+ // Update convolution backend config.
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ conv->backend_config<GpuBackendConfig>());
+ CudnnConvBackendConfig& config =
+ *gpu_config.mutable_cudnn_conv_backend_config();
+ config.set_reordered_int8_nchw_vect(true);
+ TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+
+ // Reorder the filter.
+ TF_ASSIGN_OR_RETURN(Shape filter_shape, builder->GetShape(operands[1]));
+ TF_ASSIGN_OR_RETURN(auto reorder, CudnnInferTransposeForFilterReordering(
+ filter_shape, dnums));
+ XlaOp reshape = Reshape(reorder.transpose_shape, operands[1]);
+ XlaOp transpose = Transpose(reshape, reorder.permutation);
+ operands[1] = Reshape(reorder.result_shape, transpose);
+
+ // The reshape-transpose-reshape we did above makes sure the resulting filter
+ // has dimension numbers corresponding to "oihw?", so update them.
+ dnums.set_kernel_output_feature_dimension(0);
+ dnums.set_kernel_input_feature_dimension(1);
+ dnums.set_kernel_spatial_dimensions(0, 2);
+ dnums.set_kernel_spatial_dimensions(1, 3);
+ conv->set_convolution_dimension_numbers(dnums);
+
+ if (has_bias) {
+ // Reorder the bias.
+ TF_ASSIGN_OR_RETURN(Shape bias_shape, builder->GetShape(operands[2]));
+ TF_ASSIGN_OR_RETURN(reorder,
+ CudnnInferTransposeForBiasReordering(bias_shape));
+ reshape = Reshape(reorder.transpose_shape, operands[2]);
+ transpose = Transpose(reshape, reorder.permutation);
+ operands[2] = Reshape(reorder.result_shape, transpose);
+ }
+ return absl::OkStatus();
+}
+
+// Tries to vectorize an already-vectorized convolution.
+//
+// That is, given a convolution of shape [N, C/k, H, W, k], changes it to have
+// shape [N, C/vect_size, H, W, vect_size]. Similarly changes the filter from
+// [H, W, I/k, O] to [H, W, I/vect_size, vect_size, O].
+//
+// (The dimensions can appear in any order; which is N/C/etc is determined by
+// the convolutions' dnums.)
+static absl::StatusOr<bool> TryRevectorizeConv(
+ const se::CudaComputeCapability& compute_capability,
+ const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
+ int vect_size) {
+ const Shape& input_shape = conv->operand(0)->shape();
+ const Shape& kernel_shape = conv->operand(1)->shape();
+ const Shape& output_shape = conv->shape().tuple_shapes(0);
+ const ConvolutionDimensionNumbers* dnums =
+ &conv->convolution_dimension_numbers();
+
+ // Find the vectorized-features dim in the input/kernel/output.
+ std::optional<int64_t> input_vect_dim;
+ std::optional<int64_t> kernel_vect_dim;
+ std::optional<int64_t> output_vect_dim;
+ std::tie(input_vect_dim, kernel_vect_dim, output_vect_dim) =
+ FindVectorizedFeatureDims(*dnums, input_shape, kernel_shape,
+ output_shape);
+
+ if (!input_vect_dim.has_value() || !kernel_vect_dim.has_value() ||
+ !output_vect_dim.has_value()) {
+ return false;
+ }
+
+ int64_t input_feat_size =
+ input_shape.dimensions(dnums->input_feature_dimension());
+ int64_t output_feat_size =
+ output_shape.dimensions(dnums->output_feature_dimension());
+ int64_t input_vect_size = input_shape.dimensions(*input_vect_dim);
+ int64_t output_vect_size = output_shape.dimensions(*output_vect_dim);
+ if (vect_size % input_vect_size != 0 || vect_size % output_vect_size != 0 ||
+ input_feat_size % (vect_size / input_vect_size) != 0 ||
+ output_feat_size % (vect_size / output_vect_size) != 0) {
+ return false;
+ }
+
+ // If this is an integer convolution check that we only vectorize when cuDNN
+ // supports the vectorized implementation.
+ if (primitive_util::IsIntegralType(input_shape.element_type())) {
+ TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
+ CudnnSupportsOptimizedIntegerConvolution(
+ compute_capability, *conv, vect_size));
+ if (!supported_target_vectorization) {
+ VLOG(3) << "Skipping re-vectorization of conv to vector size: "
+ << vect_size << ": " << conv->ToString();
+ return false;
+ }
+ }
+
+ VLOG(1) << "Re-vectorizing conv channels from "
+ << input_shape.dimensions(*input_vect_dim) << " to " << vect_size
+ << ": " << conv->ToString();
+
+ // We use XlaBuilder because it's a lot easier to get these tricky
+ // reshape/transposes correct using that API.
+ XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
+ b.SetOpMetadata(conv->metadata());
+
+ XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
+ absl::InlinedVector<XlaOp, 4> new_operands = {
+ RevectorizeInstr(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
+ dnums->input_feature_dimension(), *input_vect_dim,
+ vect_size),
+ RevectorizeInstr(filter, dnums->kernel_input_feature_dimension(),
+ *kernel_vect_dim, vect_size),
+ };
+ if (conv->operand_count() > 2) {
+ // Bias, if present. This is passed through unmodified.
+ new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
+ }
+ if (conv->operand_count() > 3) {
+ new_operands.push_back(RevectorizeInstr(
+ Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
+ dnums->input_feature_dimension(), *input_vect_dim, vect_size));
+ }
+
+ if (conv->operand_count() > 4) {
+ return InvalidArgument(
+ "Don't understand a conv with more than 4 arguments: %s",
+ conv->ToString());
+ }
+
+ // Reorder filter and bias for the int8x32 convolutions. This requires cudnn
+ // >= 8.3.0.
+ //
+ // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
+ const auto& debug_options = conv->GetModule()->config().debug_options();
+ bool use_reordering =
+ input_shape.element_type() == xla::S8 && vect_size == 32 &&
+ debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
+ cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
+ if (use_reordering) {
+ // Reordering helper supports vector sizes of 4 and 32, so an additional
+ // reshape-transpose-reshape is not necessary in these cases.
+ int64_t kernel_vect_size = kernel_shape.dimensions(*kernel_vect_dim);
+ if (kernel_vect_size == 4 || kernel_vect_size == 32) {
+ new_operands[1] = filter;
+ }
+ TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
+ dnums = &conv->convolution_dimension_numbers();
+ }
+
+ // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
+ // value in the tuple represents the convolution's scratch memory.
+ DimensionVector new_output_dims(output_shape.dimensions().begin(),
+ output_shape.dimensions().end());
+ new_output_dims[dnums->output_feature_dimension()] /=
+ (vect_size / output_vect_size);
+ new_output_dims[*output_vect_dim] = vect_size;
+ XlaOp new_conv = CustomCallWithConvDnums(
+ &b, conv->custom_call_target(), new_operands,
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(output_shape.element_type(), new_output_dims),
+ ShapeUtil::MakeShape(U8, {0})}),
+ /*operand_shapes_with_layout=*/{},
+ /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
+ /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
+ /*window=*/conv->window(),
+ /*dnums=*/*dnums);
+
+ XlaOp new_conv_result = GetTupleElement(new_conv, 0);
+ XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
+
+ XlaOp new_conv_result_unrevectorized = UnrevectorizeInstr(
+ new_conv_result, dnums->output_feature_dimension(), *output_vect_dim,
+ /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));
+
+ TF_ASSIGN_OR_RETURN(
+ HloComputation * new_conv_comp,
+ BuilderToHloComputation(
+ b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
+ conv->parent()));
+
+ // Set the name on the new conv. This is purely cosmetic, but we attempt to
+ // preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
+ auto new_conv_comp_instrs = new_conv_comp->instructions();
+ auto new_conv_it =
+ absl::c_find_if(new_conv_comp_instrs, [](HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kCustomCall;
+ });
+ if (new_conv_it != new_conv_comp_instrs.end()) {
+ new_conv_comp->parent()->SetAndUniquifyInstrName(*new_conv_it,
+ conv->name());
+ }
+
+ // Replace the old conv with a call to the computation we just created.
+ VLOG(1) << "Re-vectorized conv to " << new_conv_comp->ToString();
+ TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
+ conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
+ new_conv_comp)));
+
+ return true;
+}
+
+// Tries to vectorize a convolution.
+//
+// Given a convolution of dimensions [N, C, H, W], tries to convert it to have
+// shape [N, C/vect_size, H, W, vect_size]. Similarly, given a kernel of shape
+// [H, W, I, O], tries to conver it to [H, W, I/vect_size, vect_size, O].
+//
+// This requires that C be a multiple of vect_size. CudnnPadForConvolutions can
+// add padding to make this true.
+static absl::StatusOr<bool> TryVectorizeConv(
+ const se::CudaComputeCapability& compute_capability,
+ const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
+ int64_t vect_size) {
+ const Shape& input_shape = conv->operand(0)->shape();
+ const Shape& output_shape = conv->shape().tuple_shapes(0);
+ const ConvolutionDimensionNumbers* dnums =
+ &conv->convolution_dimension_numbers();
+ int64_t in_channels =
+ input_shape.dimensions(dnums->input_feature_dimension());
+ int64_t out_channels =
+ output_shape.dimensions(dnums->output_feature_dimension());
+
+ if (in_channels % vect_size != 0 || out_channels % vect_size != 0) {
+ return false;
+ }
+
+ if (input_shape.dimensions_size() >
+ 2 + dnums->input_spatial_dimensions_size()) {
+ // Conv already has an extra dimension, which we assume is the vectorized
+ // features dim.
+ return false;
+ }
+
+ // If this is an integer convolution check that we only vectorize when cuDNN
+ // supports the vectorized implementation.
+ if (primitive_util::IsIntegralType(input_shape.element_type())) {
+ TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
+ CudnnSupportsOptimizedIntegerConvolution(
+ compute_capability, *conv, vect_size));
+ if (!supported_target_vectorization) {
+ VLOG(3) << "Skipping vectorization of conv to vector size: " << vect_size
+ << ": " << conv->ToString();
+ return false;
+ }
+ }
+
+ VLOG(1) << "Vectorizing conv channels by " << vect_size << ": "
+ << conv->ToString();
+
+ // We use XlaBuilder because it's a lot easier to get these tricky
+ // reshape/transposes correct using that API.
+ XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
+ b.SetOpMetadata(conv->metadata());
+
+ XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
+ absl::InlinedVector<XlaOp, 4> new_operands = {
+ SplitAtDim(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
+ dnums->input_feature_dimension(), vect_size),
+ SplitAtDim(filter, dnums->kernel_input_feature_dimension(), vect_size),
+ };
+ if (conv->operand_count() > 2) {
+ // Bias, if present. This is passed through unmodified.
+ new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
+ }
+ if (conv->operand_count() > 3) {
+ // Handle side input, which has same shape as the output.
+ new_operands.push_back(
+ SplitAtDim(Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
+ dnums->output_feature_dimension(), vect_size));
+ }
+ if (conv->operand_count() > 4) {
+ return InvalidArgument(
+ "Don't understand a conv with more than 4 arguments: %s",
+ conv->ToString());
+ }
+
+ // Reorder filter and bias for the int8x32 convolutions. This requires cudnn
+ // >= 8.3.0.
+ //
+ // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
+ const auto& debug_options = conv->GetModule()->config().debug_options();
+ bool use_reordering =
+ input_shape.element_type() == xla::S8 && vect_size == 32 &&
+ debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
+ cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
+ if (use_reordering) {
+ new_operands[1] = filter;
+ TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
+ dnums = &conv->convolution_dimension_numbers();
+ }
+
+ // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
+ // value in the tuple represents the convolution's scratch memory.
+ Shape new_output_shape = SplitShapeAtDim(
+ output_shape, dnums->output_feature_dimension(), vect_size);
+ XlaOp new_conv = CustomCallWithConvDnums(
+ &b, conv->custom_call_target(), new_operands,
+ ShapeUtil::MakeTupleShape(
+ {new_output_shape, ShapeUtil::MakeShape(U8, {0})}),
+ /*operand_shapes_with_layout=*/{},
+ /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
+ /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
+ /*window=*/conv->window(),
+ /*dnums=*/VectorizeDnums(*dnums, use_reordering));
+
+ XlaOp new_conv_result = GetTupleElement(new_conv, 0);
+ XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
+
+ // Reshape back to the original shape.
+ XlaOp conv_result_collapsed =
+ Collapse(new_conv_result, {dnums->output_feature_dimension(),
+ dnums->output_feature_dimension() + 1});
+
+ TF_ASSIGN_OR_RETURN(
+ HloComputation * new_conv_comp,
+ BuilderToHloComputation(
+ b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
+ conv->parent()));
+
+ // Create a tuple and replace the old conv with it!
+ VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
+ TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
+ conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
+ new_conv_comp)));
+ return true;
+}
+
+} // namespace
+
+absl::StatusOr<bool> CudnnVectorizeConvolutions::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
+ // Try to (re)vectorize to int8x32 if this is an sm75+ GPU. If we can't,
+ // fall back to int8x4.
+ bool local_changed = false;
+ if (compute_capability_.IsAtLeast(7, 5)) {
+ TF_ASSIGN_OR_RETURN(
+ local_changed,
+ TryRevectorizeConv(compute_capability_, cudnn_version_, conv, 32));
+ if (!local_changed) {
+ TF_ASSIGN_OR_RETURN(
+ local_changed,
+ TryVectorizeConv(compute_capability_, cudnn_version_, conv, 32));
+ }
+ }
+ if (!local_changed) {
+ TF_ASSIGN_OR_RETURN(
+ local_changed,
+ TryVectorizeConv(compute_capability_, cudnn_version_, conv, 4));
+ }
+ changed |= local_changed;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h
new file mode 100644
index 0000000..6fd2f6e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h
@@ -0,0 +1,73 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+// Changes the shape of cudnn convolutions to allow faster "vectorized"
+// algorithms.
+//
+// On sm61+ will convert int8_t convolutions from
+//
+// - [N, C, H, W] to [N, C/4, H, W, 4],
+//
+// assuming C is divisible by 4.
+//
+// On sm75+ will convert int8_t convolutions from
+//
+// - [N, C, H, W] to [N, C/32, H, W, 32],
+// - [N, C/4, H, W, 4] to [N, C/32, H, W, 32], and
+// - [N, C, H, W] to [N, C/4, H, W, 4] (same as sm61+),
+//
+// assuming C is divisible by 4 or 32.
+//
+// This pass will not pad the channel dim to a multiple of 4 or 32, so you
+// should run CudnnPadForConvolutions before this.
+class CudnnVectorizeConvolutions : public HloModulePass {
+ public:
+ explicit CudnnVectorizeConvolutions(
+ se::CudaComputeCapability compute_capability,
+ se::dnn::VersionInfo cudnn_version)
+ : compute_capability_(compute_capability),
+ cudnn_version_(cudnn_version) {}
+
+ absl::string_view name() const override {
+ return "cudnn_vectorize_convolutions";
+ }
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const se::CudaComputeCapability compute_capability_;
+ const se::dnn::VersionInfo cudnn_version_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc
new file mode 100644
index 0000000..7528870
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc
@@ -0,0 +1,758 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/status/statusor.h"
+#include "xla/service/call_inliner.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class CudnnVectorizeConvolutionsTest : public HloTestBase {
+ protected:
+ // Runs this pass and some cleanup to make pattern-matching easier.
+ absl::StatusOr<bool> Run(std::pair<int, int> compute_capability,
+ HloModule* module) {
+ CudnnVectorizeConvolutions pass(
+ se::CudaComputeCapability{compute_capability.first,
+ compute_capability.second},
+ se::dnn::VersionInfo(8, 3, 0));
+ TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module));
+
+ CallInliner inliner;
+ TF_RETURN_IF_ERROR(RunHloPass(&inliner, module).status());
+
+ return changed;
+ }
+};
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,40] parameter(0)
+ filter = s8[2,2,40,44] parameter(1)
+ ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward",
+ backend_config="{bar: 0}"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 10, 4}),
+ m::Reshape(m::Parameter(1))
+ .WithShape(S8, {2, 2, 10, 4, 44}))
+ .WithConvDnums("b01f?_01i?o->b01f?"))
+ .WithShape(S8, {10, 20, 30, 11, 4})),
+ m::Op())));
+
+ EXPECT_EQ(conv->raw_backend_config_string(), "{bar: 0}");
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4UnsupportedFilterType) {
+ // This test checks that the vectorize pass correctly calls
+ // CudnnSupportsOptimizedIntegerConvolution() which should reject this
+ // convolution because its filter type is f32.
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,40] parameter(0)
+ filter = f32[2,2,40,44] parameter(1)
+ ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward",
+ backend_config="{bar: 0}"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4NCHW) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,48,20,30] parameter(0)
+ filter = s8[48,44,2,2] parameter(1)
+ ROOT result = (s8[10,44,20,30], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=bf01_io01->bf01,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 12, 4, 20, 30}),
+ m::Reshape(m::Parameter(1))
+ .WithShape(S8, {12, 4, 44, 2, 2}))
+ .WithConvDnums("bf?01_i?o01->bf?01"))
+ .WithShape(S8, {10, 11, 4, 20, 30})),
+ m::Op())));
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, IncrementAllDnums) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[16,16,16,16] parameter(0)
+ filter = s8[16,16,3,3] parameter(1)
+ ROOT result = (s8[16,16,16,16], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=fb01_i01o->fb01,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {4, 4, 16, 16, 16}),
+ m::Reshape(m::Parameter(1))
+ .WithShape(S8, {4, 4, 16, 3, 3}))
+ .WithConvDnums("f?b01_i?01o->f?b01"))
+ .WithShape(S8, {4, 4, 16, 16, 16})),
+ m::Op())));
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, FilterDnums) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[1,20,9,9] parameter(0)
+ filter = s8[3,3,20,32] parameter(1)
+ ROOT result = (s8[1,32,9,9], u8[0]) custom-call(s8[1,20,9,9] input, s8[3,3,20,32] filter),
+ window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {1, 5, 4, 9, 9}),
+ m::Reshape(m::Parameter(1))
+ .WithShape(S8, {3, 3, 5, 4, 32}))
+ .WithConvDnums("bf?01_01i?o->bf?01"))
+ .WithShape(S8, {1, 8, 4, 9, 9})),
+ m::Op())));
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,41] parameter(0)
+ filter = s8[2,2,41,44] parameter(1)
+ ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ CudnnVectorizeConvolutions pass(
+ /*compute_capability=*/{7, 5},
+ /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+
+ SCOPED_TRACE(module->ToString());
+ EXPECT_FALSE(changed);
+}
+
+// Don't vectorize int8_t -> int32_t into int8x4 or int8x32; this is not
+// supported in cudnn.
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsS32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,41] parameter(0)
+ filter = s8[2,2,41,44] parameter(1)
+ ROOT result = (s32[10,20,30,44], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ SCOPED_TRACE(module->ToString());
+ EXPECT_FALSE(changed);
+}
+
+// Don't vectorize int8_t -> float into int8x4 or int8x32. Vectorizing to
+// int8x4 *is* allowed by cudnn, but we don't do it at the moment.
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsF32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,41] parameter(0)
+ filter = s8[2,2,41,44] parameter(1)
+ ROOT result = (f32[10,20,30,44], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ SCOPED_TRACE(module->ToString());
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,64] parameter(0)
+ filter = s8[2,2,64,128] parameter(1)
+ ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 2, 32}),
+ m::Reshape(
+ m::Transpose(
+ m::Reshape(m::Parameter(1))
+ .WithShape(S8, {2, 2, 2, 8, 4, 16, 4, 2}))
+ .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
+ .WithPredicate([](const HloInstruction* instr) {
+ return absl::c_equal(
+ instr->dimensions(),
+ std::vector<int64_t>{2, 0, 1, 5, 7, 3, 6,
+ 4});
+ }))
+ .WithShape(S8, {128, 2, 2, 2, 32})))
+ .WithShape(S8, {10, 20, 30, 4, 32})),
+ m::Op())));
+
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,64] parameter(0)
+ filter = s8[2,2,64,128] parameter(1)
+ bias = f32[128] parameter(2)
+ side_input = s8[10,20,30,64] parameter(3)
+
+ ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter, bias, side_input),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 2, 32}),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
+ .WithShape(S8, {128, 2, 2, 2, 32}),
+ m::Reshape(
+ m::Transpose(m::Reshape(m::Parameter(2))
+ .WithShape(F32, {4, 4, 2, 4}))
+ .WithShape(F32, {4, 2, 4, 4})
+ .WithPredicate([](const HloInstruction* instr) {
+ return absl::c_equal(
+ instr->dimensions(),
+ std::vector<int64_t>{0, 2, 1, 3});
+ }))
+ .WithShape(F32, {128}),
+ m::Reshape(m::Parameter(3))
+ .WithShape(S8, {10, 20, 30, 2, 32})))
+ .WithShape(S8, {10, 20, 30, 4, 32})),
+ m::Op())));
+
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, InputNHWC_OutputNCHW) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,64] parameter(0)
+ filter = s8[2,2,64,128] parameter(1)
+ bias = f32[128] parameter(2)
+ side_input = s8[10,128,20,30] parameter(3)
+
+ ROOT result = (s8[10,128,20,30], u8[0]) custom-call(input, filter, bias, side_input),
+ window={size=2x2}, dim_labels=b01f_01io->bf01,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 2, 32}),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
+ .WithShape(S8, {128, 2, 2, 2, 32}),
+ m::Reshape(
+ m::Transpose(m::Reshape(m::Parameter(2))
+ .WithShape(F32, {4, 4, 2, 4}))
+ .WithShape(F32, {4, 2, 4, 4})
+ .WithPredicate([](const HloInstruction* instr) {
+ return absl::c_equal(
+ instr->dimensions(),
+ std::vector<int64_t>{0, 2, 1, 3});
+ }))
+ .WithShape(F32, {128}),
+ m::Reshape(m::Parameter(3))
+ .WithShape(S8, {10, 4, 32, 20, 30})))
+ .WithShape(S8, {10, 4, 32, 20, 30})),
+ m::Op())));
+
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,64] parameter(0)
+ filter = s8[2,2,64,128] parameter(1)
+ ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ ASSERT_THAT(
+ root,
+ GmockMatch(m::Tuple(
+ m::Reshape(m::GetTupleElement(
+ m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 16, 4}),
+ m::Reshape(m::Parameter(1))
+ .WithShape(S8, {2, 2, 16, 4, 128})))
+ .WithShape(S8, {10, 20, 30, 32, 4})),
+ m::Op())));
+
+ EXPECT_FALSE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,16,4] parameter(0)
+ filter = s8[3,5,16,192,4] parameter(1)
+ bias = f32[64] parameter(2)
+ side_input = s8[10,20,30,16,4] parameter(3)
+ ROOT result = (s8[10,20,30,48,4], u8[0]) custom-call(input, filter, bias, side_input),
+ window={size=3x5}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ auto conv_pat =
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+ .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+ .WithShape(S8, {10, 20, 30, 2, 32}),
+ m::Reshape(
+ m::Transpose(m::Reshape(m::Parameter(1))
+ .WithShape(S8, {3, 5, 2, 8, 24, 4, 2, 4}))
+ .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
+ .WithPredicate([](const HloInstruction* instr) {
+ return absl::c_equal(
+ instr->dimensions(),
+ std::vector<int64_t>{2, 0, 1, 4, 6, 3, 5, 7});
+ }))
+ .WithShape(S8, {192, 2, 3, 5, 32}),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
+ .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+ .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+ .WithShape(S8, {10, 20, 30, 2, 32}))
+ .WithConvDnums("b01f?_oi01?->b01f?"))
+ .WithShape(S8, {10, 20, 30, 6, 32});
+ ASSERT_THAT(root, GmockMatch(m::Tuple(
+ m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+ S8, {10, 20, 30, 6, 8, 4}))
+ .WithShape(S8, {10, 20, 30, 6, 8, 4}))
+ .WithShape(S8, {10, 20, 30, 48, 4}),
+ m::Op())));
+
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,16,20,30,4] parameter(0)
+ filter = s8[16,128,2,2,4] parameter(1)
+ bias = f32[64] parameter(2)
+ side_input = s8[10,16,20,30,4] parameter(3)
+ ROOT result = (s8[10,32,20,30,4], u8[0]) custom-call(input, filter, bias, side_input),
+ window={size=2x2}, dim_labels=bf01_io01->bf01,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ auto conv_pat =
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 2, 8, 20, 30, 4}))
+ .WithShape(S8, {10, 2, 20, 30, 8, 4}))
+ .WithShape(S8, {10, 2, 20, 30, 32}),
+ m::Reshape(
+ m::Transpose(m::Reshape(m::Parameter(1))
+ .WithShape(S8, {2, 8, 16, 4, 2, 2, 2, 4}))
+ .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
+ .WithPredicate([](const HloInstruction* instr) {
+ return absl::c_equal(
+ instr->dimensions(),
+ std::vector<int64_t>{0, 5, 6, 2, 4, 1, 3, 7});
+ }))
+ .WithShape(S8, {128, 2, 2, 2, 32}),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
+ .WithShape(S8, {10, 2, 8, 20, 30, 4}))
+ .WithShape(S8, {10, 2, 20, 30, 8, 4}))
+ .WithShape(S8, {10, 2, 20, 30, 32}))
+ .WithConvDnums("bf01_oi01->bf01"))
+ .WithShape(S8, {10, 4, 20, 30, 32});
+ ASSERT_THAT(root, GmockMatch(m::Tuple(
+ m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+ S8, {10, 4, 20, 30, 8, 4}))
+ .WithShape(S8, {10, 4, 8, 20, 30, 4}))
+ .WithShape(S8, {10, 32, 20, 30, 4}),
+ m::Op())));
+
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[4,10,20,30,16] parameter(0)
+ filter = s8[4,3,5,16,192] parameter(1)
+ bias = f32[64] parameter(2)
+ side_input = s8[4,10,20,30,16] parameter(3)
+ ROOT result = (s8[4,10,20,30,48], u8[0]) custom-call(input, filter, bias, side_input),
+ window={size=3x5}, dim_labels=?b01f_?01io->?b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ auto conv_pat =
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+ .WithShape(S8, {4, 10, 20, 30, 2, 8}))
+ .WithShape(S8, {8, 4, 10, 20, 30, 2}))
+ .WithShape(S8, {32, 10, 20, 30, 2}),
+ m::Reshape(
+ m::Transpose(m::Reshape(m::Parameter(1))
+ .WithShape(S8, {4, 3, 5, 2, 8, 24, 4, 2}))
+ .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
+ .WithPredicate([](const HloInstruction* instr) {
+ return absl::c_equal(
+ instr->dimensions(),
+ std::vector<int64_t>{3, 1, 2, 5, 7, 4, 6, 0});
+ }))
+ .WithShape(S8, {192, 2, 3, 5, 32}),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
+ .WithShape(S8, {4, 10, 20, 30, 2, 8}))
+ .WithShape(S8, {8, 4, 10, 20, 30, 2}))
+ .WithShape(S8, {32, 10, 20, 30, 2}))
+ .WithConvDnums("?b01f_oi01->?b01f"))
+ .WithShape(S8, {32, 10, 20, 30, 6});
+ ASSERT_THAT(root, GmockMatch(m::Tuple(
+ m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+ S8, {8, 4, 10, 20, 30, 6}))
+ .WithShape(S8, {4, 10, 20, 30, 6, 8}))
+ .WithShape(S8, {4, 10, 20, 30, 48}),
+ m::Op())));
+
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorize4To32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,16,4] parameter(0)
+ filter = s8[2,2,16,128,4] parameter(1)
+ bias = f32[10] parameter(2)
+ side_input = s8[10,20,30,16,4] parameter(3)
+ ROOT result = (s8[10,20,30,32,4], u8[0]) custom-call(input, filter, bias, side_input),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize16To32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,4,16] parameter(0)
+ filter = s8[3,5,4,192,16] parameter(1)
+ ROOT result = (s8[10,20,30,12,16], u8[0]) custom-call(input, filter),
+ window={size=3x5}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ auto filter_pat =
+ m::Reshape(
+ m::Transpose(
+ m::Reshape(m::Parameter(1)).WithShape(S8, {3, 5, 2, 2, 192, 16}))
+ .WithShape(S8, {3, 5, 2, 192, 2, 16}))
+ .WithShape(S8, {3, 5, 2, 192, 32});
+ auto conv_pat =
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(
+ m::Transpose(m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 2, 2, 16}))
+ .WithShape(S8, {10, 20, 30, 2, 2, 16}))
+ .WithShape(S8, {10, 20, 30, 2, 32}),
+ m::Reshape(
+ m::Transpose(m::Reshape(filter_pat)
+ .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
+ .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
+ .WithShape(S8, {192, 2, 3, 5, 32}))
+ .WithConvDnums("b01f_oi01->b01f"))
+ .WithShape(S8, {10, 20, 30, 6, 32});
+ ASSERT_THAT(root, GmockMatch(m::Tuple(
+ m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+ S8, {10, 20, 30, 6, 2, 16}))
+ .WithShape(S8, {10, 20, 30, 6, 2, 16}))
+ .WithShape(S8, {10, 20, 30, 12, 16}),
+ m::Op())));
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeMixedTo32) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = s8[10,20,30,8,8] parameter(0)
+ filter = s8[3,5,2,192,32] parameter(1)
+ ROOT result = (s8[10,20,30,96,2], u8[0]) custom-call(input, filter),
+ window={size=3x5}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })")
+ .value();
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+ EXPECT_TRUE(changed);
+
+ SCOPED_TRACE(module->ToString());
+ auto* root = module->entry_computation()->root_instruction();
+
+ const HloInstruction* conv = nullptr;
+ auto conv_pat =
+ m::GetTupleElement(
+ m::CustomCall(
+ &conv, {kCudnnConvForwardCallTarget},
+ m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+ .WithShape(S8, {10, 20, 30, 2, 4, 8}))
+ .WithShape(S8, {10, 20, 30, 2, 4, 8}))
+ .WithShape(S8, {10, 20, 30, 2, 32}),
+ m::Reshape(
+ m::Transpose(m::Reshape(m::Parameter(1))
+ .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
+ .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
+ .WithShape(S8, {192, 2, 3, 5, 32}))
+ .WithConvDnums("b01f_oi01->b01f"))
+ .WithShape(S8, {10, 20, 30, 6, 32});
+ ASSERT_THAT(root, GmockMatch(m::Tuple(
+ m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+ S8, {10, 20, 30, 6, 16, 2}))
+ .WithShape(S8, {10, 20, 30, 6, 16, 2}))
+ .WithShape(S8, {10, 20, 30, 96, 2}),
+ m::Op())));
+ EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+ ->cudnn_conv_backend_config()
+ .reordered_int8_nchw_vect());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc
new file mode 100644
index 0000000..af9591a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc
@@ -0,0 +1,240 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+CustomKernelFusionRewriter::CustomKernelFusionRewriter(
+ const se::DeviceDescription* device,
+ const CustomKernelFusionPatternRegistry* patterns)
+ : device_(device), patterns_(patterns) {}
+
+// Returns a set of instruction that have users outside of a matched pattern
+// and have a replacement that must be applied after building a new custom
+// fusion instruction. Only root instruction can have external users and does
+// not require a replacement, as the fusion itself is a replacement. If
+// instruction has external users and does not have a replacement returns empty
+// optional.
+static std::optional<absl::flat_hash_set<HloInstruction*>>
+GetPatternReplacements(const CustomKernelFusionPattern::Match& match) {
+ absl::flat_hash_set<HloInstruction*> requires_replacement;
+ absl::flat_hash_set<HloInstruction*> instructions_set(
+ match.instructions().begin(), match.instructions().end());
+
+ for (HloInstruction* instr : match.instructions()) {
+ for (HloInstruction* user : instr->users()) {
+ if (instr == match.root() || instructions_set.contains(user)) continue;
+
+ if (match.HasReplacement(instr)) {
+ requires_replacement.insert(instr);
+ continue;
+ }
+
+ VLOG(3) << "Custom kernel fusion intermediate result " << instr->name()
+ << " has users outside of a matched pattern: " << user->name();
+ return std::nullopt;
+ }
+ }
+
+ return requires_replacement;
+}
+
+// Returns instructions that have to become custom kernel fusion parameters.
+// Returns an error if matched pattern can't be outlined as a fusion.
+static absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
+ const CustomKernelFusionPattern::Match& match) {
+ absl::InlinedVector<HloInstruction*, 4> captures;
+
+ absl::flat_hash_set<HloInstruction*> instructions_set(
+ match.instructions().begin(), match.instructions().end());
+
+ for (HloInstruction* instr : match.instructions()) {
+ for (HloInstruction* operand : instr->operands()) {
+ if (!instructions_set.contains(operand) &&
+ absl::c_find(captures, operand) == captures.end()) {
+ captures.emplace_back(operand);
+ }
+ }
+ }
+
+ return captures;
+}
+
+// Creates custom kernel fusion computation and moves all matched instructions
+// into it.
+static absl::StatusOr<HloComputation*> CreateFusionBody(
+ HloModule* module, const CustomKernelFusionPattern::Match& match,
+ absl::Span<HloInstruction* const> captures) {
+ HloComputation::Builder builder(match.config().name());
+
+ // A mapping from original instructions to instructions in the fusion body.
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
+
+ auto mapped_operands = [&](HloInstruction* instr) {
+ absl::InlinedVector<HloInstruction*, 4> operands;
+ for (HloInstruction* operand : instr->operands()) {
+ operands.push_back(instr_mapping.at(operand));
+ }
+ return operands;
+ };
+
+ // For every captured value create a parameter instruction in the computation
+ // body and set up instruction mapping.
+ for (const HloInstruction* capture : captures) {
+ int64_t index = instr_mapping.size();
+ instr_mapping[capture] =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ index, capture->shape(), absl::StrCat("p", index)));
+ }
+
+ // TODO(ezhulenev): Instructions in the pattern must be topologically sorted,
+ // otherwise we'll get a crash! Figure out how to do it!
+ for (HloInstruction* instr : match.instructions()) {
+ instr_mapping[instr] = builder.AddInstruction(
+ instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
+ }
+
+ HloInstruction* root = builder.last_added_instruction();
+
+ // If custom kernel fusion requires a workspace we add a custom call that
+ // allocates workspace and return a tuple of "real" result and a workspace.
+ if (match.workspace_size_bytes() > 0) {
+ auto workspace_shape =
+ ShapeUtil::MakeShape(PrimitiveType::U8, {match.workspace_size_bytes()});
+ HloInstruction* workspace =
+ builder.AddInstruction(HloInstruction::CreateCustomCall(
+ workspace_shape, {}, CustomKernelFusionPattern::kWorkspace, "",
+ CustomCallApiVersion::API_VERSION_TYPED_FFI));
+ builder.AddInstruction(HloInstruction::CreateTuple({root, workspace}));
+ }
+
+ return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
+}
+
+static absl::StatusOr<HloInstruction*> CreateFusionInstruction(
+ HloModule* module, const CustomKernelFusionPattern::Match& match,
+ absl::Span<HloInstruction* const> captures, HloComputation* body) {
+ // We'll be replacing the root operation of a custom kernel fusion with a
+ // fusion instruction calling fusion computation.
+ HloInstruction* root = match.root();
+ HloComputation* parent = root->parent();
+
+ // Add a fusion operation calling outlined fusion computation.
+ HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
+ body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
+ captures, body));
+ module->SetAndUniquifyInstrName(fusion, match.config().name());
+
+ // Set backends config to a matched custom fusion config.
+ GpuBackendConfig gpu_config;
+ FusionBackendConfig& backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+ backend_config.set_kind("__custom_fusion");
+ *backend_config.mutable_custom_fusion_config() = match.config();
+ backend_config.mutable_custom_fusion_config()->set_kernel_index(0);
+ TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
+
+ // If we don't have workspace we can return constructed fusion instruction.
+ if (match.workspace_size_bytes() == 0) return fusion;
+
+ // Otherwise have to get result corresponding to the original value;
+ return parent->AddInstruction(
+ HloInstruction::CreateGetTupleElement(fusion, 0));
+}
+
+absl::StatusOr<bool> CustomKernelFusionRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ std::vector<CustomKernelFusionPattern::Match> matches;
+
+ // Collect all potential custom fusion matches in the module.
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instr : computation->instructions()) {
+ auto matched = patterns_->Match(*device_, instr);
+ matches.insert(matches.end(), matched.begin(), matched.end());
+ }
+ }
+
+ if (matches.empty()) return false;
+
+ for (const CustomKernelFusionPattern::Match& match : matches) {
+ VLOG(2) << "Matched custom kernel fusion " << match.config().name()
+ << "; root instruction: " << match.instructions().back()->name();
+
+ auto replacememts = GetPatternReplacements(match);
+ if (!replacememts.has_value()) continue;
+
+ auto captures = GetPatternCaptures(match);
+
+ TF_ASSIGN_OR_RETURN(HloComputation * fusion_body,
+ CreateFusionBody(module, match, captures));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * fusion,
+ CreateFusionInstruction(module, match, captures, fusion_body));
+
+ VLOG(2) << "Added a fusion instruction: " << fusion->name()
+ << " for custom kernel fusion " << match.config().name()
+ << " (instruction count = " << match.instructions().size() << ")";
+
+ for (HloInstruction* instr : *replacememts) {
+ VLOG(2) << "Replace matched instruction: " << instr->name()
+ << " with a pattern replacement";
+
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * replacement,
+ match.BuildReplacement(instr, Cast<HloFusionInstruction>(fusion)));
+
+ TF_RETURN_IF_ERROR(
+ instr->ReplaceAllUsesWith(replacement, match.config().name()));
+
+ VLOG(2) << "Replaced instruction: " << instr->name()
+ << " with: " << replacement->name();
+ }
+
+ VLOG(2) << "Replace custom kernel fusion root instruction "
+ << match.root()->name() << "with " << fusion->name();
+ HloComputation* parent = match.root()->parent();
+ TF_RETURN_IF_ERROR(parent->ReplaceInstruction(match.root(), fusion));
+ }
+
+ return true;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h
new file mode 100644
index 0000000..849fdbb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h
@@ -0,0 +1,86 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla::gpu {
+
+// Pattern matches HLO instruction to custom kernel fusions (hand written CUDA
+// C++ kernels, e.g. custom GEMMs implemented with CUTLASS) and rewrites them
+// into fusion instructions and fusion computations.
+//
+// Example: pattern matching dot operation into CUTLASS gemm
+//
+// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+// %p0 = f16[15,19]{1,0} parameter(0)
+// %p1 = f16[19,17]{1,0} parameter(1)
+// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+// lhs_contracting_dims={1}, rhs_contracting_dims={0}
+// }
+//
+// After the pass:
+//
+// %cutlass_gemm (p0: f16[19,17], p1: f16[15,19]) -> f16[15,17] {
+// %p0 = f16[15,19]{1,0} parameter(0)
+// %p1 = f16[19,17]{1,0} parameter(1)
+// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+// lhs_contracting_dims={1}, rhs_contracting_dims={0}
+// }
+//
+// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+// %p0 = f16[15,19]{1,0} parameter(0)
+// %p1 = f16[19,17]{1,0} parameter(1)
+// ROOT %r = f16[15,17]{1,0} fusion(%p0, %p1), kind=kCustom,
+// calls==cutlass_gemm,
+// backend_config={kind: "__custom_fusion",
+// custom_fusion_config: {"name":"cutlass_gemm"}}
+// }
+//
+class CustomKernelFusionRewriter : public HloModulePass {
+ public:
+ explicit CustomKernelFusionRewriter(
+ const se::DeviceDescription* device,
+ const CustomKernelFusionPatternRegistry* patterns =
+ CustomKernelFusionPatternRegistry::Default());
+
+ absl::string_view name() const override {
+ return "custom-kernel-fusion-rewriter";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const se::DeviceDescription* device_;
+ const CustomKernelFusionPatternRegistry* patterns_;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc
new file mode 100644
index 0000000..235e9de
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+
+//===----------------------------------------------------------------------===//
+// Simple pattern matchers for testing custom kernel_fusion rewriter.
+//===----------------------------------------------------------------------===//
+
+struct SimpleGemmPattern : public CustomKernelFusionPattern {
+ explicit SimpleGemmPattern(int64_t workspace = 0) : workspace(workspace) {}
+
+ std::optional<Match> TryMatch(const se::DeviceDescription& device,
+ HloInstruction* instr) const override {
+ if (auto* dot = DynCast<HloDotInstruction>(instr)) {
+ CustomFusionConfig config;
+ config.set_name("simple_gemm");
+ return Match{config, {instr}, workspace};
+ }
+ return std::nullopt;
+ }
+
+ int64_t workspace;
+};
+
+//===----------------------------------------------------------------------===//
+
+class CustomKernelFusionRewriterTest : public HloTestBase {};
+
+TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+ %p0 = f16[15,19]{1,0} parameter(0)
+ %p1 = f16[19,17]{1,0} parameter(1)
+ ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %simple_gemm {{.*}} {
+ ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
+ ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
+ ; CHECK: ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
+ ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main {{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[15,17]{1,0} fusion
+ ; CHECK: kind=kCustom, calls=%simple_gemm,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ CustomKernelFusionPatternRegistry patterns;
+ patterns.Emplace<SimpleGemmPattern>();
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ CustomKernelFusionRewriter pass(&device, &patterns);
+ RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
+}
+
+TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+ %p0 = f16[15,19]{1,0} parameter(0)
+ %p1 = f16[19,17]{1,0} parameter(1)
+ ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %simple_gemm {{.*}} {
+ ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
+ ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
+ ; CHECK: [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
+ ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ; CHECK: [[WORKSPACE:%[^ ]+]] = u8[1024]{0} custom-call(),
+ ; CHECK: custom_call_target="__custom_kernel_fusion$workspace"
+ ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0})
+ ; CHECK: tuple([[DOT]], [[WORKSPACE]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main {{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%simple_gemm,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT {{.*}} get-tuple-element([[FUSION]]), index=0
+ ; CHECK: }
+ )";
+
+ CustomKernelFusionPatternRegistry patterns;
+ patterns.Emplace<SimpleGemmPattern>(1024);
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ CustomKernelFusionRewriter pass(&device, &patterns);
+ RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc
new file mode 100644
index 0000000..b1e0b98
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc
@@ -0,0 +1,136 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/permutation_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// Sort contracting dimensions of a dot() instruction preserving lhs-rhs pairs.
+absl::Status SortDotDimensions(HloDotInstruction* dot) {
+ const DotDimensionNumbers& dims = dot->dot_dimension_numbers();
+ DotDimensionNumbers new_dims(dims);
+ new_dims.clear_lhs_contracting_dimensions();
+ new_dims.clear_rhs_contracting_dimensions();
+ const bool sort_by_lhs =
+ DistinctNumbersAreConsecutiveIfSorted(dims.lhs_contracting_dimensions());
+ // Sort lhs and rhs by sort_key using the fact that
+ // sort_key is guaranteed to have only distinct consecutive numbers.
+ const absl::Span<const int64_t>& sort_key =
+ sort_by_lhs ? dims.lhs_contracting_dimensions()
+ : dims.rhs_contracting_dimensions();
+ std::vector<int64_t> permutation;
+ for (const int64_t a : sort_key) {
+ permutation.push_back(a - *absl::c_min_element(sort_key));
+ }
+ const std::vector<int64_t> sorted_lhs =
+ Permute(dims.lhs_contracting_dimensions(), permutation);
+ *new_dims.mutable_lhs_contracting_dimensions() = {sorted_lhs.begin(),
+ sorted_lhs.end()};
+ const std::vector<int64_t> sorted_rhs =
+ Permute(dims.rhs_contracting_dimensions(), permutation);
+ *new_dims.mutable_rhs_contracting_dimensions() = {sorted_rhs.begin(),
+ sorted_rhs.end()};
+ std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
+ dot->shape(), dot->mutable_operand(0), dot->mutable_operand(1), new_dims,
+ dot->precision_config(), {dot->sparsity().begin(), dot->sparsity().end()},
+ absl::MakeSpan(dot->operands()).subspan(HloDotInstruction::kOperands));
+ dot->SetupDerivedInstruction(new_dot.get());
+
+ VLOG(3) << "Sorted dot() dimensions:\n"
+ << "\t before: " << dot->ToString() << "\n"
+ << "\t after: " << new_dot->ToString();
+ return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
+}
+
+} // namespace
+
+absl::StatusOr<bool> DotDimensionSorter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ std::vector<HloInstruction*> dots_to_process;
+ for (const HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* instr : computation->instructions()) {
+ if (instr->opcode() != HloOpcode::kDot) {
+ continue;
+ }
+ // TODO(b/265688934): should non-default layouts be expected here at all?
+ if ((instr->operand(0)->shape().has_layout() &&
+ !LayoutUtil::IsMonotonicWithDim0Major(
+ instr->operand(0)->shape().layout())) ||
+ (instr->operand(1)->shape().has_layout() &&
+ !LayoutUtil::IsMonotonicWithDim0Major(
+ instr->operand(1)->shape().layout()))) {
+ continue;
+ }
+ const DotDimensionNumbers& dims = instr->dot_dimension_numbers();
+ if (dims.lhs_contracting_dimensions_size() == 0) {
+ continue;
+ }
+ const bool cons_lhs = DistinctNumbersAreConsecutiveIfSorted(
+ dims.lhs_contracting_dimensions());
+ const bool cons_rhs = DistinctNumbersAreConsecutiveIfSorted(
+ dims.rhs_contracting_dimensions());
+ const bool sorted_lhs =
+ absl::c_is_sorted(dims.lhs_contracting_dimensions());
+ const bool sorted_rhs =
+ absl::c_is_sorted(dims.rhs_contracting_dimensions());
+ // The side to be sorted has to be consecutive and not sorted yet;
+ // the other side should not get worsened.
+ // TODO(b/265688934): we may still want to change which one is sorted
+ // if this reduces the amount of transposed data.
+ if ((cons_lhs && !sorted_lhs && !cons_rhs) ||
+ (cons_rhs && !sorted_rhs && !cons_lhs) ||
+ (cons_lhs && !sorted_lhs && cons_rhs && !sorted_rhs)) {
+ dots_to_process.push_back(instr);
+ }
+ }
+ }
+ if (dots_to_process.empty()) {
+ return false;
+ }
+ for (HloInstruction* dot : dots_to_process) {
+ TF_RETURN_IF_ERROR(SortDotDimensions(Cast<HloDotInstruction>(dot)));
+ }
+ return true;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h
new file mode 100644
index 0000000..872fb72
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h
@@ -0,0 +1,52 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Sorts contracting dimensions of dot() operands when this reduces the
+// number of transposes. Example:
+// dot(p0, p1), lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1} ->
+// dot(p0, p1), lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+// The first case gets transposes inserted by dot_decomposer, the second one
+// does not and thus is generally more efficient.
+
+// TODO(b/265688934): do the same for batch dimensions?
+
+class DotDimensionSorter : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "dot_dimension_sorter"; }
+
+ // Run the pass on computations in 'module'.
+ // Returns whether the 'module' was changed.
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc
new file mode 100644
index 0000000..364c140
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc
@@ -0,0 +1,191 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class WithoutDotDimensionSorterTest : public GpuCodegenTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+ // The pass is disabled here to preserve suboptimal dimension order in
+ // 1) UnsortedDimsCreateTransposes to reveal the transposes.
+ // 2) DimOrderCanBeChanged for the comparison of ordered vs unordered.
+ // The pass does not touch SortedDimsDoNotCreateTransposes anyway because
+ // the dimensions are already ordered there.
+ debug_options.add_xla_disable_hlo_passes("dot_dimension_sorter");
+ return debug_options;
+ }
+};
+
+TEST_F(WithoutDotDimensionSorterTest, UnsortedDimsCreateTransposes) {
+ const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[1,14,9,32] parameter(0)
+ p1 = f16[12,9,32] parameter(1)
+ ROOT _ = f16[1,14,12] dot(p0, p1),
+ lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK: transpose
+)");
+}
+
+TEST_F(WithoutDotDimensionSorterTest, SortedDimsDoNotCreateTransposes) {
+ const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[1,14,9,32] parameter(0)
+ p1 = f16[12,9,32] parameter(1)
+ ROOT _ = f16[1,14,12] dot(p0, p1),
+ lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+}
+)";
+
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT: transpose
+)");
+}
+
+TEST_F(WithoutDotDimensionSorterTest, DimOrderCanBeChanged) {
+ const char* hlo_text_ref = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[1,14,9,32] parameter(0)
+ p1 = f16[12,9,32] parameter(1)
+ ROOT _ = f16[1,14,12] dot(p0, p1),
+ lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+ const char* hlo_text_modified = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[1,14,9,32] parameter(0)
+ p1 = f16[12,9,32] parameter(1)
+ ROOT _ = f16[1,14,12] dot(p0, p1),
+ lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+}
+)";
+
+ EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_modified,
+ ErrorSpec{1e-5, 1e-3},
+ /*run_hlo_passes=*/true));
+}
+
+using DotDimensionSorterTest = GpuCodegenTest;
+
+TEST_F(DotDimensionSorterTest, SortContractingDims) {
+ const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[1,144,96,32] parameter(0)
+ p1 = f16[122,96,32] parameter(1)
+ ROOT _ = f16[1,144,122] dot(p0, p1),
+ lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ const auto& dims =
+ module->entry_computation()->root_instruction()->dot_dimension_numbers();
+
+ EXPECT_EQ(dims.lhs_contracting_dimensions(0), 3);
+ EXPECT_EQ(dims.lhs_contracting_dimensions(1), 2);
+
+ EXPECT_EQ(dims.rhs_contracting_dimensions(0), 2);
+ EXPECT_EQ(dims.rhs_contracting_dimensions(1), 1);
+
+ TF_ASSERT_OK_AND_ASSIGN(bool modified,
+ DotDimensionSorter().Run(module.get()));
+ EXPECT_TRUE(modified);
+ const auto& dims2 =
+ module->entry_computation()->root_instruction()->dot_dimension_numbers();
+
+ EXPECT_EQ(dims2.lhs_contracting_dimensions(0), 2);
+ EXPECT_EQ(dims2.lhs_contracting_dimensions(1), 3);
+
+ EXPECT_EQ(dims2.rhs_contracting_dimensions(0), 1);
+ EXPECT_EQ(dims2.rhs_contracting_dimensions(1), 2);
+}
+
+TEST_F(DotDimensionSorterTest, NothingToReorder) {
+ const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[1,144,96,32] parameter(0)
+ p1 = f16[122,96,32] parameter(1)
+ ROOT _ = f16[1,144,122] dot(p0, p1),
+ lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+
+ TF_ASSERT_OK_AND_ASSIGN(bool modified,
+ DotDimensionSorter().Run(module.get()));
+ EXPECT_FALSE(modified);
+}
+
+TEST_F(DotDimensionSorterTest, SparseDotSortContractingDims) {
+ const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[1,144,96,16] parameter(0)
+ p1 = f16[122,96,32] parameter(1)
+ meta = u16[1,144,96,2] parameter(2)
+ ROOT _ = f16[1,144,122] dot(p0, p1, meta), sparsity=L.3@2:4,
+ lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ TF_ASSERT_OK_AND_ASSIGN(bool modified,
+ DotDimensionSorter().Run(module.get()));
+ EXPECT_TRUE(modified);
+ HloDotInstruction* dot = DynCast<HloDotInstruction>(
+ module->entry_computation()->root_instruction());
+ EXPECT_TRUE(dot != nullptr && dot->sparse_operands() == 1);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc
new file mode 100644
index 0000000..d9e095e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc
@@ -0,0 +1,74 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_operand_converter.h"
+
+#include "absl/status/statusor.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+
+namespace xla::gpu {
+
+bool DotOperandConverter::InstructionMatchesPattern(
+ HloInstruction* instruction) {
+ if (instruction->opcode() != HloOpcode::kDot) {
+ return false;
+ }
+ HloInstruction* lhs = instruction->mutable_operand(0);
+ HloInstruction* rhs = instruction->mutable_operand(1);
+
+ PrimitiveType lhs_type = lhs->shape().element_type();
+ PrimitiveType rhs_type = rhs->shape().element_type();
+
+ if (lhs_type == rhs_type) {
+ return false;
+ }
+
+ // Exclude conversions between FP8 types.
+ absl::flat_hash_set<PrimitiveType> non_converting = {F8E4M3FN, F8E5M2};
+ if (non_converting.contains(lhs_type) && non_converting.contains(rhs_type)) {
+ return false;
+ }
+
+ PrimitiveType desired_type =
+ ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
+
+ return desired_type == lhs_type || desired_type == rhs_type;
+}
+
+absl::StatusOr<HloInstruction*> DotOperandConverter::ExpandInstruction(
+ HloInstruction* instruction) {
+ HloInstruction* lhs = instruction->mutable_operand(0);
+ HloInstruction* rhs = instruction->mutable_operand(1);
+
+ // Find the higher precision type among the two operands, and add a convert
+ // instruction to convert the lesser-precise operand to that type.
+ PrimitiveType desired_type =
+ ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
+ int operand_index = desired_type == lhs->shape().element_type() ? 1 : 0;
+ HloInstruction* inst_to_replace =
+ desired_type == lhs->shape().element_type() ? rhs : lhs;
+ auto upcast_shape = inst_to_replace->shape();
+ upcast_shape.set_element_type(desired_type);
+ auto* convert_inst = instruction->AddInstruction(
+ HloInstruction::CreateConvert(upcast_shape, inst_to_replace));
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape(
+ operand_index, convert_inst));
+ return nullptr;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h
new file mode 100644
index 0000000..b269bed
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h
@@ -0,0 +1,46 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_
+
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/op_expander_pass.h"
+#include "xla/util.h"
+
+namespace xla::gpu {
+
+// Converts both operands to the highest precision operand type.
+class DotOperandConverter : public OpExpanderPass {
+ public:
+ explicit DotOperandConverter(HloPredicate extra_filter = nullptr)
+ : OpExpanderPass(std::move(extra_filter)) {}
+
+ absl::string_view name() const override { return "operand_converter"; }
+
+ protected:
+ bool InstructionMatchesPattern(HloInstruction* instruction) override;
+
+ absl::StatusOr<HloInstruction*> ExpandInstruction(
+ HloInstruction* instruction) override;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc
new file mode 100644
index 0000000..be05b67
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc
@@ -0,0 +1,142 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_operand_converter.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/utils/hlo_matchers.h"
+#include "xla/primitive_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace op = ::xla::testing::opcode_matchers;
+
+class DotOperandConverterTest : public HloTestBase {
+ public:
+ void TestConvert(bool left_less_precise, PrimitiveType lhs_type,
+ PrimitiveType rhs_type, PrimitiveType result_type) {
+ absl::string_view module_tmpl = R"(
+ HloModule module
+
+ ENTRY main {
+ p0 = $0[2,3]{1,0} parameter(0)
+ p1 = $1[3,2]{1,0} parameter(1)
+ ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+ rhs_contracting_dims={0}
+ })";
+ auto module_string = absl::Substitute(
+ module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type),
+ primitive_util::LowercasePrimitiveTypeName(rhs_type),
+ primitive_util::LowercasePrimitiveTypeName(result_type));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
+ DotOperandConverter().Run(module.get()));
+ EXPECT_TRUE(upcasted);
+ if (left_less_precise) {
+ auto original_lhs = op::Parameter(0);
+ auto upcasted_lhs =
+ AllOf(op::Convert(original_lhs),
+ op::Shape(absl::Substitute(
+ "$0[2,3]{1,0}",
+ primitive_util::LowercasePrimitiveTypeName(rhs_type))));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ AllOf(op::Dot(upcasted_lhs, op::Parameter(1)),
+ op::Shape(absl::Substitute(
+ "$0[2,2]{1,0}",
+ primitive_util::LowercasePrimitiveTypeName(result_type)))));
+ } else {
+ auto original_rhs = op::Parameter(1);
+ auto upcasted_rhs =
+ AllOf(op::Convert(original_rhs),
+ op::Shape(absl::Substitute(
+ "$0[3,2]{1,0}",
+ primitive_util::LowercasePrimitiveTypeName(lhs_type))));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ AllOf(op::Dot(op::Parameter(0), upcasted_rhs),
+ op::Shape(absl::Substitute(
+ "$0[2,2]{1,0}",
+ primitive_util::LowercasePrimitiveTypeName(result_type)))));
+ }
+ }
+};
+
+TEST_F(DotOperandConverterTest, ConvertsLeftAndRight) {
+ TestConvert(/*left_less_precise=*/true, S8, BF16, F32);
+ TestConvert(/*left_less_precise=*/false, BF16, S8, F32);
+}
+
+TEST_F(DotOperandConverterTest, NoConvertHappensWithSameTypes) {
+ absl::string_view module_string = R"(
+ HloModule module
+
+ ENTRY main {
+ p0 = s8[2,3]{1,0} parameter(0)
+ p1 = s8[3,2]{1,0} parameter(1)
+ ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+ rhs_contracting_dims={0}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
+ DotOperandConverter().Run(module.get()));
+ EXPECT_FALSE(upcasted);
+}
+
+TEST_F(DotOperandConverterTest, NoConvertFromF8toF8) {
+ absl::string_view module_string = R"(
+ HloModule module
+
+ ENTRY main {
+ p0 = f8e4m3fn[2,3]{1,0} parameter(0)
+ p1 = f8e5m2[3,2]{1,0} parameter(1)
+ ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+ rhs_contracting_dims={0}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
+ DotOperandConverter().Run(module.get()));
+ EXPECT_FALSE(upcasted);
+}
+
+TEST_F(DotOperandConverterTest, CompilerOptimizesUsingDotOperandConverter) {
+ absl::string_view module_string = R"(
+ HloModule module
+
+ ENTRY main {
+ p0 = s8[2,3]{1,0} parameter(0)
+ p1 = bf16[3,2]{1,0} parameter(1)
+ ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+ rhs_contracting_dims={0}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ GetOptimizedModule(module_string));
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc
new file mode 100644
index 0000000..637689a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc
@@ -0,0 +1,110 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h"
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class SparseDotRewriterImpl : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleDot(HloInstruction* instr) override {
+ // Only handle sparse dots with a single RHS sparse descriptor.
+ HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
+ if (dot->sparse_operands() != 1 || dot->sparsity().front().index() != 1) {
+ return absl::OkStatus();
+ }
+
+ HloInstruction* lhs = dot->mutable_operand(0);
+ HloInstruction* rhs = dot->mutable_operand(1);
+ HloInstruction* meta = dot->mutable_operand(2);
+
+ // Swap LHS and RHS in the attributes.
+ DotDimensionNumbers dnums = dot->dot_dimension_numbers();
+ std::swap(*dnums.mutable_lhs_batch_dimensions(),
+ *dnums.mutable_rhs_batch_dimensions());
+ std::swap(*dnums.mutable_lhs_contracting_dimensions(),
+ *dnums.mutable_rhs_contracting_dimensions());
+
+ PrecisionConfig precision_config = dot->precision_config();
+ std::swap(precision_config.mutable_operand_precision()->at(0),
+ precision_config.mutable_operand_precision()->at(1));
+
+ SparsityDescriptor sparsity = dot->sparsity().front();
+ sparsity.set_index(0);
+
+ // Create new dot with LHS and RHS swapped.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_dot,
+ MakeDotHlo(rhs, lhs, dnums, precision_config,
+ dot->shape().element_type(), {std::move(sparsity)}, {meta}));
+ dot->SetupDerivedInstruction(new_dot);
+
+ // Result dimensions: <batch>, <rhs_noncontracting>, <lhs_noncontracting>
+ int batch_dims = dnums.lhs_batch_dimensions().size();
+ int new_lhs_noncontracting = rhs->shape().rank() - batch_dims -
+ dnums.lhs_contracting_dimensions().size();
+ int new_rhs_noncontracting = lhs->shape().rank() - batch_dims -
+ dnums.rhs_contracting_dimensions().size();
+
+ int rank = dot->shape().rank();
+ DimensionVector dimensions(rank);
+ for (int i = 0; i < batch_dims; ++i) {
+ dimensions[i] = i;
+ }
+ for (int i = 0; i < new_lhs_noncontracting; ++i) {
+ dimensions[i + batch_dims] = i + batch_dims + new_rhs_noncontracting;
+ }
+ for (int i = 0; i < new_rhs_noncontracting; ++i) {
+ dimensions[i + batch_dims + new_lhs_noncontracting] = i + batch_dims;
+ }
+
+ // Transpose the result.
+ TF_ASSIGN_OR_RETURN(HloInstruction * transpose,
+ MakeTransposeHlo(new_dot, dimensions));
+ transpose->set_metadata(dot->metadata());
+ *transpose->mutable_shape()->mutable_layout() = dot->shape().layout();
+
+ return ReplaceInstruction(dot, transpose);
+ }
+};
+
+} // namespace
+
+absl::StatusOr<bool> DotSparsityRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ return SparseDotRewriterImpl().RunOnModule(module, execution_threads);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h
new file mode 100644
index 0000000..b912e2b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h
@@ -0,0 +1,42 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Make sure sparse dot requirements are met (sparse operand is LHS).
+class DotSparsityRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "dot_sparsity_rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc
new file mode 100644
index 0000000..28f813f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+
+class DotSparsityRewriterTest : public HloTestBase {
+ public:
+ DotSparsityRewriterTest() : HloTestBase(/*verifier_layout_sensitive=*/true) {}
+};
+
+TEST_F(DotSparsityRewriterTest, SparseDotRhsToLhs) {
+ const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+ lhs = f16[4,2,16,8,64] parameter(0)
+ rhs = f16[2,4,8,32,128] parameter(1)
+ meta = u16[2,4,8,4,128] parameter(2)
+ ROOT dot = f16[4,2,16,128] dot(lhs, rhs, meta),
+ lhs_contracting_dims={3,4}, rhs_contracting_dims={2,3},
+ lhs_batch_dims={0,1}, rhs_batch_dims={1,0}, sparsity=R.3@2:4
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ TF_ASSERT_OK_AND_ASSIGN(bool modified,
+ DotSparsityRewriter().Run(module.get()));
+ EXPECT_TRUE(modified);
+
+ const HloTransposeInstruction* transpose = DynCast<HloTransposeInstruction>(
+ module->entry_computation()->root_instruction());
+ ASSERT_TRUE(transpose != nullptr);
+ EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 1, 3, 2));
+
+ const HloDotInstruction* dot =
+ DynCast<HloDotInstruction>(transpose->operand(0));
+ ASSERT_TRUE(dot != nullptr);
+
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+ EXPECT_EQ(dnums.lhs_contracting_dimensions(0), 2);
+ EXPECT_EQ(dnums.lhs_contracting_dimensions(1), 3);
+ EXPECT_EQ(dnums.rhs_contracting_dimensions(0), 3);
+ EXPECT_EQ(dnums.rhs_contracting_dimensions(1), 4);
+ EXPECT_EQ(dnums.lhs_batch_dimensions(0), 1);
+ EXPECT_EQ(dnums.lhs_batch_dimensions(1), 0);
+ EXPECT_EQ(dnums.rhs_batch_dimensions(0), 0);
+ EXPECT_EQ(dnums.rhs_batch_dimensions(1), 1);
+
+ EXPECT_EQ(dot->sparse_operands(), 1);
+ EXPECT_EQ(dot->sparsity().front().index(), 0);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc
new file mode 100644
index 0000000..a4f901f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc
@@ -0,0 +1,569 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
+
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instruction_utils.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/flatten_call_graph.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+void SetChannelIdForNewCollective(HloInstruction* new_instr,
+ const HloModule* module) {
+ // This is to track mappings of old->new channel id for async collectives
+ // wrapped in the form of HloAsyncInstruction, the start and done need to
+ // have the same unique channel id.
+ absl::flat_hash_map<int64_t, int64_t> old_to_new_channel_id_map;
+ absl::flat_hash_map<int64_t, HloComputation*> channel_id_comp_map;
+ if (new_instr->IsAsynchronous() && hlo_query::IsCollectiveCommunicationOp(
+ new_instr->async_wrapped_opcode())) {
+ HloInstruction* wrapped_instr =
+ DynCast<HloAsyncInstruction>(new_instr)->async_wrapped_instruction();
+ int64_t old_channel_id = *wrapped_instr->channel_id();
+ int64_t new_channel_id = old_to_new_channel_id_map[old_channel_id];
+ if (old_to_new_channel_id_map.find(old_channel_id) ==
+ old_to_new_channel_id_map.end()) {
+ new_channel_id = hlo_query::NextChannelId(*module);
+ VLOG(2) << "Generated new channel id " << new_channel_id;
+ old_to_new_channel_id_map[old_channel_id] = new_channel_id;
+ }
+
+ VLOG(2) << "Setting channel id to " << new_channel_id;
+
+ wrapped_instr->set_channel_id(new_channel_id);
+ if (channel_id_comp_map.find(new_channel_id) == channel_id_comp_map.end()) {
+ channel_id_comp_map[new_channel_id] =
+ new_instr->async_wrapped_computation();
+ } else {
+ channel_id_comp_map[new_channel_id]->AddAsyncStart(new_instr);
+ }
+ } else if (hlo_query::IsCollectiveCommunicationOp(new_instr->opcode()) ||
+ hlo_query::IsAsyncCollectiveStartOp(new_instr)) {
+ new_instr->set_channel_id(hlo_query::NextChannelId(*module));
+ }
+}
+
+using Interval = std::pair<int64_t, int64_t>;
+
+// Parses a string of the format `{{a,b},{c,d},{e,f}...}` to a vector of pairs.
+absl::StatusOr<std::vector<Interval>> ParseVectorOfPairs(
+ absl::string_view str) {
+ TF_ASSIGN_OR_RETURN(std::vector<ReplicaGroup> replica_groups,
+ ParseReplicaGroupsOnly(str));
+ std::vector<Interval> res;
+ res.reserve(replica_groups.size());
+ for (const ReplicaGroup& replica_group : replica_groups) {
+ TF_RET_CHECK(replica_group.replica_ids_size() == 2);
+ int64_t a = replica_group.replica_ids(0);
+ int64_t b = replica_group.replica_ids(1);
+ res.emplace_back(a, b);
+ }
+ return res;
+}
+
+// This function fixes the `_xla_send_recv_validation` attribute for peeled
+// instructions. When the loop trip count is odd, the peeled instructions are
+// moved before the loop. The collectives in these instructions correspond to
+// the first iteration of the original loop. We have to run this peeled
+// collective for all those devices that had the 0-th iteration as a valid
+// iteration.
+absl::Status SetSendRecvValidationForPeeledInstr(HloInstruction* new_instr,
+ HloInstruction* old_instr) {
+ TF_RET_CHECK(
+ new_instr->opcode() == old_instr->opcode() &&
+ "cloned instruction and original instruction have different opcodes");
+ if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
+ HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
+ HloOpcode::kRecv>(old_instr)) {
+ return absl::OkStatus();
+ }
+
+ const auto& attribute_map = new_instr->frontend_attributes().map();
+ if (!attribute_map.contains(kSendRecvValidationAttr)) {
+ return absl::OkStatus();
+ }
+
+ VLOG(3) << "Original send-recv iterations: "
+ << attribute_map.at(kSendRecvValidationAttr);
+
+ TF_ASSIGN_OR_RETURN(
+ auto send_recv_validation_attr,
+ ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
+
+ uint64_t n_pairs = send_recv_validation_attr.size();
+ if (n_pairs == 0) {
+ return absl::OkStatus();
+ }
+ std::vector<Interval> send_recv_validation_attr_updated(n_pairs, {1, 0});
+ // Check which of the attributes have iteration number zero as valid
+ // iteration. For all those, set the peeled instruction to run.
+ for (std::uint64_t i = 0; i < send_recv_validation_attr.size(); i++) {
+ if (send_recv_validation_attr[i].first <= 0 &&
+ send_recv_validation_attr[i].second >= 0) {
+ send_recv_validation_attr_updated[i] = {0, 0};
+ }
+ }
+
+ hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
+ /*instr=*/new_instr, /*attr_name=*/kSendRecvValidationAttr,
+ /*intervals=*/send_recv_validation_attr_updated);
+ return absl::OkStatus();
+}
+
+// This function fixes the `_xla_send_recv_validation` attribute for the two new
+// collectives inside the loop. The calculation of the new valid iterations
+// depends on whether the loop was peeled or not.
+//
+// If the loop was not peeled, then
+// - iteration 0 of the new loop coressponds to iteration 0,1 of the old loop.
+// - iteration 1 of the new loop coressponds to iteration 2,3 of the old loop.
+// - and so on...
+// If the loop was peeled, then the first iteration runs before the loop. So,
+// - iteration 0 of the new loop coressponds to iteration 1,2 of the old loop.
+// - iteration 1 of the new loop coressponds to iteration 3,4 of the old loop.
+// - and so on...
+//
+// Consider the case when the loop was peeled, and the original attribute for
+// some device was {4,7}. Consider that the two new collectives are
+// `collective.1` and `collective.2` (they execute in this order inside the new
+// loop). In the old loop, iterations 4,5,6,7 were valid. In the new
+// loop,
+// - collective.2 in iteration 1 of new loop runs 4th iteration of old loop.
+// - collective.1 in iteration 2 of new loop runs 5th iteration of old loop.
+// - collective.2 in iteration 2 of new loop runs 6th iteration of old loop.
+// - collective.1 in iteration 3 of new loop runs 7th iteration of old loop.
+// So, the updated attribute for that device are {1,2} for `collective.2` and
+// {2,3} for `collective.1`.
+//
+// In a similar fashion we can generalize the computation of new values based on
+// the values of the old attribute as done in the logic below.
+absl::Status SetSendRecvValidation(HloInstruction* cp1, HloInstruction* cp2,
+ bool is_peeled) {
+ TF_RET_CHECK(
+ cp2->opcode() == cp1->opcode() &&
+ "cloned instruction and original instruction have different opcodes");
+ if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
+ HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
+ HloOpcode::kRecv>(cp1)) {
+ return absl::OkStatus();
+ }
+ const auto& attribute_map = cp2->frontend_attributes().map();
+ if (!attribute_map.contains(kSendRecvValidationAttr)) {
+ return absl::OkStatus();
+ }
+ VLOG(3) << "Original send-recv iterations: "
+ << attribute_map.at(kSendRecvValidationAttr);
+
+ TF_ASSIGN_OR_RETURN(
+ auto send_recv_validation_attr,
+ ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
+
+ if (send_recv_validation_attr.size() == 0) {
+ return absl::OkStatus();
+ }
+
+ std::vector<Interval> send_recv_iterations_new_instr1,
+ send_recv_iterations_new_instr2;
+ send_recv_iterations_new_instr1.reserve(send_recv_validation_attr.size());
+ send_recv_iterations_new_instr2.reserve(send_recv_validation_attr.size());
+ for (const Interval& pair : send_recv_validation_attr) {
+ int64_t a = pair.first;
+ int64_t b = pair.second;
+ if (is_peeled) {
+ send_recv_iterations_new_instr1.emplace_back(
+ std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
+ send_recv_iterations_new_instr2.emplace_back(
+ std::max(0.0, std::floor((a - 1) / 2.0)),
+ std::max(0.0, std::floor((b - 2) / 2.0)));
+ } else {
+ send_recv_iterations_new_instr1.emplace_back(std::floor((a + 1) / 2.0),
+ std::floor(b / 2.0));
+ send_recv_iterations_new_instr2.emplace_back(
+ std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
+ }
+ }
+
+ hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
+ /*instr=*/cp1, /*attr_name=*/kSendRecvValidationAttr,
+ /*intervals=*/send_recv_iterations_new_instr1);
+ hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
+ /*instr=*/cp2, /*attr_name=*/kSendRecvValidationAttr,
+ /*intervals=*/send_recv_iterations_new_instr2);
+
+ VLOG(3) << "Updated send-recv iterations for " << cp1->name() << " : "
+ << cp1->frontend_attributes().map().at(kSendRecvValidationAttr);
+ VLOG(3) << "Updated send-recv iterations for " << cp2->name() << " : "
+ << cp2->frontend_attributes().map().at(kSendRecvValidationAttr);
+ return absl::OkStatus();
+}
+
+// Handle control predecessors/successors for every old-new instruction pair.
+// For every new instruction, we find the relevant predecessor/successor
+// relationships of the old instruction and we reconstruct them by looking up
+// new (already created) predecessors/successors.
+//
+// When rewiring dependencies from output of the original body, to the input of
+// the cloned body we skip collectives, and ops in `skip_control_dep_injection`.
+absl::Status HandleControlDependencies(
+ const HloComputation* while_body,
+ const absl::flat_hash_map<HloInstruction*, HloInstruction*>& old_to_new_map,
+ HloInstruction::InstructionVector* old_loop_roots,
+ HloInstruction* input_parameter,
+ const absl::flat_hash_set<HloInstruction*>& skip_control_dep_injection) {
+ for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+ if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+ HloInstruction* new_instr = old_to_new_map.at(old_instr);
+ VLOG(2) << "Processing control predecessors for "
+ << new_instr->ToString();
+ std::vector<HloInstruction*> new_control_pred;
+ new_control_pred.reserve(old_instr->control_predecessors().size());
+ for (HloInstruction* pred : old_instr->control_predecessors()) {
+ if (!old_to_new_map.contains(pred)) {
+ continue;
+ }
+ new_control_pred.push_back(old_to_new_map.at(pred));
+ }
+
+ TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
+ for (HloInstruction* new_pred : new_control_pred) {
+ TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
+ VLOG(2) << "Adding " << new_pred->ToString()
+ << " to control dependency of " << new_instr->ToString();
+ }
+ }
+ }
+ for (HloInstruction* input_consumer : input_parameter->users()) {
+ for (HloInstruction* old_input : input_consumer->users()) {
+ if (old_to_new_map.find(old_input) != old_to_new_map.end()) {
+ HloInstruction* new_input = old_to_new_map.at(old_input);
+ if (skip_control_dep_injection.find(old_input) ==
+ skip_control_dep_injection.end() &&
+ !IsCollective(old_input)) {
+ for (HloInstruction* old_root : *old_loop_roots) {
+ TF_RETURN_IF_ERROR(old_root->AddControlDependencyTo(new_input));
+ }
+ }
+ }
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+absl::StatusOr<bool> FullyUnroll(HloInstruction* while_instr,
+ HloModule* module) {
+ HloComputation* while_body = while_instr->while_body();
+ bool changed = false;
+ VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
+
+ auto loop_roots = while_body->root_instruction()->mutable_operands();
+ HloInstruction* input_parameter = while_body->parameter_instruction(0);
+ VLOG(2) << "Processing input parameter " << input_parameter->ToString();
+
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
+ absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
+ std::string clone_suffix = "full_unroll_clone";
+
+ TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
+ while_instr->backend_config<WhileLoopBackendConfig>());
+ std::vector<HloInstruction*> ops_to_clone;
+ ops_to_clone.reserve(while_body->MakeInstructionPostOrder().size());
+
+ // Pre-loop prep.
+ HloInstruction* old_input_parameter = input_parameter;
+ HloInstruction* new_input_parameter = while_body->root_instruction();
+ absl::flat_hash_set<HloInstruction*> seen_ops;
+ for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+ if (seen_ops.contains(old_instr)) {
+ continue;
+ }
+ ops_to_clone.push_back(old_instr);
+ seen_ops.insert(old_instr);
+ }
+
+ int n = config.known_trip_count().n();
+ while (--n) {
+ std::vector<HloInstruction*> new_ops_to_clone;
+ old_to_new_map[old_input_parameter] = new_input_parameter;
+ for (HloInstruction* old_instr : ops_to_clone) {
+ if (old_to_new_map.contains(old_instr)) {
+ continue;
+ }
+ VLOG(2) << "Cloning instruction " << old_instr->ToString();
+ std::vector<HloInstruction*> new_operands;
+ for (HloInstruction* old_operand : old_instr->mutable_operands()) {
+ new_operands.push_back(old_to_new_map[old_operand]);
+ }
+ HloInstruction* new_instr =
+ while_body->AddInstruction(old_instr->CloneWithNewOperands(
+ old_instr->shape(), new_operands, clone_suffix));
+
+ // If an elementwise instruction with constant operand is present, we
+ // won't inject control dependency at the end to allow more constant
+ // folding opportunities.
+ if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
+ skip_control_dep_injection.insert(old_instr);
+ }
+ SetChannelIdForNewCollective(new_instr, module);
+ old_to_new_map[old_instr] = new_instr;
+ new_ops_to_clone.push_back(new_instr);
+ VLOG(2) << "Added instruction " << new_instr->ToString();
+ }
+
+ while_body->set_root_instruction(
+ old_to_new_map[while_body->root_instruction()]);
+ VLOG(2) << "Replaced with new root "
+ << while_body->root_instruction()->ToString();
+
+ TF_RETURN_IF_ERROR(HandleControlDependencies(
+ while_body, old_to_new_map, &loop_roots, old_input_parameter,
+ skip_control_dep_injection));
+
+ // Inductive step update, clean/update necessary buffers to prepare them for
+ // the next unrolling iteration.
+ old_to_new_map.clear();
+ skip_control_dep_injection.clear();
+ loop_roots = while_body->root_instruction()->mutable_operands();
+ old_input_parameter = new_input_parameter;
+ new_input_parameter = while_body->root_instruction();
+ ops_to_clone = std::move(new_ops_to_clone);
+ changed = true;
+ }
+
+ WhileLoopBackendConfig new_config;
+ new_config.mutable_known_trip_count()->set_n(1);
+ TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
+
+ return changed;
+}
+
+absl::Status PeelInstructionsForOddTripCount(HloModule* module,
+ HloInstruction* while_instr) {
+ std::string suffix = "peeled_double_buffer";
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
+ HloComputation* while_body = while_instr->while_body();
+ HloInstruction* input_parameter = while_body->parameter_instruction(0);
+ HloInstruction* input_tuple = while_instr->mutable_operand(0);
+
+ auto old_loop_roots = while_body->root_instruction()->mutable_operands();
+ HloComputation* parent_comp = while_instr->parent();
+ old_to_new_map[input_parameter] = input_tuple;
+
+ for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+ if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+ continue;
+ }
+ VLOG(2) << "Peeling instruction " << old_instr->ToString();
+ std::vector<HloInstruction*> new_operands(old_instr->operand_count());
+ for (int64_t i = 0; i < old_instr->operand_count(); i++) {
+ new_operands[i] = old_to_new_map[old_instr->mutable_operand(i)];
+ }
+ HloInstruction* new_instr =
+ parent_comp->AddInstruction(old_instr->CloneWithNewOperands(
+ old_instr->shape(), new_operands, suffix));
+
+ SetChannelIdForNewCollective(new_instr, module);
+ TF_CHECK_OK(SetSendRecvValidationForPeeledInstr(new_instr, old_instr));
+ old_to_new_map[old_instr] = new_instr;
+ VLOG(2) << "Added instruction " << new_instr->ToString()
+ << " to parent computation.";
+ }
+
+ std::vector<HloInstruction*> new_roots;
+ for (HloInstruction* instr : old_loop_roots) {
+ new_roots.push_back(old_to_new_map[instr]);
+ }
+ TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(
+ 0, old_to_new_map[while_body->root_instruction()]));
+ VLOG(2) << "Replaced with new input tuple "
+ << while_instr->operand(0)->ToString();
+
+ // Handle existing control dependencies.
+ for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+ if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+ HloInstruction* new_instr = old_to_new_map[old_instr];
+ VLOG(2) << "Processing control predecessors for peeled instruction "
+ << new_instr->ToString();
+ std::vector<HloInstruction*> new_control_pred(
+ old_instr->control_predecessors().size());
+ for (HloInstruction* pred : old_instr->control_predecessors()) {
+ new_control_pred.push_back(old_to_new_map[pred]);
+ }
+
+ TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
+ for (HloInstruction* new_pred : new_control_pred) {
+ TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
+ VLOG(2) << "Adding " << new_pred->ToString()
+ << " to control dependency of peeled instruction: "
+ << new_instr->ToString();
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
+// TODO(olechwierowicz): Extract common logic of this and `FullyUnroll` to
+// a separate function.
+absl::StatusOr<bool> DoubleBufferingUnroll(HloInstruction* while_instr,
+ HloModule* module) {
+ TF_ASSIGN_OR_RETURN(auto config,
+ while_instr->backend_config<WhileLoopBackendConfig>());
+
+ CHECK(config.has_known_trip_count())
+ << "Only loops with known trip count are supported.";
+ int64_t exact_trip_count = config.known_trip_count().n();
+ VLOG(2) << "Processing while loop " << while_instr->ToString()
+ << " with trip count: " << exact_trip_count;
+
+ HloComputation* while_body = while_instr->while_body();
+
+ VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
+
+ auto old_loop_roots = while_body->root_instruction()->mutable_operands();
+ HloInstruction* input_parameter = while_body->parameter_instruction(0);
+ VLOG(2) << "Processing input parameter " << input_parameter->ToString();
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
+ absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
+
+ bool is_peeled = exact_trip_count % 2;
+ if (is_peeled) {
+ VLOG(2) << "Found loops with odd trip count, 1 iteration will be peeled "
+ "outside of the main body.";
+ TF_RETURN_IF_ERROR(PeelInstructionsForOddTripCount(module, while_instr));
+ exact_trip_count -= 1;
+ }
+
+ std::string suffix = "double_buffer_clone";
+ old_to_new_map[input_parameter] = while_body->root_instruction();
+ for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+ if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+ continue;
+ }
+ VLOG(2) << "Cloning instruction " << old_instr->ToString();
+ std::vector<HloInstruction*> new_operands;
+ for (HloInstruction* old_operand : old_instr->mutable_operands()) {
+ new_operands.push_back(old_to_new_map[old_operand]);
+ }
+ HloInstruction* new_instr =
+ while_body->AddInstruction(old_instr->CloneWithNewOperands(
+ old_instr->shape(), new_operands, suffix));
+
+ // If an elementwise instruction with constant operand is present, we
+ // won't inject control dependency at the end to allow more constant
+ // folding opportunities.
+ if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
+ skip_control_dep_injection.insert(old_instr);
+ }
+ SetChannelIdForNewCollective(new_instr, module);
+ TF_CHECK_OK(SetSendRecvValidation(old_instr, new_instr, is_peeled));
+ old_to_new_map[old_instr] = new_instr;
+ VLOG(2) << "Added instruction " << new_instr->ToString();
+ }
+
+ while_body->set_root_instruction(
+ old_to_new_map[while_body->root_instruction()]);
+ VLOG(2) << "Replaced with new root "
+ << while_body->root_instruction()->ToString();
+
+ // Handle existing control dependencies.
+ TF_RETURN_IF_ERROR(HandleControlDependencies(while_body, old_to_new_map,
+ &old_loop_roots, input_parameter,
+ skip_control_dep_injection));
+
+ WhileLoopBackendConfig new_config;
+ new_config.mutable_known_trip_count()->set_n(exact_trip_count / 2);
+ TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
+ return true; // changed
+}
+
+} // namespace
+
+absl::StatusOr<bool> DoubleBufferLoopUnrolling::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ std::vector<HloInstruction*> while_instrs;
+ for (auto comp : module->MakeNonfusionComputations()) {
+ absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+ HloPredicateIsOp<HloOpcode::kWhile>);
+ }
+ VLOG(2) << "Processing " << while_instrs.size() << " while loops.";
+
+ for (HloInstruction* while_instr : while_instrs) {
+ TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
+ while_instr->backend_config<WhileLoopBackendConfig>());
+ if (!config.has_known_trip_count() || config.known_trip_count().n() == 1) {
+ VLOG(2) << while_instr->ToString()
+ << " doesn't have exact trip count, skipping loop unrolling "
+ "for now";
+ continue;
+ }
+
+ if (unroll_strategy_ == UnrollStrategy::kFullUnroll) {
+ TF_ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module));
+ } else if (unroll_strategy_ == UnrollStrategy::kDoubleBuffer) {
+ TF_ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module));
+ } else {
+ LOG(FATAL) << absl::StrCat("Unhandled unrolling strategy: ",
+ unroll_strategy_);
+ }
+ }
+
+ VLOG(2) << "LoopDoubleBufferTransformer output: " << module->ToString();
+
+ // Run necessary cleanup to ensure LoopDoubleBufferTransformer behaves
+ // correctly.
+ if (changed) {
+ // The call graph will not be flat if one of the loops that was unrolled
+ // contains any kind of call to another computation---since the call will
+ // be duplicated, thereby adding a second callsite for that computation.
+ TF_RETURN_IF_ERROR(
+ FlattenCallGraph().Run(module, execution_threads).status());
+ }
+
+ return changed;
+}
+
+} // end namespace gpu
+} // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h
new file mode 100644
index 0000000..26bb178
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h
@@ -0,0 +1,75 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// With `kDoubleBuffer` strategy:
+// This pass performs the unrolling-by-2 loop transformation
+// to effectively achieve double buffering between inputs and outputs
+// of previously rolled iterations.
+// This pass only runs on loops with known trip counts.
+// For even number of iterations, unrolling-by-2 will be done directly.
+// For odd number of iterations, the first iteration of the loop will be
+// peeled outside of the while loop to make the trip count an even number,
+// then proceed to unroll by 2.
+// It also updates the trip count property of the loop to the correct one
+// (n/2).
+//
+// With `kFullUnroll` strategy:
+// This pass will perform the full unroll of the loop with the same strategy
+// that is used with `kDoubleBuffer` but while loop trip count times.
+// It updates the trip count of the while loop to 1, and relies on other
+// passes (like `WhileLoopSimplifier`) to simplify/get rid of the while loop
+// eventually.
+//
+// Note that this pass will flatten the call graph if any loop has been
+// unrolled.
+// TODO(olechwierowicz): Rename the loop unroller to something more generic like
+// 'DoubleBufferLoopUnrolling'.
+class DoubleBufferLoopUnrolling : public HloModulePass {
+ public:
+ enum class UnrollStrategy { kDoubleBuffer, kFullUnroll };
+
+ explicit DoubleBufferLoopUnrolling(
+ UnrollStrategy unroll_strategy = UnrollStrategy::kDoubleBuffer)
+ : unroll_strategy_(unroll_strategy) {};
+ ~DoubleBufferLoopUnrolling() override = default;
+
+ absl::string_view name() const override {
+ return "loop-double-buffer-transformer";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ UnrollStrategy unroll_strategy_;
+};
+
+} // end namespace gpu
+} // end namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc
new file mode 100644
index 0000000..05e704d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc
@@ -0,0 +1,1243 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+
+#include "absl/container/flat_hash_set.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/tuple_simplifier.h"
+#include "xla/test.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using tsl::testing::IsOkAndHolds;
+
+int64_t CountInstructions(const HloComputation& computation, HloOpcode opcode) {
+ int64_t count = 0;
+ for (const auto& instruction : computation.instructions()) {
+ if (instruction->opcode() == opcode) {
+ count++;
+ }
+ }
+ return count;
+}
+
+int64_t CountInstructions(const HloModule& module, HloOpcode opcode) {
+ int64_t count = 0;
+ for (const auto& computation : module.computations()) {
+ count += CountInstructions((*computation), opcode);
+ }
+ return count;
+}
+
+class GpuLoopDoubleBufferTransformerTest : public HloTestBase {
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_while_loop_double_buffering(true);
+ return debug_options;
+ }
+};
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollOddTripCountTest) {
+ const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+ TupleSimplifier tuple_simp;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
+ EXPECT_TRUE(changed);
+ TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
+ EXPECT_TRUE(changed);
+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+ *module->entry_computation(), HloOpcode::kWhile);
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ while_instruction->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ EXPECT_EQ(exact_trip_count, 1);
+ EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+ HloOpcode::kAllGatherStart),
+ 11);
+ EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 11);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollEvenTripCountTest) {
+ const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ // Start all-gather communication
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+ TupleSimplifier tuple_simp;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
+ EXPECT_TRUE(changed);
+ TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloInstruction* while_instruction;
+ for (auto instr : module->entry_computation()->instructions()) {
+ if (instr->opcode() == HloOpcode::kWhile) {
+ while_instruction = instr;
+ }
+ }
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ while_instruction->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ EXPECT_EQ(exact_trip_count, 1);
+ EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+ HloOpcode::kAllGatherStart),
+ 10);
+ EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 10);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopEvenTripCount) {
+ const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ // Start all-gather communication
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ // Intertwined with the all-gather communication, an operation happens which
+ // depends on param_1, but crucially has a different output shape (which
+ // excludes reusing param_1's buffer for its output).
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ // The all-gather communication finishes
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer;
+ TupleSimplifier tuple_simp;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
+ EXPECT_TRUE(changed);
+ TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+ *module->entry_computation(), HloOpcode::kWhile);
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ while_instruction->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ // We expect that after unrolling, the total trip count is half of original
+ // count.
+ EXPECT_EQ(exact_trip_count, 5);
+ // We expect that after unrolling, there should be 2 allgather starts,
+ // both in while body.
+ EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+ HloOpcode::kAllGatherStart),
+ 2);
+ EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 2);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopOddTripCount) {
+ const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ // Start all-gather communication
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ // Intertwined with the all-gather communication, an operation happens which
+ // depends on param_1, but crucially has a different output shape (which
+ // excludes reusing param_1's buffer for its output).
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ // The all-gather communication finishes
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer;
+ TupleSimplifier tuple_simp;
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+ // We expect that for the while loop, no further copy needs to be added to the
+ // module.
+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+ *module->entry_computation(), HloOpcode::kWhile);
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ while_instruction->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ // We expect that after unrolling, the total trip count is half of original
+ // count.
+ EXPECT_EQ(exact_trip_count, 5);
+
+ // We expect that after unrolling, there should be 3 allgather starts,
+ // 1 in parent computation, 2 in while body.
+ EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+ HloOpcode::kAllGatherStart),
+ 2);
+ EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 3);
+
+ // We expect that after unrolling, the third operand of the input tuple should
+ // be the peeled allgather done.
+ EXPECT_EQ(while_instruction->operand(0)->operand(2)->opcode(),
+ HloOpcode::kAllGatherDone);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ UnrolledLoopNoControlDepsForConstantAdd) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ c2 = f32[] constant(2)
+ add = f32[] add(c2, param_0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(add, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer;
+ TupleSimplifier tuple_simp;
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+ *module->entry_computation(), HloOpcode::kWhile);
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ while_instruction->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ // We expect that after unrolling, the total trip count is half of original
+ // count.
+ EXPECT_EQ(exact_trip_count, 5);
+
+ // We expect that after unrolling, there should be 4 adds
+ EXPECT_EQ(
+ CountInstructions((*while_instruction->while_body()), HloOpcode::kAdd),
+ 4);
+
+ // We expect that after unrolling, the first operand of the output tuple
+ // should not have any control dependency since it's a elementwise add with a
+ // constant operand.
+ EXPECT_EQ(while_instruction->while_body()
+ ->root_instruction()
+ ->operand(0)
+ ->control_predecessors()
+ .size(),
+ 0);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ UnrolledLoopNoControlDepsForCollective) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
+ one = s32[] constant(1)
+ all-reduce-done = f32[] all-reduce-done(all-reduce-start)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer;
+ TupleSimplifier tuple_simp;
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+ *module->entry_computation(), HloOpcode::kWhile);
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ while_instruction->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ // We expect that after unrolling, the total trip count is half of original
+ // count.
+ EXPECT_EQ(exact_trip_count, 5);
+
+ // We expect that after unrolling, there should be 2 all-reduce-starts
+ EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+ HloOpcode::kAllReduceStart),
+ 2);
+ absl::flat_hash_set<int64_t> channel_ids;
+ hlo_query::ForEachInstructionWithOpcode(
+ *while_instruction->while_body(), HloOpcode::kAllReduceStart,
+ [&channel_ids](HloInstruction* ar) {
+ // We expect that after unrolling, all-reduces should not have any
+ // control deps.
+ EXPECT_EQ(ar->control_predecessors().size(), 0);
+ channel_ids.insert(*(ar->channel_id()));
+ });
+ // we expect that all 2 all-reduces will have different channel ids.
+ EXPECT_EQ(channel_ids.size(), 2);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ FullyUnrolledLoopNoControlDepsForCollective) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
+ one = s32[] constant(1)
+ all-reduce-done = f32[] all-reduce-done(all-reduce-start)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+ TupleSimplifier tuple_simp;
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+ HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+ *module->entry_computation(), HloOpcode::kWhile);
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ while_instruction->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ EXPECT_EQ(exact_trip_count, 1);
+
+ // We expect that after unrolling, there should be 10 all-reduce-starts
+ EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+ HloOpcode::kAllReduceStart),
+ 10);
+ absl::flat_hash_set<int64_t> channel_ids;
+ hlo_query::ForEachInstructionWithOpcode(
+ *while_instruction->while_body(), HloOpcode::kAllReduceStart,
+ [&channel_ids](HloInstruction* ar) {
+ // We expect that after unrolling, all-reduces should not have any
+ // control deps.
+ EXPECT_EQ(ar->control_predecessors().size(), 0);
+ channel_ids.insert(*(ar->channel_id()));
+ });
+ // we expect that all 10 all-reduces will have different channel ids.
+ EXPECT_EQ(channel_ids.size(), 10);
+}
+
+// The following 2 tests also address the regression described here:
+// https://github.com/openxla/xla/issues/6353
+TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopRemainsFlattened) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_while_loop_remains_flattened
+
+condition_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+
+condition {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (s32[]) parameter(0)
+ ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
+}
+
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer;
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ absl::flat_hash_set<const HloComputation*> while_loops_callees;
+
+ hlo_query::ForEachInstructionWithOpcode(
+ *module, HloOpcode::kWhile,
+ [&while_loops_callees](HloInstruction* instr) {
+ EXPECT_TRUE(
+ while_loops_callees.insert(instr->while_condition()).second);
+ EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
+ });
+
+ // We expect that the nested while loop has been duplicated, along with its
+ // associated computations.
+ EXPECT_EQ(while_loops_callees.size(), 6);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ NestedWhileLoopRemainsFlattenedOddTripCount) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_while_loop_remains_flattened
+
+condition_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+
+condition {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (s32[]) parameter(0)
+ ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
+}
+
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer;
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ absl::flat_hash_set<const HloComputation*> while_loops_callees;
+
+ hlo_query::ForEachInstructionWithOpcode(
+ *module, HloOpcode::kWhile,
+ [&while_loops_callees](HloInstruction* instr) {
+ EXPECT_TRUE(
+ while_loops_callees.insert(instr->while_condition()).second);
+ EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
+ });
+
+ // We expect that the nested while loop has been duplicated, along with its
+ // associated computations.
+ EXPECT_EQ(while_loops_callees.size(), 8);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ NestedWhileLoopRemainsFlattenedWhenFullyUnrolled) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_while_loop_remains_flattened
+
+condition_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+
+condition {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (s32[]) parameter(0)
+ ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
+}
+
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ absl::flat_hash_set<const HloComputation*> while_loops_callees;
+
+ hlo_query::ForEachInstructionWithOpcode(
+ *module, HloOpcode::kWhile,
+ [&while_loops_callees](HloInstruction* instr) {
+ EXPECT_TRUE(
+ while_loops_callees.insert(instr->while_condition()).second);
+ EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
+ });
+
+ hlo_query::ForEachInstructionWithOpcode(
+ *module->entry_computation(), HloOpcode::kWhile,
+ [](HloInstruction* instr) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileLoopBackendConfig config,
+ instr->backend_config<WhileLoopBackendConfig>());
+ int64_t exact_trip_count = config.known_trip_count().n();
+ EXPECT_EQ(exact_trip_count, 1);
+ });
+
+ // We expect that the nested while loop has been fully duplicated 10
+ // times. The one outer while loop still remains so that's 11 while
+ // instructions. We check whether there are 22 distinct computations for
+ // each while loop body and condition.
+ EXPECT_EQ(while_loops_callees.size(), 22);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreUnrolled) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_are_unrolled
+condition_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+condition {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body {
+ input_tuple = (s32[]) parameter(0)
+ ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
+}
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer;
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ int64_t num_whiles = 0;
+ hlo_query::ForEachInstructionWithOpcode(
+ *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
+ EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
+ ->known_trip_count()
+ .n(),
+ 5);
+ ++num_whiles;
+ });
+ // We expect the number of while loops to be 4 in total after unrolling.
+ EXPECT_EQ(num_whiles, 4);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreFullyUnrolled) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_are_unrolled
+condition_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+condition {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body {
+ input_tuple = (s32[]) parameter(0)
+ ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
+}
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ int64_t num_whiles = 0;
+ hlo_query::ForEachInstructionWithOpcode(
+ *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
+ EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
+ ->known_trip_count()
+ .n(),
+ 1);
+ ++num_whiles;
+ });
+ EXPECT_EQ(num_whiles, 12);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithCollectivePermute) {
+ const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
+ frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+ VLOG(0) << module->ToString();
+ EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+ // CHECK: %body {{.+}} {
+ // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}}
+ // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+ // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+ // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}}
+ // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+ // CHECK: }
+ // CHECK: ENTRY %main {{.+}} {
+ // CHECK-NOT: collective-permute
+ // CHECK: }
+ )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ WhileLoopWithCollectivePermutePeeled) {
+ const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(15)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,0}},
+ frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10},{4,11},{5,12},{6,13},{7,14}}"}
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
+}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+ VLOG(0) << module->ToString();
+
+ EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+ // CHECK: %body
+ // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}}
+ // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+ // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]])
+ // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}}
+ // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+ // CHECK: ENTRY %main {{.+}} {
+ // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}}
+ // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
+ // CHECK: %[[while:.+]] = {{.+}} while({{.+}} %[[out_peeled]])
+ // CHECK: }
+ )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ WhileLoopWithCollectivePermuteBackwardCycle) {
+ const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(14)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
+ frontend_attributes={_xla_send_recv_validation="{{7,13},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}}"}
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"14"}}
+}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+ // CHECK: %body
+ // CHECK: %[[cp1:.+]] = f32[] collective-permute(f32[] %param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}}
+ // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+ // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+ // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}}
+ // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+ // CHECK: ENTRY %main
+ // CHECK-NOT: collective-permute
+ // CHECK: }
+ )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ WhileLoopWithCollectivePermuteBackwardCyclePeeled) {
+ const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(15)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
+ frontend_attributes={_xla_send_recv_validation="{{7,14},{6,13},{5,12},{4,11},{3,10},{2,9},{1,8},{0,7}}"}
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
+}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+ // CHECK: %body
+ // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3}{{[}]}}}
+ // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+ // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+ // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}}
+ // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+ // CHECK: }
+ // CHECK: ENTRY %main
+ // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}}
+ // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
+ // CHECK: ROOT {{.+}} = {{.+}} while({{.+}} %[[out_peeled]])
+ // CHECK: }
+ )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ WhileLoopWithCollectivePermuteStartDone) {
+ const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ collective-permute-start = (f32[], f32[], u32[], u32[]) collective-permute-start(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
+ frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
+ collective-permute = f32[] collective-permute-done(collective-permute-start)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+ // CHECK: %body
+ // CHECK: %[[cp_start1:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}}
+ // CHECK: %[[cp1:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start1]])
+ // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+ // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+ // CHECK: %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}}
+ // CHECK: %[[cp2:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start2]])
+ // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+ // CHECK: }
+ // CHECK: ENTRY %main
+ // CHECK-NOT: collective-permute
+ // CHECK: }
+ )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithRecvDone) {
+ const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ after-all.0 = token[] after-all()
+ recv.0 = (f32[], u32[], token[]) recv(after-all.0), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
+ _xla_send_recv_pipeline="0",
+ _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
+ }
+ recv-done.0 = (f32[], token[]) recv-done(recv.0), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ recv-data = f32[] get-tuple-element(recv-done.0), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(recv-data, cond_plus_1)
+}
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+ // CHECK: %body
+ // CHECK: %[[recv1:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}
+ // CHECK: %[[recv2:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}
+ // CHECK: ENTRY %main
+ // CHECK-NOT: recv
+ // CHECK: }
+ )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithSendDone) {
+ const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+ input_tuple = (f32[], s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ trip_count = s32[] constant(10)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+ Arg_1 = f32[] parameter(1)
+ Arg_0 = f32[] parameter(0)
+ ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ after-all.0 = token[] after-all()
+ send.0 = (f32[], u32[], token[]) send(param_0, after-all.0), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
+ _xla_send_recv_pipeline="0",
+ _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
+ }
+ send-done.0 = token[] send-done(send.0), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(param_0, cond_plus_1)
+}
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+ EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+ // CHECK: %body
+ // CHECK: %[[send1:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}
+ // CHECK: %[[send2:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}
+ // CHECK: ENTRY %main
+ // CHECK-NOT: send
+ // CHECK: }
+ )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+ WhileLoopWithTripCount1ShouldBeSkipped) {
+ const char* const kModuleString = R"(
+HloModule loop_unrolling_skipped
+condition_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(0)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+condition {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ trip_count = s32[] constant(0)
+ ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body {
+ input_tuple = (s32[]) parameter(0)
+ ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"1"}}
+}
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"1"}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+ ParseAndReturnVerifiedModule(kModuleString));
+ DoubleBufferLoopUnrolling double_buffer(
+ DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+ // The processing of the loop should be completely skipped.
+ EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(false));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc
new file mode 100644
index 0000000..f059607
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc
@@ -0,0 +1,627 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/custom_call_target_registry.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/gpu_constants.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace m = ::xla::match;
+
+// A dataflow path flowing from a definition to a user.
+using DefUseDataflowPath = absl::InlinedVector<HloInstruction*, 2>;
+
+// All dataflow paths flowing from a definition to all users. Each user will
+// have a separate entry in the vector.
+using DefUseDataflowPaths = absl::InlinedVector<DefUseDataflowPath, 4>;
+
+// A dataflow path flowing from a user to a definition.
+using UseDefDataflowPath = absl::InlinedVector<HloInstruction*, 4>;
+
+// All dataflow paths flowing from a user to all definitions of its operands.
+using UseDefDataflowPaths = absl::InlinedVector<HloInstruction*, 8>;
+
+using DataflowPathView = absl::Span<HloInstruction* const>;
+using DataflowPathsView = absl::Span<DataflowPathView>;
+
+using InstructionSet = absl::flat_hash_set<HloInstruction*>;
+
+bool IsNoOp(const HloInstruction* hlo) {
+ return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
+ HloOpcode::kGetTupleElement>(hlo);
+}
+
+bool IsCustomCall(const HloInstruction* hlo, absl::string_view platform_name) {
+ auto* custom_call = DynCast<HloCustomCallInstruction>(hlo);
+ if (custom_call == nullptr) return false;
+
+ // TODO(vuson): properly handle token by following
+ // `LhloDialectEmitter::EmitCustomCallOp`'s `CreateOperands` logic for
+ // `LhloDialectEmitter::EmitFusionOp`'s `RewriteFusionOperand`
+ if (custom_call->shape().IsTuple() &&
+ absl::c_any_of(
+ custom_call->shape().tuple_shapes(),
+ [&](const Shape& sub_shape) { return sub_shape.IsToken(); }))
+ return false;
+
+ const std::string call_target_name = custom_call->custom_call_target();
+
+ bool is_ffi_custom_call =
+ custom_call->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI;
+
+ void* call_target = CustomCallTargetRegistry::Global()->Lookup(
+ call_target_name, std::string(platform_name));
+
+ absl::StatusOr<ffi::HandlerRegistration> handler_registration =
+ ffi::FindHandler(call_target_name, platform_name);
+
+ // At least one implementation should be available at run time.
+ bool found_custom_call = !is_ffi_custom_call && call_target != nullptr;
+ bool found_ffi_handler = is_ffi_custom_call && handler_registration.ok();
+
+ return found_custom_call || found_ffi_handler;
+}
+
+// Returns true if the slice is 128-byte-aligned. The slice starting
+// address is determined by the product of all non-sliced dimensions and an
+// offset defined by `slice_starts` of the slice op.
+//
+// For dynamic cases, we don't have info about the start indices, so we have to
+// be conservative by only accepting sliced shapes that have the product of all
+// non-sliced dimensions being a multiple of `kXlaAllocatedBufferAlignBytes`.
+bool IsAlignedSlice(const HloInstruction* slice) {
+ DCHECK(slice->opcode() == HloOpcode::kSlice ||
+ slice->opcode() == HloOpcode::kDynamicSlice ||
+ slice->opcode() == HloOpcode::kDynamicUpdateSlice)
+ << "Unknown slice operation: " << slice->ToString();
+
+ if (!IsContiguousSlice(*slice)) return false;
+
+ auto [full_shape, slice_shape] = [&] {
+ if (auto* dus = DynCast<HloDynamicUpdateSliceInstruction>(slice)) {
+ return std::make_pair(dus->shape(), dus->update()->shape());
+ }
+ return std::make_pair(slice->operand(0)->shape(), slice->shape());
+ }();
+
+ auto strides = ShapeUtil::ByteStrides(slice_shape);
+ if (!strides.has_value()) return false;
+
+ for (auto dim : slice_shape.layout().minor_to_major()) {
+ if ((strides.value()[dim] % kXlaAllocatedBufferAlignBytes) == 0) {
+ return true;
+ }
+ if (slice_shape.dimensions(dim) < full_shape.dimensions(dim)) {
+ return (slice->opcode() == HloOpcode::kSlice &&
+ (((*strides)[dim] * slice->slice_starts(dim)) %
+ kXlaAllocatedBufferAlignBytes ==
+ 0));
+ }
+ }
+ return true;
+}
+
+// Pattern matches the following IR (generated by `jax.lax.scan`) to check if
+// the offset is a loop iteration number:
+
+// clang-format off
+// param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0)
+// // the index in `gte` has to be the loop iteration index
+// gte = s32[] get-tuple-element(param), index=0
+// c0 = s32[] constant(0) compare = pred[] compare(gte, c0), direction=LT
+// c_trip_count = s32[] constant(16)
+// add = s32[] add(gte, c_trip_count) select = s32[] select(compare, add, gte)
+// clang-format on
+
+bool IsLoopIterationNumber(const HloInstruction& offset) {
+ const HloComputation* parent = offset.parent();
+ if (!parent->IsWhileBodyComputation()) return false;
+
+ // Scan loops trip count must be known at compile time as it iterates over the
+ // leading dimension of the statically shaped input.
+ const HloInstruction* while_instr = parent->WhileCallInstruction();
+ auto config = while_instr->backend_config<xla::WhileLoopBackendConfig>();
+ if (!config.ok() || !config->has_known_trip_count()) return false;
+ int32_t trip_count = config->known_trip_count().n();
+
+ // First lets check the offset computation pattern
+ if (!Match(&offset, m::Select(m::Lt(m::GetTupleElement(m::Parameter(0)),
+ m::ConstantScalar<int32_t>(0)),
+ m::Add(m::GetTupleElement(m::Parameter(0)),
+ m::ConstantScalar(trip_count)),
+ m::GetTupleElement(m::Parameter())))) {
+ return false;
+ }
+
+ // Next, we check that the parameter used in offset computation is the loop
+ // induction variable
+ int64_t param_idx = offset.operand(2)->tuple_index();
+ const HloInstruction* root = offset.parent()->root_instruction();
+ if (root->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+ // Check the update operation
+ const HloInstruction* updated_var =
+ offset.parent()->root_instruction()->operand(param_idx);
+ if (!Match(updated_var, m::Add(m::GetTupleElement(m::Parameter(0), param_idx),
+ m::ConstantScalar(1)))) {
+ return false;
+ }
+ // Check that the condition considers this.
+ const HloInstruction* condition_root =
+ while_instr->while_condition()->root_instruction();
+ if (!Match(condition_root,
+ m::Lt(m::GetTupleElement(m::Parameter(0), param_idx),
+ m::ConstantScalar(trip_count)))) {
+ return false;
+ }
+ // Check init
+ const HloInstruction* init_loop_iter =
+ while_instr->operand(0)->operand(param_idx);
+ if (!Match(init_loop_iter, m::ConstantScalar(0))) {
+ return false;
+ }
+
+ return true;
+}
+
+// This returns true for the constants that are handled in the dynamic slice
+// fusion runtime. These constants do not force a D2H copy and hence preserve
+// the cuda graph.
+bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) {
+ if (auto* cst = DynCast<HloConstantInstruction>(&offset)) {
+ switch (cst->shape().element_type()) {
+ case PrimitiveType::S32:
+ case PrimitiveType::S64:
+ case PrimitiveType::U32:
+ case PrimitiveType::U64:
+ return true;
+ default:
+ return false;
+ };
+ }
+ return false;
+}
+
+// This checks whether a dynamic index operation has all offsets that are either
+// constant or loop iteration offsets.
+bool HasConstantOrLoopIterationOffsets(
+ const HloDynamicIndexInstruction& instr) {
+ return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) {
+ return IsLoopIterationNumber(*offset) ||
+ IsHandledConstantForDynamicSliceFusion(*offset);
+ });
+}
+
+UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) {
+ UseDefDataflowPaths sliced_operand_paths;
+
+ // This set is used to avoid duplicates in the matched results. It contains
+ // the matched instructions that we have seen so far.
+ InstructionSet processed_instrs;
+
+ const auto& aliasing_pairs =
+ Cast<HloCustomCallInstruction>(instr)->output_to_operand_aliasing();
+ absl::flat_hash_set<int64_t> aliased_operands;
+ for (const auto& pair : aliasing_pairs) {
+ aliased_operands.insert(pair.second.first);
+ }
+
+ for (const auto* operand : instr->operands()) {
+ // output_to_operand_aliasing means the operand is to be materialized, which
+ // is against the whole idea of address computation fusion. Skip this
+ // operand.
+ if (aliased_operands.contains(instr->operand_index(operand))) continue;
+ UseDefDataflowPath maybe_sliced_operand_path;
+ bool slice_found = false;
+ // TODO: currently HloFindIf exits upon encountering the first node that
+ // matches. This works well if each operand only has 1 data flow (i.e. only
+ // flows through unary op). We might want to keep finding until the queue is
+ // empty: if the operand is a tuple, it might have different data flows
+ // (i.e. 1 for each element).
+ auto maybe_slice_instr =
+ HloBfsFindIf({operand}, [&](const HloInstruction* cur) {
+ // If the node is a match that has been processed, stop the traversal.
+ if (processed_instrs.contains(cur)) return true;
+
+ maybe_sliced_operand_path.push_back(const_cast<HloInstruction*>(cur));
+
+ if (IsOpcodeAnyOf<HloOpcode::kDynamicSlice, HloOpcode::kSlice>(cur)) {
+ if (IsAlignedSlice(cur)) {
+ slice_found = true;
+ return slice_found;
+ }
+ }
+
+ return !IsNoOp(cur);
+ });
+
+ if (maybe_slice_instr == std::nullopt) continue;
+ auto dynamic_index_operation =
+ DynCast<HloDynamicIndexInstruction>(maybe_slice_instr.value());
+ bool valid_slice_found =
+ slice_found &&
+ ((dynamic_index_operation &&
+ HasConstantOrLoopIterationOffsets(*dynamic_index_operation)) ||
+ (*maybe_slice_instr)->opcode() == HloOpcode::kSlice);
+ if (valid_slice_found ||
+ processed_instrs.contains(maybe_slice_instr.value())) {
+ // Even in the case of stopping at a match that has been processed, we
+ // still need to add instructions encountered in the sliced operand path
+ // during the latest traversal.
+ sliced_operand_paths.insert(sliced_operand_paths.end(),
+ maybe_sliced_operand_path.rbegin(),
+ maybe_sliced_operand_path.rend());
+ processed_instrs.insert(maybe_sliced_operand_path.begin(),
+ maybe_sliced_operand_path.end());
+ }
+ }
+
+ sliced_operand_paths.push_back(const_cast<HloInstruction*>(instr));
+ return sliced_operand_paths;
+}
+
+// Each user of `instr` that goes into a DUS will have an entry in the returned
+// vector.
+// Each entry contains the sliced paths for that user, i.e. the sequence of ops
+// following the dataflow from the user itself to the DUS (included).
+DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) {
+ DefUseDataflowPaths sliced_user_paths;
+ // This set is used to avoid duplicates in the matched results. It contains
+ // the matched instructions that we have seen so far.
+ InstructionSet processed_instrs;
+
+ auto traverse_hlo_and_collect = [&](HloInstruction* start) {
+ DefUseDataflowPath maybe_sliced_user_path;
+ bool dus_found = false;
+ auto maybe_dus_instr = HloBfsFindIf(
+ {start},
+ [&](const HloInstruction* cur) {
+ // If the node is a match that has been processed, stop the
+ // traversal.
+ if (processed_instrs.contains(cur)) return true;
+ maybe_sliced_user_path.push_back(const_cast<HloInstruction*>(cur));
+ if (const auto slice_instr =
+ DynCast<HloDynamicUpdateSliceInstruction>(cur)) {
+ if (IsAlignedSlice(slice_instr)) {
+ dus_found = true;
+ return true;
+ }
+ }
+ return cur->user_count() > 1 || !IsNoOp(cur);
+ },
+ /*visit_operands=*/false);
+ if (maybe_dus_instr == std::nullopt) return;
+ auto dynamic_index_operation =
+ DynCast<HloDynamicIndexInstruction>(maybe_dus_instr.value());
+ bool valid_dus_found =
+ dus_found && dynamic_index_operation &&
+ HasConstantOrLoopIterationOffsets(*dynamic_index_operation);
+ if (valid_dus_found || processed_instrs.contains(maybe_dus_instr.value())) {
+ // Even in the case of stopping at a match that has been processed, we
+ // still need to add instructions encountered in the sliced user path
+ // during the latest traversal.
+ processed_instrs.insert(maybe_sliced_user_path.begin(),
+ maybe_sliced_user_path.end());
+ sliced_user_paths.push_back(std::move(maybe_sliced_user_path));
+ }
+ };
+
+ if (instr->shape().IsTuple()) {
+ for (auto* user : instr->users()) {
+ if (DynCast<HloGetTupleElementInstruction>(user)) {
+ traverse_hlo_and_collect(user);
+ }
+ }
+ } else {
+ if (instr->user_count() == 1) {
+ traverse_hlo_and_collect(instr->users().front());
+ }
+ }
+
+ return sliced_user_paths;
+}
+
+absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
+ DataflowPathView matches) {
+ absl::InlinedVector<HloInstruction*, 4> captures;
+
+ InstructionSet matched_instrs(matches.begin(), matches.end());
+
+ for (HloInstruction* instr : matches) {
+ for (HloInstruction* operand : instr->operands()) {
+ if (!matched_instrs.contains(operand) &&
+ absl::c_find(captures, operand) == captures.end()) {
+ captures.emplace_back(operand);
+ }
+ }
+ }
+
+ return captures;
+}
+
+absl::Status CreateRootTuple(
+ HloInstruction* hero, HloComputation::Builder& builder,
+ DataflowPathsView sliced_user_paths,
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
+ instr_mapping) {
+ unsigned tuple_size = hero->shape().tuple_shapes_size();
+
+ std::vector<HloInstruction*> sliced_elems(tuple_size, nullptr);
+ for (auto& sliced_user_path : sliced_user_paths) {
+ auto gte = Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
+ sliced_elems[gte->tuple_index()] = sliced_user_path.back();
+ }
+
+ std::vector<HloInstruction*> elements;
+ for (size_t i = 0; i < tuple_size; ++i) {
+ if (sliced_elems[i] != nullptr) {
+ elements.push_back(instr_mapping[sliced_elems[i]]);
+ continue;
+ }
+ auto* gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(instr_mapping[hero], i));
+ if (hero->shape().tuple_shapes(i).IsTuple()) {
+ instr_mapping[gte] = gte;
+ TF_RETURN_IF_ERROR(CreateRootTuple(gte, builder, {}, instr_mapping));
+ elements.push_back(builder.last_added_instruction());
+ } else {
+ elements.push_back(gte);
+ }
+ }
+ if (elements.size() > 1)
+ builder.AddInstruction(HloInstruction::CreateTuple(elements));
+
+ return absl::OkStatus();
+}
+
+absl::StatusOr<HloComputation*> CreateFusionBody(
+ HloModule* module, DataflowPathView sliced_operand_paths,
+ DataflowPathsView sliced_user_paths, DataflowPathView captures) {
+ HloComputation::Builder builder("address-computation");
+
+ // A mapping from original instructions to instructions in the fusion body.
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
+
+ auto mapped_operands = [&](HloInstruction* instr) {
+ absl::InlinedVector<HloInstruction*, 4> operands;
+ for (HloInstruction* operand : instr->operands()) {
+ operands.push_back(instr_mapping.at(operand));
+ }
+ return operands;
+ };
+
+ // For every captured value create a parameter instruction in the computation
+ // body and set up instruction mapping.
+ for (const HloInstruction* capture : captures) {
+ int64_t index = instr_mapping.size();
+ instr_mapping[capture] =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ index, capture->shape(), absl::StrCat("p", index)));
+ }
+
+ // Instructions in the pattern are already topologically sorted, as we visited
+ // them following use-def path, then reverse the list.
+ HloInstruction* hero;
+ for (HloInstruction* instr : sliced_operand_paths) {
+ instr_mapping[instr] = builder.AddInstruction(
+ instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
+ hero = instr;
+ }
+
+ for (auto& sliced_user_path : sliced_user_paths) {
+ for (HloInstruction* instr : sliced_user_path) {
+ instr_mapping[instr] = builder.AddInstruction(
+ instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
+ }
+ }
+
+ // Create a tuple if the hero is a tuple to make sure there's a buffer
+ // assigned for each of the elements. Make sure the tuple is not nil first.
+ if (hero->shape().IsTuple() && hero->shape().tuple_shapes_size() > 0) {
+ TF_RETURN_IF_ERROR(
+ CreateRootTuple(hero, builder, sliced_user_paths, instr_mapping));
+ }
+
+ return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
+}
+
+absl::StatusOr<HloInstruction*> CreateFusionInstruction(
+ HloModule* module, HloInstruction* orig, DataflowPathView captures,
+ HloComputation* body, bool dynamic) {
+ HloComputation* parent = orig->parent();
+
+ // Add a fusion operation calling outlined fusion computation.
+ HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
+ body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
+ captures, body));
+ module->SetAndUniquifyInstrName(fusion, "address_computation");
+
+ // We don't need to set/update output_to_operand_aliasing for the new fusion
+ // instruction because all buffers are already assigned at this point.
+
+ // Set backends config to a matched custom fusion config.
+ GpuBackendConfig gpu_config;
+ FusionBackendConfig& backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+ backend_config.set_kind("__custom_fusion");
+ CustomFusionConfig config;
+ config.set_name(dynamic ? "dynamic_address_computation"
+ : "address_computation");
+ *backend_config.mutable_custom_fusion_config() = config;
+ TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
+
+ return fusion;
+}
+
+} // namespace
+
+absl::StatusOr<bool> DynamicSliceFusionRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ absl::flat_hash_map<HloInstruction*,
+ std::pair<UseDefDataflowPaths, DefUseDataflowPaths>>
+ matches;
+
+ // Collect all potential custom call matches in the non-fusion computations.
+ for (HloComputation* computation : module->computations()) {
+ if (computation->IsFusionComputation()) continue;
+ for (HloInstruction* instr : computation->instructions()) {
+ UseDefDataflowPaths sliced_operand_paths = {instr};
+ bool has_sliced_operand_paths = false;
+ if (IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) {
+ sliced_operand_paths = GetSlicedOperandPaths(instr);
+ has_sliced_operand_paths = sliced_operand_paths.size() > 1;
+ }
+ if (instr->opcode() == HloOpcode::kReduceScatter ||
+ IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) {
+ DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr);
+ bool has_sliced_user_paths = absl::c_any_of(
+ sliced_user_paths,
+ [&](auto& sliced_user_path) { return !sliced_user_path.empty(); });
+
+ if (absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) {
+ return DynCast<HloDynamicUpdateSliceInstruction>(
+ sliced_user_path.back()) == nullptr;
+ })) {
+ return absl::InternalError(
+ "Expect sliced user path to end with a DUS.");
+ }
+
+ if (has_sliced_operand_paths || has_sliced_user_paths) {
+ matches[instr] = std::make_pair(std::move(sliced_operand_paths),
+ std::move(sliced_user_paths));
+ }
+ }
+ }
+ }
+
+ if (matches.empty()) return false;
+
+ for (auto& [hero, paths] : matches) {
+ auto& [sliced_operand_paths, sliced_user_paths] = paths;
+ std::vector<HloInstruction*> matched_instrs;
+ absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs));
+
+ std::vector<DataflowPathView> sliced_user_paths_view;
+ for (auto& sliced_user_path : sliced_user_paths) {
+ absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs));
+ DataflowPathView sliced_user_path_view{&sliced_user_path.front(),
+ sliced_user_path.size()};
+ sliced_user_paths_view.push_back(std::move(sliced_user_path_view));
+ }
+
+ auto captures = GetPatternCaptures(matched_instrs);
+
+ TF_ASSIGN_OR_RETURN(
+ HloComputation * fusion_body,
+ CreateFusionBody(module, sliced_operand_paths,
+ DataflowPathsView(sliced_user_paths_view), captures));
+
+ bool has_dynamic_slices = absl::c_any_of(matched_instrs, [&](auto* instr) {
+ return DynCast<HloDynamicIndexInstruction>(instr) != nullptr;
+ });
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * fusion,
+ CreateFusionInstruction(module, hero, captures, fusion_body,
+ has_dynamic_slices));
+
+ HloComputation* parent = hero->parent();
+ if (fusion->shape().IsTuple()) {
+ TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape(
+ const_cast<HloInstruction*>(hero), fusion));
+ for (auto& sliced_user_path : sliced_user_paths) {
+ auto old_gte =
+ Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
+ HloInstruction* gte =
+ parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+ fusion, old_gte->tuple_index()));
+ TF_RETURN_IF_ERROR(
+ parent->ReplaceInstruction(sliced_user_path.back(), gte));
+ }
+ } else {
+ auto* instr_to_be_replaced = const_cast<HloInstruction*>(hero);
+ if (sliced_user_paths.empty()) {
+ // The only case where a tuple-shaped original hero op is fused into a
+ // non-tuple-shaped fusion is there's only one element of the original
+ // tuple being used. In that case, we need to replace that single
+ // get-tuple-element (instead of the hero op) with the fusion
+ // instruction.
+ if (hero->shape().IsTuple()) {
+ if (hero->user_count() != 1 ||
+ !DynCast<HloGetTupleElementInstruction>(hero->users().front())) {
+ return absl::InternalError(
+ "Expect a single get-tuple-element user of the original "
+ "tuple-shaped hero op when address computation fusion does "
+ "not return a tuple");
+ }
+ instr_to_be_replaced = hero->users().front();
+ }
+ } else {
+ instr_to_be_replaced = sliced_user_paths.front().back();
+ }
+ TF_RETURN_IF_ERROR(
+ parent->ReplaceInstruction(instr_to_be_replaced, fusion));
+ // This is required for collective operations which will not be removed.
+ if (hero->parent()) {
+ TF_RETURN_IF_ERROR(hero->parent()->RemoveInstruction(hero));
+ }
+ }
+ }
+
+ return true;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h
new file mode 100644
index 0000000..ad996de
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h
@@ -0,0 +1,91 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_
+
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Pattern matches (slice(s) + custom call) to custom address computation
+// fusions and rewrites them into fusion instructions and fusion computations.
+//
+// Example:
+//
+// ENTRY %main {
+// %p0 = bf16[2,8,8]{2,1,0} parameter(0)
+// %p1 = bf16[2,8,8]{2,1,0} parameter(1)
+// %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+// %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
+// %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+// %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
+// ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
+// custom_call_target="__cublas$gemm"
+// }
+//
+// After the pass:
+//
+// %address_computation {
+// %p0 = bf16[2,8,8]{2,1,0} parameter(0)
+// %p1 = bf16[2,8,8]{2,1,0} parameter(1)
+// %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+// %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
+// %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+// %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
+// ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
+// custom_call_target="__cublas$gemm"
+// }
+//
+// ENTRY %main {
+// %p0 = bf16[2,8,8]{2,1,0} parameter(0)
+// %p1 = bf16[2,8,8]{2,1,0} parameter(1)
+// ROOT %fusion.2 = bf16[8,8]{1,0} fusion(%p0, %p1),
+// kind=kCustom, calls=%address_computation,
+// backend_config={"fusion_backend_config":{
+// "kind":"__custom_fusion",
+// "custom_fusion_config":{"name":"address_computation"}
+// }}
+// }
+//
+class DynamicSliceFusionRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "address-computation-fusion-rewriter";
+ }
+
+ explicit DynamicSliceFusionRewriter(std::string platform_name)
+ : platform_name_(std::move(platform_name)) {}
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ std::string platform_name_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc
new file mode 100644
index 0000000..2bd7168
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc
@@ -0,0 +1,2064 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
+
+#include <cstddef>
+#include <optional>
+
+#include "absl/status/status.h"
+#include "xla/client/lib/constants.h"
+#include "xla/client/xla_builder.h"
+#include "xla/ffi/ffi.h"
+#include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/custom_call_target_registry.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+
+class DynamicSliceFusionRewriterTest : public HloTestBase {};
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemm) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWithWorkspace) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+ ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+ ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+ ; CHECK: tuple([[DOT]], [[WORKSPACE]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWorkspaceIgnored) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ ROOT %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+ ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+ ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+ ; CHECK: tuple([[DOT]], [[WORKSPACE]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotRoot) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandHasMultipleUsers) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[4,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[2:3], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %bitcast.41)
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[2:3], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[P0]], [[P1]])
+ ; CHECK-DAG: kind=kCustom, calls=%address-computation,
+ ; CHECK-DAG: backend_config={
+ ; CHECK-DAG: "kind":"__custom_fusion",
+ ; CHECK-DAG: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK-DAG: }
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[B0]])
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsHaveMultipleUsers) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+
+ ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation{{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+ ; CHECK: %address-computation{{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmSlicingNotParameter) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[4,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.12 = f16[2,8,8]{2,1,0} slice(%p0), slice={[0:2], [0:8], [0:8]}
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%slice.12), slice={[1:2], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[2,8,8]{2,1,0} slice([[P0]]), slice={[0:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[S0]], [[P1]])
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotContiguousSlice) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,4,6]{2,1,0} slice(%p0), slice={[1:2], [0:4], [0:6]}
+ %bitcast.41 = f16[4,6]{1,0} bitcast(%slice.13)
+ %slice.14 = f16[1,6,4]{2,1,0} slice(%p1), slice={[1:2], [0:6], [0:4]}
+ %bitcast.42 = f16[6,4]{1,0} bitcast(%slice.14)
+
+ ROOT %custom-call.1 = f16[4,4]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+ std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+ %add.0 = f16[1,8,8]{2,1,0} add(%slice.13, %slice.14)
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%add.0)
+ %slice.15 = f16[1,8,8]{2,1,0} slice(%p1), slice={[0:1], [0:8], [0:8]}
+ %slice.16 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %add.1 = f16[1,8,8]{2,1,0} add(%slice.15, %slice.16)
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%add.1)
+
+ ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+ std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmDuplicateOperand) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main {
+ %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
+ %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
+ %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
+ %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0}
+ %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240),
+ custom_call_target="__cublas$gemm",
+ backend_config={
+ "gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"20000",
+ "rhs_stride":"10000",
+ "grad_x":false,
+ "grad_y":false
+ }
+ }
+ %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0
+ %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]}
+ ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26),
+ custom_call_target="__cublas$gemm",
+ backend_config={
+ "gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"10000",
+ "rhs_stride":"10000",
+ "grad_x":false,
+ "grad_y":false
+ }
+ }
+ })";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
+ ; CHECK: [[S0:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[0:100], [0:100]}
+ ; CHECK-NOT: slice
+ ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) custom-call([[S0]], [[S0]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %p1 = f16[2,8,8]{2,1,0} parameter(0)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder2) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %p1 = f16[2,8,8]{2,1,0} parameter(1)
+ %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+ %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+ ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[0:1], [0:8], [0:8]}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+ ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandAliasingOutput) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
+ %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
+ %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
+ %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0}
+ %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[16:116], [0:100]}
+ %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]}
+ ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34),
+ custom_call_target="__cublas$gemm",
+ output_to_operand_aliasing={{0}: (2, {})},
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":1,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"10000",
+ "rhs_stride":"10000",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P2:%[^ ]+]] = f32[100,100]{1,0} parameter(2)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f32[100,100]{1,0} parameter(1)
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[16:116], [0:100]}
+ ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P1]], [[S1]], [[P2]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[P:%[^ ]+]] = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
+ ; CHECK: [[GTE0:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=0
+ ; CHECK: [[GTE1:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=1
+ ; CHECK: [[CONCAT:%[^ ]+]] = f32[200,100]{1,0} concatenate([[GTE0]], [[GTE1]]), dimensions={0}
+ ; CHECK: [[S:%[^ ]+]] = f32[100,100]{1,0} slice([[CONCAT]]), slice={[99:199], [0:100]}
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[CONCAT]], [[GTE0]], [[S]])
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsFromSameSlice) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[2,8,8]{2,1,0} parameter(0)
+ %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+ %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+ %bitcast.42 = f16[8,8]{0,1} bitcast(%slice.13)
+
+ ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{0,1} bitcast([[S0]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]])
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+static absl::Status Memcpy(se::Stream* stream, ffi::AnyBuffer src,
+ ffi::AnyBuffer dst) {
+ se::DeviceMemoryBase dst_mem = dst.device_memory();
+ se::DeviceMemoryBase src_mem = src.device_memory();
+ return stream->MemcpyD2D(&dst_mem, src_mem, src_mem.size());
+}
+
+XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
+ ffi::Ffi::Bind()
+ .Ctx<ffi::Stream>()
+ .Arg<ffi::AnyBuffer>() // src
+ .Arg<ffi::AnyBuffer>() // dst
+);
+XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", "gpu",
+ kMemcpy);
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCall) {
+ XlaBuilder b(TestName());
+ CustomCall(&b, "__xla_test$$memcpy",
+ /*operands=*/
+ {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
+ {128}, {1})},
+ ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"",
+ /*has_side_effect=*/false,
+ /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
+ /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
+ /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
+ TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+ xla::HloModuleConfig hlo_config(
+ xla::ProgramShape(computation.proto().host_program_shape()),
+ /*ignore_layouts=*/false);
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+ hlo_config.set_debug_options(debug_options);
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+ computation.proto(), hlo_config));
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
+ ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
+ ; CHECK: custom_call_target="__xla_test$$memcpy",
+ ; CHECK: api_version=API_VERSION_TYPED_FFI
+ ; CHECK: }
+
+ ; CHECK: ENTRY %{{.*}} {
+ ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42)
+ ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+ expected);
+}
+
+void Callback_Void(se::gpu::GpuStreamHandle stream, void** buffers,
+ const char* /*opaque*/, size_t /*opaque_len*/) {}
+
+XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, "gpu");
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCallLegacy) {
+ XlaBuilder b(TestName());
+ CustomCall(&b, "Callback_Void",
+ /*operands=*/
+ {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
+ {128}, {1})},
+ ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
+ TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+ xla::HloModuleConfig hlo_config(
+ xla::ProgramShape(computation.proto().host_program_shape()),
+ /*ignore_layouts=*/false);
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+ hlo_config.set_debug_options(debug_options);
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+ computation.proto(), hlo_config));
+ // TF_ASSERT_OK_AND_ASSIGN(
+ // HloSchedule schedule,
+ // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+ // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+ // }));
+ // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
+ ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
+ ; CHECK: custom_call_target="Callback_Void"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %{{.*}} {
+ ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42)
+ ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+ expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, TupleSliceCustomCallLegacy) {
+ XlaBuilder b(TestName());
+ CustomCall(
+ &b, "Callback_Void",
+ /*operands=*/
+ {
+ Tuple(&b,
+ {
+ Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
+ {0, 0}, {4, 8}, {1, 1}),
+ Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
+ }),
+ Tuple(&b,
+ {
+ Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
+ Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
+ }),
+ },
+ ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
+ TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+ xla::HloModuleConfig hlo_config(
+ xla::ProgramShape(computation.proto().host_program_shape()),
+ /*ignore_layouts=*/false);
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+ hlo_config.set_debug_options(debug_options);
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+ computation.proto(), hlo_config));
+ // TF_ASSERT_OK_AND_ASSIGN(
+ // HloSchedule schedule,
+ // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+ // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+ // }));
+ // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
+ ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
+ ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[T0]], [[P2]]),
+ ; CHECK: custom_call_target="Callback_Void"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %{{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion(
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+ expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, TupledOutputCustomCallLegacy) {
+ XlaBuilder b(TestName());
+ auto custom_call = CustomCall(
+ &b, "Callback_Void",
+ /*operands=*/
+ {
+ Tuple(&b,
+ {
+ Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
+ {0, 0}, {4, 8}, {1, 1}),
+ Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
+ }),
+ Tuple(&b,
+ {
+ Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
+ Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
+ }),
+ },
+ ShapeUtil::MakeTupleShape({
+ ShapeUtil::MakeShape(F32, {8}),
+ ShapeUtil::MakeTupleShape({
+ ShapeUtil::MakeShape(F32, {128}),
+ ShapeUtil::MakeShape(F32, {256}),
+ }),
+ ShapeUtil::MakeShape(F32, {1024}),
+ ShapeUtil::MakeShape(F32, {4, 8}),
+ }),
+ /*opaque=*/"");
+ Tuple(&b, {GetTupleElement(GetTupleElement(custom_call, 1), 0),
+ GetTupleElement(custom_call, 2)});
+ TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+ xla::HloModuleConfig hlo_config(
+ xla::ProgramShape(computation.proto().host_program_shape()),
+ /*ignore_layouts=*/false);
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+ hlo_config.set_debug_options(debug_options);
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+ computation.proto(), hlo_config));
+ // TF_ASSERT_OK_AND_ASSIGN(
+ // HloSchedule schedule,
+ // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+ // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+ // }));
+ // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+ const char* expected = R"(
+ ; CHECK: %address-computation {{.*}} {
+ ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
+ ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
+ ; CHECK: [[CC:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) custom-call([[T0]], [[P2]]),
+ ; CHECK: custom_call_target="Callback_Void"
+ ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[8]{0} get-tuple-element([[CC]]), index=0
+ ; CHECK-DAG: [[GTE1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[CC]]), index=1
+ ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE1]]), index=0
+ ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[256]{0} get-tuple-element([[GTE1]]), index=1
+ ; CHECK-DAG: [[T1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) tuple([[GTE2]], [[GTE3]])
+ ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1024]{0} get-tuple-element([[CC]]), index=2
+ ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,8]{1,0} get-tuple-element([[CC]]), index=3
+ ; CHECK: ROOT {{.*}} = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) tuple([[GTE0]], [[T1]], [[GTE4]], [[GTE5]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK-DAG: [[GTE6:%[^ ]+]] = f32[1024]{0} get-tuple-element([[FUSION]]), index=2
+ ; CHECK-DAG: [[GTE7:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[FUSION]]), index=1
+ ; CHECK-DAG: [[GTE8:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE7]]), index=0
+ ; CHECK: ROOT {{.*}} = (f32[128]{0}, f32[1024]{0}) tuple([[GTE8]], [[GTE6]])
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+ expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, UnalignedSlice) {
+ XlaBuilder b(TestName());
+ CustomCall(
+ &b, "Callback_Void",
+ /*operands=*/
+ {Slice(Broadcast(ConstantR0WithType(&b, S32, 42), {17}), {1}, {17}, {1})},
+ ShapeUtil::MakeShape(S32, {16}), /*opaque=*/"");
+ TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+ xla::HloModuleConfig hlo_config(
+ xla::ProgramShape(computation.proto().host_program_shape()),
+ /*ignore_layouts=*/false);
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+ hlo_config.set_debug_options(debug_options);
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+ computation.proto(), hlo_config));
+ // TF_ASSERT_OK_AND_ASSIGN(
+ // HloSchedule schedule,
+ // ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+ // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+ // }));
+ // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+ std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemm) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[2,8,8]{2,1,0} parameter(0)
+ p1 = f16[2,8,8]{2,1,0} parameter(1)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+ slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+ ROOT custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWithWorkspace) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[2,8,8]{2,1,0} parameter(0)
+ p1 = f16[2,8,8]{2,1,0} parameter(1)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+ slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+ ROOT custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+ ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+ ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+ ; CHECK: tuple([[DOT]], [[WORKSPACE]])
+ ; CHECK: }
+
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWorkspaceIgnored) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[2,8,8]{2,1,0} parameter(0)
+ p1 = f16[2,8,8]{2,1,0} parameter(1)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+ slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+ custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ ROOT get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+ ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+ ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+ ; CHECK: tuple([[DOT]], [[WORKSPACE]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmNotRoot) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[2,8,8]{2,1,0} parameter(0)
+ p1 = f16[2,8,8]{2,1,0} parameter(1)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+ slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+ custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ ROOT res = f16[8,8]{1,0} add(custom-call.1, custom-call.1)
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemm) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[1,8,8]{2,1,0} parameter(0)
+ p1 = f16[1,8,8]{2,1,0} parameter(1)
+ p2 = f16[4,8,8]{2,1,0} parameter(2)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ bitcast.41 = f16[8,8]{1,0} bitcast(p0)
+ bitcast.42 = f16[8,8]{1,0} bitcast(p1)
+
+ custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+ ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
+ ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(3)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(4)
+ ; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]),
+ ; CHECK-DAG: custom_call_target="__cublas$gemm"
+ ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
+ ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmNotRoot) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[2,8,8]{2,1,0} parameter(0)
+ p1 = f16[2,8,8]{2,1,0} parameter(1)
+ p2 = f16[4,8,8]{2,1,0} parameter(2)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+ slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+ custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+ dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
+ ROOT res = f16[4,8,8]{2,1,0} log(dus)
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+ ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+ ; CHECK-DAG: custom_call_target="__cublas$gemm"
+ ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
+ ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} log([[FUSION]])
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWithWorkspace) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[2,8,8]{2,1,0} parameter(0)
+ p1 = f16[2,8,8]{2,1,0} parameter(1)
+ p2 = f16[4,8,8]{2,1,0} parameter(2)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+ slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+ bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+ custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+
+ get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
+ bitcast.43 = f16[1,8,8]{2,1,0} bitcast(get-tuple-element.0)
+ dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
+ get-tuple-element.1 = s8[256]{0} get-tuple-element(custom-call.1), index=1
+ ROOT tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(dus, get-tuple-element.1)
+ }
+ )";
+
+ const char* expected = R"(
+ ; CHECK: address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+ ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
+ ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+ ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+ ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+ ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+ ; CHECK: custom_call_target="__cublas$gemm"
+ ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+ ; CHECK: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
+ ; CHECK: [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+ ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+ ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
+ ; CHECK: tuple([[DUS]], [[WORKSPACE]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: [[DUS_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
+ ; CHECK: [[WORKSPACE_MAIN:%[^ ]+]] = s8[256]{0} get-tuple-element([[FUSION]]), index=1
+ ; CHECK: ROOT {{.*}} = (f16[4,8,8]{2,1,0}, s8[256]{0})
+ ; CHECK: tuple([[DUS_MAIN]], [[WORKSPACE_MAIN]])
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY %main.9 {
+ %p0 = f16[8,8]{1,0} parameter(0)
+ %p1 = f16[8,8]{1,0} parameter(1)
+ %p2 = f16[4,8,8]{2,1,0} parameter(2)
+ %c1_s32 = s32[] constant(1)
+ %c0_s32 = s32[] constant(0)
+
+ %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%p0, %p1),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
+ %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0)
+ ROOT %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32)
+ })";
+
+ const char* expected = R"(
+ ; CHECK: address-computation {{.*}} {
+ ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
+ ; CHECK-DAG: [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
+ ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
+ ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(3)
+ ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(4)
+ ; CHECK-DAG: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[P0]], [[P1]]),
+ ; CHECK-DAG: custom_call_target="__cublas$gemm"
+ ; CHECK-DAG: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+ ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
+ ; CHECK-DAG: [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+ ; CHECK-DAG: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+ ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
+ ; CHECK: tuple([[DUS]], [[WORKSPACE]])
+ ; CHECK: }
+
+ ; CHECK: ENTRY %main{{.*}} {
+ ; CHECK: [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
+ ; CHECK: kind=kCustom, calls=%address-computation,
+ ; CHECK: backend_config={
+ ; CHECK: "kind":"__custom_fusion",
+ ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+ ; CHECK: }
+ ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
+ ; CHECK: }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSConstantOffset) {
+ const char* hlo = R"(
+ HloModule test, replica_count=2
+
+ add {
+ param_0 = f16[] parameter(0)
+ param_1 = f16[] parameter(1)
+ ROOT add.1 = f16[] add(param_0, param_1)
+ }
+
+ ENTRY main.9 {
+ param_0 = f16[128,128]{1,0} parameter(0)
+ param_1 = f16[128,128]{1,0} parameter(1)
+ constant_20 = u32[] constant(20)
+ constant_0 = u32[] constant(0)
+ reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
+ ROOT loop_dynamic_update_slice_fusion = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, constant_20, constant_0)
+ }
+ )";
+
+ const char* expected = R"(
+ // CHECK: %address-computation{{.+}} {
+ // CHECK: %[[RS:.+]] = f16[64,128]{1,0} reduce-scatter({{.+}})
+ // CHECK: ROOT %{{.+}} = f16[128,128]{1,0} dynamic-update-slice(%{{.+}}, %[[RS]], %{{.+}}, %{{.+}})
+ // CHECK: }
+ // CHECK: ENTRY {{.+}} {
+ // CHECK-NOT: reduce-scatter
+ // CHECK: ROOT %{{.+}} = {{.+}} fusion(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}), kind=kCustom, calls=%address-computation, {{.+}}"name":"dynamic_address_computation"
+ // CHECK: }
+ )";
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSParameterOffset) {
+ const char* hlo = R"(
+ HloModule test, replica_count=2
+
+ add.clone {
+ x.1 = f16[] parameter(0)
+ y.1 = f16[] parameter(1)
+ ROOT add.462 = f16[] add(x.1, y.1)
+ }
+
+ ENTRY %main.9 {
+ param_0 = f16[128,128]{1,0} parameter(0)
+ param_1 = f16[128,128]{1,0} parameter(1)
+ param_2 = u32[] parameter(2)
+ constant_0 = u32[] constant(0)
+ reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone
+ ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, param_2, constant_0)
+ })";
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+ std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSLoopIterationOffset) {
+ const char* hlo = R"(
+ HloModule jit_scan, replica_count=2
+
+ add {
+ param_0 = f32[] parameter(0)
+ param_1 = f32[] parameter(1)
+ ROOT add.6 = f32[] add(param_0, param_1)
+ }
+
+ Body {
+ arg_tuple.1 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(arg_tuple.1), index=0
+ constant.1 = s32[] constant(1)
+ add.7 = s32[] add(get-tuple-element.5, constant.1)
+ get-tuple-element.6 = f32[128,128]{1,0} get-tuple-element(arg_tuple.1), index=3
+ get-tuple-element.7 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.1), index=2
+ reduce-scatter.0 = f32[64,128]{1,0} reduce-scatter(get-tuple-element.6), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
+ bitcast.63 = f32[1,64,128]{2,1,0} bitcast(reduce-scatter.0)
+ constant.2 = s32[] constant(0)
+ compare.4 = pred[] compare(get-tuple-element.5, constant.2), direction=LT
+ constant.3 = s32[] constant(128)
+ add.8 = s32[] add(get-tuple-element.5, constant.3)
+ select.2 = s32[] select(compare.4, add.8, get-tuple-element.5)
+ dynamic-update-slice.2 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.7, bitcast.63, select.2, constant.2, constant.2)
+ ROOT tuple.1 = tuple(add.7, get-tuple-element.6, dynamic-update-slice.2, get-tuple-element.6)
+ } // Body
+
+ Cond {
+ arg_tuple.0 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+ get-tuple-element.4 = s32[] get-tuple-element(arg_tuple.0), index=0
+ constant.0 = s32[] constant(128)
+ ROOT compare.5 = pred[] compare(get-tuple-element.4, constant.0), direction=LT
+ }
+
+ ENTRY main.55 {
+ Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2)
+ constant.4 = s32[] constant(0)
+ Arg_1.2 = f32[128,128]{1,0} parameter(1)
+ constant.5 = f32[] constant(0)
+ broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={}
+ Arg_0.1 = f32[128,128]{1,0} parameter(0)
+ tuple = tuple(constant.4, Arg_1.2, broadcast.1, Arg_0.1)
+ while = while(tuple), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}}
+ get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while), index=1
+ get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while), index=2
+ ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51)
+ })";
+ const char* expected = R"(
+ // CHECK: %address-computation{{.*}}{
+ // CHECK: {{.+}} = {{.*}}reduce-scatter({{.+}})
+ // CHECK: {{.+}} = {{.*}}dynamic-update-slice({{.+}})
+ // CHECK: }
+ // CHECK: Body{{.+}}{
+ // CHECK-NOT: {{.+}} = {{.*}}reduce-scatter({{.+}})
+ // CHECK: {{.+}} = {{.+}}fusion({{.+}}), kind=kCustom, calls=%address-computation{{.*}}"name":"dynamic_address_computation"
+ // CHECK: }
+ )";
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) {
+ const char* hlo = R"(
+ HloModule test
+
+ %Body {
+ param = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0)
+ p0 = get-tuple-element(param), index=0
+ p1 = get-tuple-element(param), index=1
+ p2 = get-tuple-element(param), index=2
+ loop_iter = get-tuple-element(param), index=3
+
+ bitcast.41 = f16[8,8]{1,0} bitcast(p0)
+ bitcast.42 = f16[8,8]{1,0} bitcast(p1)
+ custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+ c0 = u32[] constant(0)
+ c_trip_count = u32[] constant(11)
+ compare = pred[] compare(loop_iter, c0), direction=LT
+ add = u32[] add(loop_iter, c_trip_count)
+ offset = u32[] select(compare, add, loop_iter)
+ dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, offset, c0, c0)
+ c1 = u32[] constant(1)
+ add2 = u32[] add(loop_iter, c1)
+ ROOT tuple = tuple(p0, p1, dus, u32[] add2)
+ }
+
+ %Cond {
+ %param.1 = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0)
+ %i.1 = u32[] get-tuple-element(%param.1), index=3
+ %trip_count = u32[] constant(11)
+ ROOT %done = pred[] compare(u32[] %i.1, u32[] %trip_count), direction=LT
+ }
+
+ ENTRY %test {
+ %p0.1 = f16[1,8,8]{2,1,0} parameter(0)
+ %p1.1 = f16[1,8,8]{2,1,0} parameter(1)
+ %p2.1 = f16[4,8,8]{2,1,0} parameter(2)
+ %c0.1 = u32[] constant(0)
+ %initial_tuple = tuple(%p0.1, %p1.1, %p2.1, u32[] %c0.1)
+ ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}}
+ })";
+
+ const char* expected = R"(
+ // CHECK: %Body{{.+}}{
+ // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0)
+ // CHECK: %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=3
+ // CHECK: %[[OFFSET:.+]] = u32[] select({{.+}})
+ // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, {{.+}}, {{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation, {{.+}}"name":"dynamic_address_computation"
+ // CHECK: ROOT %tuple = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}})
+ // CHECK: }
+ // CHECK: ENTRY %test{{.+}}{
+ // CHECK: ROOT %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"11"}}
+ }
+ )";
+
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmParameterOffset) {
+ const char* hlo = R"(
+ HloModule test
+
+ ENTRY main.9 {
+ p0 = f16[1,8,8]{2,1,0} parameter(0)
+ p1 = f16[1,8,8]{2,1,0} parameter(1)
+ p2 = f16[4,8,8]{2,1,0} parameter(2)
+ p3 = s32[] parameter(3)
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ bitcast.41 = f16[8,8]{1,0} bitcast(p0)
+ bitcast.42 = f16[8,8]{1,0} bitcast(p1)
+
+ custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+ custom_call_target="__cublas$gemm",
+ backend_config={"gemm_backend_config":{
+ "alpha_real":1,
+ "beta":0,
+ "dot_dimension_numbers":{
+ "lhs_contracting_dimensions":["1"],
+ "rhs_contracting_dimensions":["0"],
+ "lhs_batch_dimensions":[],
+ "rhs_batch_dimensions":[]
+ },
+ "alpha_imag":0,
+ "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+ "epilogue":"DEFAULT",
+ "lhs_stride":"64",
+ "rhs_stride":"64",
+ "grad_x":false,
+ "grad_y":false
+ }}
+ bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+ ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, p3, c0_s32, c0_s32)
+ })";
+
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+ std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) {
+ const char* hlo = R"(
+ HloModule lax_scan
+
+ // This is the HLO generated for the following:
+ //
+ // inp = jax.random.uniform(jax.random.key(128), (128, 128, 128))
+ // init = jnp.identity(128)
+ // ans = jax.lax.scan(lambda carry, x : (init, x@carry), init, inp)
+
+ Body {
+ arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+ get-tuple-element.16 = s32[] get-tuple-element(arg_tuple.15), index=0
+ constant.21 = s32[] constant(1)
+ add.2 = s32[] add(get-tuple-element.16, constant.21)
+ get-tuple-element.30 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=4
+ get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=2
+ get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=3
+ constant.23 = s32[] constant(0)
+ compare.2 = pred[] compare(get-tuple-element.16, constant.23), direction=LT
+ constant.22 = s32[] constant(128)
+ add.3 = s32[] add(get-tuple-element.16, constant.22)
+ select.1 = s32[] select(compare.2, add.3, get-tuple-element.16)
+ dynamic-slice.1 = f32[1,128,128]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,128,128}
+ bitcast.72 = f32[128,128]{1,0} bitcast(dynamic-slice.1)
+ get-tuple-element.17 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=1
+ custom-call.1 = (f32[128,128]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm"
+ get-tuple-element = f32[128,128]{1,0} get-tuple-element(custom-call.1), index=0
+ bitcast.77 = f32[1,128,128]{2,1,0} bitcast(get-tuple-element)
+ dynamic-update-slice.1 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23)
+ ROOT tuple.38 = tuple(add.2, get-tuple-element.30, dynamic-update-slice.1, get-tuple-element.19, get-tuple-element.30)
+ } // Body
+
+ Cond {
+ arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+ get-tuple-element.41 = s32[] get-tuple-element(arg_tuple.40), index=0
+ constant.46 = s32[] constant(128)
+ ROOT compare.3 = pred[] compare(get-tuple-element.41, constant.46), direction=LT
+ }
+
+ ENTRY main {
+ constant.4 = s32[] constant(0)
+ Arg_1.2 = f32[128,128]{1,0} parameter(1)
+ constant.5 = f32[] constant(0)
+ broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={}
+ Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2)
+ Arg_0.1 = f32[128,128]{1,0} parameter(0)
+ tuple.7 = tuple(constant.4, Arg_1.2, broadcast.1, Arg_2.3, Arg_0.1)
+ while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}}
+ get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while.48), index=1
+ get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while.48), index=2
+ ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51)
+ } // main.55
+
+)";
+ auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ const char* expected = R"(
+ // CHECK: %address-computation{{.*}} {{.+}} {
+ // CHECK: {{.+}} = {{.+}}dynamic-slice
+ // CHECK: {{.+}} = {{.+}}custom-call
+ // CHECK: {{.+}} = {{.+}}dynamic-update-slice
+ // CHECK: }
+ // CHECK: %Body{{.+}}{
+ // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0)
+ // CHECK: %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=0
+ // CHECK: %[[OFFSET:.+]] = s32[] select({{.+}})
+ // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation{{.+}}"name":"dynamic_address_computation"
+ // CHECK: %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0
+ // CHECK: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}})
+ // CHECK: }
+ // CHECK: ENTRY %main{{.+}}{
+ // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"128"}}
+ // CHECK: }
+ )";
+ RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc
new file mode 100644
index 0000000..5a09bf5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc
@@ -0,0 +1,327 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/fusion_merger.h"
+
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/model/gpu_performance_model.h"
+#include "xla/service/gpu/model/gpu_performance_model_base.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_graph_dumper.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
+
+namespace xla {
+namespace gpu {
+
+// For each fusion F, attempts to fuse F into *all* of F's users (does not fuse
+// if can't fuse into at least one).
+class FusionInstructionMerger {
+ public:
+ explicit FusionInstructionMerger(
+ HloComputation* computation, const se::DeviceDescription& gpu_device_info,
+ HloCostAnalysis::ShapeSizeFunction shape_size_function)
+ : computation_(computation),
+ shape_size_function_(shape_size_function),
+ gpu_device_info_(gpu_device_info),
+ dump_fusion_visualization_(computation->parent()
+ ->config()
+ .debug_options()
+ .xla_dump_fusion_visualization()) {}
+
+ absl::Status Run();
+
+ bool changed() const { return changed_; }
+
+ private:
+ FusionDecision ShouldFuse(HloInstruction* producer);
+ absl::Status FuseIntoAllUsers(HloInstruction* producer);
+
+ HloComputation* computation_;
+ HloCostAnalysis::ShapeSizeFunction shape_size_function_;
+ // Many cheap checks can prevent fusion merging - postpone execution of full
+ // HLO cost analysis of the computation so that it may be not needed at all.
+ std::optional<GpuHloCostAnalysis> cost_analysis_;
+ FusionInfoCache fusion_info_cache_;
+ const se::DeviceDescription& gpu_device_info_;
+ bool changed_ = false;
+ bool dump_fusion_visualization_ = false;
+
+ // Fusion instruction merge stats.
+ int total_visited_ = 0;
+ int total_merged_ = 0;
+ int num_fail_no_users_ = 0;
+ int num_fail_not_loop_fusion_ = 0;
+ int num_fail_merge_all_users_ = 0;
+ int num_fail_inefficient_fusion_emitter_ = 0;
+ int num_fail_fusion_too_large_ = 0;
+ int num_fail_uncoalesced_read_ = 0;
+ int num_fail_slower_if_fused_ = 0;
+
+ FusionInstructionMerger(const FusionInstructionMerger&) = delete;
+ FusionInstructionMerger& operator=(const FusionInstructionMerger&) = delete;
+};
+
+absl::Status FusionInstructionMerger::FuseIntoAllUsers(
+ HloInstruction* producer) {
+ // Merge fused instructions from 'fusion' into each user.
+ std::vector<HloInstruction*> users = producer->users();
+ for (HloInstruction* user : users) {
+ if (dump_fusion_visualization_) {
+ RegisterFusionState(
+ *computation_,
+ absl::StrCat("About to fuse |", producer->name(), "| into |",
+ user->name(), "| inside FusionMerger"),
+ /*consumer=*/*user,
+ /*producer=*/producer);
+ }
+
+ TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(user));
+
+ // Wrap consumers which are not fusions first.
+ HloInstruction* consumer = user;
+ if (consumer->opcode() != HloOpcode::kFusion) {
+ consumer = computation_->AddInstruction(HloInstruction::CreateFusion(
+ user->shape(), ChooseFusionKind(*producer, *user), user));
+ TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer));
+ }
+
+ consumer->MergeFusionInstruction(producer);
+ TF_RETURN_IF_ERROR(cost_analysis_->RevisitInstruction(consumer));
+ fusion_info_cache_.Invalidate(consumer);
+
+ if (dump_fusion_visualization_) {
+ RegisterFusionState(*computation_,
+ absl::StrCat("Fused |", producer->name(), "| into |",
+ user->name(), "| inside FusionMerger"),
+ *consumer);
+ }
+
+ changed_ = true;
+ }
+
+ CHECK_EQ(0, producer->user_count()) << producer->ToString();
+ TF_RETURN_IF_ERROR(computation_->RemoveInstruction(producer));
+ TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(producer));
+ fusion_info_cache_.Invalidate(producer);
+ VLOG(2) << "Merged fusion instruction: " << producer->name()
+ << " into users { "
+ << absl::StrJoin(users, ", ",
+ [](std::string* out, HloInstruction* user) {
+ absl::StrAppend(out, user->name());
+ })
+ << " }";
+ return absl::OkStatus();
+}
+
+absl::Status FusionInstructionMerger::Run() {
+ for (HloInstruction* producer : computation_->MakeInstructionPostOrder()) {
+ if (producer->opcode() != HloOpcode::kFusion) {
+ continue;
+ }
+ FusionDecision should_fuse = ShouldFuse(producer);
+ if (should_fuse) {
+ TF_RETURN_IF_ERROR(FuseIntoAllUsers(producer));
+ ++total_merged_;
+ } else {
+ VLOG(3) << "Not fusing fusion |" << producer->name()
+ << "| with all of it's users due to: " << should_fuse.Explain();
+ if (dump_fusion_visualization_ && !producer->users().empty()) {
+ RegisterFusionState(
+ *computation_,
+ absl::StrCat(
+ "Not fusing fusion |", producer->name(),
+ "| into all of its users due to: ", should_fuse.Explain()),
+ // Just pick any consumer, since we are trying to merge into all.
+ /*consumer=*/*producer->users()[0],
+ /*producer=*/producer);
+ }
+ }
+ }
+
+ VLOG(1) << "FusionInstructionMerger EXIT"
+ << " computation: " << computation_->name()
+ << " total_visited: " << total_visited_
+ << " total_merged: " << total_merged_ << " merge failures { "
+ << " no_users: " << num_fail_no_users_
+ << " not_loop_fusion: " << num_fail_not_loop_fusion_
+ << " merge_all_users: " << num_fail_merge_all_users_
+ << " uncoalesced_read: " << num_fail_uncoalesced_read_
+ << " inefficient_fusion_emitter: "
+ << num_fail_inefficient_fusion_emitter_
+ << " slower_if_fused: " << num_fail_slower_if_fused_
+ << " fusion_too_large: " << num_fail_fusion_too_large_ << " }";
+ return absl::OkStatus();
+}
+
+bool TransposesMostData(const HloInstruction& fusion) {
+ float score = 0;
+
+ for (const HloInstruction* instr : fusion.fused_instructions()) {
+ if (IsPhysicallyTransposing(*instr)) {
+ score += 1.0 * ShapeUtil::ElementsInRecursive(instr->shape()) /
+ ShapeUtil::ElementsInRecursive(fusion.shape());
+ if (score >= 0.5) {
+ VLOG(3) << fusion.ToString() << " transpose ratio exceeds " << score;
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) {
+ ++total_visited_;
+
+ VLOG(4) << "Considering producer " << producer->name();
+
+ // Skip 'producer' instruction if there are no users into which we can
+ // merge.
+ if (producer->users().empty()) {
+ ++num_fail_no_users_;
+ return "fusion has no users";
+ }
+
+ // Skip 'producer' instruction if it is not a loop fusion. Library fusion
+ // instructions match specific patterns, so they shouldn't be further fused.
+ // Input fusion instructions need to be rooted at a particular HLO (e.g.
+ // kReduce), so they shouldn't be further fused either.
+ if (!producer->IsLoopFusion()) {
+ ++num_fail_not_loop_fusion_;
+ return "not a loop fusion";
+ }
+
+ auto producer_hero = GetRealHeroForMultiOutputFusion(*producer);
+
+ bool has_reduction_user = false;
+ for (const HloInstruction* user : producer->users()) {
+ if (user->opcode() == HloOpcode::kBitcast) {
+ ++num_fail_merge_all_users_;
+ return "not fusing bitcast ops";
+ }
+ if (user->IsCustomFusion()) {
+ ++num_fail_merge_all_users_;
+ return "not fusing custom fusions";
+ }
+ auto consumer_hero = GetRealHeroForMultiOutputFusion(*user);
+ if (auto compatible =
+ FusionHeroesAreCompatible(producer_hero, consumer_hero);
+ !compatible) {
+ return compatible;
+ }
+ FusionDecision fusible = IsProducerConsumerFusible(*producer, *user);
+ if (!fusible) {
+ ++num_fail_merge_all_users_;
+ VLOG(9) << user->ToString();
+ return fusible;
+ }
+ if (IsInputFusibleReduction(*user)) {
+ has_reduction_user = true;
+ }
+ }
+
+ // We do not want to worsen reduction's memory access pattern by connecting
+ // it to a producer which transposes most data.
+ if (has_reduction_user && TransposesMostData(*producer)) {
+ ++num_fail_uncoalesced_read_;
+ return "would read mostly uncoalesced";
+ }
+
+ for (const HloInstruction* user : producer->users()) {
+ // Skip 'fusion' instruction if merging it into at least one of the users
+ // would make the fusion use too much shared memory or registers.
+ FusionDecision fits = FusionFitsInBudget(
+ *user, *producer, gpu_device_info_,
+ /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
+ if (!fits) {
+ ++num_fail_fusion_too_large_;
+ return fits;
+ }
+ }
+
+ if (!cost_analysis_) {
+ VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
+ cost_analysis_.emplace(
+ GpuHloCostAnalysis::Options{shape_size_function_,
+ /*per_second_rates=*/{},
+ /*count_multiple_input_accesses=*/true},
+ gpu_device_info_);
+ TF_CHECK_OK(computation_->Accept(&cost_analysis_.value()));
+ }
+
+ for (const HloInstruction* user : producer->users()) {
+ if (cost_analysis_->ProducerConsumerMergedTooLarge(*producer, *user)) {
+ ++num_fail_inefficient_fusion_emitter_;
+ return FusionDecision{} << "if merged with " << user->name()
+ << " will generate huge IR";
+ }
+ }
+
+ GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+ producer, gpu_device_info_, &*cost_analysis_,
+ GpuPerformanceModelOptions::Default(), producer->users());
+ if (t.time_fused > t.time_unfused) {
+ ++num_fail_slower_if_fused_;
+ return "will execute slower if fused";
+ }
+
+ return {};
+}
+
+absl::StatusOr<bool> FusionMerger::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ VLOG(1) << "FusionMerger for module: " << module->name();
+ for (auto* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ VLOG(9) << "Before running FusionInstructionMerger for computation: "
+ << computation->name();
+ XLA_VLOG_LINES(9, computation->ToString());
+
+ FusionInstructionMerger fusion_merger(computation, gpu_device_info_,
+ shape_size_function_);
+ TF_RETURN_IF_ERROR(fusion_merger.Run());
+ changed |= fusion_merger.changed();
+
+ VLOG(9) << "After running FusionInstructionMerger for computation: "
+ << computation->name() << " changed: " << changed;
+ XLA_VLOG_LINES(9, computation->ToString());
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.h b/third_party/xla/xla/service/gpu/transforms/fusion_merger.h
new file mode 100644
index 0000000..15ea960
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.h
@@ -0,0 +1,85 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// An HLO pass that attempts to merge fusion instructions to reduce memory
+// bandwidth requirements and kernel launch overhead.
+//
+// Consider the example below. On the left-hand side, op A is the producer and
+// ops B and C are its consumers. FusionMerger duplicates producer ops and fuses
+// them into all consumers. The result is depicted on the right-hand side below.
+//
+// p p
+// | / \
+// v / \
+// A +fusion+ +fusion+
+// / \ | A' | | A" |
+// | | | | | | | |
+// v v | v | | v |
+// B C | B | | C |
+// +------+ +------+
+//
+// Op A has been cloned twice and fused with B and C. The kernel launch overhead
+// is reduced from 3 to 2. The memory bandwidth requirements may be reduced.
+// We trade 1 read of input(A) + 1 write and 2 reads of output(A) for 2 reads of
+// input(A). In general the achieveable savings in memory bandwidth depend on
+// the differences in memory read and written and the number of consumers. The
+// FusionMeger pass takes this into account when making fusion decisions.
+//
+// The pass traverses the HLO module in post-order (defs before uses).
+// Fusion instructions are merged into their users if some conditions are met:
+// * The result of merging the fusion instruction into its users would not
+// increase bytes transferred.
+// * Producer ops are fusible with _all_ consumers. If they are not fusible with
+// at least one consumers, they won't be fused at all.
+// * Producers are kLoop fusion ops.
+//
+// None of these restrictions are necessary for correctness. In fact, lifting
+// the latter two could be beneficial.
+
+class FusionMerger : public HloModulePass {
+ public:
+ explicit FusionMerger(const se::DeviceDescription& d,
+ HloCostAnalysis::ShapeSizeFunction f)
+ : gpu_device_info_(d), shape_size_function_(f) {}
+ absl::string_view name() const override { return "fusion_merger"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::DeviceDescription gpu_device_info_;
+ HloCostAnalysis::ShapeSizeFunction shape_size_function_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc
new file mode 100644
index 0000000..5068f65
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc
@@ -0,0 +1,1170 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/fusion_merger.h"
+
+#include <cstdint>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class FusionMergerTest : public HloTestBase {
+ HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
+ return [&](const Shape& shape) {
+ constexpr int64_t kPointerSize = 8;
+ return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+ };
+ }
+
+ public:
+ FusionMerger fusion_merger_{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+ ShapeSizeBytesFunction()};
+ FusionMergerTest() : HloTestBase() {}
+};
+
+// Tests that we can merge a fusion instruction that is below threshold.
+//
+// Computation after fusion merger pass (Fusion2 is merged into Fusion0 and
+// Fusion1):
+// Param
+// / | \
+// Fusion3 Fusion0 Fusion1
+// \ | /
+// Tuple
+//
+TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule MergeSharedFusionInstruction
+
+comp.3 {
+ constant.param_0 = f32[4]{0} parameter(0)
+ param.param_1.2 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(1)
+ get-tuple-element.6 = f32[4]{0} get-tuple-element(param.param_1.2), index=0
+ ROOT add.7 = f32[4]{0} add(constant.param_0, get-tuple-element.6)
+}
+
+comp.2 {
+ param.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+ get-tuple-element.4 = f32[4]{0} get-tuple-element(param.param_1.1), index=1
+ get-tuple-element.5 = f32[4]{0} get-tuple-element(param.param_1.1), index=2
+ ROOT add.6 = f32[4]{0} add(get-tuple-element.4, get-tuple-element.5)
+}
+
+comp.1 {
+ add.1.param_1.1 = f32[4]{0} parameter(1)
+ constant.param_1.3 = f32[4]{0} parameter(0)
+ add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
+ ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
+}
+
+comp {
+ add.1.param_1 = f32[4]{0} parameter(1)
+ constant.param_1.1 = f32[4]{0} parameter(0)
+ multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
+ ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
+}
+
+ENTRY MergeSharedFusionInstruction.Computation0 {
+ constant = f32[4]{0} constant({1, 1, 1, 1})
+ param = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+ fusion.3 = f32[4]{0} fusion(constant, param), kind=kLoop, calls=comp.3
+ fusion.4 = f32[4]{0} fusion(param), kind=kLoop, calls=comp.2
+ fusion.5 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp.1
+ fusion.6 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp
+ ROOT tuple = (f32[4]{0}, f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.5, fusion.6)
+})")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+
+ auto* root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(HloOpcode::kTuple, root->opcode());
+ // Check operand 0 (not merged). Should have 4 instructions.
+ auto* operand0 = root->operand(0);
+ EXPECT_EQ(HloOpcode::kFusion, operand0->opcode());
+ EXPECT_EQ(4, operand0->fused_instruction_count());
+ // Check operand 1 (should have merged in its operand fusion instruction).
+ auto* operand1 = root->operand(1);
+ EXPECT_EQ(HloOpcode::kFusion, operand1->opcode());
+ EXPECT_EQ(7, operand1->fused_instruction_count());
+ // Check operand 2 (should have merged in its operand fusion instruction).
+ auto* operand2 = root->operand(2);
+ EXPECT_EQ(HloOpcode::kFusion, operand2->opcode());
+ EXPECT_EQ(7, operand2->fused_instruction_count());
+}
+
+TEST_F(FusionMergerTest, MoreMemoryAccessIfFused) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f32add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT _ = f32[] add(x, y)
+}
+
+comp0 {
+ p = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0)
+ gte0 = f32[2048] get-tuple-element(p), index=0
+ gte1 = f32[2048] get-tuple-element(p), index=1
+ add.9 = f32[2048] add(gte0, gte1)
+ gte2 = f32[2048] get-tuple-element(p), index=2
+ add.10 = f32[2048] add(add.9, gte2)
+ gte3 = f32[2048] get-tuple-element(p), index=3
+ add.11 = f32[2048] add(add.10, gte3)
+ p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1)
+ gte4 = f32[2048] get-tuple-element(p1), index=0
+ gte5 = f32[2048] get-tuple-element(p1), index=1
+ add.12 = f32[2048] add(gte4, gte5)
+ gte6 = f32[2048] get-tuple-element(p1), index=2
+ add.13 = f32[2048] add(add.12, gte6)
+ gte7 = f32[2048] get-tuple-element(p1), index=3
+ add.14 = f32[2048] add(add.13, gte7)
+ ROOT r = f32[2048] add(add.14, add.11)
+}
+
+comp1 {
+ p = f32[2048] parameter(0)
+ c0 = f32[] constant(0)
+ ROOT r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
+}
+
+comp2 {
+ p = f32[2048] parameter(0)
+ c0 = f32[] constant(0)
+ r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
+ ROOT n = f32[] negate(r)
+}
+
+ENTRY m.Computation2 {
+ p0 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0)
+ p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1)
+ fusion.0 = f32[2048] fusion(p0, p1), kind=kLoop, calls=comp0
+ fusion.1 = f32[] fusion(fusion.0), kind=kLoop, calls=comp1
+ fusion.2 = f32[] fusion(fusion.0), kind=kLoop, calls=comp2
+ ROOT tuple = (f32[], f32[]) tuple(fusion.1, fusion.2)
+}
+)")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, LessMemoryAccessIfFused) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+comp.2 {
+ state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+ get-tuple-element.5 = f32[4]{0} get-tuple-element(state.param_1.1), index=0
+ get-tuple-element.6 = f32[4]{0} get-tuple-element(state.param_1.1), index=1
+ add.7 = f32[4]{0} add(get-tuple-element.5, get-tuple-element.6)
+ get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=2
+ ROOT add.8 = f32[4]{0} add(add.7, get-tuple-element.7)
+}
+
+comp.1 {
+ add.1.param_1.1 = f32[4]{0} parameter(1)
+ constant.param_1.3 = f32[4]{0} parameter(0)
+ add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
+ ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
+}
+
+comp {
+ add.1.param_1 = f32[4]{0} parameter(1)
+ constant.param_1.1 = f32[4]{0} parameter(0)
+ multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
+ ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
+}
+
+ENTRY m.Computation2 {
+ constant = f32[4]{0} constant({1, 1, 1, 1})
+ state = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+ fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2
+ fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1
+ fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp
+ ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4)
+})")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+// Check that we're willing to merge f1_computation into f2_computation, even
+// though f2 is an input fusion node.
+TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[32]{0} parameter(0)
+ ROOT f1_root = f32[32]{0} add(f1_p0, f1_p0)
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[32]{0} parameter(0)
+ f2_mul = f32[32]{0} multiply(f2_p0, f2_p0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[32]{0} parameter(0)
+ f1 = f32[32]{0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+ })")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter())));
+}
+
+TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule jit_matmul.36
+
+ max (parameter.13: f32[], parameter.14: f32[]) -> f32[] {
+ parameter.13 = f32[] parameter(0)
+ parameter.14 = f32[] parameter(1)
+ ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14)
+ }
+
+ add (parameter.29: f32[], parameter.30: f32[]) -> f32[] {
+ parameter.29 = f32[] parameter(0)
+ parameter.30 = f32[] parameter(1)
+ ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30)
+ }
+
+ fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] {
+ param_1.4 = f32[200,200,200]{2,1,0} parameter(0)
+ param_2.1 = f32[200,200]{1,0} parameter(1)
+ broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2}
+ subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3)
+ exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0)
+ constant.27 = f32[] constant(0)
+ ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add
+ }
+
+ fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] {
+ param_1.9 = f32[200,200]{1,0} parameter(1)
+ broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1}
+ param_0.7 = f32[200,200]{1,0} parameter(0)
+ broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2}
+ ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8)
+ }
+
+ ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] {
+ parameter.2 = f32[200,200]{1,0} parameter(1)
+ parameter.1 = f32[200,200]{1,0} parameter(0)
+ fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3
+ constant.11 = f32[] constant(-inf)
+ reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max
+ ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1
+ })")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Fusion(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
+ // TODO(b/247762001): the case here does not represent the problem -
+ // profiling shows that it works faster if merged (even on larger dimensions).
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+ add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
+ // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
+ ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,256]{0,1,2} parameter(0)
+ f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+ })")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeReduceNotTooUnfriendlyLayouts) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+ slice1 = f32[5,16,256]{0,1,2} slice(f1_p0), slice={[0:5], [0:16], [0:256]}
+ // Here the copy changes the layout only of a part of the data.
+ f1_copy = f32[5,16,256]{2,1,0} copy(slice1)
+ slice2 = f32[11,16,256]{0,1,2} slice(f1_p0), slice={[0:11], [0:16], [0:256]}
+ bitcast = f32[11,16,256]{2,1,0} bitcast(slice2)
+ ROOT f1_root = f32[16,16,256]{2,1,0} concatenate(f1_copy, bitcast), dimensions={0}
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[16,16] reduce(f2_p0, f2_zero), dimensions={2},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,256]{0,1,2} parameter(0)
+ f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[16,16] fusion(f1), kind=kInput, calls=f2_computation
+ })")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+// Check that we limit the number of operands to fusions we create.
+TEST_F(FusionMergerTest, AvoidsLargeFusion) {
+ constexpr int64_t kNumParams = MaxOperandsAndOutputsPerFusion() + 1;
+
+ // Compute
+ // p0 + p1 + p2 + ... + pn,
+ // Use so many parameters that they do not fit into one fusion.
+ auto module = CreateNewVerifiedModule();
+ HloComputation::Builder b(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+
+ std::vector<HloInstruction*> entry_params;
+
+ for (int64_t i = 0; i < kNumParams; ++i) {
+ entry_params.push_back(
+ b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
+ }
+ auto make_fusion = [&](absl::Span<HloInstruction* const> params) {
+ // Build a fusion computation for calculating the sum of all parameters.
+ HloComputation::Builder sub_builder("subcomp");
+ HloInstruction* sum = nullptr;
+ for (int64_t i = 0; i < params.size(); ++i) {
+ auto p = sub_builder.AddInstruction(
+ HloInstruction::CreateParameter(i, shape, "p"));
+ if (sum == nullptr) {
+ sum = p;
+ } else {
+ sum = sub_builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, p));
+ }
+ }
+ HloComputation* subcomp =
+ module->AddEmbeddedComputation(sub_builder.Build());
+ return HloInstruction::CreateFusion(
+ shape, HloInstruction::FusionKind::kLoop, params, subcomp);
+ };
+ auto fusion = b.AddInstruction(
+ make_fusion(absl::MakeSpan(entry_params)
+ .subspan(0, MaxOperandsAndOutputsPerFusion())));
+ b.AddInstruction(make_fusion({entry_params.back(), fusion}));
+ module->AddEntryComputation(b.Build());
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+// TODO(b/119692968): Remove this test once fusion emitter is fixed.
+TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f1 {
+ Arg_0.5 = f32[200000] parameter(0)
+ slice.7 = f32[100000] slice(Arg_0.5), slice={[0:199999:2]}
+ slice.8 = f32[100000] slice(Arg_0.5), slice={[1:200000:2]}
+ add.9 = f32[100000] add(slice.7, slice.8)
+ slice.10 = f32[50000] slice(add.9), slice={[0:99999:2]}
+ slice.11 = f32[50000] slice(add.9), slice={[1:100000:2]}
+ add.12 = f32[50000] add(slice.10, slice.11)
+ slice.13 = f32[25000] slice(add.12), slice={[0:49999:2]}
+ slice.14 = f32[25000] slice(add.12), slice={[1:50000:2]}
+ add.15 = f32[25000] add(slice.13, slice.14)
+ slice.16 = f32[12500] slice(add.15), slice={[0:24999:2]}
+ slice.17 = f32[12500] slice(add.15), slice={[1:25000:2]}
+ add.18 = f32[12500] add(slice.16, slice.17)
+ slice.19 = f32[6250] slice(add.18), slice={[0:12499:2]}
+ slice.20 = f32[6250] slice(add.18), slice={[1:12500:2]}
+ add.21 = f32[6250] add(slice.19, slice.20)
+ slice.22 = f32[3125] slice(add.21), slice={[0:6249:2]}
+ slice.23 = f32[3125] slice(add.21), slice={[1:6250:2]}
+ ROOT add.24 = f32[3125] add(slice.22, slice.23)
+}
+
+f2 {
+ Arg_0 = f32[3125] parameter(0)
+ slice.25 = f32[1562] slice(Arg_0), slice={[0:3124:2]}
+ slice.26 = f32[1562] slice(Arg_0), slice={[1:3125:2]}
+ add.27 = f32[1562] add(slice.25, slice.26)
+ slice.28 = f32[781] slice(add.27), slice={[0:1561:2]}
+ slice.29 = f32[781] slice(add.27), slice={[1:1562:2]}
+ add.30 = f32[781] add(slice.28, slice.29)
+ slice.31 = f32[390] slice(add.30), slice={[0:780:2]}
+ slice.32 = f32[390] slice(add.30), slice={[1:781:2]}
+ add.33 = f32[390] add(slice.31, slice.32)
+ slice.34 = f32[195] slice(add.33), slice={[0:389:2]}
+ slice.35 = f32[195] slice(add.33), slice={[1:390:2]}
+ add.36 = f32[195] add(slice.34, slice.35)
+ slice.37 = f32[97] slice(add.36), slice={[0:194:2]}
+ slice.38 = f32[97] slice(add.36), slice={[1:195:2]}
+ add.39 = f32[97] add(slice.37, slice.38)
+ slice.40 = f32[48] slice(add.39), slice={[0:96:2]}
+ slice.41 = f32[48] slice(add.39), slice={[1:97:2]}
+ ROOT add.42 = f32[48] add(slice.40, slice.41)
+}
+
+ENTRY e {
+ p0 = f32[200000] parameter(0)
+ f1 = f32[3125] fusion(p0), kind=kLoop, calls=f1
+ ROOT r = f32[48] fusion(f1), kind=kLoop, calls=f2
+})")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeSliceIntoReusingConsumer) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f1 {
+ p01 = s8[1000000] parameter(0)
+ ROOT s0 = s8[10] slice(p01), slice={[0:10]}
+}
+
+f2 {
+ p02 = s8[10] parameter(0)
+ ROOT b0 = s8[10,1000000] broadcast(p02), dimensions={0}
+}
+
+ENTRY e {
+ p0 = s8[1000000] parameter(0)
+ f1 = s8[10] fusion(p0), kind=kLoop, calls=f1
+ ROOT r = s8[10,1000000] fusion(f1), kind=kLoop, calls=f2
+})")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ %f_a (p: f32[]) -> f32[1024,1024,1024] {
+ %p = f32[] parameter(0)
+ %b = f32[1024,1024,1024] broadcast(%p), dimensions={}
+ ROOT %t = f32[1024,1024,1024] tanh(%b)
+ }
+
+ %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+ %p = f32[1024,1024,1024] parameter(0)
+ ROOT %t = f32[1024,1024,1024] tanh(%p)
+ }
+
+ %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+ %p = f32[1024,1024,1024] parameter(0)
+ ROOT %t = f32[1024,1024,1024] tanh(%p)
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_a
+ f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_b
+ f3 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
+ ROOT f4 = f32[1024,1024,1024] add(f2, f3)
+ })")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+ %p = f32[1024,1024,1024] parameter(0)
+ ROOT %t = f32[1024,1024,1024] tanh(%p)
+ }
+
+ %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+ %p = f32[1024,1024,1024] parameter(0)
+ ROOT %t = f32[1024,1024,1024] add(%p, %p)
+ }
+
+ ENTRY entry {
+ p0 = f32[1024,1024,1024] parameter(0)
+ f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
+ ROOT f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
+ })")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillNotMergeExpensiveFusionsWithReusingConsumer) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ %f_b {
+ %p = f32[1024,1024,1024] parameter(0)
+ %t1 = f32[1024,1024,1024] tanh(%p)
+ %t2 = f32[1024,1024,1024] tanh(%t1)
+ %t3 = f32[1024,1024,1024] tanh(%t2)
+ %t4 = f32[1024,1024,1024] tanh(%t3)
+ %t5 = f32[1024,1024,1024] tanh(%t4)
+ %t6 = f32[1024,1024,1024] tanh(%t5)
+ %t7 = f32[1024,1024,1024] tanh(%t6)
+ %t8 = f32[1024,1024,1024] tanh(%t7)
+ ROOT %t9 = f32[1024,1024,1024] tanh(%t8)
+ }
+
+ %f_c {
+ %p = f32[1024,1024,1024] parameter(0)
+ ROOT %t = f32[1024,1024,1024,2048] broadcast(%p), dimensions={0,1,2}
+ }
+
+ ENTRY entry {
+ p0 = f32[1024,1024,1024] parameter(0)
+ f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
+ ROOT f2 = f32[1024,1024,1024,2048] fusion(f1), kind=kLoop, calls=%f_c
+ })")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, NoMergeWithBitcast) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f32add {
+ x.634 = f32[] parameter(0)
+ y.635 = f32[] parameter(1)
+ ROOT add.636 = f32[] add(x.634, y.635)
+}
+
+fused_computation.103 {
+ param_0.310 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
+ param_1.420 = f32[8,512]{1,0} parameter(1)
+ bitcast.1144 = f32[1,8,512]{2,1,0} bitcast(param_1.420)
+ convert.252 = f16[1,8,512]{2,1,0} convert(bitcast.1144)
+ bitcast.1143 = f16[8,512]{1,0} bitcast(convert.252)
+ broadcast.481 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1143), dimensions={1,2}
+ divide.15 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.310, broadcast.481)
+ ROOT bitcast.1142 = f16[8,512,1536]{1,2,0} bitcast(divide.15)
+}
+
+fused_computation.105 {
+ param_1.426 = f16[8,1536,512]{2,1,0} parameter(1)
+ bitcast.1896 = f16[1,8,1536,512]{3,2,1,0} bitcast(param_1.426)
+ transpose.238 = f16[1,8,512,1536]{2,3,1,0} transpose(bitcast.1896), dimensions={0,1,3,2}
+ param_0.315 = f16[8,512]{1,0} parameter(0)
+ broadcast.482 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.315), dimensions={1,2}
+ subtract.22 = f16[1,8,512,1536]{2,3,1,0} subtract(transpose.238, broadcast.482)
+ ROOT exponential.15 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.22)
+}
+
+fused_computation.104 {
+ param_0.1000 = f16[8,1536,512]{2,1,0} parameter(0)
+ convert.652 = f32[8,1536,512]{2,1,0} convert(param_0.1000)
+ constant_752 = f32[] constant(-0)
+ ROOT reduce.232 = f32[8,512]{1,0} reduce(convert.652, constant_752),
+ dimensions={1}, to_apply=f32add
+}
+
+ENTRY entry {
+ p0 = f16[8,1536,512]{2,1,0} parameter(0)
+ p1 = f16[8,512]{1,0} parameter(1)
+ fusion.105 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.105
+ bitcast.1787 = f16[8,1536,512]{2,1,0} bitcast(fusion.105)
+ fusion.104 = f32[8,512]{1,0} fusion(bitcast.1787), kind=kInput, calls=fused_computation.104
+ ROOT fusion.103 = f16[8,512,1536]{1,2,0} fusion(fusion.105, fusion.104), kind=kLoop, calls=fused_computation.103
+}
+ )")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, CostBasedMerge) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+fused_computation.45 {
+ param_1.194 = f16[8,1536,512]{2,1,0} parameter(1)
+ bitcast.1042 = f16[1,8,512,1536]{2,3,1,0} bitcast(param_1.194)
+ param_0.135 = f16[8,512]{1,0} parameter(0)
+ broadcast.391 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.135), dimensions={1,2}
+ subtract.6 = f16[1,8,512,1536]{2,3,1,0} subtract(bitcast.1042, broadcast.391)
+ ROOT exponential.11 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.6)
+}
+
+f32add {
+ x.634 = f32[] parameter(0)
+ y.635 = f32[] parameter(1)
+ ROOT add.636 = f32[] add(x.634, y.635)
+}
+
+fused_computation.44 {
+ param_0.869 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
+ convert.221 = f32[1,8,512,1536]{2,3,1,0} convert(param_0.869)
+ transpose.212 = f32[1,8,1536,512]{3,2,1,0} transpose(convert.221), dimensions={0,1,3,2}
+ bitcast.1041 = f32[8,1536,512]{2,1,0} bitcast(transpose.212)
+ constant_429 = f32[] constant(0)
+ ROOT reduce.149 = f32[8,512]{1,0} reduce(bitcast.1041, constant_429), dimensions={1}, to_apply=f32add
+}
+
+fused_computation.43 {
+ param_0.130 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
+ param_1.188 = f32[8,512]{1,0} parameter(1)
+ bitcast.1040 = f32[1,8,512]{2,1,0} bitcast(param_1.188)
+ convert.220 = f16[1,8,512]{2,1,0} convert(bitcast.1040)
+ bitcast.1039 = f16[8,512]{1,0} bitcast(convert.220)
+ broadcast.390 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1039), dimensions={1,2}
+ divide.11 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.130, broadcast.390)
+ ROOT bitcast.1038 = f16[8,512,1536]{1,2,0} bitcast(divide.11)
+}
+
+ENTRY entry {
+ p0 = f16[8,1536,512]{2,1,0} parameter(0)
+ p1 = f16[8,512]{1,0} parameter(1)
+ fusion.45 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.45
+ fusion.44 = f32[8,512]{1,0} fusion(fusion.45), kind=kInput, calls=fused_computation.44
+ ROOT fusion.43 = f16[8,512,1536]{1,2,0} fusion(fusion.45, fusion.44), kind=kLoop, calls=fused_computation.43
+}
+ )")
+ .value();
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ // For some reason, we would not merge any fusions when using the MLIR
+ // reduction emitter. The cost model queries the reduction emitter regarding
+ // the launch dimensions, so it seems likely that it is caused by different
+ // launch dimensions.
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+// Outputs of fusions 66 and 67 here are heavily reused by fusion 59 - so
+// it is better to not merge here.
+TEST_F(FusionMergerTest, CostBasedNoMerge) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+add_float_.56 {
+ x.57 = f32[] parameter(0)
+ y.58 = f32[] parameter(1)
+ ROOT add.59 = f32[] add(x.57, y.58)
+}
+
+fused_computation.66 {
+ constant.635 = f32[] constant(0)
+ broadcast.257 = f32[459,3]{1,0} broadcast(constant.635), dimensions={}
+ constant.641 = f32[] constant(1)
+ broadcast.256 = f32[459,3]{1,0} broadcast(constant.641), dimensions={}
+ broadcast.255 = f32[459]{0} broadcast(constant.635), dimensions={}
+ iota.28 = f32[459]{0} iota(), iota_dimension=0
+ constant.629 = f32[] constant(1.49891067)
+ broadcast.253 = f32[459]{0} broadcast(constant.629), dimensions={}
+ multiply.39 = f32[459]{0} multiply(iota.28, broadcast.253)
+ constant.633 = f32[] constant(-1)
+ broadcast.252 = f32[459]{0} broadcast(constant.633), dimensions={}
+ add.31 = f32[459]{0} add(multiply.39, broadcast.252)
+ ceil.11 = f32[459]{0} ceil(add.31)
+ constant.630 = f32[] constant(685)
+ broadcast.251 = f32[459]{0} broadcast(constant.630), dimensions={}
+ clamp.49 = f32[459]{0} clamp(broadcast.255, ceil.11, broadcast.251)
+ subtract.11 = f32[459]{0} subtract(clamp.49, multiply.39)
+ broadcast.249 = f32[459,3]{1,0} broadcast(subtract.11), dimensions={0}
+ iota.26 = f32[459,3]{1,0} iota(), iota_dimension=1
+ add.30 = f32[459,3]{1,0} add(broadcast.249, iota.26)
+ abs.3 = f32[459,3]{1,0} abs(add.30)
+ subtract.10 = f32[459,3]{1,0} subtract(broadcast.256, abs.3)
+ maximum.6 = f32[459,3]{1,0} maximum(broadcast.257, subtract.10)
+ ROOT reduce.3 = f32[459]{0} reduce(maximum.6, constant.635), dimensions={1}, to_apply=add_float_.56
+}
+
+fused_computation.67 {
+ constant.684 = f32[] constant(0)
+ broadcast.296 = f32[1130,3]{1,0} broadcast(constant.684), dimensions={}
+ constant.685 = f32[] constant(1)
+ broadcast.295 = f32[1130,3]{1,0} broadcast(constant.685), dimensions={}
+ broadcast.294 = f32[1130]{0} broadcast(constant.684), dimensions={}
+ iota.41 = f32[1130]{0} iota(), iota_dimension=0
+ constant.675 = f32[] constant(1.34513271)
+ broadcast.293 = f32[1130]{0} broadcast(constant.675), dimensions={}
+ multiply.47 = f32[1130]{0} multiply(iota.41, broadcast.293)
+ constant.677 = f32[] constant(-1)
+ broadcast.290 = f32[1130]{0} broadcast(constant.677), dimensions={}
+ add.39 = f32[1130]{0} add(multiply.47, broadcast.290)
+ ceil.15 = f32[1130]{0} ceil(add.39)
+ constant.676 = f32[] constant(1517)
+ broadcast.289 = f32[1130]{0} broadcast(constant.676), dimensions={}
+ clamp.53 = f32[1130]{0} clamp(broadcast.294, ceil.15, broadcast.289)
+ subtract.19 = f32[1130]{0} subtract(clamp.53, multiply.47)
+ broadcast.287 = f32[1130,3]{1,0} broadcast(subtract.19), dimensions={0}
+ iota.39 = f32[1130,3]{1,0} iota(), iota_dimension=1
+ add.38 = f32[1130,3]{1,0} add(broadcast.287, iota.39)
+ abs.7 = f32[1130,3]{1,0} abs(add.38)
+ subtract.18 = f32[1130,3]{1,0} subtract(broadcast.295, abs.7)
+ maximum.10 = f32[1130,3]{1,0} maximum(broadcast.296, subtract.18)
+ ROOT reduce.4 = f32[1130]{0} reduce(maximum.10, constant.684), dimensions={1}, to_apply=add_float_.56
+}
+
+fused_computation.59 {
+ constant.532 = f32[] constant(0)
+ broadcast.316 = f32[1130,3]{1,0} broadcast(constant.532), dimensions={}
+ constant.663 = f32[] constant(1)
+ broadcast.315 = f32[1130,3]{1,0} broadcast(constant.663), dimensions={}
+ broadcast.314 = f32[1130]{0} broadcast(constant.532), dimensions={}
+ iota.47 = f32[1130]{0} iota(), iota_dimension=0
+ constant.579 = f32[] constant(1.34513271)
+ broadcast.311 = f32[1130]{0} broadcast(constant.579), dimensions={}
+ multiply.51 = f32[1130]{0} multiply(iota.47, broadcast.311)
+ constant.578 = f32[] constant(-1)
+ broadcast.310 = f32[1130]{0} broadcast(constant.578), dimensions={}
+ add.43 = f32[1130]{0} add(multiply.51, broadcast.310)
+ ceil.17 = f32[1130]{0} ceil(add.43)
+ constant.576 = f32[] constant(1517)
+ broadcast.309 = f32[1130]{0} broadcast(constant.576), dimensions={}
+ clamp.55 = f32[1130]{0} clamp(broadcast.314, ceil.17, broadcast.309)
+ subtract.24 = f32[1130]{0} subtract(clamp.55, multiply.51)
+ broadcast.306 = f32[1130,3]{1,0} broadcast(subtract.24), dimensions={0}
+ iota.45 = f32[1130,3]{1,0} iota(), iota_dimension=1
+ add.42 = f32[1130,3]{1,0} add(broadcast.306, iota.45)
+ abs.9 = f32[1130,3]{1,0} abs(add.42)
+ subtract.23 = f32[1130,3]{1,0} subtract(broadcast.315, abs.9)
+ maximum.12 = f32[1130,3]{1,0} maximum(broadcast.316, subtract.23)
+ param_2.183 = f32[1130]{0} parameter(2)
+ broadcast.172 = f32[1130,3]{1,0} broadcast(param_2.183), dimensions={0}
+ divide.3 = f32[1130,3]{1,0} divide(maximum.12, broadcast.172)
+ bitcast.53 = f32[3390]{0} bitcast(divide.3)
+ broadcast.171 = f32[3390,1377]{1,0} broadcast(bitcast.53), dimensions={0}
+ broadcast.276 = f32[459,3]{1,0} broadcast(constant.532), dimensions={}
+ broadcast.275 = f32[459,3]{1,0} broadcast(constant.663), dimensions={}
+ broadcast.274 = f32[459]{0} broadcast(constant.532), dimensions={}
+ iota.35 = f32[459]{0} iota(), iota_dimension=0
+ constant.614 = f32[] constant(1.49891067)
+ broadcast.273 = f32[459]{0} broadcast(constant.614), dimensions={}
+ multiply.43 = f32[459]{0} multiply(iota.35, broadcast.273)
+ broadcast.272 = f32[459]{0} broadcast(constant.578), dimensions={}
+ add.35 = f32[459]{0} add(multiply.43, broadcast.272)
+ ceil.13 = f32[459]{0} ceil(add.35)
+ constant.611 = f32[] constant(685)
+ broadcast.269 = f32[459]{0} broadcast(constant.611), dimensions={}
+ clamp.51 = f32[459]{0} clamp(broadcast.274, ceil.13, broadcast.269)
+ subtract.15 = f32[459]{0} subtract(clamp.51, multiply.43)
+ broadcast.267 = f32[459,3]{1,0} broadcast(subtract.15), dimensions={0}
+ iota.33 = f32[459,3]{1,0} iota(), iota_dimension=1
+ add.34 = f32[459,3]{1,0} add(broadcast.267, iota.33)
+ abs.5 = f32[459,3]{1,0} abs(add.34)
+ subtract.14 = f32[459,3]{1,0} subtract(broadcast.275, abs.5)
+ maximum.8 = f32[459,3]{1,0} maximum(broadcast.276, subtract.14)
+ param_1.177 = f32[459]{0} parameter(1)
+ broadcast.170 = f32[459,3]{1,0} broadcast(param_1.177), dimensions={0}
+ divide.2 = f32[459,3]{1,0} divide(maximum.8, broadcast.170)
+ bitcast.52 = f32[1377]{0} bitcast(divide.2)
+ broadcast.169 = f32[3390,1377]{1,0} broadcast(bitcast.52), dimensions={1}
+ multiply.15 = f32[3390,1377]{1,0} multiply(broadcast.171, broadcast.169)
+ bitcast.61 = f32[1130,3,459,3]{3,2,1,0} bitcast(multiply.15)
+ transpose.68 = f32[459,1130,3,3]{2,0,3,1} transpose(bitcast.61), dimensions={2,0,3,1}
+ copy.1 = f32[459,1130,3,3]{3,2,1,0} copy(transpose.68)
+ bitcast.50 = f32[1130,459,9]{2,1,0} bitcast(copy.1)
+ broadcast.168 = f32[1130,459,6,9]{3,2,1,0} broadcast(bitcast.50), dimensions={0,1,3}
+ param_0.171 = u8[1,688,1520,6]{3,2,1,0} parameter(0)
+ bitcast.49 = u8[688,1520,1,6]{3,1,0,2} bitcast(param_0.171)
+ convert.175 = f32[688,1520,1,6]{3,1,0,2} convert(bitcast.49)
+ broadcast.167 = f32[459,1130,1]{2,1,0} broadcast(clamp.51), dimensions={0}
+ broadcast.166 = f32[459,1130,1]{2,1,0} broadcast(clamp.55), dimensions={1}
+ concatenate.3 = f32[459,1130,2]{2,1,0} concatenate(broadcast.167, broadcast.166), dimensions={2}
+ convert.174 = s32[459,1130,2]{2,1,0} convert(concatenate.3)
+ bitcast.48 = s32[518670,2]{1,0} bitcast(convert.174)
+ gather.1 = f32[518670,3,3,1,6]{2,1,4,0,3} gather(convert.175, bitcast.48), offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={3,3,1,6}
+ transpose.69 = f32[1,518670,6,3,3]{4,3,2,1,0} transpose(gather.1), dimensions={3,0,4,1,2}
+ bitcast.47 = f32[1130,459,6,9]{3,2,1,0} bitcast(transpose.69)
+ multiply.14 = f32[1130,459,6,9]{3,2,1,0} multiply(broadcast.168, bitcast.47)
+ reduce.2 = f32[1130,459,6]{2,1,0} reduce(multiply.14, constant.532), dimensions={3}, to_apply=add_float_.56
+ convert.173 = f16[1130,459,6]{2,1,0} convert(reduce.2)
+ bitcast.46 = f16[1,459,1130,6]{3,2,1,0} bitcast(convert.173)
+ constant.533 = f16[] constant(0)
+ pad.9 = f16[1,480,1130,6]{3,2,1,0} pad(bitcast.46, constant.533), padding=0_0x0_21x0_0x0_0
+ pad.8 = f16[1,480,1152,6]{3,2,1,0} pad(pad.9, constant.533), padding=0_0x0_0x0_22x0_0
+ constant.532f16 = f16[] constant(0)
+ ROOT pad.7 = f16[1,485,1157,6]{3,2,1,0} pad(pad.8, constant.532f16), padding=0_0x2_3x2_3x0_0
+}
+
+ENTRY e {
+ arg0.1 = u8[1,688,1520,6]{3,2,1,0} parameter(0), parameter_replication={false}
+ fusion.66 = f32[459]{0} fusion(), kind=kLoop, calls=fused_computation.66
+ fusion.67 = f32[1130]{0} fusion(), kind=kLoop, calls=fused_computation.67
+ ROOT fusion.59 = f16[1,485,1157,6]{2,1,3,0} fusion(arg0.1, fusion.66, fusion.67), kind=kLoop, calls=fused_computation.59
+}
+ )")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, NoMergeBecauseTooManyBasicBlockSplits) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+region_6.97 {
+ Arg_0.98 = pred[] parameter(0)
+ Arg_1.99 = pred[] parameter(1)
+ ROOT or.100 = pred[] or(Arg_0.98, Arg_1.99)
+}
+
+region_4.50 {
+ Arg_0.51 = f64[] parameter(0)
+ Arg_1.52 = f64[] parameter(1)
+ ROOT add.53 = f64[] add(Arg_0.51, Arg_1.52)
+}
+
+f2 {
+ param_0 = s64[1]{0} parameter(0)
+ constant_70 = f64[] constant(0)
+ convert.41.clone.1 = f64[1]{0} convert(param_0)
+ ROOT pad.99.clone.1 = f64[3]{0} pad(convert.41.clone.1, constant_70), padding=0_2
+}
+
+f1 {
+ param_0.361 = pred[5]{0} parameter(0)
+ broadcast.107 = pred[10,5]{1,0} broadcast(param_0.361), dimensions={1}
+ param_6.244 = pred[5]{0} parameter(6)
+ broadcast.111.clone.1 = pred[10,5]{1,0} broadcast(param_6.244), dimensions={1}
+ param_1.450 = f64[10,5]{1,0} parameter(1)
+ constant_294_clone_1 = f64[] constant(1)
+ broadcast.153.clone.1 = f64[10,5]{1,0} broadcast(constant_294_clone_1), dimensions={}
+ compare.22.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.153.clone.1), direction=GE
+ constant_75_clone_1 = f64[] constant(-1)
+ broadcast.109.clone.1 = f64[10,5]{1,0} broadcast(constant_75_clone_1), dimensions={}
+ add.34.clone.1 = f64[10,5]{1,0} add(param_1.450, broadcast.109.clone.1)
+ param_5.322 = f64[10,5,4]{1,0,2} parameter(5)
+ slice.45.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [3:4]}
+ bitcast.94.clone.1 = f64[10,5]{1,0} bitcast(slice.45.clone.1)
+ divide.7.clone.1 = f64[10,5]{1,0} divide(add.34.clone.1, bitcast.94.clone.1)
+ add.33.clone.1 = f64[10,5]{1,0} add(divide.7.clone.1, broadcast.153.clone.1)
+ constant_70 = f64[] constant(0)
+ broadcast.157.clone.1 = f64[10,5]{1,0} broadcast(constant_70), dimensions={}
+ compare.26.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.157.clone.1), direction=LE
+ slice.46.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [0:1]}
+ bitcast.93.clone.1 = f64[10,5]{1,0} bitcast(slice.46.clone.1)
+ divide.6.clone.1 = f64[10,5]{1,0} divide(param_1.450, bitcast.93.clone.1)
+ broadcast.295.clone.1 = f64[10,5,3]{1,0,2} broadcast(param_1.450), dimensions={0,1}
+ param_4.368 = f64[10,5,2]{1,0,2} parameter(4)
+ pad.103.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_70), padding=0_0x0_0x1_0
+ compare.121.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.103.clone.1), direction=GE
+ pad.102.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_294_clone_1), padding=0_0x0_0x0_1
+ compare.120.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.102.clone.1), direction=LT
+ and.39.clone.1 = pred[10,5,3]{1,0,2} and(compare.121.clone.1, compare.120.clone.1)
+ transpose.9 = pred[3,10,5]{2,1,0} transpose(and.39.clone.1), dimensions={2,0,1}
+ constant_296_clone_1 = pred[] constant(false)
+ reduce.91.clone.1 = pred[10,5]{1,0} reduce(transpose.9, constant_296_clone_1), dimensions={0}, to_apply=region_6.97
+ broadcast.294.clone.1 = pred[10,5,3]{1,0,2} broadcast(reduce.91.clone.1), dimensions={0,1}
+ pad.99.clone.1 = f64[3]{0} parameter(3)
+ broadcast.292.clone.1 = f64[3]{0} broadcast(constant_70), dimensions={}
+ compare.117.clone.1 = pred[3]{0} compare(pad.99.clone.1, broadcast.292.clone.1), direction=NE
+ broadcast.290.clone.1 = pred[10,5,3]{1,0,2} broadcast(compare.117.clone.1), dimensions={2}
+ select.67.clone.1 = pred[10,5,3]{1,0,2} select(broadcast.294.clone.1, and.39.clone.1, broadcast.290.clone.1)
+ convert.40.clone.1 = f64[10,5,3]{1,0,2} convert(select.67.clone.1)
+ broadcast.288.clone.1 = f64[10,5,3,3]{1,0,2,3} broadcast(convert.40.clone.1), dimensions={0,1,2}
+ param_2.361 = f64[10,5,4,3]{1,0,2,3} parameter(2)
+ slice.114.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [1:4], [0:3]}
+ multiply.53.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.114.clone.1)
+ transpose.10 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.53.clone.1), dimensions={3,2,0,1}
+ reduce.90.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.10, constant_70), dimensions={1}, to_apply=region_4.50
+ transpose.11 = f64[10,5,3]{1,0,2} transpose(reduce.90.clone.1), dimensions={1,2,0}
+ slice.28.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [0:1]}
+ bitcast.99.clone.1 = f64[10,5]{1,0} bitcast(slice.28.clone.1)
+ slice.108.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [0:3], [0:3]}
+ multiply.49.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.108.clone.1)
+ transpose.12 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.49.clone.1), dimensions={3,2,0,1}
+ reduce.82.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.12, constant_70), dimensions={1}, to_apply=region_4.50
+ transpose.13 = f64[10,5,3]{1,0,2} transpose(reduce.82.clone.1), dimensions={1,2,0}
+ slice.107.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [0:1]}
+ bitcast.240.clone.1 = f64[10,5]{1,0} bitcast(slice.107.clone.1)
+ subtract.27.clone.1 = f64[10,5]{1,0} subtract(bitcast.99.clone.1, bitcast.240.clone.1)
+ slice.27.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [2:3]}
+ bitcast.98.clone.1 = f64[10,5]{1,0} bitcast(slice.27.clone.1)
+ slice.26.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [2:3]}
+ bitcast.97.clone.1 = f64[10,5]{1,0} bitcast(slice.26.clone.1)
+ add.36.clone.1 = f64[10,5]{1,0} add(bitcast.97.clone.1, bitcast.98.clone.1)
+ slice.24.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [1:2]}
+ bitcast.95.clone.1 = f64[10,5]{1,0} bitcast(slice.24.clone.1)
+ slice.121.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [1:2]}
+ bitcast.274.clone.1 = f64[10,5]{1,0} bitcast(slice.121.clone.1)
+ subtract.26.clone.1 = f64[10,5]{1,0} subtract(bitcast.95.clone.1, bitcast.274.clone.1)
+ divide.21 = f64[10,5]{1,0} divide(subtract.26.clone.1, subtract.27.clone.1)
+ constant_77_clone_1 = f64[] constant(2)
+ broadcast.117.clone.1 = f64[10,5]{1,0} broadcast(constant_77_clone_1), dimensions={}
+ multiply.37.clone.1 = f64[10,5]{1,0} multiply(divide.21, broadcast.117.clone.1)
+ subtract.25.clone.1 = f64[10,5]{1,0} subtract(add.36.clone.1, multiply.37.clone.1)
+ subtract.24.clone.1 = f64[10,5]{1,0} subtract(param_1.450, bitcast.274.clone.1)
+ divide.9.clone.1 = f64[10,5]{1,0} divide(subtract.24.clone.1, subtract.26.clone.1)
+ clamp.7.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.9.clone.1, broadcast.153.clone.1)
+ multiply.36.clone.1 = f64[10,5]{1,0} multiply(subtract.25.clone.1, clamp.7.clone.1)
+ subtract.23.clone.1 = f64[10,5]{1,0} subtract(bitcast.98.clone.1, multiply.36.clone.1)
+ compare.13.clone.1 = pred[10,5]{1,0} compare(subtract.23.clone.1, broadcast.157.clone.1), direction=GE
+ negate.19.clone.1 = f64[10,5]{1,0} negate(divide.21)
+ multiply.35.clone.1 = f64[10,5]{1,0} multiply(negate.19.clone.1, clamp.7.clone.1)
+ multiply.34.clone.1 = f64[10,5]{1,0} multiply(multiply.35.clone.1, broadcast.117.clone.1)
+ negate.18.clone.1 = f64[10,5]{1,0} negate(subtract.23.clone.1)
+ multiply.33.clone.1 = f64[10,5]{1,0} multiply(subtract.23.clone.1, subtract.23.clone.1)
+ subtract.22.clone.1 = f64[10,5]{1,0} subtract(divide.21, subtract.23.clone.1)
+ constant_78_clone_1 = f64[] constant(4)
+ broadcast.113.clone.1 = f64[10,5]{1,0} broadcast(constant_78_clone_1), dimensions={}
+ multiply.32.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.113.clone.1)
+ multiply.31.clone.1 = f64[10,5]{1,0} multiply(multiply.32.clone.1, multiply.35.clone.1)
+ subtract.21.clone.1 = f64[10,5]{1,0} subtract(multiply.33.clone.1, multiply.31.clone.1)
+ compare.12.clone.1 = pred[10,5]{1,0} compare(subtract.21.clone.1, broadcast.157.clone.1), direction=GT
+ constant_79_clone_1 = f64[] constant(2.2250738585072014e-308)
+ broadcast.112.clone.1 = f64[10,5]{1,0} broadcast(constant_79_clone_1), dimensions={}
+ maximum.18.clone.1 = f64[10,5]{1,0} maximum(broadcast.112.clone.1, subtract.21.clone.1)
+ sqrt.1.clone.1 = f64[10,5]{1,0} sqrt(maximum.18.clone.1)
+ select.47.clone.1 = f64[10,5]{1,0} select(compare.12.clone.1, sqrt.1.clone.1, broadcast.157.clone.1)
+ add.35.clone.1 = f64[10,5]{1,0} add(negate.18.clone.1, select.47.clone.1)
+ select.46.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, multiply.34.clone.1, add.35.clone.1)
+ subtract.20.clone.1 = f64[10,5]{1,0} subtract(negate.18.clone.1, select.47.clone.1)
+ multiply.30.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.117.clone.1)
+ select.45.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, subtract.20.clone.1, multiply.30.clone.1)
+ divide.8.clone.1 = f64[10,5]{1,0} divide(select.46.clone.1, select.45.clone.1)
+ clamp.6.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.8.clone.1, broadcast.153.clone.1)
+ multiply.29.clone.1 = f64[10,5]{1,0} multiply(subtract.27.clone.1, clamp.6.clone.1)
+ add.32.clone.1 = f64[10,5]{1,0} add(multiply.29.clone.1, bitcast.240.clone.1)
+ select.44.clone.1 = f64[10,5]{1,0} select(compare.26.clone.1, divide.6.clone.1, add.32.clone.1)
+ select.43.clone.1 = f64[10,5]{1,0} select(compare.22.clone.1, add.33.clone.1, select.44.clone.1)
+ select.42.clone.1 = f64[10,5]{1,0} select(broadcast.111.clone.1, param_1.450, select.43.clone.1)
+ select.41 = f64[10,5]{1,0} select(broadcast.107, select.42.clone.1, broadcast.157.clone.1)
+ ROOT tuple.14 = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) tuple(select.41, select.42.clone.1, clamp.6.clone.1, subtract.25.clone.1, bitcast.97.clone.1, multiply.37.clone.1, bitcast.98.clone.1, divide.21)
+}
+
+ENTRY e {
+ p3 = s64[1]{0} parameter(3)
+ f2 = f64[3]{0} fusion(p3), kind=kLoop, calls=f2
+
+ p0 = pred[5]{0} parameter(0)
+ p1 = f64[10,5]{1,0} parameter(1)
+ p2 = f64[10,5,4,3]{1,0,2,3} parameter(2)
+ p4 = f64[10,5,2]{1,0,2} parameter(4)
+ p5 = f64[10,5,4]{1,0,2} parameter(5)
+ p6 = pred[5]{0} parameter(6)
+ ROOT ret = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) fusion(p0, p1, p2, f2, p4, p5, p6), kind=kLoop, calls=f1
+}
+ )")
+ .value();
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, CommonElementwiseUsedParameter) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ p {
+ p0 = f32[10000000] parameter(0)
+ p1 = f32[10000000] parameter(1)
+ p2 = f32[10000000] parameter(2)
+ p3 = f32[10000000] parameter(3)
+ a0 = f32[10000000] add(p1, p2)
+ a1 = f32[10000000] add(a0, p3)
+ ROOT _ = add(p0, a1)
+ }
+
+ c1 {
+ p0 = f32[10000000] parameter(0)
+ p1 = f32[10000000] parameter(1)
+ ROOT _ = add(p0, p1)
+ }
+
+ c2 {
+ p0 = f32[10000000] parameter(0)
+ p1 = f32[10000000] parameter(1)
+ ROOT _ = multiply(p0, p1)
+ }
+
+ ENTRY entry {
+ p0 = f32[10000000] parameter(0)
+ p1 = f32[10000000] parameter(1)
+ p2 = f32[10000000] parameter(2)
+ p3 = f32[10000000] parameter(3)
+ f = f32[10000000] fusion(p0, p1, p2, p3), kind=kLoop, calls=p
+ f1 = f32[10000000] fusion(p0, f), kind=kLoop, calls=c1
+ f2 = f32[10000000] fusion(p1, f), kind=kLoop, calls=c2
+ ROOT _ = (f32[10000000], f32[10000000]) tuple(f1, f2)
+ }
+ )")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, IncompatibleNonTrivialHeroes) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation {
+ param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
+ param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
+ s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
+ t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
+ sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
+ exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
+ ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
+ }
+
+ fused_computation.2 {
+ param_0.2 = f32[32,16,18]{2,1,0} parameter(0)
+ s.2 = f32[32,16,18]{2,1,0} sqrt(param_0.2)
+ ROOT t.2 = f32[32,18,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
+ }
+
+ ENTRY main {
+ p = f32[18,16,32]{2,1,0} parameter(0)
+ p2 = f32[32,16,18]{2,1,0} parameter(1)
+ fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
+ ROOT fusion2 = f32[32,18,16]{2,1,0} fusion(fusion), kind=kInput, calls=fused_computation.2
+ }
+ )")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, DoNotMergeDUSFusions) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ %fused_computation (param_0: f32[8], param_1.2: f32[], param_2.3: f32[8]) -> f32[8] {
+ %param_0 = f32[8]{0} parameter(0)
+ %param_2.3 = f32[8]{0} parameter(2)
+ %slice.2 = f32[5]{0} slice(f32[8]{0} %param_2.3), slice={[0:5]}
+ %param_1.2 = f32[] parameter(1)
+ %broadcast.2 = f32[5]{0} broadcast(f32[] %param_1.2), dimensions={}
+ %add.2 = f32[5]{0} add(f32[5]{0} %slice.2, f32[5]{0} %broadcast.2)
+ %two.1 = s32[] constant(2)
+ ROOT %dynamic-update-slice.2 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0, f32[5]{0} %add.2, s32[] %two.1)
+ }
+
+ %fused_computation.1 (param_0.1: f32[8], param_1.4: f32[6], param_2.6: f32[]) -> f32[8] {
+ %param_0.1 = f32[8]{0} parameter(0)
+ %param_1.4 = f32[6]{0} parameter(1)
+ %param_2.6 = f32[] parameter(2)
+ %broadcast.3 = f32[6]{0} broadcast(f32[] %param_2.6), dimensions={}
+ %add.3 = f32[6]{0} add(f32[6]{0} %param_1.4, f32[6]{0} %broadcast.3)
+ %three.1 = s32[] constant(3)
+ ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[6]{0} %add.3, s32[] %three.1)
+ }
+
+ ENTRY %Test (parameter: f32[8]) -> f32[8] {
+ %parameter = f32[8]{0} parameter(0)
+ %slice.1 = f32[6]{0} slice(f32[8]{0} %parameter), slice={[0:6]}
+ %one = f32[] constant(1)
+ %fusion.1 = f32[8]{0} fusion(f32[8]{0} %parameter, f32[6]{0} %slice.1, f32[] %one), kind=kLoop, calls=%fused_computation.1
+ ROOT %fusion = f32[8]{0} fusion(f32[8]{0} %fusion.1, f32[] %one, f32[8]{0} %parameter), kind=kLoop, calls=%fused_computation
+ }
+ )")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, MergeDUSFusionWithElementwiseFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ %fused_computation {
+ %param_0 = f32[1,8]{1,0} parameter(0)
+ %bitcast = f32[8]{0} bitcast(%param_0)
+ ROOT %neg = f32[8]{0} negate(%bitcast)
+ }
+
+ %fused_computation.1 {
+ %param_0.1 = f32[8]{0} parameter(0)
+ %param_1.4 = f32[5]{0} parameter(1)
+ %three.1 = s32[] constant(3)
+ %exp = f32[5]{0} exponential(%param_1.4)
+ ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[5]{0} %exp, s32[] %three.1)
+ }
+
+ ENTRY %Test {
+ %parameter = f32[5]{0} parameter(0)
+ %parameter.1 = f32[1,8]{1,0} parameter(1)
+ %fusion = f32[8]{0} fusion(f32[1,8]{1,0} %parameter.1), kind=kLoop, calls=%fused_computation
+ ROOT %fusion.1 = f32[8]{0} fusion(f32[8]{0} %fusion, f32[5]{0} %parameter), kind=kLoop, calls=%fused_computation.1
+ }
+ )")
+ .value();
+ EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, DoNotMergeTwoReduces) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add.13235 = f32[] add(p0, p1)
+ }
+
+ ENTRY main {
+ p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
+ ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
+ }
+ )")
+ .value();
+ EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc
new file mode 100644
index 0000000..d7f8505
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc
@@ -0,0 +1,152 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+
+#include <functional>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> FusionWrapper::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ auto instructions = module->entry_computation()->MakeInstructionPostOrder();
+ bool changed = false;
+
+ std::function<absl::Status(HloInstruction*)> handle_instruction;
+ handle_instruction = [&](HloInstruction* instruction) -> absl::Status {
+ switch (instruction->opcode()) {
+ case HloOpcode::kConditional:
+ case HloOpcode::kWhile:
+ for (auto* computation : instruction->called_computations()) {
+ for (auto* inner_instruction :
+ computation->MakeInstructionPostOrder()) {
+ TF_RETURN_IF_ERROR(handle_instruction(inner_instruction));
+ }
+ }
+ break;
+ case HloOpcode::kAbs:
+ case HloOpcode::kAdd:
+ case HloOpcode::kAnd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kBitcastConvert:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kCeil:
+ case HloOpcode::kCbrt:
+ case HloOpcode::kClamp:
+ case HloOpcode::kClz:
+ case HloOpcode::kCompare:
+ case HloOpcode::kComplex:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCopy:
+ case HloOpcode::kCos:
+ case HloOpcode::kDivide:
+ case HloOpcode::kDot:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kErf:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kFloor:
+ case HloOpcode::kGather:
+ case HloOpcode::kImag:
+ case HloOpcode::kIota:
+ case HloOpcode::kIsFinite:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kMap:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNegate:
+ case HloOpcode::kNot:
+ case HloOpcode::kOr:
+ case HloOpcode::kPad:
+ case HloOpcode::kPopulationCount:
+ case HloOpcode::kPower:
+ case HloOpcode::kReal:
+ case HloOpcode::kReshape:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReducePrecision:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kReverse:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kRoundNearestEven:
+ case HloOpcode::kRsqrt:
+ case HloOpcode::kScatter:
+ case HloOpcode::kSelect:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightLogical:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kSign:
+ case HloOpcode::kSin:
+ case HloOpcode::kSlice:
+ case HloOpcode::kSqrt:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kStochasticConvert:
+ case HloOpcode::kTan:
+ case HloOpcode::kTanh:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kXor: {
+ auto* computation = instruction->parent();
+ auto* fusion_instruction =
+ computation->AddInstruction(HloInstruction::CreateFusion(
+ instruction->shape(),
+ ChooseFusionKind(*instruction, *instruction), instruction));
+ const absl::string_view wrapped_opcode =
+ HloOpcodeString(instruction->opcode());
+ module->SetAndUniquifyInstrName(
+ fusion_instruction, absl::StrCat("wrapped_", wrapped_opcode));
+ module->SetAndUniquifyComputationName(
+ fusion_instruction->fused_instructions_computation(),
+ absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
+ if (module->has_schedule()) {
+ module->schedule().replace_instruction(computation, instruction,
+ fusion_instruction);
+ }
+ TF_RETURN_IF_ERROR(
+ fusion_instruction->CopyAllControlDepsFrom(instruction));
+ TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
+ changed = true;
+ break;
+ }
+ default:
+ break;
+ }
+ return absl::OkStatus();
+ };
+
+ for (auto* instruction : instructions) {
+ TF_RETURN_IF_ERROR(handle_instruction(instruction));
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h
new file mode 100644
index 0000000..30b1c8a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h
@@ -0,0 +1,42 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Wraps leftover unfused instruction that are in the entry computation that
+// have no LHLO equivalent in fusions containing just that instruction.
+class FusionWrapper : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "fusion-wrapper"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc
new file mode 100644
index 0000000..a46338f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc
@@ -0,0 +1,188 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+
+#include <optional>
+
+#include <gtest/gtest.h>
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class FusionWrapperTest : public HloTestBase {};
+
+TEST_F(FusionWrapperTest, SimpleOp) {
+ RunAndFilecheckHloRewrite(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ p0 = f16[30,41] parameter(0)
+ p1 = f16[30,41] parameter(1)
+ ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
+ })",
+ FusionWrapper(), R"(
+// CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] {
+// CHECK: %param_0 = f16[30,41]{1,0} parameter(0)
+// CHECK: %param_1 = f16[30,41]{1,0} parameter(1)
+// CHECK: ROOT %result.1 = f16[60,41]{1,0} concatenate(%param_0, %param_1), dimensions={0}
+// CHECK: }
+
+// CHECK: ENTRY %TestComputation (p0: f16[30,41], p1: f16[30,41]) -> f16[60,41] {
+// CHECK: %p0 = f16[30,41]{1,0} parameter(0)
+// CHECK: %p1 = f16[30,41]{1,0} parameter(1)
+// CHECK: ROOT %wrapped_concatenate = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%wrapped_concatenate_computation
+// CHECK: })");
+}
+
+TEST_F(FusionWrapperTest, Scatter) {
+ RunAndFilecheckHloRewrite(R"(
+ HloModule ScatterIntoScalar
+
+ update_s32 {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+ }
+
+ ENTRY main {
+ parameter.1 = s32[] parameter(0)
+ parameter.2 = s32[0]{0} parameter(1)
+ parameter.3 = s32[] parameter(2)
+ ROOT scatter_ScatterIntoScalar = s32[] scatter(parameter.1, parameter.2, parameter.3),
+ update_window_dims={},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={},
+ index_vector_dim=0,
+ to_apply=update_s32
+ })",
+ FusionWrapper(), R"(
+// CHECK: wrapped_scatter_computation
+// CHECK: %[[param_0:.*]] = s32[] parameter(0)
+// CHECK: %[[param_1:.*]] = s32[0]{0} parameter(1)
+// CHECK: %[[param_2:.*]] = s32[] parameter(2)
+// CHECK: ROOT %{{.*}} = s32[] scatter(%[[param_0]], %[[param_1]], %[[param_2]])
+
+// CHECK: ENTRY
+// CHECK: %[[p0:.*]] = s32[] parameter(0)
+// CHECK: %[[p1:.*]] = s32[0]{0} parameter(1)
+// CHECK: %[[p2:.*]] = s32[] parameter(2)
+// CHECK: ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%wrapped_scatter_computation
+// CHECK: })");
+}
+
+TEST_F(FusionWrapperTest, ControlDependency) {
+ RunAndFilecheckHloRewrite(R"(
+ HloModule TestModule
+
+ fusion {
+ ROOT param = f32[] parameter(0)
+ }
+
+ ENTRY main {
+ param = f32[] parameter(0)
+ fusion = f32[] fusion(param), kind=kLoop, calls=fusion
+ constant_one = f32[] constant(1)
+ ROOT add = f32[] add(param, constant_one), control-predecessors={fusion}
+ })",
+ FusionWrapper(), R"(
+// CHECK: ROOT %wrapped_add = f32[] fusion(%param.1, %constant_one),
+// CHECK-SAME: control-predecessors={%fusion})");
+}
+
+TEST_F(FusionWrapperTest, While) {
+ RunAndFilecheckHloRewrite(R"(
+ HloModule While
+
+ %body {
+ %parameter.5 = (f32[5]{0}) parameter(0)
+ %constant_8 = f32[] constant(0)
+ %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
+ ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
+ }
+
+ %cond {
+ %parameter.12 = (f32[5]{0}) parameter(0)
+ ROOT %constant_1 = pred[] constant(false)
+ }
+
+ ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
+ %parameter.1 = f32[5]{0} parameter(0)
+ %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
+ %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
+ ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
+ })",
+ FusionWrapper(), R"(
+// CHECK: %wrapped_broadcast_computation {{.*}} {
+// CHECK: %param_0.1 = f32[] parameter(0)
+// CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={}
+// CHECK: }
+// CHECK: %body {{.*}} {
+// CHECK: %parameter.5 = (f32[5]{0}) parameter(0)
+// CHECK: %constant_8 = f32[] constant(0)
+// CHECK: %wrapped_broadcast = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%wrapped_broadcast_computation
+// CHECK: ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast)
+// CHECK: }
+// CHECK: %cond {{.*}} {
+// CHECK: %parameter.12 = (f32[5]{0}) parameter(0)
+// CHECK: ROOT %constant_1 = pred[] constant(false)
+// CHECK: }
+// CHECK: %wrapped_copy_computation {{.*}} {
+// CHECK: %param_0 = f32[5]{0} parameter(0)
+// CHECK: ROOT %copy.0 = f32[5]{0} copy(%param_0)
+// CHECK: }
+// CHECK: ENTRY %main {{.*}} {
+// CHECK: %parameter.1 = f32[5]{0} parameter(0)
+// CHECK: %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation
+// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy)
+// CHECK: ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body
+// CHECK: })");
+}
+
+TEST_F(FusionWrapperTest, WhileInFusion) {
+ RunAndFilecheckHloRewrite(R"(
+ HloModule While
+
+ %body {
+ %parameter.5 = (f32[5]{0}) parameter(0)
+ %constant_8 = f32[] constant(0)
+ %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
+ ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
+ }
+
+ %cond {
+ %parameter.12 = (f32[5]{0}) parameter(0)
+ ROOT %constant_1 = pred[] constant(false)
+ }
+
+ %fusion {
+ %parameter.1 = f32[5]{0} parameter(0)
+ %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
+ %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
+ ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
+ }
+
+ ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
+ %parameter.1 = f32[5]{0} parameter(0)
+ ROOT %fusion = (f32[5]{0}) fusion(f32[5]{0} %parameter.1), kind=kLoop, calls=%fusion
+ })",
+ FusionWrapper(),
+ // No change
+ std::nullopt);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc
new file mode 100644
index 0000000..3395ed8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc
@@ -0,0 +1,124 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
+
+#include <cstdint>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace m = match;
+
+class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleCustomCall(HloInstruction *instr) override {
+ HloInstruction *existing_gemm;
+ HloInstruction *bcast;
+ if (Match(instr, m::CustomCall(&existing_gemm,
+ {kGemmCallTarget, kCublasLtMatmulCallTarget})
+ .WithOperand(0, m::Broadcast(&bcast, m::Op()))) ||
+ (Match(instr, m::CustomCall(&existing_gemm, {kGemmCallTarget,
+ kCublasLtMatmulCallTarget})
+ .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) {
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ existing_gemm->backend_config<GpuBackendConfig>());
+ GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+ DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers();
+ int bcast_operand_index = instr->operand_index(bcast);
+ int num_bcast_dims = (bcast->shape().dimensions_size() -
+ bcast->operand(0)->shape().dimensions_size());
+ int num_batch_dims = dim_nums->lhs_batch_dimensions_size();
+
+ const tsl::protobuf::RepeatedField<int64_t> &batch_dimensions =
+ (bcast_operand_index == 1) ? dim_nums->rhs_batch_dimensions()
+ : dim_nums->lhs_batch_dimensions();
+ // This optimization is only valid if the set of broadcasted dimensions
+ // is exactly the set of batch dimensions. First, check that all newly
+ // broadcast dimensions have been inserted on the left i.e. all new
+ // dimensions must be in [0, num_bcast_dims) or equivalently all original
+ // dimensions are >= num_bcast_dims.
+ for (int64_t bcast_dim : bcast->dimensions()) {
+ if (bcast_dim < num_bcast_dims) {
+ return absl::OkStatus();
+ }
+ // bcast_dim should not be in batch_dimensions.
+ if (absl::c_linear_search(batch_dimensions, bcast_dim)) {
+ return absl::OkStatus();
+ }
+ }
+
+ // Then check that all batch dimensions are being broadcast, and that
+ // there is at least one batch dimension.
+ CHECK_GT(num_bcast_dims, 0);
+ if (num_bcast_dims != num_batch_dims) {
+ return absl::OkStatus();
+ }
+
+ if (bcast_operand_index == 1) {
+ CHECK_EQ(dim_nums->rhs_contracting_dimensions_size(), 1);
+ dim_nums->set_rhs_contracting_dimensions(
+ 0, dim_nums->rhs_contracting_dimensions(0) - num_batch_dims);
+ dim_nums->clear_rhs_batch_dimensions();
+ } else {
+ CHECK_EQ(dim_nums->lhs_contracting_dimensions_size(), 1);
+ dim_nums->set_lhs_contracting_dimensions(
+ 0, dim_nums->lhs_contracting_dimensions(0) - num_batch_dims);
+ dim_nums->clear_lhs_batch_dimensions();
+ }
+ TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape(
+ bcast_operand_index, bcast->mutable_operand(0)));
+ TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
+ MarkAsChanged();
+ }
+ return absl::OkStatus();
+ }
+};
+
+static absl::StatusOr<bool> RunOnComputation(HloComputation *computation) {
+ GemmBroadcastFoldingVisitor visitor;
+ TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+ return visitor.changed();
+}
+
+absl::StatusOr<bool> GemmBroadcastFoldingRewriter::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ bool changed = false;
+ for (HloComputation *computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h
new file mode 100644
index 0000000..8606136
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h
@@ -0,0 +1,51 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// cuBLAS GEMM has support for strided batched calls, where the stride is used
+// to determine the offset between the batches.
+//
+// This allows (kCustomCall:gemm A kBroadcast(B)) or
+// (kCustomCall:gemm kBroadcast(A) B)
+// to be rewritten as (kCustomCall:gemm A B) with a zero stride for the
+// broadcasted operand if the broadcast operates on all the batch dimensions.
+//
+// This pattern matches the above case and removes the unnecessary broadcast.
+class GemmBroadcastFoldingRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "cublas-gemm-broadcast-folding-rewriter";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc
new file mode 100644
index 0000000..57e68fc
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc
@@ -0,0 +1,230 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
+
+#include <memory>
+
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest {
+ protected:
+ const auto& GpuComputeComp() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .gpu_compute_capability();
+ }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+ // These tests test the cuBLAS rewriter so we have to make sure that we use
+ // cuBLAS for them.
+ debug_options.set_xla_gpu_enable_triton_gemm(false);
+ debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+ return debug_options;
+ }
+};
+
+TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteRhs) {
+ const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+ x = f32[3,2,2]{2,1,0} parameter(0)
+ y = f32[2,2]{1,0} parameter(1)
+ y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
+ ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,2], {{.*}}: f32[2,2]) -> f32[3,2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["2"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":["0"]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteLhs) {
+ const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+ x = f32[2,2]{1,0} parameter(0)
+ y = f32[3,2,2]{2,1,0} parameter(1)
+ x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
+ ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[3,2,2]) -> f32[3,2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK : custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
+; CHECK : backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest,
+ BroadcastedStridedRewriteRhsPassChanged) {
+ const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+ x = f32[3,2,2]{2,1,0} parameter(0)
+ y = f32[2,2]{1,0} parameter(1)
+ y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
+ ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ // Use GemmRewriter to generate cublasGemm call.
+ GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ this->RunHloPass(&gemm_rewriter, module.get()));
+ EXPECT_TRUE(changed);
+ GemmBroadcastFoldingRewriter pass;
+ TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest,
+ BroadcastedStridedRewriteLhsPassChanged) {
+ const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+ x = f32[2,2]{1,0} parameter(0)
+ y = f32[3,2,2]{2,1,0} parameter(1)
+ x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
+ ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ // Use GemmRewriter to generate cublasGemm call.
+ GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ this->RunHloPass(&gemm_rewriter, module.get()));
+ EXPECT_TRUE(changed);
+ GemmBroadcastFoldingRewriter pass;
+ TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest, LHSBatchDimNonZero) {
+ const char* hlo_text = R"(
+HloModule LHSBatchDimNonZero
+
+ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
+ %Arg_1 = f32[4,3]{1,0} parameter(0)
+ %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
+ %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
+ ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[7,4,3]{2,1,0} %broadcast.22, f32[4,7,3]{2,1,0} %Arg_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
+}
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ // Use GemmRewriter to generate cublasGemm call.
+ GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ this->RunHloPass(&gemm_rewriter, module.get()));
+ EXPECT_TRUE(changed);
+ GemmBroadcastFoldingRewriter pass;
+ TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest, RHSBatchDimNonZero) {
+ const char* hlo_text = R"(
+HloModule RHSBatchDimNonZero
+
+ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
+ %Arg_1 = f32[4,3]{1,0} parameter(0)
+ %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
+ %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
+ ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[4,7,3]{2,1,0} %Arg_2, f32[7,4,3]{2,1,0} %broadcast.22), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2}
+}
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ this->RunHloPass(&gemm_rewriter, module.get()));
+ EXPECT_TRUE(changed);
+ GemmBroadcastFoldingRewriter pass;
+ TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_FALSE(changed);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc
new file mode 100644
index 0000000..0801830
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc
@@ -0,0 +1,815 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <queue>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_padding_requirements.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/triton_fusion_analysis.h"
+#include "xla/service/gpu/triton_tiling_propagation.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+using triton_fusion::CombineDotRequirements;
+using triton_fusion::DimensionOrder;
+using triton_fusion::DimOrderMap;
+using triton_fusion::DimOrdersAndReqs;
+using triton_fusion::DimOrdersAndReqsOrError;
+using triton_fusion::DotProperties;
+using triton_fusion::DotRequirements;
+using triton_fusion::DotRequirementsOrError;
+using triton_fusion::FusionContext;
+using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible;
+using triton_fusion::TransformDirection;
+
+// This represents a directed graph.
+class AdjacencyList {
+ public:
+ using NodeId = int64_t;
+
+ NodeId AddNode() {
+ adj_.emplace_back();
+ return adj_.size() - 1;
+ }
+
+ const std::vector<NodeId>& GetOutNeighbors(NodeId node_id) const {
+ return adj_.at(node_id);
+ }
+
+ void ReserveSpaceForOutNeighbors(NodeId node_id, size_t count) {
+ adj_.at(node_id).reserve(count);
+ }
+
+ void AddArc(NodeId from, NodeId to) { adj_.at(from).push_back(to); }
+
+ // Currently the Root node is the node which was added first.
+ NodeId GetRoot() const {
+ CHECK(!adj_.empty());
+ return 0;
+ }
+
+ private:
+ // Adjacency list: A vector of out-neighbors for each node.
+ std::vector<std::vector<NodeId>> adj_;
+};
+
+struct HloAndDimOrder {
+ const HloInstruction* original_hlo = nullptr;
+ DimensionOrder dim_order;
+};
+
+struct HloAndIterSpec {
+ const HloInstruction* original_hlo;
+ TensorIterationSpec iter_spec;
+
+ auto ToTuple() const { return std::make_tuple(original_hlo, iter_spec); }
+ bool operator==(const HloAndIterSpec& other) const {
+ return ToTuple() == other.ToTuple();
+ }
+ template <typename H>
+ friend H AbslHashValue(H h, const HloAndIterSpec& key) {
+ return H::combine(std::move(h), key.ToTuple());
+ }
+};
+
+struct NodeFusionPlan {
+ const HloInstruction* original_hlo = nullptr;
+ bool should_fuse = false;
+};
+
+struct FusionPlan {
+ // The graph describing the structure of the fusion that we build - nodes
+ // corresponding to the instructions and arcs pointing from users to operands.
+ AdjacencyList graph;
+ // The fusion plan for each node.
+ absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> map;
+};
+
+struct FusionPlanAndRequirements {
+ FusionPlan fusion_plan;
+ DotRequirements requirements;
+};
+
+struct HlosAndRequirements {
+ // The original HLO (which is outside the fusion computation).
+ const HloInstruction* original_hlo = nullptr;
+ // The fused HLO inside the new fusion computation, built by the builder.
+ //
+ // This can have the same opcode as `original_hlo` or it can be a parameter if
+ // the original HLO can't be fused.
+ const HloInstruction* fused_hlo = nullptr;
+ // The requirements imposed by the fused operations.
+ //
+ // If we fuse further operations they may have to conform to these
+ // requirements.
+ DotRequirements requirements;
+};
+
+// Clones the hero kDot operation into the fusion.
+HloInstruction& FuseDot(const HloDotInstruction& dot,
+ const HloInstruction& fused_lhs,
+ const HloInstruction& fused_rhs,
+ std::optional<const HloInstruction*> fused_meta,
+ HloComputation::Builder& builder // append
+) {
+ VLOG(3) << "Fusing " << dot.ToString();
+
+ std::vector<HloInstruction*> hlo_new_operands = {
+ const_cast<HloInstruction*>(&fused_lhs),
+ const_cast<HloInstruction*>(&fused_rhs)};
+ if (fused_meta.has_value()) {
+ hlo_new_operands.push_back(const_cast<HloInstruction*>(fused_meta.value()));
+ }
+ return *builder.AddInstruction(
+ dot.CloneWithNewOperands(dot.shape(), hlo_new_operands));
+}
+
+// Tells how many new parameters does a fusion gain by fusing the operation as
+// an input.
+int64_t NumAddedParameters(const HloInstruction& hlo) {
+ // Non-scalar constant is equivalent to a parameter: one input, one output.
+ if (hlo.opcode() == HloOpcode::kParameter ||
+ (hlo.opcode() == HloOpcode::kConstant &&
+ !ShapeUtil::IsScalar(hlo.shape()))) {
+ return 0;
+ }
+ // All other instructions add all own inputs and remove own single output.
+ return hlo.operand_count() - 1;
+}
+
+// Just a helper to reduce "unwrapping" code where we use this.
+std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqs(
+ const HloInstruction& hlo, const DimensionOrder& dim_order,
+ const DotProperties& properties,
+ const se::GpuComputeCapability& gpu_version,
+ const DotRequirements& requirements) {
+ DimOrdersAndReqsOrError dim_orders_and_new_reqs =
+ GetPropagatedDimOrdersAndRequirements(
+ hlo, dim_order, TransformDirection::kOutputToInput, properties);
+ if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
+ return std::nullopt;
+ }
+ DotRequirementsOrError combined_reqs = CombineDotRequirements(
+ requirements,
+ std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
+ if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
+ return std::nullopt;
+ }
+ return DimOrdersAndReqs{
+ std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
+ std::get<DotRequirements>(combined_reqs)};
+}
+
+// Just a helper to reduce "unwrapping" code where we use this.
+std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqsIfProfitable(
+ const HloInstruction& hlo, const DimensionOrder& dim_order,
+ const DotProperties& properties,
+ const se::GpuComputeCapability& gpu_version,
+ const DotRequirements& requirements) {
+ DimOrdersAndReqsOrError dim_orders_and_new_reqs =
+ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
+ hlo, TransformDirection::kOutputToInput,
+ /*src_operand_index=*/std::nullopt, dim_order, gpu_version,
+ properties);
+ if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
+ return std::nullopt;
+ }
+ DotRequirementsOrError combined_reqs = CombineDotRequirements(
+ requirements,
+ std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
+ if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
+ return std::nullopt;
+ }
+ return DimOrdersAndReqs{
+ std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
+ std::get<DotRequirements>(combined_reqs)};
+}
+
+// Just a helper to reduce "unwrapping" code where we use this.
+std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
+ const HloInstruction& hlo, const DimensionOrder& hlo_dim_order,
+ const HloInstruction& user, const DotProperties& properties,
+ const se::GpuComputeCapability& gpu_version,
+ const DotRequirements& requirements) {
+ DimOrdersAndReqsOrError dim_orders_and_new_reqs =
+ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
+ user, TransformDirection::kInputToOutput, user.operand_index(&hlo),
+ hlo_dim_order, gpu_version, properties);
+ if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
+ return std::nullopt;
+ }
+ DotRequirementsOrError combined_reqs = CombineDotRequirements(
+ requirements,
+ std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
+ if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
+ return std::nullopt;
+ }
+ return DimOrdersAndReqs{
+ std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
+ std::get<DotRequirements>(combined_reqs)};
+}
+
+// Builds the fusion map and the requirements which can later be used to
+// actually fuse that subgraph.
+FusionPlanAndRequirements BuildFusionPlanTowardOperands(
+ const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
+ const std::optional<int>& max_params,
+ const se::GpuComputeCapability& gpu_version,
+ const DotProperties& properties,
+ const DotRequirements& requirements_so_far) {
+ CHECK(!max_params.has_value() || max_params.value() >= 1);
+
+ // The graph describing the structure of the fusion that we build - nodes
+ // corresponding to the instructions and arcs pointing from users to operands.
+ // We can build and modify this graph easily without the need to create
+ // HloInstructions at this point.
+ AdjacencyList graph;
+ // Stores the original HLO and the dimension order for each node. This is a
+ // temporary map which is used when processing the nodes in this function.
+ absl::flat_hash_map<AdjacencyList::NodeId, HloAndDimOrder>
+ hlo_and_dim_order_map;
+ // Stores the information needed to build the fused HLO for each node (what
+ // was the original HLO and whether we should fuse it or create a parameter).
+ // This is one of the outputs of this function.
+ absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> fusion_plan_map;
+ // Allows reusing nodes when multiple instructions iterate over the same HLO
+ // using the same iteration spec. In that case we don't duplicate the
+ // instruction in the fusion.
+ absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map;
+ // The requirements imposed by the fusion choices made in this function,
+ // combined with the existing requirements. This is one of the outputs of this
+ // function.
+ DotRequirements combined_reqs = requirements_so_far;
+
+ auto get_or_create_fusion_node =
+ [&](const HloInstruction& hlo, const DimensionOrder& dim_order,
+ bool* is_new_node = nullptr) -> AdjacencyList::NodeId {
+ HloAndIterSpec reuse_key = {&hlo, dim_order.ToTensorIterationSpec()};
+ if (auto it = node_reuse_map.find(reuse_key); it != node_reuse_map.end()) {
+ if (is_new_node != nullptr) {
+ *is_new_node = false;
+ }
+ return it->second;
+ }
+ AdjacencyList::NodeId node_id = graph.AddNode();
+ CHECK(hlo_and_dim_order_map.insert({node_id, {&hlo, dim_order}}).second);
+ CHECK(node_reuse_map.insert({reuse_key, node_id}).second);
+ if (is_new_node != nullptr) {
+ *is_new_node = true;
+ }
+ return node_id;
+ };
+ AdjacencyList::NodeId root =
+ get_or_create_fusion_node(root_hlo, root_dim_order);
+
+ // Nodes at the fusion edge that can either get fused too or become parameters
+ // of the fusion. Used to track the number of parameters.
+ absl::flat_hash_set<AdjacencyList::NodeId> inputs({root});
+ std::queue<AdjacencyList::NodeId> queue({root});
+ int64_t num_requeued = 0;
+ // BFS
+ while (queue.size() > num_requeued) {
+ AdjacencyList::NodeId node_id = queue.front();
+ queue.pop();
+ const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
+ const HloInstruction& original_hlo = *hlo_and_dim_order.original_hlo;
+ const DimensionOrder& dim_order = hlo_and_dim_order.dim_order;
+
+ // Watch the total number of fusion parameters.
+ if (max_params.has_value() &&
+ inputs.size() + NumAddedParameters(original_hlo) > max_params.value()) {
+ // Re-queue: the number of parameters may go down when other instructions
+ // are processed.
+ queue.push(node_id);
+ // Prevent infinite loops.
+ ++num_requeued;
+ continue;
+ }
+ num_requeued = 0;
+ if (original_hlo.opcode() == HloOpcode::kParameter) {
+ CHECK(fusion_plan_map
+ .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
+ .second);
+ continue;
+ }
+ auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(
+ original_hlo, dim_order, properties, gpu_version, combined_reqs);
+ if (!opt_result.has_value()) {
+ CHECK(fusion_plan_map
+ .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
+ .second);
+ continue;
+ }
+ const DimOrderMap operand_dim_orders = std::move(opt_result->dim_orders);
+ combined_reqs = std::move(opt_result->requirements);
+ inputs.erase(node_id);
+ graph.ReserveSpaceForOutNeighbors(node_id, original_hlo.operand_count());
+ for (int64_t i = 0; i < original_hlo.operand_count(); ++i) {
+ const HloInstruction& operand = *original_hlo.operand(i);
+ const DimensionOrder& operand_dim_order = operand_dim_orders.at(&operand);
+ bool is_new_node = false;
+ AdjacencyList::NodeId operand_node_id =
+ get_or_create_fusion_node(operand, operand_dim_order, &is_new_node);
+ graph.AddArc(node_id, operand_node_id);
+ if (is_new_node) {
+ VLOG(6) << "Enqueueing " << operand.ToString() << ":"
+ << operand_dim_order.ToString();
+ inputs.insert(operand_node_id);
+ queue.push(operand_node_id);
+ }
+ }
+ CHECK(
+ fusion_plan_map.insert({node_id, {&original_hlo, /*should_fuse=*/true}})
+ .second);
+ }
+ // Handle the remaining requeued items.
+ while (!queue.empty()) {
+ AdjacencyList::NodeId node_id = queue.front();
+ queue.pop();
+
+ const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
+ CHECK(fusion_plan_map
+ .insert({node_id,
+ {hlo_and_dim_order.original_hlo, /*should_fuse=*/false}})
+ .second);
+ }
+ return {{std::move(graph), std::move(fusion_plan_map)},
+ std::move(combined_reqs)};
+}
+
+// Builds the HLO instructions for the fusion represented by `fusion_plan`,
+// starting from `node_id`.
+HloInstruction& BuildFusionTowardOperandsImpl(
+ AdjacencyList::NodeId node_id, const FusionPlan& fusion_plan,
+ absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*>&
+ fused_hlo_map, // read/append
+ HloComputation::Builder& builder, // append
+ std::vector<HloInstruction*>& fusion_params // append
+) {
+ if (auto it = fused_hlo_map.find(node_id); it != fused_hlo_map.end()) {
+ return *it->second;
+ }
+
+ const NodeFusionPlan& node_fusion_plan = fusion_plan.map.at(node_id);
+ const bool should_fuse = node_fusion_plan.should_fuse;
+ const HloInstruction& original_hlo = *node_fusion_plan.original_hlo;
+
+ HloInstruction* fused_hlo = nullptr;
+ if (should_fuse) {
+ HloInstruction::InstructionVector new_operands;
+ for (AdjacencyList::NodeId operand_id :
+ fusion_plan.graph.GetOutNeighbors(node_id)) {
+ new_operands.push_back(&BuildFusionTowardOperandsImpl(
+ operand_id, fusion_plan, fused_hlo_map, builder, fusion_params));
+ }
+ fused_hlo = builder.AddInstruction(
+ original_hlo.CloneWithNewOperands(original_hlo.shape(), new_operands));
+ } else {
+ fusion_params.push_back(const_cast<HloInstruction*>(&original_hlo));
+ fused_hlo = builder.AddInstruction(HloInstruction::CreateParameter(
+ fusion_params.size() - 1, original_hlo.shape(),
+ absl::StrCat("parameter_", fusion_params.size() - 1)));
+ }
+
+ CHECK(fused_hlo_map.insert({node_id, fused_hlo}).second);
+ return *fused_hlo;
+}
+
+// Builds the HLO instructions for the fusion represented by `fusion_plan`.
+HloInstruction& BuildFusionTowardOperands(
+ const FusionPlan& fusion_plan,
+ HloComputation::Builder& builder, // append
+ std::vector<HloInstruction*>& fusion_params // append
+) {
+ absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*> fused_hlo_map;
+ return BuildFusionTowardOperandsImpl(fusion_plan.graph.GetRoot(), fusion_plan,
+ fused_hlo_map, builder, fusion_params);
+}
+
+// Grows the fusion toward the operands.
+//
+// This always succeeds.
+//
+// If it's not possible to fuse something, it fuses a parameter instead.
+//
+// The fusion can grow until it has `max_params` params and it can only grow
+// with operations for which the DimOrder propagation works and they don't
+// impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to `root_hlo` and the
+// requirements corresponding to the whole fusion so far.
+HlosAndRequirements FuseTowardOperands(
+ const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
+ const std::optional<int>& max_params,
+ const se::GpuComputeCapability& gpu_version,
+ const DotProperties& properties, const DotRequirements& requirements_so_far,
+ HloComputation::Builder& builder, // append
+ std::vector<HloInstruction*>& fusion_params // append
+) {
+ FusionPlanAndRequirements fusion_plan_and_reqs =
+ BuildFusionPlanTowardOperands(root_hlo, root_dim_order, max_params,
+ gpu_version, properties,
+ requirements_so_far);
+ HloInstruction& fused_hlo_or_param = BuildFusionTowardOperands(
+ fusion_plan_and_reqs.fusion_plan, builder, fusion_params);
+ return HlosAndRequirements{&root_hlo, &fused_hlo_or_param,
+ fusion_plan_and_reqs.requirements};
+}
+
+// Grows the fusion toward the given dot operand.
+//
+// This always succeeds.
+//
+// If it's not possible to fuse something, it fuses a parameter instead.
+//
+// The fusion can grow until it has `max_params` params and it can only grow
+// with operations for which the DimOrder propagation works and they don't
+// impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to the given dot operand and
+// the requirements corresponding to the whole fusion so far.
+absl::StatusOr<HlosAndRequirements> FuseDotOperand(
+ const HloInstruction& dot, int operand_index,
+ const se::GpuComputeCapability& gpu_version,
+ HloComputation::Builder& builder, // append
+ std::vector<HloInstruction*>& fusion_params // append
+) {
+ // Direct dot inputs have well defined dimension orders.
+ TF_ASSIGN_OR_RETURN(const FusionContext context,
+ FusionContext::FromDotOperand(dot, operand_index));
+ const HloInstruction& operand = *dot.operand(operand_index);
+ return FuseTowardOperands(operand, context.dim_orders().at(&operand),
+ TritonFusionAnalysis::kMaxParameterPerDotOperand,
+ gpu_version, context.dot_properties(),
+ context.requirements(), builder, fusion_params);
+}
+
+// Grows the fusion toward the users.
+//
+// This always succeeds.
+//
+// The fusion can grow as long as the DimOrder propagation works and the users
+// don't impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to the "lowest" fused user
+// or `hlo` if no users can be fused.
+//
+// It also grows the fusion upward, toward the "other" operands of the users,
+// but currently only in special cases, such as binary elementwise operation
+// with broadcast of scalar constant.
+HlosAndRequirements FuseTowardUsers(
+ const HloInstruction& hlo, const HloInstruction& fused_hlo,
+ const DimensionOrder& hlo_dim_order,
+ const se::GpuComputeCapability& gpu_version,
+ const DotProperties& properties, const DotRequirements& requirements,
+ HloComputation::Builder& builder, // append
+ std::vector<HloInstruction*>& fusion_params // append
+) {
+ const HlosAndRequirements existing_hlos_and_requirements = {&hlo, &fused_hlo,
+ requirements};
+ if (hlo.user_count() != 1) {
+ return existing_hlos_and_requirements;
+ }
+ const HloInstruction& user = *hlo.users()[0];
+ if (!legacy_triton::IsDistributiveOverAddition(user)) {
+ return existing_hlos_and_requirements;
+ }
+
+ // Get the dim orders for the user.
+ auto opt_user_result = GetUserDimOrdersAndCombinedReqsIfProfitable(
+ hlo, hlo_dim_order, user, properties, gpu_version, requirements);
+ if (!opt_user_result.has_value()) {
+ return existing_hlos_and_requirements;
+ }
+ DimensionOrder user_dim_order = opt_user_result->dim_orders.at(&user);
+ DotRequirements combined_requirements = opt_user_result->requirements;
+
+ HloInstruction::InstructionVector new_operands;
+ if (user.operand_count() == 1) {
+ new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
+ } else {
+ // Get the dim orders for the operands of the user.
+ // We shouldn't do a profitability check here, we made that decision in
+ // GetUserDimOrdersAndCombinedReqsIfProfitable.
+ auto opt_operand_result = GetOperandDimOrdersAndCombinedReqs(
+ user, user_dim_order, properties, gpu_version, combined_requirements);
+ // This shouldn't fail, because currently we only encounter this when we
+ // have just propagated down the DimOrders on a binary elementwise
+ // operation (user). In that case propagating up the DimOrders should always
+ // work.
+ if (!opt_operand_result.has_value()) {
+ return existing_hlos_and_requirements;
+ }
+ DimOrderMap operand_dim_orders = opt_operand_result->dim_orders;
+ combined_requirements = opt_operand_result->requirements;
+
+ // Fuse the other operands of the user.
+ for (int i = 0; i < user.operand_count(); ++i) {
+ const HloInstruction& operand = *user.operand(i);
+ if (&operand == &hlo) {
+ new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
+ } else {
+ HlosAndRequirements hlos_and_requirements = FuseTowardOperands(
+ operand, operand_dim_orders.at(&operand),
+ /*max_params=*/std::nullopt, gpu_version, properties,
+ combined_requirements, builder, fusion_params);
+ new_operands.push_back(
+ const_cast<HloInstruction*>(hlos_and_requirements.fused_hlo));
+ combined_requirements = hlos_and_requirements.requirements;
+ }
+ }
+ }
+
+ const HloInstruction& fused_user = *builder.AddInstruction(
+ user.CloneWithNewOperands(user.shape(), new_operands));
+ return FuseTowardUsers(user, fused_user, user_dim_order, gpu_version,
+ properties, combined_requirements, builder,
+ fusion_params);
+}
+
+// Grows the fusion toward the users of the dot.
+//
+// This always succeeds.
+//
+// The fusion can grow as long as the DimOrder propagation works and the users
+// don't impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to the "lowest" fused user
+// or `dot` if no users can be fused.
+//
+// It also grows the fusion towards the "other" operands of the users, but
+// currently only in special cases, such as binary elementwise operation with
+// broadcast of scalar constant.
+HlosAndRequirements FuseDotOutput(
+ const HloInstruction& dot, const HloInstruction& fused_dot,
+ const se::GpuComputeCapability& gpu_version,
+ const DotRequirements& requirements,
+ HloComputation::Builder& builder, // append
+ std::vector<HloInstruction*>& fusion_params // append
+) {
+ const auto context =
+ FusionContext::FromDotOutput(dot, /*split_k=*/1, requirements);
+ return FuseTowardUsers(dot, fused_dot, context.dim_orders().at(&dot),
+ gpu_version, context.dot_properties(),
+ context.requirements(), builder, fusion_params);
+}
+
+// Fuses dot and the compatible and profitable to fuse operations around it
+// into a new fusion computation constructed using the builder. fusion_inputs
+// get populated with the non-fused instructions that become operands of the
+// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the
+// original instruction that has to be replaced by the call to the fusion.
+absl::StatusOr<FusionDecision> CreateDotFusion(
+ const HloDotInstruction& dot, const se::GpuComputeCapability gpu_version,
+ HloComputation::Builder& builder,
+ std::vector<HloInstruction*>& fusion_inputs,
+ HloInstruction** fusion_output_ptr) {
+ VLOG(5) << dot.ToString();
+ if (CodegenDecision is_supported =
+ legacy_triton::IsTritonSupportedInstruction(dot, gpu_version);
+ !is_supported) {
+ VLOG(3) << is_supported.Explain();
+ return is_supported;
+ }
+
+ // Verify sparse dot constraints.
+ if (dot.sparse_operands()) {
+ const SparsityDescriptor& descriptor = dot.sparsity().front();
+ if (dot.sparse_operands() != 1 || descriptor.index() != 0) {
+ return InvalidArgument("Sparsity is only supported on left operand");
+ }
+ if (descriptor.type() != SparsityType::SPARSITY_STRUCTURED_N_M ||
+ descriptor.n() != 2 || descriptor.m() != 4) {
+ return InvalidArgument("Only 2:4 structured sparsity is supported");
+ }
+ // DotDimensionSorter pass makes sure the sparse dimension is minor.
+ CHECK_EQ(descriptor.dimension(), dot.operand(0)->shape().rank() - 1);
+ }
+
+ TF_ASSIGN_OR_RETURN(HlosAndRequirements lhs_hlos_and_reqs,
+ FuseDotOperand(dot, /*operand_index=*/0, gpu_version,
+ builder, fusion_inputs));
+ TF_ASSIGN_OR_RETURN(HlosAndRequirements rhs_hlos_and_reqs,
+ FuseDotOperand(dot, /*operand_index=*/1, gpu_version,
+ builder, fusion_inputs));
+ std::optional<const HloInstruction*> meta_hlo;
+ if (dot.sparse_operands()) {
+ TF_ASSIGN_OR_RETURN(HlosAndRequirements meta_hlos_and_reqs,
+ FuseDotOperand(dot, /*operand_index=*/2, gpu_version,
+ builder, fusion_inputs));
+ meta_hlo.emplace(meta_hlos_and_reqs.fused_hlo);
+ }
+ HloInstruction& fused_dot =
+ FuseDot(dot, *lhs_hlos_and_reqs.fused_hlo, *rhs_hlos_and_reqs.fused_hlo,
+ meta_hlo, builder);
+ // For now the RHS doesn't support splits, so it also doesn't impose any
+ // requirements.
+ HlosAndRequirements fused_output_and_reqs =
+ FuseDotOutput(dot, fused_dot, gpu_version, lhs_hlos_and_reqs.requirements,
+ builder, fusion_inputs);
+
+ if (fusion_output_ptr != nullptr) {
+ *fusion_output_ptr =
+ const_cast<HloInstruction*>(fused_output_and_reqs.original_hlo);
+ }
+
+ const PrecisionConfig::Algorithm algorithm =
+ dot.precision_config().algorithm();
+ if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
+ algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
+ dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
+ dot.sparse_operands()) {
+ return FusionDecision{};
+ }
+
+ bool is_pure_matmul = true;
+ (void)builder.ForEachInstruction([&](const HloInstruction* fused_hlo) {
+ static constexpr std::array<HloOpcode, 4> kPureOpcodes = {
+ HloOpcode::kBitcast, HloOpcode::kDot, HloOpcode::kParameter,
+ HloOpcode::kReshape};
+ if (absl::c_find(kPureOpcodes, fused_hlo->opcode()) == kPureOpcodes.end()) {
+ is_pure_matmul = false;
+ // Stop iterating.
+ return absl::CancelledError();
+ }
+ return absl::OkStatus();
+ });
+ if (!is_pure_matmul) {
+ return FusionDecision{};
+ }
+
+ return "No profitable operations to fuse.";
+}
+
+// Extracts into fused computations parts of HLO graph including dot()
+// operations that can target the triton GEMM emitter.
+class GemmFusionVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit GemmFusionVisitor(const se::GpuComputeCapability& gpu_version)
+ : gpu_version_(gpu_version) {}
+ // Checks that a dot() should be targeting the triton GEMM emitter;
+ // if so - fuses all its compatible inputs and outputs as a new computation
+ // and replaces the original dot() with a call to the computation.
+ absl::Status HandleDot(HloInstruction* dot) override {
+ CHECK_EQ(dot->opcode(), HloOpcode::kDot);
+
+ int64_t gemm_rewrite_size_threshold =
+ dot->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_gemm_rewrite_size_threshold();
+ TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
+ IsMatrixMultiplicationTooSmallForRewriting(
+ *dot, gemm_rewrite_size_threshold));
+ if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*dot)) {
+ return absl::OkStatus();
+ }
+
+ std::string fusion_name = absl::StrCat("gemm_fusion_", dot->name());
+ HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation"));
+ std::vector<HloInstruction*> fusion_inputs;
+ HloInstruction* fusion_output = nullptr;
+ TF_ASSIGN_OR_RETURN(
+ const FusionDecision should_fuse,
+ CreateDotFusion(*Cast<HloDotInstruction>(dot), gpu_version_, builder,
+ fusion_inputs, &fusion_output));
+ if (builder.last_added_instruction() == nullptr) {
+ return absl::OkStatus();
+ }
+ // If a GEMM requiring padding for cuBLAS is encountered here this
+ // happened because earlier ShouldTritonHandleGEMM() accepted it and padding
+ // was skipped. Accept it ignoring profitability checks.
+ // TODO(rocm): check ROCM padding requirements.
+ if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
+ if (!CublasRequiresPadding(
+ *Cast<HloDotInstruction>(dot),
+ std::get<se::CudaComputeCapability>(gpu_version_)) &&
+ !should_fuse) {
+ return absl::OkStatus();
+ }
+ }
+
+ HloComputation* computation =
+ dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
+ /*is_entry=*/false);
+ HloInstruction* dot_fusion =
+ dot->parent()->AddInstruction(HloInstruction::CreateFusion(
+ computation->root_instruction()->shape(),
+ HloInstruction::FusionKind::kCustom, fusion_inputs, computation));
+ // Copy the metadata of the `dot` to the newly created `fusion` op. This
+ // is convenient for handling metadata in split-k rewriting subsequently.
+ dot_fusion->set_metadata(dot->metadata());
+ dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name);
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ dot_fusion->backend_config<GpuBackendConfig>());
+ FusionBackendConfig& backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+ backend_config.set_kind(std::string(kTritonGemmFusionKind));
+ TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(gpu_config));
+
+ if (fusion_output->IsRoot()) {
+ fusion_output->parent()->set_root_instruction(dot_fusion);
+ TF_RETURN_IF_ERROR(
+ fusion_output->parent()->RemoveInstructionAndUnusedOperands(
+ fusion_output));
+ MarkAsChanged();
+ } else {
+ TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion));
+ }
+ XLA_VLOG_LINES(5, computation->ToString(HloPrintOptions::ShortParsable()));
+ return absl::OkStatus();
+ }
+
+ private:
+ se::GpuComputeCapability gpu_version_;
+};
+
+absl::StatusOr<bool> RunOnComputation(
+ HloComputation* computation, const se::GpuComputeCapability& gpu_version) {
+ GemmFusionVisitor visitor(gpu_version);
+ TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+ return visitor.changed();
+}
+
+
+} // namespace
+
+bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
+ const se::GpuComputeCapability& gpu_version) {
+ std::vector<HloInstruction*> fusion_inputs;
+ HloComputation::Builder builder("disposable");
+ return CreateDotFusion(dot, gpu_version, builder, fusion_inputs,
+ /*fusion_output_ptr=*/nullptr)
+ ->CanFuse();
+}
+
+absl::StatusOr<bool> GemmFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ TF_RETURN_IF_ERROR(
+ EnsureTritonSupportsComputeCapability(compute_capability_));
+
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result,
+ RunOnComputation(computation, compute_capability_));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h
new file mode 100644
index 0000000..7f8fe6f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h
@@ -0,0 +1,57 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
+
+// This file contains the code for fusing dots and other operations into Triton
+// GEMM fusions.
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Filters GEMMs which are better to handle using Triton.
+bool ShouldTritonHandleGEMM(HloDotInstruction&,
+ const se::GpuComputeCapability&);
+
+// Rewrite compatible dot() calls into custom calls with fused computations
+// that target Triton-based matmul emitter.
+class GemmFusion : public HloModulePass {
+ public:
+ explicit GemmFusion(const se::GpuComputeCapability& compute_capability)
+ : compute_capability_(compute_capability) {}
+ absl::string_view name() const override { return "triton-gemm-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::GpuComputeCapability compute_capability_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc
new file mode 100644
index 0000000..85ad2e8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc
@@ -0,0 +1,1334 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/cublas_padding_requirements.h"
+#include "xla/service/gpu/triton_fusion_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::FieldsAre;
+
+namespace m = ::xla::match;
+
+class GemmFusionTest : public HloTestBase {
+ public:
+ GemmFusionTest()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {}
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_triton_gemm_any(false);
+ debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+ return debug_options;
+ }
+
+ se::GpuComputeCapability gpu_version_{
+ se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}};
+
+ void MatchHloModule(HloModule& module, absl::string_view pattern) {
+ TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result,
+ RunFileCheck(module.ToString(), pattern));
+ EXPECT_TRUE(filecheck_result);
+ }
+};
+
+TEST_F(GemmFusionTest, TransposeSubdimensionGroup) {
+ // This HLO is artificial because unnecessary reshapes get optimized
+ // out during compilation. It tests the ability of GemmFusion
+ // to handle transposes of groups of subdimensions.
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+ p0 = f32[32,3] parameter(0)
+ t1 = f32[3,32] transpose(p0), dimensions={1,0}
+ r1 = f32[3,8,4] reshape(t1)
+ r0 = f32[3,32] reshape(r1)
+ p1 = f16[32,7] parameter(1)
+ c1 = f32[32,7] convert(p1)
+ ROOT d = f32[3,7] dot(r0, c1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+ .value();
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, UnsupportedTransposeIsNotFused) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f16[1,512,8,1024]{3,1,0,2} parameter(0)
+ c = f16[1,512,8,1024]{3,2,1,0} copy(p0)
+ b = f16[4096,1024]{1,0} bitcast(c)
+ p1 = f16[128,1024]{1,0} parameter(1)
+ ROOT d = f16[4096,128]{1,0} dot(b, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})")
+ .value();
+ EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, BitcastChain) {
+ // This HLO is artificial because unnecessary reshapes get optimized
+ // out during compilation. It tests the ability of GemmFusion
+ // to handle various kinds of bitcasts.
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+ p0 = s8[60,5] parameter(0)
+ r0 = s8[3,20,5] reshape(p0)
+ c0 = f16[3,20,5] convert(r0)
+ p1 = f16[3,200] parameter(1)
+ r12 = f16[600] reshape(p1)
+ r11 = f16[30,20] reshape(r12)
+ r1 = f16[3,10,20] reshape(r11)
+ ROOT d = f16[3,5,10] dot(c0, r1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={2},
+ lhs_batch_dims={0}, rhs_batch_dims={0}
+})")
+ .value();
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, SplitDimensionTwice) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = s8[4,2,32,4,2] parameter(0)
+ r1 = s8[8,32,8] reshape(p0)
+ t1 = s8[32,8,8] transpose(r1), dimensions={1,0,2}
+ r0 = s8[32,64] reshape(t1)
+ p1 = s8[32,32] parameter(1)
+ c0 = f16[32,32] convert(p1)
+ ROOT d = f16[64,32] dot(r0, c0),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})")
+ .value();
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, DoNotTriggerOnUnsupportedOutputConversions) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f16[128,256] parameter(0)
+ p1 = f16[256,512] parameter(1)
+ r = f16[128,512] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT c = u8[128,512] convert(r)
+})"));
+ EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, FuseDotWithTrivialNoncontractingDim) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+ p0 = s8[60,5] parameter(0)
+ r0 = s8[3,20,5] reshape(p0)
+ c0 = f16[3,20,5] convert(r0)
+ p1 = f16[3,1,20] parameter(1)
+ ROOT d = f16[3,5,1] dot(c0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={2},
+ lhs_batch_dims={0}, rhs_batch_dims={0}
+})")
+ .value();
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, HandleDotIfCublasRequiresPadding) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[5,3] parameter(0)
+ p1 = f16[5,7] parameter(1)
+ ROOT d = f16[3,7] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_TRUE(CublasRequiresPadding(
+ *xla::Cast<HloDotInstruction>(
+ module->entry_computation()->root_instruction()),
+ cc));
+ EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, FuseSliceOfParameterWithOtherUsers) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[97,121] parameter(0)
+ s0 = f32[7,101] slice(p0), slice={[3:10], [10:111]}
+ p1 = f32[101,16] parameter(1)
+ d = f32[16,7] dot(p1, s0),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+ s1 = f32[3,33] slice(p0), slice={[10:13], [20:53]}
+ ROOT t = tuple(d, s1)
+})"));
+
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, DoNotFuseSliceOfMixedDimensions) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = bf16[768,64] parameter(0)
+ s0 = bf16[768,32] slice(p0), slice={[0:768], [0:32]}
+ b0 = bf16[256,3,32] reshape(s0)
+ b1 = bf16[256,96] reshape(b0)
+ p1 = bf16[256,96] parameter(1)
+ ROOT d = bf16[96,96] dot(b1, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, DoNotFuseSlicesOfNonMajorFragments) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[2,2,256,256] parameter(0)
+ s0 = f32[1,1,256,256] slice(p0),
+ slice={[0:1], [0:1], [0:256], [0:256]}
+ r0 = f32[256,256] reshape(s0)
+ p1 = f16[2,2,256,256] parameter(1)
+ s1 = f16[1,1,256,256] slice(p1),
+ slice={[0:1], [0:1], [0:256], [0:256]}
+ r1 = f16[256,256] reshape(s1)
+ ROOT d = f32[256,256] dot(r0, r1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, DynamicSliceIsFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ dot_lhs = f32[2,18] parameter(0)
+ dynamic_slice_input = f32[2,64,2] parameter(1)
+ start_index0 = s32[] parameter(2)
+ start_index1_2 = s32[] constant(0)
+ dynamic_slice = f32[1,64,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1_2, start_index1_2),
+ dynamic_slice_sizes={1,64,2}
+ reshape = f32[64,2] reshape(dynamic_slice)
+ ROOT dot = f16[18,64] dot(dot_lhs, reshape),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+ m::Parameter(), m::Constant()))));
+}
+
+TEST_F(GemmFusionTest, DynamicSlicesAreFusedEvenIfTheyShareIndices) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[2,64,2] parameter(0)
+ p1 = s32[] parameter(1)
+ p2 = s32[] parameter(2)
+ p3 = s32[] parameter(3)
+ ds0 = f32[1,64,2] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,64,2}
+ a = f32[64,2] reshape(ds0)
+ ds1 = f32[1,64,2] dynamic-slice(p0, p3, p2, p1), dynamic_slice_sizes={1,64,2}
+ b = f32[64,2] reshape(ds1)
+ ROOT d = f16[64,64] dot(a, b),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ // TODO(b/339810582): Don't duplicate scalar parameters to dot fusions,
+ // because they are never tiled differently.
+ // TODO(b/339814210): Don't count scalar parameters towards dot fusion
+ // parameter limit.
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
+ m::Parameter(), m::Parameter(), m::Parameter(),
+ m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionTest, DoNotFuseDynamicSliceOfNonMajorFragments) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ dot_lhs = f32[2,4]{1,0} parameter(0)
+ dynamic_slice_input = f32[4,5,2]{2,1,0} parameter(1)
+ c0 = s32[] constant(0)
+ c2 = s32[] constant(2)
+ dynamic_slice = f32[4,1,2]{2,1,0} dynamic-slice(dynamic_slice_input, c0, c2, c0),
+ dynamic_slice_sizes={4,1,2}
+ reshape = f32[4,2]{1,0} reshape(dynamic_slice)
+ ROOT dot = f32[4,4]{1,0} dot(dot_lhs, reshape),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ // FusionDecision "Unsupported dynamic slice on non-major-most dimension."
+ EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ dot_lhs = f32[2,4]{1,0} parameter(0)
+ dynamic_slice_input = f32[5,5]{1,0} parameter(1)
+ start_index0 = s32[] constant(2)
+ start_index1 = s32[] constant(0)
+ dynamic_slice = f32[2,5]{1,0} dynamic-slice(dynamic_slice_input, start_index0, start_index1),
+ dynamic_slice_sizes={2,5}
+ ROOT d = f32[4,5]{1,0} dot(dot_lhs, dynamic_slice),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+ m::Constant(), m::Constant()))));
+}
+
+TEST_F(GemmFusionTest, SliceToDegenerateIsSkipped) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p = f32[3] parameter(0)
+ s = f32[1] slice(p), slice={[2:3]}
+ r = f32[] reshape(s)
+ b = f32[3,3] broadcast(r), dimensions={}
+ ROOT d = f32[3,3] dot(b, b),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)"));
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+
+ ASSERT_TRUE(GemmFusion(cc).Run(module.get()).value());
+
+ // Slice is not fused.
+ MatchHloModule(*module, R"(
+; CHECK-NOT: slice
+; CHECK: ENTRY
+; CHECK: slice
+)");
+}
+
+TEST_F(GemmFusionTest, MultipleUsesAreHandled) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ c = f32[] constant(1)
+ b = f32[6,8] broadcast(c), dimensions={}
+ p0 = f32[6,8] parameter(0)
+ a1 = f32[6,8] add(p0, b)
+ e = f32[6,8] exponential(a1)
+ a2 = f32[6,8] add(e, b)
+ d = f32[6,8] divide(b, a2)
+ p2 = f16[8,6] parameter(1)
+ cv = f32[8,6] convert(p2)
+ ROOT r = f32[6,6] dot(d, cv),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, BinaryElementwiseOfBroadcastIsFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p2 = f32[3072] parameter(2)
+ b = f32[8192,3072] broadcast(p2), dimensions={1}
+ p0 = f16[8192,3072] parameter(0)
+ p0c = f32[8192,3072] convert(p0)
+ a = f32[8192,3072] add(p0c, b)
+ p1 = f32[3072,768] parameter(1)
+ ROOT r = f32[8192,768] dot(a, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, BinaryElementwiseOfUnsupportedBroadcastIsNotFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p2 = f32[768] parameter(2)
+ b = f32[8192,768,4] broadcast(p2), dimensions={1}
+ s = f32[8192,3072] bitcast(b)
+ p0 = f16[8192,3072] parameter(0)
+ p0c = f32[8192,3072] convert(p0)
+ a = f32[8192,3072] add(p0c, s)
+ p1 = f32[3072,768] parameter(1)
+ ROOT r = f32[8192,768] dot(a, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+class GemmFusionLevel2Test : public GemmFusionTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_triton_fusion_level(2);
+ return debug_options;
+ }
+};
+
+TEST_F(GemmFusionTest, ConcatenationDivisibleBy64IsFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = bf16[8192,1]{1,0} parameter(0)
+ p1 = bf16[2752,8192]{1,0} parameter(1)
+ p2 = bf16[2752,8192]{1,0} parameter(2)
+ concat = bf16[5504,8192]{1,0} concatenate(p1, p2), dimensions={0}
+ bitcast = bf16[8192,5504]{0,1} bitcast(concat)
+ ROOT r = f32[1,5504]{1,0} dot(p0, bitcast),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+ const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+ EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionLevel2Test, ReshapeToScalarIsHandled) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = s8[5,3] parameter(0)
+ c = f16[5,3] convert(p0)
+ p1 = f16[1] parameter(1)
+ r = f16[] reshape(p1)
+ b = f16[5,7] broadcast(r)
+ ROOT d = f16[3,7] dot(c, b),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionLevel2Test, DoNotFuseIncompatibleDimensionSplits) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p1 = s8[5,7,2,3]{3,2,1,0} parameter(1)
+ t1 = s8[7,5,2,3]{3,2,1,0} transpose(p1), dimensions={1,0,2,3}
+ r1 = s8[7,30]{1,0} reshape(t1)
+ cvt = f16[7,30]{1,0} convert(r1)
+ p2 = f16[2,7,5,3]{3,2,1,0} parameter(2)
+ t2 = f16[7,2,5,3]{3,2,1,0} transpose(p2), dimensions={1,0,2,3}
+ r2 = f16[7,30]{1,0} reshape(t2)
+ a = f16[7,30]{1,0} add(cvt, r2)
+ p0 = f16[7,79]{1,0} parameter(0)
+ ROOT dot = f16[30,79]{1,0} dot(a, p0),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Transpose(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParameters) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ tmp_0 = f32[] constant(1)
+ tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={}
+ tmp_2 = f32[3,49]{1,0} parameter(6)
+ tmp_3 = f32[] constant(0)
+ tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={}
+ tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT
+ tmp_6 = f32[3,49]{1,0} convert(tmp_5)
+ tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6)
+ tmp_8 = s32[] parameter(13)
+ tmp_9 = f32[] convert(tmp_8)
+ tmp_10 = f32[] maximum(tmp_9, tmp_0)
+ tmp_11 = f32[] divide(tmp_3, tmp_10)
+ tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={}
+ tmp_13 = pred[3,49]{1,0} parameter(7)
+ tmp_14 = pred[3,49]{1,0} parameter(10)
+ tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14)
+ tmp_16 = f32[3,49]{1,0} convert(tmp_15)
+ tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16)
+ tmp_18 = f32[3,49]{1,0} negate(tmp_17)
+ tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18)
+ tmp_20 = f32[3,49]{1,0} parameter(19)
+ tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20)
+ tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21)
+ tmp_23 = f32[3,49]{1,0} negate(tmp_22)
+ tmp_24 = f32[3,49]{1,0} negate(tmp_6)
+ tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17)
+ tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20)
+ tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26)
+ tmp_28 = f32[3,49]{1,0} parameter(18)
+ tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28)
+ tmp_30 = f32[3,49]{1,0} parameter(17)
+ tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30)
+ tmp_32 = f32[3,49]{1,0} parameter(16)
+ tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32)
+ tmp_34 = f32[3,49]{1,0} parameter(15)
+ tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34)
+ tmp_36 = f32[3,49]{1,0} parameter(14)
+ tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36)
+ tmp_38 = f32[1,1]{1,0} constant({ {0} })
+ tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1}
+ tmp_40 = f32[] reshape(tmp_39)
+ tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={}
+ tmp_42 = u32[48]{0} parameter(11)
+ tmp_43 = u32[48]{0} parameter(5)
+ tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0}
+ tmp_45 = u32[3,32]{1,0} reshape(tmp_44)
+ tmp_46 = u32[96]{0} reshape(tmp_45)
+ tmp_47 = u32[] constant(1)
+ tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={}
+ tmp_49 = u32[96]{0} reshape(tmp_48)
+ tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49)
+ tmp_51 = u32[3,32]{1,0} reshape(tmp_50)
+ tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48)
+ tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52)
+ tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={}
+ tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54)
+ tmp_56 = f32[1,1]{1,0} constant({ {1} })
+ tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1}
+ tmp_58 = f32[] reshape(tmp_57)
+ tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={}
+ tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59)
+ tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41)
+ tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61)
+ tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={}
+ tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT
+ tmp_65 = f32[3,32]{1,0} convert(tmp_64)
+ tmp_66 = f32[3,49]{1,0} parameter(9)
+ tmp_67 = f32[49]{0} parameter(4)
+ tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1}
+ tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68)
+ tmp_70 = f32[1,49]{1,0} parameter(12)
+ tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={}
+ tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71)
+ tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1}
+ tmp_74 = f32[49]{0} reshape(tmp_73)
+ tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1}
+ tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75)
+ tmp_77 = f32[1,49]{1,0} parameter(3)
+ tmp_78 = f32[1,49]{1,0} parameter(8)
+ tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71)
+ tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72)
+ tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80)
+ tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71)
+ tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82)
+ tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83)
+ tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1}
+ tmp_86 = f32[49]{0} reshape(tmp_85)
+ tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1}
+ tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87)
+ tmp_89 = f32[1,49]{1,0} parameter(2)
+ tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1}
+ tmp_91 = f32[49]{0} reshape(tmp_90)
+ tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1}
+ tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92)
+ tmp_94 = f32[49,32]{1,0} parameter(1)
+ tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ tmp_96 = f32[32]{0} parameter(0)
+ tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1}
+ tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97)
+ tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98)
+ tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63)
+ tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63)
+ ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
+ HloOpcode::kFusion);
+ EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
+ HloInstruction::FusionKind::kCustom);
+ EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
+ TritonFusionAnalysis::kMaxParameterPerDotOperand * 2);
+}
+
+TEST_F(GemmFusionLevel2Test,
+ DoNotFuseTooManyParametersWhenAnInstructionWouldAddMultipleParameters) {
+ static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
+ "We have to update this test.");
+ // If we fuse the select, it adds 2 additional parameters at once (not 3,
+ // because the select instruction itself is removed from the parameters).
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ a = f32[3,49]{1,0} parameter(0)
+ b = f32[3,49]{1,0} parameter(1)
+ c = pred[3,49]{1,0} parameter(2)
+ d = f32[3,49]{1,0} parameter(3)
+ e = f32[3,49]{1,0} parameter(4)
+ add0 = f32[3,49]{1,0} add(a, b)
+ select = f32[3,49]{1,0} select(c, d, e)
+ add1 = f32[3,49]{1,0} add(add0, select)
+ f = f32[3,32]{1,0} parameter(5)
+ ROOT tmp_102 = f32[49,32]{1,0} dot(add1, f), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
+ HloOpcode::kFusion);
+ EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
+ HloInstruction::FusionKind::kCustom);
+ EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
+ TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
+}
+
+TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParametersForConcat) {
+ static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
+ "We have to update this test.");
+ // The concat shouldn't overgo the allowed parameter limit.
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ a = f32[3,3]{1,0} parameter(0)
+ b = f32[3,3]{1,0} parameter(1)
+ c = f32[3,3]{1,0} parameter(2)
+ d = f32[3,3]{1,0} parameter(3)
+ e = f32[3,3]{1,0} parameter(4)
+ f = f16[3,3]{1,0} parameter(5)
+ concat = f32[15,3]{1,0} concatenate(a, b, c, d, e), dimensions={0}
+ convert = f32[3,3]{1,0} convert(f)
+ ROOT dot = f32[15,3]{1,0} dot(concat, convert), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
+ HloOpcode::kFusion);
+ EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
+ HloInstruction::FusionKind::kCustom);
+ EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
+ TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
+}
+
+TEST_F(GemmFusionLevel2Test,
+ InstructionsReachableFromMultipleOperandsAreHandledCorrectly) {
+ static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
+ "We have to update this test.");
+ // There was a bug that some dead code was generated into some fusions in a
+ // specific edge case. When some instructions were reachable both through the
+ // LHS and the RHS operands, the BFS (Breadth-first search) through the LHS1
+ // operand "marked" one operation as non-fusible because it would exceed the
+ // limit on fusion parameters per operand. But the BFS through the RHS operand
+ // went through that node and fused some more operands. So the resulting
+ // fusion was not connected and caused errors. This test case checks that such
+ // configurations generate a correct HLO now.
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ a = f32[2,4]{1,0} parameter(0)
+ b = f32[2,4]{1,0} parameter(1)
+ c = f32[2,4]{1,0} parameter(2)
+ d = f32[2,4]{1,0} parameter(3)
+ e = f32[2,4]{1,0} parameter(4)
+ add0 = f32[2,4]{1,0} add(a, b)
+ add1 = f32[2,4]{1,0} add(add0, c)
+ add2 = f32[2,4]{1,0} add(add1, d)
+ add3 = f32[2,4]{1,0} add(add2, e)
+ ROOT r = f32[2,2]{1,0} dot(add3, add0),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ // ~VerifiedHloModule() will verify the module.
+}
+
+TEST_F(GemmFusionLevel2Test, EachScopeIsFusedToASeparateSubgraph) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ a = f32[2,4]{1,0} parameter(0)
+ b = f32[2,4]{1,0} parameter(1)
+ add = f32[2,4]{1,0} add(a, b)
+ ROOT r = f32[2,2]{1,0} dot(add, add),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+ MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
+CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]])
+CHECK-DAG: %[[P2:.*]] = f32[2,4]{1,0} parameter(2)
+CHECK-DAG: %[[P3:.*]] = f32[2,4]{1,0} parameter(3)
+CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P2]], f32[2,4]{1,0} %[[P3]])
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
+CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]),
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+// The 2 inputs of the add operation are the same and they are iterated the same
+// way, so the same parameter node is reused for them.
+// The reuse happens per "operand fusion", so the add of the LHS and RHS still
+// use different nodes.
+TEST_F(GemmFusionLevel2Test, ParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ a = f32[2,4]{1,0} parameter(0)
+ add = f32[2,4]{1,0} add(a, a)
+ ROOT r = f32[2,2]{1,0} dot(add, add),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+ MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
+CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
+CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P1]])
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
+CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+// NEGATE has the same iteration spec at both usages, so the node is reused
+// (implying that P0 is also reused).
+TEST_F(GemmFusionLevel2Test, NonParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ a = f32[4,4]{1,0} parameter(0)
+ b = f32[4,4]{1,0} parameter(1)
+ negate = f32[4,4]{1,0} negate(a)
+ sine = f32[4,4]{1,0} sine(negate)
+ add = f32[4,4]{1,0} add(negate, sine)
+ ROOT r = f32[4,4]{1,0} dot(add, b),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+ MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: %[[NEGATE:.*]] = f32[4,4]{1,0} negate(f32[4,4]{1,0} %[[P0]])
+CHECK-DAG: %[[SINE:.*]] = f32[4,4]{1,0} sine(f32[4,4]{1,0} %[[NEGATE]])
+CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[NEGATE]], f32[4,4]{1,0} %[[SINE]])
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P1]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
+CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+// The direct read of the input and the transposed read of the input have
+// different iteration specs, so we don't reuse the node.
+TEST_F(GemmFusionLevel2Test, NodesAreNotReusedIfTheyHaveDifferentIterSpecs) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ a = f32[4,4]{1,0} parameter(0)
+ b = f32[4,4]{1,0} parameter(1)
+ tr_a = f32[4,4]{1,0} transpose(a), dimensions={1,0}
+ add = f32[4,4]{1,0} add(a, tr_a)
+ ROOT r = f32[4,4]{1,0} dot(add, b),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+ MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: %[[P2:.*]] = f32[4,4]{1,0} parameter(2)
+CHECK-DAG: %[[TRANSPOSE:.*]] = f32[4,4]{1,0} transpose(f32[4,4]{1,0} %[[P1]])
+CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[TRANSPOSE]])
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P2]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
+CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+TEST_F(GemmFusionLevel2Test, OperationsAddingMoreParametersGetMultipleTries) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+e {
+ p0 = f32[2,2] parameter(0)
+ c0 = f32[] constant(12345)
+ b0 = f32[2,2] broadcast(c0), dimensions={}
+ m0 = f32[2,2] multiply(p0, b0)
+ c1 = f32[] constant(34567)
+ b1 = f32[2,2] broadcast(c1), dimensions={}
+ a0 = f32[2,2] add(m0, b1)
+ b3 = f32[2,2,2] broadcast(a0), dimensions={0,1}
+ p2 = f32[2,2,2] parameter(2)
+ m2 = f32[2,2,2] multiply(p2, b3)
+ p1 = f32[2]{0} parameter(1)
+ c2 = f32[] constant(5678)
+ b2 = f32[2] broadcast(c2), dimensions={}
+ a1 = f32[2]{0} add(p1, b2)
+ b4 = f32[2,2,2] broadcast(a1), dimensions={2}
+ m1 = f32[2,2,2] multiply(m2, b4)
+ b = f32[4,2] bitcast(m1)
+ p3 = f16[2,2] parameter(3)
+ p3c = f32[2,2] convert(p3)
+ ROOT r = f32[4,2] dot(b, p3c),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+ m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutPreAmpere) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[2,53] parameter(0)
+ p0e = f32[2,53] exponential(p0)
+ p1 = s16[53,2] parameter(1)
+ p1c = f32[53,2] convert(p1)
+ ROOT dot = f32[2,2] dot(p0e, p1c),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ EXPECT_THAT(
+ GemmFusion(se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0})
+ .Run(module.get()),
+ tsl::testing::StatusIs(
+ absl::StatusCode::kFailedPrecondition,
+ ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
+ "(compute capability 8.0) and up, but got")));
+}
+
+TEST_F(GemmFusionLevel2Test, GemmFusionSucceedsOnNonCudaGpu) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[2,53] parameter(0)
+ p0e = f32[2,53] exponential(p0)
+ p1 = s16[53,2] parameter(1)
+ p1c = f32[53,2] convert(p1)
+ ROOT dot = f32[2,2] dot(p0e, p1c),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ EXPECT_TRUE(GemmFusion(se::RocmComputeCapability{}).Run(module.get()).ok());
+}
+
+TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+HloModule t
+
+ENTRY e {
+ p0 = f32[2,35] parameter(0)
+ p0n = f32[2,35] negate(p0)
+ p0e = f32[2,35] exponential(p0)
+ a = f32[2,35] add(p0e, p0n)
+ p1 = f16[35,2] parameter(1)
+ p1c = f32[35,2] convert(p1)
+ ROOT dot = f32[2,2] dot(a, p1c),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter()))));
+ TF_ASSERT_OK_AND_ASSIGN(
+ const auto analysis,
+ TritonFusionAnalysis::Execute(*module->entry_computation()
+ ->root_instruction()
+ ->called_computations()[0]));
+ EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(),
+ 1);
+ EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size(),
+ 1);
+}
+
+TEST_F(GemmFusionLevel2Test,
+ ParameterUsedNonElementwiseTwiceIsFusedOnBothPaths) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+HloModule t
+
+ENTRY e {
+ p0 = f32[4,4] parameter(0)
+ p0t = f32[4,4] transpose(p0), dimensions={1,0}
+ a = f32[4,4] add(p0, p0t)
+ p1 = f16[4,5] parameter(1)
+ p1c = f32[4,5] convert(p1)
+ ROOT dot = f32[4,5] dot(a, p1c),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test,
+ ComputationParameterWithMultipleUsersIsNotTrivialToFuse) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = f32[400,400] parameter(0)
+
+ c0 = f16[400,400] convert(p0)
+ p1 = f16[400,400] parameter(1)
+ dot0 = f16[400,400] dot(c0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+ c1 = f16[400,400] convert(p0)
+ p2 = f16[400,400] parameter(2)
+ dot1 = f16[400,400] dot(c1, p2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+ ROOT a = f16[400,400] add(dot0, dot1)
+})"));
+ EXPECT_FALSE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+}
+
+TEST_F(GemmFusionLevel2Test, NarrowingConversionIsAlwaysBetterToFuse) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+ p0 = s8[512,512] parameter(0)
+ c0 = f16[512,512] convert(p0)
+ p1 = f16[512,512] parameter(1)
+ dot0 = f16[512,512] dot(c0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+ n = f16[512,512] negate(c0)
+ ROOT a = f16[512,512] add(dot0, n)
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Add(m::Fusion(m::Parameter(), m::Parameter()),
+ m::Negate()))));
+}
+
+TEST_F(GemmFusionLevel2Test, NestedSlicingIsAnalyzedCorrectly) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+triton_gemm_d_computation {
+ p0 = f32[6,24]{1,0} parameter(0)
+ slice1 = f32[5,20]{1,0} slice(p0), slice={[1:6], [3:23]}
+ n1 = f32[5,20]{1,0} negate(slice1)
+ slice2 = f32[3,7]{1,0} slice(n1), slice={[1:4], [13:20]}
+ p1 = f32[7,37]{1,0} parameter(1)
+ ROOT d = f32[3,37]{1,0} dot(slice2, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY e {
+ p0 = f32[7,37]{1,0} parameter(0)
+ p1 = f32[6,24]{1,0} parameter(1)
+ ROOT triton_gemm_d = f32[3,37]{1,0} fusion(p1, p0), kind=kCustom,
+ calls=triton_gemm_d_computation
+})"));
+ const HloComputation* computation =
+ module->entry_computation()->root_instruction()->called_computations()[0];
+ TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
+ TritonFusionAnalysis::Execute(*computation));
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
+ computation->parameter_instruction(0), 0),
+ ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6,
+ /*slice_start=*/2, /*sliced_count=*/3,
+ /*subfragments=*/ElementsAre(3))));
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
+ computation->parameter_instruction(0), 1),
+ ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24,
+ /*slice_start=*/16, /*sliced_count=*/7,
+ /*subfragments=*/ElementsAre(7))));
+}
+
+TEST_F(GemmFusionLevel2Test, FusedConcatenationIsAnalyzedCorrectly) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+e {
+ p0 = s8[153,1536] parameter(0)
+ p1 = s8[153,128] parameter(1)
+ p2 = s8[153,256] parameter(2)
+ cat = s8[153,1920] concatenate(p0, p1, p2), dimensions={1}
+ cvt = bf16[153,1920] convert(cat)
+ p3 = bf16[16,153] parameter(3)
+ ROOT d = bf16[16,1920] dot(p3, cvt),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+ m::Parameter(), m::Parameter()))));
+ const HloComputation* computation =
+ module->entry_computation()->root_instruction()->called_computations()[0];
+ TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
+ TritonFusionAnalysis::Execute(*computation));
+
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+ computation->parameter_instruction(1), 0),
+ ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153,
+ /*slice_start=*/0, /*sliced_count=*/153,
+ /*subfragments=*/ElementsAre(153))));
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+ computation->parameter_instruction(1), 1),
+ ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536,
+ /*slice_start=*/0, /*sliced_count=*/1536,
+ /*subfragments=*/ElementsAre(1536))));
+
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+ computation->parameter_instruction(2), 0),
+ ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153,
+ /*slice_start=*/0, /*sliced_count=*/153,
+ /*subfragments=*/ElementsAre(153))));
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+ computation->parameter_instruction(2), 1),
+ ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128,
+ /*slice_start=*/-1536, /*sliced_count=*/128,
+ /*subfragments=*/ElementsAre(128))));
+
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+ computation->parameter_instruction(3), 0),
+ ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153,
+ /*slice_start=*/0, /*sliced_count=*/153,
+ /*subfragments=*/ElementsAre(153))));
+ EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+ computation->parameter_instruction(3), 1),
+ ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256,
+ /*slice_start=*/-1536 - 128,
+ /*sliced_count=*/256,
+ /*subfragments=*/ElementsAre(256))));
+}
+
+TEST_F(GemmFusionLevel2Test, IndivisibleConcatenationIsNotFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+e {
+ p0 = s8[124,1024] parameter(0)
+ p1 = s8[124,1001] parameter(1)
+ cat = s8[124,2025] concatenate(p0, p1), dimensions={1}
+ cvt = f16[124,2025] convert(cat)
+ p2 = f16[123,124] parameter(2)
+ ROOT d = f16[2025,123] dot(cvt, p2),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test, ConcatenationOfContractingIsNotFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+e {
+ p0 = s8[124,1024] parameter(0)
+ p1 = s8[124,1024] parameter(1)
+ cat = s8[124,2048] concatenate(p0, p1), dimensions={1}
+ cvt = f16[124,2048] convert(cat)
+ p2 = f16[123,2048] parameter(2)
+ ROOT d = f16[124,123] dot(cvt, p2),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test, ConcatenationOfBatchIsNotFused) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+e {
+ p0 = s8[124,1024,50] parameter(0)
+ p1 = s8[124,1024,50] parameter(1)
+ cat = s8[124,2048,50] concatenate(p0, p1), dimensions={1}
+ cvt = f16[124,2048,50] convert(cat)
+ p2 = f16[123,2048,50] parameter(2)
+ ROOT d = f16[2048,124,123] dot(cvt, p2),
+ lhs_batch_dims={1}, rhs_batch_dims={1},
+ lhs_contracting_dims={2}, rhs_contracting_dims={2}
+})"));
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test,
+ DifferentConcatenationOfSameParametersIsFusedViaNodeDuplication) {
+ // It means that the same input is passed to the fusion multiple times and
+ // it's read differently for each.
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+e {
+ p0 = s8[128,2] parameter(0)
+ p1 = s8[128,2] parameter(1)
+ cat0 = s8[256,2] concatenate(p0, p1), dimensions={0}
+ cvt0 = f16[256,2] convert(cat0)
+ cat1 = s8[256,2] concatenate(p1, p0), dimensions={0}
+ n1 = s8[256,2] negate(cat1)
+ cvt1 = f16[256,2] convert(n1)
+ a = f16[256,2] add(cvt1, cvt0)
+ p2 = f16[2,18] parameter(2)
+ ROOT d = f16[18,256] dot(p2, a),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+
+ EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+ se::CudaComputeCapability::AMPERE, 0})
+ .Run(module.get())
+ .value());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
+ m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionTest, CopiesDotMetadataToFusionOp) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[2,18] parameter(0)
+ p1 = f16[256,2] parameter(1)
+ ROOT d = f16[18,256] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="foo"}
+})")
+ .value();
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+ EXPECT_EQ(
+ module->entry_computation()->root_instruction()->metadata().op_name(),
+ "foo");
+}
+
+// A test fixture class for testing the threshold for small matrices.
+class SmallDotGemmFusionTest : public GemmFusionTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
+ return debug_options;
+ }
+};
+
+TEST_F(SmallDotGemmFusionTest, SkipSmallMatrixMultiplicationRewrite) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[2,10] parameter(0)
+ p1 = f16[10,2] parameter(1)
+ ROOT d = f16[10,10] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})")
+ .value();
+
+ EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+ MatchHloModule(*module, R"(
+; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,10], {{.*}}: f16[10,2]) -> f16[10,10] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,10]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[10,2]{1,0} parameter(1)
+; CHECK: ROOT {{.*}} = f16[10,10]{1,0} dot(f16[2,10]{1,0} [[P0]], f16[10,2]{1,0} [[P1]])
+})");
+}
+
+TEST_F(SmallDotGemmFusionTest, LargeMatrixMultiplicationIsRewritten) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+ p0 = f16[2,18] parameter(0)
+ p1 = f16[50,2] parameter(1)
+ ROOT d = f16[18,50] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})")
+ .value();
+
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+ MatchHloModule(*module, R"(
+; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,18], {{.*}}: f16[50,2]) -> f16[18,50] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,18]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[50,2]{1,0} parameter(1)
+; CHECK: ROOT {{.*}} = f16[18,50]{1,0}
+; CHECK: fusion(f16[2,18]{1,0} [[P0]], f16[50,2]{1,0} [[P1]]),
+; CHECK: kind=kCustom
+; CHECK: __triton_gemm
+})");
+}
+
+class SparseDotTest : public GemmFusionTest {};
+
+TEST_F(SparseDotTest, DotWithSparseLhsOperandIsRewritten) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+ENTRY main {
+ lhs = f16[2,16] parameter(0)
+ rhs = f16[32,2] parameter(1)
+ meta = u16[2,2] parameter(2)
+ ROOT dot = f32[2,2] dot(lhs, rhs, meta),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
+})")
+ .value();
+ EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+ MatchHloModule(*module, R"(
+; CHECK-LABEL: ENTRY %main ({{.*}}: f16[2,16], {{.*}}: f16[32,2], {{.*}}: u16[2,2]) -> f32[2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,16]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[32,2]{1,0} parameter(1)
+; CHECK-NEXT: [[META:%[^ ]+]] = u16[2,2]{1,0} parameter(2)
+; CHECK: ROOT {{.*}} = f32[2,2]{1,0}
+; CHECK-SAME: fusion(f16[2,16]{1,0} [[P0]], f16[32,2]{1,0} [[P1]], u16[2,2]{1,0} [[META]]),
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: __triton_gemm
+})");
+}
+
+TEST_F(SparseDotTest, DotWithSparseRhsOperandIsNotSupported) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+ENTRY main {
+ lhs = f16[2,32] parameter(0)
+ rhs = f16[16,2] parameter(1)
+ meta = u16[2,2] parameter(2)
+ ROOT dot = f32[2,2] dot(lhs, rhs, meta),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=R.0@2:4
+})")
+ .value();
+ auto result = GemmFusion(gpu_version_).Run(module.get());
+ EXPECT_FALSE(result.ok());
+}
+
+TEST_F(SparseDotTest, UnsupportedSparsityType) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+ENTRY main {
+ lhs = f16[2,8] parameter(0)
+ rhs = f16[32,2] parameter(1)
+ meta = u16[2,1] parameter(2)
+ ROOT dot = f32[2,2] dot(lhs, rhs, meta),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@1:4
+})")
+ .value();
+ auto result = GemmFusion(gpu_version_).Run(module.get());
+ EXPECT_FALSE(result.ok());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc
new file mode 100644
index 0000000..a48d4f9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc
@@ -0,0 +1,2370 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=
+=============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+
+#include <algorithm>
+#include <array>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/evaluator/hlo_evaluator.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/algorithm_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/gpu/gpu_blas_lt.h"
+#include "xla/types.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/ml_dtypes.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/protobuf/dnn.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = match;
+
+// Give this instruction a more useful name than "custom-call.42".
+absl::Status SetName(HloModule *module, HloInstruction *gemm) {
+ if (IsCublasLtMatmul(*gemm)) {
+ module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul");
+ return absl::OkStatus();
+ }
+
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ gemm->backend_config<GpuBackendConfig>());
+ const GemmBackendConfig &config = gpu_config.gemm_backend_config();
+ const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
+ bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
+ !dot_dims.rhs_batch_dimensions().empty();
+
+ module->SetAndUniquifyInstrName(
+ gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm");
+ return absl::OkStatus();
+}
+
+// Returns whether a given PrimitiveType is supported by cuBLASLt Epilogue
+// Fusion. A table of supported data types can be found in the cuBLASLt
+// documentation: https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul.
+// Note that `Ctype` also describes the output type of the GEMM. Rows with
+// `Non-default epilogue not supported` entries in the last column indicate data
+// types not compatible with Epilogue Fusion.
+bool SupportsEpilogueFusion(PrimitiveType type) {
+ switch (type) {
+ case F8E4M3FN:
+ case F8E5M2:
+ case F16:
+ case BF16:
+ case F32:
+ case F64:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool IsF8Type(const HloInstruction *instr) {
+ return primitive_util::IsF8Type(instr->shape().element_type());
+}
+
+// Returns a new shape with non-batch dimensions padded to multiples of 16, as
+// required by cuBLASLt FP8 gemms.
+Shape PadShapeToMultipleOf16(const Shape old_shape,
+ const absl::Span<const int64_t> batch_dims) {
+ Shape padded_shape = old_shape;
+ for (int i = 0; i < old_shape.rank(); ++i) {
+ if (!absl::c_linear_search(batch_dims, i)) {
+ int64_t padded_dimension =
+ RoundUpTo<int64_t>(old_shape.dimensions(i), 16);
+ padded_shape.set_dimensions(i, padded_dimension);
+ }
+ }
+ return padded_shape;
+}
+
+// Pad the dimensions of the operands to the target shape.
+HloInstruction *PadOperandToTargetShape(const Shape &target,
+ HloInstruction *x) {
+ if (ShapeUtil::Equal(target, x->shape()) ||
+ !ShapeUtil::SameElementType(x->shape(), target)) {
+ return x;
+ }
+
+ PaddingConfig padding_config;
+ for (int i = 0; i < x->shape().rank(); ++i) {
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(0);
+ dimension->set_edge_padding_high(target.dimensions(i) -
+ x->shape().dimensions(i));
+ dimension->set_interior_padding(0);
+ }
+
+ HloInstruction *zero = x->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(x->shape().element_type())));
+ return x->AddInstruction(
+ HloInstruction::CreatePad(target, x, zero, padding_config));
+}
+
+// Pad the non-batch dimensions of the operands to multiples of 16 as required
+// by cuBLASLt FP8 gemms.
+HloInstruction *PadOperandToMultipleOf16(absl::Span<const int64_t> batch_dims,
+ HloInstruction *x) {
+ Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims);
+ return PadOperandToTargetShape(padded_shape, x);
+}
+
+// Calculates the reciprocal of scalar when invert is true and converts to FP32.
+absl::StatusOr<HloInstruction *> InvertAndConvertScalar(HloInstruction *scalar,
+ bool invert) {
+ DCHECK(ShapeUtil::IsScalar(scalar->shape()));
+
+ if (invert) {
+ Literal one_literal = LiteralUtil::One(scalar->shape().element_type());
+ HloInstruction *one = scalar->parent()->AddInstruction(
+ HloInstruction::CreateConstant(one_literal.Clone()));
+ TF_ASSIGN_OR_RETURN(scalar, MakeBinaryHlo(HloOpcode::kDivide, one, scalar,
+ &scalar->metadata()));
+ }
+ if (scalar->shape().element_type() != F32) {
+ scalar = MakeConvertToHlo(scalar, F32, &scalar->metadata());
+ }
+
+ return scalar;
+}
+
+// A path of instructions by traversing downwards through users, as (op,
+// operand_index) pairs. operand_index is the index to get to the previous
+// element in the path. I.e.,
+// path[i].first->operand(path[i].second) == path[i-1].first
+using InstrPath = std::vector<std::pair<HloInstruction *, int>>;
+
+// From 'instr', recursively traverses operands until an FP8 instruction is
+// encountered. Only unary ops and a few types of non-unary ops are traversed.
+// If an FP8 instruction is found, returns the path from the FP8 instruction to
+// 'instr'. Returns nullopt when no FP8 instruction is reached.
+//
+// The intent is, given 'instr' is the operand of a dot, to find a sequence of
+// instruction that can potentially be fused into a cuBLAS LT FP8 gemm.
+std::optional<InstrPath> FindF8SubgraphRecursive(
+ HloInstruction *instr, absl::flat_hash_set<int> &visited_instrs) {
+ // Avoid visiting the same instruction more than once.
+ if (!visited_instrs.emplace(instr->unique_id()).second) {
+ return std::nullopt;
+ }
+ if (IsF8Type(instr)) {
+ // The initial operand index is meaningless. Arbitrarily use -1.
+ return InstrPath{{instr, -1}};
+ }
+ if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide ||
+ instr->opcode() == HloOpcode::kDynamicSlice ||
+ instr->opcode() == HloOpcode::kPad) {
+ std::optional<InstrPath> subgraph =
+ FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs);
+ if (subgraph) {
+ subgraph->emplace_back(std::make_pair(instr, 0));
+ }
+ return subgraph;
+ } else if (instr->opcode() == HloOpcode::kMultiply ||
+ instr->opcode() == HloOpcode::kSelect) {
+ for (int k = 0; k < 2; ++k) {
+ // Iterate over operands 0 and 1 for multiply and operands 1 and 2 for
+ // select.
+ int operand_idx = k + (instr->opcode() == HloOpcode::kSelect);
+ std::optional<InstrPath> subgraph = FindF8SubgraphRecursive(
+ instr->mutable_operand(operand_idx), visited_instrs);
+ if (subgraph) {
+ subgraph->emplace_back(std::make_pair(instr, operand_idx));
+ return subgraph;
+ }
+ }
+ }
+ return std::nullopt;
+}
+
+// Contains information on a parameter (either the LHS or RHS) for a
+// gemm that can be potentially pattern-matched into an FP8 cublasLT gemm.
+struct MatchedFp8Param {
+ // The FP8 input to the gemm.
+ HloInstruction *fp8_input = nullptr;
+ // If nonnull, the scale for the 'x'
+ HloInstruction *scale = nullptr;
+ // Whether the scale, if present, multiplies or divides 'x'
+ bool mult_scale = false;
+ // A list of instructions from x to the dot instruction commutative with
+ // dequantization. Such instructions can be moved before the FP8 gemm.
+ InstrPath commutative_ops;
+};
+
+// Given an operand of a dot, `instr`, returns a MatchedFp8Param if this operand
+// allows rewriting the dot in an FP8 cublasLT custom call, optionally with
+// scaling. In particular, returns an MatchedFp8Param if either 'instr' is FP8
+// or there is a there is a path from an FP8 instruction 'fp8_input' to 'instr'
+// consisting of the following.
+// 1. A convert to a wider type.
+// 2. Optionally, a multiplication/division by a scalar, representing the scale.
+// If present, the scalar scale is returned as 'scale' and 'mult_scale'
+// is set to true or false depending on whether there is a multiplication or
+// a division.
+// 3. A possibly-empty set of ops communative with steps (1) and (2), meaning
+// they can be safely moved before step (1). Such ops are returned in
+// 'commutative_ops'.
+// Steps (1) and (2) together are a dequantization, and can be fused into a
+// cublas LT matmul. Step (3) can be moved before the cublas LT matmul.
+std::optional<MatchedFp8Param> MatchFp8Param(HloInstruction *instr) {
+ absl::flat_hash_set<int> visited_instrs;
+ std::optional<InstrPath> maybe_subgraph =
+ FindF8SubgraphRecursive(instr, visited_instrs);
+ if (!maybe_subgraph) {
+ return std::nullopt;
+ }
+ InstrPath &subgraph = maybe_subgraph.value();
+
+ MatchedFp8Param param;
+
+ // Directly operating on an FP8 operand.
+ if (subgraph.size() == 1) {
+ CHECK(IsF8Type(subgraph[0].first));
+ param.fp8_input = subgraph[0].first;
+ return param;
+ }
+
+ int num_dequant_ops;
+ // When not operating directly on an FP8 operand, the second and
+ // third instructions in the subgraph can describe a dequantization, i.e. a
+ // convert instruction followed by a multiply/divide instruction.
+ if (subgraph.size() > 2 &&
+ Match(subgraph[2].first,
+ m::MultiplyAnyOrder(m::Convert(m::Op(¶m.fp8_input)),
+ m::Broadcast(m::Op(¶m.scale))))) {
+ param.mult_scale = true;
+ num_dequant_ops = 2;
+ } else if (subgraph.size() > 2 &&
+ Match(subgraph[2].first,
+ m::Divide(m::Convert(m::Op(¶m.fp8_input)),
+ m::Broadcast(m::Op(¶m.scale))))) {
+ param.mult_scale = false;
+ num_dequant_ops = 2;
+ } else if (subgraph.size() > 1 &&
+ Match(subgraph[1].first, m::Convert(m::Op(¶m.fp8_input)))) {
+ // We have a convert from FP8 without a scale in this case.
+ param.scale = nullptr;
+ num_dequant_ops = 1;
+ } else {
+ VLOG(1) << "Possible intended FP8 GEMM operating on "
+ << instr->ToShortString() << " not rewritten into FP8 Custom Call.";
+ return std::nullopt;
+ }
+
+ auto preserves_element_type = [](const HloInstruction *instr) -> bool {
+ return ShapeUtil::SameElementType(instr->shape(),
+ instr->operand(0)->shape());
+ };
+ auto use_spmd_partitioning = [](const HloInstruction *instr) -> bool {
+ return instr->GetModule()->config().use_spmd_partitioning();
+ };
+
+ // Skip the initial FP8 instruction and the dequantization instructions.
+ int start = 1 + num_dequant_ops;
+ for (int i = start; i < subgraph.size(); ++i) {
+ // The remaining instructions must be commutative with dequantization.
+ // Bitcast, broadcast, copy, dynamic-slice, pad, reshape, select, slice,
+ // transpose, all-gather, all-to-all and collective-permute instructions are
+ // supported. Specifically, the all-gather, all-to-all and
+ // collective-permute operations are permitted only in SPMD cases since the
+ // optimization cannot be guaranteed to be applied to all replicas in the
+ // MPMD scenario.
+ if (!Match(
+ subgraph[i].first,
+ m::AnyOf<HloInstruction>(
+ m::Bitcast().WithPredicate(preserves_element_type),
+ m::Broadcast(), m::Copy(), m::DynamicSlice(), m::Pad(),
+ m::Reshape(), m::Select(), m::Slice(), m::Transpose(),
+ m::AllGather().WithPredicate(use_spmd_partitioning),
+ m::AllToAll().WithPredicate(use_spmd_partitioning),
+ m::CollectivePermute().WithPredicate(use_spmd_partitioning)))) {
+ VLOG(1) << "Possible intended FP8 GEMM operating on "
+ << instr->ToShortString()
+ << " not rewritten into FP8 Custom Call.";
+ return std::nullopt;
+ }
+ // One of the operands of select must be zero for the op to be commutative
+ // with dequantization.
+ if (Match(subgraph[i].first, m::Select()) &&
+ !Match(subgraph[i].first->operand(subgraph[i].second == 2 ? 1 : 2),
+ m::Broadcast(m::ConstantScalar(0)))) {
+ VLOG(1) << "Possible intended FP8 GEMM operating on "
+ << instr->ToShortString()
+ << " not rewritten into FP8 Custom Call. Select requires a zero "
+ "operand to be exchanged with dequantization.";
+ return std::nullopt;
+ }
+ }
+
+ param.commutative_ops = {subgraph.begin() + start, subgraph.end()};
+ return param;
+}
+
+// Transposes a matrix by swapping the contracting and non-contracting
+// dimension. There must be only one contracting and only one non-contracting
+// dimension. Keeps the layout the same.
+HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
+ absl::Span<const int64_t> batch_dims) {
+ // Identify the dimensional order which describes a transpose of the
+ // contracting and non-contracting dimensions of the GEMM.
+ std::vector<int64_t> permutation(instr->shape().dimensions_size(), -1);
+ // Discard the batch dimensions.
+ for (int64_t batch_dim : batch_dims) {
+ permutation[batch_dim] = batch_dim;
+ }
+ // Identify the non-contracting dimension.
+ int non_contracting_dim;
+ for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
+ if (permutation[i] == -1 && contracting_dim != i) {
+ non_contracting_dim = i;
+ }
+ }
+ permutation[non_contracting_dim] = contracting_dim;
+ permutation[contracting_dim] = non_contracting_dim;
+
+ Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
+ *new_shape.mutable_layout() = instr->shape().layout();
+ return instr->AddInstruction(
+ HloInstruction::CreateTranspose(new_shape, instr, permutation));
+}
+
+// If the bias is a sequence of ops that depend only on broadcasts of
+// constants, materialize the bias if it's small.
+//
+// Normally the constant-folding pass would materialize the bias if it is
+// calculated entirely from constants. But if the bias is a broadcast of a
+// constant, constant-folding won't expand the broadcast, on the theory that
+// folding broadcasts of constants causes us to consume more memory and can
+// actually make things slower (because any op which reads the constant has
+// to read more memory).
+//
+// OTOH in our case, we don't want to run an op that just broadcasts a
+// constant so we can fuse it into this gemm. That would defeat the whole
+// purpose of this fusion, which is to launch fewer kernels. So if we can,
+// we expand out this constant ourselves.
+//
+// TODO(b/192499646): Even better would be to use cublasLT to fuse the
+// broadcasted bias, if it supports that fusion efficiently.
+HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
+ // This limit was not chosen carefully.
+ constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
+
+ // Don't fold broadcasts of scalars -- algsimp will just collapse it again.
+ auto is_nonscalar = [](const HloInstruction *instr) {
+ return !ShapeUtil::IsEffectiveScalar(instr->shape());
+ };
+
+ // For now, only fold broadcast(constant) or
+ // reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
+ // complexity in the constant-folding pass about what is and isn't legal to
+ // fold.
+ auto broadcast_of_nonscalar =
+ m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
+
+ if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
+ (Match(bias, broadcast_of_nonscalar) ||
+ Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
+ Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
+ Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
+ HloEvaluator evaluator(/*max_loop_iterations=*/0);
+ Literal result;
+ if (evaluator.TryEvaluate(
+ bias, &result,
+ /*recursively_evaluate_nonconstant_operands=*/true)) {
+ return bias->parent()->AddInstruction(
+ HloInstruction::CreateConstant(std::move(result)));
+ }
+ }
+
+ return bias;
+}
+
+auto Gemm(HloInstruction **instr) {
+ return m::CustomCall(instr, {kGemmCallTarget});
+}
+
+auto CublasLtMatmul(HloInstruction **instr) {
+ return m::CustomCall(instr, {kCublasLtMatmulCallTarget});
+}
+
+auto CublasLtMatmulF8(HloInstruction **instr) {
+ return m::CustomCall(instr, {kCublasLtMatmulF8CallTarget});
+}
+
+auto CublasLtMatmulMaybeF8(HloInstruction **instr) {
+ return m::CustomCall(
+ instr, {kCublasLtMatmulCallTarget, kCublasLtMatmulF8CallTarget});
+}
+
+auto GemmOrCublasLtMatmul(HloInstruction **instr) {
+ return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget});
+}
+
+auto GemmOrCublasLtMatmulMaybeF8(HloInstruction **instr) {
+ return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget,
+ kCublasLtMatmulF8CallTarget});
+}
+
+auto BcastConstScalar(HloInstruction **instr, double value) {
+ return m::Broadcast(instr, m::ConstantScalar(value));
+}
+
+auto BcastConstScalar(double value) { return BcastConstScalar(nullptr, value); }
+
+auto BcastConstScalarNear(double value) {
+ return m::Broadcast(m::ConstantScalar().WithPredicate(
+ [expected = value](const HloInstruction *instr) {
+ // Not a very robust floating-point comparison, but good enough for our
+ // purposes.
+ std::optional<double> actual =
+ xla::Cast<const HloConstantInstruction>(instr)
+ ->literal()
+ .GetAsDouble({});
+ if (!actual.has_value()) return false;
+ double epsilon;
+ switch (instr->shape().element_type()) {
+ case F16:
+ epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
+ break;
+ case BF16:
+ epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
+ break;
+ case F32:
+ epsilon = 128 * std::numeric_limits<float>::epsilon();
+ break;
+ case F64:
+ epsilon = 128 * std::numeric_limits<double>::epsilon();
+ break;
+ default:
+ return false;
+ }
+ return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
+ }));
+}
+
+template <typename Pattern>
+auto OptionalSlice(HloInstruction **optional_slice, Pattern pattern) {
+ return m::AnyOf<HloInstruction>(m::Slice(optional_slice, pattern),
+ std::move(pattern));
+}
+
+template <typename Pattern>
+auto OptionalConvert(HloInstruction **optional_convert, Pattern pattern) {
+ return m::AnyOf<HloInstruction>(m::Convert(optional_convert, pattern),
+ std::move(pattern));
+}
+
+template <typename Pattern>
+auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) {
+ return m::AnyOf<HloInstruction>(m::Bitcast(optional_bitcast, pattern),
+ std::move(pattern));
+}
+
+// The rewriting proceeds in a bottom-up way:
+//
+// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
+//
+// (kMultiply (kCustomCall:gemm A B) C) is folding C (provided it's a constant)
+// into an alpha parameter of the custom call.
+//
+// (kAdd (kCustomCall:gemm A B) C) is rewritten into (kCustomCall:gemm A B C),
+// where the "beta" parameter is set to 1 (provided it was zero before,
+// and provided C has no other users).
+// We then guide the buffer assignment to alias the buffer of the custom call
+// and C.
+//
+// For scaled FP8 GEMMs on Hopper systems, the following steps mentioned in
+// RFC #22 (https://github.com/openxla/xla/discussions/22) are elided and
+// rewritten into a Custom Call:
+//
+// 1. Cast each input from FP8 to a wider type such as FP16 or FP32.
+// 2. Unscale each input by multiplying each input by the corresponding input
+// scale.
+// 3. Evaluate the matrix multiplication on the scaled inputs.
+// 4. Compute the maximum of the absolute values in the result of the GEMM
+// (DAmax).
+// 5. Scale the output by dividing the output by the output scale.
+// 6. Cast the output back to FP8. Since saturation should be done on
+// overflow, this is represented by a Clamp instruction followed by a Convert
+// instruction.
+
+// Steps 1 through 3 can be elided independently of the remainder. Steps 5 and
+// 6 are elided only if steps 1 through 3 were successfully transformed. Step
+// 4 requires steps 5 and 6, i.e. the computation of DAmax can be elided only
+// when the output of the GEMM is requested in FP8 format.
+class GemmRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit GemmRewriterVisitor(const se::GpuComputeCapability &gpu_version,
+ const int32_t toolkit_version,
+ const GemmRewriterOptions options)
+ : gpu_version_(gpu_version),
+ toolkit_version_(toolkit_version),
+ options_(options) {}
+
+ absl::Status HandleDot(HloInstruction *instr) override {
+ if (!IsMatrixMultiplication(*instr) &&
+ !IsMatrixVectorMultiplication(*instr)) {
+ return absl::OkStatus();
+ }
+ // Sparse dot is not supported.
+ if (Cast<HloDotInstruction>(instr)->sparse_operands()) {
+ return absl::OkStatus();
+ }
+
+ int64_t gemm_rewrite_size_threshold =
+ instr->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_gemm_rewrite_size_threshold();
+ TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
+ IsMatrixMultiplicationTooSmallForRewriting(
+ *instr, gemm_rewrite_size_threshold));
+ if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*instr)) {
+ return absl::OkStatus();
+ }
+
+ CHECK(!instr->IsRank2Transpose());
+ if (instr->operand(0)->IsRank2Transpose() ||
+ instr->operand(1)->IsRank2Transpose()) {
+ return absl::OkStatus();
+ }
+ // Create a GemmBackendConfig based on the instruction.
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
+ instr->backend_config<GpuBackendConfig>());
+ GemmBackendConfig &gemm_backend_config =
+ *gpu_backend_config.mutable_gemm_backend_config();
+ gemm_backend_config.set_alpha_real(1.0);
+ gemm_backend_config.set_alpha_imag(0.0);
+ gemm_backend_config.set_beta(0.0);
+ *gemm_backend_config.mutable_dot_dimension_numbers() =
+ instr->dot_dimension_numbers();
+ *gemm_backend_config.mutable_precision_config() = instr->precision_config();
+
+ HloInstruction *lhs = instr->mutable_operand(0);
+ HloInstruction *rhs = instr->mutable_operand(1);
+ auto attributes = instr->frontend_attributes().map();
+ gemm_backend_config.set_grad_x(attributes["grad_x"] == "true");
+ gemm_backend_config.set_grad_y(attributes["grad_y"] == "true");
+
+ int64_t lhs_batch_dims_size =
+ instr->dot_dimension_numbers().lhs_batch_dimensions_size();
+ bool is_lhs_vector =
+ lhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
+ bool is_rhs_vector =
+ rhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
+ int64_t lhs_stride =
+ is_lhs_vector ? lhs->shape().dimensions(lhs_batch_dims_size)
+ : lhs->shape().dimensions(lhs_batch_dims_size) *
+ lhs->shape().dimensions(lhs_batch_dims_size + 1);
+ int64_t rhs_stride =
+ is_rhs_vector ? rhs->shape().dimensions(lhs_batch_dims_size)
+ : rhs->shape().dimensions(lhs_batch_dims_size) *
+ rhs->shape().dimensions(lhs_batch_dims_size + 1);
+
+ gemm_backend_config.set_lhs_stride(lhs_stride);
+ gemm_backend_config.set_rhs_stride(rhs_stride);
+
+ switch (options_.dtype) {
+ case GemmRewriterOptions::DType::kFp8Only: {
+ // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call.
+ TF_ASSIGN_OR_RETURN(
+ bool supported_by_cublaslt,
+ GemmIsSupportedByCublasLt(*instr, gemm_backend_config));
+ std::optional<MatchedFp8Param> a, b;
+ if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot &&
+ (a = MatchFp8Param(
+ const_cast<HloInstruction *>(instr->operand(0)))) &&
+ (b = MatchFp8Param(
+ const_cast<HloInstruction *>(instr->operand(1))))) {
+ if (IsRocm(gpu_version_) && toolkit_version_ < 60200 &&
+ instr->shape().element_type() != F16 &&
+ instr->shape().element_type() != F32) {
+ TF_ASSIGN_OR_RETURN(
+ instr, TurnF8DotWithUnsupportedOutputTypeIntoF32(instr));
+ }
+ TF_ASSIGN_OR_RETURN(bool created_call,
+ CreateF8CustomCall(instr, gpu_backend_config,
+ a.value(), b.value()));
+ if (created_call) {
+ return absl::OkStatus();
+ }
+ }
+ if (IsF8Type(instr->operand(0))) {
+ // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt
+ // custom call, so turn into an FP16 dot which may be rewritten as an
+ // FP16 Triton, cublas or cublasLt call.
+ TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr));
+ }
+ break;
+ }
+ case GemmRewriterOptions::DType::kNonFp8Only: {
+ // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call.
+ TF_ASSIGN_OR_RETURN(
+ absl::string_view gemm_custom_call_target,
+ GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config));
+ const Shape &output_shape = instr->shape();
+ HloInstruction *gemm_call =
+ instr->AddInstruction(HloInstruction::CreateCustomCall(
+ output_shape,
+ {instr->mutable_operand(0), instr->mutable_operand(1)},
+ gemm_custom_call_target));
+ TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call));
+ } break;
+ };
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleMultiply(HloInstruction *instr) override {
+ HloInstruction *alpha, *existing_gemm;
+ if (Match(instr,
+ m::MultiplyAnyOrder(
+ GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(),
+ m::Broadcast(m::ConstantScalar(&alpha)).WithOneUser()))) {
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ existing_gemm->backend_config<GpuBackendConfig>());
+ GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+ // Do not fuse alpha into S32 GEMM, as they only support fixed values for
+ // alpha/beta.
+ if (existing_gemm->shape().element_type() == S32) {
+ return absl::OkStatus();
+ }
+
+ if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
+ complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
+ complex128 new_alpha =
+ *alpha->literal().GetAsComplex128({}) * prev_alpha;
+ config.set_alpha_real(new_alpha.real());
+ config.set_alpha_imag(new_alpha.imag());
+ TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
+ return ReplaceInstruction(instr, existing_gemm);
+ }
+ }
+
+ HloInstruction *d_scale;
+ if (Match(instr, m::MultiplyAnyOrder(
+ CublasLtMatmulF8(&existing_gemm).WithOneUser(),
+ m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
+ return F8ScaleD(instr, existing_gemm, d_scale);
+ }
+
+ // Attempt to match approximate GELU activation
+ // (https://arxiv.org/abs/1606.08415), where:
+ // approx_gelu(x) = x * cdf(x)
+ // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
+ HloInstruction *cdf, *slice_or_bitcast = nullptr;
+ if (Match(instr, m::MultiplyAnyOrder(
+ m::AnyOf<HloInstruction>(
+ m::Slice(&slice_or_bitcast,
+ CublasLtMatmulMaybeF8(&existing_gemm)),
+ m::Bitcast(&slice_or_bitcast,
+ CublasLtMatmulMaybeF8(&existing_gemm)),
+ CublasLtMatmulMaybeF8(&existing_gemm)),
+ m::Op(&cdf).WithOneUser())) &&
+ Match(cdf,
+ m::MultiplyAnyOrder(
+ BcastConstScalar(0.5),
+ m::AddAnyOrder(
+ BcastConstScalar(1.0),
+ m::Tanh(
+ m::MultiplyAnyOrder(
+ BcastConstScalarNear(sqrt(M_2_PI)),
+ m::AddAnyOrder(
+ m::Op().Is(slice_or_bitcast ? slice_or_bitcast
+ : existing_gemm),
+ m::MultiplyAnyOrder(
+ BcastConstScalarNear(0.044715),
+ m::MultiplyAnyOrder(
+ m::Op().Is(slice_or_bitcast
+ ? slice_or_bitcast
+ : existing_gemm),
+ m::MultiplyAnyOrder(
+ m::Op().Is(slice_or_bitcast
+ ? slice_or_bitcast
+ : existing_gemm),
+ m::Op().Is(slice_or_bitcast
+ ? slice_or_bitcast
+ : existing_gemm))
+ .WithOneUser())
+ .WithOneUser())
+ .WithOneUser())
+ .WithOneUser())
+ .WithOneUser())
+ .WithOneUser())))) {
+ return FuseGeluActivation(instr, existing_gemm, slice_or_bitcast);
+ }
+ return absl::OkStatus();
+ }
+
+ // Fuse the scaling of an FP8 GEMM into the Custom Call.
+ absl::Status HandleDivide(HloInstruction *instr) override {
+ HloInstruction *existing_gemm, *d_scale;
+ if (Match(instr, m::Divide(CublasLtMatmulF8(&existing_gemm).WithOneUser(),
+ m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
+ return F8ScaleD(instr, existing_gemm, d_scale);
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleAdd(HloInstruction *instr) override {
+ if (options_.bias_mode == GemmRewriterOptions::BiasMode::kNoBias) {
+ // See comments for `GemmRewriterOptions::BiasMode` for details.
+ return absl::OkStatus();
+ }
+
+ HloInstruction *bias, *existing_gemm = nullptr;
+ HloInstruction *optional_slice = nullptr;
+ HloInstruction *optional_convert = nullptr;
+ HloInstruction *optional_bitcast = nullptr;
+ // Attempt to elide broadcast and fuse addition of a vector bias into
+ // GEMM, including when slicing is applied to the result.
+ if (Match(instr,
+ m::AddAnyOrder(
+ OptionalBitcast(
+ &optional_bitcast,
+ OptionalSlice(
+ &optional_slice,
+ CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
+ .WithOneUser())
+ .WithOneUser(),
+ m::Broadcast(&bias,
+ OptionalConvert(&optional_convert, m::Op()))))) {
+ TF_ASSIGN_OR_RETURN(
+ bool was_fused,
+ FuseVectorBiasAdd(instr, bias, existing_gemm, optional_slice,
+ optional_convert, optional_bitcast));
+
+ if (was_fused) {
+ return absl::OkStatus();
+ }
+ }
+ // Attempt to elide broadcast and fuse addition of a vector bias into
+ // *batched* GEMM as a matrix bias addition using FuseMatrixBiasAdd.
+ // add(bitcast(gemm(a, b)), broadcast(bias)) ->
+ // bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) ->
+ // bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd)
+ //
+ if (Match(
+ instr,
+ m::AddAnyOrder(
+ m::Bitcast(CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
+ .WithOneUser(),
+ m::Broadcast(&bias, m::Op()).WithOneUser()))) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_add,
+ MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
+ MakeBitcastHlo(bias, existing_gemm->shape())));
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
+
+ // Continue below.
+ instr = new_add;
+ }
+
+ // Do not fuse broadcast unless we can fuse its input, as it will cause
+ // broadcast materialization.
+ auto is_not_broadcast = [](const HloInstruction *instr) {
+ return instr->opcode() != HloOpcode::kBroadcast;
+ };
+
+ // add(bitcast(gemm(a, b)), bias) ->
+ // bitcast(add(gemm(a, b), bitcast(bias))) ->
+ // bitcast(gemm(a, b, bitcast(bias))) (later down in this function).
+ //
+ // We see this idiom in models that contain batch-dots, where we cast
+ // between a rank-2 shape for non-batch dots and a higher-rank shape for
+ // batch-dots.
+ //
+ // The last stage of the transform may fail (because of any of the checks in
+ // FuseMatrixBiasAdd), but if so that's okay -- we'll have done a useless
+ // transformation, but it doesn't hurt anything.
+ if (Match(instr,
+ m::AddAnyOrder(
+ m::Bitcast(
+ GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
+ .WithOneUser(),
+ m::Op(&bias).WithPredicate(is_not_broadcast)))) {
+ HloInstruction *new_bitcast =
+ MakeBitcastHlo(bias, existing_gemm->shape(), &bias->metadata());
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_add,
+ MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
+ new_bitcast, &bias->metadata()));
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
+
+ // Continue below transforming new_add.
+ instr = new_add;
+ }
+
+ // Attempt to fuse matrix bias into gemm with optional convert
+ // add(convert(gemm(a, b)), c) -> gemm(a, b, c)
+ // add(gemm(a, b), c) -> gemm(a, b, c)
+ if (Match(instr,
+ m::AddAnyOrder(
+ m::AnyOf<HloInstruction>(
+ GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(),
+ m::Convert(
+ GemmOrCublasLtMatmul(&existing_gemm).WithOneUser())
+ .WithOneUser()),
+ m::Op(&bias).WithPredicate(is_not_broadcast)))) {
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
+ existing_gemm->backend_config<GpuBackendConfig>());
+ const GemmBackendConfig &gemm_backend_config =
+ gpu_backend_config.gemm_backend_config();
+ // check if type combination is supported here
+ TF_ASSIGN_OR_RETURN(
+ bool types_are_supported,
+ IsLegacyCublasMatmul(*existing_gemm)
+ ? TypesAreSupportedByLegacyCublas(*existing_gemm,
+ gemm_backend_config, instr)
+ : TypesAreSupportedByCublasLt(*existing_gemm, gemm_backend_config,
+ instr));
+
+ // for mix type gemm, only fuse add if there is no consumers
+ // ROOT add
+ // ROOT tuple(add)
+ bool has_no_consumer =
+ instr->shape().element_type() ==
+ existing_gemm->shape().element_type() ||
+ instr->user_count() == 0 ||
+ (instr->user_count() == 1 &&
+ instr->users()[0]->opcode() == HloOpcode::kTuple &&
+ instr->users()[0]->user_count() == 0);
+
+ if (types_are_supported && has_no_consumer) {
+ return FuseMatrixBiasAdd(instr, bias, existing_gemm);
+ }
+ }
+
+ HloInstruction *optional_bitcast_matrix = nullptr;
+ HloInstruction *optional_slice_matrix = nullptr;
+ if (Match(instr,
+ m::AddAnyOrder(
+ OptionalBitcast(
+ &optional_bitcast_matrix,
+ OptionalSlice(&optional_slice_matrix,
+ GemmOrCublasLtMatmulMaybeF8(&existing_gemm)
+ .WithOneUser()))
+ .WithOneUser(),
+ m::Op(&bias).WithPredicate(is_not_broadcast)))) {
+ // The matrix bias must not be FP8, see
+ // https://docs.nvidia.com/cuda/cublas/index.html.
+ if (!IsF8Type(bias)) {
+ return FuseMatrixBiasAdd(instr, bias, existing_gemm,
+ optional_bitcast_matrix,
+ optional_slice_matrix);
+ }
+ }
+
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleMaximum(HloInstruction *instr) override {
+ HloInstruction *existing_gemm, *zeros;
+ HloInstruction *optional_slice_or_bitcast = nullptr;
+ // Attempt to elide maximum and fuse ReLU activation into GEMM, including
+ // when slicing or bitcasting is applied to the result.
+ if (Match(instr,
+ m::MaximumAnyOrder(
+ m::AnyOf<HloInstruction>(
+ m::Slice(
+ &optional_slice_or_bitcast,
+ CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
+ m::Bitcast(
+ &optional_slice_or_bitcast,
+ CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
+ CublasLtMatmulMaybeF8(&existing_gemm))
+ .WithOneUser(),
+ m::Broadcast(&zeros, m::ConstantScalar(0))))) {
+ TF_RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm,
+ optional_slice_or_bitcast));
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleConvert(HloInstruction *instr) override {
+ HloInstruction *clamp_lower, *clamp_upper, *existing_gemm,
+ *d_scale = nullptr, *binary = nullptr;
+ // Attempt to elide the scaling and conversion of the result of an FP8
+ // GEMM, including the optional calculation of the maximum of the absolute
+ // values before scaling, and adapt the Custom Call.
+ if (Match(instr,
+ m::Convert(
+ m::Clamp(
+ m::Broadcast(m::ConstantScalar(&clamp_lower)),
+ m::AnyOf<HloInstruction>(
+ CublasLtMatmulF8(&existing_gemm),
+ m::Divide(&binary, CublasLtMatmulF8(&existing_gemm),
+ m::Broadcast(m::Op(&d_scale))),
+ m::MultiplyAnyOrder(&binary,
+ CublasLtMatmulF8(&existing_gemm),
+ m::Broadcast(m::Op(&d_scale)))),
+ m::Broadcast(m::ConstantScalar(&clamp_upper)))
+ .WithOneUser()))) {
+ return F8ConvertD(
+ instr, existing_gemm, d_scale, clamp_lower, clamp_upper,
+ /*mult_scale=*/(binary && binary->opcode() == HloOpcode::kMultiply));
+ }
+ return absl::OkStatus();
+ }
+
+ static bool IsCuda(const se::GpuComputeCapability &gpu_version) {
+ return std::holds_alternative<se::CudaComputeCapability>(gpu_version);
+ }
+
+ static absl::StatusOr<se::CudaComputeCapability> GetCudaComputeCapability(
+ const se::GpuComputeCapability &gpu_version) {
+ auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version);
+ if (cuda_cc == nullptr) {
+ return absl::InvalidArgumentError("Compute Capability is not CUDA.");
+ }
+ return *cuda_cc;
+ }
+
+ static bool IsRocm(const se::GpuComputeCapability &gpu_version) {
+ return std::holds_alternative<se::RocmComputeCapability>(gpu_version);
+ }
+
+ static absl::StatusOr<se::RocmComputeCapability> GetRocmComputeCapability(
+ const se::GpuComputeCapability &gpu_version) {
+ auto rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);
+ if (rocm_cc == nullptr) {
+ return absl::InvalidArgumentError("Compute Capability is not ROCm.");
+ }
+ return *rocm_cc;
+ }
+
+ absl::StatusOr<bool> CreateF8CustomCall(HloInstruction *instr,
+ GpuBackendConfig &gpu_backend_config,
+ MatchedFp8Param a,
+ MatchedFp8Param b) {
+ GemmBackendConfig &gemm_backend_config =
+ *gpu_backend_config.mutable_gemm_backend_config();
+ if (IsCuda(gpu_version_)) {
+ TF_ASSIGN_OR_RETURN(auto cuda_compute_capability,
+ GetCudaComputeCapability(gpu_version_));
+ // FP8 GEMM kernels are only available on Ada, Hopper, and later
+ // architectures.
+ if (!cuda_compute_capability.IsAtLeast(8, 9)) {
+ VLOG(1) << "FP8 Custom Calls require Ada, Hopper, or later "
+ "architectures. Got: "
+ << cuda_compute_capability.ToString()
+ << " and toolkit version: " << toolkit_version_;
+ return false;
+ }
+ // FP8 GEMM kernels are only available with CUDA 12.0 and above
+ if (toolkit_version_ < 12000) {
+ VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer.";
+ return false;
+ }
+ }
+
+ if (IsRocm(gpu_version_)) {
+ TF_ASSIGN_OR_RETURN(auto rocm_compute_capability,
+ GetRocmComputeCapability(gpu_version_));
+ if (!rocm_compute_capability.has_fp8_support()) {
+ VLOG(1) << "FP8 Custom Calls require MI300, or later architectures.";
+ return false;
+ }
+ if (toolkit_version_ < 60000) {
+ // FP8 GEMM kernels are only available with ROCm 6.0 and above
+ VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer.";
+ return false;
+ }
+ }
+
+ PrimitiveType a_type = a.fp8_input->shape().element_type();
+ PrimitiveType b_type = b.fp8_input->shape().element_type();
+
+ // cuBLASLt FP8 GEMM kernels require one of the two operands to be in
+ // F8E4M3FN format.
+ if (IsCuda(gpu_version_)) {
+ if (a_type == F8E5M2 && b_type == F8E5M2) {
+ VLOG(1)
+ << "Failed to rewrite " << instr->ToShortString()
+ << " into FP8 Custom Call. The element type of one of the operands "
+ "must be F8E4M3FN.";
+ return false;
+ }
+ if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
+ (b_type != F8E5M2 && b_type != F8E4M3FN)) {
+ VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+ << " into FP8 Custom Call. The input types must be F8E5M2 or "
+ "F8E4M3FN, but got "
+ << PrimitiveType_Name(a_type) << " and "
+ << PrimitiveType_Name(b_type);
+ return false;
+ }
+ }
+
+ if (IsRocm(gpu_version_)) {
+ if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
+ VLOG(1)
+ << "Failed to rewrite " << instr->ToShortString()
+ << " into FP8 Custom Call. The element type of one of the operands "
+ "must be F8E4M3FNUZ.";
+ return false;
+ }
+ if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
+ (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
+ VLOG(1)
+ << "Failed to rewrite " << instr->ToShortString()
+ << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
+ "F8E4M3FNUZ, but got "
+ << PrimitiveType_Name(a_type) << " and "
+ << PrimitiveType_Name(b_type);
+ return false;
+ }
+ }
+
+ absl::Span<const int64_t> batch_dims =
+ gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions();
+
+ // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32
+ // format. Set the factors to one when no scaling factors were captured.
+ Literal one_literal = LiteralUtil::One(F32);
+ HloInstruction *one = instr->AddInstruction(
+ HloInstruction::CreateConstant(one_literal.Clone()));
+ std::array<bool, 2> mult_scale{a.mult_scale, b.mult_scale};
+ std::array<HloInstruction *, 2> scales{a.scale, b.scale}, inv_scales,
+ scales_f32;
+ for (int i = 0; i < scales.size(); ++i) {
+ if (scales[i]) {
+ if (!ShapeUtil::IsScalar(scales[i]->shape())) {
+ VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+ << " into FP8 Custom Call. The scaling factors must be "
+ "scalars.";
+ return false;
+ }
+ if (!mult_scale[i]) {
+ inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary(
+ scales[i]->shape(), HloOpcode::kDivide, one, scales[i]));
+ }
+ scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i];
+ if (scales_f32[i]->shape().element_type() != F32) {
+ scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::MakeScalarShape(F32), scales_f32[i]));
+ }
+ } else {
+ scales_f32[i] = one;
+ }
+ }
+
+ PrimitiveType d_type = instr->shape().element_type();
+ bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32);
+ if (IsCuda(gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) {
+ supported_d_type = true;
+ }
+ if (IsRocm(gpu_version_) && toolkit_version_ >= 60200 &&
+ (d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) {
+ supported_d_type = true;
+ }
+ if (!supported_d_type) {
+ VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+ << " into FP8 Custom Call. Output element type must be "
+ << (IsCuda(gpu_version_) ? "F8E4M3FN, F8E5M2, BF16, F16 or F32. "
+ : toolkit_version_ >= 60200
+ ? "F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. "
+ : "BF16, F16 or F32. ")
+ << "Actual element type is " << PrimitiveType_Name(d_type);
+ return false;
+ }
+
+ // Each operand must have exactly one contracting and one non-contracting
+ // dimension.
+ absl::Span<const int64_t> a_contracting_dims =
+ gemm_backend_config.dot_dimension_numbers()
+ .lhs_contracting_dimensions();
+ absl::Span<const int64_t> b_contracting_dims =
+ gemm_backend_config.dot_dimension_numbers()
+ .rhs_contracting_dimensions();
+ if (a_contracting_dims.size() != 1 || b_contracting_dims.size() != 1) {
+ VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+ << " into FP8 Custom Call. A and B must have one contracting "
+ "dimension.";
+ return false;
+ }
+ if ((a.commutative_ops.empty() ? a.fp8_input
+ : a.commutative_ops.back().first)
+ ->shape()
+ .dimensions_size() -
+ batch_dims.size() !=
+ 2 ||
+ (b.commutative_ops.empty() ? b.fp8_input
+ : b.commutative_ops.back().first)
+ ->shape()
+ .dimensions_size() -
+ batch_dims.size() !=
+ 2) {
+ VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+ << "into FP8 Custom Call. A and B must have one non-contracting "
+ "dimension.";
+ return false;
+ }
+
+ // Sequentially apply the collected unary, dynamic-slice, pad and select ops
+ // to the unconverted and unscaled operands.
+ auto shift_ops = [&instr](HloInstruction *&x, InstrPath &x_ops) -> void {
+ for (std::pair<HloInstruction *, int> op : x_ops) {
+ std::vector<HloInstruction *> operands = {x};
+ // Insert the additional operands of dynamic-slice ops.
+ if (op.first->opcode() == HloOpcode::kDynamicSlice) {
+ for (int i = 1; i < op.first->operand_count(); ++i) {
+ operands.emplace_back(op.first->mutable_operand(i));
+ }
+ }
+ // Convert the second operand of pad ops.
+ if (op.first->opcode() == HloOpcode::kPad) {
+ HloInstruction *convert =
+ instr->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(op.first->operand(1)->shape(),
+ x->shape().element_type()),
+ op.first->mutable_operand(1)));
+ operands.emplace_back(convert);
+ }
+ // Convert and insert the additional operands of select ops.
+ if (op.first->opcode() == HloOpcode::kSelect) {
+ // The first operand is the predicate.
+ operands.emplace(operands.begin(), op.first->mutable_operand(0));
+ // Convert the remaining operand.
+ int operand_idx = op.second == 2 ? 1 : 2;
+ HloInstruction *convert =
+ instr->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(
+ op.first->operand(operand_idx)->shape(),
+ x->shape().element_type()),
+ op.first->mutable_operand(operand_idx)));
+ operands.emplace(operands.begin() + operand_idx, convert);
+ }
+ x = instr->AddInstruction(op.first->CloneWithNewOperands(
+ ShapeUtil::MakeShapeWithDenseLayout(
+ x->shape().element_type(), op.first->shape().dimensions(),
+ op.first->shape().layout().minor_to_major()),
+ operands));
+ }
+ return;
+ };
+ shift_ops(a.fp8_input, a.commutative_ops);
+ shift_ops(b.fp8_input, b.commutative_ops);
+
+ TF_ASSIGN_OR_RETURN(GemmConfig gemm_config,
+ GemmConfig::For(instr, gemm_backend_config));
+
+ DotDimensionNumbers *dim_nums =
+ gemm_backend_config.mutable_dot_dimension_numbers();
+ int batch_dim_offset = batch_dims.size();
+
+ // cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
+ // be row-major. If A is column-major, swap the contracting and
+ // non-contracting dimension and transpose the matrix to effectively make it
+ // column-major.
+ // TODO(philipphack): Remove once cuBLASLt supports A being column-major
+ if (gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor) {
+ CHECK(a_contracting_dims[0] == batch_dim_offset ||
+ a_contracting_dims[0] == batch_dim_offset + 1);
+ if (a_contracting_dims[0] == batch_dim_offset) {
+ dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset + 1);
+ } else {
+ dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset);
+ }
+ a.fp8_input =
+ TransposeMatrix(a.fp8_input, a_contracting_dims[0], batch_dims);
+ }
+
+ // Similarly, cuBLASLt requires the second operand to be column-major, so
+ // make it column-major if it is currently row-major.
+ if (gemm_config.rhs_layout.order == MatrixLayout::Order::kRowMajor) {
+ CHECK(b_contracting_dims[0] == batch_dim_offset ||
+ b_contracting_dims[0] == batch_dim_offset + 1);
+ if (b_contracting_dims[0] == batch_dim_offset) {
+ dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset + 1);
+ } else {
+ dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset);
+ }
+ b.fp8_input =
+ TransposeMatrix(b.fp8_input, b_contracting_dims[0], batch_dims);
+ }
+
+ a.fp8_input = PadOperandToMultipleOf16(batch_dims, a.fp8_input);
+ b.fp8_input = PadOperandToMultipleOf16(batch_dims, b.fp8_input);
+ Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims);
+
+ std::vector<HloInstruction *> operands_list = {
+ a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one};
+
+ HloInstruction *new_custom_call =
+ instr->AddInstruction(HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeShapeWithDenseLayout(
+ instr->shape().element_type(), new_output_shape.dimensions(),
+ instr->shape().layout().minor_to_major()),
+ operands_list, kCublasLtMatmulF8CallTarget));
+ TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config));
+ TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call));
+
+ // Slice the result of the GEMM if the operands were padded.
+ HloInstruction *slice = nullptr;
+ if (new_output_shape.dimensions() != instr->shape().dimensions()) {
+ std::vector<int64_t> start_indices(instr->shape().rank(), 0);
+ std::vector<int64_t> strides(instr->shape().rank(), 1);
+ slice = instr->AddInstruction(HloInstruction::CreateSlice(
+ instr->shape(), new_custom_call, start_indices,
+ instr->shape().dimensions(), strides));
+ }
+
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(instr, slice ? slice : new_custom_call));
+ VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call.";
+ return true;
+ }
+
+ absl::Status F8ScaleD(HloInstruction *instr, HloInstruction *existing_gemm,
+ HloInstruction *d_scale) {
+ if (!ShapeUtil::IsScalar(d_scale->shape())) {
+ return absl::OkStatus();
+ }
+
+ // When the output of an FP8 GEMM is scaled but not type converted to FP8,
+ // cublasLT requires the scaling factor to be forwarded to the Custom Call
+ // as a_scale (chosen here) or b_scale. The scaling factor is fused here
+ // when no input scaling factors were fused during the creation of the
+ // Custom Call. When the maximum of the absolute value of the output of an
+ // FP8 GEMM is calculated and the output is scaled and type converted to
+ // FP8, the scaling of the output is fused in F8ConvertD.
+ if (!existing_gemm->operand(2)->IsConstant() ||
+ existing_gemm->operand(2)->literal().GetAsDouble({}) != 1.) {
+ return absl::OkStatus();
+ }
+
+ // The application of the scaling of the output to the input (see previous
+ // comment) is not valid for epilogues other than ReLU or when a matrix bias
+ // has been fused.
+ TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
+ existing_gemm->backend_config<GpuBackendConfig>());
+ const GemmBackendConfig &config = gpu_backend_config.gemm_backend_config();
+ if ((config.epilogue() != GemmBackendConfig::DEFAULT &&
+ config.epilogue() != GemmBackendConfig::RELU) ||
+ config.beta() != 0.) {
+ return absl::OkStatus();
+ }
+
+ // If necessary, invert the scaling factor of D and convert to F32.
+ TF_ASSIGN_OR_RETURN(
+ d_scale,
+ InvertAndConvertScalar(d_scale, instr->opcode() == HloOpcode::kDivide));
+
+ TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, d_scale));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
+
+ VLOG(1) << "Scaling of FP8 GEMM fused into Custom Call.";
+ return absl::OkStatus();
+ }
+
+ absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm,
+ HloInstruction *d_scale, HloInstruction *clamp_lower,
+ HloInstruction *clamp_upper,
+ bool mult_scale = false) {
+ // Verify the data types and the operands of clamp.
+ if (instr->shape().element_type() == F8E4M3FN) {
+ if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e4m3fn>::lowest())) ||
+ !clamp_upper->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e4m3fn>::max()))) {
+ return absl::OkStatus();
+ }
+ } else if (instr->shape().element_type() == F8E5M2) {
+ if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e5m2>::lowest())) ||
+ !clamp_upper->literal().IsAllFloat(static_cast<float>(
+ std::numeric_limits<tsl::float8_e5m2>::max()))) {
+ return absl::OkStatus();
+ }
+ } else {
+ return absl::OkStatus();
+ }
+
+ if (d_scale && !ShapeUtil::IsScalar(d_scale->shape())) {
+ return absl::OkStatus();
+ }
+
+ // The possible second user of the GEMM must be the calculation of the
+ // maximum of the absolute value of the result of the GEMM. Since it is
+ // unknown in what form this operation will be used, it is identified in a
+ // top-down approach by inspecting the users of the GEMM.
+ const std::vector<HloInstruction *> gemm_users = existing_gemm->users();
+ HloInstruction *reduce_damax = nullptr;
+ if (gemm_users.size() == 2) {
+ // In the presence of a ReLU activation, the abs instruction is elided
+ // since abs(ReLU(x)) = ReLU(x).
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ existing_gemm->backend_config<GpuBackendConfig>());
+ const GemmBackendConfig &config = gpu_config.gemm_backend_config();
+ for (int i = 0; i < gemm_users.size(); ++i) {
+ HloInstruction *maybe_reduce = nullptr;
+ if (gemm_users[i]->opcode() == HloOpcode::kAbs) {
+ if (gemm_users[i]->users().size() != 1) continue;
+ maybe_reduce = gemm_users[i]->users()[0];
+ } else {
+ // If there is no Abs instruction, relu is required as epilogue to
+ // ensure all values are nonnegative.
+ if (config.epilogue() != GemmBackendConfig::BIAS_RELU &&
+ config.epilogue() != GemmBackendConfig::RELU)
+ continue;
+ maybe_reduce = gemm_users[i];
+ }
+
+ if (maybe_reduce->opcode() == HloOpcode::kReduce &&
+ maybe_reduce->operands().size() == 2 &&
+ maybe_reduce->operand(1)->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::IsScalar(maybe_reduce->operand(1)->shape())) {
+ HloInstruction *reduce = maybe_reduce;
+ HloComputation *reduce_comp = reduce->to_apply();
+ HloInstruction *reduce_comp_root = reduce_comp->root_instruction();
+ if (reduce->operand(1)->literal().GetAsDouble({}) <= 0. &&
+ reduce_comp_root->opcode() == HloOpcode::kMaximum &&
+ reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
+ reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter) {
+ reduce_damax = reduce;
+ }
+ }
+ }
+ if (!reduce_damax) {
+ return absl::OkStatus();
+ }
+ } else if (gemm_users.size() > 2) {
+ return absl::OkStatus();
+ }
+
+ TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
+ existing_gemm->backend_config<GpuBackendConfig>());
+ const GemmBackendConfig &gemm_backend_config =
+ gpu_backend_config.gemm_backend_config();
+
+ if (gemm_backend_config.beta() != 0.0) {
+ if (existing_gemm->operand(2)->shape().element_type() != BF16 &&
+ existing_gemm->operand(2)->shape().element_type() != F16) {
+ VLOG(1) << "The scaling and conversion of the result of "
+ << existing_gemm->ToShortString()
+ << " is not fused into the FP8 Custom Call because it "
+ "conflicts with the existing fusion of the addition of a "
+ "matrix bias with element type other than BF16 or F16.";
+ return absl::OkStatus();
+ } else {
+ // Turn off the output to operand aliasing, since the fp8 output and
+ // bf16/fp16 bias have different sizes.
+ xla::Cast<HloCustomCallInstruction>(existing_gemm)
+ ->set_output_to_operand_aliasing({});
+ }
+ }
+
+ // If necessary, invert the scaling factor of D and convert to F32.
+ if (d_scale) {
+ TF_ASSIGN_OR_RETURN(d_scale,
+ InvertAndConvertScalar(d_scale, !mult_scale));
+ TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(
+ gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale));
+ }
+
+ // If present, elide the calculation of the maximum of the absolute values
+ // of the result of the GEMM.
+ if (reduce_damax) {
+ return F8AddDAmax(instr, existing_gemm, reduce_damax);
+ }
+
+ std::unique_ptr<HloInstruction> new_gemm =
+ existing_gemm->CloneWithNewShape(instr->shape());
+
+ TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm)));
+
+ VLOG(1) << "Conversion" << (reduce_damax ? " and amax calculation" : "")
+ << " fused into FP8 GEMM.";
+ return absl::OkStatus();
+ }
+
+ // Adds a scalar DAmax return value to an FP8 GEMM.
+ absl::Status F8AddDAmax(HloInstruction *instr, HloInstruction *existing_gemm,
+ HloInstruction *reduce_damax) {
+ // Change the output shape of the Custom Call to tuple(D, DAmax).
+ Shape damax_shape = ShapeUtil::MakeScalarShape(F32);
+ Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({instr->shape(), damax_shape});
+ HloInstruction *gemm_and_damax =
+ instr->AddInstruction(existing_gemm->CloneWithNewShape(tuple_shape));
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ gemm_and_damax->backend_config<GpuBackendConfig>());
+ GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+ config.set_damax_output(true);
+ TF_RETURN_IF_ERROR(gemm_and_damax->set_backend_config(gpu_config));
+
+ // Obtain D and DAmax separately from the output tuple.
+ HloInstruction *d =
+ instr->AddInstruction(HloInstruction::CreateGetTupleElement(
+ instr->shape(), gemm_and_damax, 0));
+ HloInstruction *damax = instr->AddInstruction(
+ HloInstruction::CreateGetTupleElement(damax_shape, gemm_and_damax, 1));
+
+ // Convert DAmax from FP32 to the requested type and elide reduce.
+ HloInstruction *damax_converted = instr->AddInstruction(
+ HloInstruction::CreateConvert(reduce_damax->shape(), damax));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(instr, d));
+
+ return absl::OkStatus();
+ }
+
+ // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add
+ // instruction in the following form:
+ // Add(OptionalBitcast(OptionalSlice(gemm)), bias)
+ // where 'gemm' is expected to be a cuBLAS custom_call. Slice is introduced
+ // when the inputs of the gemm are possibly padded. Bitcast is introduced to
+ // handle high rank input.
+ absl::Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias,
+ const HloInstruction *gemm,
+ HloInstruction *bitcast = nullptr,
+ HloInstruction *slice = nullptr) {
+ TF_RET_CHECK(Shape::Equal().IgnoreElementType()(bias->shape(),
+ bitcast ? bitcast->shape()
+ : slice ? slice->shape()
+ : gemm->shape()));
+
+ // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
+ // supports fixed values for alpha/beta.
+ if (gemm->shape().element_type() == S32) {
+ return absl::OkStatus();
+ }
+
+ // To ensure correctness, only slices that chop off the ends of dimensions
+ // are supported.
+ if (slice) {
+ int slice_op_dim = slice->operand(0)->shape().rank();
+ if (slice->slice_starts() != std::vector<int64_t>(slice_op_dim, 0) ||
+ slice->slice_strides() != std::vector<int64_t>(slice_op_dim, 1)) {
+ return absl::OkStatus();
+ }
+ }
+ // Cublas gemm overwrites the bias matrix, so fusion is only possible if the
+ // gemm is the only user. CublasLt gemm can operate out-of-place.
+ bool can_overwrite_bias = [bias]() {
+ if (bias->user_count() > 1) {
+ // There is another user of the data, do not overwrite it.
+ return false;
+ }
+
+ if (bias->opcode() != HloOpcode::kParameter) {
+ // Not a parameter; can overwrite.
+ return true;
+ }
+
+ // The bias is a parameter of the computation; check if it is aliased.
+ if (!bias->parent()->IsEntryComputation()) {
+ // Only the HloModule has input/output aliasing, since this is not the
+ // entry computation, there are no guarantees about aliasing; do not
+ // overwrite.
+ return false;
+ }
+ const auto &in_out_alias_config =
+ bias->GetModule()->input_output_alias_config();
+ // If the parameter is aliased, we can overwrite it.
+ // TODO(victorstone): The assumption when calling ParameterHasAlias is
+ // that bias is not a tuple. This is why we pass {} as the argument for
+ // param_index.
+ return in_out_alias_config.ParameterHasAlias(bias->parameter_number(),
+ /*param_index=*/{});
+ }();
+ bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) ||
+ IsCublasLtMatmul(*gemm) || can_overwrite_bias;
+
+ auto gpu_config = gemm->backend_config<GpuBackendConfig>().value();
+ GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+ // It is possible to fuse into a cublasLt matmul that already has a vector
+ // bias, but no other epilogue will commute with the matrix bias add.
+ bool supported_epilogue =
+ ((config.epilogue() == GemmBackendConfig::DEFAULT) ||
+ (config.epilogue() == GemmBackendConfig::BIAS));
+
+ if ((config.beta() != 0) || !want_to_fuse_bias ||
+ (gemm->user_count() != 1) || !supported_epilogue) {
+ return absl::OkStatus();
+ }
+
+ config.set_beta(1.0);
+
+ std::vector<HloInstruction *> operands(gemm->operands().begin(),
+ gemm->operands().end());
+ HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias);
+ if (bitcast) {
+ maybe_constant_folded_bias =
+ instr->AddInstruction(HloInstruction::CreateBitcast(
+ slice->shape(), maybe_constant_folded_bias));
+ }
+
+ maybe_constant_folded_bias =
+ PadOperandToTargetShape(gemm->shape(), maybe_constant_folded_bias);
+
+ operands.insert(operands.begin() + 2, maybe_constant_folded_bias);
+
+ std::unique_ptr<HloInstruction> fused_op =
+ gemm->CloneWithNewOperands(gemm->shape(), operands);
+ // set output shape to bias shape if mix type
+ fused_op->mutable_shape()->set_element_type(bias->shape().element_type());
+ TF_RETURN_IF_ERROR(fused_op->set_backend_config(gpu_config));
+
+ // Choose whether the bias must alias the output. Legacy cublas GEMMs must
+ // operate in place and alias the bias with the output, whereas with
+ // cublasLt we can choose.
+ //
+ // Operating in place is always safe; copy-insertion will insert copies if
+ // necessary. But (we assume) copying is slower than operating
+ // out-of-place, so for cublasLt (where we have the choice), we try to
+ // operate in place if we think it a copy won't be necessary.
+ //
+ // We assume that parameters are always read-only and therefore we'd need to
+ // copy if we were going to operate in place. (This is not quite true; the
+ // param could have input/output aliasing.) We also assume that if there
+ // are other uses of the bias, we might need to copy. (Again, not quite
+ // true if those uses all come before this operation. But copy-insertion
+ // runs before scheduling, so it can't know and has to conservatively insert
+ // copies.)
+ if (IsLegacyCublasMatmul(*fused_op) || can_overwrite_bias) {
+ xla::Cast<HloCustomCallInstruction>(fused_op.get())
+ ->set_output_to_operand_aliasing({{{}, {2, {}}}});
+ }
+ TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
+ if (slice) {
+ fused_op = slice->CloneWithNewOperands(
+ slice->shape(),
+ {slice->parent()->AddInstruction(std::move(fused_op))});
+ }
+
+ if (bitcast) {
+ fused_op = bitcast->CloneWithNewOperands(
+ bitcast->shape(),
+ {bitcast->parent()->AddInstruction(std::move(fused_op))});
+ }
+
+ return ReplaceWithNewInstruction(instr, std::move(fused_op));
+ }
+
+ // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add
+ // instruction in the following form:
+ // Add(OptionalBitcast(OptionalSlice(gemm)), Broadcast(OptionalConvert()))
+ // where 'gemm' is expected to be a cuBLAS custom_call. The optional
+ // convert is only used for F8 matmuls as cublasLt has specific constraints
+ // on the vector bias type for such matmuls. The optional bitcast is
+ // necessary to handle high rank input cases.
+ absl::StatusOr<bool> FuseVectorBiasAdd(HloInstruction *instr,
+ HloInstruction *broadcast,
+ HloInstruction *gemm,
+ HloInstruction *slice = nullptr,
+ HloInstruction *convert = nullptr,
+ HloInstruction *bitcast = nullptr) {
+ if (!bitcast) {
+ TF_RET_CHECK(ShapeUtil::Compatible(
+ broadcast->shape(), (slice ? slice->shape() : gemm->shape())));
+ }
+ // Verify that the data type is supported by Epilogue Fusion.
+ if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
+ return false;
+ }
+
+ HloInstruction *bias = broadcast->mutable_operand(0);
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ gemm->backend_config<GpuBackendConfig>());
+ GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+ // # output column dims == # non-contracting rhs operand dims.
+ const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
+ size_t num_col_dims = gemm->operand(1)->shape().rank() -
+ dot_dims.rhs_batch_dimensions_size() -
+ dot_dims.rhs_contracting_dimensions_size();
+
+ if ((gemm->user_count() != 1) ||
+ (config.epilogue() != GemmBackendConfig::DEFAULT) ||
+ (bias->shape().rank() != num_col_dims)) {
+ return false;
+ }
+ // We require the bias vector to have been broadcast in the most major
+ // dimensions; i.e. its most minor physical dimensions align with most minor
+ // physical dimensions of the gemm output.
+ absl::Span<const int64_t> broadcast_dims = broadcast->dimensions();
+ for (size_t i = 0; i < num_col_dims; ++i) {
+ int64_t dim =
+ (bitcast ? bitcast : gemm)->shape().layout().minor_to_major(i);
+
+ // Find the corresponding dimension from the bias vector.
+ auto it = absl::c_find(broadcast_dims, dim);
+
+ if (it == broadcast_dims.end()) {
+ return false;
+ }
+
+ int64_t vector_dim = it - broadcast_dims.begin();
+ if (bias->shape().layout().minor_to_major(i) != vector_dim) {
+ return false;
+ }
+ }
+
+ std::vector<HloInstruction *> operands(gemm->operands().begin(),
+ gemm->operands().end());
+ // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just
+ // fuse matrix bias.
+ if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
+ config.beta() != 0.0) {
+ return true;
+ }
+
+ if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
+ bias->shape().element_type() == F32) {
+ if (convert == nullptr) {
+ return false;
+ }
+
+ HloInstruction *bias_f16_or_bf16 = convert->mutable_operand(0);
+ auto compatible_bias_type = [](const PrimitiveType bias_type,
+ const PrimitiveType output_type) {
+ if (bias_type == BF16) {
+ return output_type == F8E4M3FN || output_type == F8E5M2 ||
+ output_type == F32 || output_type == BF16;
+ } else if (bias_type == F16) {
+ return output_type == F16 || output_type == F8E4M3FN ||
+ output_type == F8E5M2;
+ }
+ return false;
+ };
+
+ // cuBLAS LT does not support FP32 biases on matmuls with FP8 inputs,
+ // even if the matmul output is FP32. We do not unconditionally convert
+ // the bias to a supported precision (F16 or BF16) because this lowers
+ // precision. Instead, we only fuse the bias if the bias itself is a
+ // convert from F16 or BF16, fusing the input of the convert instruction
+ // to the matmul.
+ if (compatible_bias_type(bias_f16_or_bf16->shape().element_type(),
+ gemm->shape().element_type())) {
+ bias = bias_f16_or_bf16;
+ } else {
+ VLOG(1) << "Epilogue fusion of FP32 vector bias into FP8 GEMM is "
+ "currently not supported. See the cublasLT support matrix.";
+ return false;
+ }
+ }
+
+ // In the case of high rank input for FP8, it is necessary to consider
+ // potential padding for the bias.
+ if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && bitcast) {
+ bias = PadOperandToMultipleOf16(
+ config.dot_dimension_numbers().rhs_batch_dimensions(), bias);
+ }
+ // Replace add(gemm, broadcast) with fused new_gemm.
+ operands.push_back(bias);
+ config.set_epilogue(GemmBackendConfig::BIAS);
+ std::unique_ptr<HloInstruction> result =
+ gemm->CloneWithNewOperands(gemm->shape(), operands);
+ TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
+ if (slice) {
+ result = slice->CloneWithNewOperands(
+ slice->shape(), {slice->parent()->AddInstruction(std::move(result))});
+ }
+
+ if (bitcast) {
+ result = bitcast->CloneWithNewOperands(
+ bitcast->shape(),
+ {bitcast->parent()->AddInstruction(std::move(result))});
+ }
+ TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(result)));
+ return true;
+ }
+
+ absl::Status FuseReluActivation(HloInstruction *instr,
+ HloInstruction *broadcast,
+ HloInstruction *gemm,
+ HloInstruction *slice_or_bitcast = nullptr) {
+ TF_RET_CHECK(ShapeUtil::Compatible(
+ broadcast->shape(),
+ (slice_or_bitcast ? slice_or_bitcast->shape() : gemm->shape())));
+
+ if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
+ return absl::OkStatus();
+ }
+
+ if (gemm->user_count() != 1) {
+ return absl::OkStatus();
+ }
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ gemm->backend_config<GpuBackendConfig>());
+ GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+ if (config.epilogue() == GemmBackendConfig::DEFAULT) {
+ config.set_epilogue(GemmBackendConfig::RELU);
+ } else if (config.epilogue() == GemmBackendConfig::BIAS) {
+ config.set_epilogue(GemmBackendConfig::BIAS_RELU);
+ } else {
+ return absl::OkStatus();
+ }
+
+ std::unique_ptr<HloInstruction> result = gemm->Clone();
+ TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
+
+ if (slice_or_bitcast) {
+ result = slice_or_bitcast->CloneWithNewOperands(
+ slice_or_bitcast->shape(),
+ {slice_or_bitcast->parent()->AddInstruction(std::move(result))});
+ }
+
+ return ReplaceWithNewInstruction(instr, std::move(result));
+ }
+
+ absl::Status FuseGeluActivation(HloInstruction *multiply,
+ HloInstruction *gemm,
+ HloInstruction *slice_or_bitcast = nullptr) {
+ if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
+ return absl::OkStatus();
+ }
+ // For CUDA versions less than 12.3.2, cuBLAS LT returns
+ // CUBLAS_STATUS_NOT_SUPPORTED in some cases when fusing gelu into an FP8
+ // matmul. We cannot check the patch version, so disable this fusion with
+ // CUDA versions less than 12.4.
+ if (IsCuda(gpu_version_) && toolkit_version_ < 12040 &&
+ IsCublasLtMatmulF8(*gemm)) {
+ return absl::OkStatus();
+ }
+
+ // There are four users of the gemm output within the GELU calculation.
+ bool has_aux = gemm->user_count() > 4;
+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ gemm->backend_config<GpuBackendConfig>());
+ GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+
+ if (config.epilogue() == GemmBackendConfig::DEFAULT) {
+ config.set_epilogue(has_aux ? GemmBackendConfig::GELU_AUX
+ : GemmBackendConfig::GELU);
+ } else if (config.epilogue() == GemmBackendConfig::BIAS) {
+ config.set_epilogue(has_aux ? GemmBackendConfig::BIAS_GELU_AUX
+ : GemmBackendConfig::BIAS_GELU);
+ } else {
+ return absl::OkStatus();
+ }
+
+ std::unique_ptr<HloInstruction> output = gemm->CloneWithNewShape(
+ has_aux ? ShapeUtil::MakeTupleShape({gemm->shape(), gemm->shape()})
+ : gemm->shape());
+ TF_RETURN_IF_ERROR(output->set_backend_config(gpu_config));
+ TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get()));
+
+ if (slice_or_bitcast) {
+ output = slice_or_bitcast->CloneWithNewOperands(
+ slice_or_bitcast->shape(),
+ {gemm->parent()->AddInstruction(std::move(output))});
+ }
+
+ if (has_aux) {
+ HloInstruction *tuple_output =
+ gemm->parent()->AddInstruction(std::move(output));
+ TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
+ gemm, HloInstruction::CreateGetTupleElement(tuple_output, 1)));
+ output = HloInstruction::CreateGetTupleElement(tuple_output, 0);
+ }
+
+ return ReplaceWithNewInstruction(multiply, std::move(output));
+ }
+
+ private:
+ se::GpuComputeCapability gpu_version_;
+ int32_t toolkit_version_;
+ GemmRewriterOptions options_;
+
+ // Choose cublas or cublasLt for the target of the custom call that instr will
+ // be rewritten into.
+ absl::StatusOr<absl::string_view> GetNonFp8GemmCustomCallTarget(
+ const HloInstruction &instr,
+ const GemmBackendConfig &gemm_backend_config) const {
+ if (!instr.GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_enable_cublaslt()) {
+ // cublasLt is not enabled.
+ return absl::string_view(kGemmCallTarget);
+ }
+
+ // cublasLt is enabled, check if other internal conditions are met.
+ const HloInstruction *lhs = instr.operand(0);
+ const HloInstruction *rhs = instr.operand(1);
+ if (lhs->shape().element_type() == S8 ||
+ rhs->shape().element_type() == S8) {
+ // TODO(b/241446501) The XLA usage of cublasLt does not yet handle
+ // int8 matmuls. Fallback to legacy cublas.
+ return absl::string_view(kGemmCallTarget);
+ }
+
+ // All internal conditions are met, check if we meet the requirements of
+ // cublasLt.
+ TF_ASSIGN_OR_RETURN(bool gemm_is_supported_by_cublas_lt,
+ GemmIsSupportedByCublasLt(instr, gemm_backend_config));
+ if (gemm_is_supported_by_cublas_lt) {
+ return absl::string_view(kCublasLtMatmulCallTarget);
+ }
+
+ // This case is not supported by cublasLt, fallback to legacy cublas.
+ return absl::string_view(kGemmCallTarget);
+ }
+
+ absl::StatusOr<bool> TypesAreSupportedByLegacyCublas(
+ const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config,
+ const HloInstruction *bias = nullptr) const {
+ // Figure out the Atype/Btype.
+ const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
+ const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
+ const PrimitiveType output_type =
+ bias ? bias->shape().element_type() : instr.shape().element_type();
+ const std::array<PrimitiveType, 12> supported_type = {
+ PrimitiveType::S8, PrimitiveType::F16, PrimitiveType::BF16,
+ PrimitiveType::F32, PrimitiveType::S32, PrimitiveType::F64,
+ PrimitiveType::C64, PrimitiveType::C128};
+ // legacy cublas has a defined set of combinations of types that it
+ // supports. Figure out the computeType and scaleType.
+ if (!absl::c_linear_search(supported_type, output_type)) return false;
+ TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
+ se::gpu::AsBlasDataType(output_type));
+ // TODO(tdanyluk): Investigate why don't we use the actual precision (and
+ // algorithm) here? Why do we use the default?
+ TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
+ se::gpu::GetBlasComputationType(
+ PrecisionConfig::ALG_UNSET, a_dtype, output_type,
+ stream_executor::blas::kDefaultComputePrecision));
+ se::blas::DataType scale_type =
+ se::gpu::GetScaleType(output_dtype, compute_type);
+
+ using se::blas::ComputationType;
+ using se::blas::DataType;
+ // This matrix of supported types is taken directly from cublas
+ // documentation.
+ // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex
+ const std::array<
+ std::tuple<ComputationType, DataType /*scale_type*/,
+ PrimitiveType /*a_dtype*/, PrimitiveType /*b_dtype*/,
+ DataType /*output_dtype*/>,
+ 32>
+ supported_type_combinations = {{
+ {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
+ PrimitiveType::F16, DataType::kHalf},
+
+ {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
+ PrimitiveType::S8, DataType::kInt32},
+
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+ PrimitiveType::BF16, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+ PrimitiveType::F16, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
+ PrimitiveType::S8, DataType::kFloat},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+ PrimitiveType::BF16, DataType::kFloat},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+ PrimitiveType::F16, DataType::kFloat},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+
+ // There would be an entry here for A/BType complex int8, but we do
+ // not support that type.
+ {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
+ PrimitiveType::C64, DataType::kComplexFloat},
+
+ {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+ {ComputationType::kF16AsF32, DataType::kComplexFloat,
+ PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+ {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+ {ComputationType::kBF16AsF32, DataType::kComplexFloat,
+ PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+ {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+ {ComputationType::kTF32AsF32, DataType::kComplexFloat,
+ PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+ {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
+ PrimitiveType::F64, DataType::kDouble},
+ {ComputationType::kF64, DataType::kComplexDouble,
+ PrimitiveType::C128, PrimitiveType::C128,
+ DataType::kComplexDouble},
+ }};
+
+ return absl::c_linear_search(
+ supported_type_combinations,
+ std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
+ output_dtype));
+ }
+
+ absl::StatusOr<bool> TypesAreSupportedByCublasLt(
+ const HloInstruction &instr, const GemmBackendConfig &backend_config,
+ const HloInstruction *bias = nullptr) const {
+ // Figure out the Atype/Btype.
+ const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
+ const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
+ const PrimitiveType output_type =
+ bias ? bias->shape().element_type() : instr.shape().element_type();
+ const std::array<PrimitiveType, 12> supported_type = {
+ PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN,
+ PrimitiveType::S8, PrimitiveType::F16,
+ PrimitiveType::BF16, PrimitiveType::F32,
+ PrimitiveType::S32, PrimitiveType::F64,
+ PrimitiveType::C64, PrimitiveType::C128};
+ if (!absl::c_linear_search(supported_type, output_type)) return false;
+ // cublasLt has a defined set of combinations of types that it supports.
+ // Figure out the computeType and scaleType.
+ TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
+ se::gpu::AsBlasDataType(output_type));
+ const int max_precision = *absl::c_max_element(
+ backend_config.precision_config().operand_precision());
+ const PrecisionConfig::Algorithm algorithm =
+ backend_config.precision_config().algorithm();
+ if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm)) return false;
+
+ TF_ASSIGN_OR_RETURN(
+ const se::blas::ComputationType compute_type,
+ se::gpu::GetBlasComputationType(
+ algorithm, a_dtype, instr.shape().element_type(), max_precision));
+ se::blas::DataType scale_type =
+ se::gpu::GetScaleType(output_dtype, compute_type);
+
+ using se::blas::ComputationType;
+ using se::blas::DataType;
+ using TypeCombinations = std::initializer_list<std::tuple<
+ ComputationType, DataType /*scale_type*/, PrimitiveType /*a_dtype*/,
+ PrimitiveType /*b_dtype*/, DataType /*output_dtype*/>>;
+ // This matrix of supported types is taken directly from cublasLt
+ // documentation.
+ // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
+ const TypeCombinations supported_cublas_type_combinations = {
+ // FP8 types:
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E4M3FN, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E4M3FN, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E4M3FN, DataType::kFloat},
+
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E5M2, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E5M2, DataType::kF8E4M3FN},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E5M2, DataType::kF8E5M2},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E5M2, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+ PrimitiveType::F8E5M2, DataType::kFloat},
+
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+ PrimitiveType::F8E4M3FN, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+ PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+ PrimitiveType::F8E4M3FN, DataType::kF8E5M2},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+ PrimitiveType::F8E4M3FN, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+ PrimitiveType::F8E4M3FN, DataType::kFloat},
+ // There would be an entry here for A/BType complex int8, but we do
+ // not support that type.
+ {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
+ PrimitiveType::C64, DataType::kComplexFloat},
+
+ {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+ {ComputationType::kF16AsF32, DataType::kComplexFloat,
+ PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+ // The next 4 may be supported by hipblaslt, but they are not
+ // covered by any unit tests
+ {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+ {ComputationType::kBF16AsF32, DataType::kComplexFloat,
+ PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+ {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+ {ComputationType::kTF32AsF32, DataType::kComplexFloat,
+ PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+ {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
+ PrimitiveType::F64, DataType::kDouble},
+ {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128,
+ PrimitiveType::C128, DataType::kComplexDouble},
+ };
+ if (IsCuda(gpu_version_) &&
+ absl::c_linear_search(supported_cublas_type_combinations,
+ std::tuple{compute_type, scale_type, a_dtype,
+ b_dtype, output_dtype})) {
+ return true;
+ }
+ const TypeCombinations supported_hipblas_type_combinations = {
+ // FP8 types:
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
+
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E5M2FNUZ, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E5M2FNUZ, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+ PrimitiveType::F8E5M2FNUZ, DataType::kFloat},
+
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+ PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
+ };
+ if (IsRocm(gpu_version_) &&
+ absl::c_linear_search(supported_hipblas_type_combinations,
+ std::tuple{compute_type, scale_type, a_dtype,
+ b_dtype, output_dtype})) {
+ return true;
+ }
+ const TypeCombinations supported_type_combinations = {
+ // Other data types:
+ {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
+ PrimitiveType::F16, DataType::kHalf},
+
+ {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
+ PrimitiveType::S8, DataType::kInt32},
+ {ComputationType::kI32, DataType::kFloat, PrimitiveType::S8,
+ PrimitiveType::S8, DataType::kInt8},
+
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+ PrimitiveType::BF16, DataType::kBF16},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+ PrimitiveType::F16, DataType::kHalf},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
+ PrimitiveType::S8, DataType::kFloat},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+ PrimitiveType::BF16, DataType::kFloat},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+ PrimitiveType::F16, DataType::kFloat},
+ {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
+ PrimitiveType::F32, DataType::kFloat},
+ };
+
+ return absl::c_linear_search(
+ supported_type_combinations,
+ std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
+ output_dtype));
+ }
+
+ absl::StatusOr<bool> GemmIsSupportedByCublasLt(
+ const HloInstruction &instr,
+ const GemmBackendConfig &gemm_backend_config) const {
+ const HloInstruction *lhs = instr.operand(0);
+ const Shape &output_shape = instr.shape();
+
+ TF_ASSIGN_OR_RETURN(
+ bool types_are_supported_by_cublas_lt,
+ TypesAreSupportedByCublasLt(instr, gemm_backend_config));
+ if (!types_are_supported_by_cublas_lt) {
+ return false;
+ }
+
+ // The cublasLt API has two currently known limitations:
+ // 1. Batch count must be <2^16.
+ constexpr int64_t kMaxBatchCount = 65535;
+ // We get the batch dimension size from lhs here, but we could just as well
+ // use rhs; they are guaranteed to be the same (TODO:Verify).
+ const auto &batch_dimensions =
+ gemm_backend_config.dot_dimension_numbers().lhs_batch_dimensions();
+ int batch_count = (batch_dimensions.empty() ? 0 : 1);
+ // All batch dimensions get flattened into a single batch dimension.
+ for (auto batch_dimension : batch_dimensions) {
+ batch_count *= lhs->shape().dimensions(batch_dimension);
+ }
+ if (batch_count > kMaxBatchCount) {
+ // This is not supported by cublasLt.
+ return false;
+ }
+
+ if (auto isrocm = std::get_if<se::RocmComputeCapability>(&gpu_version_);
+ isrocm) {
+ if (!isrocm->has_hipblaslt()) {
+ return false;
+ }
+ }
+
+ // 2. cublasLt does not support rhs col dimension size > 4194240 for
+ // C64.
+ constexpr int kMaxDimensionSize{4194240};
+ if (output_shape.element_type() != C64) {
+ // Does not match type in unsupported case.
+ return true;
+ }
+
+ if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
+ if (std::get<se::CudaComputeCapability>(gpu_version_).IsAtLeastAmpere()) {
+ // cuBlasLt has an implementation for complex data with compute type
+ // 32F_FAST_32TF that uses tensor cores and that is free from the
+ // restriction. This implementation only works on Ampere
+ // architecture though (where TF32 was introduced).
+ return true;
+ }
+ }
+
+ TF_ASSIGN_OR_RETURN(GemmConfig gemm_config,
+ GemmConfig::For(&instr, gemm_backend_config));
+
+ // Check that the size of the non-contracting dimension is not too large.
+ return gemm_config.rhs_layout.num_cols <= kMaxDimensionSize;
+ }
+
+ // Turns an F8 dot with unsupported output type into an F8 dot with F32
+ // output, and converting the F32 output to unsupported output types.
+ absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32(
+ HloInstruction *instr) {
+ Shape output_f32_shape = instr->shape();
+ output_f32_shape.set_element_type(F32);
+ HloInstruction *f32_dot =
+ instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape));
+ HloInstruction *convert = instr->AddInstruction(
+ HloInstruction::CreateConvert(instr->shape(), f32_dot));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert));
+ return f32_dot;
+ }
+
+ // Turns an F8 dot into an F16 dot, converting operands to F16 and
+ // converting the output back to F8.
+ absl::StatusOr<HloInstruction *> TurnF8DotIntoF16Dot(HloInstruction *instr) {
+ DCHECK(IsF8Type(instr->operand(0)));
+ DCHECK(IsF8Type(instr->operand(1)));
+
+ // Convert operands to F16
+ for (int i = 0; i < 2; ++i) {
+ Shape operand_f16_shape = instr->operand(i)->shape();
+ operand_f16_shape.set_element_type(F16);
+ HloInstruction *convert =
+ instr->AddInstruction(HloInstruction::CreateConvert(
+ operand_f16_shape, instr->mutable_operand(i)));
+ TF_RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert));
+ }
+
+ // If output is F8, change output to F16 and then convert it back to F8
+ if (IsF8Type(instr)) {
+ Shape output_f16_shape = instr->shape();
+ output_f16_shape.set_element_type(F16);
+ HloInstruction *f16_dot =
+ instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape));
+ HloInstruction *convert_to_f8 = instr->AddInstruction(
+ HloInstruction::CreateConvert(instr->shape(), f16_dot));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8));
+ return f16_dot;
+ } else {
+ return instr;
+ }
+ }
+};
+
+// Rewriter that adds a workspace to legacy cuBLAS custom calls. We run it
+// separately after gemm rewriter, so that we can do pattern matching without
+// having to match output tuples.
+class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit GemmWorkspaceRewriteVisitor(
+ const se::GpuComputeCapability &gpu_version)
+ : gpu_version_(gpu_version) {}
+
+ absl::Status HandleCustomCall(HloInstruction *instr) override {
+ bool has_aux_output = false;
+ if (instr->custom_call_target() == kCublasLtMatmulCallTarget ||
+ instr->custom_call_target() == kCublasLtMatmulF8CallTarget) {
+ TF_ASSIGN_OR_RETURN(const auto gpu_config,
+ instr->backend_config<xla::gpu::GpuBackendConfig>());
+ const xla::gpu::GemmBackendConfig &config =
+ gpu_config.gemm_backend_config();
+ xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
+ TF_ASSIGN_OR_RETURN(
+ has_aux_output,
+ xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(epilogue));
+
+ if (!((instr->shape().IsTuple() &&
+ instr->shape().tuple_shapes_size() ==
+ has_aux_output + config.damax_output() + 1) ||
+ instr->shape().IsArray())) {
+ return absl::OkStatus();
+ }
+ } else if (instr->custom_call_target() != kGemmCallTarget ||
+ !instr->shape().IsArray()) {
+ return absl::OkStatus();
+ }
+
+ auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version_);
+
+ // Pass a user-managed workspace to legacy cuBLAS operations, as
+ // otherwise cuBLAS will use its own internal pool which will be competing
+ // with XLA allocator for device memory.
+ int64_t workspace = cuda_cc == nullptr ? GemmConfig::kDefaultWorkspace
+ : cuda_cc->IsAtLeastHopper()
+ ? GemmConfig::kHopperWorkspace
+ : GemmConfig::kDefaultWorkspace;
+
+ // We do not know the workspace size required by cuBLAS, but we can guess
+ // that in a worst case cuBLAS will transpose all operands into tiled
+ // layout optimal for the tensor cores. It doesn't make sense to allocate a
+ // larger workspace.
+ //
+ // TODO(ezhulenev): This is not based on any measurement, just a common
+ // sense, we should tweak it to find the minimal workspace size.
+ if (instr->custom_call_target() == kGemmCallTarget) {
+ int64_t operands_byte_size = 0;
+ for (auto &operand : instr->operands()) {
+ operands_byte_size += ShapeUtil::ByteSizeOf(operand->shape());
+ }
+ workspace = std::min(workspace, operands_byte_size);
+ }
+
+ // Append workspace buffer to instruction outputs.
+ std::vector<Shape> output_shapes = instr->shape().IsArray()
+ ? std::vector<Shape>{instr->shape()}
+ : instr->shape().tuple_shapes();
+ output_shapes.emplace_back(ShapeUtil::MakeShape(S8, {workspace}));
+ Shape output_shape = ShapeUtil::MakeTupleShape(output_shapes);
+
+ // Clone custom call with a new shape.
+ HloInstruction *new_call = instr->AddInstruction(
+ instr->CloneWithNewOperands(output_shape, instr->operands()));
+
+ // Update operand aliasing if it was a fused gemm with aliased output.
+ auto *custom_call = xla::Cast<HloCustomCallInstruction>(new_call);
+ if (!custom_call->output_to_operand_aliasing().empty()) {
+ custom_call->set_output_to_operand_aliasing({{{0}, {2, {}}}});
+ }
+
+ if (instr->shape().IsTuple()) {
+ for (auto user : instr->users()) {
+ auto user_get_tuple =
+ dynamic_cast<HloGetTupleElementInstruction *>(user);
+ TF_RET_CHECK(user_get_tuple);
+ HloInstruction *get_output =
+ instr->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_call, user_get_tuple->tuple_index()));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(user_get_tuple, get_output));
+ }
+ return absl::OkStatus();
+ } else {
+ HloInstruction *get_output = instr->AddInstruction(
+ HloInstruction::CreateGetTupleElement(new_call, 0));
+ return ReplaceInstruction(instr, get_output);
+ }
+ }
+
+ private:
+ se::GpuComputeCapability gpu_version_;
+};
+
+absl::StatusOr<bool> RunOnComputation(HloComputation *computation,
+ se::GpuComputeCapability gpu_version,
+ int32_t toolkit_version,
+ GemmRewriterOptions options) {
+ GemmRewriterVisitor visitor(gpu_version, toolkit_version, options);
+ TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+ GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version);
+ TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor));
+ return visitor.changed();
+}
+
+} // anonymous namespace
+
+GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version,
+ int32_t toolkit_version, GemmRewriterOptions options)
+ : gpu_version_(gpu_version),
+ toolkit_version_(toolkit_version),
+ options_(options) {}
+
+absl::StatusOr<bool> GemmRewriter::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ bool changed = false;
+ for (HloComputation *computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result,
+ RunOnComputation(computation, gpu_version_,
+ toolkit_version_, options_));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h
new file mode 100644
index 0000000..cce09c4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h
@@ -0,0 +1,98 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// cuBLAS GEMM in the most general form can run the following operation:
+//
+// (kAdd
+// (kMultiply (kDot A B) alpha)
+// (kMultiply C beta))
+//
+// where A, B, C are matrices or vectors and `alpha` and `beta` are host
+// constants. In matrix-vector multiplication, one operand must be a matrix and
+// the other must be a vector. The additional requirement is that C has no other
+// users (otherwise, it does not make sense to fuse it inside the custom call).
+//
+// Both multiplication and addition can be avoided (equivalent to setting
+// `alpha` to one and `beta` to zero).
+//
+// This pass pattern-matches the most general form of this instruction
+// (we assume transposes are already folded), and rewrites it into a custom call
+// where (A, B, C) are three operands respectively, and `alpha` and `beta` are
+// stored in the backend config.
+
+struct GemmRewriterOptions {
+ // The DType of the GEMM to rewrite.
+ enum class DType { kFp8Only, kNonFp8Only };
+ DType dtype = DType::kNonFp8Only;
+
+ // Disabling bias prevents using the `beta * C` term the GEMM, which can
+ // remove dependencies between multiple matrix multiplications. This, in
+ // turn, can improve the performance of overall computation by allowing
+ // multiple GEMMs to be scheduled in parallel.
+ //
+ // As an example, consider the following computation: `(A * A) + (B * B)`.
+ // With bias enabled, the `GemmRewriter` will emit the following GEMMs:
+ //
+ // AA := GEMM(A * A)
+ // ROOT := GEMM(B * B + AA)
+ //
+ // Because the second GEMM depends on the first, they cannot be scheduled in
+ // parallel. Instead, with bias disabled, the `GemmRewriter` will emit the
+ // following:
+ //
+ // AA := GEMM(A * A)
+ // BB := GEMM(B * B)
+ // ROOT := AA + BB
+ //
+ // In this case, the two GEMMs can be scheduled in parallel.
+ enum class BiasMode { kBias, kNoBias };
+ BiasMode bias_mode = BiasMode::kBias;
+};
+
+class GemmRewriter : public HloModulePass {
+ public:
+ GemmRewriter(se::GpuComputeCapability gpu_version, int32_t toolkit_version,
+ GemmRewriterOptions options = {});
+ absl::string_view name() const override { return "cublas-gemm-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::GpuComputeCapability gpu_version_;
+ int32_t toolkit_version_;
+ GemmRewriterOptions options_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc
new file mode 100644
index 0000000..e4e2aad
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc
@@ -0,0 +1,8423 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+
+#include <array>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/buffer_assignment.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/gpu_executable.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/test.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#elif TENSORFLOW_USE_ROCM
+#include "rocm/rocm_config.h"
+#endif
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace m = ::xla::match;
+
+class GemmRewriteTest : public GpuCodegenTest {
+ const auto& device_desc() {
+ return backend().default_stream_executor()->GetDeviceDescription();
+ }
+
+ protected:
+ const se::GpuComputeCapability& Capability() {
+ return device_desc().gpu_compute_capability();
+ }
+
+ int32_t GetToolkitVersion() const {
+#if GOOGLE_CUDA
+ return CUDA_VERSION;
+#elif TENSORFLOW_USE_ROCM
+ return TF_ROCM_VERSION;
+#endif
+ return 0;
+ }
+
+ bool IsCuda() {
+ return std::holds_alternative<se::CudaComputeCapability>(Capability());
+ }
+
+ se::GpuComputeCapability CudaHopperOrRocmMI300() {
+ if (IsCuda()) {
+ return se::CudaComputeCapability::Hopper();
+ } else {
+ return se::RocmComputeCapability{"gfx942"};
+ }
+ }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+ // These tests test the cuBLAS rewriter so we have to make sure that we use
+ // cuBLAS for them.
+ debug_options.set_xla_gpu_enable_triton_gemm(false);
+ debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+ return debug_options;
+ }
+
+ bool SkipGpuBlasLtTest() {
+ return !IsCuda() &&
+ !std::get<se::RocmComputeCapability>(Capability()).has_hipblaslt() &&
+ GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
+ }
+
+ bool HasFp8Support() {
+ if (IsCuda()) {
+ return std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(8, 9);
+ }
+ return std::get<se::RocmComputeCapability>(Capability()).has_fp8_support();
+ }
+
+ bool HasCudaComputeCapability(const se::CudaComputeCapability& cc) {
+ return IsCuda() &&
+ std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(cc);
+ }
+};
+
+TEST_F(GemmRewriteTest, CheckCustomCallTarget) {
+ if (SkipGpuBlasLtTest()) {
+ GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+ }
+
+ const char* hlo_text = R"(
+HloModule SimpleGemm
+
+ENTRY AddDotsFunc {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ if (debug_options.xla_gpu_enable_cublaslt()) {
+ MatchOptimizedHlo(hlo_text,
+ R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
+ } else {
+ MatchOptimizedHlo(hlo_text,
+ R"(; CHECK: custom_call_target="__cublas$gemm")");
+ }
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+TEST_F(GemmRewriteTest, TestBatchedAutotuning) {
+ if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP()
+ << "There is no autotuning starting with the Nvidia Ampere generation";
+ }
+
+ const char* hlo_text = R"(
+HloModule ComplexDotMultipleNonContracting
+
+ENTRY %test {
+ %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
+ %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
+ ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
+}
+
+)";
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: selected_algorithm
+ )");
+}
+#endif
+
+TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) {
+ if (SkipGpuBlasLtTest()) {
+ GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+ }
+
+ const char* hlo_text = R"(
+HloModule SimpleGemm
+
+ENTRY AddDotsFunc {
+ x = f32[128,128] parameter(0)
+ y = f32[128,128] parameter(1)
+ ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ ErrorSpec error_spec = [&] {
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ if (debug_options.xla_gpu_enable_cublaslt()) {
+ return ErrorSpec{1e-3, 1e-3};
+ } else {
+ return ErrorSpec{1e-3, 1e-3};
+ }
+ }();
+
+ auto get_module = [&]() {
+ HloModuleConfig config;
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_exclude_nondeterministic_ops(true);
+ config.set_debug_options(debug_options);
+ return ParseAndReturnVerifiedModule(hlo_text, config);
+ };
+
+ se::StreamExecutorMemoryAllocator allocator(
+ backend().default_stream_executor());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> optimized_module,
+ backend().compiler()->RunHloPasses(
+ *get_module(), backend().default_stream_executor(), &allocator));
+
+ absl::StatusOr<bool> filecheck_result =
+ RunFileCheck(optimized_module->ToString(),
+ R"(
+; CHECK: custom_call_target="__cublas${{(lt\$matmul|gemm)}}"
+ )");
+ TF_ASSERT_OK(filecheck_result.status());
+ EXPECT_TRUE(filecheck_result.value());
+ EXPECT_TRUE(RunAndCompare(*get_module(), error_spec));
+}
+
+TEST_F(GemmRewriteTest, BF16GemmCodeGen) {
+ const char* hlo_text = R"(
+HloModule bf16codegendgemm
+
+ENTRY bf16gemm {
+ %parameter.1 = bf16[3]{0} parameter(0)
+ %parameter.2 = bf16[3]{0} parameter(1)
+ ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+}
+ )";
+
+ if (HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) {
+ // The Hopper optimized HLO has a BF16 multiply instruction since Hopper has
+ // native BF16 multiply support.
+ MatchOptimizedHlo(hlo_text, R"(
+ ; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
+ ; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
+ ; CHECK: [[INSTR_2:%[^ ]+]] = bf16[3]{0} multiply([[P0]], [[P1]])
+ ; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[INSTR_2]])
+ ; CHECK: [[INSTR_4:%[^ ]+]] = f32[] constant(0)
+ ; CHECK: [[INSTR_5:%[^ ]+]] = f32[] reduce([[INSTR_3]], [[INSTR_4]]), dimensions={0}, to_apply=[[INSTR_6:%[^ ]+]]
+ ; CHECK: ROOT [[INSTR_7:%[^ ]+]] = bf16[] convert([[INSTR_5]])
+ )");
+ } else {
+ MatchOptimizedHlo(hlo_text, R"(
+ ; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
+ ; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
+ ; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
+ ; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
+ ; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
+ ; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
+ ; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
+ ; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
+ )");
+ }
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-4}));
+}
+
+TEST_F(GemmRewriteTest, BF16Transpose) {
+ const char* hlo_text = R"(
+HloModule broadcast
+
+ENTRY broadcast {
+ p = bf16[9] parameter(0)
+ ROOT out = bf16[1,9] broadcast(p), dimensions={1}
+}
+)";
+
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK: bf16[1,9]{1,0} bitcast
+; CHECK: bf16[1,9]{1,0} copy
+)");
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// A test fixture class for tests which should have similar results with legacy
+// cublas and cublasLt
+class ParameterizedGemmRewriteTest
+ : public GemmRewriteTest,
+ public ::testing::WithParamInterface<bool> {
+ public:
+ ParameterizedGemmRewriteTest() {
+ const bool kUsingCublasLt = GetParam();
+ replacements_[kCustomCallTargetPlaceholder] =
+ kUsingCublasLt ? "__cublas$lt$matmul" : "__cublas$gemm";
+ }
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_cublaslt(GetParam());
+ debug_options.set_xla_gpu_enable_triton_gemm(false);
+ return debug_options;
+ }
+ void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
+ bool print_operand_shape = false) {
+ GemmRewriteTest::MatchOptimizedHlo(
+ hlo, absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
+ }
+ absl::string_view CustomCallTarget() {
+ return replacements_[kCustomCallTargetPlaceholder];
+ }
+
+ protected:
+ void SetUp() override {
+ if (SkipGpuBlasLtTest()) {
+ GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+ }
+ }
+
+ protected:
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
+
+ private:
+ static constexpr const char* kCustomCallTargetPlaceholder{
+ "<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"};
+};
+
+TEST_P(ParameterizedGemmRewriteTest, Simple) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, SimpleRewrite) {
+ const char* hlo_text = R"(
+HloModule SimpleGemm
+
+ENTRY AddDotsFunc {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, MultipleContractingDims) {
+ const char* hlo_text = R"(
+HloModule MultipleContractingCheckGemm
+
+ENTRY AddDotsFunc {
+ x = f32[3,4,2] parameter(0)
+ y = f32[3,4,5] parameter(1)
+ ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, operand_precision={highest,highest}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-NOT: copy
+;
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,4,2], {{.*}}: f32[3,4,5]) -> f32[2,5] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1)
+; CHECK-DAG: [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]])
+; CHECK-DAG: [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]])
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[BITCAST0]], [[BITCAST1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, ArgTransposeFoldCheck) {
+ const char* hlo_text = R"(
+HloModule ArgTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+ x = f32[3,2] parameter(0)
+ y = f32[3,4] parameter(1)
+ x_transposed = f32[2,3] transpose(x), dimensions={1, 0}
+ ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["0"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchedArgRowColTransposeFoldCheck) {
+ const char* hlo_text = R"(
+HloModule BatchedArgRowColTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+ x = f32[5,3,2] parameter(0)
+ y = f32[5,3,4] parameter(1)
+ x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1}
+ ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,3,2], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":["0"]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchRowTransposeFoldCheck) {
+ const char* hlo_text = R"(
+HloModule BatchRowTransposeFoldCheck
+
+ENTRY AddDotsFunc {
+ x = f32[2,5,3] parameter(0)
+ y = f32[5,3,4] parameter(1)
+ x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2}
+ ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,5,3], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["2"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":["1"]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) {
+ const char* hlo_text = R"(
+HloModule BatchFromMinorDimTransposeDoesntFold
+
+ENTRY AddDotsFunc {
+ x = f32[3,2,5] parameter(0)
+ y = f32[5,3,4] parameter(1)
+ x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0}
+ ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,5], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]])
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[FUSION]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["2"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":["0"]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, LargeBatch) {
+ const char* hlo_text = R"(
+HloModule BatchedArgRowColTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+ x = f32[20000,4,3,2] parameter(0)
+ y = f32[20000,4,3,4] parameter(1)
+ ROOT dot_a = f32[20000,4,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+}
+
+)";
+
+ // Batch sizes larger than 2^16-1 are not supported by cublasLt. Ensure that
+ // the custom_call_target is __cublas$gemm.
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[20000,4,3,2], {{.*}}: f32[20000,4,3,4]) -> f32[20000,4,2,4] {
+; CHECK: [[P0:%[^ ]+]] = f32[20000,4,3,2]{3,2,1,0} parameter(0)
+; CHECK: [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]])
+; CHECK: [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1)
+; CHECK: [[BC1:%[^ ]+]] = f32[80000,3,4]{2,1,0} bitcast([[P1]])
+; CHECK: [[GEMM:%[^ ]+]] = (f32[80000,2,4]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[BC0]], [[BC1]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":["0"]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK: }
+; CHECK: [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} get-tuple-element([[GEMM]]), index=0
+; CHECK: ROOT {{[^ ]+}} = f32[20000,4,2,4]{3,2,1,0} bitcast([[OUT]])
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, InstrTransposeFoldCheck) {
+ const char* hlo_text = R"(
+HloModule InstrTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[4,2] {
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P1]], [[P0]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["0"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutTransposed) {
+ const char* hlo_text = R"(
+HloModule BatchedInstrLayoutCheck
+
+ENTRY AddDotsFunc {
+ x = f32[5,2,3] parameter(0)
+ y = f32[5,3,4] parameter(1)
+ dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+ ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,5,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["2"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":["0"]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) {
+ const char* hlo_text = R"(
+HloModule BatchedInstrLayoutBatchNotInMinorDim
+
+ENTRY AddDotsFunc {
+ x = f32[5,2,3] parameter(0)
+ y = f32[5,3,4] parameter(1)
+ dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+ ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,4,5] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["2"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":["0"]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, AlphaSimpleRewrite) {
+ const char* hlo_text = R"(
+HloModule AlphaSimpleRewrite
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ k = f32[] constant(3.0)
+ k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, F64C64_CublasLtSupportTest) {
+ // This test should fail if gemm rewriter does not correctly rewrite
+ // F64/C64 dots to cublas-lt or legacy cublas calls
+ {
+ const char* hlo_text = R"(
+HloModule F64_rewrite
+
+ENTRY AddDotsFunc {
+ x = f64[2,2] parameter(0)
+ y = f64[2,2] parameter(1)
+ k = f64[] constant(3.0)
+ k_broadcast = f64[2, 2] broadcast(k), dimensions={}
+ dot_a = f64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT dot_a_multiplied = f64[2, 2] multiply(dot_a, k_broadcast)
+}
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
+ }
+ {
+ const char* hlo_text = R"(
+HloModule C64_rewrite
+
+ENTRY AddDotsFunc {
+ x = c64[2,2] parameter(0)
+ y = c64[2,2] parameter(1)
+ k = c64[] constant((3.0, 3.0))
+ k_broadcast = c64[2, 2] broadcast(k), dimensions={}
+ dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
+}
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) {
+ if (!IsCuda() && GetDebugOptionsForTest().xla_gpu_enable_cublaslt()) {
+ GTEST_SKIP() << "TODO: Unsupported C64 gpublas-lt datatype on ROCM";
+ }
+ const char* hlo_text = R"(
+HloModule ComplexAlphaSimpleRewrite
+
+ENTRY AddDotsFunc {
+ x = c64[2,2] parameter(0)
+ y = c64[2,2] parameter(1)
+ k = c64[] constant((3.0, 3.0))
+ k_broadcast = c64[2, 2] broadcast(k), dimensions={}
+ dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: c64[2,2], {{.*}}: c64[2,2]) -> c64[2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":3
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, AlphaMultipleUsersNoRewrite) {
+ const char* hlo_text = R"(
+HloModule AlphaMultipleUsersNoRewrite
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ k = f32[] constant(3.0)
+ k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+ ROOT out = f32[2,2] add(dot_a_multiplied, dot_a)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{[^ ]+}} = {{.*}} custom-call({{[^,]+}}, {{[^)]+}}),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, AlphaVectorNoRewrite) {
+ const char* hlo_text = R"(
+HloModule AlphaVectorNoRewrite
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ alpha = f32[2] constant({1, 2})
+ alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1}
+ dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast)
+}
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BF16Gemm) {
+ const char* hlo_text = R"(
+HloModule bf16gemm
+
+ENTRY bf16gemm {
+ %parameter.1 = bf16[12,4]{1,0} parameter(0)
+ %parameter.2 = bf16[4,8]{1,0} parameter(1)
+ ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+ )";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+ )",
+ /*print_operand_shape=*/true);
+ } else {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BF16GemmStrided) {
+ const char* hlo_text = R"(
+HloModule bf16gemm
+
+ENTRY bf16gemm {
+ %parameter.1 = bf16[3,3,4] parameter(0)
+ %parameter.2 = bf16[3,3,2] parameter(1)
+ ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest}
+}
+
+ )";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+ ; CHECK: {{.*}} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+ )",
+ /*print_operand_shape=*/true);
+ } else {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8Gemm) {
+ const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+ %parameter.1 = s8[12,4]{1,0} parameter(0)
+ %parameter.2 = s8[4,8]{1,0} parameter(1)
+ ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+ )";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
+ )",
+ /*print_operand_shape=*/true);
+ } else {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+ )",
+ /*print_operand_shape=*/true);
+ }
+}
+
+TEST_F(GemmRewriteTest, Int8GemmRankGreaterThanTwo) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
+ }
+
+ const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY main.4 {
+ Arg_0.1 = s8[1,8,2]{2,1,0} parameter(0)
+ Arg_1.2 = s8[2,4]{1,0} parameter(1)
+ ROOT dot.3 = s32[1,8,4]{2,1,0} dot(Arg_0.1, Arg_1.2),
+ lhs_contracting_dims={2}, rhs_contracting_dims={0}
+}
+ )";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %{{.*}}, s8[4,4]{0,1} %{{.*}}), custom_call_target="__cublas$gemm",
+ )",
+ /*print_operand_shape=*/true);
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoAlphaRewrite) {
+ const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+ %parameter.1 = s8[12,4]{1,0} parameter(0)
+ %parameter.2 = s8[4,8]{1,0} parameter(1)
+ k = s32[] constant(2)
+ k_broadcast = s32[12,8] broadcast(k), dimensions={}
+ %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast)
+}
+ )";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+ )",
+ /*print_operand_shape=*/true);
+ } else {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+ )",
+ /*print_operand_shape=*/true);
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoBetaRewrite) {
+ const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+ %parameter.1 = s8[12,4]{1,0} parameter(0)
+ %parameter.2 = s8[4,8]{1,0} parameter(1)
+ bias = s32[12,8] parameter(2)
+ %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = s32[12,8] add(%dot.8, bias)
+}
+ )";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+ )",
+ /*print_operand_shape=*/true);
+ } else {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+ )",
+ /*print_operand_shape=*/true);
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8GemmNotMultipleOfFour) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
+ }
+
+ const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+ %parameter.1 = s8[13,4]{1,0} parameter(0)
+ %parameter.2 = s8[4,9]{1,0} parameter(1)
+ ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+ )";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
+ )",
+ /*print_operand_shape=*/true);
+ } else {
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: {{.*}} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+ )",
+ /*print_operand_shape=*/true);
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
+ }
+
+ std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
+ type_combinations = {{"s8", "s8", true},
+ {"s32", "s32", true},
+ {"bf16", "bf16", true},
+ {"f16", "f16", true},
+ {"f32", "f32", true},
+ {"f64", "f64", true},
+ {"c64", "c64", true},
+ {"c128", "c128", true},
+ // add mix type gemm
+ {"s8", "s32", true},
+ {"s8", "f32", true},
+ {"f16", "f32", true},
+ {"bf16", "f32", true}};
+
+ if (!IsCuda() ||
+ HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ // For compute capabilities before Ampere, we may do upcasting, so it
+ // would be impossible for this test to fail. That is why we only add these
+ // cases when the compute capability is at least Volta.
+ std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
+ more_type_combinations = {
+ {"s8", "bf16", false}, {"s8", "f16", false},
+ {"s8", "f64", false}, {"s8", "c64", false},
+ {"s8", "c128", false},
+
+ {"s32", "f32", false}, {"s32", "f64", false},
+ {"s32", "c64", false}, {"s32", "c128", false},
+
+ {"f16", "bf16", false}, {"f16", "f64", false},
+ {"f16", "c64", false}, {"f16", "c128", false},
+
+ {"bf16", "f16", false}, {"bf16", "f64", false},
+ {"bf16", "c64", false}, {"bf16", "c128", false},
+
+ {"f32", "f64", false}, {"f32", "c64", false},
+ {"f32", "c128", false},
+
+ {"f64", "c64", false}, {"f64", "c128", false},
+ };
+ type_combinations.insert(type_combinations.end(),
+ more_type_combinations.begin(),
+ more_type_combinations.end());
+ }
+
+ for (const auto& type_combination : type_combinations) {
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<ABType>>"] = std::get<0>(type_combination);
+ replacements["<<DType>>"] = std::get<1>(type_combination);
+ const char* hlo_template = R"(
+ HloModule type_combo
+
+ ENTRY type_combo {
+ %parameter.1 = <<ABType>>[4,4]{1,0} parameter(0)
+ %parameter.2 = <<ABType>>[4,4]{1,0} parameter(1)
+ ROOT %dot = <<DType>>[4,4] dot(%parameter.1, %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+ )";
+ const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
+ if (std::get<2>(type_combination)) {
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ } else {
+ EXPECT_FALSE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ }
+ }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingBf16ToF64) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ Arg_0.1 = bf16[4,3]{1,0} parameter(0)
+ Arg_1.2 = bf16[3,6]{1,0} parameter(1)
+ ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ // This is a type combination which is not supported by cublasLt, expect
+ // GemmRewriter to choose legacy cublas.
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ Arg_0.1 = c64[4,3]{1,0} parameter(0)
+ Arg_1.2 = c64[3,6]{1,0} parameter(1)
+ ROOT dot.3 = c128[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ // This is a type combination which is not supported by cublasLt, expect
+ // GemmRewriter to choose legacy cublas.
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ Arg_0.1 = f16[4,3]{1,0} parameter(0)
+ Arg_1.2 = f16[3,6]{1,0} parameter(1)
+ ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest, highest}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ Arg_0.1 = f16[4,3]{1,0} parameter(0)
+ Arg_1.2 = f16[3,6]{1,0} parameter(1)
+ ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ // This is a type combination which is not supported by cublasLt, expect
+ // GemmRewriter to choose legacy cublas.
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ Arg_0.1 = f32[4,3]{1,0} parameter(0)
+ Arg_1.2 = f32[3,6]{1,0} parameter(1)
+ ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ // This is a type combination which is not supported by cublasLt, expect
+ // GemmRewriter to choose legacy cublas.
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, DoNotUpconvertOutput) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+ param_0 = f16[240,88]{1,0} parameter(0)
+ param_1 = f16[88,4]{1,0} parameter(1)
+ dot = f16[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ constant_255 = f16[] constant(255)
+ broadcast = f16[240,4]{1,0} broadcast(constant_255), dimensions={}
+ multiply = f16[240,4]{1,0} multiply(dot, broadcast)
+ ROOT result = f32[240,4]{1,0} convert(multiply)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ // input fp16 and output fp32 combination is supported by legacy cublas and
+ // cublasLt, expect GemmRewriter to fuse the convert into gemm.
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Convert(
+ m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UnsupportedMixTypeGemm) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+ param_0 = f32[240,88]{1,0} parameter(0)
+ param_1 = f32[88,4]{1,0} parameter(1)
+ dot = f32[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ constant_255 = f32[] constant(255)
+ broadcast = f32[240,4]{1,0} broadcast(constant_255), dimensions={}
+ multiply = f32[240,4]{1,0} multiply(dot, broadcast)
+ ROOT result = u8[240,4]{1,0} convert(multiply)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ // u8 is not supported by legacy cublas and cublasLt, expect
+ // GemmRewriter to not fuse the convert into gemm.
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Convert(
+ m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, CheckIsGemmAliasedBeforeFusion) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+ Arg_0.1 = f16[8,16]{1,0} parameter(0)
+ Arg_1.2 = f16[16,32]{1,0} parameter(1)
+ dot.8 = f16[8,32]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ Arg_2.3 = f16[8,32]{1,0} parameter(2)
+ constant.5 = f16[] constant(1)
+ broadcast.6 = f16[8,32]{1,0} broadcast(constant.5), dimensions={}
+ add.7 = f16[8,32]{1,0} add(Arg_2.3, broadcast.6)
+ add.9 = f16[8,32]{1,0} add(dot.8, add.7)
+ convert.10 = f32[8,32]{1,0} convert(add.9)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ // input fp16 and output fp32 combination is supported by legacy cublas and
+ // cublasLt, but gemm output is already aliased with one of the input expect
+ // GemmRewriter to not fuse the convert into gemm.
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Convert(
+ m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
+}
+
+INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt,
+ ParameterizedGemmRewriteTest, ::testing::Bool());
+#endif
+
+// A test fixture class for tests which are specific to legacy cublas
+class LegacyCublasGemmRewriteTest : public GemmRewriteTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_triton_gemm(false);
+ debug_options.set_xla_gpu_enable_cublaslt(false);
+ return debug_options;
+ }
+};
+
+TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplication) {
+ const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f32[2048] parameter(0)
+ p1 = f32[2048, 16384] parameter(1)
+ ROOT d = f32[16384] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})";
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(
+ se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
+ /*toolkit_version=*/12040),
+ R"(
+; CHECK: %[[P0:.+]] = f32[2048]{0} parameter(0)
+; CHECK: %[[P1:.+]] = f32[2048,16384]{1,0} parameter(1)
+; CHECK: %[[CUSTOM_CALL:.+]] = (f32[16384]{0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplicationWithBatch) {
+ const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+ p0 = f32[10, 10, 2048] parameter(0)
+ p1 = f32[10, 10, 2048, 16384] parameter(1)
+ ROOT d = f32[10, 10, 16384] dot(p0, p1),
+ lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1},
+ lhs_contracting_dims={2}, rhs_contracting_dims={2}
+})";
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(
+ se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
+ /*toolkit_version=*/12040),
+ R"(
+; CHECK: %[[P0:.+]] = f32[10,10,2048]{2,1,0} parameter(0)
+; CHECK: %[[P1:.+]] = f32[10,10,2048,16384]{3,2,1,0} parameter(1)
+; CHECK: %[[CUSTOM_CALL:.+]] = (f32[10,10,16384]{2,1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, SparseDotNotSupported) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+ lhs = f16[5,16] parameter(0)
+ rhs = f16[32,10] parameter(1)
+ meta = u16[5,2] parameter(2)
+ ROOT dot = f32[5,10] dot(lhs, rhs, meta),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
+})";
+ auto hlo_pass = GemmRewriter(
+ se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
+ /*toolkit_version=*/12040);
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get()));
+ EXPECT_FALSE(changed);
+}
+
+// Test that the alpha and beta fields of the GemmBackendConfig are updated.
+// A bias must be present for the beta value to be set.
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, AlphaBetaRewrite) {
+ const char* hlo_text = R"(
+HloModule NonZeroAlphaBeta
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ param_2 = f32[2,2] parameter(2)
+ bias = f32[2,2] negate(param_2)
+ k = f32[] constant(3.0)
+ k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+ ROOT out = f32[2,2] add(dot_a_multiplied, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK: [[O:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: output_to_operand_aliasing={
+; CHECK-SAME: {0}: (2, {})
+; CHECK-SAME: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[O]]), index=0
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
+ const char* hlo_text = R"(
+HloModule BiasMultipleUsersNoOverwrite
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] parameter(2)
+ k = f32[] constant(3.0)
+ k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+ biased_out = f32[2,2] add(dot_a_multiplied, bias)
+ ROOT out = f32[2,2] add(biased_out, bias)
+}
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT: [[CUSTOM_CALL:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, BiasParameterNoOverwrite) {
+ const char* hlo_text = R"(
+HloModule BiasParameterNoOverwrite
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] parameter(2)
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,2] add(dot_a, bias)
+}
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, BiasTupleParameterOverwrite) {
+ const char* hlo_text = R"(
+HloModule BiasTupleParameterOverwrite
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ param_2 = (f32[2,2], f32[3,3]) parameter(2)
+ bias = f32[2,2] get-tuple-element(param_2), index=0
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,2] add(dot_a, bias)
+}
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: (f32[2,2], f32[3,3])) -> f32[2,2] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG: [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2)
+; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[P2]]), index=0
+; CHECK-DAG: [[BIAS_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[BIAS]])
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS_COPY]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: output_to_operand_aliasing={
+; CHECK-SAME: {0}: (2, {})
+; CHECK-SAME: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, AliasedBiasOverwrite) {
+ const char* hlo_text = R"(
+HloModule AliasedBiasOverwrite, input_output_alias={ {}: (2, {}, must-alias) }
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] parameter(2)
+ k = f32[] constant(3.0)
+ k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+ ROOT out = f32[2,2] add(dot_a_multiplied, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
+; CHECK: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: output_to_operand_aliasing={
+; CHECK-SAME: {0}: (2, {})
+; CHECK-SAME: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
+ const char* hlo_text = R"(
+HloModule LargerBiasMultipleUsersNoRewrite
+
+ENTRY AddDotsFunc {
+ x = f32[1024,1024] parameter(0)
+ y = f32[1024,1024] parameter(1)
+ bias = f32[1024,1024] parameter(2)
+ dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ biased_out = f32[1024,1024] add(dot_a, bias)
+ ROOT out = f32[1024,1024] add(biased_out, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, BF16GemmWithBias) {
+ const char* hlo_text = R"(
+HloModule BF16GemmWithBias
+
+ENTRY BF16GemmWithBias {
+ x = bf16[8,8]{1,0} parameter(0)
+ y = bf16[8,8]{1,0} parameter(1)
+ dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ param_2 = bf16[8,8]{1,0} parameter(2)
+ bias = bf16[8,8]{1,0} negate(param_2)
+ ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
+}
+ )";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 2e-3}));
+
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
+; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
+; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
+; CHECK: [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: output_to_operand_aliasing={
+; CHECK-SAME: {0}: (2, {})
+; CHECK-SAME: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ param_2 = f32[2,4] parameter(2)
+ bias = f32[2,4] negate(param_2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,4] add(dot_a, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK: [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], {{[^,)]+}}),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: output_to_operand_aliasing={
+; CHECK-SAME: {0}: (2, {})
+; CHECK-SAME: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ w = f32[2,3] parameter(0)
+ x = f32[3,4] parameter(1)
+ first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ y = f32[2,3] parameter(2)
+ z = f32[3,4] parameter(3)
+ second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,4] add(second_dot, first_dot)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
+; CHECK-NEXT: [[FIRST_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[FIRST_GEMM_OUT:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM]]), index=0
+; CHECK-NEXT: [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM_OUT]]),
+; CHECK: custom_call_target="__cublas$gemm",
+; CHECK: output_to_operand_aliasing={
+; CHECK-SAME: {0}: (2, {})
+; CHECK-SAME: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Test gemm matrix bias add fusion with mix type
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixType) {
+ std::vector<std::tuple<absl::string_view, absl::string_view>>
+ type_combinations = {
+ {"f16", "f32"},
+ {"bf16", "f32"},
+ };
+
+ const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+ x = <<ABType>>[16,32] parameter(0)
+ y = <<ABType>>[32,16] parameter(1)
+ z = <<DType>>[16,16] parameter(2)
+ dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias = <<DType>>[16,16] negate(z)
+ convert = <<DType>>[16,16] convert(dot_a)
+ ROOT out = <<DType>>[16,16] add(convert, bias)
+}
+
+)";
+ for (const auto& type_combination : type_combinations) {
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<ABType>>"] = std::get<0>(type_combination);
+ replacements["<<DType>>"] = std::get<1>(type_combination);
+ const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ continue;
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1),
+ m::Negate(m::Parameter(2))),
+ 0)));
+ }
+}
+
+// Test batch gemm matrix bias add fusion with mix type
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeBatched) {
+ std::vector<std::tuple<absl::string_view, absl::string_view>>
+ type_combinations = {
+ {"f16", "f32"},
+ {"bf16", "f32"},
+ };
+
+ const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+ x = <<ABType>>[4,16,32] parameter(0)
+ y = <<ABType>>[4,32,16] parameter(1)
+ z = <<DType>>[4,16,16] parameter(2)
+ dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+ bias = <<DType>>[4,16,16] negate(z)
+ convert = <<DType>>[4,16,16] convert(dot_a)
+ ROOT out = <<DType>>[4,16,16] add(convert, bias)
+})";
+ for (const auto& type_combination : type_combinations) {
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<ABType>>"] = std::get<0>(type_combination);
+ replacements["<<DType>>"] = std::get<1>(type_combination);
+ const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ continue;
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1),
+ m::Negate(m::Parameter(2))),
+ 0)));
+ }
+}
+#endif
+
+// Test batch gemm matrix bias add fusion with mix type that is not supported.
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[16,32] parameter(0)
+ y = bf16[32,16] parameter(1)
+ z = f64[16,16] parameter(2)
+ dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias = f64[16,16] negate(z)
+ convert = f64[16,16] convert(dot_a)
+ ROOT out = f64[16,16] add(convert, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
+; CHECK: %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
+; CHECK: ROOT {{.*}} fusion({{.*}}%[[gte]]
+)");
+}
+
+// Test batch gemm matrix bias add fusion with mix type that is not supported
+// because there are consumers of bias add.
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeAddWithMoreConsumers) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[16,32] parameter(0)
+ y = bf16[32,16] parameter(1)
+ z = f32[16,16] parameter(2)
+ dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias = f32[16,16] negate(z)
+ convert = f32[16,16] convert(dot_a)
+ add_bias = f32[16,16] add(convert, bias)
+ ROOT out = f32[16,16] negate(add_bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
+; CHECK: %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
+; CHECK: ROOT {{.*}} fusion({{.*}}%[[gte]]
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) {
+ const char* hlo_text = R"(
+HloModule test
+ENTRY test {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[4] parameter(2)
+ dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[4] add(f32[4] bitcast(dot), bias)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Bitcast(
+ m::GetTupleElement(
+ m::CustomCall(
+ {"__cublas$gemm"}, m::Parameter(0), m::Parameter(1),
+ m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})),
+ 0))
+ .WithShape(F32, {4})));
+}
+
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, FoldConstantBias) {
+ const char* hlo_text = R"(
+HloModule test
+ENTRY test {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
+
+ dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ param_2 = f32[2,2] parameter(2)
+ bias1 = f32[2,2] negate(param_2)
+ sum1 = add(dot1, bias1)
+
+ dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ sum2 = add(dot2, f32[2,2] reshape(bias))
+
+ dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias3 = f32[2,2] transpose(bias), dimensions={1,0}
+ sum3 = add(dot3, bias3)
+
+ dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ sum4 = add(dot4, f32[2,2] bitcast(bias))
+
+ ROOT root = tuple(sum1, sum2, sum3, sum4)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ SCOPED_TRACE(module->ToString());
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(m::CustomCall(m::Parameter(0), m::Parameter(1),
+ m::Negate(m::Parameter(2))),
+ 0),
+ m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+ 0),
+ m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+ 0),
+ m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+ 0))));
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// A test fixture class for tests which are specific to cublasLt
+class CublasLtGemmRewriteTest : public GemmRewriteTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_enable_cublaslt(true);
+ debug_options.set_xla_gpu_enable_triton_gemm(false);
+ return debug_options;
+ }
+
+ protected:
+ void SetUp() override {
+ if (SkipGpuBlasLtTest()) {
+ GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+ }
+ }
+};
+
+TEST_F(CublasLtGemmRewriteTest, AlphaBetaRewrite) {
+ const char* hlo_text = R"(
+HloModule NonZeroAlphaBeta
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] parameter(2)
+ k = f32[] constant(3.0)
+ k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+ ROOT out = f32[2,2] add(dot_a_multiplied, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK-NEXT ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element(%cublas-lt-matmul.2.0), index=0
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
+ const char* hlo_text = R"(
+HloModule BiasMultipleUsersNoOverwrite
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] parameter(2)
+ k = f32[] constant(3.0)
+ k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+ dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+ biased_out = f32[2,2] add(dot_a_multiplied, bias)
+ ROOT out = f32[2,2] add(biased_out, bias)
+}
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK-NOT: output_to_operand_aliasing
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
+ const char* hlo_text = R"(
+HloModule LargerBiasMultipleUsersNoRewrite
+
+ENTRY AddDotsFunc {
+ x = f32[1024,1024] parameter(0)
+ y = f32[1024,1024] parameter(1)
+ bias = f32[1024,1024] parameter(2)
+ dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ biased_out = f32[1024,1024] add(dot_a, bias)
+ ROOT out = f32[1024,1024] add(biased_out, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
+; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[1024,1024]{1,0} parameter(2)
+; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[1024,1024]{1,0} add([[GEMM]], [[BIAS]])
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BF16GemmWithBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY BF16GemmWithBias {
+ x = bf16[8,8]{1,0} parameter(0)
+ y = bf16[8,8]{1,0} parameter(1)
+ dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias = bf16[8,8]{1,0} parameter(2)
+ ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
+}
+ )";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
+; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
+; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
+; CHECK-DAG: [[BIAS:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[2,4] parameter(2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,4] add(dot_a, z)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ w = f32[2,3] parameter(0)
+ x = f32[3,4] parameter(1)
+ first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ y = f32[2,3] parameter(2)
+ z = f32[3,4] parameter(3)
+ second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,4] add(second_dot, first_dot)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
+; CHECK-NEXT: [[FIRST_GEMM_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM_TUPLE]]), index=0
+; CHECK-NEXT: [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: output_to_operand_aliasing={
+; CHECK: {0}: (2, {})
+; CHECK: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[4] parameter(2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z), dimensions={1}
+ ROOT out = f32[2,4] add(dot_a, z_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+)");
+}
+
+// Epilogue Fusion disabled when GEMM has multiple users.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasMultipleUsers) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[4,4] parameter(0)
+ y = f32[4,4] parameter(1)
+ z = f32[4] parameter(2)
+ c = f32[] constant(5)
+ dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ z_bcast = f32[4,4] broadcast(z), dimensions={1}
+ add_a = f32[4,4] add(dot_a, z_bcast)
+ c_bcast = f32[4,4] broadcast(c), dimensions={}
+ dot_b = f32[4,4] dot(dot_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ ROOT out = f32[4,4] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[4,4], [[DUMMY1:[^ ]+]]: f32[4]) -> f32[4,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} broadcast([[P1]]), dimensions={1}
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} add([[P0]], [[P2]])
+}
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4]) -> f32[4,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
+; CHECK-NEXT: [[MATMUL0_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK-NEXT: [[MATMUL0:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL0_TUPLE]]), index=0
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[4,4]{1,0} fusion([[MATMUL0]], [[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]]
+; CHECK: [[MATMUL1_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[MATMUL0]]
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[FUSION]], [[MATMUL1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedVectorBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3,4] parameter(0)
+ y = f32[4,5,6] parameter(1)
+ z = f32[3,5,6] parameter(2)
+ dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
+ ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
+; CHECK: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: output_to_operand_aliasing={
+; CHECK: {0}: (2, {})
+; CHECK: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedSharedVectorBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3,4] parameter(0)
+ y = f32[4,5,6] parameter(1)
+ z = f32[6] parameter(2)
+ dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ z_bcast = f32[2,3,5,6] broadcast(z), dimensions={3}
+ ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[6]) -> f32[2,3,5,6] {
+; CHECK: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: output_to_operand_aliasing={
+; CHECK: {0}: (2, {})
+; CHECK: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[2] parameter(2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z), dimensions={0}
+ add = f32[2,4] add(dot_a, z_bcast)
+ ROOT out = f32[4,2] transpose(add), dimensions={1,0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
+; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasSliced) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[4,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[3] parameter(2)
+ dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ slice_a = f32[2,3] slice(dot_a), slice={[0:2], [0:3]}
+ z_bcast = f32[2,3] broadcast(z), dimensions={1}
+ ROOT out = f32[2,3] add(slice_a, z_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,3], {{.*}}: f32[3,4], {{.*}}: f32[3]) -> f32[2,3] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3]{0} parameter(2)
+; CHECK-NEXT: [[MATMUL:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+; CHECK-NEXT: [[GETTUPLE:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3]{1,0} slice([[GETTUPLE]]), slice={[0:2], [0:3]}
+ )");
+}
+
+// Epilogue Fusion disabled when slice has multiple users.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasSlicedMultipleUsers) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[2] parameter(2)
+ c = f32[] constant(5)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
+ z_bcast = f32[2,2] broadcast(z), dimensions={1}
+ add_a = f32[2,2] add(slice_a, z_bcast)
+ c_bcast = f32[2,2] broadcast(c), dimensions={}
+ dot_b = f32[2,2] dot(slice_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,2] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[2,2] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
+; CHECK-NEXT: [[MATMUL0_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[MATMUL1_TUPLE:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call{{.*}}[[MATMUL1]]
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposed) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[2] parameter(2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] parameter(3)
+ ROOT out = f32[2,4] add(dot_a, z_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2_BCAST]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[4] parameter(2)
+ z2 = f32[2,4] parameter(3)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z), dimensions={1}
+ add0 = f32[2,4] add(dot_a, z_bcast)
+ ROOT add1 = f32[2,4] add(add0, z2)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG: [[VECTOR_BIAS:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-DAG: [[MATRIX_BIAS:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[MATRIX_BIAS]], [[VECTOR_BIAS]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BF16VectorBias) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[16,24] parameter(0)
+ y = bf16[24,32] parameter(1)
+ z = bf16[32] parameter(2)
+ dot_a = bf16[16,32] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = bf16[16,32] broadcast(z), dimensions={1}
+ ROOT out = bf16[16,32] add(dot_a, z_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{3e-3, 1e-3}));
+
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[16,24], {{.*}}: bf16[24,32], {{.*}}: bf16[32]) -> bf16[16,32] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[16,24]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[24,32]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32]{0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BF16VectorBiasPadded) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
+ "architectures with bf16 Tensor Cores.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[2,3] parameter(0)
+ y = bf16[3,4] parameter(1)
+ z = bf16[4] parameter(2)
+ dot_a = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = bf16[2,4] broadcast(z), dimensions={1}
+ ROOT out = bf16[2,4] add(dot_a, z_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] {
+; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
+; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ c = f32[] constant(0)
+ c_bcast = f32[2,4] broadcast(c), dimensions={}
+ ROOT out = f32[2,4] maximum(dot_a, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3,4] parameter(0)
+ y = f32[4,5,6] parameter(1)
+ dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ c = f32[] constant(0)
+ c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
+ ROOT out = f32[2,3,5,6] maximum(dot_a, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6]) -> f32[2,3,5,6] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1)
+; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0}
+; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+; CHECK: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivationSliced) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ c = f32[] constant(0)
+ c_bcast = f32[2,2] broadcast(c), dimensions={}
+ slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
+ ROOT out = f32[2,2] maximum(slice_a, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+; CHECK: [[MATMUL:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL]]), slice={[0:2], [0:2]}
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[2,4] parameter(2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ add = f32[2,4] add(dot_a, z)
+ c = f32[] constant(0)
+ c_bcast = f32[2,4] broadcast(c), dimensions={}
+ ROOT out = f32[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, SquareMatrixBiasReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[4,4] parameter(0)
+ y = f32[4,4] parameter(1)
+ z = f32[4,4] parameter(2)
+ dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ add = f32[4,4] add(dot_a, z)
+ c = f32[] constant(0)
+ c_bcast = f32[4,4] broadcast(c), dimensions={}
+ ROOT out = f32[4,4] maximum(add, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4,4]) -> f32[4,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[4] parameter(2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z), dimensions={1}
+ add = f32[2,4] add(dot_a, z_bcast)
+ c = f32[] constant(0)
+ c_bcast = f32[2,4] broadcast(c), dimensions={}
+ ROOT out = f32[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_RELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedVectorBiasReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3,4] parameter(0)
+ y = f32[4,5,6] parameter(1)
+ z = f32[3,5,6] parameter(2)
+ dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
+ add = f32[2,3,5,6] add(dot_a, z_bcast)
+ c = f32[] constant(0)
+ c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
+ ROOT out = f32[2,3,5,6] maximum(add, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
+; CHECK: [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposedReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[2] parameter(2)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z), dimensions={0}
+ add = f32[2,4] add(dot_a, z_bcast)
+ c = f32[] constant(0)
+ c_bcast = f32[2,4] broadcast(c), dimensions={}
+ maximum = f32[2,4] maximum(add, c_bcast)
+ ROOT out = f32[4,2] transpose(maximum), dimensions={1,0}
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
+; CHECK-NEXT: [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_RELU"
+; CHECK: }
+; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBiasReluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z_vec = f32[4] parameter(2)
+ z_matrix = f32[2,4] parameter(3)
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z_vec), dimensions={1}
+ add0 = f32[2,4] add(dot_a, z_bcast)
+ add1 = f32[2,4] add(add0, z_matrix)
+ c = f32[] constant(0)
+ c_bcast = f32[2,4] broadcast(c), dimensions={}
+ ROOT out = f32[2,4] maximum(add1, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P3]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_RELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ mul.0 = f32[2,4] multiply(dot, dot)
+ mul.1 = f32[2,4] multiply(dot, mul.0)
+ const.0 = f32[] constant(0.044715)
+ bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+ mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+ add.0 = f32[2,4] add(dot, mul.2)
+ const.1 = f32[] constant(0.797884583)
+ bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+ mul.3 = f32[2,4] multiply(add.0, bcast.1)
+ tanh = f32[2,4] tanh(mul.3)
+ const.2 = f32[] constant(1)
+ bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+ add.2 = f32[2,4] add(tanh, bcast.2)
+ const.3 = f32[] constant(0.5)
+ bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+ mul.4 = f32[2,4] multiply(add.2, bcast.3)
+ ROOT out = f32[2,4] multiply(dot, mul.4)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"GELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWrongConstant) {
+ // Modify one constant slightly, so it should no longer pattern match.
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ mul.0 = f32[2,4] multiply(dot, dot)
+ mul.1 = f32[2,4] multiply(dot, mul.0)
+ const.0 = f32[] constant(0.05)
+ bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+ mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+ add.0 = f32[2,4] add(dot, mul.2)
+ const.1 = f32[] constant(0.797884583)
+ bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+ mul.3 = f32[2,4] multiply(add.0, bcast.1)
+ tanh = f32[2,4] tanh(mul.3)
+ const.2 = f32[] constant(1)
+ bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+ add.2 = f32[2,4] add(tanh, bcast.2)
+ const.3 = f32[] constant(0.5)
+ bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+ mul.4 = f32[2,4] multiply(add.2, bcast.3)
+ ROOT out = f32[2,4] multiply(dot, mul.4)
+}
+
+)";
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-NOT: GELU
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) {
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60000
+ auto rocm_switch = false; // GELU is only available from ROCM 6.0
+#else
+ auto rocm_switch = true;
+#endif
+ if (!IsCuda() && rocm_switch) {
+ GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[4] parameter(2)
+ dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z), dimensions={1}
+ add = f32[2,4] add(dot, z_bcast)
+ mul.0 = f32[2,4] multiply(add, add)
+ mul.1 = f32[2,4] multiply(add, mul.0)
+ const.0 = f32[] constant(0.044715)
+ bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+ mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+ add.0 = f32[2,4] add(add, mul.2)
+ const.1 = f32[] constant(0.797884583)
+ bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+ mul.3 = f32[2,4] multiply(add.0, bcast.1)
+ tanh = f32[2,4] tanh(mul.3)
+ const.2 = f32[] constant(1)
+ bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+ add.2 = f32[2,4] add(tanh, bcast.2)
+ const.3 = f32[] constant(0.5)
+ bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+ mul.4 = f32[2,4] multiply(add.2, bcast.3)
+ ROOT out = f32[2,4] multiply(add, mul.4)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_GELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ mul.0 = f32[2,4] multiply(dot, dot)
+ mul.1 = f32[2,4] multiply(dot, mul.0)
+ const.0 = f32[] constant(0.044715)
+ bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+ mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+ add.0 = f32[2,4] add(dot, mul.2)
+ const.1 = f32[] constant(0.797884583)
+ bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+ mul.3 = f32[2,4] multiply(add.0, bcast.1)
+ tanh = f32[2,4] tanh(mul.3)
+ const.2 = f32[] constant(1)
+ bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+ add.2 = f32[2,4] add(tanh, bcast.2)
+ const.3 = f32[] constant(0.5)
+ bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+ mul.4 = f32[2,4] multiply(add.2, bcast.3)
+ mul.5 = f32[2,4] multiply(dot, mul.4)
+ ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, dot)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> (f32[2,4], f32[2,4]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"GELU_AUX"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[4] parameter(2)
+ dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f32[2,4] broadcast(z), dimensions={1}
+ add = f32[2,4] add(dot, z_bcast)
+ mul.0 = f32[2,4] multiply(add, add)
+ mul.1 = f32[2,4] multiply(add, mul.0)
+ const.0 = f32[] constant(0.044715)
+ bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+ mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+ add.0 = f32[2,4] add(add, mul.2)
+ const.1 = f32[] constant(0.797884583)
+ bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+ mul.3 = f32[2,4] multiply(add.0, bcast.1)
+ tanh = f32[2,4] tanh(mul.3)
+ const.2 = f32[] constant(1)
+ bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+ add.2 = f32[2,4] add(tanh, bcast.2)
+ const.3 = f32[] constant(0.5)
+ bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+ mul.4 = f32[2,4] multiply(add.2, bcast.3)
+ mul.5 = f32[2,4] multiply(add, mul.4)
+ ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, add)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> (f32[2,4], f32[2,4]) {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_GELU_AUX"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBF16) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
+ "architectures with bf16 Tensor Cores.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[2,3] parameter(0)
+ y = bf16[3,4] parameter(1)
+ dot = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ mul.0 = bf16[2,4] multiply(dot, dot)
+ mul.1 = bf16[2,4] multiply(dot, mul.0)
+ const.0 = bf16[] constant(0.044715)
+ bcast.0 = bf16[2,4] broadcast(const.0), dimensions={}
+ mul.2 = bf16[2,4] multiply(mul.1, bcast.0)
+ add.0 = bf16[2,4] add(dot, mul.2)
+ const.1 = bf16[] constant(0.797884583)
+ bcast.1 = bf16[2,4] broadcast(const.1), dimensions={}
+ mul.3 = bf16[2,4] multiply(add.0, bcast.1)
+ tanh = bf16[2,4] tanh(mul.3)
+ const.2 = bf16[] constant(1)
+ bcast.2 = bf16[2,4] broadcast(const.2), dimensions={}
+ add.2 = bf16[2,4] add(tanh, bcast.2)
+ const.3 = bf16[] constant(0.5)
+ bcast.3 = bf16[2,4] broadcast(const.3), dimensions={}
+ mul.4 = bf16[2,4] multiply(add.2, bcast.3)
+ ROOT out = bf16[2,4] multiply(dot, mul.4)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{5e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] {
+; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
+; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBitcast) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_bitcast = f32[2,2,2] bitcast(dot)
+ mul.0 = f32[2,2,2] multiply(dot_bitcast, dot_bitcast)
+ mul.1 = f32[2,2,2] multiply(dot_bitcast, mul.0)
+ const.0 = f32[] constant(0.044715)
+ bcast.0 = f32[2,2,2] broadcast(const.0), dimensions={}
+ mul.2 = f32[2,2,2] multiply(mul.1, bcast.0)
+ add.0 = f32[2,2,2] add(dot_bitcast, mul.2)
+ const.1 = f32[] constant(0.797884583)
+ bcast.1 = f32[2,2,2] broadcast(const.1), dimensions={}
+ mul.3 = f32[2,2,2] multiply(add.0, bcast.1)
+ tanh = f32[2,2,2] tanh(mul.3)
+ const.2 = f32[] constant(1)
+ bcast.2 = f32[2,2,2] broadcast(const.2), dimensions={}
+ add.2 = f32[2,2,2] add(tanh, bcast.2)
+ const.3 = f32[] constant(0.5)
+ bcast.3 = f32[2,2,2] broadcast(const.3), dimensions={}
+ mul.4 = f32[2,2,2] multiply(add.2, bcast.3)
+ ROOT out = f32[2,2,2] multiply(dot_bitcast, mul.4)
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Bitcast(m::GetTupleElement(
+ m::CustomCall({"__cublas$lt$matmul"},
+ m::Parameter(0).WithShape(F32, {2, 3}),
+ m::Parameter(1).WithShape(F32, {3, 4})),
+ 0))
+ .WithShape(F32, {2, 2, 2})));
+}
+
+// For F16, the sizes of all dimensions of the operands are required to be
+// multiples of 8 to allow matrix bias fusion.
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasF16) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[8,16] parameter(0)
+ y = f16[16,8] parameter(1)
+ z = f16[8,8] parameter(2)
+ dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f16[8,8] add(dot_a, z)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasF32UnpaddedWithBitcast) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3]{1,0} parameter(0)
+ y = f32[3,4]{1,0} parameter(1)
+ z = f32[2]{0} parameter(2)
+ dot_a = f32[2,4]{0,1} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitc = f32[4,2]{1,0} bitcast(f32[2,4]{0,1} dot_a)
+ z_bcast = f32[4,2] broadcast(z), dimensions={1}
+ ROOT add = f32[4,2]{1,0} add(bitc, z_bcast)
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Bitcast(m::GetTupleElement(
+ m::CustomCall({"__cublas$lt$matmul"}, m::Parameter(0),
+ m::Parameter(1),
+ m::Parameter(2).WithShape(F32, {2})),
+ 0)
+ .WithShape(F32, {2, 4}))
+ .WithShape(F32, {4, 2})));
+}
+
+// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
+// newer architectures) so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Unpadded) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[8,16] parameter(0)
+ y = f16[16,8] parameter(1)
+ z = f16[8] parameter(2)
+ dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f16[8,8] broadcast(z), dimensions={1}
+ ROOT add = f16[8,8] add(dot_a, z_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT: pad("
+; CHECK: custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Padded) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ GTEST_SKIP() << "Padding of GEMM operands only implemented on "
+ "architectures with Tensor Cores.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[6,12] parameter(0)
+ y = f16[12,6] parameter(1)
+ z = f16[6] parameter(2)
+ dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f16[6,6] broadcast(z), dimensions={1}
+ ROOT add = f16[6,6] add(dot_a, z_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
+; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+ )");
+}
+
+// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
+// newer architectures) so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Unpadded) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[8,16] parameter(0)
+ y = f16[16,8] parameter(1)
+ dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ c = f16[] constant(0)
+ c_bcast = f16[8,8] broadcast(c), dimensions={}
+ ROOT out = f16[8,8] maximum(dot_a, c_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT: pad("
+; CHECK: custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Padded) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ GTEST_SKIP() << "Padding of GEMM operands only implemented on "
+ "architectures with Tensor Cores.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[6,12] parameter(0)
+ y = f16[12,6] parameter(1)
+ dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ c = f16[] constant(0)
+ c_bcast = f16[6,6] broadcast(c), dimensions={}
+ ROOT out = f16[6,6] maximum(dot_a, c_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] {
+; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivationF16) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[8,16] parameter(0)
+ y = f16[16,8] parameter(1)
+ z = f16[8,8] parameter(2)
+ dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ add = f16[8,8] add(dot_a, z)
+ c = f16[] constant(0)
+ c_bcast = f16[8,8] broadcast(c), dimensions={}
+ ROOT out = f16[8,8] maximum(add, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+ )");
+}
+
+// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
+// newer architectures) so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Unpadded) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[8,16] parameter(0)
+ y = f16[16,8] parameter(1)
+ z = f16[8] parameter(2)
+ dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f16[8,8] broadcast(z), dimensions={1}
+ add = f16[8,8] add(dot_a, z_bcast)
+ c = f16[] constant(0)
+ c_bcast = f16[8,8] broadcast(c), dimensions={}
+ ROOT out = f16[8,8] maximum(add, c_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT: pad("
+; CHECK: custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Padded) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+ GTEST_SKIP() << "Padding of GEMM operands only implemented on "
+ "architectures with Tensor Cores.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f16[6,12] parameter(0)
+ y = f16[12,6] parameter(1)
+ z = f16[6] parameter(2)
+ dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f16[6,6] broadcast(z), dimensions={1}
+ add = f16[6,6] add(dot_a, z_bcast)
+ c = f16[] constant(0)
+ c_bcast = f16[6,6] broadcast(c), dimensions={}
+ ROOT out = f16[6,6] maximum(add, c_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
+; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+ )");
+}
+
+// For bfloat16, the sizes of all dimensions of the operands are required to be
+// multiples of 8 to allow matrix bias fusion.
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasBF16) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[8,16] parameter(0)
+ y = bf16[16,8] parameter(1)
+ z = bf16[8,8] parameter(2)
+ dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = bf16[8,8] add(dot_a, z)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[8,16], {{.*}}: bf16[16,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
+; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0)
+; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1)
+; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasBitcastBF16) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[8,16] parameter(0)
+ y = bf16[16,8] parameter(1)
+ bias = bf16[2,4,8] parameter(2)
+ dot = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bitcast = bf16[2,4,8] bitcast(dot)
+ ROOT out = bf16[2,4,8] add(bitcast, bias)
+}
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Bitcast(
+ m::GetTupleElement(
+ m::CustomCall(
+ {"__cublas$lt$matmul"},
+ m::Parameter(0).WithShape(BF16, {8, 16}),
+ m::Parameter(1).WithShape(BF16, {16, 8}),
+ m::Bitcast(m::Parameter(2)).WithShape(BF16, {8, 8})),
+ 0))
+ .WithShape(BF16, {2, 4, 8})));
+}
+
+// For bfloat16, the operands are padded if necessary on Ampere and newer
+// architectures so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Unpadded) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[8,16] parameter(0)
+ y = bf16[16,8] parameter(1)
+ z = bf16[8] parameter(2)
+ dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = bf16[8,8] broadcast(z), dimensions={1}
+ ROOT add = bf16[8,8] add(dot_a, z_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT: pad("
+; CHECK: custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Padded) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
+ "Ampere and newer architectures.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[6,12] parameter(0)
+ y = bf16[12,6] parameter(1)
+ z = bf16[6] parameter(2)
+ dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = bf16[6,6] broadcast(z), dimensions={1}
+ ROOT add = bf16[6,6] add(dot_a, z_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
+; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+ )");
+}
+
+// For bfloat16, the operands are padded if necessary on Ampere and newer
+// architectures so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Unpadded) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[8,16] parameter(0)
+ y = bf16[16,8] parameter(1)
+ dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ c = bf16[] constant(0)
+ c_bcast = bf16[8,8] broadcast(c), dimensions={}
+ ROOT out = bf16[8,8] maximum(dot_a, c_bcast)
+}
+
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT: pad("
+; CHECK: custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Padded) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
+ "Ampere and newer architectures.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[6,12] parameter(0)
+ y = bf16[12,6] parameter(1)
+ dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ c = bf16[] constant(0)
+ c_bcast = bf16[6,6] broadcast(c), dimensions={}
+ ROOT out = bf16[6,6] maximum(dot_a, c_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] {
+; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+ )");
+}
+
+// For bfloat16, the operands are padded if necessary on Ampere and newer
+// architectures so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Unpadded) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[8,16] parameter(0)
+ y = bf16[16,8] parameter(1)
+ z = bf16[8] parameter(2)
+ dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = bf16[8,8] broadcast(z), dimensions={1}
+ add = bf16[8,8] add(dot_a, z_bcast)
+ c = bf16[] constant(0)
+ c_bcast = bf16[8,8] broadcast(c), dimensions={}
+ ROOT out = bf16[8,8] maximum(add, c_bcast)
+})";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-NOT: pad("
+; CHECK: custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Padded) {
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
+ "Ampere and newer architectures.";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[6,12] parameter(0)
+ y = bf16[12,6] parameter(1)
+ z = bf16[6] parameter(2)
+ dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = bf16[6,6] broadcast(z), dimensions={1}
+ add = bf16[6,6] add(dot_a, z_bcast)
+ c = bf16[] constant(0)
+ c_bcast = bf16[6,6] broadcast(c), dimensions={}
+ ROOT out = bf16[6,6] maximum(add, c_bcast)
+}
+
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
+; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "TODO: Unsupported blas-lt F64 datatype on ROCM";
+ }
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f64[2,3] parameter(0)
+ y = f64[3,4] parameter(1)
+ z = f64[4] parameter(2)
+ dot_a = f64[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ z_bcast = f64[2,4] broadcast(z), dimensions={1}
+ add = f64[2,4] add(dot_a, z_bcast)
+ c = f64[] constant(0)
+ c_bcast = f64[2,4] broadcast(c), dimensions={}
+ ROOT out = f64[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-10, 1e-10}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f64[2,3], {{.*}}: f64[3,4], {{.*}}: f64[4]) -> f64[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f64[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f64[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f64[4]{0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f64[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_RELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, AlphaSimpleRewriteBiasAddActivation) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = f32[2,3] parameter(0)
+ y = f32[3,4] parameter(1)
+ z = f32[4] parameter(2)
+ k = f32[] constant(3.0)
+ k_bcast = f32[2,4] broadcast(k), dimensions={}
+ dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+ dot_a_multiplied = f32[2, 4] multiply(dot_a, k_bcast)
+ z_bcast = f32[2,4] broadcast(z), dimensions={1}
+ add = f32[2,4] add(dot_a_multiplied, z_bcast)
+ c = f32[] constant(0)
+ c_bcast = f32[2,4] broadcast(c), dimensions={}
+ ROOT out = f32[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_RELU"
+; CHECK: }
+ )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, FoldConstantBias) {
+ const char* hlo_text = R"(
+HloModule test
+ENTRY test {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
+
+ dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias1 = f32[2,2] parameter(2)
+ sum1 = add(dot1, bias1)
+
+ dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ sum2 = add(dot2, f32[2,2] reshape(bias))
+
+ dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias3 = f32[2,2] transpose(bias), dimensions={1,0}
+ sum3 = add(dot3, bias3)
+
+ dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ sum4 = add(dot4, f32[2,2] bitcast(bias))
+
+ ROOT root = tuple(sum1, sum2, sum3, sum4)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(Capability(), GetToolkitVersion());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ SCOPED_TRACE(module->ToString());
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter()),
+ 0),
+ m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+ 0),
+ m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+ 0),
+ m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+ 0))));
+}
+
+TEST_F(CublasLtGemmRewriteTest, MultipleMaximumUsers) {
+ const char* hlo_text = R"(
+HloModule multiple_maximum_users
+
+relu {
+ Arg_0 = f32[3,896,54]{2,1,0} parameter(0)
+ constant = f32[] constant(0)
+ broadcast = f32[3,896,54]{2,1,0} broadcast(constant), dimensions={}
+ ROOT maximum = f32[3,896,54]{2,1,0} maximum(Arg_0, broadcast)
+}
+
+ENTRY main {
+ constant = f32[] constant(1)
+ broadcast_1 = f32[3,896,1024]{2,1,0} broadcast(constant), dimensions={}
+ Arg_2 = f32[1024,54]{1,0} parameter(2)
+ dot = f32[3,896,54]{2,1,0} dot(broadcast_1, Arg_2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ Arg_1 = f32[54]{0} parameter(1)
+ broadcast_2 = f32[3,896,54]{2,1,0} broadcast(Arg_1), dimensions={2}
+ add = f32[3,896,54]{2,1,0} add(dot, broadcast_2)
+ call = f32[3,896,54]{2,1,0} call(add), to_apply=relu
+ Arg_0 = f32[1]{0} parameter(0)
+ reshape_1 = f32[1,1,1]{2,1,0} reshape(Arg_0)
+ broadcast_3 = f32[1,1,1]{2,1,0} broadcast(reshape_1), dimensions={0,1,2}
+ reshape_2 = f32[] reshape(broadcast_3)
+ broadcast_4 = f32[3,896,54]{2,1,0} broadcast(reshape_2), dimensions={}
+ multiply = f32[3,896,54]{2,1,0} multiply(call, broadcast_4)
+ ROOT tuple = (f32[3,896,54]{2,1,0}, f32[3,896,54]{2,1,0}) tuple(multiply, call)
+}
+)";
+
+ // TODO(cjfj): Why do we need to relax the error constraint here?!
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-4}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK: custom_call_target="__cublas$lt$matmul",
+ )");
+}
+
+// Test gemm matrix bias add fusion with mix type and out of place update(C !=
+// D)
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlace) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
+ }
+ std::vector<std::tuple<absl::string_view, absl::string_view>>
+ type_combinations = {
+ {"f16", "f32"},
+ {"bf16", "f32"},
+ };
+
+ const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+ x = <<ABType>>[16,32] parameter(0)
+ y = <<ABType>>[32,16] parameter(1)
+ z = <<DType>>[16,16] parameter(2)
+ dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ convert = <<DType>>[16,16] convert(dot_a)
+ ROOT out = <<DType>>[16,16] add(convert, z)
+})";
+ for (const auto& type_combination : type_combinations) {
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<ABType>>"] = std::get<0>(type_combination);
+ replacements["<<DType>>"] = std::get<1>(type_combination);
+ const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ continue;
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ EXPECT_THAT(
+ optimized_module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)));
+ }
+}
+
+// Test batch gemm matrix bias add fusion with mix type and out of place
+// update(C != D)
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlaceBatched) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
+ }
+ std::vector<std::tuple<absl::string_view, absl::string_view>>
+ type_combinations = {
+ {"f16", "f32"},
+ {"bf16", "f32"},
+ };
+
+ const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+ x = <<ABType>>[4,16,32] parameter(0)
+ y = <<ABType>>[4,32,16] parameter(1)
+ z = <<DType>>[4,16,16] parameter(2)
+ dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+ convert = <<DType>>[4,16,16] convert(dot_a)
+ ROOT out = <<DType>>[4,16,16] add(convert, z)
+})";
+ for (const auto& type_combination : type_combinations) {
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<ABType>>"] = std::get<0>(type_combination);
+ replacements["<<DType>>"] = std::get<1>(type_combination);
+ const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ continue;
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ EXPECT_THAT(
+ optimized_module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+ 0)));
+ }
+}
+
+// Test gemm matrix bias add fusion with mix type and in place update(C = D)
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeInPlace) {
+ if (!IsCuda()) {
+ GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
+ }
+ std::vector<std::tuple<absl::string_view, absl::string_view>>
+ type_combinations = {
+ {"f16", "f32"},
+ {"bf16", "f32"},
+ };
+ const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+ x = <<ABType>>[16,32] parameter(0)
+ y = <<ABType>>[32,16] parameter(1)
+ z = <<DType>>[16,16] parameter(2)
+ dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias = <<DType>>[16,16] negate(z)
+ convert = <<DType>>[16,16] convert(dot_a)
+ ROOT out = <<DType>>[16,16] add(convert, bias)
+})";
+
+ for (const auto& type_combination : type_combinations) {
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<ABType>>"] = std::get<0>(type_combination);
+ replacements["<<DType>>"] = std::get<1>(type_combination);
+ const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ continue;
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall(m::Parameter(0), m::Parameter(1),
+ m::Negate(m::Parameter(2))),
+ 0)));
+ }
+}
+
+// Test gemm matrix bias add fusion with mix type that is not supported
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ x = bf16[16,32] parameter(0)
+ y = bf16[32,16] parameter(1)
+ z = f64[16,16] parameter(2)
+ dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias = f64[16,16] negate(z)
+ convert = f64[16,16] convert(dot_a)
+ ROOT out = f64[16,16] add(convert, bias)
+}
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+ if (IsCuda() &&
+ !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo_text));
+ MatchOptimizedHlo(hlo_text, R"(
+; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$lt$matmul
+; CHECK: %[[tuple:.*]] = bf16[16,16]{1,0} get-tuple-element(%[[custom_call]]), index=0
+; CHECK: ROOT {{.*}} fusion({{.*}}%[[tuple]]
+)");
+}
+
+class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest {
+ public:
+ ParameterizedFp8GemmRewriteTest() {
+ replacements_[kF8E4M3DatatypePlaceholder] =
+#if GOOGLE_CUDA
+ "f8e4m3fn";
+#else
+ "f8e4m3fnuz";
+#endif
+ replacements_[kF8E5M2DatatypePlaceholder] =
+#if GOOGLE_CUDA
+ "f8e5m2";
+#else
+ "f8e5m2fnuz";
+#endif
+ replacements_[kF8E4M3AmaxPlaceholder] =
+#if GOOGLE_CUDA
+ "448.";
+#else
+ "240.";
+#endif
+ }
+
+ protected:
+ // Check the HLO runs and has an FP8 cuBLAS LT custom call on supported
+ // architectures (Ada, Hopper, and later).
+ void CheckFp8IfSupported(absl::string_view hlo_text,
+ ErrorSpec error_spec = ErrorSpec{1e-2, 1e-2}) {
+ if (!HasFp8Support()) {
+ return;
+ }
+ std::string replaced_hlo_text =
+ absl::StrReplaceAll(hlo_text, replacements_);
+ EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+ error_spec));
+
+ // Most FP8 tests directly create a GemmRewriter and check the output.
+ // Here, also run the entire HLO pass pipeline to ensure no other passes
+ // interfere with GemmRewriter's pattern matching.
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(replaced_hlo_text));
+ const HloInstruction* call =
+ FindInstruction(optimized_module.get(), HloOpcode::kCustomCall);
+ ASSERT_NE(call, nullptr);
+ EXPECT_EQ(call->custom_call_target(), "__cublas$lt$matmul$f8");
+ }
+
+ void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
+ bool print_operand_shape = false) {
+ GemmRewriteTest::MatchOptimizedHlo(
+ absl::StrReplaceAll(hlo, replacements_),
+ absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
+ }
+
+ void RunAndFilecheckHloRewrite(
+ absl::string_view hlo, HloPassInterface&& hlo_pass,
+ std::optional<absl::string_view> expected,
+ std::function<void(HloModule*)> after_pass_checks = nullptr,
+ const HloModuleConfig* config = nullptr) {
+ if (expected.has_value()) {
+ std::string replaced_pattern =
+ absl::StrReplaceAll(expected.value(), replacements_);
+ GemmRewriteTest::RunAndFilecheckHloRewrite(
+ absl::StrReplaceAll(hlo, replacements_), std::move(hlo_pass),
+ replaced_pattern, after_pass_checks, config);
+ }
+ }
+
+ absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
+ ParseAndReturnVerifiedModule(absl::string_view hlo_text,
+ int64_t replica_count = 1,
+ int64_t num_partitions = 1) {
+ return GemmRewriteTest::ParseAndReturnVerifiedModule(
+ absl::StrReplaceAll(hlo_text, replacements_));
+ }
+
+ private:
+ static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
+ static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
+ static constexpr const char* kF8E4M3AmaxPlaceholder{"<<F8E4M3_AMAX>>"};
+};
+
+TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) {
+ if (HasFp8Support()) {
+ GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
+ }
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY PreAdaTest {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+ ErrorSpec{1e-2, 1e-2}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
+; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
+; CHECK-DAG: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteOnPreAdaWithF32Output) {
+ if (HasFp8Support()) {
+ GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
+ }
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY PreAdaTest {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ ROOT out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+ ErrorSpec{1e-2, 1e-2}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
+; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
+; CHECK-DAG: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ // Test with types unsupported by cuBLAS LT when FP8 is used. cuBLAS LT with
+ // FP8 requires one of the operands to be F8E4M3FN.
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY unsupported_types {
+ x = <<F8E5M2>>[16,16] parameter(0)
+ y = <<F8E5M2>>[16,16] parameter(1)
+ ROOT out = <<F8E5M2>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+)";
+ EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+ ErrorSpec{1e-2, 1e-2}));
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(Capability(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <<F8E5M2>>[16,16], {{.*}}: <<F8E5M2>>[16,16]) -> <<F8E5M2>>[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(0)
+; CHECK-NEXT: [[P0_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P1]])
+; CHECK-NEXT: [[DOT:%[^ ]+]] = f16[16,16]{1,0} dot([[P0_CONVERT]], [[P1_CONVERT]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} convert([[DOT]])
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1)
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+ R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#else
+ R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#endif
+ R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+// Do not fuse FP8 matrix bias.
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+ GTEST_SKIP() << "F8 gemm rewrite for D to be fp8 with Matrix Bias is only "
+ "supported in ROCm 6.2 and above.";
+#endif // TF_ROCM_VERSION < 60200
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ dot_a = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ b = <<F8E4M3>>[16,16] parameter(2)
+ ROOT out = <<F8E4M3>>[16,16] add(dot_a, b)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: <<F8E4M3>>[16,16]) -> <<F8E4M3>>[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[DOT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
+; CHECK-NEXT: [[P2:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} parameter(2)
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} add([[DOT]], [[P2]])
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[13,17] parameter(0)
+ y = <<F8E4M3>>[17,31] parameter(1)
+ x_f32 = f32[13,17] convert(x)
+ y_f32 = f32[17,31] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[13,17] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[17,31] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[13,17] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[17,31] multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[13,31] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[13,17], {{.*}}: <<F8E4M3>>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[13,17]{1,0} parameter(0)
+; CHECK-NEXT: [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[17,31]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,17]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C4:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]], [[C4]], /*index=5*/[[C4]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK-NEXT: [[DOT:%[^ ]+]] = f32[16,32]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[13,31]{1,0} slice([[DOT]]), slice={[0:13], [0:31]}
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[2,8,16] parameter(0)
+ y = <<F8E4M3>>[16,16] parameter(1)
+ x_f32 = f32[2,8,16] convert(x)
+ y_f32 = f32[16,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[2,8,16] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[2,8,16] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,16] multiply(y_f32, y_scale_bcast)
+ x_bitcast = f32[16,16] bitcast(x_unscaled)
+ ROOT out = f32[16,16] dot(x_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+ .WithShape(F32, {16, 16})));
+}
+
+// Test case where F8 inputs are converted to F32 before the dot, but without
+// any scaling.
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ ROOT out = f32[16,16] dot(x_f32, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[3] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[3] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[3] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[3] multiply(x_f32, x_scale_bcast)
+ zero = f32[] constant(0)
+ x_unscaled_padded = f32[30] pad(x_unscaled, zero), padding=0_27
+ x_unscaled_padded_bcast = f32[30,8,5] broadcast(x_unscaled_padded), dimensions={0}
+ x_unscaled_padded_bcast_sliced = f32[16,8,4] slice(x_unscaled_padded_bcast), slice={[2:18], [0:8], [0:4]}
+ x_unscaled_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_unscaled_padded_bcast_sliced)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[16,16] dot(x_unscaled_padded_bcast_sliced_reshaped, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
+; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
+; CHECK-NEXT: [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
+; CHECK-NEXT: [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
+; CHECK-NEXT: [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
+; CHECK-NEXT: [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C2]], /*index=5*/[[C2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ UnscaledABUnscaledDUnaryOpsWithConvertF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[3] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[3] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ zero = f32[] constant(0)
+ x_padded = f32[30] pad(x_f32, zero), padding=0_27
+ x_padded_bcast = f32[30,8,5] broadcast(x_padded), dimensions={0}
+ x_padded_bcast_sliced = f32[16,8,4] slice(x_padded_bcast), slice={[2:18], [0:8], [0:4]}
+ x_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_padded_bcast_sliced)
+ ROOT out = f32[16,16] dot(x_padded_bcast_sliced_reshaped, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
+; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
+; CHECK-NEXT: [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
+; CHECK-NEXT: [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
+; CHECK-NEXT: [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
+; CHECK-NEXT: [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]], [[C2]], /*index=5*/[[C2]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[32,32] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ zero = s32[] constant(0)
+ x_f32 = f32[32,32] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[32,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[32,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ dyn_slice = f32[16,32]{1,0} dynamic-slice(x_unscaled, zero, zero), dynamic_slice_sizes={16,32}
+ ROOT dot_a = f32[16,16] dot(dyn_slice, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[32,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} parameter(0)
+; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0)
+; CHECK-NEXT: [[DYN_SLICE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32}
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ k = pred[16,32] parameter(4)
+ c = f32[] constant(0)
+ c_bcast = f32[16,32] broadcast(c), dimensions={}
+ select_a = f32[16,32] select(k, y_unscaled, c_bcast)
+ ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P4:%[^ ]+]] = pred[16,32]{1,0} parameter(4)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[16,32]{1,0} broadcast([[C0]]), dimensions={}
+; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} convert([[C0_BCAST]])
+; CHECK-NEXT: [[SELECT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDSelectNonzeroConstantF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ k = pred[16,32] parameter(4)
+ c = f32[] constant(1)
+ c_bcast = f32[16,32] broadcast(c), dimensions={}
+ select_a = f32[16,32] select(k, y_unscaled, c_bcast)
+ ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[10,16,32] parameter(0)
+ y = <<F8E4M3>>[10,32,16] parameter(1)
+ x_f32 = f32[10,16,32] convert(x)
+ y_f32 = f32[10,32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[10,16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[10,32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[10,16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[10,32,16] multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[10,16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[10,16,32], {{.*}}: <<F8E4M3>>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[10,32,16]{2,1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["2"]
+; CHECK-DAG: "rhs_contracting_dimensions":["2"]
+; CHECK-DAG: "lhs_batch_dimensions":["0"]
+; CHECK-DAG: "rhs_batch_dimensions":["0"]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ k = f32[] constant(3.0)
+ k_bcast = f32[16,16] broadcast(k), dimensions={}
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[16,16] multiply(dot_a, k_bcast)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":3
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ c = f32[] constant(0)
+ c_bcast = f32[16,16] broadcast(c), dimensions={}
+ ROOT out = f32[16,16] maximum(dot_a, c_bcast)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_bf16 = bf16[16,32] convert(x)
+ y_bf16 = bf16[32,16] convert(y)
+ x_scale = bf16[] parameter(2)
+ y_scale = bf16[] parameter(3)
+ bias = bf16[16] parameter(4)
+ x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
+ y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
+ dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ b_bcast = bf16[16,16] broadcast(bias), dimensions={1}
+ dot = bf16[16,16] add(dot1, b_bcast)
+ mul.0 = bf16[16,16] multiply(dot, dot)
+ mul.1 = bf16[16,16] multiply(dot, mul.0)
+ const.0 = bf16[] constant(0.044715)
+ bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
+ mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
+ add.0 = bf16[16,16] add(dot, mul.2)
+ const.1 = bf16[] constant(0.797884583)
+ bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
+ mul.3 = bf16[16,16] multiply(add.0, bcast.1)
+ tanh = bf16[16,16] tanh(mul.3)
+ const.2 = bf16[] constant(1)
+ bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
+ add.2 = bf16[16,16] add(tanh, bcast.2)
+ const.3 = bf16[] constant(0.5)
+ bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
+ mul.4 = bf16[16,16] multiply(add.2, bcast.3)
+ ROOT out = bf16[16,16] multiply(dot, mul.4)
+ }
+)";
+
+ CheckFp8IfSupported(hlo_text);
+
+// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
+// than 12.4.
+#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2)
+; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3)
+; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+ R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#else
+ R"(; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]),
+)"
+#endif
+ R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+ R"(; CHECK-GCN-DAG: "epilogue":"DEFAULT"
+)"
+#else
+ R"(; CHECK-DAG: "epilogue":"BIAS_GELU"
+)"
+#endif
+ R"(; CHECK: }
+ )");
+#endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDApproxGeluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_bf16 = bf16[16,32] convert(x)
+ y_bf16 = bf16[32,16] convert(y)
+ x_scale = bf16[] parameter(2)
+ y_scale = bf16[] parameter(3)
+ x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
+ y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
+ dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ mul.0 = bf16[16,16] multiply(dot, dot)
+ mul.1 = bf16[16,16] multiply(dot, mul.0)
+ const.0 = bf16[] constant(0.044715)
+ bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
+ mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
+ add.0 = bf16[16,16] add(dot, mul.2)
+ const.1 = bf16[] constant(0.797884583)
+ bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
+ mul.3 = bf16[16,16] multiply(add.0, bcast.1)
+ tanh = bf16[16,16] tanh(mul.3)
+ const.2 = bf16[] constant(1)
+ bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
+ add.2 = bf16[16,16] add(tanh, bcast.2)
+ const.3 = bf16[] constant(0.5)
+ bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
+ mul.4 = bf16[16,16] multiply(add.2, bcast.3)
+ ROOT out = bf16[16,16] multiply(dot, mul.4)
+ }
+)";
+
+ CheckFp8IfSupported(hlo_text);
+
+// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
+// than 12.4.
+#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+ // Currently, hipBlasLt does not support output datatype bf16 for fp8 matmul.
+ // And no fusion was done for such cases.
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2)
+; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3)
+; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+ R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#else
+ R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#endif
+ R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+ R"(; CHECK-GCN-DAG: "epilogue":"DEFAULT"
+)"
+#else
+ R"(; CHECK-DAG: "epilogue":"GELU"
+)"
+#endif
+ R"(; CHECK: }
+ )");
+#endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] divide(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] divide(y_f32, y_scale_bcast)
+ ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ b = f32[16,16] parameter(2)
+ one = f32[] constant(1)
+ ones = f32[16,16] broadcast(one), dimensions={}
+ b_ones = f32[16,16] add(b, ones)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(3)
+ y_scale = f32[] parameter(4)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = add(dot_a, b_ones)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK: [[C0:%[^ ]+]] = f32[16,16]{1,0} add({{.*}})
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: output_to_operand_aliasing={
+; CHECK-SAME: {0}: (2, {})
+; CHECK-SAME: }
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[14,31] parameter(0)
+ y = <<F8E4M3>>[31,14] parameter(1)
+ b = f32[14,14] parameter(2)
+ x_f32 = f32[14,31] convert(x)
+ y_f32 = f32[31,14] convert(y)
+ x_scale = f32[] parameter(3)
+ y_scale = f32[] parameter(4)
+ x_scale_bcast = f32[14,31] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[31,14] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[14,31] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[31,14] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[14,14] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = add(dot_a, b)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[14,31], {{.*}}: <<F8E4M3>>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} parameter(0)
+; CHECK-NEXT: [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[31,14]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[14,14]{1,0} parameter(2)
+; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]], /*index=5*/[[C3]], [[C3]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[DOT:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[14,14]{1,0} slice([[DOT]]), slice={[0:14], [0:14]}
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ z_scale = f32[] parameter(2)
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+ c1 = f32[] constant(-448.)
+ c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+ c2 = f32[] constant(448.)
+ c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ z_scale = f32[] parameter(2)
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ z_scale = f32[] parameter(2)
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+// Do not fuse output scaling without type conversion when a matrix bias was
+// fused.
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ b = f32[16,16] parameter(2)
+ z_scale = f32[] parameter(3)
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bias = f32[16,16] add(dot_a, b)
+ ROOT dot_a_scaled = f32[16,16] divide(dot_a_bias, z_scale_bcast)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2)
+; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK-PTX-NEXT: [[GEMM:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-PTX-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-PTX-NEXT: [[P3_BCAST:%[^ ]+]] = f32[16,16]{1,0} broadcast([[P3]]), dimensions={}
+; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} divide([[GEMM]], [[P3_BCAST]])
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ z_scale = f32[] parameter(4)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+ c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+ c2 = f32[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ z_scale = f32[] parameter(4)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
+ c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+ c2 = f32[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-NOT: divide
+
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ z_scale = f32[] parameter(4)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ c = f32[] constant(0)
+ c_bcast = f32[16,16] broadcast(c), dimensions={}
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ relu_a = f32[16,16] maximum(dot_a, c_bcast)
+ relu_a_scaled = f32[16,16] divide(relu_a, z_scale_bcast)
+ c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+ c2 = f32[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+ relu_a_clamped = f32[16,16] clamp(c1_bcast, relu_a_scaled, c2_bcast)
+ ROOT out = <<F8E4M3>>[16,16] convert(relu_a_clamped)
+ }
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f16[] parameter(0)
+ b = f16[] parameter(1)
+ ROOT c = f16[] maximum(a, b)
+ }
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f16 = f16[16,32] convert(x)
+ y_f16 = f16[32,16] convert(y)
+ b = f16[16,16] parameter(2)
+ one = f16[] constant(1)
+ ones = f16[16,16] broadcast(one), dimensions={}
+ b_ones = f16[16,16] add(b, ones)
+ x_scale = f16[] parameter(3)
+ y_scale = f16[] parameter(4)
+ z_scale = f16[] parameter(5)
+ x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+ y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+ dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bias = f16[16,16] add(dot_a, b_ones)
+ abs_dot_a = f16[16,16] abs(dot_a_bias)
+ c0 = f16[] constant(-inf)
+ amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
+ dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
+ c1 = f16[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f16[16,16] broadcast(c1), dimensions={}
+ c2 = f16[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f16[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ ROOT result = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
+ }
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK: [[C0:%[^ ]+]] = f16[16,16]{1,0} add({{.*}})
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK: [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5)
+; CHECK-PTX: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
+; CHECK-NOT: output_to_operand_aliasing
+; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f16 = f16[16,32] convert(x)
+ y_f16 = f16[32,16] convert(y)
+ b = f16[16] parameter(2)
+ b_bcast = f16[16,16] broadcast(b), dimensions={1}
+ x_scale = f16[] parameter(3)
+ y_scale = f16[] parameter(4)
+ z_scale = f16[] parameter(5)
+ x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+ y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+ dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bias = f16[16,16] add(dot_a, b_bcast)
+ dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
+ c1 = f16[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f16[16,16] broadcast(c1), dimensions={}
+ c2 = f16[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f16[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <<F8E4M3>>[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1)
+; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5)
+; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]])
+; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
+; CHECK-PTX: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]),
+; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ b = f32[16] parameter(2)
+ b_bf16 = bf16[16] convert(b)
+ b_f32 = f32[16] convert(b_bf16)
+ b_bcast = f32[16,16] broadcast(b_f32), dimensions={1}
+ x_scale = f32[] parameter(3)
+ y_scale = f32[] parameter(4)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[16,16] add(dot_a, b_bcast)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[VB:%[^ ]+]] = f32[16]{0} parameter(2)
+; CHECK-NEXT: [[VBC:%[^ ]+]] = bf16[16]{0} convert([[VB]])
+; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]], [[VBC]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDVectorBiasThenReluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ b = f16[16] parameter(2)
+ b_bcast = f16[16,16] broadcast(b), dimensions={1}
+ x_f32 = f16[16,32] convert(x)
+ y_f32 = f16[32,16] convert(y)
+ x_scale = f16[] parameter(3)
+ y_scale = f16[] parameter(4)
+ x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f16[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f16[32,16] multiply(y_f32, y_scale_bcast)
+ c = f16[] constant(0)
+ c_bcast = f16[16,16] broadcast(c), dimensions={}
+ dot_a0 = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a = f16[16,16] add(dot_a0, b_bcast)
+ ROOT out = f16[16,16] maximum(dot_a, c_bcast)
+ }
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
+; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS_RELU"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[4,16,16] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ b = f32[32] parameter(2)
+ b_f16 = f16[32] convert(b)
+ b_bcast = f16[4,16,32] broadcast(b_f16), dimensions={2}
+ x_f16 = f16[4,16,16] convert(x)
+ y_f16 = f16[16,32] convert(y)
+ x_scale = f16[] parameter(3)
+ y_scale = f16[] parameter(4)
+ x_scale_bcast = f16[4,16,16] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f16[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f16[4,16,16] multiply(x_f16, x_scale_bcast)
+ x_unscaled_bitcast = f16[64,16] bitcast(x_unscaled)
+ y_unscaled = f16[16,32] multiply(y_f16, y_scale_bcast)
+ dot_a = f16[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bitcast = f16[4,16,32]{2,1,0} bitcast(dot_a)
+ ROOT out = f16[4,16,32] add(dot_a_bitcast, b_bcast)
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Bitcast(m::GetTupleElement(
+ m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+ .WithShape(F16, {64, 32}))
+ .WithShape(F16, {4, 16, 32})));
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[B:%[^ ]+]] = f32[32]{0} parameter(2)
+; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]])
+; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+; CHECK: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK: ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]])
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ Rank3ScaledABUnscaledDVectorBiasPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[4,15,15] parameter(0)
+ y = <<F8E4M3>>[15,31] parameter(1)
+ b = f32[31] parameter(2)
+ b_f16 = f16[31] convert(b)
+ b_bcast = f16[4,15,31] broadcast(b_f16), dimensions={2}
+ x_f16 = f16[4,15,15] convert(x)
+ y_f16 = f16[15,31] convert(y)
+ x_scale = f16[] parameter(3)
+ y_scale = f16[] parameter(4)
+ x_scale_bcast = f16[4,15,15] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f16[15,31] broadcast(y_scale), dimensions={}
+ x_unscaled = f16[4,15,15] multiply(x_f16, x_scale_bcast)
+ x_unscaled_bitcast = f16[60,15] bitcast(x_unscaled)
+ y_unscaled = f16[15,31] multiply(y_f16, y_scale_bcast)
+ dot_a = f16[60,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bitcast = f16[4,15,31]{2,1,0} bitcast(dot_a)
+ ROOT out = f16[4,15,31] add(dot_a_bitcast, b_bcast)
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Bitcast(m::Slice(m::GetTupleElement(
+ m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+ .WithShape(F16, {64, 32}))
+ .WithShape(F16, {60, 31}))
+ .WithShape(F16, {4, 15, 31})));
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[4,15,15]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[60,15]{1,0} bitcast([[P0]])
+; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[B:%[^ ]+]] = f32[31]{0} parameter(2)
+; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]])
+; CHECK-NEXT: [[C3:%[^ ]+]] = f16[] constant(0)
+; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1
+; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"BIAS"
+; CHECK: }
+; CHECK: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT: [[SLICE:%[^ ]+]] = f16[60,31]{1,0} slice([[GEMM]]), slice={[0:60], [0:31]}
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[4,15,31]{2,1,0} bitcast([[SLICE]])
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[4,16,16] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ b = f32[4,16,32] parameter(2)
+ x_f32 = f32[4,16,16] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(3)
+ y_scale = f32[] parameter(4)
+ x_scale_bcast = f32[4,16,16] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[4,16,16] multiply(x_f32, x_scale_bcast)
+ x_unscaled_bitcast = f32[64,16] bitcast(x_unscaled)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bitcast = f32[4,16,32]{2,1,0} bitcast(dot_a)
+ ROOT out = f32[4,16,32] add(dot_a_bitcast, b)
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Bitcast(m::GetTupleElement(
+ m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+ .WithShape(F32, {64, 32}))
+ .WithShape(F32, {4, 16, 32})));
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2)
+; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[GEMM:%[^ ]+]] = f32[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK: ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]])
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ Rank3ScaledABUnscaledDMatrixBiasPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[3,15,15] parameter(0)
+ y = <<F8E4M3>>[15,31] parameter(1)
+ b = f32[3,15,31] parameter(2)
+ x_f32 = f32[3,15,15] convert(x)
+ y_f32 = f32[15,31] convert(y)
+ x_scale = f32[] parameter(3)
+ y_scale = f32[] parameter(4)
+ x_scale_bcast = f32[3,15,15] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[15,31] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[3,15,15] multiply(x_f32, x_scale_bcast)
+ x_unscaled_bitcast = f32[45,15] bitcast(x_unscaled)
+ y_unscaled = f32[15,31] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[45,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bitcast = f32[3,15,31]{2,1,0} bitcast(dot_a)
+ ROOT out = f32[3,15,31] add(dot_a_bitcast, b)
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Bitcast(m::Slice(m::GetTupleElement(
+ m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+ .WithShape(F32, {48, 32}))
+ .WithShape(F32, {45, 31}))
+ .WithShape(F32, {3, 15, 31})));
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[3,15,15]{2,1,0} parameter(0)
+; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[45,15]{1,0} bitcast([[P0]])
+; CHECK-NEXT: [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
+; CHECK-NEXT: [[B:%[^ ]+]] = f32[3,15,31]{2,1,0} parameter(2)
+; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[45,31]{1,0} bitcast([[B]])
+; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[45,31]{1,0} slice([[GEMM]]), slice={[0:45], [0:31]}
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[3,15,31]{2,1,0} bitcast([[SLICE]])
+ )");
+}
+
+// Do not fuse matrix bias When there is a slice that does not chop off the ends
+// of dimensions.
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDMatrixBiasWithSliceF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>>[48,16] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ b = f32[32,16] parameter(2)
+ x_f32 = f32[48,16] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(3)
+ y_scale = f32[] parameter(4)
+ x_scale_bcast = f32[48,16] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[48,16] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[48,32] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_sliced = f32[32,16] slice(dot_a), slice={[16:48], [16:32]}
+ ROOT out = f32[32,16] add(dot_a_sliced, b)
+ }
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_TRUE(changed);
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[48,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[GEMM:%[^_]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[32,16]{1,0} slice([[GEMM]]), slice={[16:48], [16:32]}
+; CHECK-NEXT: [[B:%[^ ]+]] = f32[32,16]{1,0} parameter(2)
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[32,16]{1,0} add([[SLICE]], [[B]])
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ absl::string_view hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ all_gather = f32[16,64]{1,0} all-gather(x_unscaled), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+ all_gather1 = f32[64,32]{1,0} all-gather(y_unscaled), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
+ ROOT dot_a = f32[16,32] dot(all_gather, all_gather1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+)";
+
+ HloModuleConfig config = GetModuleConfigForTest();
+ config.set_use_spmd_partitioning(true);
+ config.set_num_partitions(8);
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK: [[AG:%[^ ]+]] = <<F8E4M3>>[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}}
+; CHECK: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK: [[AG1:%[^ ]+]] = <<F8E4M3>>[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}}
+; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0}
+; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: ROOT [[GEMM:%[^_]+]] = f32[16,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+ )",
+ nullptr, &config);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ absl::string_view hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ all_to_all = f32[16,32]{1,0} all-to-all(x_unscaled), channel_id=1, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={0}
+ ROOT dot_a = f32[16,16] dot(all_to_all, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ }
+)";
+
+ HloModuleConfig config = GetModuleConfigForTest();
+ config.set_use_spmd_partitioning(true);
+ config.set_num_partitions(8);
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK: [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}}
+; CHECK: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )",
+ nullptr, &config);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDWithCollectivePermuteF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ absl::string_view hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[16,32] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[16,32] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+ collective_permute = f32[16,32]{1,0} collective-permute(x_unscaled), source_target_pairs={{0,0}, {1,1}, {2,4}, {3,5}, {4,2}, {5,3}, {6,6}, {7,7}}
+ ROOT dot_a = f32[16,16] dot(collective_permute, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ }
+)";
+
+ HloModuleConfig config = GetModuleConfigForTest();
+ config.set_use_spmd_partitioning(true);
+ config.set_num_partitions(8);
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK: [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}}
+; CHECK: [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK: [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )",
+ nullptr, &config);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDMatrixBiasThenVectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f16 = f16[16,32] convert(x)
+ y_f16 = f16[32,16] convert(y)
+ b = f16[16] parameter(2)
+ b_bcast = f16[16,16] broadcast(b), dimensions={1}
+ b2 = f16[16,16] parameter(3)
+ x_scale = f16[] parameter(4)
+ y_scale = f16[] parameter(5)
+ x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+ y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+ dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot_a_bias1 = f16[16,16] add(dot_a, b2)
+ ROOT dot_a_bias = f16[16,16] add(dot_a_bias1, b_bcast)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
+; CHECK-DAG: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT: [[MB:%[^ ]+]] = f16[16,16]{1,0} parameter(3)
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT: [[CV0:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(5)
+; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]], /*index=5*/[[C1]], [[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":1
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+; CHECK: [[GEMMOUT:%[^ ]+]] = f16[16,16]{1,0} get-tuple-element([[GEMMOUT_TUPLE]]), index=0
+; CHECK: [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
+; CHECK: [[VBC:%[^ ]+]] = f16[16,16]{1,0} broadcast([[VB]]), dimensions={1}
+; CHECK: ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} add([[GEMMOUT]], [[VBC]])
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] maximum(a, b)
+ }
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ z_scale = f32[] parameter(4)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ abs_dot_a = f32[16,16] abs(dot_a)
+ c0 = f32[] constant(-inf)
+ amax = f32[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
+ dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+ c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+ c2 = f32[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABScaledDWithDAmaxF8WithF16Intermediates) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate
+ // values instead of F32 intermediate values.
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f16[] parameter(0)
+ b = f16[] parameter(1)
+ ROOT c = f16[] maximum(a, b)
+ }
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f16 = f16[16,32] convert(x)
+ y_f16 = f16[32,16] convert(y)
+ x_scale = f16[] parameter(2)
+ y_scale = f16[] parameter(3)
+ z_scale = f16[] parameter(4)
+ x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+ y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+ dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ abs_dot_a = f16[16,16] abs(dot_a)
+ c0 = f16[] constant(-inf)
+ amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
+ dot_a_scaled = f16[16,16] divide(dot_a, z_scale_bcast)
+ c1 = f16[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f16[16,16] broadcast(c1), dimensions={}
+ c2 = f16[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f16[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ ROOT out = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(2)
+; CHECK-NEXT: [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT: [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1)
+; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4)
+; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]])
+; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]),
+; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABScaledDReluActivationWithDAmaxF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] maximum(a, b)
+ }
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E4M3>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ z_scale = f32[] parameter(4)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ czero = f32[] constant(0)
+ czero_bcast = f32[16,16] broadcast(czero), dimensions={}
+ dot_a_relu = f32[16,16] maximum(dot_a, czero_bcast)
+ c0 = f32[] constant(-inf)
+ amax = f32[] reduce(dot_a_relu, c0), dimensions={0,1}, to_apply=apply
+ dot_a_scaled = f32[16,16] divide(dot_a_relu, z_scale_bcast)
+ c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+ c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+ c2 = f32[] constant(<<F8E4M3_AMAX>>)
+ c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+ dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+ dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+ ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
+; CHECK: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"RELU"
+; CHECK: }
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* raw_hlo_template = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[1600,3200] parameter(0)
+ y = <<F8E4M3>>[3200,1600] parameter(1)
+ x_f32 = f32[1600,3200] convert(x)
+ y_f32 = f32[3200,1600] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
+ }
+)";
+
+ std::string hlo_template =
+ absl::StrReplaceAll(raw_hlo_template, replacements_);
+
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<precision>>"] = "default";
+ const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
+ EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
+
+ replacements["<<precision>>"] = "highest";
+ const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
+ EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ std::array<std::array<absl::string_view, 7>, 32> combinations;
+ int i = 0;
+
+ for (bool d_is_col : {false, true}) {
+ for (bool a_is_col : {false, true}) {
+ for (bool b_is_col : {false, true}) {
+ for (int lhs_contracting_dim : {0, 1}) {
+ for (int rhs_contracting_dim : {0, 1}) {
+ const absl::string_view lcd =
+ lhs_contracting_dim == 1 ? "{1}" : "{0}";
+ const absl::string_view rcd =
+ rhs_contracting_dim == 1 ? "{1}" : "{0}";
+ const absl::string_view a_shape =
+ lhs_contracting_dim == 1 ? "[64,32]" : "[32,64]";
+ const absl::string_view b_shape =
+ rhs_contracting_dim == 0 ? "[32,16]" : "[16,32]";
+ const absl::string_view a_layout = a_is_col ? "{0,1}" : "{1,0}";
+ const absl::string_view b_layout = b_is_col ? "{0,1}" : "{1,0}";
+ const absl::string_view output_layout =
+ d_is_col ? "{0,1}" : "{1,0}";
+ combinations[i++] = std::array{
+ lcd, rcd, a_shape, b_shape, a_layout, b_layout, output_layout};
+ }
+ }
+ }
+ }
+ }
+
+ const char* hlo_template = R"(
+ HloModule test
+ ENTRY test {
+ x = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
+ x_f32 = f32<<Ashape>><<Alayout>> convert(x)
+ x_scale = f32[] parameter(2)
+ x_scale_bcast = f32<<Ashape>> broadcast(x_scale), dimensions={}
+ x_unscaled = f32<<Ashape>> multiply(x_f32, x_scale_bcast)
+ y = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
+ y_f32 = f32<<Bshape>><<Blayout>> convert(y)
+ y_scale = f32[] parameter(3)
+ y_scale_bcast = f32<<Bshape>> broadcast(y_scale), dimensions={}
+ y_unscaled = f32<<Bshape>> multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[64,16]<<Olayout>> dot(x_unscaled, y_unscaled), lhs_contracting_dims=<<Lcd>>, rhs_contracting_dims=<<Rcd>>
+ }
+ )";
+ for (const auto& combination : combinations) {
+ absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+ replacements["<<Lcd>>"] = std::get<0>(combination);
+ replacements["<<Rcd>>"] = std::get<1>(combination);
+ replacements["<<Ashape>>"] = std::get<2>(combination);
+ replacements["<<Bshape>>"] = std::get<3>(combination);
+ replacements["<<Alayout>>"] = std::get<4>(combination);
+ replacements["<<Blayout>>"] = std::get<5>(combination);
+ replacements["<<Olayout>>"] = std::get<6>(combination);
+ const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
+ CheckFp8IfSupported(hlo_text);
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+ ; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+ )");
+ }
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+ ScaledABUnscaledDF8ParameterizedBatched) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ // TODO(wenscarl): For batched matmul, not all combinations of A, B and
+ // output layouts get pattern matched successfully to FP8 custom call. Only
+ // a handful of cases are tested here.
+ std::array<std::array<std::string, 7>, 32> combinations;
+ std::string lcd, rcd, a_shape, b_shape, a_layout, b_layout, o_layout;
+ int i = 0;
+ for (bool o_is_col : {false, true}) {
+ for (int lhs_contracting_dim : {2, 1}) {
+ for (int rhs_contracting_dim : {2, 1}) {
+ lcd = lhs_contracting_dim == 2 ? "{2}" : "{1}";
+ rcd = rhs_contracting_dim == 2 ? "{2}" : "{1}";
+ a_shape = lhs_contracting_dim == 2 ? "[2,64,32]" : "[2,32,64]";
+ b_shape = rhs_contracting_dim == 1 ? "[2,32,16]" : "[2,16,32]";
+ o_layout = o_is_col ? "{2, 0, 1}" : "{2, 1, 0}";
+ for (std::string a_layout : {"{2,1,0}", "{1,2,0}"}) {
+ for (std::string b_layout : {"{2,1,0}", "{1,2,0}"}) {
+ combinations[i++] = std::array{lcd, rcd, a_shape, b_shape,
+ a_layout, b_layout, o_layout};
+ }
+ }
+ }
+ }
+ }
+
+ const char* hlo_template = R"(
+ HloModule m
+ENTRY f {
+ x_q = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
+ x_scale = f32[] parameter(2)
+ x_scale_broadcast = f32<<Ashape>><<Alayout>> broadcast(x_scale), dimensions={}
+ x_q_convert = f32<<Ashape>><<Alayout>> convert(x_q)
+ x_qdq = f32<<Ashape>><<Alayout>> multiply(x_q_convert, x_scale_broadcast)
+
+ y_q = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
+ y_scale = f32[] parameter(3)
+ y_scale_broadcast = f32<<Bshape>><<Blayout>> broadcast(y_scale), dimensions={}
+ y_q_convert = f32<<Bshape>><<Blayout>> convert(y_q)
+ y_qdq = f32<<Bshape>><<Blayout>> multiply(y_q_convert, y_scale_broadcast)
+
+ ROOT out = f32[2,64,16]<<Olayout>> dot(x_qdq, y_qdq), lhs_batch_dims={0}, lhs_contracting_dims=<<Lcd>>, rhs_batch_dims={0}, rhs_contracting_dims=<<Rcd>>
+}
+ )";
+
+ for (const auto& combination : combinations) {
+ absl::flat_hash_map<std::string, std::string> replacements;
+ replacements["<<Lcd>>"] = std::get<0>(combination);
+ replacements["<<Rcd>>"] = std::get<1>(combination);
+ replacements["<<Ashape>>"] = std::get<2>(combination);
+ replacements["<<Bshape>>"] = std::get<3>(combination);
+ replacements["<<Alayout>>"] = std::get<4>(combination);
+ replacements["<<Blayout>>"] = std::get<5>(combination);
+ replacements["<<Olayout>>"] = std::get<6>(combination);
+
+ const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
+ CheckFp8IfSupported(hlo_text);
+
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+ ; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+ )");
+ }
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = <<F8E4M3>>[16,32] parameter(0)
+ y = <<F8E5M2>>[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+
+)";
+
+ CheckFp8IfSupported(hlo_text);
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+ ; CHECK: custom_call_target="__cublas$lt$matmul$f8",
+ )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+ GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif // TF_ROCM_VERSION < 60000
+
+ // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them
+ const char* hlo_text = R"(
+ HloModule test
+
+ ENTRY test {
+ x = f8e4m3fnuz[16,32] parameter(0)
+ y = f8e4m3fnuz[32,16] parameter(1)
+ x_f32 = f32[16,32] convert(x)
+ y_f32 = f32[32,16] convert(y)
+ x_scale = f32[] parameter(2)
+ y_scale = f32[] parameter(3)
+ x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+ y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+ x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+ y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+ ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+)";
+#if GOOGLE_CUDA
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+ EXPECT_FALSE(changed);
+#endif
+#if TENSORFLOW_USE_ROCM
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2}));
+ RunAndFilecheckHloRewrite(
+ hlo_text,
+ GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+ GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+ R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0)
+; CHECK-PTX-NEXT: [[P0_CV:%[^ ]+]] = f32[16,32]{1,0} convert([[P0]])
+; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-PTX-NEXT: [[P2_B:%[^ ]+]] = f32[16,32]{1,0} broadcast([[P2]]), dimensions={}
+; CHECK-PTX-NEXT: [[P0_UNSCALED:%[^ ]+]] = f32[16,32]{1,0} multiply([[P0_CV]], [[P2_B]])
+; CHECK-PTX-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
+; CHECK-PTX-NEXT: [[P1_CV:%[^ ]+]] = f32[32,16]{1,0} convert([[P1]])
+; CHECK-PTX-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-PTX-NEXT: [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={}
+; CHECK-PTX-NEXT: [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]])
+; CHECK-PTX-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]),
+; CHECK-GCN-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
+; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8",
+; CHECK: backend_config={
+; CHECK-DAG: "alpha_real":1
+; CHECK-DAG: "alpha_imag":0
+; CHECK-DAG: "beta":0
+; CHECK-DAG: "dot_dimension_numbers":{
+; CHECK-DAG: "lhs_contracting_dimensions":["1"]
+; CHECK-PTX-DAG: "rhs_contracting_dimensions":["0"]
+; CHECK-GCN-DAG: "rhs_contracting_dimensions":["1"]
+; CHECK-DAG: "lhs_batch_dimensions":[]
+; CHECK-DAG: "rhs_batch_dimensions":[]
+; CHECK-DAG: }
+; CHECK-DAG: "precision_config":{
+; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG: }
+; CHECK-DAG: "epilogue":"DEFAULT"
+; CHECK: }
+ )");
+#endif
+}
+
+INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt,
+ ParameterizedFp8GemmRewriteTest, ::testing::Bool());
+#endif
+
+TEST_F(GemmRewriteTest, NoFuseBiasBroadcast) {
+ const char* hlo = R"(
+
+HloModule module
+
+ENTRY main.10 {
+ Arg_0.1 = f16[384,128]{1,0} parameter(0)
+ Arg_1.2 = f16[128,256]{1,0} parameter(1)
+ dot.4 = f16[384,256]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ Arg_2.3 = f16[256]{0} parameter(2)
+ reshape.5 = f16[1,256]{1,0} reshape(Arg_2.3)
+ broadcast.6 = f16[1,256]{1,0} broadcast(reshape.5), dimensions={0,1}
+ reshape.7 = f16[256]{0} reshape(broadcast.6)
+ broadcast.8 = f16[384,256]{1,0} broadcast(reshape.7), dimensions={1}
+ ROOT add.9 = f16[384,256]{1,0} add(dot.4, broadcast.8)
+})";
+
+ MatchOptimizedHlo(hlo, R"(
+// CHECK: "beta":0
+ )");
+}
+
+TEST_F(GemmRewriteTest, ReduceOfBatchDot) {
+ absl::string_view hlo_string =
+ R"(
+HloModule test
+
+region_5.50 {
+ Arg_0.51 = f32[] parameter(0)
+ Arg_1.52 = f32[] parameter(1)
+ ROOT add.53 = f32[] add(Arg_0.51, Arg_1.52)
+}
+
+ENTRY main {
+ p0 = bf16[3,32,3,13]{3,2,1,0} parameter(0)
+ p1 = bf16[3,32,3,64]{3,2,1,0} parameter(1)
+ dot.95 = bf16[3,3,13,64]{3,2,1,0} dot(p0, p1), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}, operand_precision={highest,highest}
+ transpose.96 = bf16[3,64,3,13]{1,3,2,0} transpose(dot.95), dimensions={0,3,1,2}
+ convert.101 = f32[3,64,3,13]{1,3,2,0} convert(transpose.96)
+ constant.66 = f32[] constant(0.0)
+ ROOT reduce.102 = f32[3,64,13]{2,1,0} reduce(convert.101, constant.66), dimensions={2}, to_apply=region_5.50
+}
+)";
+ // Make sure the dot is lowered to a custom call. There is an algebraic
+ // simplifier simplification which could turn the dot into a non-canonical dot
+ // late in the pipeline, which will make it unsupported by the GemmRewriter.
+ MatchOptimizedHlo(hlo_string, R"(
+ // CHECK: custom_call_target="__cublas$gemm"
+ )");
+}
+
+TEST_F(GemmRewriteTest, DotWithBias) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY main {
+ p0 = f32[1024,1024] parameter(0)
+ p1 = f32[1024,1024] parameter(1)
+ p2 = f32[1024,1024] parameter(2)
+ p3 = f32[1024,1024] parameter(3)
+ dot0 = f32[1024,1024] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot1 = f32[1024,1024] dot(p2, p3),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT root = f32[1024,1024] add(dot0, dot1)
+ })";
+
+ const char* expected = R"()
+ // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0)
+ // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1)
+ // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2)
+ // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3)
+ // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]])
+ // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0
+ // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]], %[[S0]])
+ // CHECK: ROOT %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0
+ })";
+
+ RunAndFilecheckHloRewrite(
+ hlo,
+ GemmRewriter(
+ se::CudaComputeCapability{}, /*toolkit_version=*/0,
+ GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only}),
+ expected);
+}
+
+TEST_F(GemmRewriteTest, DotWithoutBias) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY main {
+ p0 = f32[1024,1024] parameter(0)
+ p1 = f32[1024,1024] parameter(1)
+ p2 = f32[1024,1024] parameter(2)
+ p3 = f32[1024,1024] parameter(3)
+ dot0 = f32[1024,1024] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ dot1 = f32[1024,1024] dot(p2, p3),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT root = f32[1024,1024] add(dot0, dot1)
+ })";
+
+ const char* expected = R"()
+ // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0)
+ // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1)
+ // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]])
+ // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0
+ // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2)
+ // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3)
+ // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]])
+ // CHECK: %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0
+ // CHECK: ROOT %[[S2:.*]] = f32[1024,1024]{1,0} add(%[[S0]], %[[S1]])
+ })";
+
+ RunAndFilecheckHloRewrite(
+ hlo,
+ GemmRewriter(se::CudaComputeCapability{}, /*toolkit_version=*/0,
+ GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only,
+ GemmRewriterOptions::BiasMode::kNoBias}),
+ expected);
+}
+
+TEST_F(CublasLtGemmRewriteTest, CublasLtSuccessfullyMatchesLargeC64Lhs) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ p0 = c64[2000,3000,3]{2,1,0} parameter(0)
+ p1 = c64[3,6]{1,0} parameter(1)
+ ROOT dot = c64[2000,3000,6]{2,1,0} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+}
+)";
+ // Large lhs is fine for cuBLASlt.
+ MatchOptimizedHlo(hlo_text,
+ R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
+}
+
+TEST_F(CublasLtGemmRewriteTest, CublasLtOnlyMatchesLargeC64RhsPostAmpere) {
+ const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+ p0 = c64[6,3]{1,0} parameter(0)
+ p1 = c64[3,2000,3000]{2,1,0} parameter(1)
+ ROOT dot = c64[6,2000,3000]{2,1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+ if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+ // From Ampere onwards, cuBLASlt supports large rhs.
+ MatchOptimizedHlo(hlo_text,
+ R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
+ } else {
+ // Rhs with non-contracting dimensions > 4194240 (combined) is not fine for
+ // C64 type.
+ MatchOptimizedHlo(
+ hlo_text, R"(; CHECK-NOT: custom_call_target="__cublas$lt$matmul")");
+ }
+}
+
+class GemmRewriteAllocationTest : public GpuCodegenTest {
+ public:
+ void CheckNumberOfAllocations(const std::string& hlo,
+ int expected_number_of_allocations) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+ GetOptimizedModule(hlo));
+ if (allocator_ == nullptr) {
+ allocator_ = std::make_unique<se::StreamExecutorMemoryAllocator>(
+ backend().default_stream_executor());
+ }
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Executable> executable,
+ backend().compiler()->RunBackend(std::move(optimized_module),
+ backend().default_stream_executor(),
+ allocator_.get()));
+ GpuExecutable* gpu_executable =
+ static_cast<GpuExecutable*>(executable.get());
+ absl::Span<const BufferAllocation> allocations =
+ gpu_executable->GetAllocations();
+ ASSERT_EQ(allocations.size(), expected_number_of_allocations);
+ }
+
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+ // Make sure the rewriter does not skip the rewrite for being too small.
+ debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+ return debug_options;
+ }
+
+ private:
+ std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
+};
+
+TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) {
+ const char* hlo_text = R"(
+HloModule SharedBufferAssignment
+
+ENTRY AddDotsFunc {
+ x = f32[2,2] parameter(0)
+ y = f32[2,2] parameter(1)
+ bias = f32[2,2] add(x, y)
+ dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT out = f32[2,2] add(dot, bias)
+}
+
+)";
+
+ // Bias should be fused into the multiplication.
+ CheckNumberOfAllocations(hlo_text, 4);
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+class SmallDotGemmRewriteTest : public GemmRewriteTest {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
+ return debug_options;
+ }
+};
+
+TEST_F(SmallDotGemmRewriteTest, SkipSmallMatrixMultiplicationRewrite) {
+ const char* hlo_text = R"(
+HloModule SkipSmallMatrixRewrite
+
+ENTRY DotFunc {
+ x = f32[3,3] parameter(0)
+ y = f32[3,3] parameter(1)
+ ROOT out = f32[3,3] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[3,3], {{.*}}: f32[3,3]) -> f32[3,3] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,3]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,3]{1,0} parameter(1)
+; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} dot([[P0]], [[P1]]),
+; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0}
+)");
+}
+
+TEST_F(SmallDotGemmRewriteTest, LargeMatrixMultiplicationIsRewritten) {
+ const char* hlo_text = R"(
+HloModule SkipSmallMatrixRewrite
+
+ENTRY DotFunc {
+ x = f32[8,8] parameter(0)
+ y = f32[8,8] parameter(1)
+ ROOT out = f32[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[8,8], {{.*}}: f32[8,8]) -> f32[8,8] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8,8]{1,0} parameter(1)
+; CHECK: {{[^ ]+}} = {{.*}} custom-call([[P0]], [[P1]])
+)");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc
new file mode 100644
index 0000000..fddb9e6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc
@@ -0,0 +1,183 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemv_rewriter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/shape.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// Construct a new layout by adding a new minor-most dimension to the input
+// layout. For example, {3, 2, 1, 0} is extended to {4, 3, 2, 1, 0}.
+// We expect that the input layout is normalized by LayoutNormalizer, so that
+// the input layout has a descending ordering.
+absl::StatusOr<Layout> GetLayoutWithNewMinorMostDimension(
+ const Layout& layout) {
+ // Check that the layout is normalized.
+ if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) {
+ return absl::InvalidArgumentError("Layout is not normalized.");
+ }
+ return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() + 1);
+}
+
+class GemvRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleDot(HloInstruction* instr) override {
+ HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
+ const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers();
+ HloInstruction* lhs = dot->mutable_operand(0);
+ HloInstruction* rhs = dot->mutable_operand(1);
+
+ // This pass relies on dot decomposer which ensures that all non-batch
+ // dimensions are merged into one.
+ bool lhs_has_non_contracting_dim =
+ lhs->shape().rank() ==
+ dim_numbers.lhs_batch_dimensions_size() +
+ dim_numbers.lhs_contracting_dimensions_size() + 1;
+ bool rhs_has_non_contracting_dim =
+ rhs->shape().rank() ==
+ dim_numbers.rhs_batch_dimensions_size() +
+ dim_numbers.rhs_contracting_dimensions_size() + 1;
+
+ // Skip matrix-matrix multiplication.
+ if (lhs_has_non_contracting_dim && rhs_has_non_contracting_dim) {
+ return absl::OkStatus();
+ }
+
+ // Skip vector-vector multiplication.
+ if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) {
+ return absl::OkStatus();
+ }
+
+ if (dot->shape().is_dynamic()) {
+ return absl::OkStatus();
+ }
+
+ changed_ = true;
+
+ HloComputation* computation = dot->parent();
+ HloInstruction* new_lhs = lhs;
+ if (!lhs_has_non_contracting_dim) {
+ const Shape& lhs_shape = lhs->shape();
+ absl::Span<const int64_t> lhs_dimensions = lhs_shape.dimensions();
+ std::vector<int64_t> new_lhs_dimensions(lhs_dimensions.begin(),
+ lhs_dimensions.end());
+ new_lhs_dimensions.push_back(1);
+ Shape new_lhs_shape(
+ lhs_shape.element_type(), new_lhs_dimensions,
+ absl::InlinedVector<bool, 4>(new_lhs_dimensions.size(), false),
+ /*tuple_shapes=*/{});
+ TF_ASSIGN_OR_RETURN(
+ *new_lhs_shape.mutable_layout(),
+ GetLayoutWithNewMinorMostDimension(lhs_shape.layout()));
+ new_lhs = computation->AddInstruction(
+ HloInstruction::CreateBitcast(new_lhs_shape, lhs));
+ }
+
+ HloInstruction* new_rhs = rhs;
+ if (!rhs_has_non_contracting_dim) {
+ const Shape& rhs_shape = rhs->shape();
+ absl::Span<const int64_t> rhs_dimensions = rhs_shape.dimensions();
+ std::vector<int64_t> new_rhs_dimensions(rhs_dimensions.begin(),
+ rhs_dimensions.end());
+ new_rhs_dimensions.push_back(1);
+ Shape new_rhs_shape(
+ rhs_shape.element_type(), new_rhs_dimensions,
+ absl::InlinedVector<bool, 4>(new_rhs_dimensions.size(), false),
+ /*tuple_shapes=*/{});
+ TF_ASSIGN_OR_RETURN(
+ *new_rhs_shape.mutable_layout(),
+ GetLayoutWithNewMinorMostDimension(rhs_shape.layout()));
+ new_rhs = computation->AddInstruction(
+ HloInstruction::CreateBitcast(new_rhs_shape, rhs));
+ }
+
+ std::vector<int64_t> new_out_dimensions;
+ new_out_dimensions.reserve(dot->shape().dimensions().size() + 1);
+ for (int64_t dim_size : dot->shape().dimensions()) {
+ new_out_dimensions.push_back(dim_size);
+ }
+ if (!lhs_has_non_contracting_dim) {
+ // Insert the trivial dimension before the non-contracting dimension from
+ // rhs.
+ int non_contracting_dim_size = new_out_dimensions.back();
+ new_out_dimensions[new_out_dimensions.size() - 1] = 1;
+ new_out_dimensions.push_back(non_contracting_dim_size);
+ } else {
+ new_out_dimensions.push_back(1);
+ }
+
+ Shape new_out_shape(
+ dot->shape().element_type(), new_out_dimensions,
+ absl::InlinedVector<bool, 4>(new_out_dimensions.size(), false),
+ /*tuple_shapes=*/{});
+ TF_ASSIGN_OR_RETURN(
+ *new_out_shape.mutable_layout(),
+ GetLayoutWithNewMinorMostDimension(dot->shape().layout()));
+
+ HloInstruction* new_dot =
+ computation->AddInstruction(HloInstruction::CreateDot(
+ new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(),
+ dot->precision_config()));
+ HloInstruction* bitcast = computation->AddInstruction(
+ HloInstruction::CreateBitcast(dot->shape(), new_dot));
+ return computation->ReplaceInstruction(dot, bitcast);
+ }
+
+ bool changed() const { return changed_; }
+
+ private:
+ bool changed_ = false;
+};
+
+} // namespace
+
+absl::StatusOr<bool> GemvRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ GemvRewriterVisitor gemv_rewriter;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_RETURN_IF_ERROR(computation->Accept(&gemv_rewriter));
+ }
+ return gemv_rewriter.changed();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h
new file mode 100644
index 0000000..9339101
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h
@@ -0,0 +1,44 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrite a matrix-vector or a vector-matrix multiplication into a
+// matrix-matrix multiplication with a trivial dimension. For example,
+// [m x n] @ [n] is rewritten to [m x n] @ [n x 1], and [n] @ [m x n] is
+// rewritten to [n x 1] @ [m x n].
+class GemvRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "gemv-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc
new file mode 100644
index 0000000..d255528
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc
@@ -0,0 +1,149 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemv_rewriter.h"
+
+#include <memory>
+#include <optional>
+
+#include <gtest/gtest.h>
+#include "absl/status/statusor.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+class GemvRewriterTest : public HloTestBase {};
+
+TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationToGemm) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY e {
+ p0 = f32[32,7] parameter(0)
+ p1 = f32[7] parameter(1)
+ ROOT d = f32[32] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ })";
+
+ const char* expected = R"()
+// CHECK: %[[P0:.*]] = f32[32,7]{1,0} parameter(0)
+// CHECK: %[[P1:.*]] = f32[7]{0} parameter(1)
+// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P1]])
+// CHECK: %[[DOT:.*]] = f32[32,1]{1,0} dot(%[[P0]], %[[BITCAST]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+// CHECK: ROOT %[[ROOT:.*]] = f32[32]{0} bitcast(%[[DOT]])
+})";
+
+ RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
+}
+
+TEST_F(GemvRewriterTest, RewriteVectorMatrixMultiplicationToGemm) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY e {
+ p0 = f32[7] parameter(0)
+ p1 = f32[7,32] parameter(1)
+ ROOT d = f32[32] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+ })";
+
+ const char* expected = R"()
+// CHECK: %[[P0:.*]] = f32[7]{0} parameter(0)
+// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P0]])
+// CHECK: %[[P1:.*]] = f32[7,32]{1,0} parameter(1)
+// CHECK: %[[DOT:.*]] = f32[1,32]{1,0} dot(%[[BITCAST]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+// CHECK: ROOT %[[ROOT:.*]].1 = f32[32]{0} bitcast(%[[DOT]])
+})";
+
+ RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
+}
+
+TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationWithBatch) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY e {
+ p0 = f32[2,5,32,7] parameter(0)
+ p1 = f32[2,5,7] parameter(1)
+ ROOT d = f32[2,5,32] dot(p0, p1),
+ lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
+ lhs_contracting_dims={3}, rhs_contracting_dims={2}
+ })";
+
+ const char* expected = R"()
+// CHECK: %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0)
+// CHECK: %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1)
+// CHECK: %[[BITCAST:.*]] = f32[2,5,7,1]{3,2,1,0} bitcast(%[[P1]])
+// CHECK: %[[DOT:.*]] = f32[2,5,32,1]{3,2,1,0} dot(%[[P0]], %[[BITCAST]]),
+// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+// CHECK: ROOT %[[ROOT:.*]] = f32[2,5,32]{2,1,0} bitcast(%[[DOT]])
+})";
+
+ RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
+}
+
+TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY e {
+ p0 = f32[7] parameter(0)
+ p1 = f32[7] parameter(1)
+ ROOT d = f32[] dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={0}
+ })";
+
+ RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
+}
+
+TEST_F(GemvRewriterTest, DotNotRewriteMatrixMatrixMultiplication) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY e {
+ p0 = f32[5,7] parameter(0)
+ p1 = f32[7,32] parameter(1)
+ ROOT d = f32[5,32] dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ })";
+
+ RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
+}
+
+TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) {
+ const char* hlo = R"(
+ HloModule m
+
+ ENTRY e {
+ p0 = f32[5,32,7]{2,1,0} parameter(0)
+ p1 = f32[5,7]{0,1} parameter(1)
+ ROOT d = f32[5,32]{0,1} dot(p0, p1),
+ lhs_batch_dims={0}, rhs_batch_dims={0},
+ lhs_contracting_dims={2}, rhs_contracting_dims={1}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo));
+ GemvRewriter rewriter;
+ absl::StatusOr<bool> result = this->RunHloPass(&rewriter, module.get());
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(result.status().message(), "Layout is not normalized.");
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc
new file mode 100644
index 0000000..ef78dbd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc
@@ -0,0 +1,201 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gpusolver_rewriter.h"
+
+#include <cstdint>
+#include <functional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/cusolver_context.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+void SetFortranLayout(Shape* shape) {
+ LayoutUtil::SetToDefaultLayout(shape);
+ int n = shape->mutable_layout()->minor_to_major_size();
+ CHECK_GE(n, 2);
+ std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
+ shape->mutable_layout()->mutable_minor_to_major()->at(1));
+}
+
+absl::StatusOr<HloInstruction*> CreateCholesky(GpuSolverContext* context,
+ HloInstruction* operand,
+ const CholeskyOptions& options,
+ const OpMetadata& metadata) {
+ HloComputation* computation = operand->parent();
+
+ Shape a_shape = operand->shape();
+ int ndim = a_shape.dimensions_size();
+ CHECK_GE(ndim, 2);
+ int64_t n = a_shape.dimensions(ndim - 1);
+
+ std::vector<int64_t> batch_dims(a_shape.dimensions().begin(),
+ a_shape.dimensions().end() - 2);
+ std::vector<int64_t> batch_dim_ids(batch_dims.size());
+ absl::c_iota(batch_dim_ids, 0);
+ int64_t batch_size = absl::c_accumulate(batch_dims, 1, std::multiplies<>{});
+
+ // Find the workspace size.
+ se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
+ : se::blas::UpperLower::kUpper;
+ int64_t workspace_size; // Number of elements of size a_shape.element_type()
+ TF_ASSIGN_OR_RETURN(
+ workspace_size,
+ context->PotrfBufferSize(a_shape.element_type(), uplo, n, n, batch_size));
+
+ // TODO(phawkins): Ideally we would relax this constraint. What we actually
+ // want is that:
+ // a) the batch dimensions are major, in no particular order.
+ // b) the two minor dimensions are in fortran (column-major) order,
+
+ SetFortranLayout(&a_shape);
+
+ // This call returns a tuple of (cholesky_result, workspace, info) where:
+ // * cholesky_result is the result of the Cholesky decomposition,
+ // * workspace is temporary scratch memory used by cuSolver.
+ // * info contains the Potrf success/failure status.
+ // Currently we have no meaningful way to report an error, so we simply
+ // discard the success/failure information. Obviously this is suboptimal.
+ Shape info_shape = ShapeUtil::MakeShape(S32, batch_dims);
+ Shape call_shape = ShapeUtil::MakeTupleShape(
+ {a_shape,
+ ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
+ info_shape});
+
+ HloInstruction* custom_call =
+ computation->AddInstruction(HloInstruction::CreateCustomCall(
+ call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
+ custom_call->set_metadata(metadata);
+ TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
+ HloInstruction* out = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(a_shape, custom_call, 0));
+ HloInstruction* info = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(info_shape, custom_call, 2));
+
+ // If info was non-zero, indicating that the Cholesky decomposition failed,
+ // returns an array full of NaNs for the corresponding batch element.
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
+ HloInstruction* zeros =
+ computation->AddInstruction(HloInstruction::CreateBroadcast(
+ info_shape, zero, /*broadcast_dimensions=*/{}));
+ HloInstruction* ok = computation->AddInstruction(
+ HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, batch_dims),
+ info, zeros, ComparisonDirection::kEq));
+ ok = computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(PRED, a_shape.dimensions()), ok,
+ /*broadcast_dimensions=*/batch_dim_ids));
+
+ TF_ASSIGN_OR_RETURN(Literal nan_literal,
+ LiteralUtil::NanValue(a_shape.element_type()));
+ HloInstruction* nan = computation->AddInstruction(
+ HloInstruction::CreateConstant(std::move(nan_literal)));
+ HloInstruction* nans =
+ computation->AddInstruction(HloInstruction::CreateBroadcast(
+ a_shape, nan, /*broadcast_dimensions=*/{}));
+
+ HloInstruction* select =
+ computation->AddInstruction(HloInstruction::CreateTernary(
+ a_shape, HloOpcode::kSelect, ok, out, nans));
+ return select;
+}
+
+// Tries to rewrite a single convolution into a call to cudnn.
+absl::StatusOr<bool> RunOnInstruction(GpuSolverContext* context,
+ HloInstruction* instruction) {
+ if (instruction->opcode() != HloOpcode::kCholesky) {
+ return false;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * custom_call,
+ CreateCholesky(context, instruction->mutable_operand(0),
+ instruction->cholesky_options(), instruction->metadata()));
+
+ VLOG(1) << "Replacing " << instruction->ToString() << " with "
+ << custom_call->ToString();
+
+ TF_RETURN_IF_ERROR(
+ instruction->parent()->ReplaceInstruction(instruction, custom_call));
+ return true;
+}
+
+} // namespace
+
+// Rewrites the convolutions in the given computation into calls to cudnn.
+// Returns true if it made any changes.
+absl::StatusOr<bool> GpusolverRewriter::RunOnComputation(
+ HloComputation* computation) {
+ std::vector<HloInstruction*> cusolver_calls;
+ for (auto* hlo : computation->instructions()) {
+ if (hlo->opcode() == HloOpcode::kCholesky) {
+ cusolver_calls.push_back(hlo);
+ }
+ }
+
+ if (cusolver_calls.empty()) {
+ return false;
+ }
+
+ TF_ASSIGN_OR_RETURN(GpuSolverContext context, GpuSolverContext::Create());
+
+ bool changed = false;
+ for (HloInstruction* instruction : cusolver_calls) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(&context, instruction));
+ changed |= result;
+ }
+ return changed;
+}
+
+GpusolverRewriter::GpusolverRewriter() = default;
+
+absl::StatusOr<bool> GpusolverRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+ changed |= result;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h
new file mode 100644
index 0000000..cdc0ff2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h
@@ -0,0 +1,47 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver.
+class GpusolverRewriter : public HloModulePass {
+ public:
+ GpusolverRewriter();
+ absl::string_view name() const override { return "gpusolver-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc
new file mode 100644
index 0000000..befe869
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc
@@ -0,0 +1,192 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_input_fusion.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// Gets the representative input shape of the multi-output fusion.
+Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
+ // Get the HLO that determines the emitter used for lowering.
+ const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
+ if (real_hero->operands().empty()) {
+ // Simply return an empty shape if the representative node has no input
+ // operands.
+ return Shape();
+ } else {
+ return real_hero->operand(0)->shape();
+ }
+}
+
+class HorizontalInputFusionImpl {
+ public:
+ explicit HorizontalInputFusionImpl(HloComputation* computation,
+ const se::DeviceDescription& d)
+ : computation_(computation), device_info_(d) {}
+
+ ~HorizontalInputFusionImpl() = default;
+
+ absl::StatusOr<bool> Run();
+
+ private:
+ HloComputation* computation_;
+ const se::DeviceDescription& device_info_;
+}; // HorizontalInputFusionImpl
+
+// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
+// right.
+bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
+ const Shape& shape_b) {
+ if (shape_a.rank() != shape_b.rank()) {
+ return shape_a.rank() < shape_b.rank();
+ }
+ auto dims_a = shape_a.dimensions();
+ auto dims_b = shape_b.dimensions();
+ for (size_t i = 0; i < dims_a.size(); ++i) {
+ if (dims_a[i] != dims_b[i]) {
+ return dims_a[i] < dims_b[i];
+ }
+ }
+ return true;
+}
+
+std::vector<HloInstruction*> FindAndSortFusionCandidates(
+ HloInstruction* consumer) {
+ absl::flat_hash_set<HloInstruction*> fusion_instr_set;
+ std::vector<HloInstruction*> fusion_instrs;
+ for (HloInstruction* opnd : consumer->operands()) {
+ HloInstruction* predecessor = opnd->LatestNonGteAncestor();
+ // Find out the input fusion instructions whose only consumer is `consumer`.
+ // This guarantees that fusing these candidates will never create cycles, as
+ // there is no back edge.
+ if (IsInputFusibleReduction(*predecessor) &&
+ IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
+ if (fusion_instr_set.insert(predecessor).second) {
+ fusion_instrs.push_back(predecessor);
+ }
+ }
+ }
+
+ std::sort(fusion_instrs.begin(), fusion_instrs.end(),
+ [&](const HloInstruction* a, const HloInstruction* b) {
+ Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
+ Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
+ if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
+ // Sort shapes according to dimensions, so that the same input
+ // shapes will be placed adjacent each other.
+ return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
+ }
+ // Sort `fusion_instrs` according to instruction counts, because
+ // we'd like to fuse together computations of similar sizes.
+ return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
+ });
+
+ return fusion_instrs;
+}
+
+absl::StatusOr<bool> HorizontalInputFusionImpl::Run() {
+ bool changed = false;
+ XLA_VLOG_LINES(3, computation_->ToString());
+
+ // Using def-to-use order is sound since we do not modify users.
+ std::vector<HloInstruction*> def_to_use_order =
+ computation_->MakeInstructionPostOrder();
+ for (HloInstruction* consumer : def_to_use_order) {
+ auto candidates = FindAndSortFusionCandidates(consumer);
+ if (candidates.size() <= 1) {
+ continue;
+ }
+
+ // Convert candidates into fusions if needed.
+ for (size_t j = 0; j < candidates.size(); ++j) {
+ if (candidates[j]->opcode() != HloOpcode::kFusion) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * fusion_instr,
+ MakeFusionInstruction(candidates[j],
+ HloInstruction::FusionKind::kInput));
+ candidates[j] = fusion_instr;
+ changed = true;
+ }
+ }
+
+ size_t fusion_anchor_id = 0;
+ for (size_t j = 1; j < candidates.size(); ++j) {
+ HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
+ HloInstruction* fused = candidates[j];
+ if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
+ FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) {
+ VLOG(3) << "Fuse " << fused->ToString() << " into "
+ << fusion_anchor->ToString();
+ fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
+ changed = true;
+ } else {
+ // Update the `fusion_anchor_id` since `fused` is either not
+ // compatible or not beneficial to be fused with current fusion anchor.
+ VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
+ fusion_anchor_id = j;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace
+
+absl::StatusOr<bool> HorizontalInputFusion::RunOnComputation(
+ HloComputation* computation) {
+ HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_);
+ return horizontal_fusion_impl.Run();
+}
+
+absl::StatusOr<bool> HorizontalInputFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ VLOG(2) << "Run horizontal input fusion.";
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
+ }
+
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h
new file mode 100644
index 0000000..a08168d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h
@@ -0,0 +1,63 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// This optimization pass horizontally fuses kInput fusions to both reduce the
+// kernel launch overhead and increase parallelism degree. See
+// HorizontalLoopFusion for general description and motivation about horizontal
+// fusion. HorizontalLoopFusion deals with kLoop fusions while this pass deals
+// with kInput fusions.
+//
+// Following HorizontalLoopFusion, a simple yet effective heuristic is used
+// to search the fusion candidates while avoiding creating cycles. That is,
+// we simply search for fusion candidates by looking for instructions whose
+// outputs are all consumed by the same instruction. This catches the typical
+// target cases; often, the candidate instructions are just consumed by the
+// ROOT tuple of the entry computation.
+class HorizontalInputFusion : public HloModulePass {
+ public:
+ explicit HorizontalInputFusion(const se::DeviceDescription& d)
+ : device_info_(d) {}
+
+ absl::string_view name() const override { return "horizontal_input_fusion"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ absl::StatusOr<bool> RunOnComputation(HloComputation*);
+
+ const se::DeviceDescription& device_info_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc
new file mode 100644
index 0000000..5fc1a54
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc
@@ -0,0 +1,270 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_input_fusion.h"
+
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class HorizontalInputFusionTest : public GpuCodegenTest {
+ public:
+ se::DeviceDescription device_description_{
+ TestGpuDeviceInfo::RTXA6000DeviceInfo()};
+ HorizontalInputFusion horizontal_input_fusion_{device_description_};
+};
+
+TEST_F(HorizontalInputFusionTest, BasicTest) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule BasicTest
+
+ %add_f16 {
+ %x = f16[] parameter(0)
+ %y = f16[] parameter(1)
+ ROOT %add = f16[] add(%x, %y)
+ }
+
+ fused_computation.1 {
+ arg.1 = f16[1024]{0} parameter(0)
+ constant0 = f16[] constant(0)
+ ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+ }
+
+ fused_computation.2 {
+ arg.1 = f16[1024]{0} parameter(0)
+ constant0 = f16[] constant(0)
+ ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
+ fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
+ ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
+ }
+)")
+ .value();
+
+ EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
+
+ const HloInstruction* entry_root =
+ module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(entry_root,
+ GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
+ (m::GetTupleElement(m::Fusion())))));
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
+ auto module = CreateNewVerifiedModule();
+
+ HloComputation* reduce_computation;
+ {
+ auto embedded_builder = HloComputation::Builder("add");
+ auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
+ auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
+ embedded_builder.AddInstruction(
+ HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
+ reduce_computation =
+ module->AddEmbeddedComputation(embedded_builder.Build());
+ }
+
+ HloComputation::Builder builder(TestName());
+ std::vector<HloInstruction*> var_outs;
+ auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
+ auto output_shape = ShapeUtil::MakeShape(F32, {1024});
+ for (int64_t i = 0; i < 130; ++i) {
+ // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
+ // f32[1024] {
+ // %param_0 = f32[1024,1024]{1,0} parameter(0)
+ // %param_1 = f32[] parameter(1)
+ // %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
+ // dimensions={}
+ // %multiply = f32[1024,1024]{1,0}
+ // multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
+ // %broadcast)
+ // %constant0 = f32[] constant(0)
+ // ROOT %reduce = f32[1024]{0}
+ // reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
+ // dimensions={1}, to_apply=%add
+ // }
+ HloInstruction* param_var_in = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
+ HloInstruction* param_alpha =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
+ auto alpha_broadcasted = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
+ auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+ input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
+ output_shape, mul, const0, {1}, reduce_computation));
+ var_outs.push_back(reduce);
+ }
+ builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
+ module->AddEntryComputation(builder.Build());
+
+ // Verify that horizontal fusion is kicked in. Check that there are multiple
+ // `reduce` instructions fused into the same fusion.
+ if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() < 4) {
+ // 6 is just a randomly picked number as we don't exactly know how large the
+ // fusion will be created due to the `FusionFitsInBudget` constraint.
+ CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
+ /*match_optimized_ir=*/false);
+ } else {
+ // Verify that we produced a multi-output reduction with independent groups.
+ CompileAndVerifyIr(module->Clone(), R"(CHECK: switch {{.*}} label {{.*}} [
+ CHECK-NEXT: label)",
+ /*match_optimized_ir=*/false);
+ }
+
+ // Testing with the entire gpu optimization pipeline.
+ EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
+ // This tests the below pattern. One known issue is that gtes (to fusions) can
+ // be removed after their producer fusions are merged. In the below case, gte2
+ // and gte6 will be gone if Fusion2 is fused into Fusion1.
+ //
+ // Fusion1 Fusion2
+ // | | | |
+ // | gte1 gte2 |
+ // | | | |
+ // | Fusion3 |
+ // | | | |
+ // gte3 gte4 gte5 gte6
+ // \ | | /
+ // =====ROOT=====
+ //
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule MultiOutputFusionTest
+
+ %add_f16 {
+ %x = f16[] parameter(0)
+ %y = f16[] parameter(1)
+ ROOT %add = f16[] add(%x, %y)
+ }
+
+ fused_computation.1 {
+ arg.1 = f16[1024]{0} parameter(0)
+ constant0 = f16[] constant(0)
+ reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+ add.0 = f16[1024] add(arg.1, arg.1)
+ ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
+ }
+
+ fused_computation.2 {
+ arg.1 = f16[1024]{0} parameter(0)
+ constant0 = f16[] constant(0)
+ reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+ add.0 = f16[1024] add(arg.1, arg.1)
+ ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
+ }
+
+ fused_computation.3 {
+ arg.0 = f16[1024]{0} parameter(0)
+ arg.1 = f16[1024]{0} parameter(1)
+ add.0 = f16[1024] add(arg.0, arg.1)
+ mul.0 = f16[1024] multiply(arg.0, arg.1)
+ ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
+ fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
+ gte.3 = f16[] get-tuple-element(fusion.1), index=0
+ gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
+ gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
+ gte.6 = f16[] get-tuple-element(fusion.2), index=0
+ fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
+ kind=kLoop, calls=fused_computation.3
+ gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
+ gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
+ ROOT tuple.1 = (f16[], f16[1024], f16[1024]{0}, f16[])
+ tuple(gte.3, gte.4, gte.5, gte.6)
+ }
+)")
+ .value();
+
+ EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
+}
+
+TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NonfusionInstrs
+
+ %add_f16 {
+ %x = f16[] parameter(0)
+ %y = f16[] parameter(1)
+ ROOT %add = f16[] add(%x, %y)
+ }
+
+ ENTRY entry_computation {
+ arg.0 = f16[1024]{0} parameter(0)
+ arg.1 = f16[1024]{0} parameter(1)
+ constant0 = f16[] constant(0)
+ reduce.0 = f16[] reduce(arg.0, constant0), dimensions={0}, to_apply=%add_f16
+ reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+ ROOT tuple.0 = (f16[], f16[]) tuple(reduce.0, reduce.1)
+ }
+)")
+ .value();
+
+ EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
+
+ const HloInstruction* entry_root =
+ module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(entry_root,
+ GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
+ (m::GetTupleElement(m::Fusion())))));
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc
new file mode 100644
index 0000000..0a3d705
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc
@@ -0,0 +1,744 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/sub_byte_normalization.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
+ auto outputs = GetOutputsOfFusible(fusible);
+ CHECK(!outputs.empty());
+ PrimitiveType first_output_type = outputs[0]->shape().element_type();
+ for (size_t i = 1; i < outputs.size(); ++i) {
+ PrimitiveType cur_output_type = outputs[i]->shape().element_type();
+ CHECK(first_output_type == cur_output_type)
+ << "Output types are expected to be unique, but see "
+ << PrimitiveType_Name(first_output_type) << " and "
+ << PrimitiveType_Name(cur_output_type);
+ }
+
+ return first_output_type;
+}
+
+class HorizontalLoopFusionImpl {
+ public:
+ explicit HorizontalLoopFusionImpl(HloComputation* computation,
+ absl::string_view prefix)
+ : computation_(computation), prefix_(prefix) {}
+
+ ~HorizontalLoopFusionImpl() = default;
+
+ absl::StatusOr<bool> Run();
+
+ private:
+ absl::Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs,
+ bool sliced_input_fusion,
+ std::vector<HloInstruction*>& to_fuse_candidates);
+
+ // If `sliced_input_fusion` is true, Horizontally fuses `fused_fusion_instrs`
+ // into kInput computation, else fuses `fused_fusion_instrs` into kLoop
+ // computation.
+ //
+ // It is required that each of `fused_fusion_instrs` is a kLoop fusion. Also,
+ // we require their numbers of outputs to be the same, so that each output
+ // will be fused/concatenated with the same number of outputs from other fused
+ // fusion instrs. Then, all the fused outputs still have the same shapes for
+ // kernel generation.
+ //
+ // Returns the fused computation in `uniq_computation` and the operands that
+ // are used by `uniq_computation`.
+ absl::Status CreateFusedComputation(
+ absl::Span<HloInstruction*> fused_fusion_instrs,
+ std::unique_ptr<HloComputation>* uniq_computation,
+ std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion);
+
+ // Horizontally fuses the operands of consumer instruction,
+ // `sliced_input_fusion` controls whether kInput or kLoop type fused
+ // instruction want to be created. `to_fuse_candidates` is the instruction
+ // stack that we want to try horizontally fuse its operands, when we create a
+ // new fusion instruction, we push it to the stack in hope to further fuse its
+ // operands.
+ absl::StatusOr<bool> FuseConsumerOperands(
+ HloInstruction* consumer, bool sliced_input_fusion,
+ std::vector<HloInstruction*>& to_fuse_candidates);
+
+ // FusionCandidates collects profitable candidates for a given consumer
+ // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
+ // acquire the next set of fusion candidates based on some heuristics.
+ class FusionCandidates {
+ public:
+ explicit FusionCandidates(HloInstruction* consumer,
+ bool sliced_input_fusion)
+ : fusible_instrs_(),
+ pos_(0),
+ sliced_input_fusion_(sliced_input_fusion) {
+ Initialize(consumer);
+ }
+
+ // Gets a span of fusions to be fused.
+ absl::Span<HloInstruction*> GetNextSpanOfFusions();
+
+ private:
+ void Initialize(HloInstruction*);
+
+ std::vector<HloInstruction*> fusible_instrs_;
+ // `pos_` points to the start position of the next span.
+ size_t pos_;
+ // `sliced_input_fusion_` flag controls whether we want to fuse
+ // into kLoop (false) or kInput (True) type kernel
+ bool sliced_input_fusion_;
+ };
+
+ HloComputation* computation_;
+ std::string prefix_;
+}; // HorizontalLoopFusionImpl
+
+bool IsFusibleCandidate(const HloInstruction& instr) {
+ // For now, we do not support fusing instruction with control flow.
+ if (!instr.control_successors().empty() ||
+ !instr.control_predecessors().empty()) {
+ return false;
+ }
+
+ if (IsNestableVariadicReduction(instr)) {
+ return false;
+ }
+
+ // Require no further check for element-wise instructions.
+ if (instr.IsElementwise() && instr.operand_count() > 0) {
+ return true;
+ }
+
+ // Exclude fusions other than kLoop.
+ if (!instr.IsLoopFusion()) {
+ return false;
+ }
+
+ // Cannot support fusion who has multiple output types, because the
+ // concatenate (inserted for horizontal fusion) requires the same type
+ // for all of its operands.
+ auto outputs = GetOutputsOfFusible(instr);
+ CHECK(!outputs.empty());
+ const HloInstruction* first_output = outputs[0];
+ for (size_t i = 1; i < outputs.size(); ++i) {
+ if (first_output->shape().element_type() !=
+ outputs[i]->shape().element_type()) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// Returns whether `instr` is a profitable candidate to be horizontally fused.
+// Since the primary benefit of horizontal fusion comes from reducing the
+// kernel launch overhead, we want to exclude the instructions with
+// insignificant kernel launch overhead. In other words, we exclude instructions
+// if their computation latencies are longer than launch latencies. We estimate
+// the computation latency of a given instruction by its shapes and the
+// instruction count in its fused computation. We roughly observe that if a
+// fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
+// instructions than `kInstrCountThreshold`, it is launch-latency-bound and
+// profitable by horizontal fusion.
+bool IsProfitableFusionCandidate(const HloInstruction& instr,
+ bool sliced_input_fusion) {
+ // For kLoop fused kernel, each GPU thread will process 1 or more elements
+ // from each horizontal fused operands, while for kInput fused kernel, each
+ // GPU thread can only process 1 element. From experience, we enable larger
+ // tensor size threshold for kLoop fusion.
+ const int64_t kShapeThreshold =
+ sliced_input_fusion ? 128 * 2048 : 8192 * 8192;
+ const int64_t kInstrCountThreshold = sliced_input_fusion ? 30 : 128;
+ const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
+ ? instr.fused_expression_root()
+ : &instr;
+
+ // Too large shapes are not easily profitable.
+ if (root->opcode() == HloOpcode::kTuple) {
+ // Since all output shapes are the same, use the first shape as the
+ // representative.
+ Shape shape = root->operand(0)->shape();
+ if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
+ VLOG(2) << "Profitable check failed due to element count with "
+ "sliced_input_fusion="
+ << sliced_input_fusion;
+ return false;
+ }
+ } else {
+ Shape shape = root->shape();
+ if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
+ VLOG(2) << "Profiltable check failed due to element size with "
+ "sliced_input_fusion="
+ << sliced_input_fusion;
+ return false;
+ }
+ }
+
+ // Having too many instructions is not easily profitable.
+ if (instr.opcode() == HloOpcode::kFusion &&
+ instr.fused_instruction_count() > kInstrCountThreshold) {
+ return false;
+ }
+
+ return true;
+}
+
+// Returns whether `fusion_instr` has only row-major layouts.
+// The horizontal fusion excludes computations with non-row-major layouts,
+// because fusing computations with different layouts can result in uncoalesced
+// memory accesses and cause great performance overhead.
+bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
+ if (instr.opcode() != HloOpcode::kFusion) {
+ return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
+ }
+
+ auto fused_instrs = instr.fused_instructions_computation()->instructions();
+ for (HloInstruction* i : fused_instrs) {
+ if (!LayoutUtil::IsDenseArray(i->shape())) {
+ continue;
+ }
+ if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// Returns whether any operand of `instr` is a parameter instruction that
+// is shared with `fusion_instrs`.
+bool AnyOpndIsParamSharedAmongFusions(
+ const HloInstruction* instr,
+ const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
+ return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
+ return opnd->opcode() == HloOpcode::kParameter &&
+ absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
+ return user != instr && fusion_instrs.contains(user);
+ });
+ });
+}
+
+void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
+ HloInstruction* consumer) {
+ // First, find out all potential target candidates. We will filter out
+ // unsupported/non-profitable cases below.
+ absl::flat_hash_set<HloInstruction*> fusible_candidates;
+ std::vector<HloInstruction*> ordered_fusible_candidates;
+ for (HloInstruction* opnd : consumer->operands()) {
+ HloInstruction* predecessor = opnd->LatestNonGteAncestor();
+ // We support kLoop fusion and element-wise HLOs now. We may extend the
+ // support list if needs arise.
+ if (IsFusibleCandidate(*predecessor)) {
+ if (fusible_candidates.insert(predecessor).second) {
+ // Add unseen fusion to ordered list.
+ ordered_fusible_candidates.push_back(predecessor);
+ }
+ }
+ }
+
+ for (HloInstruction* instr : ordered_fusible_candidates) {
+ if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
+ VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+ << " rejects maybe illegal instr " << instr->ToString()
+ << "; including it may create cycles in HLO.";
+ continue;
+ } else if (!IsProfitableFusionCandidate(*instr, sliced_input_fusion_)) {
+ VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+ << " rejects may-not-be profitable fusion instr"
+ << instr->ToString();
+ continue;
+ } else if (!HasOnlyRowMajorLayout(*instr)) {
+ VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+ << " rejects non-row-major fusion instr " << instr->ToString();
+ continue;
+ } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
+ // Don't fuse fusions whose operands are parameter instructions that are
+ // shared among fusions because we cannot i/o alias the produced
+ // horizontal fusion due to the concat insertion.
+ VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+ << " rejects the fusion instr because it shares parameter with"
+ << " other fusion candidates, instr: " << instr->ToString();
+ continue;
+ } else {
+ VLOG(2) << "Find a fusion candidate " << instr->ToString();
+ // Encapsulate it into a fusion computation for unified representation
+ // for later processing.
+ fusible_instrs_.push_back(instr);
+ }
+ }
+
+ // Sort `fusible_instrs_` according to output types, the number of outputs,
+ // instruction counts, output tensor element count. For sliced input fusion,
+ // we only fuse instructions with the same number/type of outputs and whose
+ // computations have the same instruction count. For kLoop fusion, we requires
+ // the fused instructions to have the same number/type of outputs and also the
+ // same output shape. We did a sort here so the fusion candidates is
+ // populating a continuous span.
+ std::stable_sort(
+ fusible_instrs_.begin(), fusible_instrs_.end(),
+ [&](const HloInstruction* a, const HloInstruction* b) {
+ if (GetUniqueOutputTypeOfFusible(*a) !=
+ GetUniqueOutputTypeOfFusible(*b)) {
+ return GetUniqueOutputTypeOfFusible(*a) <
+ GetUniqueOutputTypeOfFusible(*b);
+ } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
+ return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
+ } else if (GetInstrCountOfFusible(*a) != GetInstrCountOfFusible(*b)) {
+ return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
+ } else {
+ return ShapeUtil::ElementsIn(GetOutputsOfFusible(*a)[0]->shape()) <
+ ShapeUtil::ElementsIn(GetOutputsOfFusible(*b)[0]->shape());
+ }
+ });
+}
+
+// Gets a next span of fusion instructions to be fused.
+absl::Span<HloInstruction*>
+HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
+ if (pos_ >= fusible_instrs_.size()) {
+ return absl::Span<HloInstruction*>();
+ }
+
+ // Fusing too many computations at a time may not be easily profitable and
+ // may increase compile time due to large kernels. Set a limit to it.
+ // From profiling results, we found an issue that large fused horizontal
+ // kernel could have lower E2E perf, though the pure GPU kernel time is
+ // shorter. TODO task for understanding why E2E perf regression for large
+ // horiizontal fused kernel. Use the experience max fusion batch size based on
+ // the fused instruction count of the operand
+ const auto kMaxFusionBatchSize = [&]() -> int64_t {
+ if (sliced_input_fusion_) {
+ return 32;
+ } else {
+ if (fusible_instrs_[pos_]->opcode() == HloOpcode::kFusion) {
+ return 32;
+ } else {
+ return 64;
+ }
+ }
+ }();
+
+ size_t left = pos_;
+ size_t right = pos_ + 1;
+ size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
+ PrimitiveType first_output_type =
+ GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
+ // CUDA has a parameter size limit of ~4k bytes.
+ constexpr int64_t kMaxCudaParamSize = 4000;
+ size_t accum_io_size = 0;
+ size_t accum_num_outputs = 0;
+ for (; right < fusible_instrs_.size(); ++right) {
+ PrimitiveType cur_output_type =
+ GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
+ if (first_output_type != cur_output_type) {
+ // Cannot fuse computations who have multiple output types.
+ break;
+ }
+ if (first_output_size != GetOutputSizeOfFusible(*fusible_instrs_[right])) {
+ // Cannot fuse computations who have different numbers of outputs.
+ break;
+ }
+ if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
+ GetInstrCountOfFusible(*fusible_instrs_[right])) {
+ // Do not fuse computations of different instruction counts as it may
+ // introduce control divergence. This is a very simple heuristic to avoid
+ // fusing computations with too much discrepancy and we may improve it
+ // when the needs arise.
+ break;
+ }
+ if (!sliced_input_fusion_ &&
+ !ShapeUtil::EqualIgnoringElementType(
+ GetOutputsOfFusible(*fusible_instrs_[left])[0]->shape(),
+ GetOutputsOfFusible(*fusible_instrs_[right])[0]->shape())) {
+ // This is for fusing into kLoop type kernel, so we requires that each
+ // fusion operand have the same shape
+ break;
+ }
+ size_t num_outputs = GetOutputSizeOfFusible(*fusible_instrs_[right]);
+ accum_num_outputs += num_outputs;
+ if (accum_num_outputs >= kMaxFusionBatchSize) {
+ // Hit max fusion batch size.
+ break;
+ }
+ accum_io_size += fusible_instrs_.at(right)->operand_count() + num_outputs;
+ if (accum_io_size * 8 >= kMaxCudaParamSize) {
+ break;
+ }
+ }
+ VLOG(2) << "horizontal fuse get instruction span with " << (right - left)
+ << " instructions for sliced_input_fusion=" << sliced_input_fusion_
+ << " fusion";
+ pos_ = right;
+ return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
+}
+
+absl::StatusOr<bool> HorizontalLoopFusionImpl::FuseConsumerOperands(
+ HloInstruction* consumer, bool sliced_input_fusion,
+ std::vector<HloInstruction*>& to_fuse_candidates) {
+ bool changed = false;
+ FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion);
+ while (true) {
+ auto fusibles = loop_fusion_candidates.GetNextSpanOfFusions();
+ if (fusibles.empty()) {
+ break;
+ } else if (fusibles.size() == 1) {
+ // Skip; there is just one fused_instr.
+ continue;
+ }
+
+ changed = true;
+ // Convert fusible into fusion_instrs to simplify the implementation of
+ // `Fuse()`.
+ std::vector<HloInstruction*> fusion_instrs;
+ for (HloInstruction* instr : fusibles) {
+ if (instr->opcode() == HloOpcode::kFusion) {
+ fusion_instrs.push_back(instr);
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * fusion_instr,
+ MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
+ fusion_instrs.push_back(fusion_instr);
+ }
+ }
+
+ TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs), sliced_input_fusion,
+ to_fuse_candidates));
+ }
+ return changed;
+}
+
+absl::Status HorizontalLoopFusionImpl::CreateFusedComputation(
+ absl::Span<HloInstruction*> fused_fusion_instrs,
+ std::unique_ptr<HloComputation>* uniq_computation,
+ std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion) {
+ // First, build a computation with only params.
+ HloComputation::Builder b(prefix_ + "horizontally_fused_computation");
+ size_t fused_comp_param_id = 0;
+ for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+ auto old_params = fused_fusion_instrs[i]->fused_parameters();
+ for (size_t j = 0; j < old_params.size(); ++j) {
+ HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
+ // in a form of param_i_j
+ b.AddInstruction(HloInstruction::CreateParameter(
+ fused_comp_param_id++, bound_opnd->shape(),
+ absl::StrCat("param_", i, "_", j)));
+ bound_operands->push_back(bound_opnd);
+ }
+ }
+ // Always create a dummy tuple instruction to serve as the root of the
+ // computation, as the existence of a root instruction is required by the
+ // HloComputation. The real root instruction will replace it below.
+ HloInstruction* dummy_root = b.AddInstruction(
+ HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
+ *uniq_computation = b.Build(dummy_root);
+ HloComputation* comp = uniq_computation->get();
+
+ // Preparing clone_map, which maps old operand to new operand.
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
+ size_t new_param_id = 0;
+ for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+ auto old_params = fused_fusion_instrs[i]->fused_parameters();
+ for (size_t j = 0; j < old_params.size(); ++j) {
+ HloInstruction* old_param = old_params[j];
+ HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
+ clone_map.insert({old_param, new_param});
+ }
+ }
+
+ // Clone every fused computation.
+ const OpMetadata* metadata = nullptr;
+ for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+ auto def_to_use_order = fused_fusion_instrs[i]
+ ->fused_instructions_computation()
+ ->MakeInstructionPostOrder();
+ for (HloInstruction* old_instr : def_to_use_order) {
+ if (old_instr->opcode() == HloOpcode::kParameter ||
+ (sliced_input_fusion && old_instr->opcode() == HloOpcode::kTuple &&
+ old_instr == fused_fusion_instrs[i]->fused_expression_root())) {
+ // Parameters have been created, and we don't need tuples from
+ // multi-output fusions, as we will directly reference the tuple
+ // operands instead by using GetOutputsOfFusible().
+ continue;
+ }
+ std::vector<HloInstruction*> new_opnds;
+ const auto& old_opnds = old_instr->operands();
+ new_opnds.reserve(old_opnds.size());
+ for (HloInstruction* old_opnd : old_opnds) {
+ CHECK(clone_map.find(old_opnd) != clone_map.end());
+ new_opnds.push_back(clone_map[old_opnd]);
+ }
+ HloInstruction* new_instr = comp->AddInstruction(
+ old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
+ clone_map.insert({old_instr, new_instr});
+ // Get the metadata from the last fused instruction.
+ metadata = &old_instr->metadata();
+ }
+ }
+
+ // Since we require each fusion to have the same number of outputs, we can
+ // simply use the first fusion as the representative for output size.
+ size_t fused_instr_output_size =
+ GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
+
+ if (sliced_input_fusion) {
+ // Fusing into kInput fusion
+ std::vector<HloInstruction*> concated_outputs;
+ for (size_t i = 0; i < fused_instr_output_size; ++i) {
+ std::vector<HloInstruction*> instr_outputs(fused_fusion_instrs.size());
+ for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
+ const HloInstruction* old_output =
+ GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
+ HloInstruction* new_output = clone_map[old_output];
+ if (new_output->shape().dimensions_size() == 1) {
+ instr_outputs[j] = new_output;
+ } else {
+ Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout(
+ new_output->shape().element_type(),
+ {ShapeUtil::ElementsIn(new_output->shape())},
+ /*minor_to_major=*/std::vector<int64_t>(1, 0));
+ TF_ASSIGN_OR_RETURN(instr_outputs[j],
+ MakeReshapeHlo(new_shape, new_output));
+ }
+ }
+ TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
+ MakeConcatHlo(instr_outputs, 0));
+ concated_outputs.push_back(concated_output);
+ }
+
+ // Make slices of outputs.
+ std::vector<HloInstruction*> output_slices(concated_outputs.size() *
+ fused_fusion_instrs.size());
+ for (size_t i = 0; i < concated_outputs.size(); ++i) {
+ HloInstruction* concated_output = concated_outputs[i];
+ int64_t slice_start = 0;
+ // Create a slice per fused computation.
+ for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
+ const HloInstruction* old_output =
+ GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
+ Shape shape = old_output->shape();
+ int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
+ TF_ASSIGN_OR_RETURN(
+ output_slices[concated_outputs.size() * j + i],
+ MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
+ /*strides=*/{1}));
+ slice_start = slice_limit;
+ }
+ }
+
+ // Make a tuple of output_slices.
+ HloInstruction* tuple = comp->AddInstruction(
+ HloInstruction::CreateTuple(output_slices), metadata);
+ comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
+ TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
+
+ } else {
+ // Fusing into kLoop fusion
+ std::vector<HloInstruction*> tuple_operands(fused_instr_output_size *
+ fused_fusion_instrs.size());
+ // If fusing into kLoop fusion, the new fusion root is tuple of fused
+ // fusion computaton's root.
+ for (size_t i = 0; i < fused_instr_output_size; ++i) {
+ for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
+ const HloInstruction* old_output =
+ GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
+ HloInstruction* new_output = clone_map[old_output];
+ tuple_operands[fused_instr_output_size * j + i] = new_output;
+ }
+ }
+ // Make a tuple instruction of fused instruction outputs as
+ // the root of fused computation.
+ HloInstruction* tuple =
+ comp->AddInstruction(HloInstruction::CreateTuple(tuple_operands));
+ comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
+ TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status HorizontalLoopFusionImpl::Fuse(
+ absl::Span<HloInstruction*> fused_fusion_instrs, bool sliced_input_fusion,
+ std::vector<HloInstruction*>& to_fuse_candidates) {
+ // Fuse fused_fusion_instrs and replace them with the new fused computation.
+ std::unique_ptr<HloComputation> uniq_computation;
+ std::vector<HloInstruction*> bound_operands;
+
+ TF_RETURN_IF_ERROR(CreateFusedComputation(fused_fusion_instrs,
+ &uniq_computation, &bound_operands,
+ sliced_input_fusion));
+
+ HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
+ std::move(uniq_computation));
+ HloInstruction* hori_fusion_instr = computation_->AddInstruction(
+ HloInstruction::CreateFusion(fused_comp->root_instruction()->shape(),
+ sliced_input_fusion
+ ? HloInstruction::FusionKind::kInput
+ : HloInstruction::FusionKind::kLoop,
+ bound_operands, fused_comp, prefix_),
+ &fused_comp->root_instruction()->metadata());
+ fused_comp->SetFusionInstruction(hori_fusion_instr);
+
+ // we push the newly fused instruction into fusion candidate stack, because
+ // the operands of the newly fused instruction could now be possible to be
+ // horizontally fused.
+ to_fuse_candidates.push_back(hori_fusion_instr);
+
+ // Insert bitcasts and replace corresponding users. Note that we do not insert
+ // the bitcasts in the fused computation as it does not fit into the slice
+ // input fusion pattern. However, inserting bitcasts outside the fused
+ // computation creates no performance cost.
+ size_t total_output_id = 0;
+ for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+ std::vector<HloInstruction*> bitcasts_or_gte;
+ HloInstruction* fused_instr = fused_fusion_instrs[i];
+ size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
+ for (size_t j = 0; j < num_outputs; ++j) {
+ const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * gep,
+ MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
+ // This pass runs late, so useless bitcast won't be cleaned up.
+ if (output->shape().dimensions_size() == 1) {
+ bitcasts_or_gte.push_back(gep);
+ } else {
+ bitcasts_or_gte.push_back(computation_->AddInstruction(
+ HloInstruction::CreateBitcast(output->shape(), gep)));
+ }
+ }
+ HloInstruction* bitcast_or_tuple =
+ (bitcasts_or_gte.size() == 1)
+ ? bitcasts_or_gte.at(0)
+ : computation_->AddInstruction(
+ HloInstruction::CreateTuple(bitcasts_or_gte));
+ HloComputation* old_computation =
+ fused_instr->fused_instructions_computation();
+ HloModule* module = old_computation->parent();
+ TF_RETURN_IF_ERROR(
+ computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
+ TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(old_computation));
+ }
+
+ TF_RETURN_IF_ERROR(Cast<HloFusionInstruction>(hori_fusion_instr)
+ ->DeduplicateFusionOperands());
+
+ VLOG(1) << "Fused " << fused_fusion_instrs.size()
+ << " instructions into: " << hori_fusion_instr->ToString();
+ return absl::OkStatus();
+}
+
+absl::StatusOr<bool> HorizontalLoopFusionImpl::Run() {
+ bool changed = false;
+ XLA_VLOG_LINES(3, computation_->ToString());
+
+ // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
+ // shape mismatch but bitcasts could prevent future h-fusion from happening.
+ // So, a bottom-up, use-to-def order should be more favorable. It also helps
+ // to save compiler iterations to reach the fixed point.
+ std::vector<HloInstruction*> to_fuse_candidates =
+ computation_->MakeInstructionPostOrder();
+
+ while (!to_fuse_candidates.empty()) {
+ HloInstruction* consumer = to_fuse_candidates.back();
+ to_fuse_candidates.pop_back();
+
+ // the consumer may be the operands of previously fused instruction, so
+ // it will no longer valid, skip this instruction.
+ if (consumer->IsDead()) {
+ continue;
+ }
+
+ // we first try to fuse into kLoop fusion instruction for those operands
+ // that have the same shape.
+ TF_ASSIGN_OR_RETURN(
+ bool loop_fusion_changed,
+ FuseConsumerOperands(consumer, false, to_fuse_candidates));
+
+ // for the remaining operands with diffent shape, we further try fuse them
+ // into kInput fusion instruction.
+ TF_ASSIGN_OR_RETURN(
+ bool sliced_input_fusion_changed,
+ FuseConsumerOperands(consumer, true, to_fuse_candidates));
+
+ changed = changed || loop_fusion_changed || sliced_input_fusion_changed;
+ }
+ return changed;
+}
+
+} // namespace
+
+absl::StatusOr<bool> HorizontalLoopFusion::RunOnComputation(
+ HloComputation* computation) {
+ HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_);
+ return horizontal_fusion_impl.Run();
+}
+
+absl::StatusOr<bool> HorizontalLoopFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ VLOG(2) << "Run horizontal fusion.";
+
+ // Run on the entry computation is actually enough.
+ TF_ASSIGN_OR_RETURN(bool changed,
+ RunOnComputation(module->entry_computation()));
+
+ if (changed) {
+ // Correctly set element_size_in_bits for any sub-byte added slice and
+ // concatenate instructions
+ TF_ASSIGN_OR_RETURN(
+ [[maybe_unused]] bool unused,
+ SubByteNormalization{SubByteNormalization::SET_ELEMENT_SIZE}.Run(
+ module));
+ }
+
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h
new file mode 100644
index 0000000..f29bcd3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h
@@ -0,0 +1,145 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_
+
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// This optimization pass horizontally fuses computations for reducing kernel
+// launch overhead while increasing kernel launch dims on GPU. The initial
+// motivation of this horizontal fusion is due to the observation that the
+// training optimizer phase (e.g., AdamOptimizer and L2Loss, etc.) typically
+// has many small kernels as a result of applying the same formula on many
+// training parameters (or variables in Tensorflow). Fusing these small
+// kernels, hence, provides performance gain.
+//
+// Theoretically speaking, we may implement a cycle detection algorithm to make
+// sure no cycles are created after fusion. However, cycle detection check is
+// somewhat cumbersome; also, we observe that naive horizontal fusion of
+// arbitrary kernels may not be profitable due to control divergence and
+// possible increase of memory bandwidth pressure due to uncoalesced memory
+// accesses (note that horizontal fusion does not change the amount of memory
+// read+written at all). In practice, a simple yet effective heuristic is used
+// to avoid these issues while addressing the known beneficial cases. That is,
+// we simply search for fusion candidates by looking for instructions whose
+// outputs are all consumed by the same instruction. This catches the cases in
+// the training optimizer phase, as the candidate instructions are typically
+// consumed only by the ROOT tuple of the entry computation.
+//
+// The following illustrates the mechanism of the horizontal fusion. Before
+// fusion, there are two trivial kernels in the illustrating example. One has
+// only a Mul op, while the other consists of only an Add op. Since they are
+// only consumed by the same (ROOT) tuple instruction, horizontal fusion is
+// triggered.
+//
+// i0 i1 i2 i3
+// | | | |
+// v v v v
+// Mul Add
+// | |
+// v v
+// (ROOT) tuple
+//
+// We fuse into one of two possible patterns, depending on whether all the
+// fused operations have the same shape or not.
+//
+// case 1: if Mul and Add's output shape and type are the same, then we fuse
+// them into the below pattern:
+// i0 i1 i2 i3
+// | | | |
+// v v v v
+// Mul Add
+// | |
+// v v
+// (ROOT) tuple
+// the fused kernel will be kLoop type, and GPU code is emitted through
+// the LoopFusion class.
+//
+// case 2: if Mul and Add's output shape are diffent, then we fuse them into
+// the below pattern that adds extra indexing:
+// i0 i1 i2 i3 +++ (Slice) Input Fusion
+// | | | | +
+// v v v v +
+// Mul Add +
+// | | +
+// v v +
+// Reshape0 Reshape1 +
+// | | +
+// v v +
+// Concatenate +
+// | | +
+// v v +
+// Slice0 Slice1 +++
+// | |
+// v v
+// Reshape2 Reshape3
+// | |
+// v v
+// (ROOT) tuple
+//
+// the fused kernel will be kInput type, and, the GPU code is emitted through
+// the InputSlicesFusion class.
+//
+// In theory, the pattern in case 1 could also be fused into the case2 target
+// graph, but we prefer to fuse into kLoop type, because the codegen for it does
+// not have the slicing range check cost introduced by case 2 pattern.
+//
+// Note that the fusion style by case 2 provides an important advantage that
+// kernels of different shapes can be horizontally fused. The first pair of
+// reshapes (i.e., Reshape0 and Reshape1) reshape the dims to 1 dimension, so
+// that the outputs of the fused kernels can (always) be concatenated. The
+// second pair of reshapes (Reshape2 and Reshape3) restore the original shapes
+// to the output tensors.
+//
+// No extra copies are introduced by the horizontal fusion. Besides Reshape2
+// and Reshape3, the other instructions are fused into an input fusion; the
+// output dims of the concatenate will be used as the kernel launch dims.
+// Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the
+// outputs of Mul and Add are row-major.
+//
+// Note, reshapes are added only if the tensors isn't already a vector.
+class HorizontalLoopFusion : public HloModulePass {
+ public:
+ HorizontalLoopFusion() = default;
+ explicit HorizontalLoopFusion(absl::string_view prefix) : prefix_(prefix) {}
+
+ absl::string_view name() const override { return "horizontal_loop_fusion"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ absl::StatusOr<bool> RunOnComputation(HloComputation*);
+ std::string prefix_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc
new file mode 100644
index 0000000..781d27a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc
@@ -0,0 +1,851 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/hlo_dce.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/hlo_pass_fix.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class HorizontalLoopFusionTest : public HloTestBase {
+ public:
+ static bool IsFusion(const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kFusion;
+ }
+};
+
+TEST_F(HorizontalLoopFusionTest, BasicTest) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule BasicTest
+
+ fused_computation.1 {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ arg.3 = f16[123]{0} parameter(2)
+ arg.4 = f16[123]{0} parameter(3)
+ fusion.1 = f16[1024]{0}
+ fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+ fusion.2 = f16[123]{0}
+ fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
+ ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
+ tuple(fusion.1, fusion.2)
+ }
+)")
+ .value();
+
+ EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+ TF_ASSERT_OK(verifier().Run(module.get()).status());
+ EXPECT_FALSE(HloDCE().Run(module.get()).value());
+
+ const HloInstruction* entry_root =
+ module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(entry_root,
+ GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement(m::Fusion()))));
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Slice(m::Concatenate()),
+ m::Slice(m::Concatenate()))));
+}
+
+// Horizontal fusion should not be triggered as fusion will create cycles.
+TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NegativeTestForCycle
+
+ fused_computation.1 {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ arg.3 = f16[123]{0} parameter(2)
+ arg.4 = f16[123]{0} parameter(3)
+ // fusion.1 and fusion.2 will not be horizontally fused as it will create
+ // a cycle through fusion.1 -> add.2 -> fusion.2
+ fusion.1 = f16[123]{0}
+ fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+ add.2 = f16[123]{0} add(fusion.1, arg.4)
+ fusion.2 = f16[123]{0}
+ fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
+ ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
+ tuple(fusion.1, fusion.2, add.2)
+ }
+)")
+ .value();
+
+ EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NegativeTestForIncompatibleTypes
+
+ fused_computation.1 {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+ arg.1 = s32[123]{0} parameter(0)
+ arg.2 = s32[123]{0} parameter(1)
+ ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ arg.3 = s32[123]{0} parameter(2)
+ arg.4 = s32[123]{0} parameter(3)
+ // fusion.1 and fusion.2 will not be horizontally fused because their output
+ // types are different.
+ fusion.1 = f16[1024]{0}
+ fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+ fusion.2 = s32[123]{0}
+ fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
+ ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
+ tuple(fusion.1, fusion.2)
+ }
+)")
+ .value();
+
+ EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule FusingIntoKLoopAndKInputTogether
+
+ fused_computation.1 {
+ arg.1 = f16[129, 2048]{1, 0} parameter(0)
+ arg.2 = f16[129, 2048]{1, 0} parameter(1)
+ ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+ arg.1 = f16[129, 2048]{1, 0} parameter(0)
+ arg.2 = f16[129, 2048]{1, 0} parameter(1)
+ ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.3 {
+ arg.1 = f16[130, 2048]{1, 0} parameter(0)
+ arg.2 = f16[130, 2048]{1, 0} parameter(1)
+ ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.4 {
+ arg.1 = f16[130, 2048]{1, 0} parameter(0)
+ arg.2 = f16[130, 2048]{1, 0} parameter(1)
+ ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.5 {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ fused_computation.6 {
+ arg.1 = f16[128]{0} parameter(0)
+ arg.2 = f16[128]{0} parameter(1)
+ ROOT add.1 = f16[128]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[129, 2048]{1, 0} parameter(0)
+ arg.2 = f16[129, 2048]{1, 0} parameter(1)
+ arg.3 = f16[129, 2048]{1, 0} parameter(2)
+ arg.4 = f16[129, 2048]{1, 0} parameter(3)
+ arg.5 = f16[130, 2048]{1, 0} parameter(4)
+ arg.6 = f16[130, 2048]{1, 0} parameter(5)
+ arg.7 = f16[130, 2048]{1, 0} parameter(6)
+ arg.8 = f16[130, 2048]{1, 0} parameter(7)
+ arg.9 = f16[123]{0} parameter(8)
+ arg.10 = f16[123]{0} parameter(9)
+ arg.11 = f16[128]{0} parameter(10)
+ arg.12 = f16[128]{0} parameter(11)
+
+ // fusion.1 and fusion.2 will be fused into kLoop fusion
+ // fusion.3 and fusion.4 will be fused into another kLoop fusion
+ // fusion.5 and fusion.6 will be fused into kInput fusion
+
+ fusion.1 = f16[129,2048]{1, 0}
+ fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+
+ fusion.2 = f16[129,2048]{1, 0}
+ fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
+
+ fusion.3 = f16[130,2048]{1, 0}
+ fusion(arg.5, arg.6), kind=kLoop, calls=fused_computation.3
+
+ fusion.4 = f16[130,2048]{1, 0}
+ fusion(arg.7, arg.8), kind=kLoop, calls=fused_computation.4
+
+ fusion.5 = f16[123]{0}
+ fusion(arg.9, arg.10), kind=kLoop, calls=fused_computation.5
+
+ fusion.6 = f16[128]{0}
+ fusion(arg.11, arg.12), kind=kLoop, calls=fused_computation.6
+
+ ROOT tuple.1 = (f16[129,2048]{1, 0}, f16[129,2048]{1, 0},
+ f16[130,2048]{1, 0}, f16[130,2048]{1, 0},
+ f16[123]{0}, f16[128]{0})
+ tuple(fusion.1, fusion.2, fusion.3, fusion.4, fusion.5, fusion.6)
+ }
+)")
+ .value();
+
+ EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+
+ int input_fusion_count = 0;
+ int loop_fusion_count = 0;
+ for (auto inst : module->entry_computation()->MakeInstructionPostOrder()) {
+ if (inst->opcode() == HloOpcode::kFusion) {
+ input_fusion_count +=
+ (inst->fusion_kind() == HloInstruction::FusionKind::kInput) ? 1 : 0;
+ loop_fusion_count +=
+ (inst->fusion_kind() == HloInstruction::FusionKind::kLoop) ? 1 : 0;
+ }
+ }
+ EXPECT_EQ(input_fusion_count, 1);
+ EXPECT_EQ(loop_fusion_count, 2);
+}
+
+TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule MergeSharedFusionInstruction
+
+ ENTRY MergeSharedFusionInstruction.Computation0 {
+ param.1.1 = f32[4,1024]{1,0} parameter(0)
+ param.1.2 = f32[4,1024]{1,0} parameter(1)
+ param.1.3 = f32[4,1024]{1,0} parameter(2)
+ param.2.1 = f32[321,5]{1,0} parameter(3)
+ param.2.2 = f32[321,5]{1,0} parameter(4)
+ param.2.3 = f32[321,5]{1,0} parameter(5)
+ const.1 = f32[] constant(3)
+ const.2 = f32[] constant(3)
+ broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
+ broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
+ mul.1.1 = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
+ mul.1.2 = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
+ add.1 = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
+ mul.2.1 = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
+ mul.2.2 = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
+ add.2 = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
+ ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
+})")
+ .value();
+
+ HloPassPipeline fusion("fusion");
+ const se::DeviceDescription device_info =
+ TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false,
+ device_info);
+ fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true,
+ device_info);
+ EXPECT_TRUE(fusion.Run(module.get()).value());
+ EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+ TF_ASSERT_OK(verifier().Run(module.get()).status());
+
+ VLOG(2) << "Dump after horizontal fusion:";
+ VLOG(2) << module->ToString();
+
+ const HloInstruction* entry_root =
+ module->entry_computation()->root_instruction();
+ const HloInstruction* fusion_instr = nullptr;
+ // Check that we add bitcast when needed.
+ ASSERT_THAT(entry_root,
+ GmockMatch(m::Tuple(
+ m::Bitcast(m::GetTupleElement(m::Fusion(&fusion_instr))),
+ m::Bitcast(m::GetTupleElement(m::Fusion())))));
+ ASSERT_TRUE(fusion_instr->IsMultiOutputFusion());
+ EXPECT_THAT(fusion_instr->fused_expression_root(),
+ GmockMatch(m::Tuple(
+ m::Slice(m::Concatenate(m::Reshape(), m::Reshape())),
+ m::Slice(m::Concatenate(m::Reshape(), m::Reshape())))));
+
+ EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
+ HloComputation::Builder builder(TestName());
+
+ std::vector<HloInstruction*> var_outs;
+ for (int64_t i = 0; i < 128; ++i) {
+ // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
+ Shape shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
+ HloInstruction* param_var_in = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
+ HloInstruction* param_alpha =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
+ HloInstruction* param_delta = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
+ HloInstruction* alpha_broadcasted = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(shape, param_alpha, {}));
+ HloInstruction* alpha_delta =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
+ HloInstruction* var_out =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
+ var_outs.push_back(var_out);
+ }
+ builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
+
+ auto module = CreateNewVerifiedModule();
+ module->AddEntryComputation(builder.Build());
+
+ // Testing with the entire gpu optimization pipeline.
+ EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule HeterogeneousMultiOutputFusions
+
+ fused_computation.1 {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ arg.3 = f16[1024]{0} parameter(2)
+ arg.4 = f16[1024]{0} parameter(3)
+ mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
+ mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
+ add.1 = f16[1024]{0} add(mul.1, mul.2)
+ ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
+ }
+
+ fused_computation.2 {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ arg.3 = f16[123]{0} parameter(2)
+ arg.4 = f16[123]{0} parameter(3)
+ add.1 = f16[123]{0} add(arg.1, arg.2)
+ add.2 = f16[123]{0} add(arg.3, arg.4)
+ mul.1 = f16[123]{0} multiply(add.1, add.2)
+ ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[1024]{0} parameter(0)
+ arg.2 = f16[1024]{0} parameter(1)
+ arg.3 = f16[1024]{0} parameter(2)
+ arg.4 = f16[1024]{0} parameter(3)
+ arg.5 = f16[123]{0} parameter(4)
+ arg.6 = f16[123]{0} parameter(5)
+ arg.7 = f16[123]{0} parameter(6)
+ arg.8 = f16[123]{0} parameter(7)
+ fusion.1 = (f16[1024]{0}, f16[1024]{0})
+ fusion(arg.1, arg.2, arg.3, arg.4),
+ kind=kLoop, calls=fused_computation.1
+ fusion.2 = (f16[123]{0}, f16[123]{0})
+ fusion(arg.5, arg.6, arg.7, arg.8),
+ kind=kLoop, calls=fused_computation.2
+ gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
+ gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
+ gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
+ gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
+ ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
+ tuple(gte.1, gte.2, gte.3, gte.4)
+ }
+)")
+ .value();
+
+ EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+ TF_ASSERT_OK(verifier().Run(module.get()).status());
+ EXPECT_FALSE(HloDCE().Run(module.get()).value());
+
+ VLOG(2) << "Dump after horizontal fusion:";
+ VLOG(2) << module->ToString();
+
+ EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
+ HloComputation::Builder builder(TestName());
+
+ std::vector<HloInstruction*> all_outputs;
+ for (int64_t i = 0; i < 48; ++i) {
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
+ // ms <- grad**2 (1 - rho) + ms * rho
+ HloInstruction* grad = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
+ HloInstruction* ms = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
+ HloInstruction* rho =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
+ HloInstruction* one_minus_rho =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
+ HloInstruction* rho_broadcasted =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
+ HloInstruction* one_mins_rho_broadcasted = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
+ HloInstruction* grad_squared = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
+ HloInstruction* ms_1st_term = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad_squared,
+ one_mins_rho_broadcasted));
+ HloInstruction* ms_2nd_term =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kMultiply, ms, rho_broadcasted));
+ HloInstruction* ms_out =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
+
+ // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+ HloInstruction* momentum = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
+ HloInstruction* mom = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
+ HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
+ i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
+ HloInstruction* epsilon =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
+ HloInstruction* lr_broadcasted =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
+ HloInstruction* epsilon_broadcasted = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(shape, epsilon, {}));
+ HloInstruction* mom_1st_term =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kMultiply, momentum, mom));
+ HloInstruction* ms_eps =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
+ HloInstruction* ms_eps_rsq = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
+ HloInstruction* grad_ms_eps_rsq =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
+ HloInstruction* mom_2nd_term =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
+ HloInstruction* mom_out =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
+
+ // var <- var - mom
+ HloInstruction* var = builder.AddInstruction(
+ HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
+ HloInstruction* var_out =
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kSubtract, var, mom_out));
+
+ all_outputs.push_back(ms_out);
+ all_outputs.push_back(mom_out);
+ all_outputs.push_back(var_out);
+ }
+ builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
+
+ auto module = CreateNewVerifiedModule();
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
+}
+
+TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NegativeTestForDynamicUpdateSlice
+
+ fusion.1 {
+ p.0 = f16[5,9,10]{2,1,0} parameter(0)
+ p.1 = s32[] parameter(1)
+ p.2 = f16[1,9,10]{2,1,0} parameter(2)
+ c.0 = s32[] constant(0)
+ ROOT %dynamic-update-slice =
+ f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
+ }
+
+ fusion.2 {
+ p.0 = f16[5,9,10]{2,1,0} parameter(0)
+ p.1 = s32[] parameter(1)
+ p.2 = f16[1,9,10]{2,1,0} parameter(2)
+ c.0 = s32[] constant(0)
+ ROOT %dynamic-update-slice =
+ f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
+ }
+
+ ENTRY entry {
+ p.00 = f16[5,9,10]{2,1,0} parameter(0)
+ p.01 = f16[5,9,10]{2,1,0} parameter(1)
+ p.10 = s32[] parameter(2)
+ p.11 = s32[] parameter(3)
+ p.20 = f16[1,9,10]{2,1,0} parameter(4)
+ p.21 = f16[1,9,10]{2,1,0} parameter(5)
+
+ f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
+ f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
+ ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
+ })")
+ .value();
+
+ EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+ TF_ASSERT_OK(verifier().Run(module.get()).status());
+ EXPECT_FALSE(HloDCE().Run(module.get()).value());
+
+ VLOG(2) << "Dump after horizontal fusion:";
+ VLOG(2) << module->ToString();
+
+ EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule BasicTest
+
+ fused_computation.1 {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+ arg.1 = f16[123]{0} parameter(0)
+ arg.2 = f16[123]{0} parameter(1)
+ ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+ arg.1 = f16[123]{0} parameter(0)
+ // arg.2 is shared by fusion.1 and fusion.2
+ arg.2 = f16[123]{0} parameter(1)
+ arg.3 = f16[123]{0} parameter(2)
+ fusion.1 = f16[123]{0}
+ fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+ fusion.2 = f16[123]{0}
+ fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
+ ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
+ tuple(fusion.1, fusion.2)
+ }
+)")
+ .value();
+
+ EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NonfusionInstrs
+
+ fused_computation.0 {
+ arg.0 = f16[] parameter(0)
+ arg.1 = f16[123]{0} parameter(1)
+ broadcast.0 = f16[123]{0} broadcast(arg.0), dimensions={}
+ ROOT mul.1 = f16[123]{0} multiply(broadcast.0, arg.1)
+ }
+
+ fused_computation.1 {
+ arg.0 = f16[] parameter(0)
+ arg.1 = f16[456]{0} parameter(1)
+ broadcast.0 = f16[456]{0} broadcast(arg.0), dimensions={}
+ ROOT add.1 = f16[456]{0} add(broadcast.0, arg.1)
+ }
+
+ ENTRY entry_computation {
+ arg.0 = f16[] parameter(0)
+ arg.1 = f16[] parameter(1)
+ arg.2 = f16[123]{0} parameter(2)
+ arg.3 = f16[456]{0} parameter(3)
+ // Test fusion of non-fusion instructions. sqrt.0 and sqrt.1 are to be
+ // fused.
+ sqrt.0 = f16[] sqrt(arg.0)
+ sqrt.1 = f16[] sqrt(arg.1)
+ // fusion.0 and fusion.1 are to be fused.
+ fusion.0 = f16[123]{0}
+ fusion(sqrt.0, arg.2), kind=kLoop, calls=fused_computation.0
+ fusion.1 = f16[456]{0}
+ fusion(sqrt.1, arg.3), kind=kLoop, calls=fused_computation.1
+ ROOT tuple.1 = (f16[123]{0}, f16[456]{0}) tuple(fusion.0, fusion.1)
+ }
+)")
+ .value();
+
+ HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
+ iterative_h_fusion.AddPass<HorizontalLoopFusion>();
+ iterative_h_fusion.AddPass<HloDCE>();
+ EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
+
+ // Verify that fusion.0 and fusion.1 are fused.
+ const HloInstruction* entry_root =
+ module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(entry_root,
+ GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement(m::Fusion()))));
+ EXPECT_TRUE(fusion->IsMultiOutputFusion());
+
+ // Verify that the total number of fusion instructions is 2 so that we
+ // know sqrt.0 and sqrt.1 are fused.
+ EXPECT_EQ(
+ absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
+ 2);
+}
+
+TEST_F(HorizontalLoopFusionTest, TraversalOrder) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule cluster
+
+ %fused_computation (param_0: f32[256,256], param_1: f32[], param_2: f32[])
+ -> f32[256,256] {
+ %param_0 = f32[256,256]{1,0} parameter(0)
+ %param_1 = f32[] parameter(1)
+ %param_2 = f32[] parameter(2)
+ %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
+ %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
+ ROOT %multiply.1 = f32[256,256]{1,0}
+ multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
+ }
+
+ %fused_computation.1 (param_0: f32[256,256], param_1: f32[], param_2: f32[])
+ -> f32[256,256] {
+ %param_0 = f32[256,256]{1,0} parameter(0)
+ %param_1 = f32[] parameter(1)
+ %param_2 = f32[] parameter(2)
+ %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
+ %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
+ ROOT %multiply.1 = f32[256,256]{1,0}
+ multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
+ }
+
+ ENTRY %entry_computation (arg0: f32[256,256], arg1: f32[256,256], arg2: f32[],
+ arg3: f32[], arg4: f32[], arg5: f32[])
+ -> (f32[256,256], f32[256,256]) {
+ %arg0 = f32[256,256]{1,0} parameter(0), parameter_replication={false}
+ %arg1 = f32[256,256]{1,0} parameter(1), parameter_replication={false}
+ %arg2 = f32[] parameter(2), parameter_replication={false}
+ %arg3 = f32[] parameter(3), parameter_replication={false}
+ %arg4 = f32[] parameter(4), parameter_replication={false}
+ %arg5 = f32[] parameter(5), parameter_replication={false}
+ %sqrt = f32[] sqrt(f32[] %arg2)
+ %sqrt.1 = f32[] sqrt(f32[] %arg3)
+ %fusion = f32[256,256]{1,0}
+ fusion(f32[256,256]{1,0} %arg0, f32[] %sqrt, f32[] %sqrt.1),
+ kind=kLoop, calls=%fused_computation
+ %sqrt.2 = f32[] sqrt(f32[] %arg4)
+ %sqrt.3 = f32[] sqrt(f32[] %arg5)
+ %fusion.1 = f32[256,256]{1,0}
+ fusion(f32[256,256]{1,0} %arg1, f32[] %sqrt.2, f32[] %sqrt.3),
+ kind=kLoop, calls=%fused_computation.1
+ ROOT %tuple.163 = (f32[256,256]{1,0}, f32[256,256]{1,0})
+ tuple(f32[256,256]{1,0} %fusion.1, f32[256,256]{1,0} %fusion)
+ }
+)")
+ .value();
+
+ HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
+ iterative_h_fusion.AddPass<HorizontalLoopFusion>();
+ EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
+
+ // Verify that the total number of fusion instructions is 2 so that we
+ // know all the sqrt instructions are fused into a kernel. Note that if we
+ // traverse from def-to-use (i.e., top-to-down) instead of use-to-def, we
+ // will end up having 3 fusions instead of 2.
+ EXPECT_EQ(
+ absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
+ 2);
+}
+
+// Simplified reproducer for Google bug b/242287055.
+// Things that happened:
+// - horizontal loop fusion joined addition a0 and multiplication m0
+// - the resulting fusion had 4 inputs: (gte1, gte0, gte1, gte0)
+// - buffer assignment aliased outputs of this fusion with its inputs
+// - some threads simultaneously did the addition, some - multiplication
+// - as a result some inputs were overwritten before being read
+// Conditional operation is meaningless (branches are equivalent) and
+// is there only to properly confuse the buffer assignment.
+TEST_F(HorizontalLoopFusionTest, NoBufferAliasingOfDuplicateParameter) {
+ const char* hlo_text = R"(
+HloModule m
+
+branch_a {
+ p0 = s32[] parameter(0)
+ c0 = s32[] constant(1)
+ c1 = s32[] constant(2)
+ b0 = s32[4096] broadcast(c0), dimensions={}
+ b1 = s32[4096] broadcast(c1), dimensions={}
+ ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
+}
+
+branch_b {
+ p0 = s32[] parameter(0)
+ c0 = s32[] constant(1)
+ c1 = s32[] constant(2)
+ b0 = s32[4096] broadcast(c0), dimensions={}
+ b1 = s32[4096] broadcast(c1), dimensions={}
+ ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
+}
+
+ENTRY e {
+ p0 = s32[] parameter(0)
+ c0 = s32[] constant(0)
+ cond = (s32[4096], s32[4096]) conditional(p0, c0, c0), branch_computations={branch_a, branch_b}
+ p1 = s32[4096] parameter(1)
+ gte0 = s32[4096] get-tuple-element(cond), index=0
+ gte1 = s32[4096] get-tuple-element(cond), index=1
+ a0 = s32[4096] add(gte1, gte0)
+ m0 = s32[4096] multiply(gte1, gte0)
+ ROOT r = (s32[4096], s32[4096]) tuple(m0, a0)
+}
+)";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+TEST_F(HorizontalLoopFusionTest, CopyInsertionFusionControlFlow) {
+ const char* hlo_text = R"(
+HloModule cluster
+
+ENTRY main {
+ cst = f32[1]{0} constant({0})
+ cp1 = f32[1]{0} copy(cst)
+ cp2 = f32[1]{0} copy(cst)
+ cp3 = f32[1]{0} copy(cst)
+ cp4 = f32[1]{0} copy(cst), control-predecessors={cp1}
+ ROOT tuple_out = (f32[1]{0}, f32[1]{0}, f32[1]{0}, f32[1]{0}) tuple(cp1, cp2, cp3, cp4)
+}
+)";
+
+ auto module = ParseAndReturnUnverifiedModule(hlo_text).value();
+ EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+
+ VLOG(2) << module->ToString();
+
+ // Verify that the total number of fusion instructions is 1.
+ EXPECT_EQ(
+ absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
+ 1);
+
+ const HloInstruction* entry_root =
+ module->entry_computation()->root_instruction();
+ // Check that we fuse when supported.
+ EXPECT_THAT(entry_root,
+ GmockMatch(m::Tuple(m::Copy(), m::GetTupleElement(m::Fusion()),
+ m::GetTupleElement(m::Fusion()), m::Copy())));
+}
+
+TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ fused_computation.94 {
+ tmp_0 = f32[] parameter(0)
+ tmp_1 = f32[] parameter(1)
+ tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
+ tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
+ tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
+ tmp_5 = s32[] parameter(2)
+ tmp_6 = s32[] parameter(3)
+ tmp_7 = s32[] minimum(tmp_5, tmp_6)
+ tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
+ tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
+ ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
+ }
+
+ minmax_func.1536 {
+ tmp_0 = f32[] parameter(0)
+ tmp_1 = f32[] parameter(2)
+ tmp_2 = s32[] parameter(1)
+ tmp_3 = s32[] parameter(3)
+ ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
+ }
+
+ fused_computation {
+ tmp_0 = f32[554112,10]{1,0} parameter(0)
+ tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+ tmp_2 = f32[] constant(-inf)
+ tmp_3 = s32[] constant(0)
+ ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+ }
+
+ fused_computation2 {
+ tmp_0 = f32[554112,10]{1,0} parameter(0)
+ tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+ tmp_2 = f32[] constant(inf)
+ tmp_3 = s32[] constant(1)
+ ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+ }
+
+ ENTRY e {
+ tmp_0 = f32[554112,10]{1,0} parameter(0)
+ tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
+ tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
+ tmp_3 = f32[554112,10]{1,0} parameter(1)
+ tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_3), kind=kLoop, calls=fused_computation2
+ tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
+ ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
+ })")
+ .value();
+
+ EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc
new file mode 100644
index 0000000..5e32f2e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc
@@ -0,0 +1,187 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/meta/type_traits.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/fusion_node_indexing_evaluation.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+bool ElementIsF32OrF16(const Shape& shape) {
+ PrimitiveType type = shape.element_type();
+ return type == F32 || type == F16;
+}
+
+class EmptyFusionQueue : public FusionQueue {
+ public:
+ std::pair<HloInstruction*, std::vector<int64_t>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() override {
+ return {nullptr, {}};
+ }
+ void RemoveInstruction(HloInstruction* instruction) override {};
+ const std::vector<bool>* FusionConfiguration() override { return nullptr; };
+};
+
+} // namespace
+
+absl::StatusOr<bool> GpuInstructionFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ fusion_node_evaluations_.clear();
+ auto fusible_computations =
+ GetFusibleComputations(*module, execution_threads);
+ fusible_computations_ = {fusible_computations.begin(),
+ fusible_computations.end()};
+ return InstructionFusion::Run(module, execution_threads);
+}
+
+/*static*/ bool GpuInstructionFusion::IsExpensive(
+ const HloInstruction& instruction) {
+ // Some floating-point math ops are cheap on the GPU.
+ switch (instruction.opcode()) {
+ case HloOpcode::kDivide:
+ case HloOpcode::kSqrt:
+ case HloOpcode::kRsqrt:
+ case HloOpcode::kExp:
+ if (ElementIsF32OrF16(instruction.shape())) {
+ return false;
+ }
+ break;
+ default:
+ break;
+ }
+ return InstructionFusion::IsExpensive(instruction);
+}
+
+FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks(
+ HloInstruction* consumer, int64_t operand_index) {
+ HloInstruction* producer = consumer->mutable_operand(operand_index);
+
+ // Output fusions are not currently supported on GPUs.
+ if (producer->opcode() == HloOpcode::kFusion) {
+ return "the producer is a fusion";
+ }
+
+ if (consumer->IsCustomFusion()) {
+ return "the consumer is a custom fusion";
+ }
+
+ // Cost condition: not fuse (simple, expensive producers) and (consumers who
+ // reuse operand elements).
+ if (is_expensive(*producer) &&
+ ReusesOperandElements(consumer, operand_index)) {
+ return "the producer is expensive, and the consumer reuses inputs";
+ }
+
+ // Do not fuse into fusions if the resulting kernel would suffer from
+ // uncoalesced reads due to a transposed memory access pattern.
+ if (IsInputFusibleReduction(*consumer) &&
+ IsPhysicallyTransposing(*producer)) {
+ return "fusing the producer would break read coalescing";
+ }
+
+ RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer));
+
+ if (CreatesHeavyComputation(*producer, *consumer)) {
+ return "the fusion would create a heavy computation";
+ }
+
+ return InstructionFusion::ShouldFuse(consumer, operand_index);
+}
+
+FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
+ int64_t operand_index) {
+ RETURN_IF_NOT_FUSIBLE(ShouldFuseInexpensiveChecks(consumer, operand_index));
+
+ auto producer = consumer->operand(operand_index);
+
+ // The following checks are potentially expensive.
+ RETURN_IF_NOT_FUSIBLE(
+ FusionFitsInBudget(*consumer, *producer, device_info_,
+ /*is_consumer_producer_fusion=*/true));
+
+ if (consumer->opcode() != HloOpcode::kFusion) {
+ return {};
+ }
+
+ // Also check that our emitter can handle the fusion node. We currently can
+ // have exponential time/memory requirements for emitting certain fusion
+ // kernels, in which case we don't want to fuse.
+ // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
+ if (fusion_node_evaluations_.find(consumer) ==
+ fusion_node_evaluations_.end()) {
+ // We have no cached results for this fusion node yet. This can happen when
+ // we run the InstructionFusion pass more than once. We can only cache the
+ // results within one run.
+ fusion_node_evaluations_.emplace(consumer,
+ FusionNodeIndexingEvaluation(consumer));
+ }
+ if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
+ return "the fusion would result in an overly large code duplication";
+ }
+ return {};
+}
+
+HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
+ const HloInstruction* producer, const HloInstruction* consumer) {
+ return ChooseFusionKind(*producer, *consumer);
+}
+
+HloInstruction* GpuInstructionFusion::FuseInstruction(
+ HloInstruction* fusion_instruction, HloInstruction* producer) {
+ auto evaluation = fusion_node_evaluations_.find(fusion_instruction);
+ if (evaluation == fusion_node_evaluations_.end()) {
+ evaluation = fusion_node_evaluations_
+ .emplace(fusion_instruction,
+ FusionNodeIndexingEvaluation(fusion_instruction))
+ .first;
+ }
+ auto indexing_users = evaluation->second.RemoveFusionOperand(producer);
+ HloInstruction* new_producer =
+ InstructionFusion::FuseInstruction(fusion_instruction, producer);
+ evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
+ return new_producer;
+}
+
+std::unique_ptr<FusionQueue> GpuInstructionFusion::GetFusionQueue(
+ HloComputation* computation) {
+ if (fusible_computations_.contains(computation)) {
+ return InstructionFusion::GetFusionQueue(computation);
+ }
+ return std::make_unique<EmptyFusionQueue>();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h
new file mode 100644
index 0000000..d7fb7f2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h
@@ -0,0 +1,82 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_
+
+#include <stdint.h>
+
+#include <memory>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/fusion_node_indexing_evaluation.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuInstructionFusion : public InstructionFusion {
+ public:
+ GpuInstructionFusion(bool may_duplicate, const se::DeviceDescription& d)
+ : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate),
+ device_info_(d) {}
+
+ static bool IsExpensive(const HloInstruction& instruction);
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ protected:
+ std::unique_ptr<FusionQueue> GetFusionQueue(
+ HloComputation* computation) override;
+ FusionDecision ShouldFuse(HloInstruction* consumer,
+ int64_t operand_index) override;
+
+ HloInstruction::FusionKind ChooseKind(
+ const HloInstruction* producer, const HloInstruction* consumer) override;
+
+ private:
+ // This method is called by ShouldFuse() to do all the computationally
+ // inexpensive checks whether we should fuse the operand into 'consumer'.
+ FusionDecision ShouldFuseInexpensiveChecks(HloInstruction* consumer,
+ int64_t operand_index);
+
+ HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
+ HloInstruction* producer) override;
+
+ // Keep track of the number of times each instruction inside a fusion node is
+ // indexed with different index vectors.
+ absl::flat_hash_set<const HloComputation*> fusible_computations_;
+ absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
+ fusion_node_evaluations_;
+
+ se::DeviceDescription device_info_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc
new file mode 100644
index 0000000..454dab1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc
@@ -0,0 +1,1006 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+
+#include <cstdint>
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/test_utils.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace m = ::xla::match;
+
+namespace xla {
+namespace gpu {
+
+class InstructionFusionTest : public HloTestBase {
+ public:
+ GpuInstructionFusion duplicating_instruction_fusion_{
+ /*may_duplicate=*/true, TestGpuDeviceInfo::RTXA6000DeviceInfo()};
+};
+
+TEST_F(InstructionFusionTest, NoFusionIntoCustomFusionConsumer) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+c {
+ p0 = bf16[3000,53]{1,0} parameter(0)
+ p1 = bf16[22,53]{1,0} parameter(1)
+ d = bf16[3000,22]{1,0} dot(p0, p1),
+ lhs_contracting_dims={1}, rhs_contracting_dims={1}
+ r = bf16[1,1,3000,22]{3,2,1,0} reshape(d)
+ ROOT c = bf16[1,1,3000,22]{2,1,3,0} copy(r)
+}
+
+ENTRY e {
+ p1 = bf16[3000,53]{1,0} parameter(1)
+ p0 = bf16[22,53]{1,0} parameter(0)
+ cp0 = bf16[22,53]{1,0} convert(p0)
+ ROOT f = bf16[1,1,3000,22]{2,1,3,0} fusion(p1, cp0), kind=kCustom, calls=c
+})"));
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest,
+ CostlyProducerAndOperandElementReusingConsumerNotFused) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
+ HloInstruction* log1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kLog, const0));
+ HloInstruction* broadcast2 =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {1}), log1, {}));
+
+ auto module = CreateNewVerifiedModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(broadcast2, computation->root_instruction());
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+ EXPECT_EQ(broadcast2, computation->root_instruction());
+}
+
+TEST_F(InstructionFusionTest,
+ NonCostlyProducerAndOperandElementReusingConsumerFused) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
+ HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
+ HloInstruction* broadcast2 =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(S32, {1}), negate1, {}));
+
+ auto module = CreateNewVerifiedModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(broadcast2, computation->root_instruction());
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+ EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
+}
+
+TEST_F(InstructionFusionTest,
+ CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
+ HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
+ HloInstruction* reshape2 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), exp1));
+
+ auto module = CreateNewVerifiedModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(reshape2, computation->root_instruction());
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+ EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
+}
+
+TEST_F(InstructionFusionTest,
+ CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
+ HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
+ HloInstruction* transpose2 = builder.AddInstruction(
+ HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {}));
+
+ auto module = CreateNewVerifiedModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(transpose2, computation->root_instruction());
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+ EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotFused) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1}), "0"));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(F32, {1, 1}), param0, param0));
+ auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 1, 1}), dot1));
+ auto log = builder.AddInstruction(HloInstruction::CreateUnary(
+ reshape2->shape(), xla::HloOpcode::kLog, reshape2));
+
+ auto module = CreateNewVerifiedModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(log, computation->root_instruction());
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
+
+ auto module = CreateNewVerifiedModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(transpose2, computation->root_instruction());
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+// Tests that broadcasts fused into a fusion with a reduce root.
+TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY BroadcastIntoReduce {
+ constant = f32[] constant(1)
+ broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={}
+ constant.1 = f32[] constant(0)
+ ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
+ to_apply=add
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_THAT(
+ root->fused_expression_root(),
+ GmockMatch(m::Reduce(m::Broadcast(m::Constant()), m::Constant())));
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ constant.1 = f32[] constant(0)
+ ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add
+ })")
+ .value();
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ fused_reduce {
+ p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0)
+ mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1)
+ c0.1 = f32[] constant(0)
+ ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[]) tuple(fusion)
+ })")
+ .value();
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry {
+ p0 = s32[512,512,2] parameter(0)
+ p1 = f32[1,1,512,512] parameter(1)
+ constant_1 = f32[] constant(1)
+ reduce-window.1 = reduce-window(p1, constant_1),
+ window={size=1x1x9x9}, to_apply=add
+ ROOT ret = gather(reduce-window.1, p0), offset_dims={0,1,2,3},
+ collapsed_slice_dims={}, start_index_map={1,2},
+ index_vector_dim=2, slice_sizes={1,1,1,1}
+ })")
+ .value();
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_THAT(root->fused_expression_root(),
+ GmockMatch(m::Add(m::Copy(), m::Copy())));
+}
+
+TEST_F(InstructionFusionTest, BitcastIntoAdd) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ ENTRY BroadcastIntoAdd {
+ p0 = f32[4,1,1]{2,1,0} parameter(0)
+ p1 = f32[4,1]{1,0} parameter(1)
+ bitcast = f32[4,1]{1,0} bitcast(p0)
+ ROOT add = f32[4,1] add(bitcast, p1)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_THAT(root->fused_expression_root(),
+ GmockMatch(m::Add(m::Bitcast(m::Parameter()), m::Parameter())));
+}
+
+TEST_F(InstructionFusionTest, AddIntoBitcast) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ ENTRY BroadcastIntoAdd {
+ p0 = f32[4,1]{1,0} parameter(0)
+ p1 = f32[4,1]{1,0} parameter(1)
+ add = f32[4,1] add(p0, p1)
+ ROOT bitcast = f32[4,1,1] bitcast(add)
+ })")
+ .value();
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, ConvertIntoBitcastBothConsumedByTuple) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test
+
+ ENTRY main {
+ param_0 = f32[2048,16000]{1,0} parameter(0)
+ convert = bf16[2048,16000]{1,0} convert(param_0)
+ bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
+ ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
+ })")
+ .value();
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, DontFuseGTE) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ ENTRY DontFuseGTE {
+ p0 = (f32[10], f32[10]) parameter(0)
+ gte0 = f32[10] get-tuple-element(p0), index=0
+ gte1 = f32[10] get-tuple-element(p0), index=1
+ ROOT add = f32[10] add(gte0, gte1)
+ })")
+ .value();
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
+// duplicated and fused into both reduces.
+TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ Add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+ ENTRY TestComputation {
+ zero = f32[] constant(0)
+ p0 = f32[100] parameter(0)
+ p1 = f32[100] parameter(1)
+ recip = f32[100] divide(p1, p0)
+ sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ ROOT root = (f32[], f32[]) tuple(sum1, sum2)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
+ << module->ToString();
+}
+
+// Compute sum(100/p0), where p0 has type s32, twice. Check that the division
+// is *not* duplicated and fused into both reduces, because we say that integer
+// division is not cheap.
+TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ Add {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+ }
+ ENTRY TestComputation {
+ zero = s32[] constant(0)
+ p0 = s32[100] parameter(0)
+ p1 = s32[100] parameter(1)
+ recip = s32[100] divide(p1, p0)
+ sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ ROOT mul = (s32[], s32[]) tuple(sum1, sum2)
+ })")
+ .value();
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value())
+ << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ ENTRY NoOutputFusion {
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[3,4]{1,0} parameter(1)
+ dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ d = f32[4,4]{1,0} multiply(dot, dot)
+ ROOT mul = f32[4,4] multiply(d, broadcast)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+ EXPECT_THAT(
+ root->fused_expression_root(),
+ GmockMatch(m::Multiply(m::Multiply(m::Parameter(), m::Parameter()),
+ m::Broadcast(m::Constant()))));
+}
+
+// Counts the HLO ops with a given op code in the specified module.
+static int Count(const HloModule& module, HloOpcode op) {
+ int count = 0;
+ for (const auto* computation : module.computations()) {
+ for (const auto* instruction : computation->instructions()) {
+ if (instruction->opcode() == op) {
+ ++count;
+ }
+ }
+ }
+ return count;
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusion) {
+ // sub --> add --> tuple
+ // \---------------/
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ p2 = f32[4,3]{1,0} parameter(2)
+ sub = f32[4,3]{1,0} subtract(p0, p2)
+ add = f32[4,3]{1,0} add(sub, p1)
+ ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add)
+ })")
+ .value();
+
+ // Multi-output fusion is disabled here and performed in the
+ // MultiOutputFusion pass instead.
+ ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, FuseScalarConstant) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ ENTRY FuseScalarConstant {
+ p0 = f32[] parameter(0)
+ c0 = f32[] constant(1)
+ add1 = f32[] add(p0, c0)
+ b0 = f32[2]{0} broadcast(add1), dimensions={}
+ c1 = f32[2]{0} constant({1, 2})
+ ROOT add2 = f32[2]{0} add(b0, c1)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_THAT(
+ root->fused_expression_root(),
+ GmockMatch(m::Add(m::Broadcast(m::Add(m::Parameter(), m::Constant())),
+ m::Parameter())));
+}
+
+// Check that we limit the number of operands to fusions we create.
+TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
+ constexpr int64_t kNumParams = 200;
+ ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
+
+ // Compute p0 + p1 + ... + pN.
+ HloComputation::Builder b(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+ auto param0 =
+ b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
+ auto sum = param0;
+ for (int64_t i = 1; i < kNumParams; ++i) {
+ auto param =
+ b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"));
+ sum = b.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param));
+ }
+ auto module = CreateNewVerifiedModule();
+ auto computation = module->AddEntryComputation(b.Build());
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ for (const HloInstruction* instr : computation->instructions()) {
+ EXPECT_LE(instr->operand_count(), MaxOperandsAndOutputsPerFusion())
+ << instr->ToString();
+ }
+}
+
+TEST_F(InstructionFusionTest, FuseIntoScatter) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+ }
+
+ ENTRY FuseIntoScatter {
+ p0 = s32[3,3] parameter(0)
+ p1 = s32[2] parameter(1)
+ indices = s32[2] add(p1, p1)
+ p2 = s32[2,3] parameter(2)
+ updates = s32[2,3] add(p2, p2)
+ scatter = s32[3,3] scatter(p0, indices, updates),
+ to_apply=add,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ ROOT add = s32[3,3] add(scatter, scatter)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
+ EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
+}
+
+TEST_F(InstructionFusionTest, DontFuseIntoFirstOperandOfScatter) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+ }
+
+ ENTRY FuseIntoScatter {
+ p0 = s32[3,3] parameter(0)
+ operand = s32[3,3] add(p0, p0)
+ p1 = s32[2] parameter(1)
+ indices = s32[2] add(p1, p1)
+ p2 = s32[2,3] parameter(2)
+ updates = s32[2,3] add(p2, p2)
+ scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ ROOT add = s32[3,3] add(scatter, scatter)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
+ EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
+}
+
+TEST_F(InstructionFusionTest, ScatterOpShouldNotFuseWithSharedOperand) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY Test {
+ parameter.0 = f32[8,8] parameter(0)
+ parameter.1 = s32[7] parameter(1)
+ indices = s32[7] add(parameter.1, parameter.1)
+ slice = f32[7,8] slice(parameter.0), slice={[0:7],[0:8]}
+ ROOT scatter = f32[8,8] scatter(parameter.0, indices, slice),
+ to_apply=add,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ })")
+ .value();
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+ // Verify that we don't fuse scatter and slice together since
+ // scatter modifies the input buffer in-place, which is also used
+ // as slice's input, and we don't know where the scatter indices point to.
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(
+ root, GmockMatch(m::Fusion(m::Parameter(), m::Slice(), m::Parameter())));
+}
+
+TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY BroadcastIntoReduce {
+ constant = f32[16] constant({0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15})
+ broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={0}
+ constant.1 = f32[] constant(0)
+ ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
+ to_apply=add
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+ // The f32[16] constant should not be fused into the reduce, but the f32[]
+ // constant should be.
+ auto* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_THAT(
+ root->fused_instructions_computation()->root_instruction(),
+ GmockMatch(m::Reduce(m::Broadcast(m::Parameter()), m::Constant())));
+}
+
+TEST_F(InstructionFusionTest, FuseReverse) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ ENTRY Reverse {
+ p0 = f32[50,96,1024]{2,1,0} parameter(0)
+ add = f32[50,96,1024]{2,1,0} add(p0, p0)
+ ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0}
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_THAT(root->fused_expression_root(),
+ GmockMatch(m::Reverse(m::Add(m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveF32) {
+ auto m = CreateNewVerifiedModule();
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
+ HloInstruction* rem = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kRemainder, param0, one));
+ HloInstruction* sqrt = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
+ HloInstruction* rsqrt = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
+ HloInstruction* exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
+ EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*sqrt));
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rsqrt));
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*exp));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveF64) {
+ auto m = CreateNewVerifiedModule();
+ Shape r0f64 = ShapeUtil::MakeShape(F64, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f64, "param0"));
+
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f64, HloOpcode::kDivide, param0, one));
+ HloInstruction* rem = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f64, HloOpcode::kRemainder, param0, one));
+ HloInstruction* sqrt = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f64, HloOpcode::kSqrt, param0));
+ HloInstruction* rsqrt = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f64, HloOpcode::kRsqrt, param0));
+ HloInstruction* exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f64, HloOpcode::kExp, param0));
+
+ EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*div));
+ EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
+ EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*sqrt));
+ EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rsqrt));
+ EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*exp));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveS32) {
+ auto m = CreateNewVerifiedModule();
+ Shape r0s32 = ShapeUtil::MakeShape(S32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0s32, "param0"));
+
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32, HloOpcode::kDivide, param0, one));
+ HloInstruction* rem = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32, HloOpcode::kRemainder, param0, one));
+
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveBroadcastS32) {
+ auto m = CreateNewVerifiedModule();
+ Shape r1s32 = ShapeUtil::MakeShape(S32, {10});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1s32, "param0"));
+
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* one_broad =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r1s32, one, {}));
+
+ HloInstruction* div = builder.AddInstruction(HloInstruction::CreateBinary(
+ r1s32, HloOpcode::kDivide, param0, one_broad));
+ HloInstruction* rem = builder.AddInstruction(HloInstruction::CreateBinary(
+ r1s32, HloOpcode::kRemainder, param0, one_broad));
+
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
+ EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
+}
+
+TEST_F(InstructionFusionTest, FloatingPointExpIsCheap) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ Add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+ ENTRY TestComputation {
+ zero = f32[] constant(0)
+ p0 = f32[100] parameter(0)
+ recip = f32[100] exponential(p0)
+ sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ ROOT root = (f32[], f32[]) tuple(sum1, sum2)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
+ << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, SmallReducedDimensionIsNotLoweredToLoop) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+ }
+
+ ENTRY FuseSmallReduction {
+ p0 = s32[1048576,4] parameter(0)
+ p1 = s32[1048576,4] parameter(1)
+ sum = s32[1048576,4] add(p0, p1)
+ init = s32[] constant(0)
+ ROOT reduce = s32[1048576] reduce(sum, init), dimensions={1}, to_apply=add
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
+}
+
+TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule m
+
+ f {
+ tmp_0 = f32[] parameter(0)
+ tmp_1 = f32[] parameter(1)
+ tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
+ tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
+ tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
+ tmp_5 = s32[] parameter(2)
+ tmp_6 = s32[] parameter(3)
+ tmp_7 = s32[] minimum(tmp_5, tmp_6)
+ tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
+ tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
+ ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
+ }
+
+ minmax {
+ tmp_0 = f32[] parameter(0)
+ tmp_1 = f32[] parameter(2)
+ tmp_2 = s32[] parameter(1)
+ tmp_3 = s32[] parameter(3)
+ ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=f
+ }
+
+ ENTRY e {
+ tmp_0 = f32[554112,10]{1,0} parameter(0)
+ tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+ tmp_2 = f32[] constant(-inf)
+ tmp_3 = s32[] constant(0)
+ ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax
+ })")
+ .value();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
+ TestGpuDeviceInfo::RTXA6000DeviceInfo())
+ .Run(module.get())
+ .value());
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter())));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction()->fused_expression_root(),
+ GmockMatch(
+ m::Reduce(m::Parameter(), m::Iota(), m::Constant(), m::Constant())));
+}
+
+TEST_F(InstructionFusionTest, InputReductionFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+ add.clone.13 {
+ x.27 = f32[] parameter(0)
+ y.27 = f32[] parameter(1)
+ ROOT add.1036 = f32[] add(x.27, y.27)
+ }
+ add.clone.14 {
+ x.28 = f32[] parameter(0)
+ y.28 = f32[] parameter(1)
+ ROOT add.1037 = f32[] add(x.28, y.28)
+ }
+ add {
+ x = bf16[] parameter(0)
+ convert.448 = f32[] convert(x)
+ y = bf16[] parameter(1)
+ convert.449 = f32[] convert(y)
+ add.597 = f32[] add(convert.448, convert.449)
+ ROOT convert.450 = bf16[] convert(add.597)
+ }
+ ENTRY FuseSmallReduction {
+ param_2.7 = bf16[8,16,64,2048]{3,2,1,0} parameter(2)
+ convert.1395 = f32[8,16,64,2048]{3,2,1,0} convert(param_2.7)
+ param_0.85 = bf16[8,16,64,2048]{3,2,1,0} parameter(0)
+ convert.1393 = f32[8,16,64,2048]{3,2,1,0} convert(param_0.85)
+ multiply.1652 = f32[8,16,64,2048]{3,2,1,0} multiply(convert.1395, convert.1393)
+ convert.1392 = bf16[8,16,64,2048]{3,2,1,0} convert(multiply.1652)
+ bitcast.15934 = bf16[128,64,2048]{2,1,0} bitcast(convert.1392)
+ convert.1391 = f32[128,64,2048]{2,1,0} convert(bitcast.15934)
+ param_1.15 = bf16[] parameter(1)
+ convert.1394 = f32[] convert(param_1.15)
+ reduce.462 = f32[128,64]{1,0} reduce(convert.1391, convert.1394), dimensions={2}, to_apply=add.clone.13
+ reduce.121 = f32[64]{0} reduce(reduce.462, convert.1394), dimensions={0}, to_apply=add.clone.14
+ ROOT convert.890 = bf16[64]{0} convert(reduce.121)
+ })")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* fused_convert_fusion =
+ module->entry_computation()->root_instruction();
+
+ ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
+ SCOPED_TRACE(module->ToString());
+ EXPECT_EQ(fused_convert_fusion->fusion_kind(),
+ HloInstruction::FusionKind::kInput);
+}
+
+TEST_F(InstructionFusionTest, DotStrengthReductionFusion) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+scalar_add_computation {
+ scalar_rhs = f32[] parameter(1)
+ scalar_lhs = f32[] parameter(0)
+ ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+ param_1.3 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} parameter(1)
+ param_0.6 = f16[16,64,96,1,2,16]{5,4,3,2,1,0} parameter(0)
+ bitcast.26 = f16[16,64,96,2,16]{4,3,2,1,0} bitcast(param_0.6)
+ broadcast.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} broadcast(bitcast.26), dimensions={0,1,2,4,5}
+ multiply.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} multiply(broadcast.4, param_1.3)
+ convert.8 = f32[16,64,96,6,2,16]{5,4,3,2,1,0} convert(multiply.4)
+ constant_2 = f32[] constant(0)
+ reduce.3 = f32[16,64,96,6,2]{3,4,2,1,0} reduce(convert.8, constant_2), dimensions={5}, to_apply=scalar_add_computation
+ bitcast.25 = f32[16,64,96,2,6]{4,3,2,1,0} bitcast(reduce.3)
+ convert.7 = f16[16,64,96,2,6]{4,3,2,1,0} convert(bitcast.25)
+ ROOT bitcast.24 = f16[16,64,96,2,1,6]{5,4,3,2,1,0} bitcast(convert.7)
+})")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ const HloInstruction* fused_convert_fusion =
+ module->entry_computation()->root_instruction()->operand(0);
+
+ ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
+ SCOPED_TRACE(module->ToString());
+ EXPECT_EQ(fused_convert_fusion->fusion_kind(),
+ HloInstruction::FusionKind::kInput);
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
+}
+
+TEST_F(InstructionFusionTest, ReductionFusionOtherUnaryElementwiseOpsAreFused) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+scalar_add_computation {
+ scalar_rhs = f32[] parameter(1)
+ scalar_lhs = f32[] parameter(0)
+ ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+ param_0 = f16[64,96,6,16]{3,2,1,0} parameter(0)
+ constant_2 = f32[] constant(0)
+ reduce.3 = f32[64,6,16]{2,1,0} reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
+ negate = f32[64,6,16]{2,1,0} negate(reduce.3)
+ ROOT sine = f16[64,6,16]{2,1,0} sine(negate)
+})")
+ .value();
+
+ EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+ HloInstruction* fused_convert_fusion =
+ module->entry_computation()->root_instruction();
+
+ ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
+ SCOPED_TRACE(module->ToString());
+ EXPECT_EQ(fused_convert_fusion->fusion_kind(),
+ HloInstruction::FusionKind::kInput);
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseInsideReducer) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+scalar_add_computation {
+ scalar_rhs = f32[] parameter(1)
+ scalar_lhs = f32[] parameter(0)
+ add.1 = f32[] add(scalar_lhs, scalar_rhs)
+ ROOT add.2 = f32[] add(add.1, scalar_rhs)
+}
+
+ENTRY main {
+ param_0 = f16[64,96] parameter(0)
+ constant_2 = f32[] constant(0)
+ ROOT reduce = f32[64] reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
+})")
+ .value();
+
+ EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc
new file mode 100644
index 0000000..caa8d3c1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc
@@ -0,0 +1,596 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/layout_assignment.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/host_memory_offload_annotations.h"
+#include "xla/service/logical_buffer.h"
+#include "xla/shape.h"
+#include "xla/shape_layout.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tsl/util/env_var.h"
+#include "xla/util.h"
+#include "xla/window_util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+using se::dnn::DataLayout;
+using se::dnn::FilterLayout;
+
+// Returns (input, filter, output) layouts.
+static std::tuple<DataLayout, FilterLayout, DataLayout>
+HeuristicLayoutAssignment(const HloInstruction* instr,
+ const se::GpuComputeCapability& gpu_version,
+ const se::dnn::VersionInfo& dnn_version) {
+ // DataLayout and FilterLayout uses weird enum names. Translations:
+ // N <=> Batch or Output
+ // C <=> Depth or Input
+ // H <=> Y
+ // W <=> X
+ //
+ // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
+ //
+ // If you have trouble keeping these straight, consider that all that matters
+ // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
+
+ constexpr auto kAllNCHW =
+ std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX);
+ // kBatchDepthYX4 has the same layout as kBatchDepthYX32; they're both VECT_C
+ // layouts as far as cudnn is concerned.
+ constexpr auto kAllNCHW_VECT_C =
+ std::make_tuple(DataLayout::kBatchDepthYX4, FilterLayout::kOutputInputYX4,
+ DataLayout::kBatchDepthYX4);
+ constexpr auto kAllNHWC =
+ std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
+ DataLayout::kBatchYXDepth);
+
+ // Integer convolution must use NHWC or NCHW_VECT_C.
+ //
+ // TODO(jlebar): Do non-VECT_C int8_t convs still require NHWC with new
+ // versions of cudnn?
+ const ConvolutionDimensionNumbers& dnums =
+ instr->convolution_dimension_numbers();
+ Shape input_shape = instr->operand(0)->shape();
+ PrimitiveType input_ty = instr->operand(0)->shape().element_type();
+ if (primitive_util::IsIntegralType(input_ty)) {
+ if (input_ty == S8 && dnums.input_spatial_dimensions_size() == 2 &&
+ input_shape.dimensions_size() == 5) {
+ VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString();
+ return kAllNCHW_VECT_C;
+ }
+ VLOG(2) << "Using NHWC for int8_t conv " << instr->ToString();
+ return kAllNHWC;
+ }
+
+ if (primitive_util::IsF8Type(input_ty)) {
+ VLOG(2) << "Using NHWC for FP8 conv " << instr->ToString();
+ return kAllNHWC;
+ }
+
+ const DebugOptions& debug_options =
+ instr->GetModule()->config().debug_options();
+
+ if (debug_options.xla_gpu_force_conv_nchw()) {
+ VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
+ return kAllNCHW;
+ }
+
+ if (debug_options.xla_gpu_force_conv_nhwc()) {
+ VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
+ return kAllNHWC;
+ }
+
+ const auto* rocm_compute_capability =
+ std::get_if<se::RocmComputeCapability>(&gpu_version);
+ if (rocm_compute_capability && input_ty == F16) return kAllNHWC;
+
+ // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
+ // easy: Use NCHW.
+ const bool isFloat16 = (input_ty == F16) || (input_ty == BF16);
+ if (std::holds_alternative<se::CudaComputeCapability>(gpu_version)) {
+ // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
+ // easy: Use NCHW.
+ const auto* cuda_compute_capability =
+ std::get_if<se::CudaComputeCapability>(&gpu_version);
+ bool is_volta =
+ cuda_compute_capability &&
+ cuda_compute_capability->IsAtLeast(se::CudaComputeCapability::VOLTA);
+ if (!isFloat16 || !is_volta ||
+ instr->shape().tuple_shapes(0).dimensions_size() != 4) {
+ return kAllNCHW;
+ }
+
+ // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
+ // convs with stride are significantly faster with NCHW layouts.
+ //
+ // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
+ // which on paper gives good performance. However, there are two
+ // observations:
+ // * a mixed layout combination is more cuDNN-bug prone, based on empirical
+ // evidence.
+ // * we've also observed that for mixed layouts, cuDNN transposes data back
+ // and forth from a different layout combination. If we end up with
+ // transposes anyway, we prefer to have them in XLA, as they can be fused.
+ if (std::make_tuple(dnn_version.major_version(),
+ dnn_version.minor_version()) <= std::make_tuple(7, 3) &&
+ instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
+ window_util::HasStride(instr->window())) {
+ return kAllNCHW;
+ }
+ } else if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
+ bool is_enabled = false;
+ TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_ROCM_NHWC",
+ /*default_val=*/false, &is_enabled));
+ auto rocm_compute_capability =
+ std::get<se::RocmComputeCapability>(gpu_version);
+ if (!isFloat16 || (!rocm_compute_capability.has_nhwc_layout_support()) ||
+ instr->shape().tuple_shapes(0).dimensions_size() != 4 || !is_enabled) {
+ return kAllNCHW;
+ }
+ }
+
+ VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
+
+ // For other Volta f16 convolutions, use NHWC.
+ return kAllNHWC;
+}
+
+// Adds layout constraints on the cudnn custom-call instruction. The layout
+// constraints are represented in terms of minor_to_major fields of both
+// operands and the output shape. Depending on the underlying algorithm, one of
+// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
+absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
+ Shape lhs_shape = instr->operand(0)->shape();
+ Shape rhs_shape = instr->operand(1)->shape();
+ Shape result_shape = instr->shape().tuple_shapes(0);
+
+ Shape* input_shape;
+ Shape* filter_shape;
+ Shape* output_shape;
+
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ case CudnnConvKind::kForwardGraph:
+ input_shape = &lhs_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &result_shape;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ input_shape = &result_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &lhs_shape;
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ input_shape = &lhs_shape;
+ filter_shape = &result_shape;
+ output_shape = &rhs_shape;
+ break;
+ }
+
+ {
+ DataLayout input;
+ FilterLayout filter;
+ DataLayout output;
+ std::tie(input, filter, output) =
+ HeuristicLayoutAssignment(instr, gpu_version_, dnn_version_);
+
+ TF_ASSIGN_OR_RETURN(
+ std::tie(*input_shape->mutable_layout(),
+ *filter_shape->mutable_layout(),
+ *output_shape->mutable_layout()),
+ StreamExecutorConvLayoutsToXlaLayouts(
+ instr->convolution_dimension_numbers(), input, filter, output));
+ }
+
+ // The custom call returns a tuple of (actual_result, scratch_buffer);
+ // call_result_buf is the logical buffer for actual_result, the thing that
+ // contains the result of the conv call.
+ TF_ASSIGN_OR_RETURN(
+ const LogicalBuffer* call_result_buf,
+ points_to_analysis_->GetBufferDefinedAt(instr, /*index=*/{0}));
+
+ // Set layouts of the instructions' shapes.
+ TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0));
+ TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1));
+ TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf));
+ // For fused convolutions, instr->operand(2), if exists, is the bias buffer.
+ // There is no need to assign layout to it, as it has only one dimension.
+ // instr->operand(3), if exists, is the side input buffer.
+ if (kind == CudnnConvKind::kForwardActivation &&
+ instr->operand_count() == 4) {
+ // The side input layout must match the output layout.
+ TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3));
+ }
+
+ // For graph convolutions, align the layouts of the non-scalar inputs to any
+ // pointwise ops with the output layout.
+ if (kind == CudnnConvKind::kForwardGraph) {
+ for (int k = 2; k < instr->operand_count(); ++k) {
+ if (!ShapeUtil::IsScalar(instr->operand(k)->shape())) {
+ TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, k));
+ }
+ }
+ }
+
+ if (instr->operand_count() > 2 && kind != CudnnConvKind::kForwardActivation &&
+ kind != CudnnConvKind::kForwardGraph) {
+ return Internal(
+ "Invalid convolution. Conv has a side input, but kind is not fused "
+ "conv forward or graph conv foward: %s",
+ instr->ToString());
+ }
+
+ return absl::OkStatus();
+}
+
+namespace {
+
+// Imposes the default layout with first two dimensions swapped on input
+// `shape`.
+void SetFortranLayout(Shape* shape) {
+ LayoutUtil::SetToDefaultLayout(shape);
+ int n = shape->mutable_layout()->minor_to_major_size();
+ CHECK_GE(n, 2);
+ std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
+ shape->mutable_layout()->mutable_minor_to_major()->at(1));
+}
+
+bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
+ const Shape& shape) {
+ const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers();
+ // If we are able to construct a `MatrixLayout` then the dot can support
+ // this layout.
+ return MatrixLayout::For(shape, dot_dims.lhs_batch_dimensions().size(),
+ dot->operand(0)->shape().rank() -
+ dot_dims.lhs_contracting_dimensions().size() -
+ dot_dims.lhs_batch_dimensions().size(),
+ dot_dims.rhs_batch_dimensions().size(),
+ dot->operand(1)->shape().rank() -
+ dot_dims.rhs_contracting_dimensions().size() -
+ dot_dims.rhs_batch_dimensions().size())
+ .ok();
+}
+
+} // namespace
+
+absl::Status GpuLayoutAssignment::AddBackendConstraints(
+ LayoutConstraints* constraints) {
+ // Add convolution constraints in reverse postorder that the earliest
+ // convolution layout propagates first. This reduces the likelihood of fusion
+ // nodes with copies.
+ auto post_order = constraints->computation()->MakeInstructionPostOrder();
+ for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
+ ++iterator) {
+ HloInstruction* instruction = *iterator;
+ if (IsCustomCallToDnnConvolution(*instruction)) {
+ TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
+ Cast<HloCustomCallInstruction>(instruction), constraints));
+ }
+
+ CHECK(!IsCublasGemm(*instruction))
+ << "Gemm rewriting should run after layout assignment";
+
+ if (instruction->opcode() == HloOpcode::kDot) {
+ const Shape& output_shape = instruction->shape();
+ const Shape& lhs_shape = instruction->operand(0)->shape();
+ const Shape& rhs_shape = instruction->operand(1)->shape();
+ const DotDimensionNumbers& dot_dims =
+ instruction->dot_dimension_numbers();
+
+ // Matmuls require the batch dimensions to be in consecutive physical
+ // dimensions and likewise for the contracting and non-contracting
+ // dimensions. Additionally, no batch dimension can be in the most
+ // minor physical dimension for inputs or the output.
+ absl::Span<const int64_t> lhs_batch_dims =
+ dot_dims.lhs_batch_dimensions();
+ absl::Span<const int64_t> lhs_contracting_dims =
+ dot_dims.lhs_contracting_dimensions();
+ TF_ASSIGN_OR_RETURN(std::vector<int64_t> lhs_non_contracting_dims,
+ GetNonContractingDims(lhs_shape, lhs_batch_dims,
+ lhs_contracting_dims));
+
+ absl::Span<const int64_t> rhs_batch_dims =
+ dot_dims.rhs_batch_dimensions();
+ absl::Span<const int64_t> rhs_contracting_dims =
+ dot_dims.rhs_contracting_dimensions();
+ TF_ASSIGN_OR_RETURN(std::vector<int64_t> rhs_non_contracting_dims,
+ GetNonContractingDims(rhs_shape, rhs_batch_dims,
+ rhs_contracting_dims));
+
+ const DebugOptions& debug_options =
+ instruction->GetModule()->config().debug_options();
+
+ bool is_bf16_to_bf16 =
+ (output_shape.element_type() == PrimitiveType::BF16 &&
+ lhs_shape.element_type() == PrimitiveType::BF16 &&
+ rhs_shape.element_type() == PrimitiveType::BF16);
+ bool is_s8_to_s32 = (output_shape.element_type() == PrimitiveType::S32 &&
+ lhs_shape.element_type() == PrimitiveType::S8 &&
+ rhs_shape.element_type() == PrimitiveType::S8 &&
+ output_shape.dimensions_size() == 2 &&
+ lhs_shape.dimensions_size() == 2 &&
+ rhs_shape.dimensions_size() == 2);
+
+ if (is_s8_to_s32 ||
+ (is_bf16_to_bf16 &&
+ debug_options.xla_gpu_ensure_minor_dot_contraction_dims())) {
+ TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
+ instruction, /*operand=*/0,
+ /*dim_groups=*/
+ {lhs_batch_dims, lhs_non_contracting_dims, lhs_contracting_dims}));
+ TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
+ instruction, /*operand=*/1,
+ /*dim_groups=*/
+ {rhs_batch_dims, rhs_non_contracting_dims, rhs_contracting_dims}));
+ TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
+ } else {
+ if (!lhs_batch_dims.empty() || lhs_contracting_dims.size() > 1 ||
+ lhs_non_contracting_dims.size() > 1) {
+ TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 0, lhs_batch_dims,
+ lhs_contracting_dims,
+ lhs_non_contracting_dims));
+ }
+ if (!rhs_batch_dims.empty() || rhs_non_contracting_dims.size() > 1 ||
+ rhs_contracting_dims.size() > 1) {
+ TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 1, rhs_batch_dims,
+ rhs_contracting_dims,
+ rhs_non_contracting_dims));
+ }
+ // If we have at least one batch dimension or there is more than one
+ // non-contracting dimension on lhs or rhs, we need to set a layout for
+ // the dot output.
+ if (!lhs_batch_dims.empty() || lhs_non_contracting_dims.size() > 1 ||
+ rhs_non_contracting_dims.size() > 1) {
+ TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
+ }
+ }
+ } else if (instruction->opcode() == HloOpcode::kTranspose) {
+ const HloInstruction* operand = instruction->operand(0);
+ if ((operand->opcode() != HloOpcode::kDot) ||
+ (operand->user_count() > 1)) {
+ continue;
+ }
+
+ // If possible, set layout of the dot operation such that the output of
+ // the transpose (as a bitcast) has the default layout.
+ Shape shape = operand->shape();
+ *shape.mutable_layout() =
+ LayoutUtil::MakeLayoutFromMajorToMinor(instruction->dimensions());
+
+ if (DotCanSupportShapeWithLayout(operand, shape)) {
+ TF_RETURN_IF_ERROR(
+ SetOperandLayout(shape, instruction, /*operand_no=*/0));
+ }
+ } else if (instruction->opcode() == HloOpcode::kFft) {
+ // cuFFT requires a dim0 major layout.
+ Shape op0_shape = instruction->operand(0)->shape();
+ LayoutUtil::SetToDefaultLayout(&op0_shape);
+ Shape output_shape = instruction->shape();
+ LayoutUtil::SetToDefaultLayout(&output_shape);
+ TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
+ TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
+ } else if (instruction->opcode() == HloOpcode::kSort &&
+ instruction->operand(0)->shape().rank() > 1) {
+ // Make sure that all the operands and the output(s) have the same layout.
+ Shape keys_shape = instruction->operand(0)->shape();
+ Layout keys_layout =
+ LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
+ for (int64_t i = 0; i < instruction->operand_count(); ++i) {
+ Shape shape = instruction->operand(i)->shape();
+ *shape.mutable_layout() = keys_layout;
+ TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i));
+ const LogicalBuffer* output_buffer;
+ if (instruction->shape().IsArray()) {
+ TF_ASSIGN_OR_RETURN(
+ output_buffer,
+ points_to_analysis_->GetBufferDefinedAt(instruction, {}));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ output_buffer,
+ points_to_analysis_->GetBufferDefinedAt(instruction, {i}));
+ }
+ TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer));
+ }
+ } else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
+ // TODO(phawkins): Ideally we would relax this constraint. What we
+ // actually want is that:
+ // a) the batch dimensions are major, in no particular order.
+ // b) the two minor dimensions are in fortran (column-major) order,
+ // although for the 'a' argument we could potentially accept row-major
+ // order and fold the transpose into the operator.
+ Shape op0_shape = instruction->operand(0)->shape();
+ Shape op1_shape = instruction->operand(1)->shape();
+ Shape output_shape = instruction->shape();
+ SetFortranLayout(&op0_shape);
+ SetFortranLayout(&op1_shape);
+ SetFortranLayout(&output_shape);
+ TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
+ TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1));
+ TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
+ } else if (instruction->opcode() == HloOpcode::kReduceScatter) {
+ // XLA:GPU can only support reduce-scatter where the scatter dimension
+ // is the most major dimension in the layout.
+ auto ars = Cast<HloReduceScatterInstruction>(instruction);
+ TF_RETURN_IF_ERROR(SetInstructionLayout(
+ ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()),
+ ars));
+ } else if (instruction->opcode() == HloOpcode::kAllGather) {
+ // XLA:GPU can only support all-gathers where the gather dimension is the
+ // most major dimension in the layout.
+ auto ag = Cast<HloAllGatherInstruction>(instruction);
+ TF_RETURN_IF_ERROR(SetInstructionLayout(
+ ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()),
+ ag));
+ } else if (instruction->opcode() == HloOpcode::kAllToAll &&
+ instruction->shape().IsArray()) {
+ // XLA:GPU can only support all-to-all with split dimensions where the
+ // split dimension is the most major dimension in the layout.
+ auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
+ TF_RETURN_IF_ERROR(SetInstructionLayout(
+ ShapeUtil::MoveDimToMajor(all_to_all->shape(),
+ *all_to_all->split_dimension()),
+ all_to_all));
+ } else if (instruction->opcode() == HloOpcode::kSend) {
+ Shape s = instruction->operand(0)->shape();
+ LayoutUtil::SetToDefaultLayout(&s);
+ TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction->operand(0)));
+ TF_RETURN_IF_ERROR(
+ SetArrayOperandLayout(s.layout(), instruction->operand(0), 0));
+ } else if (instruction->opcode() == HloOpcode::kRecv) {
+ Shape s = instruction->shape();
+ ShapeUtil::ForEachMutableSubshape(
+ &s, [&](Shape* subshape, const ShapeIndex& index) {
+ LayoutUtil::SetToDefaultLayout(subshape);
+ });
+ TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status GpuLayoutAssignment::SetDotOperandLayout(
+ const HloInstruction* instruction, int64_t operand,
+ absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
+ absl::Span<const int64_t> col_dims) {
+ Shape shape = instruction->operand(operand)->shape();
+
+ // First, try to use the existing layout, if present.
+ if (shape.has_layout() &&
+ MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
+ // Re-set the operand layout, so it becomes mandatory.
+ return SetOperandLayout(shape, instruction, operand);
+
+ // Next, try the default layout (for the sake of everybody's sanity).
+ LayoutUtil::SetToDefaultLayout(&shape);
+ if (MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
+ return SetOperandLayout(shape, instruction, operand);
+
+ // Otherwise, fallback to forcing (batch, rows, cols) layout.
+ return SetOperandMajorToMinorLayout(
+ instruction, operand,
+ /*dim_groups=*/{batch_dims, row_dims, col_dims});
+}
+
+absl::Status GpuLayoutAssignment::SetOperandMajorToMinorLayout(
+ const HloInstruction* instruction, int64_t operand,
+ std::initializer_list<absl::Span<const int64_t>> dim_groups) {
+ size_t size = 0;
+ for (auto group : dim_groups) size += group.size();
+ std::vector<int64_t> major_to_minor;
+ major_to_minor.reserve(size);
+ for (const auto& group : dim_groups) {
+ major_to_minor.insert(major_to_minor.end(), group.begin(), group.end());
+ }
+
+ Shape shape = instruction->operand(operand)->shape();
+ *shape.mutable_layout() =
+ LayoutUtil::MakeLayoutFromMajorToMinor(major_to_minor);
+ return SetOperandLayout(shape, instruction, operand);
+}
+
+absl::Status GpuLayoutAssignment::SetDotLayout(
+ const HloInstruction* instruction, LayoutConstraints* constraints) {
+ // If a user has requested a layout that we can support, use that.
+ for (const HloInstruction* user : instruction->users()) {
+ for (int64_t i = 0; i < user->operand_count(); ++i) {
+ if (user->operand(i) != instruction) {
+ continue;
+ }
+
+ const ShapeLayout* constraint = constraints->OperandLayout(user, i);
+ if ((constraint != nullptr) &&
+ DotCanSupportShapeWithLayout(instruction, constraint->shape())) {
+ return SetInstructionLayout(constraint->shape(), instruction);
+ }
+ }
+ }
+
+ // Otherwise, use the default layout.
+ return SetInstructionLayout(
+ LayoutUtil::GetWithDefaultLayout(instruction->shape()), instruction);
+}
+
+bool GpuLayoutAssignment::PropagateReductionLayoutToOperand(
+ const HloInstruction* user) {
+ // We try to propagate a layout to make the reduction a row reduction. But
+ // propagating the layout is only beneficial if the reduction emitter would be
+ // used for the row reduction.
+ int64_t reduction_size = 1;
+ for (int64_t reduction_dim : user->dimensions()) {
+ reduction_size *= user->operand(0)->shape().dimensions(reduction_dim);
+ }
+ int64_t kept_dimension_size = ShapeUtil::ElementsIn(user->shape());
+ return IsUnnestedReductionFasterThanElemental(
+ {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}});
+}
+
+bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance(
+ const HloInstruction* instruction) {
+ // The host offloading custom calls will be eventually removed
+ // by the offloader, so we need to make sure that the calls do not change
+ // the layout and thus cause layout mismatches after the removal.
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
+ if (custom_call != nullptr &&
+ (custom_call->custom_call_target() ==
+ host_memory_offload_annotations::kMoveToHostCustomCallTarget ||
+ custom_call->custom_call_target() ==
+ host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
+ return false;
+ }
+
+ return LayoutAssignment::InstructionCanChangeLayoutInstance(instruction);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.h b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h
new file mode 100644
index 0000000..efa58f3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h
@@ -0,0 +1,81 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_
+
+#include <cstdint>
+#include <initializer_list>
+
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/computation_layout.h"
+#include "xla/service/layout_assignment.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+// GPU-specific layout assignment pass which preassigns layouts to satisfy
+// layout constraints for operands and results of library calls.
+class GpuLayoutAssignment : public LayoutAssignment {
+ public:
+ explicit GpuLayoutAssignment(
+ ComputationLayout* entry_computation_layout,
+ const se::GpuComputeCapability& gpu_version,
+ const se::dnn::VersionInfo& dnn_version,
+ ChannelLayoutConstraints* channel_constraints = nullptr)
+ : LayoutAssignment(entry_computation_layout, channel_constraints),
+ gpu_version_(gpu_version),
+ dnn_version_(dnn_version) {}
+ ~GpuLayoutAssignment() override = default;
+
+ protected:
+ absl::Status AddBackendConstraints(LayoutConstraints* constraints) override;
+
+ private:
+ absl::Status AddBackendConstraintsToDnnConvCustomCall(
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints);
+
+ // dim_groups are ordered from major to minor dimensions.
+ absl::Status SetOperandMajorToMinorLayout(
+ const HloInstruction* instruction, int64_t operand,
+ std::initializer_list<absl::Span<const int64_t>> dim_groups);
+
+ absl::Status SetDotOperandLayout(const HloInstruction* instruction,
+ int64_t operand,
+ absl::Span<const int64_t> batch_dims,
+ absl::Span<const int64_t> row_dims,
+ absl::Span<const int64_t> col_dims);
+
+ absl::Status SetDotLayout(const HloInstruction* instruction,
+ LayoutConstraints* constraints);
+
+ bool PropagateReductionLayoutToOperand(const HloInstruction* user) override;
+
+ bool InstructionCanChangeLayoutInstance(
+ const HloInstruction* instruction) override;
+
+ const se::GpuComputeCapability gpu_version_;
+ const se::dnn::VersionInfo dnn_version_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc
new file mode 100644
index 0000000..dd1cbc6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc
@@ -0,0 +1,677 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/layout_assignment.h"
+
+#include <cstdint>
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/service/computation_layout.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_layout.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+using ::tsl::testing::IsOkAndHolds;
+
+class LayoutAssignmentTest : public HloTestBase {
+ public:
+ se::CudaComputeCapability GetCudaComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .cuda_compute_capability();
+ }
+
+ se::GpuComputeCapability GetGpuComputeCapability() {
+ return backend()
+ .default_stream_executor()
+ ->GetDeviceDescription()
+ .gpu_compute_capability();
+ }
+
+ se::dnn::VersionInfo GetDnnVersion() {
+ // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but
+ // none of the tests trigger this heuristic.
+ return GetDnnVersionInfoOrDefault(backend().default_stream_executor(),
+ se::dnn::VersionInfo{8, 3, 0});
+ }
+};
+
+TEST_F(LayoutAssignmentTest, Elementwise) {
+ Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
+ Shape ashape_in_row_major(ashape);
+ Shape ashape_in_col_major(ashape);
+ *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ // Enumerate all possible combinations of layouts.
+ for (const Shape& lhs_shape_with_layout :
+ {ashape_in_row_major, ashape_in_col_major}) {
+ for (const Shape& rhs_shape_with_layout :
+ {ashape_in_row_major, ashape_in_col_major}) {
+ for (const Shape& result_shape_with_layout :
+ {ashape_in_row_major, ashape_in_col_major}) {
+ // GpuLayoutAssignment should assign the same layout to "add" and its
+ // two operands.
+ auto builder = HloComputation::Builder(TestName());
+ auto x = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ashape, "x"));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ashape, "y"));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
+ auto module = CreateNewVerifiedModule();
+ HloComputation* computation =
+ module->AddEntryComputation(builder.Build(add));
+
+ ComputationLayout computation_layout(
+ computation->ComputeProgramShape());
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(lhs_shape_with_layout);
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(rhs_shape_with_layout);
+ *computation_layout.mutable_result_layout() =
+ ShapeLayout(result_shape_with_layout);
+
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+
+ for (const HloInstruction* operand : add->operands()) {
+ EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
+ operand->shape().layout()));
+ }
+ }
+ }
+ }
+}
+
+TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[5,2,3]{1,2,0} parameter(0)
+ p1 = f32[5,3,4]{1,2,0} parameter(1)
+ ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
+ lhs_batch_dims={0}, lhs_contracting_dims={2},
+ rhs_batch_dims={0}, rhs_contracting_dims={1}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {1, 2, 0}),
+ m::Op().WithShape(F32, {5, 3, 4}, {1, 2, 0}))
+ .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
+}
+
+TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[5,3,2] parameter(0)
+ p1 = f32[5,4,3]{0,1,2} parameter(1)
+ ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
+ lhs_batch_dims={0}, lhs_contracting_dims={1},
+ rhs_batch_dims={0}, rhs_contracting_dims={2}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 3, 2}, {2, 1, 0}),
+ m::Op().WithShape(F32, {5, 4, 3}, {2, 1, 0}))
+ .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
+}
+
+TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[2,3,5]{2,1,0} parameter(0)
+ p1 = f32[3,4,5] parameter(1)
+ ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
+ lhs_batch_dims={2}, lhs_contracting_dims={1},
+ rhs_batch_dims={2}, rhs_contracting_dims={0}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Op().WithShape(F32, {2, 3, 5}, {0, 1, 2}),
+ m::Op().WithShape(F32, {3, 4, 5}, {1, 0, 2}))));
+}
+
+TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[5,6,2,3] parameter(0)
+ p1 = f32[6,5,3,4] parameter(1)
+ ROOT dot.1330.10585 = f32[5,6,2,4] dot(p0, p1),
+ lhs_batch_dims={0,1}, lhs_contracting_dims={3},
+ rhs_batch_dims={1,0}, rhs_contracting_dims={2}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 6, 2, 3}, {3, 2, 1, 0}),
+ m::Op().WithShape(F32, {6, 5, 3, 4}, {3, 2, 0, 1}))));
+}
+
+TEST_F(LayoutAssignmentTest, TransposedDotLayout) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[5,2,3] parameter(0)
+ p1 = f32[5,3,4,6] parameter(1)
+ dot = f32[5,2,4,6] dot(p0, p1),
+ lhs_batch_dims={0}, lhs_contracting_dims={2},
+ rhs_batch_dims={0}, rhs_contracting_dims={1}
+ ROOT out = f32[2,5,4,6] transpose(dot), dimensions={1,0,2,3}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Transpose(
+ m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {2, 1, 0}),
+ m::Op().WithShape(F32, {5, 3, 4, 6}, {3, 2, 1, 0}))
+ .WithShape(F32, {5, 2, 4, 6}, {3, 2, 0, 1}))
+ .WithShape(F32, {2, 5, 4, 6}, {3, 2, 1, 0})));
+}
+
+TEST_F(LayoutAssignmentTest, TransposedDotOfDotLayout) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[8,50] parameter(0)
+ p1 = f32[2,8,4,4] parameter(1)
+ p2 = f32[4,38] parameter(2)
+ dot.1 = f32[50,2,4,4]{3,2,1,0} dot(p0, p1),
+ lhs_contracting_dims={0}, rhs_contracting_dims={1}
+ dot.2 = f32[50,2,4,38]{3,2,1,0} dot(dot.1, p2),
+ lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ ROOT out = f32[2,50,38,4]{2,3,0,1} transpose(dot.2), dimensions={1,0,3,2}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ // The transpose layout is not supported by dot.2. Also, we need a copy
+ // between dot.1 and dot.2, because the needed operand layout for the lhs of
+ // dot.1 cannot be used as layout for dot.1
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Transpose(
+ m::Dot(m::Copy(m::Dot(m::Op().WithShape(F32, {8, 50}, {1, 0}),
+ m::Op().WithShape(F32, {2, 8, 4, 4},
+ {3, 2, 0, 1}))
+ .WithShape(F32, {50, 2, 4, 4}, {3, 2, 1, 0}))
+ .WithShape(F32, {50, 2, 4, 4}, {3, 1, 0, 2}),
+ m::Op().WithShape(F32, {4, 38}, {1, 0}))
+ .WithShape(F32, {50, 2, 4, 38}, {3, 2, 1, 0}))
+ .WithShape(F32, {2, 50, 38, 4}, {2, 3, 0, 1})));
+}
+
+TEST_F(LayoutAssignmentTest, DotLayoutS8) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY int8_t {
+ p0 = s8[32,64] parameter(0)
+ p1 = s8[64,96] parameter(1)
+ ROOT out = s32[32,96] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Dot(m::Op().WithShape(S8, {32, 64}, {1, 0}),
+ m::Op().WithShape(S8, {64, 96}, {0, 1}))));
+}
+
+TEST_F(LayoutAssignmentTest, SortLayout) {
+ const char* hlo_text = R"(
+ HloModule SortLayout
+
+ compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ p.1.lhs = f32[] parameter(2)
+ p.1.rhs = f32[] parameter(3)
+ ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
+ }
+
+ ENTRY sort {
+ keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
+ values = f32[2,3]{1,0} parameter(0)
+ transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
+ ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
+ dimensions={1}, to_apply=compare
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Sort(m::Op().WithShape(F32, {3, 2}, {1, 0}),
+ m::Op().WithShape(F32, {3, 2}, {1, 0}))));
+}
+
+TEST_F(LayoutAssignmentTest, FftLayout) {
+ const char* hlo_text = R"(
+ HloModule Fft_module
+
+ ENTRY Fft {
+ input = c64[8,32]{0,1} parameter(0)
+ fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
+ ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Copy(
+ m::Transpose(m::Fft(m::Op().WithShape(C64, {8, 32}, {1, 0}))
+ .WithShape(C64, {8, 32}, {1, 0})))));
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallConstrainedAlias) {
+ const char* module_str = R"(
+HloModule TestModule
+
+ENTRY entry {
+ Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
+ Arg_1 = f32[2,5,5]{2,1,0} parameter(1)
+ Arg_2 = f32[2,5,5]{2,1,0} parameter(2)
+ dot.0 = f32[2,5,5]{2,1,0} dot(Arg_1, Arg_2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}
+ custom-call.0 = (f32[2,5,5]{1,2,0}, s8[16]{0}, s8[16]{0}) custom-call(Arg_0, dot.0), custom_call_target="dummy_call", operand_layout_constraints={f32[2,5,5]{1,2,0}, f32[2,5,5]{1,2,0}}, output_to_operand_aliasing={{0}: (1, {})}
+ ROOT get-tuple-element.0 = f32[2,5,5]{1,2,0} get-tuple-element(custom-call.0), index=0
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+
+ const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
+ auto expect_layout = [](const Shape& shape,
+ absl::Span<const int64_t> minor_to_major) {
+ const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+ EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
+ << "Expected layout " << expected << ", actual " << shape.layout();
+ };
+ expect_layout(ShapeUtil::GetSubshape(call_0->shape(), {0}), {1, 2, 0});
+ expect_layout(call_0->operand(0)->shape(), {1, 2, 0});
+ expect_layout(call_0->operand(1)->shape(), {1, 2, 0});
+}
+
+TEST_F(LayoutAssignmentTest, MoveToHostCustomCallConstrained) {
+ const char* module_str = R"(
+HloModule TestModule
+
+ENTRY entry {
+ Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
+ custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToHost"
+ ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+
+ const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
+ const Layout input_layout = call_0->operand(0)->shape().layout();
+ const Layout output_layout = call_0->shape().layout();
+ EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
+ << "Expected the same input/output layouts. Input: " << input_layout
+ << ". Output: " << output_layout;
+}
+
+TEST_F(LayoutAssignmentTest, MoveToDeviceCustomCallConstrained) {
+ const char* module_str = R"(
+HloModule TestModule
+
+ENTRY entry {
+ Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
+ custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToDevice"
+ ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+
+ const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
+ const Layout input_layout = call_0->operand(0)->shape().layout();
+ const Layout output_layout = call_0->shape().layout();
+ EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
+ << "Expected the same input/output layouts. Input: " << input_layout
+ << ". Output: " << output_layout;
+}
+
+TEST_F(LayoutAssignmentTest, ConvCuDNNF8) {
+ if (!GetCudaComputeCapability().IsAtLeast(
+ se::CudaComputeCapability::HOPPER)) {
+ GTEST_SKIP() << "FP8 convolutions require HOPPER or newer archiecture.";
+ }
+
+ const char* hlo = R"(
+
+ HloModule jit_conv_general_dilated
+
+ ENTRY main.4 {
+ Arg_0 = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
+ Arg_1 = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
+ ROOT conv = f8e4m3fn[1,64,64,32]{3,2,1,0} convolution(Arg_0, Arg_1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
+ }
+)";
+
+ MatchOptimizedHlo(hlo, R"(
+ // CHECK: [[P0:%[^ ]+]] = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
+ // CHECK: [[P1:%[^ ]+]] = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
+ // CHECK-NEXT: [[P2:%[^ ]+]] = f8e4m3fn[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
+ // CHECK-NEXT: [[CONV:%[^ ]+]] = (f8e4m3fn[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+ )");
+}
+
+TEST_F(LayoutAssignmentTest, ConvCuDNNBF16) {
+ if (!GetCudaComputeCapability().IsAtLeast(
+ se::CudaComputeCapability::AMPERE)) {
+ GTEST_SKIP() << "Conv with Bfloat16 uses NHWC layout for "
+ "architectures with Tensor Cores.";
+ }
+
+ const char* hlo = R"(
+
+ HloModule jit_conv_general_dilated
+
+ ENTRY main.4 {
+ Arg_0.1 = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+ ROOT convolution.3 = bf16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 64, 64, 16) rhs_shape=(3, 3, 16, 32) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.8/dist-packages/flax/linen/linear.py" source_line=438}
+ }
+)";
+
+ MatchOptimizedHlo(hlo, R"(
+ // CHECK: [[P0:%[^ ]+]] = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+ // CHECK: [[P1:%[^ ]+]] = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+ // CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
+ // CHECK-NEXT: %cudnn-conv.1 = (bf16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
+ )");
+}
+
+TEST_F(LayoutAssignmentTest, ConvCuDNNFP16) {
+ if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
+ GTEST_SKIP() << "Conv with FP16 uses NHWC layout for "
+ "architectures with Tensor Cores.";
+ }
+
+ const char* hlo = R"(
+
+ HloModule jit_conv_general_dilated
+
+ ENTRY main.4 {
+ Arg_0.1 = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+ Arg_1.2 = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+ ROOT convolution.3 = f16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
+ }
+)";
+
+ MatchOptimizedHlo(hlo, R"(
+ // CHECK: [[P0:%[^ ]+]] = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+ // CHECK: [[P1:%[^ ]+]] = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+ // CHECK-NEXT: [[P2:%[^ ]+]] = f16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
+ // CHECK-NEXT: %cudnn-conv.1 = (f16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
+ )");
+}
+
+TEST_F(LayoutAssignmentTest, ReduceOperandLayout) {
+ const char* module_str = R"(
+scalar_add_computation {
+ scalar_lhs = c64[] parameter(0)
+ scalar_rhs = c64[] parameter(1)
+ ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+ param_0 = c64[512,64,1024,32,128]{4,3,2,1,0} parameter(0)
+ negate = c64[512,64,1024,32,128]{4,3,2,1,0} negate(param_0)
+ constant_7 = c64[] constant((0, 0))
+ ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1,3}, to_apply=scalar_add_computation
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+ auto reduce = m->entry_computation()->root_instruction();
+ EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
+ LayoutUtil::MakeLayout({3, 1, 4, 2, 0}).minor_to_major());
+}
+
+TEST_F(LayoutAssignmentTest, ReduceOperandLayoutDivisorOfWarpSize) {
+ // Same as ReduceOperandLayout, but with a small reduction dimension that
+ // is a divisor of the warp size.
+ const char* module_str = R"(
+scalar_add_computation {
+ scalar_lhs = c64[] parameter(0)
+ scalar_rhs = c64[] parameter(1)
+ ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+ param_0 = c64[512,16,1024,128]{3,2,1,0} parameter(0)
+ negate = c64[512,16,1024,128]{3,2,1,0} negate(param_0)
+ constant_7 = c64[] constant((0, 0))
+ ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1}, to_apply=scalar_add_computation
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+ EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+ auto reduce = m->entry_computation()->root_instruction();
+ EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
+ LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major());
+}
+
+TEST_F(LayoutAssignmentTest, SendRcvLayout) {
+ const char* hlo = R"(
+HloModule Module
+
+condition {
+ p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
+ ROOT lt = pred[] constant(1)
+}
+
+body {
+ p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
+
+ t1 = f32[100,100] get-tuple-element(p), index=0
+ t = (f32[100,100], u32[], token[]) get-tuple-element(p), index=1
+ sdone = token[] send-done(t), channel_id=3, frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ tk = token[] after-all()
+
+
+ rcvd = (f32[100,100]{0,1}, u32[], token[]) recv(tk), channel_id=2
+ zz = (f32[100,100]{0,1}, token[]) recv-done(rcvd), channel_id=2
+
+ rcvd_d = get-tuple-element(zz), index=0
+
+ snd = (f32[100,100]{0,1}, u32[], token[]) send(t1, tk), channel_id=3, frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ a = add(t1, t1)
+
+ b = add(rcvd_d, a)
+
+ ROOT tup = tuple(b, snd)
+}
+
+ENTRY %main {
+ p0 = f32[100,100] parameter(0)
+ tk = token[] after-all()
+ snd = (f32[100,100]{0,1}, u32[], token[]) send(p0, tk), channel_id=1, frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ t = tuple(p0, snd)
+ ROOT loop = while(t), condition=condition, body=body
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+ ParseAndReturnVerifiedModule(hlo));
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape());
+
+ RunAndFilecheckHloRewrite(
+ hlo,
+ GpuLayoutAssignment{&computation_layout, GetGpuComputeCapability(),
+ GetDnnVersion()},
+ R"(
+// CHECK: (f32[100,100]{1,0}, u32[], token[]) recv
+// CHECK: (f32[100,100]{1,0}, token[]) recv-done
+// CHECK: (f32[100,100]{1,0}, u32[], token[]) send
+ )");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc
new file mode 100644
index 0000000..ae66093
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc
@@ -0,0 +1,240 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/move_copy_to_users.h"
+
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor {
+ // Turn copy->pad into pad->copy
+ absl::Status HandlePad(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ HloInstruction* c = hlo->mutable_operand(1);
+ if (operand->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied = operand->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * earlier_pad,
+ MakePadHlo(copied, c, hlo->padding_config(), &hlo->metadata()));
+ // MakePadHlo fails to propagate layout.
+ *earlier_pad->mutable_shape()->mutable_layout() =
+ copied->shape().layout();
+ HloInstruction* later_copy = MakeCopyHlo(earlier_pad, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ return absl::OkStatus();
+ }
+
+ // Turn copy->slice into slice->copy, as slice is layout-preserving.
+ absl::Status HandleSlice(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ if (operand->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied = operand->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * earlier_slice,
+ MakeSliceHlo(copied, hlo->slice_starts(), hlo->slice_limits(),
+ hlo->slice_strides(), &hlo->metadata()));
+ *earlier_slice->mutable_shape()->mutable_layout() =
+ copied->shape().layout();
+ HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ return absl::OkStatus();
+ }
+
+ // Turn copy->dynamic-slice into dynamic-slice->copy, as dynamic-slice is
+ // layout-preserving.
+ absl::Status HandleDynamicSlice(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ if (operand->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied = operand->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * earlier_slice,
+ MakeDynamicSliceHlo(
+ copied,
+ absl::Span<HloInstruction* const>(hlo->operands()).subspan(1),
+ hlo->dynamic_slice_sizes(), &hlo->metadata()));
+ *earlier_slice->mutable_shape()->mutable_layout() =
+ copied->shape().layout();
+ HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ return absl::OkStatus();
+ }
+
+ // Turn copy->reduce_window into reduce_window->copy, as reduce_window is
+ // layout-preserving.
+ absl::Status HandleReduceWindow(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ if (operand->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied = operand->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * earlier_reduce_window,
+ MakeReduceWindowHlo(copied, hlo->mutable_operand(1), hlo->window(),
+ hlo->called_computations()[0], &hlo->metadata()));
+ *earlier_reduce_window->mutable_shape()->mutable_layout() =
+ copied->shape().layout();
+ HloInstruction* later_copy =
+ MakeCopyHlo(earlier_reduce_window, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleReduce(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ // Reductions can handle transposes, e.g. via column reduction.
+ if (operand->opcode() == HloOpcode::kCopy && !hlo->shape().IsTuple()) {
+ HloInstruction* new_reduce = hlo->AddInstruction(
+ hlo->CloneWithNewOperands(hlo->shape(), {operand->mutable_operand(0),
+ hlo->mutable_operand(1)}));
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_reduce));
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status HandleBitcastConvert(HloInstruction* hlo) override {
+ return absl::OkStatus();
+ }
+
+ // Sink kCopy across elementwise unary.
+ absl::Status HandleElementwiseUnary(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ if (hlo->opcode() == HloOpcode::kReducePrecision) {
+ return absl::OkStatus();
+ }
+ if (operand->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied = operand->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * earlier_elementwise,
+ MakeUnaryHlo(hlo->opcode(), copied, &hlo->metadata()));
+ HloInstruction* later_copy =
+ MakeCopyHlo(earlier_elementwise, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ return absl::OkStatus();
+ }
+
+ // Sink kCopy across reverse
+ absl::Status HandleReverse(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ if (operand->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied = operand->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * earlier_reverse,
+ MakeReverseHlo(copied, hlo->dimensions(), &hlo->metadata()));
+ HloInstruction* later_copy = MakeCopyHlo(earlier_reverse, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ return absl::OkStatus();
+ }
+
+ // Sink kCopy across convert.
+ absl::Status HandleConvert(HloInstruction* hlo) override {
+ HloInstruction* operand = hlo->mutable_operand(0);
+ if (operand->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied = operand->mutable_operand(0);
+ HloInstruction* earlier_convert = MakeConvertToHlo(
+ copied, hlo->shape().element_type(), &hlo->metadata());
+ HloInstruction* later_copy = MakeCopyHlo(earlier_convert, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ return absl::OkStatus();
+ }
+
+ // Sink kCopy across elementwise binary.
+ absl::Status HandleElementwiseBinary(HloInstruction* hlo) override {
+ HloInstruction* a = hlo->mutable_operand(0);
+ HloInstruction* b = hlo->mutable_operand(1);
+ if (a->opcode() == HloOpcode::kCopy && b->opcode() == HloOpcode::kCopy) {
+ HloInstruction* copied_a = a->mutable_operand(0);
+ HloInstruction* copied_b = b->mutable_operand(0);
+ if (copied_a->shape() == copied_b->shape()) {
+ HloInstruction* earlier_elementwise;
+ if (hlo->opcode() == HloOpcode::kCompare) {
+ TF_ASSIGN_OR_RETURN(
+ earlier_elementwise,
+ MakeCompareHlo(hlo->comparison_direction(), copied_a, copied_b,
+ &hlo->metadata()));
+ } else {
+ TF_ASSIGN_OR_RETURN(earlier_elementwise,
+ MakeBinaryHlo(hlo->opcode(), copied_a, copied_b,
+ &hlo->metadata()));
+ }
+ HloInstruction* later_copy =
+ MakeCopyHlo(earlier_elementwise, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+ }
+ }
+ return absl::OkStatus();
+ }
+
+ // Move copy across kConcat if it occurs on all operands.
+ absl::Status HandleConcatenate(HloInstruction* hlo) override {
+ const HloInstruction* first = hlo->operand(0);
+ if (first->opcode() != HloOpcode::kCopy) {
+ return absl::OkStatus();
+ }
+ const HloInstruction* inner_op = first->operand(0);
+ const Layout& inner_op_layout = inner_op->shape().layout();
+
+ std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(hlo->operand_count());
+ for (HloInstruction* op : hlo->mutable_operands()) {
+ if (op->opcode() != HloOpcode::kCopy ||
+ op->operand(0)->shape().layout() != inner_op_layout) {
+ VLOG(3) << "Mismatch between " << op->ToString()
+ << " and expected op layout " << inner_op_layout.ToString();
+ return absl::OkStatus();
+ }
+ new_operands.push_back(op->mutable_operand(0));
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_concat,
+ MakeConcatHlo(new_operands, hlo->concatenate_dimension()));
+ *new_concat->mutable_shape()->mutable_layout() = inner_op_layout;
+
+ HloInstruction* new_copy = MakeCopyHlo(new_concat, hlo->shape());
+ TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_copy));
+ return absl::OkStatus();
+ }
+};
+
+} // end namespace
+
+absl::StatusOr<bool> MoveCopyToUsers::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ return MoveCopyToUsersVisitor{}.RunOnModule(module, execution_threads);
+}
+
+} // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h
new file mode 100644
index 0000000..698db04
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h
@@ -0,0 +1,39 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Sink kCopy operations as far down the graph as possible.
+class MoveCopyToUsers : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "move_copy_to_users"; }
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // end namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc
new file mode 100644
index 0000000..85999db
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc
@@ -0,0 +1,274 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/move_copy_to_users.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/service/layout_assignment.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace {
+
+class MoveCopyToUsersTest : public HloTestBase {
+ public:
+ MoveCopyToUsersTest()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/true,
+ LayoutAssignment::InstructionCanChangeLayout) {}
+ void CheckMoveCopyToUsers(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, MoveCopyToUsers{}, expected);
+ }
+};
+
+TEST_F(MoveCopyToUsersTest, Pad) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = s8[1,17,9,9]{3,1,2,0} parameter(0)
+ copy = s8[1,17,9,9]{1,3,2,0} copy(input)
+ constant = s8[] constant(0)
+ ROOT pad = s8[1,32,9,9]{1,3,2,0} pad(copy, constant), padding=0_0x0_15x0_0x0_0
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[constant_0:%[^ ]+]] = s8[] constant(0)
+// CHECK: [[pad_1_1:%[^ ]+]] = s8[1,32,9,9]{3,1,2,0} pad([[input_2:%[^ ]+]], [[constant_0]]), padding=0_0x0_15x0_0x0_0
+// CHECK: ROOT [[copy_1_3:%[^ ]+]] = s8[1,32,9,9]{1,3,2,0} copy([[pad_1_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Unary) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ ROOT pad = f32[1,17,9,9]{1,3,2,0} sqrt(copy)
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} sqrt([[input_0]])
+// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Reverse) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ ROOT pad = f32[1,17,9,9]{1,3,2,0} reverse(copy), dimensions={1,2}
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} reverse([[input_0]]), dimensions={1,2}
+// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Convert) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ ROOT converted = f16[1,17,9,9]{1,3,2,0} convert(copy)
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[sqrt_1:%[^ ]+]] = f16[1,17,9,9]{3,2,1,0} convert([[input_0]])
+// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f16[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Slice) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ ROOT slice = f32[1,4,6,6]{1,3,2,0} slice(copy), slice={[0:1],[0:4],[0:6],[0:6]}
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[slice_0:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} slice([[input_1:%[^ ]+]]), slice={[0:1], [0:4], [0:6], [0:6]}
+// CHECK-NEXT: ROOT [[copy_1_2:%[^ ]+]] = f32[1,4,6,6]{1,3,2,0} copy([[slice_0]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, DynamicSlice) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ p0 = s32[] parameter(1)
+ p1 = s32[] parameter(2)
+ p2 = s32[] parameter(3)
+ p3 = s32[] parameter(4)
+ ROOT ds = f32[1,4,6,6]{1,3,2,0} dynamic-slice(copy, p0, p1, p2, p3), dynamic_slice_sizes={1,4,6,6}
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[ds:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} dynamic-slice({{.*}}), dynamic_slice_sizes={1,4,6,6}
+// CHECK-NEXT: ROOT {{.*}} = f32[1,4,6,6]{1,3,2,0} copy([[ds]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, ReduceWindow) {
+ const char* hlo = R"(
+HloModule R2Window
+
+mul {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT mul = f32[] multiply(lhs, rhs)
+}
+
+ENTRY R2Window {
+ operand = f32[256,384]{1,0} parameter(0)
+ c = f32[256,384]{0,1} copy(operand)
+ constant = f32[] constant(1)
+ ROOT reduce-window = f32[256,384]{0,1} reduce-window(c, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[reduce_window_1_0:%[^ ]+]] = f32[256,384]{1,0} reduce-window([[operand_1:%[^ ]+]], [[constant_2:%[^ ]+]]), window={size=2x3 pad=0_1x1_1}, to_apply=[[mul_3:%[^ ]+]]
+// CHECK-NEXT: ROOT [[copy_4:%[^ ]+]] = f32[256,384]{0,1} copy([[reduce_window_1_0]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Reduce) {
+ const char* hlo = R"(
+HloModule R2
+
+mul {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT mul = f32[] multiply(lhs, rhs)
+}
+
+ENTRY R2 {
+ operand = f32[256,384,10]{2,1,0} parameter(0)
+ c = f32[256,384,10]{0,1,2} copy(operand)
+ constant = f32[] constant(1)
+ ROOT reduce = f32[384,10]{0,1} reduce(c, constant), dimensions={0}, to_apply=mul
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[operand:%[^ ]+]] = f32[256,384,10]{2,1,0} parameter(0)
+// CHECK: ROOT [[reduce:%[^ ]+]] = f32[384,10]{0,1} reduce([[operand]], [[constant_2:%[^ ]+]]), dimensions={0}, to_apply=[[mul_3:%[^ ]+]]
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Binary) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+ input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
+ ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[input2_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(1)
+// CHECK: [[add_1_2:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} add([[input_0]], [[input2_1]])
+// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[add_1_2]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, BinaryDifferentLayoutNoChange) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,0,1} parameter(0)
+ input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
+ ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, std::nullopt);
+}
+
+TEST_F(MoveCopyToUsersTest, Concat) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+ input2 = f32[5,17,9,9]{3,2,1,0} parameter(1)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ copy2 = f32[5,17,9,9]{1,3,2,0} copy(input2)
+ ROOT add = f32[6,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[input2_1:%[^ ]+]] = f32[5,17,9,9]{3,2,1,0} parameter(1)
+// CHECK: [[concat:%[^ ]+]] = f32[6,17,9,9]{3,2,1,0} concatenate([[input_0]], [[input2_1]])
+// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[6,17,9,9]{1,3,2,0} copy([[concat]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, ConcatDifferentLayoutNoChange) {
+ const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+ input = f32[1,17,9,9]{3,2,0,1} parameter(0)
+ input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
+ copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+ copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
+ ROOT add = f32[2,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
+}
+)";
+
+ CheckMoveCopyToUsers(hlo, std::nullopt);
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc
new file mode 100644
index 0000000..35bfe8e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc
@@ -0,0 +1,521 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_dfs_reachability.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/model/gpu_performance_model.h"
+#include "xla/service/gpu/model/gpu_performance_model_base.h"
+#include "xla/service/hlo_graph_dumper.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+bool IsProfitableOperand(HloInstruction* instr) {
+ // Effective scalars are not a profitable shared operand. Skip them.
+ return !ShapeUtil::IsEffectiveScalar(instr->shape());
+}
+
+// Finds and returns the unique `slice` op where `parent` is used in `instr`.
+// Returns `nullptr` if no such `slice` exists.
+const HloSliceInstruction* FindUniqueSlice(const HloInstruction* parent,
+ const HloInstruction* instr) {
+ if (const auto* slice = DynCast<HloSliceInstruction>(instr)) {
+ return slice;
+ } else if (const auto* fusion = DynCast<HloFusionInstruction>(instr)) {
+ const HloSliceInstruction* result = nullptr;
+ for (size_t i = 0; i < fusion->operand_count(); ++i) {
+ if (fusion->operand(i) == parent) {
+ // Parameter used more than once -> there's no unique slice.
+ if (result) return nullptr;
+
+ auto* called_param = fusion->fused_parameter(i);
+ if (called_param->user_count() != 1) return nullptr;
+
+ result = FindUniqueSlice(called_param, called_param->users()[0]);
+ if (!result) return nullptr;
+ }
+ }
+ return result;
+ } else {
+ return nullptr;
+ }
+}
+
+FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1,
+ const HloInstruction& instr2,
+ const HloInstruction* parent) {
+ if (parent->shape().IsTuple()) return {};
+ // Allow MOF if the parameter is small, even if there's no overlap. 1024 bytes
+ // were arbitrarily chosen as the threshold.
+ if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) return {};
+
+ const HloSliceInstruction* slice1 = FindUniqueSlice(parent, &instr1);
+ const HloSliceInstruction* slice2 = FindUniqueSlice(parent, &instr2);
+ if (!slice1 || !slice2) return {};
+
+ // TODO(jreiffers): Check strides as well.
+ auto& starts1 = slice1->slice_starts();
+ auto& starts2 = slice2->slice_starts();
+ auto& limits1 = slice1->slice_limits();
+ auto& limits2 = slice2->slice_limits();
+
+ for (int64_t dim = 0; dim < parent->shape().rank(); ++dim) {
+ bool overlap = starts1[dim] < limits2[dim] && starts2[dim] < limits1[dim];
+ if (!overlap) {
+ return "slices are non-overlapping";
+ }
+ }
+ return {};
+}
+
+FusionDecision LegalToFuse(const HloInstruction& instr1,
+ const HloInstruction& instr2,
+ const se::DeviceDescription& device_info,
+ FusionInfoCache* fusion_info_cache) {
+ CHECK(instr1.opcode() == HloOpcode::kFusion);
+
+ // The emitter only supports in-place DUS for fusions with a single DUS at the
+ // root. Don't sibling fuse DUS for now.
+ // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
+ // share the input and output buffers and add support to the emitter.
+ if (instr1.fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice ||
+ (instr2.opcode() == HloOpcode::kFusion &&
+ instr2.fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice)) {
+ return "can't fuse multiple DUSs";
+ }
+
+ // Do this check last, as it may be expensive.
+ return FusionFitsInBudget(instr1, instr2, device_info,
+ /*is_consumer_producer_fusion=*/false,
+ fusion_info_cache);
+}
+
+// We prefer multi-output fusions over other fusions over unfused ops, because
+// we want to preserve fusion opportunities if possible.
+int FusionPriority(const HloInstruction* instr) {
+ if (instr->IsMultiOutputFusion()) {
+ return 2;
+ }
+ if (instr->opcode() == HloOpcode::kFusion) {
+ return 1;
+ }
+ return 0;
+}
+
+HloInstruction* SelectPreferredFusionCandidate(
+ const std::vector<HloInstruction*> candidates) {
+ if (candidates.empty()) {
+ return nullptr;
+ }
+ return *std::max_element(
+ candidates.begin(), candidates.end(),
+ [](const HloInstruction* a, const HloInstruction* b) {
+ return FusionPriority(a) < FusionPriority(b);
+ });
+}
+
+// Do not fuse a producer if the other operands of the fusion are
+// reachable from the producer, this would create a cycle.
+FusionDecision OperandReachableFromProducer(
+ const HloInstruction& producer, const HloInstruction& consumer,
+ const HloDfsReachability& reachability) {
+ for (const auto* operand : consumer.operands()) {
+ // If a get-tuple-element instruction is not in the reachability
+ // map, it has been created by fusion in this pass. Simply move
+ // on to its operand, which is in the reachability map.
+ if (!reachability.IsPresent(operand) &&
+ operand->opcode() == HloOpcode::kGetTupleElement) {
+ operand = operand->operand(0);
+ }
+ CHECK(reachability.IsPresent(operand) && reachability.IsPresent(&producer))
+ << "Reachability map is incomplete. This should never "
+ "happen.";
+ if (&producer != operand && reachability.IsReachable(&producer, operand)) {
+ return {
+ absl::StrCat(producer.name(), " would introduce a cycle when fused")};
+ }
+ }
+ return {};
+}
+
+FusionDecision ProducerCandidateIsFusible(
+ const HloInstruction& producer, const HloInstruction& consumer,
+ const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache,
+ const se::DeviceDescription& device_info,
+ GpuHloCostAnalysis* cost_analysis) {
+ if (!IsFusibleAsMultiOutputFusionRoot(consumer)) {
+ return "consumer not eligible as multi-output fusion root.";
+ }
+
+ RETURN_IF_NOT_FUSIBLE(
+ ShapesCompatibleForMultiOutputFusion(consumer, producer));
+
+ RETURN_IF_NOT_FUSIBLE(
+ OperandReachableFromProducer(producer, consumer, reachability));
+
+ RETURN_IF_NOT_FUSIBLE(FusionFitsInBudget(
+ producer, consumer, device_info,
+ /*is_consumer_producer_fusion=*/false, fusion_info_cache));
+
+ if (cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer)) {
+ return "will generate too large IR";
+ }
+
+ GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+ &producer, device_info, cost_analysis,
+ GpuPerformanceModelOptions::Default(),
+ /*fused_consumers=*/{&consumer},
+ /*multi_output=*/true);
+ if (t.time_fused > t.time_unfused) {
+ return "will execute slower if fused";
+ }
+
+ return {};
+}
+
+std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
+ const HloInstruction* producer, const HloDfsReachability& reachability,
+ FusionInfoCache* fusion_info_cache,
+ const se::DeviceDescription& device_info,
+ GpuHloCostAnalysis* cost_analysis) {
+ std::vector<HloInstruction*> fusion_candidates;
+ const HloComputation* computation = producer->parent();
+ const HloModule* module = computation->parent();
+ bool dump_fusion =
+ module->config().debug_options().xla_dump_fusion_visualization();
+
+ // If the producer is not a valid candidate for MOF, no need to check any of
+ // its users.
+ if (!IsProducerMultiOutputFusible(*producer)) {
+ return fusion_candidates;
+ }
+
+ // If there is only one user, and it is not a multi-output fusion node, this
+ // fusion possibility was already considered and rejected by the FusionMerger
+ // pass. No need to try again!
+ if (producer->user_count() == 1 &&
+ !producer->users()[0]->IsMultiOutputFusion()) {
+ return fusion_candidates;
+ }
+
+ for (HloInstruction* consumer : producer->users()) {
+ VLOG(3) << "Looking at producer " << producer->name()
+ << " and its consumer " << consumer->name();
+
+ if (auto decision = ProducerCandidateIsFusible(
+ *producer, *consumer, reachability, fusion_info_cache, device_info,
+ cost_analysis)) {
+ fusion_candidates.push_back(consumer);
+ } else if (dump_fusion) {
+ RegisterFusionState(
+ *computation,
+ absl::StrCat("Not considering fusion of producer |", producer->name(),
+ "| into consumer |", consumer->name(),
+ "| due to: ", decision.Explain()),
+ *consumer, producer);
+ }
+ }
+ return fusion_candidates;
+}
+
+bool IsSiblingFusionCandidate(const HloInstruction* instr) {
+ if (instr->users().empty() || !IsFusibleAsMultiOutputFusionRoot(*instr) ||
+ IsNestableVariadicReduction(*instr)) {
+ return false;
+ }
+ // Check if the users of multioutput fusion is not a get-tuple-element.
+ // If this is the case, we bail out because the transformation assumes
+ // the users are get-tuple-element.
+ return (!instr->IsMultiOutputFusion() ||
+ absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
+ return user->opcode() == HloOpcode::kGetTupleElement;
+ }));
+}
+
+FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1,
+ const HloInstruction& sibling_consumer_2,
+ const HloInstruction& common_producer,
+ const HloDfsReachability& reachability,
+ FusionInfoCache* fusion_info_cache,
+ const se::DeviceDescription& device_info) {
+ if (reachability.IsConnected(&sibling_consumer_1, &sibling_consumer_2)) {
+ return {absl::StrCat(sibling_consumer_1.name(), " and ",
+ sibling_consumer_2.name(), " are connected")};
+ }
+
+ RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion(
+ sibling_consumer_1, sibling_consumer_2));
+
+ // Technically, this check is order-dependent (e.g. siblings A, B, C where
+ // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is
+ // [C, A, B], only {C, B} will be fused, and A will only be fused in the
+ // next iteration of the fusion pipeline, potentially requiring several
+ // iterations to converge. We assume this case to be very rare in
+ // practice.
+ RETURN_IF_NOT_FUSIBLE(ParameterSlicesAreNonOverlapping(
+ sibling_consumer_1, sibling_consumer_2, &common_producer));
+
+ // This check should be last, as it may be expensive.
+ RETURN_IF_NOT_FUSIBLE(LegalToFuse(sibling_consumer_1, sibling_consumer_2,
+ device_info, fusion_info_cache));
+ return {};
+}
+
+} // namespace
+
+void MultiOutputFusion::RecomputeReachability() {
+ reachability_ = HloDfsReachability::Build(computation_);
+}
+
+bool MultiOutputFusion::FuseSiblings(HloInstruction* parent,
+ FusionInfoCache* fusion_info_cache,
+ GpuHloCostAnalysis* cost_analysis) {
+ const HloComputation* computation = parent->parent();
+ const HloModule* module = computation->parent();
+ bool dump_fusion =
+ module->config().debug_options().xla_dump_fusion_visualization();
+
+ if (!IsProfitableOperand(parent)) {
+ VLOG(3) << "Operand " << parent->ToShortString() << " is not profitable";
+ return false;
+ }
+ bool changed = false;
+ std::vector<HloInstruction*> siblings;
+ // Only consider siblings that are fusion candidates.
+ absl::c_copy_if(parent->users(), std::back_inserter(siblings),
+ IsSiblingFusionCandidate);
+ // Sort the siblings such that multi-output fusion ops occur first, followed
+ // by fusion ops, followed by unfused ops.
+ absl::c_stable_sort(siblings,
+ [](const HloInstruction* a, const HloInstruction* b) {
+ return FusionPriority(a) > FusionPriority(b);
+ });
+
+ for (auto i = siblings.begin(); i != siblings.end(); ++i) {
+ VLOG(3) << "Considering " << (*i)->name();
+ if ((*i)->opcode() != HloOpcode::kFusion) {
+ continue;
+ }
+ for (auto j = i + 1; j != siblings.end();) {
+ VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
+
+ if (auto fusible = CanFuseSiblings(**i, **j, *parent, *reachability_,
+ fusion_info_cache, device_info_);
+ !fusible) {
+ // We pick `j` arbitrarily as a consumer.
+ if (dump_fusion) {
+ RegisterFusionState(
+ *computation,
+ absl::StrCat("Not fusing siblings |", (**i).name(), "| and |",
+ (**j).name(), "| due to: ", fusible.Explain()),
+ // Randomly pick one consumer.
+ /*consumer=*/**i,
+ /*producer=*/parent);
+ }
+ ++j;
+ continue;
+ }
+ if (!ConsumeFuel(name(), [&] {
+ return absl::StrFormat("Not fusing siblings %s and %s.",
+ (*i)->name(), (*j)->name());
+ })) {
+ ++j;
+ continue;
+ }
+ VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
+ fusion_info_cache->Invalidate(*i);
+ fusion_info_cache->Invalidate(*j);
+ HloInstruction* remaining = *i;
+ HloInstruction* fused = *j;
+ TF_CHECK_OK(cost_analysis->RemoveInstruction(remaining));
+ TF_CHECK_OK(cost_analysis->RemoveInstruction(fused));
+
+ DumpFusionState(*remaining,
+ absl::StrCat("About to fuse sibling |", fused->name(),
+ "| into sibling |", remaining->name(),
+ "| inside multi-output fusion"),
+ /*producer=*/fused);
+
+ if (fused->opcode() == HloOpcode::kFusion) {
+ remaining->MergeFusionInstructionIntoMultiOutput(fused);
+ if (fused->IsInputFusion()) {
+ remaining->set_fusion_kind(HloInstruction::FusionKind::kInput);
+ }
+ } else {
+ remaining->FuseInstructionIntoMultiOutput(fused);
+ CHECK_EQ(0, fused->user_count());
+ TF_CHECK_OK(computation_->RemoveInstruction(fused));
+ }
+ DumpFusionState(*remaining,
+ absl::StrCat("Fused into |", remaining->name(),
+ "| inside multi-output fusion"));
+ TF_CHECK_OK(cost_analysis->RevisitInstruction(remaining));
+ changed = true;
+ siblings.erase(j);
+ RecomputeReachability();
+ }
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> MultiOutputFusion::DoMultiOutputFusion() {
+ bool changed = false;
+ RecomputeReachability();
+ GpuHloCostAnalysis cost_analysis({shape_size_function_,
+ /*per_second_rates=*/{},
+ /*count_multiple_input_accesses=*/true},
+ device_info_);
+ TF_RETURN_IF_ERROR(computation_->Accept(&cost_analysis));
+ std::vector<HloInstruction*> defs_before_uses =
+ computation_->MakeInstructionPostOrder();
+
+ FusionInfoCache fusion_info_cache;
+ // Traverse the HLO in uses-before-defs order.
+ for (auto it = defs_before_uses.rbegin(); it != defs_before_uses.rend();
+ ++it) {
+ auto* producer = *it;
+ // Never multi-output fuse constants. To the extent that we want to fuse
+ // constants, that should be handled by the regular fusion pass.
+ if (producer->opcode() == HloOpcode::kConstant) {
+ VLOG(3) << producer->name() << " is a constant.";
+ continue;
+ }
+ if (producer->IsCustomFusion()) {
+ continue;
+ }
+ // First, fuse the consumer ops of the current op, which are siblings.
+ if (FuseSiblings(/*parent=*/producer, &fusion_info_cache, &cost_analysis)) {
+ changed = true;
+ }
+ // Second, perform producer-consumer multi-output fusion. This order will
+ // ensure that all get-tuple-element ops inserted as a by-product of
+ // multi-output fusion will occur before the current op in the order of
+ // traversal, and hence, not get into the way of subsequent fusion attempts.
+ const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
+ producer, *reachability_, &fusion_info_cache, device_info_,
+ &cost_analysis);
+ auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
+ if (consumer_for_fusion == nullptr) {
+ continue;
+ }
+ if (!ConsumeFuel(name(), [&] {
+ return absl::StrFormat("Not fusing %s and %s.", producer->name(),
+ consumer_for_fusion->name());
+ })) {
+ continue;
+ }
+ changed = true;
+ fusion_info_cache.Invalidate(producer);
+ fusion_info_cache.Invalidate(consumer_for_fusion);
+ TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(producer));
+ TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion));
+
+ HloInstruction* input_fusion;
+ if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
+ input_fusion = consumer_for_fusion;
+ VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
+ << consumer_for_fusion->name();
+ } else {
+ input_fusion = computation_->AddInstruction(HloInstruction::CreateFusion(
+ consumer_for_fusion->shape(),
+ ChooseFusionKind(*producer, *consumer_for_fusion),
+ consumer_for_fusion));
+ VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
+ << consumer_for_fusion->name() << " into "
+ << input_fusion->name();
+ TF_CHECK_OK(
+ computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
+ }
+
+ DumpFusionState(*input_fusion,
+ absl::StrCat("About to fuse producer |", producer->name(),
+ "| into consumer |", input_fusion->name(),
+ "| inside multi-output fusion"),
+ /*producer=*/producer);
+
+ if (producer->opcode() == HloOpcode::kFusion) {
+ input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
+ } else {
+ input_fusion->FuseInstructionIntoMultiOutput(producer);
+ CHECK_EQ(0, producer->user_count());
+ TF_CHECK_OK(computation_->RemoveInstruction(producer));
+ }
+ TF_RETURN_IF_ERROR(cost_analysis.RevisitInstruction(input_fusion));
+
+ DumpFusionState(*input_fusion,
+ absl::StrCat("Fused into |", input_fusion->name(),
+ "| inside multi-output fusion"));
+ RecomputeReachability();
+ }
+ return changed;
+}
+
+void MultiOutputFusion::DumpFusionState(const HloInstruction& consumer,
+ absl::string_view label,
+ const HloInstruction* producer) {
+ if (consumer.GetModule()
+ ->config()
+ .debug_options()
+ .xla_dump_fusion_visualization()) {
+ RegisterFusionState(*computation_, label, consumer, producer);
+ }
+}
+
+absl::StatusOr<bool> MultiOutputFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (auto* computation : GetFusibleComputations(*module, execution_threads)) {
+ computation_ = computation;
+ TF_ASSIGN_OR_RETURN(bool computation_changed, DoMultiOutputFusion());
+ changed |= computation_changed;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h
new file mode 100644
index 0000000..9ebabe6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h
@@ -0,0 +1,134 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_
+
+#include <memory>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_dfs_reachability.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Multi-output fusion of sibling and producer-consumer instructions for the
+// GPU backend to reduce memory bandwidth requirements.
+//
+// 0) Before multi- 1) Sibling multi- 2) Producer-consumer
+// output fusion output fusion multi-output fusion
+//
+// p p p
+// | | |
+// v v v
+// A A +-fusion--+
+// / \ | | A |
+// | | +-fusion--+ | / \ |
+// v v | / \ | | B | |
+// B C | B C | | | | |
+// \ / | | | | | v v |
+// v v | v v | | tuple |
+// ROOT | tuple | +---------+
+// +---------+ / \
+// / \ gte_b gte_a
+// gte_b gte_c | |
+// | | | v
+// \ / | C
+// v v \ /
+// ROOT v v
+// ROOT
+//
+// Multi-output fusion ops have a tuple op at their root containing multiple
+// elements as outputs. GetTupleElement ops (depicted as gte_* above) are
+// inserted to extract tuple elements for consumers.
+//
+// The two different flavors of multi-output fusion this pass performs are
+// depicted above.
+// 1) Fusion of sibling ops reduces memory bandwidth requirements, because
+// common input parameters have to be read only once.
+// 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by
+// saving one read from memory. In the example above, B does not need to read
+// the output of A from memory, while C still does (using gte_a).
+// Note that sibling (1) and producer-consumer (2) multi-output fusion can be
+// combined.
+//
+// The MultiOutputFusion pass modifies the HLO in reverse post-order (defs
+// before uses). First, it attempts to fuse the consumer ops of the current op,
+// which are siblings (1). Hereafter, it attempts to fuse the current op with
+// one of its consumers (2). This order avoids a phase ordering issue (described
+// in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a
+// by-product of multi-output fusion will occur before the current op in the
+// order of traversal, and hence, not get into the way of subsequent fusion
+// attempts.
+//
+// The MultiOutputFusion pass ensures several conditions are met for fusion.
+// Some of them are relevant for correctness. In particular, no cycles must be
+// introduced into the HLO module. Moreover, the code emitters for multi-output
+// fusion must support the combination of ops and their shapes. Other
+// restrictions are rather arbitrary and lifting them could be beneficial.
+// * Sibling fusion (1) requires at least one op to be a kFusion.
+// * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e.
+// the fusion kinds must match.
+
+class MultiOutputFusion : public HloModulePass {
+ public:
+ explicit MultiOutputFusion(
+ const se::DeviceDescription& device_info,
+ HloCostAnalysis::ShapeSizeFunction shape_size_function)
+ : device_info_(device_info), shape_size_function_(shape_size_function) {}
+
+ absl::string_view name() const override { return "multi_output_fusion"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache,
+ GpuHloCostAnalysis* cost_analysis);
+
+ absl::StatusOr<bool> DoMultiOutputFusion();
+
+ // Recompute reachability for the current computation.
+ void RecomputeReachability();
+
+ void DumpFusionState(const HloInstruction& consumer, absl::string_view label,
+ const HloInstruction* producer = nullptr);
+
+ // Computation for the pass.
+ HloComputation* computation_;
+
+ // The reachability map of current computation.
+ std::unique_ptr<HloDfsReachability> reachability_;
+
+ se::DeviceDescription device_info_;
+ HloCostAnalysis::ShapeSizeFunction shape_size_function_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc
new file mode 100644
index 0000000..1fd4263
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc
@@ -0,0 +1,2236 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+namespace m = ::xla::match;
+
+class MultiOutputFusionTest : public HloTestBase {
+ HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
+ return [&](const Shape& shape) {
+ constexpr int64_t kPointerSize = 8;
+ return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+ };
+ }
+
+ public:
+ MultiOutputFusion mof_{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+ ShapeSizeBytesFunction()};
+
+ void CheckMultiOutputFusion(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(
+ hlo,
+ MultiOutputFusion{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+ ShapeSizeBytesFunction()},
+ expected);
+ }
+};
+
+const char kModulePrefix[] = R"(
+ HloModule test_module
+
+ scalar_add_computation {
+ scalar_lhs.0 = f32[] parameter(0)
+ scalar_rhs.0 = f32[] parameter(1)
+ ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
+ }
+ scalar_mul_computation {
+ scalar_lhs.1 = f32[] parameter(0)
+ scalar_rhs.1 = f32[] parameter(1)
+ ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
+ })";
+
+static int64_t CountMultiOutputFusions(const HloModule* module) {
+ int multi_output_fusion_count = 0;
+ for (auto* computation : module->MakeNonfusionComputations()) {
+ for (auto* instr : computation->instructions()) {
+ if (instr->IsMultiOutputFusion()) {
+ multi_output_fusion_count++;
+ }
+ }
+ }
+ return multi_output_fusion_count;
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
+ // Fusion with reduce instruction root and a sibling reduce instruction
+ // sharing the same input param.
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ const.2 = f32[] constant(1)
+ fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
+ reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[6400]{0} parameter(1)
+ mul = f32[6400]{0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[6400]{0} parameter(1)
+ r1 = f32[64,100]{0,1} reshape(p1.2)
+ const.2 = f32[] parameter(0)
+ ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[6400]{0} parameter(1)
+ fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
+ fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, ReduceMofDifferentTypes) {
+ // Fusion with reduce instruction root and a sibling reduce instruction
+ // sharing the same input param.
+ const char* hlo = R"(
+HloModule module
+
+scalar_add_computation {
+ scalar_lhs.1 = f32[] parameter(0)
+ scalar_rhs.1 = f32[] parameter(1)
+ ROOT add.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
+}
+
+scalar_add_computation_f16 {
+ scalar_lhs.0 = f16[] parameter(0)
+ scalar_rhs.0 = f16[] parameter(1)
+ ROOT add.0 = f16[] add(scalar_lhs.0, scalar_rhs.0)
+}
+
+fused_computation {
+ param_0.2 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ c.1 = f16[128,512,28,28]{3,2,1,0} convert(param_0.2)
+ const.0 = f16[] constant(0)
+ ROOT reduce.0 = f16[512]{0} reduce(c.1, const.0), dimensions={0,2,3}, to_apply=scalar_add_computation_f16
+}
+
+ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ const.2 = f32[] constant(0)
+ reduce.1 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ fusion = f16[512]{0} fusion(p1), kind=kInput, calls=fused_computation
+ ROOT root = (f32[512]{0}, f16[512]{0}) tuple(reduce.1, fusion)
+})";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation
+// CHECK-NEXT: [[param_0_2_0:%[^ ]+]] = f32[128,512,28,28]{3,2,1,0} parameter(0)
+// CHECK-NEXT: [[c_1_1:%[^ ]+]] = f16[128,512,28,28]{3,2,1,0} convert([[param_0_2_0]])
+// CHECK-NEXT: [[const_0_2:%[^ ]+]] = f16[] constant(0)
+// CHECK-NEXT: [[reduce_0_3:%[^ ]+]] = f16[512]{0} reduce([[c_1_1]], [[const_0_2]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_f16_4:%[^ ]+]]
+// CHECK-NEXT: [[param_1_5:%[^ ]+]] = f32[] parameter(1)
+// CHECK-NEXT: [[reduce_2_6:%[^ ]+]] = f32[512]{0} reduce([[param_0_2_0]], [[param_1_5]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_7:%[^ ]+]]
+// CHECK-NEXT: ROOT [[tuple_8:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) tuple([[reduce_0_3]], [[reduce_2_6]])
+// CHECK: [[fusion_9:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) fusion([[p1_10:%[^ ]+]], [[const_2_11:%[^ ]+]]), kind=kInput, calls=[[fused_computation_12:%[^ ]+]]
+)");
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[10,10]{1,0} parameter(1)
+ mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[10,10]{1,0} parameter(1)
+ const.2 = f32[] parameter(0)
+ ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1.3 = f32[10,10]{1,0} parameter(1)
+ fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1
+ p2 = f32[] parameter(2)
+ fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
+ // Two sibling fusions with reduce instruction roots sharing the same input
+ // param.
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ const.2 = f32[] parameter(0)
+ ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
+ fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionNoSiblingFusionForCommonScalar) {
+ // Two sibling fusions with bitcast roots sharing the same scalar input param.
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ param_0.87 = bf16[32,4096,16384]{2,1,0} parameter(0)
+ param_1.4620 = s32[] parameter(1)
+ constant_3949 = s32[] constant(0)
+ compare.1026 = pred[] compare(param_1.4620, constant_3949), direction=LT
+ constant_5437 = s32[] constant(32)
+ add.6859 = s32[] add(param_1.4620, constant_5437)
+ select.1599 = s32[] select(compare.1026, add.6859, param_1.4620)
+ dynamic-slice.59 = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0.87, select.1599, constant_3949, constant_3949), dynamic_slice_sizes={1,4096,16384}
+ ROOT bitcast.41089 = bf16[4096,16384]{1,0} bitcast(dynamic-slice.59)
+ }
+
+ fused_computation_2 {
+ param_0 = bf16[32,4096,16384]{2,1,0} parameter(0)
+ param_1 = s32[] parameter(1)
+ constant = s32[] constant(0)
+ compare = pred[] compare(param_1, constant), direction=LT
+ constant.32 = s32[] constant(32)
+ add = s32[] add(param_1, constant.32)
+ select = s32[] select(compare, add, param_1)
+ dynamic-slice = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0, select, constant, constant), dynamic_slice_sizes={1,4096,16384}
+ ROOT bitcast.41087 = bf16[4096,16384]{1,0} bitcast(dynamic-slice)
+ }
+
+ ENTRY entry {
+ p0 = s32[] parameter(0)
+ p1 = bf16[32,4096,16384]{2,1,0} parameter(1)
+ p2 = bf16[32,4096,16384]{2,1,0} parameter(2)
+ fusion.1 = bf16[4096,16384]{1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = bf16[4096,16384]{1,0} fusion(p2, p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (bf16[4096,16384]{1,0}, bf16[4096,16384]{1,0}) tuple(fusion.1, fusion.2)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest,
+ MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
+ // Multi-output fusion with two reduce instructions root and a sibling reduce
+ // instruction sharing the same input param.
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
+ const.1 = f32[] constant(1)
+ p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
+ reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
+ }
+
+ ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ const = f32[] constant(1)
+ fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
+ get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
+ get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
+ reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest,
+ MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
+ // Verify that if we already have a multi-output fusion that we prefer to pick
+ // a reduce op from its operands for checking shape compatibility.
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[10,10]{1,0} parameter(1)
+ mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1)
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[10,10]{1,0} parameter(1)
+ const.2 = f32[] parameter(0)
+ ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[10,10]{1,0} parameter(1)
+ p2 = f32[] parameter(2)
+ fusion.1 = (f32[10,10], f32[]) fusion(p0, p1), kind=kInput, calls=fused_computation_1
+ get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[]) fusion.1), index=0
+ get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[]) fusion.1), index=1
+ fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, LoopVariadicReductionFusions) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation.94 {
+ tmp_0 = f32[] parameter(0)
+ tmp_1 = f32[] parameter(1)
+ tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
+ tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
+ tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
+ tmp_5 = s32[] parameter(2)
+ tmp_6 = s32[] parameter(3)
+ tmp_7 = s32[] minimum(tmp_5, tmp_6)
+ tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
+ tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
+ ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
+ }
+
+ minmax_func.1536 {
+ tmp_0 = f32[] parameter(0)
+ tmp_1 = f32[] parameter(2)
+ tmp_2 = s32[] parameter(1)
+ tmp_3 = s32[] parameter(3)
+ ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
+ }
+
+ fused_computation {
+ tmp_0 = f32[554112,10]{1,0} parameter(0)
+ tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+ tmp_2 = f32[] constant(-inf)
+ tmp_3 = s32[] constant(0)
+ ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+ }
+
+ fused_computation2 {
+ tmp_0 = f32[554112,10]{1,0} parameter(0)
+ tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+ tmp_2 = f32[] constant(inf)
+ tmp_3 = s32[] constant(1)
+ ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+ }
+
+ ENTRY e {
+ tmp_0 = f32[554112,10]{1,0} parameter(0)
+ tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
+ tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
+ tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation2
+ tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
+ ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
+ })"))
+ .value();
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, InputVariadicReductionFusions) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation.1117 {
+ param_0.2433 = f32[] parameter(0)
+ param_1.2571 = f32[] parameter(1)
+ compare.1770 = pred[] compare(param_0.2433, param_1.2571), direction=LE
+ select.682 = f32[] select(compare.1770, param_0.2433, param_1.2571)
+ compare.1303.clone.1 = pred[] compare(param_0.2433, param_1.2571), direction=EQ
+ param_2.6460 = s32[] parameter(2)
+ param_3.6755 = s32[] parameter(3)
+ minimum.633.clone.1 = s32[] minimum(param_2.6460, param_3.6755)
+ select.398.clone.1 = s32[] select(compare.1770, param_2.6460, param_3.6755)
+ select.397.clone.1 = s32[] select(compare.1303.clone.1, minimum.633.clone.1, select.398.clone.1)
+ ROOT tuple.151 = (f32[], s32[]) tuple(select.682, select.397.clone.1)
+ }
+
+ minmax_func.223 {
+ lhs_value.224 = f32[] parameter(0)
+ rhs_value.226 = f32[] parameter(2)
+ lhs_index.225 = s32[] parameter(1)
+ rhs_index.227 = s32[] parameter(3)
+ ROOT fusion.1117 = (f32[], s32[]) fusion(lhs_value.224, rhs_value.226, lhs_index.225, rhs_index.227), kind=kLoop, calls=fused_computation.1117
+ }
+
+ fused_computation.73 {
+ bitcast.86661 = f32[3,1024,300]{2,1,0} parameter(0)
+ iota.734 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
+ bitcast.97555 = s32[3,1024,300]{2,1,0} bitcast(iota.734)
+ constant_3917 = f32[] constant(inf)
+ constant_3918 = s32[] constant(0)
+ ROOT reduce.1069 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86661, bitcast.97555, constant_3917, constant_3918), dimensions={2}, to_apply=minmax_func.223
+ }
+
+ fused_computation.84 {
+ bitcast.86676 = f32[3,1024,300]{2,1,0} parameter(0)
+ iota.732 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
+ bitcast.97553 = s32[3,1024,300]{2,1,0} bitcast(iota.732)
+ constant_3915 = f32[] constant(inf)
+ constant_3916 = s32[] constant(0)
+ ROOT reduce.1070 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86676, bitcast.97553, constant_3915, constant_3916), dimensions={2}, to_apply=minmax_func.223
+ }
+
+ ENTRY e {
+ p0 = f32[3,1024,300]{2,1,0} parameter(0)
+ fusion.84 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.84
+ gte.391 = s32[3,1024]{1,0} get-tuple-element(fusion.84), index=1
+ fusion.73 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.73
+ gte.393 = s32[3,1024]{1,0} get-tuple-element(fusion.73), index=1
+ ROOT r = s32[3,1024]{1,0} add(gte.391, gte.393)
+ })"))
+ .value();
+ EXPECT_TRUE(mof_.Run(module.get()).value());
+ EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->user_count(),
+ 1);
+ const HloInstruction* fusion =
+ module->entry_computation()->parameter_instruction(0)->users()[0];
+ EXPECT_THAT(fusion, GmockMatch(m::Fusion()));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[6400]{0} parameter(0)
+ const.2 = f32[] constant(1)
+ broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
+ ROOT div = f32[6400]{0} divide(p0.2, broadcast)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ const.2 = f32[] constant(1)
+ broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
+ div = f32[6400]{0} divide(p0, broadcast)
+ ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+ ROOT mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), dimensions={0,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0}) tuple(fusion.1, fusion.2)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
+ f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2),
+ dimensions={}
+ ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
+ f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
+ calls=fused_computation_1
+ fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop,
+ calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
+ f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0})
+ tuple(gte0, gte1, fusion.2)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add())));
+}
+
+TEST_F(MultiOutputFusionTest,
+ MultiOutputFusionSiblingMultiOutputLoopAndMultiOutputLoop) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,16]{1,0} parameter(0)
+ mul = f32[8,16]{1,0} multiply(p0.1, p0.1)
+ exp = f32[8,16]{1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,16]{1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ broadcast = f32[8,16]{1,0} broadcast(const.2),
+ dimensions={}
+ add = f32[8,16]{1,0} add(p0.2, broadcast)
+ ROOT tuple.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(add, broadcast)
+ }
+
+ ENTRY entry {
+ p0 = f32[8,16]{1,0} parameter(0)
+ fusion.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
+ calls=fused_computation_1
+ fusion.2 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
+ calls=fused_computation_2
+ gte0 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=1
+ gte2 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=0
+ gte3 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=1
+ ROOT root = (f32[8,16]{1,0}, f32[8,16]{1,0}, f32[8,16]{1,0},
+ f32[8,16]{1,0})
+ tuple(gte0, gte1, gte2, gte3)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(
+ fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add(), m::Broadcast())));
+}
+
+TEST_F(MultiOutputFusionTest,
+ MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,2]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
+ f32[8,1,5,16,1,2]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2),
+ dimensions={0,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
+ f32[8,1,5,16,1,2]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
+ calls=fused_computation_1
+ fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop,
+ calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
+ f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0})
+ tuple(gte0, gte1, fusion.2)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, SiblingFusionBitcastAndLoopFusionNotFused) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+
+fused_computation_1 {
+ p0.1 = f32[2048,16000]{1,0} parameter(0)
+ bitcast = f32[2048,1,16000]{2,1,0} bitcast(p0.1)
+ ROOT exp = f32[2048,1,16000]{2,1,0} exponential(bitcast)
+}
+
+ENTRY main {
+ param_0 = f32[2048,16000]{1,0} parameter(0)
+ fusion = f32[2048,1,16000]{2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation_1
+ bitcast = f32[16000,1,2048]{2,1,0} bitcast(param_0)
+ ROOT tuple.143 = (f32[16000,1,2048]{2,1,0}, f32[2048,1,16000]{2,1,0}) tuple(bitcast, fusion)
+})")
+ .value();
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionBitcastAndElementwiseNotFused) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+
+ENTRY main {
+ param_0 = f32[2048,16000]{1,0} parameter(0)
+ convert = bf16[2048,16000]{1,0} convert(param_0)
+ bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
+ ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
+})")
+ .value();
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ ENTRY reduce {
+ p0 = f32[32,32,32]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ exp = f32[32,32,32]{2,1,0} exponential(p0)
+ reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
+ to_apply=scalar_add_computation
+ ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement())));
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Exp())));
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_add {
+ p0.1 = f32[32,32,32]{2,1,0} parameter(0)
+ p1.1 = f32[32,32,32]{2,1,0} parameter(1)
+ ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
+ }
+
+ ENTRY reduce {
+ p0 = f32[32,32,32]{2,1,0} parameter(0)
+ p1 = f32[32,32,32]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
+ reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+ to_apply=scalar_add_computation
+ ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement())));
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Add())));
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f32[32,32,32]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
+ greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
+ f32[32,32,32]{2,1,0} broadcast), direction=GT
+ p0.1 = f32[32,32,32]{2,1,0} parameter(0)
+ ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
+ greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
+ }
+
+ fused_reduce {
+ p0.2 = f32[32,32,32]{2,1,0} parameter(0)
+ c1 = f32[] constant(0)
+ r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
+ to_apply=scalar_add_computation
+ mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
+ r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
+ to_apply=scalar_add_computation
+ ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p0 = f32[32,32,32]{2,1,0} parameter(0)
+ p1 = f32[32,32,32]{2,1,0} parameter(1)
+ select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
+ calls=fused_reduce
+ gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
+ tuple(gte1, gte1, select)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root,
+ GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement(), m::GetTupleElement())));
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_element_wise {
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
+ }
+
+ fused_reduce {
+ p0.2 = f32[2,2,2]{2,1,0} parameter(0)
+ mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
+ f32[2,2,2]{2,1,0} p0.2)
+ broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
+ c1 = f32[] constant(0)
+ ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
+ f32[] c1), dimensions={1,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
+ fusion = f32[2,2]{1,0} fusion(element_wise), kind=kLoop, calls=fused_reduce
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f16[32,32,32]{2,1,0} parameter(1)
+ c0 = f16[] constant(0)
+ broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
+ greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
+ f16[32,32,32]{2,1,0} broadcast), direction=GT
+ p0.1 = f16[32,32,32]{2,1,0} parameter(0)
+ ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
+ greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[32,32,32]{2,1,0} parameter(0)
+ convert = f32[32,32,32]{2,1,0} convert(p0.2)
+ c1 = f32[] constant(0)
+ r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
+ to_apply=scalar_add_computation
+ mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
+ r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
+ to_apply=scalar_add_computation
+ ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
+ }
+ ENTRY reduce {
+ p0 = f16[32,32,32]{2,1,0} parameter(0)
+ p1 = f16[32,32,32]{2,1,0} parameter(1)
+ select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
+ calls=fused_reduce
+ gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
+ tuple(gte1, gte1, select)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root,
+ GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+ m::GetTupleElement(), m::GetTupleElement())));
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
+}
+
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ mixed_input_layouts_computation {
+ p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+ ENTRY reduce {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
+ reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
+ })"))
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionAvoidsCycles) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_add {
+ p0 = f32[32,32,32]{2,1,0} parameter(0)
+ p1 = f32[32,32,32]{2,1,0} parameter(1)
+ ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
+ }
+
+ fused_mul {
+ p2 = f32[64,64,64]{2,1,0} parameter(0)
+ p3 = f32[64,64,64]{2,1,0} parameter(1)
+ ROOT multiply = f32[64,64,64]{2,1,0} multiply(p2, p3)
+ }
+
+ fused_reduce_1 {
+ p4 = f32[32,32,32]{2,1,0} parameter(0)
+ p5 = f32[64,64,64]{2,1,0} parameter(1)
+ slice = f32[32,32,32]{2,1,0} slice(p5), slice={[0:32], [0:32], [0:32]}
+ add = f32[32,32,32]{2,1,0} add(p4, slice)
+ c0 = f32[] constant(0)
+ ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+ to_apply=scalar_add_computation
+ }
+
+ fused_reduce_2 {
+ p6 = f32[32,32,32]{2,1,0} parameter(0)
+ p7 = f32[64,64,64]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ pad = f32[64,64,64]{2,1,0} pad(p6, c0), padding=16_16x16_16x16_16
+ mul = f32[64,64,64]{2,1,0} multiply(pad, p7)
+ ROOT r1 = f32[64,64]{1,0} reduce(mul, c0), dimensions={2},
+ to_apply=scalar_add_computation
+ }
+
+ ENTRY reduce {
+ p8 = f32[32,32,32]{2,1,0} parameter(0)
+ p9 = f32[64,64,64]{2,1,0} parameter(1)
+ // `add` and `mul` can be multi-output fused with `reduce1` and `reduce2`,
+ // respectively. However, both isn't possible, because multi-output fusion
+ // will introduce an extra dependency from `neg` to `abs` or vice versa.
+ // Hence, the second multi-output fusion would introduce a cycle.
+ add = f32[32,32,32]{2,1,0} fusion(p8, p8), kind=kLoop, calls=fused_add
+ mul = f32[64,64,64]{2,1,0} fusion(p9, p9), kind=kLoop, calls=fused_mul
+
+ reduce1 = f32[32,32]{1,0} fusion(add, mul), kind=kInput,
+ calls=fused_reduce_1
+ reduce2 = f32[64,64]{1,0} fusion(add, mul), kind=kInput,
+ calls=fused_reduce_2
+ ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[64,64]{1,0},
+ f32[64,64,64]{2,1,0}) tuple(add, reduce1, reduce2, mul)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ EXPECT_EQ(1, CountMultiOutputFusions(module.get()));
+}
+
+TEST_F(MultiOutputFusionTest, PreferFuseProducerIntoFusionConsumer) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_add {
+ p0 = f32[32,32,32]{2,1,0} parameter(0)
+ p1 = f32[32,32,32]{2,1,0} parameter(1)
+ ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
+ }
+ fused_reduce {
+ p0 = f32[32,32,32]{2,1,0} parameter(0)
+ p1 = f32[64,64,64]{2,1,0} parameter(1)
+ slice = f32[32,32,32]{2,1,0} slice(p1), slice={[0:32], [0:32], [0:32]}
+ add = f32[32,32,32]{2,1,0} add(p0, slice)
+ c0 = f32[] constant(0)
+ ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+ to_apply=scalar_add_computation
+ }
+ ENTRY reduce {
+ p0 = f32[32,32,32]{2,1,0} parameter(0)
+ p1 = f32[64,64,64]{2,1,0} parameter(1)
+ add = f32[32,32,32]{2,1,0} fusion(p0, p0), kind=kLoop, calls=fused_add
+ c0 = f32[] constant(0)
+ reduce2 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+ to_apply=scalar_add_computation
+ reduce = f32[32,32]{1,0} fusion(add, p1), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[32,32]{1,0})
+ tuple(add, reduce, reduce2)
+ })"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ int multi_output_fusion_count = 0;
+ for (auto* computation : module->MakeNonfusionComputations()) {
+ for (auto* instr : computation->instructions()) {
+ if (instr->IsMultiOutputFusion()) {
+ multi_output_fusion_count++;
+ }
+ }
+ }
+ EXPECT_EQ(1, multi_output_fusion_count);
+}
+
+// Check that we limit the number of operands to fusions we create.
+TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) {
+ constexpr int64_t kNumParams = 200;
+ ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
+
+ // Compute
+ // p0 * p1,
+ // p0 * p1 + p1 * p2
+ // p0 * p1 + p1 * p2 + p2 * p3
+ // ...
+ // where each of the (pi * pj)'s is represented as a fusion node so that
+ // multi-output fusion will pay attention to it.
+ auto module = CreateNewVerifiedModule();
+ HloComputation::Builder b(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+
+ std::vector<HloInstruction*> params;
+ for (int64_t i = 0; i < kNumParams; ++i) {
+ params.push_back(
+ b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
+ }
+
+ // Creates a fusion node that calculates x*y.
+ auto make_fusion = [&](HloInstruction* x, HloInstruction* y) {
+ HloComputation::Builder sub_builder("subcomp");
+ auto* p0 = sub_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "p"));
+ auto* p1 = sub_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "p"));
+ sub_builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
+ HloComputation* subcomp =
+ module->AddEmbeddedComputation(sub_builder.Build());
+ return HloInstruction::CreateFusion(
+ shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp);
+ };
+
+ auto* sum = b.AddInstruction(make_fusion(params[0], params[1]));
+ for (int64_t i = 2; i < kNumParams; ++i) {
+ sum = b.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, sum,
+ b.AddInstruction(make_fusion(params[i - 1], params[i]))));
+ }
+ auto computation = module->AddEntryComputation(b.Build());
+ EXPECT_TRUE(mof_.Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ for (const HloInstruction* instr : computation->instructions()) {
+ EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()),
+ MaxOperandsAndOutputsPerFusion())
+ << instr->ToString();
+ }
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
+ auto module = ParseAndReturnVerifiedModule(R"(HloModule dus_mof
+ fusion.1 {
+ p.0 = f16[50,96,1024]{2,1,0} parameter(0)
+ p.1 = f16[1,96,1024]{2,1,0} parameter(1)
+ c.0 = s32[3]{0} constant({0, 0, 0})
+ ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+ }
+
+ fusion.2 {
+ p.0 = f16[50,96,1024]{2,1,0} parameter(0)
+ p.1 = f16[1,96,1024]{2,1,0} parameter(1)
+ c.0 = s32[3]{0} constant({0, 0, 0})
+ ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+ }
+
+ ENTRY entry {
+ p.00 = f16[50,96,1024]{2,1,0} parameter(0)
+ p.01 = f16[50,96,1024]{2,1,0} parameter(1)
+ p.1 = f16[1,96,1024]{2,1,0} parameter(2)
+
+ f1 = f16[50,96,1024] fusion(p.00, p.1), kind=kLoop, calls=fusion.1
+ f2 = f16[50,96,1024] fusion(p.01, p.1), kind=kLoop, calls=fusion.2
+ ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2)
+ })")
+ .value();
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+// Check that we don't fuse too many reductions together.
+TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation0 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation1 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation2 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation3 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation4 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation5 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation6 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation7 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation8 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ fused_computation9 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+ to_apply=scalar_add_computation
+ }
+ ENTRY computation {
+ zero = f32[] constant(0)
+ param0 = f32[64,64] parameter(0)
+ param1 = f32[64,64] parameter(1)
+ param2 = f32[64,64] parameter(2)
+ param3 = f32[64,64] parameter(3)
+ param4 = f32[64,64] parameter(4)
+ param5 = f32[64,64] parameter(5)
+ param6 = f32[64,64] parameter(6)
+ param7 = f32[64,64] parameter(7)
+ param8 = f32[64,64] parameter(8)
+ param9 = f32[64,64] parameter(9)
+ out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
+ out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
+ out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
+ out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
+ out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
+ out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
+ out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
+ out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
+ out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
+ out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
+ ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
+ }
+ )"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+
+ EXPECT_EQ(5, CountMultiOutputFusions(module.get()));
+}
+
+TEST_F(MultiOutputFusionTest, DoNotGroupTooManyReductions) {
+ auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+ fused_computation0 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation1 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation2 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation3 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation4 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation5 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation6 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation7 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation8 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ fused_computation9 {
+ p0 = f32[64,64] parameter(0)
+ p1 = f32[64,64] parameter(1)
+ p2 = f32[] parameter(2)
+ add = f32[64,64] add(p0, p1)
+ ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+ to_apply=scalar_add_computation
+ }
+ ENTRY computation {
+ zero = f32[] constant(0)
+ param0 = f32[64,64] parameter(0)
+ param1 = f32[64,64] parameter(1)
+ param2 = f32[64,64] parameter(2)
+ param3 = f32[64,64] parameter(3)
+ param4 = f32[64,64] parameter(4)
+ param5 = f32[64,64] parameter(5)
+ param6 = f32[64,64] parameter(6)
+ param7 = f32[64,64] parameter(7)
+ param8 = f32[64,64] parameter(8)
+ param9 = f32[64,64] parameter(9)
+ out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
+ out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
+ out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
+ out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
+ out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
+ out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
+ out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
+ out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
+ out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
+ out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
+ ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
+ }
+ )"))
+ .value();
+ ASSERT_TRUE(mof_.Run(module.get()).value());
+
+ EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
+}
+
+TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule xla_computation_update_step.10931
+
+%scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
+ %scalar_lhs.1 = f64[] parameter(0)
+ %scalar_rhs.1 = f64[] parameter(1)
+ ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
+}
+
+%fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
+ %param_0.8 = f64[64,64]{1,0} parameter(0)
+ %param_1.11 = f64[64,64]{1,0} parameter(1)
+ %multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
+ %constant_5217.3 = f64[] constant(0)
+ %broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
+ %multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
+ %reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
+ %param_2.9 = f64[64,64]{1,0} parameter(2)
+ %multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
+ %constant_5217.1.clone.1 = f64[] constant(0)
+ %broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
+ %multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
+ %reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
+ ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
+}
+
+%primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
+ %parameter.6427 = f64[] parameter(0)
+ %parameter.6428 = f64[] parameter(1)
+ ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
+}
+
+%fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
+ %param_0.7 = f64[64,64]{1,0} parameter(0)
+ %param_1.9 = f64[64,64]{1,0} parameter(1)
+ %multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
+ %constant_5217.2 = f64[] constant(0)
+ ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
+}
+
+ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
+ %param_0.1090 = f64[64,64]{1,0} parameter(0)
+ %param_1.1377 = f64[64,64]{1,0} parameter(1)
+ %param_2.1948 = f64[64,64]{1,0} parameter(2)
+ %fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
+ %get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
+ %fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
+ %get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
+ ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
+}
+ )")
+ .value();
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+and.reduce_sub_computation {
+ x = pred[] parameter(0)
+ y = pred[] parameter(1)
+ ROOT and = pred[] and(x, y)
+}
+
+fused_computation.1 {
+ param_4.658 = f32[2,20,256]{2,0,1} parameter(4)
+ slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]}
+ constant.6847 = s32[] constant(0)
+ broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={}
+ param_9.415 = s32[3]{0} parameter(9)
+ compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE
+ constant.6846 = pred[] constant(true)
+ reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation
+ broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={}
+ param_5.528 = f32[2,512]{1,0} parameter(5)
+ slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]}
+ bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384)
+ constant.5418 = f32[] constant(0)
+ broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={}
+ select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227)
+ add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173)
+ param_0.299 = s32[] parameter(0)
+ constant.5157 = s32[] constant(11)
+ dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299)
+ slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]}
+ constant.6800 = s32[] constant(0)
+ broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={}
+ param_8.484 = s32[3]{0} parameter(8)
+ compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE
+ constant.6798 = pred[] constant(true)
+ reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation
+ broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={}
+ param_3.1169 = f32[2,512]{1,0} parameter(3)
+ slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]}
+ bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382)
+ select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227)
+ add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172)
+ constant.5154 = s32[] constant(10)
+ dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299)
+ slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]}
+ constant.6794 = s32[] constant(0)
+ broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={}
+ param_7.478 = s32[3]{0} parameter(7)
+ compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE
+ constant.6793 = pred[] constant(true)
+ reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation
+ broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={}
+ param_2.1685 = f32[2,512]{1,0} parameter(2)
+ slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]}
+ bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380)
+ select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227)
+ add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171)
+ constant.5153 = s32[] constant(9)
+ dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299)
+ slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]}
+ constant.6788 = s32[] constant(0)
+ broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={}
+ param_6.495 = s32[3]{0} parameter(6)
+ compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE
+ constant.6786 = pred[] constant(true)
+ reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation
+ broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={}
+ param_1.1408 = f32[2,512]{1,0} parameter(1)
+ slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]}
+ bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378)
+ select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227)
+ add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170)
+ constant.5152 = s32[] constant(8)
+ ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299)
+}
+
+fused_computation.2 {
+ param_4.655 = f32[2,20,256]{2,0,1} parameter(4)
+ slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]}
+ param_6.483 = pred[] parameter(6)
+ broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={}
+ param_5.525 = f32[2,512]{1,0} parameter(5)
+ slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]}
+ bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368)
+ constant.5415 = f32[] constant(0)
+ broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={}
+ select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225)
+ add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161)
+ param_0.265 = s32[] parameter(0)
+ constant.5151 = s32[] constant(7)
+ dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265)
+ slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]}
+ constant.6782 = s32[] constant(0)
+ broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={}
+ param_9.391 = s32[3]{0} parameter(9)
+ compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE
+ constant.6781 = pred[] constant(true)
+ reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation
+ broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={}
+ param_3.1167 = f32[2,512]{1,0} parameter(3)
+ slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]}
+ bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366)
+ select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225)
+ add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160)
+ constant.5150 = s32[] constant(6)
+ dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265)
+ slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]}
+ constant.6776 = s32[] constant(0)
+ broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={}
+ param_8.464 = s32[3]{0} parameter(8)
+ compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE
+ constant.6775 = pred[] constant(true)
+ reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation
+ broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={}
+ param_2.1684 = f32[2,512]{1,0} parameter(2)
+ slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]}
+ bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364)
+ select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225)
+ add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159)
+ constant.5149 = s32[] constant(5)
+ dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265)
+ slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]}
+ constant.6770 = s32[] constant(0)
+ broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={}
+ param_7.458 = s32[3]{0} parameter(7)
+ compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE
+ constant.6769 = pred[] constant(true)
+ reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation
+ broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={}
+ param_1.1405 = f32[2,512]{1,0} parameter(1)
+ slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]}
+ bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362)
+ select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225)
+ add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158)
+ constant.5148 = s32[] constant(4)
+ ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265)
+}
+
+ENTRY main {
+ param_0.0 = s32[] parameter(0)
+ param_1.0 = f32[2,512]{1,0} parameter(1)
+ param_2.0 = f32[2,512]{1,0} parameter(2)
+ param_3.0 = f32[2,512]{1,0} parameter(3)
+ param_4.0 = f32[2,20,256]{2,1,0} parameter(4)
+ param_5.0 = f32[2,512]{1,0} parameter(5)
+ param_6.0 = s32[3]{0} parameter(6)
+ param_7.0 = s32[3]{0} parameter(7)
+ param_8.0 = s32[3]{0} parameter(8)
+ param_9.0 = s32[3]{0} parameter(9)
+ fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1
+ param_10 = pred[] parameter(10)
+ fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2
+ ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2)
+}
+ )")
+ .value();
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, DoNotFuseRoot) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+no_op {
+ arg_empty_tuple = () parameter(0)
+ ROOT tuple = () tuple()
+}
+
+fused_computation {
+ param_0 = f32[] parameter(0)
+ ROOT convert = s32[] convert(param_0)
+}
+
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ fusion = s32[] fusion(param_0), kind=kLoop, calls=fused_computation
+ tuple = () tuple()
+ conditional = () conditional(fusion, tuple, tuple), branch_computations={no_op, no_op}
+ constant = f32[] constant(1)
+ ROOT root = f32[] add(param_0, constant)
+}
+ )")
+ .value();
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, CostBasedNoMerge) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+region_3.63 {
+ Arg_0.64 = f32[] parameter(0)
+ Arg_1.65 = f32[] parameter(1)
+ ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
+}
+
+fused_computation.29 {
+ param_0.161 = f32[5,32,32,1]{3,2,1,0} parameter(0)
+ multiply.208 = f32[5,32,32,1]{3,2,1,0} multiply(param_0.161, param_0.161)
+ bitcast.67 = f32[5,32,32]{2,1,0} bitcast(multiply.208)
+ constant.265 = f32[] constant(0)
+ reduce-window.81 = f32[5,30,31]{2,1,0} reduce-window(bitcast.67, constant.265), window={size=1x3x2}, to_apply=region_3.63
+ constant.264 = f32[] constant(0.166666672)
+ broadcast.204 = f32[5,30,31]{2,1,0} broadcast(constant.264), dimensions={}
+ multiply.205 = f32[5,30,31]{2,1,0} multiply(reduce-window.81, broadcast.204)
+ constant.263 = f32[] constant(0)
+ reduce-window.80 = f32[5,30,31]{2,1,0} reduce-window(multiply.205, constant.263), window={size=1x2x3 pad=0_0x0_1x1_1}, to_apply=region_3.63
+ constant.262 = f32[] constant(0.0138888899)
+ broadcast.201 = f32[5,30,31]{2,1,0} broadcast(constant.262), dimensions={}
+ multiply.204 = f32[5,30,31]{2,1,0} multiply(reduce-window.80, broadcast.201)
+ constant.261 = f32[] constant(0)
+ reduce-window.78 = f32[5,30,31]{2,1,0} reduce-window(multiply.204, constant.261), window={size=1x1x2 pad=0_0x0_0x0_1}, to_apply=region_3.63
+ constant.113 = f32[] constant(0.5)
+ broadcast.137 = f32[5,30,31]{2,1,0} broadcast(constant.113), dimensions={}
+ multiply.125 = f32[5,30,31]{2,1,0} multiply(reduce-window.78, broadcast.137)
+ constant.114 = f32[] constant(0)
+ ROOT reduce-window.17 = f32[5,30,31]{2,1,0} reduce-window(multiply.125, constant.114), window={size=1x2x1 pad=0_0x0_1x0_0}, to_apply=region_3.63
+}
+
+fused_computation.15 {
+ constant.108 = f32[] constant(0.5)
+ broadcast.105 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.108), dimensions={}
+ param_3.126 = f32[5,30,31]{2,1,0} parameter(3)
+ constant.295 = f32[] constant(0.25)
+ broadcast.234 = f32[5,30,31]{2,1,0} broadcast(constant.295), dimensions={}
+ multiply.242 = f32[5,30,31]{2,1,0} multiply(param_3.126, broadcast.234)
+ broadcast.233 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.242), dimensions={0,2,3}
+ param_2.154 = f32[5,30,31]{2,1,0} parameter(2)
+ multiply.241 = f32[5,30,31]{2,1,0} multiply(param_2.154, broadcast.234)
+ broadcast.232 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.241), dimensions={1,2,3}
+ multiply.240 = f32[5,5,30,31]{3,2,1,0} multiply(broadcast.233, broadcast.232)
+ param_1.188 = f32[5,5,30,31]{3,2,1,0} parameter(1)
+ constant.294 = f32[] constant(0.159154937)
+ broadcast.231 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.294), dimensions={}
+ multiply.239 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.231)
+ param_0.164 = f32[5,5,30,31]{3,2,1,0} parameter(0)
+ add.19 = f32[5,5,30,31]{3,2,1,0} add(multiply.239, param_0.164)
+ constant.293 = f32[] constant(0)
+ reduce-window.90 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.19, constant.293), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
+ constant.292 = f32[] constant(0.5)
+ broadcast.230 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.292), dimensions={}
+ multiply.238 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.90, broadcast.230)
+ constant.291 = f32[] constant(0)
+ reduce-window.89 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.238, constant.291), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
+ constant.290 = f32[] constant(0.25)
+ broadcast.229 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.290), dimensions={}
+ multiply.237 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.89, broadcast.229)
+ multiply.236 = f32[5,5,30,31]{3,2,1,0} multiply(multiply.237, multiply.237)
+ subtract.10 = f32[5,5,30,31]{3,2,1,0} subtract(multiply.240, multiply.236)
+ constant.289 = f32[] constant(0)
+ broadcast.228 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.289), dimensions={}
+ maximum.6 = f32[5,5,30,31]{3,2,1,0} maximum(subtract.10, broadcast.228)
+ sqrt.6 = f32[5,5,30,31]{3,2,1,0} sqrt(maximum.6)
+ constant.110 = f32[] constant(0)
+ broadcast.107 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.110), dimensions={}
+ compare.4 = pred[5,5,30,31]{3,2,1,0} compare(sqrt.6, broadcast.107), direction=EQ
+ constant.243 = f32[] constant(0.159154937)
+ broadcast.193 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.243), dimensions={}
+ multiply.194 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.193)
+ add.15 = f32[5,5,30,31]{3,2,1,0} add(multiply.194, param_0.164)
+ constant.242 = f32[] constant(0)
+ reduce-window.66 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.15, constant.242), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
+ constant.241 = f32[] constant(0.5)
+ broadcast.192 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.241), dimensions={}
+ multiply.193 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.66, broadcast.192)
+ constant.240 = f32[] constant(0)
+ reduce-window.65 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.193, constant.240), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
+ constant.239 = f32[] constant(0.25)
+ broadcast.191 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.239), dimensions={}
+ multiply.192 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.65, broadcast.191)
+ compare.3 = pred[5,5,30,31]{3,2,1,0} compare(multiply.192, broadcast.107), direction=EQ
+ and.1 = pred[5,5,30,31]{3,2,1,0} and(compare.4, compare.3)
+ constant.109 = f32[] constant(1.57079637)
+ broadcast.104 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.109), dimensions={}
+ atan2.1 = f32[5,5,30,31]{3,2,1,0} atan2(sqrt.6, multiply.192)
+ select.4 = f32[5,5,30,31]{3,2,1,0} select(and.1, broadcast.104, atan2.1)
+ constant.107 = f32[] constant(0.159154937)
+ broadcast.106 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.107), dimensions={}
+ multiply.100 = f32[5,5,30,31]{3,2,1,0} multiply(select.4, broadcast.106)
+ ROOT subtract.3 = f32[5,5,30,31]{3,2,1,0} subtract(broadcast.105, multiply.100)
+}
+
+fused_computation.4 {
+ param_0.172 = f32[5,30,31]{2,1,0} parameter(0)
+ constant.315 = f32[] constant(0.125)
+ broadcast.242 = f32[5,30,31]{2,1,0} broadcast(constant.315), dimensions={}
+ multiply.250 = f32[5,30,31]{2,1,0} multiply(param_0.172, broadcast.242)
+ constant.314 = f32[] constant(0)
+ reduce-window.100 = f32[5,30,31]{2,1,0} reduce-window(multiply.250, constant.314), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
+ constant.79 = f32[] constant(0.055555556)
+ broadcast.85 = f32[5,30,31]{2,1,0} broadcast(constant.79), dimensions={}
+ multiply.80 = f32[5,30,31]{2,1,0} multiply(reduce-window.100, broadcast.85)
+ constant.81 = f32[] constant(0)
+ reduce-window.1 = f32[5,30,31]{2,1,0} reduce-window(multiply.80, constant.81), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
+ constant.80 = f32[] constant(0.111111112)
+ broadcast.86 = f32[5,30,31]{2,1,0} broadcast(constant.80), dimensions={}
+ multiply.79 = f32[5,30,31]{2,1,0} multiply(reduce-window.1, broadcast.86)
+ bitcast.26 = f32[5,930]{1,0} bitcast(multiply.79)
+ ROOT reduce.8 = f32[5]{0} reduce(bitcast.26, constant.81), dimensions={1}, to_apply=region_3.63
+}
+
+ENTRY e {
+ Arg_0.1 = f32[5,32,32,1]{3,2,1,0} parameter(0)
+ p1 = f32[5,5,30,31]{3,2,1,0} parameter(1)
+ p2 = f32[5,5,30,31]{3,2,1,0} parameter(2)
+ p3 = f32[5,30,31]{2,1,0} parameter(3)
+ fusion.29 = f32[5,30,31]{2,1,0} fusion(Arg_0.1), kind=kLoop, calls=fused_computation.29
+ fusion.15 = f32[5,5,30,31]{3,2,1,0} fusion(p2, p1, p3, fusion.29), kind=kLoop, calls=fused_computation.15
+ ROOT fusion.4 = f32[5]{0} fusion(fusion.29), kind=kInput, calls=fused_computation.4
+})")
+ .value();
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, NoOverlappingRead) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation_1 {
+ p0.1 = f32[100,200]{1,0} parameter(0)
+ slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[0:100]}
+ mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
+ exp = f32[50,100]{1,0} exponential(slice.0)
+ ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[100,200]{1,0} parameter(0)
+ slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[0:50],[100:200]}
+ const.2 = f32[] constant(0)
+ broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
+ ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
+ }
+
+ ENTRY entry {
+ p0 = f32[100,200]{1,0} parameter(0)
+ fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
+ calls=fused_computation_1
+ gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
+ fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
+ calls=fused_computation_2
+ ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
+ tuple(gte0, gte1, fusion.2)
+ })")
+ .value();
+
+ EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, OverlappingRead) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation_1 {
+ p0.1 = f32[100,200]{1,0} parameter(0)
+ slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[50:150]}
+ mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
+ exp = f32[50,100]{1,0} exponential(slice.0)
+ ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[100,200]{1,0} parameter(0)
+ slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[30:80],[20:120]}
+ const.2 = f32[] constant(0)
+ broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
+ ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
+ }
+
+ ENTRY entry {
+ p0 = f32[100,200]{1,0} parameter(0)
+ fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
+ calls=fused_computation_1
+ gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
+ fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
+ calls=fused_computation_2
+ ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
+ tuple(gte0, gte1, fusion.2)
+ })")
+ .value();
+
+ EXPECT_TRUE(mof_.Run(module.get()).value());
+}
+
+class TransposeMultiOutputFusionTest : public MultiOutputFusionTest {};
+
+TEST_F(TransposeMultiOutputFusionTest, MultipleCopies) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f32[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+ p = f32[16,32]{1,0} parameter(0)
+ fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+ c1 = f32[16,32]{0,1} copy(p)
+ ROOT t = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple(fusion, c1)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
+// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
+// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[16,32]{0,1} copy([[param_0_1_0]])
+// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK-NEXT: }
+
+// CHECK: [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, MultipleTransposes) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f32[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+}
+
+ENTRY main {
+ p = f32[16,32]{1,0} parameter(0)
+ fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
+ c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
+ ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(fusion, c1)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[32,16], f32[32,16]) {
+// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0}
+// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[32,16]{1,0} transpose([[param_0_1_0]]), dimensions={1,0}
+// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK-NEXT: }
+// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, CopyAndTranspose) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f32[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+ p = f32[16,32]{1,0} parameter(0)
+ fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+ c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
+ ROOT t = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple(fusion, c1)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, R"(
+ // CHECK: %fused_computation ({{[^ ]+}} f32[16,32]) -> (f32[16,32], f32[32,16]) {
+ // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+ // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0]])
+ // CHECK-NEXT: [[copy:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1]])
+ // CHECK-NEXT: [[transpose:[^ ]+]] = f32[32,16]{1,0} transpose([[param_0]]), dimensions={1,0}
+ // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple([[copy]], [[transpose]])
+ // CHECK: %fusion = (f32[16,32]{0,1}, f32[32,16]{1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, MultipleCopiesDifferentTypes) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f16[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} convert(param_0.1)
+ ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+ p = f16[16,32]{1,0} parameter(0)
+ fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+ c1 = f16[16,32]{0,1} copy(p)
+ ROOT t = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple(fusion, c1)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f16[16,32]) -> (f32[16,32], f16[16,32]) {
+// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f16[16,32]{1,0} parameter(0)
+// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} convert([[param_0_1_0]])
+// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
+// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f16[16,32]{0,1} copy([[param_0_1_0]])
+// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK: [[fusion_5:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) fusion([[p_6:%[^ ]+]]), kind=kInput, calls=[[fused_computation_7:%[^ ]+]]
+)");
+}
+
+// Do not group copy and reduction.
+TEST_F(TransposeMultiOutputFusionTest, TiledReduceCopy) {
+ const char* hlo = R"(
+HloModule module
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = add(lhs, rhs)
+}
+
+fused_computation {
+ param_0.1 = f32[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+ p = f32[16,32]{1,0} parameter(0)
+ fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+ z = f32[] constant(0)
+ r1 = f32[32]{0} reduce(p, z), dimensions={0}, to_apply=add
+ ROOT t = (f32[16,32]{0,1}, f32[32]{0}) tuple(fusion, r1)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, std::nullopt);
+}
+
+// Do not group incompatible transposes.
+TEST_F(TransposeMultiOutputFusionTest, IncompatibleTransposes) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
+ param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
+ s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
+ t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
+ sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
+ exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
+ ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
+}
+
+fused_computation.2 {
+ param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
+ s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
+ ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
+}
+
+ENTRY main {
+ p = f32[18,16,32]{2,1,0} parameter(0)
+ p2 = f32[32,16,18]{2,1,0} parameter(1)
+ fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
+ fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
+ ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, std::nullopt);
+}
+
+// A variation of the test above, where no CSE was run.
+TEST_F(TransposeMultiOutputFusionTest, TransposesNoCSE) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
+ param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
+ s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
+ t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
+ sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
+ exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
+ exp.2 = f32[32,16,18]{2,1,0} exponential(sub.1)
+ ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.2)
+}
+
+fused_computation.2 {
+ param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
+ s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
+ ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
+}
+
+ENTRY main {
+ p = f32[18,16,32]{2,1,0} parameter(0)
+ p2 = f32[32,16,18]{2,1,0} parameter(1)
+ fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
+ fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
+ ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, std::nullopt);
+}
+
+TEST_F(TransposeMultiOutputFusionTest, CopyAndInput) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f32[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+ p = f32[16,32]{1,0} parameter(0)
+ fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+ c1 = exponential(p)
+ ROOT t = tuple(fusion, c1)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
+// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
+// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
+// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK-NEXT: }
+// CHECK: [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, TransposeAndInputEpilogueFusion) {
+ const char* hlo = R"(
+HloModule module
+
+fused_computation {
+ param_0.1 = f32[16,32]{1,0} parameter(0)
+ s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+ t.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+ ROOT out = f32[32,16,1]{2,1,0} bitcast(t.1)
+}
+
+ENTRY main {
+ p = f32[16,32]{1,0} parameter(0)
+ fusion = f32[32,16,1]{2,1,0} fusion(p), kind=kInput, calls=fused_computation
+ c1 = exponential(p)
+ ROOT t = tuple(fusion, c1)
+}
+ )";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation
+// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]])
+// CHECK-NEXT: [[out_3:%[^ ]+]] = f32[32,16,1]{2,1,0} bitcast([[c_1_2]])
+// CHECK-NEXT: [[c1_1_4:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
+// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]])
+// CHECK-NEXT: }
+// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+class ReduceMultiOutputFusionTest : public MultiOutputFusionTest {};
+
+TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoop) {
+ const char* hlo = R"(
+HloModule module
+
+add {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a, b)
+}
+
+fused_reduction {
+ p = f32[200] parameter(0)
+ z = f32[] constant(0)
+ e = f32[200] exponential(p)
+ ROOT r = f32[] reduce(e, z), dimensions={0}, to_apply=add
+}
+
+fused_elementwise {
+ p = f32[200] parameter(0)
+ ROOT r = f32[200] sqrt(p)
+}
+
+ENTRY computation {
+ p = f32[200] parameter(0)
+ o1 = f32[200] fusion(p), kind=kLoop, calls=fused_elementwise
+ o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
+ ROOT out = (f32[200], f32[]) tuple(o1, o2)
+}
+
+)";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_elementwise
+// CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[200]{0} parameter(0)
+// CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[200]{0} sqrt([[p_1_0]])
+// CHECK-NEXT: [[e_2:%[^ ]+]].clone.1 = f32[200]{0} exponential([[p_1_0]])
+// CHECK-NEXT: [[z_3:%[^ ]+]].clone.1 = f32[] constant(0)
+// CHECK-NEXT: [[r_4:%[^ ]+]].clone.1 = f32[] reduce([[e_2]].clone.1, [[z_3]].clone.1), dimensions={0}, to_apply=[[add_5:%[^ ]+]]
+// CHECK-NEXT: ROOT [[tuple_6:%[^ ]+]] = (f32[200]{0}, f32[]) tuple([[r_1_1]], [[r_4]].clone.1)
+// CHECK-NEXT:}
+// CHECK: [[o1_0:%[^ ]+]] = (f32[200]{0}, f32[]) fusion([[p_2_1:%[^ ]+]]), kind=kInput, calls=[[fused_elementwise_2:%[^ ]+]]
+ )");
+}
+
+TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShape) {
+ const char* hlo = R"(
+HloModule module
+
+add {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a, b)
+}
+
+fused_reduction {
+ p = f32[10,20] parameter(0)
+ z = f32[] constant(0)
+ e = f32[10,20] exponential(p)
+ b = f32[200] bitcast(e)
+ ROOT r = f32[] reduce(b, z), dimensions={0}, to_apply=add
+}
+
+fused_elementwise {
+ p = f32[10,20] parameter(0)
+ ROOT r = f32[10,20] sqrt(p)
+}
+
+ENTRY computation {
+ p = f32[10,20] parameter(0)
+ o1 = f32[10,20] fusion(p), kind=kLoop, calls=fused_elementwise
+ o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
+ ROOT out = (f32[10,20], f32[]) tuple(o1, o2)
+}
+)";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_elementwise (p.1: f32[10,20]) -> (f32[10,20], f32[]) {
+// CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[10,20]{1,0} parameter(0)
+// CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[10,20]{1,0} sqrt([[p_1_0]])
+// CHECK-NEXT: [[e_2:%[^ ]+]].clone.1 = f32[10,20]{1,0} exponential([[p_1_0]])
+// CHECK-NEXT: [[b_1_3:%[^ ]+]].clone.1 = f32[200]{0} bitcast([[e_2]].clone.1)
+// CHECK-NEXT: [[z_4:%[^ ]+]].clone.1 = f32[] constant(0)
+// CHECK-NEXT: [[r_5:%[^ ]+]].clone.1 = f32[] reduce([[b_1_3]].clone.1, [[z_4]].clone.1), dimensions={0}, to_apply=[[add_6:%[^ ]+]]
+// CHECK-NEXT: ROOT [[tuple_7:%[^ ]+]] = (f32[10,20]{1,0}, f32[]) tuple([[r_1_1]], [[r_5]].clone.1)
+// CHECK-NEXT: }
+ )");
+}
+
+TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShapeDifferentType) {
+ const char* hlo = R"(
+HloModule module, entry_computation_layout={(f16[100,200]{1,0},f32[],f32[])->(f16[100,200]{1,0}, f32[])}
+
+max {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] maximum(a, b)
+}
+
+fused_computation {
+ one_5 = f32[] constant(1)
+ one_b.5 = f32[100,200]{1,0} broadcast(one_5), dimensions={}
+ param_1.15 = f16[100,200]{1,0} parameter(1)
+ c.6 = f32[100,200]{1,0} convert(param_1.15)
+ param_0.11 = f32[] parameter(0)
+ b.6 = f32[100,200]{1,0} broadcast(param_0.11), dimensions={}
+ d.5 = f32[100,200]{1,0} divide(c.6, b.6)
+ a.6 = f32[100,200]{1,0} add(one_b.5, d.5)
+ bitcast.1 = f32[20000]{0} bitcast(a.6)
+ z_1 = f32[] constant(0)
+ ROOT r.1 = f32[] reduce(bitcast.1, z_1), dimensions={0}, to_apply=max
+}
+
+fused_computation.1 {
+ one_3 = f32[] constant(1)
+ one_b.3 = f32[100,200]{1,0} broadcast(one_3), dimensions={}
+ param_2.7 = f16[100,200]{1,0} parameter(2)
+ c.4 = f32[100,200]{1,0} convert(param_2.7)
+ param_1.10 = f32[] parameter(1)
+ b.4 = f32[100,200]{1,0} broadcast(param_1.10), dimensions={}
+ d.3 = f32[100,200]{1,0} divide(c.4, b.4)
+ a.4 = f32[100,200]{1,0} add(one_b.3, d.3)
+ param_0.8 = f32[] parameter(0)
+ output_scale_broadcast.1 = f32[100,200]{1,0} broadcast(param_0.8), dimensions={}
+ a_scaled.1 = f32[100,200]{1,0} multiply(a.4, output_scale_broadcast.1)
+ ROOT a_scaled_converted.1 = f16[100,200]{1,0} convert(a_scaled.1)
+}
+
+ENTRY computation {
+ output_scale = f32[] parameter(2)
+ input_scale = f32[] parameter(1)
+ p = f16[100,200]{1,0} parameter(0)
+ fusion.1 = f16[100,200]{1,0} fusion(output_scale, input_scale, p), kind=kLoop, calls=fused_computation.1
+ fusion = f32[] fusion(input_scale, p), kind=kInput, calls=fused_computation
+ ROOT out = (f16[100,200]{1,0}, f32[]) tuple(fusion.1, fusion)
+}
+)";
+
+ CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation.1 (param_0.8: f32[], param_1.10: f32[], param_2.7: f16[100,200]) -> (f16[100,200], f32[]) {
+// CHECK-NEXT: [[one_3_0:%[^ ]+]] = f32[] constant(1)
+// CHECK-NEXT: [[one_b_3_1:%[^ ]+]] = f32[100,200]{1,0} broadcast([[one_3_0]]), dimensions={}
+// CHECK-NEXT: [[param_2_7_2:%[^ ]+]] = f16[100,200]{1,0} parameter(2)
+// CHECK-NEXT: [[c_4_3:%[^ ]+]] = f32[100,200]{1,0} convert([[param_2_7_2]])
+// CHECK-NEXT: [[param_1_10_4:%[^ ]+]] = f32[] parameter(1)
+// CHECK-NEXT: [[b_4_5:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
+// CHECK-NEXT: [[d_3_6:%[^ ]+]] = f32[100,200]{1,0} divide([[c_4_3]], [[b_4_5]])
+// CHECK-NEXT: [[a_4_7:%[^ ]+]] = f32[100,200]{1,0} add([[one_b_3_1]], [[d_3_6]])
+// CHECK-NEXT: [[param_0_8_8:%[^ ]+]] = f32[] parameter(0)
+// CHECK-NEXT: [[output_scale_broadcast_1_9:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_0_8_8]]), dimensions={}
+// CHECK-NEXT: [[a_scaled_1_10:%[^ ]+]] = f32[100,200]{1,0} multiply([[a_4_7]], [[output_scale_broadcast_1_9]])
+// CHECK-NEXT: [[a_scaled_converted_1_11:%[^ ]+]] = f16[100,200]{1,0} convert([[a_scaled_1_10]])
+// CHECK-NEXT: [[one_5_12:%[^ ]+]].clone.1 = f32[] constant(1)
+// CHECK-NEXT: [[one_b_5_13:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[one_5_12]].clone.1), dimensions={}
+// CHECK-NEXT: [[c_6_14:%[^ ]+]].clone.1 = f32[100,200]{1,0} convert([[param_2_7_2]])
+// CHECK-NEXT: [[b_6_15:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
+// CHECK-NEXT: [[d_5_16:%[^ ]+]].clone.1 = f32[100,200]{1,0} divide([[c_6_14]].clone.1, [[b_6_15]].clone.1)
+// CHECK-NEXT: [[a_6_17:%[^ ]+]].clone.1 = f32[100,200]{1,0} add([[one_b_5_13]].clone.1, [[d_5_16]].clone.1)
+// CHECK-NEXT: [[bitcast_1_18:%[^ ]+]].clone.1 = f32[20000]{0} bitcast([[a_6_17]].clone.1)
+// CHECK-NEXT: [[z_1_19:%[^ ]+]].clone.1 = f32[] constant(0)
+// CHECK-NEXT: [[r_1_20:%[^ ]+]].clone.1 = f32[] reduce([[bitcast_1_18]].clone.1, [[z_1_19]].clone.1), dimensions={0}, to_apply=[[max_21:%[^ ]+]]
+// CHECK-NEXT: ROOT [[tuple_22:%[^ ]+]] = (f16[100,200]{1,0}, f32[]) tuple([[a_scaled_converted_1_11]], [[r_1_20]].clone.1)
+// CHECK-NEXT: }
+ )");
+}
+
+TEST_F(ReduceMultiOutputFusionTest, GetTupleElementMakeTupleSequence) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ fusion {
+ p0 = s32[] parameter(0)
+ p1 = s32[32] parameter(1)
+ custom-call = (bf16[], s32[], u32[]) custom-call(p1), custom_call_target="my_custom_call"
+ get-tuple-element.0 = bf16[] get-tuple-element(custom-call), index=0
+ get-tuple-element.1 = s32[] get-tuple-element(custom-call), index=1
+ bitcast = s32[1] bitcast(get-tuple-element.1)
+ dynamic-update-slice = s32[32] dynamic-update-slice(p1, bitcast, p0)
+ get-tuple-element.2 = u32[] get-tuple-element(custom-call), index=2
+ ROOT tuple.30 = (bf16[], s32[32], u32[]) tuple(get-tuple-element.0, dynamic-update-slice, get-tuple-element.2)
+ }
+
+ ENTRY entry{
+ p0 = s32[] parameter(0)
+ bitcast = s32[32] bitcast(p0)
+ ROOT address_computation.7.0 = (bf16[], s32[32], u32[]) fusion(p0, bitcast), kind=kCustom, calls=fusion
+ }
+ )")
+ .value();
+
+ ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc
new file mode 100644
index 0000000..493d167
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc
@@ -0,0 +1,703 @@
+
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+// Maps a computation to a boolean that indicates whether there is any
+// collective operations directly or indirectly invoked in the computation.
+using CollectiveInComputation =
+ absl::flat_hash_map<const HloComputation*, bool>;
+
+using InstructionVector = HloInstruction::InstructionVector;
+
+// Records starting index and the ending index of a pipelined while-op. They
+// are the indices of the while-loop operand.
+struct PipelinedP2PInfo {
+ int64_t opnd_start;
+ int64_t opnd_end;
+};
+
+// Returns whether the instruction is a collective operation.
+bool IsCollectiveOp(const HloInstruction* op) {
+ HloOpcode opcode = op->opcode();
+ // TODO(NVIDIA/4364298): The information is recorded in b/309639264.
+ // we need to avoid custom-calls to overlap with Send/Recv to workaround the
+ // bug. Remove custom-calls here when the bug is fixed.
+ if (opcode == HloOpcode::kCustomCall) {
+ return true;
+ }
+
+ return hlo_query::IsCollectiveCommunicationOp(opcode) ||
+ opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv;
+}
+
+// Returns whether the instruction may invoke collective operations directly
+// or indirectly.
+bool MayInvokeCollectiveOp(
+ const HloInstruction* hlo,
+ const CollectiveInComputation& collective_in_computation) {
+ if (IsCollectiveOp(hlo)) {
+ return true;
+ }
+ for (HloComputation* callee : hlo->called_computations()) {
+ auto collective_in_comp = collective_in_computation.find(callee);
+ CHECK(collective_in_comp != collective_in_computation.end());
+ if (collective_in_comp->second) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns the unique get-tuple-element user with the given idx or nullptr if
+// there isn't such a unique user.
+HloInstruction* FindUniqueGTEUserWithIndex(const HloInstruction* op,
+ int64_t idx) {
+ CHECK(op->shape().IsTuple());
+
+ HloInstruction* gte = nullptr;
+ for (auto user : op->users()) {
+ if (user->opcode() != HloOpcode::kGetTupleElement) {
+ continue;
+ }
+ if (user->tuple_index() == idx) {
+ if (gte == nullptr) {
+ gte = user;
+ } else {
+ return nullptr;
+ }
+ }
+ }
+ return gte;
+}
+
+// Returns whether there is any get-tuple-element user with the given idx.
+bool HasGTEUserWithIndex(const HloInstruction* op, int64_t idx) {
+ CHECK(op->shape().IsTuple());
+
+ for (auto user : op->users()) {
+ if (user->opcode() != HloOpcode::kGetTupleElement) {
+ continue;
+ }
+ if (user->tuple_index() == idx) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns the instruction hidden behind a trivial tuple or `op`. This allows
+// the discovery of recv-done for the following case, for which the indirection
+// would have been removed by tuple-simplification.
+// gte.0 = f32[1,1024,1024] get-tuple-element(recv-done), index=0
+// gte.1 = token get-tuple-element(recv-done.p), index=1
+// op = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+//
+// TODO(bixia): investigate the possible of implementing
+// m::TrivialTuple(m::RecvDone(&instr)) as suggested by code review.
+HloInstruction* MaySkipTrivialTuple(HloInstruction* op) {
+ if (op->opcode() != HloOpcode::kTuple) {
+ return op;
+ }
+ HloInstruction* hidden_op = nullptr;
+ for (auto opnd : op->mutable_operands()) {
+ if (opnd->opcode() != HloOpcode::kGetTupleElement) {
+ return op;
+ }
+ if (hidden_op == nullptr) {
+ hidden_op = opnd->mutable_operand(0);
+ } else if (opnd->mutable_operand(0) != hidden_op) {
+ return op;
+ }
+ }
+ return hidden_op;
+}
+
+// This routine is similar to the non-const version above except that the
+// the given instruction is used for pattern checking only and can't be mutated.
+const HloInstruction* MaySkipTrivialTuple(const HloInstruction* op) {
+ // Use const_cast to avoid repeating the non-const version above to find
+ // operands of the instruction through operands() instead of
+ // mutable_operands().
+ return MaySkipTrivialTuple(const_cast<HloInstruction*>(op));
+}
+
+// Finds a consecutive block of balanced SendDone/RecvDone in the while_init
+// of a while-loop, assuming its while_init is a tuple.
+std::optional<PipelinedP2PInfo>
+FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(
+ const HloInstruction* while_init) {
+ PipelinedP2PInfo pipelined_p2p_info{0, 0};
+ // Return whether the first SendDone/RecvDone has been seen.
+ auto has_started = [&]() {
+ return pipelined_p2p_info.opnd_start != pipelined_p2p_info.opnd_end;
+ };
+ // Record the difference between the number of SendDone and RecvDone in a
+ // consecutive block.
+ int difference = 0;
+ // If SendDone/RecvDone exists in a consecutive block in the while_init
+ // tuple, find such block.
+ for (int64_t i = 0; i < while_init->operand_count(); ++i) {
+ const HloInstruction* op = while_init->operand(i);
+ if ((op->opcode() == HloOpcode::kRecvDone ||
+ op->opcode() == HloOpcode::kSendDone) &&
+ op->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) {
+ if (op->opcode() == HloOpcode::kRecvDone) {
+ difference++;
+ } else {
+ difference--;
+ }
+ if (!has_started()) {
+ pipelined_p2p_info.opnd_start = i;
+ }
+ pipelined_p2p_info.opnd_end = i + 1;
+ } else {
+ if (has_started()) {
+ VLOG(10) << "End a consecutive block";
+ break;
+ }
+ }
+ }
+
+ if (difference != 0) {
+ VLOG(10) << "Mismatch number of SendDone and RecvDone: " << difference;
+ return std::nullopt;
+ }
+
+ if (has_started()) {
+ // Check for SendDone/RecvDone outside the consecutive block.
+ for (int64_t i = pipelined_p2p_info.opnd_end;
+ i < while_init->operand_count(); ++i) {
+ const HloInstruction* op = while_init->operand(i);
+ if (op->opcode() == HloOpcode::kRecvDone ||
+ op->opcode() == HloOpcode::kSendDone) {
+ VLOG(10) << "SendDone/RecvDone outside the consecutive block";
+ return std::nullopt;
+ break;
+ }
+ }
+ }
+
+ if (!has_started()) {
+ VLOG(10) << "No SendDone/RecvDone in while-init ";
+ return std::nullopt;
+ }
+
+ return pipelined_p2p_info;
+}
+
+// Checks whether the while-op, its while-body and while-condition have a
+// recognized pipelined pattern. If a pipelined pattern is found, returns the
+// first and last indices for the pipelined instruction in the while-init tuple.
+// For pipelined Send/Recv to work, the SendDone/RecvDone doesn't have to be in
+// a consecutive block, but this simplifies the implementation and is the
+// pattern that the current gpu-p2p-pipeliner generated.
+//
+// As a summary, this is what the routine looks for:
+//
+// . The while-init has a tuple with a single user.
+// . The while-init has a consecutive block of SendDone and RecvDone. The
+// numbers of SendDone and RecvDone are the same, and there isn't any other
+// SendDone and RecvDone outside the block.
+// . The while-body has a single tuple parameter.
+// . For the while-op result tuple and the while-body parameter tuple:
+// The index corresponding to the index of SendDone in while-init should not
+// correspond to any get-element-tuple user.
+// The index corresponding to the index of RecvDone in while-init should
+// correspond to a single get-element-tuple user.
+// . In the while-body result tuple, the operand with an index corresponding to
+// the index in the while-init SendDone and RecvDone should also be a SendDone
+// or RecvDone.
+//
+// TODO(bixia): support pipelined SendDone/RecvDone not in a consecutive block
+// if the gpu-p2p-pipeliner will ever generate such code in the future.
+std::optional<PipelinedP2PInfo> FindPipelinedP2P(
+ const HloInstruction* while_op) {
+ VLOG(10) << "while_op: " << while_op->ToString();
+ const HloInstruction* while_init = while_op->while_init();
+ if (while_init->opcode() != HloOpcode::kTuple ||
+ while_init->user_count() != 1) {
+ return std::nullopt;
+ }
+
+ // The while-body and while-condition should have one parameter of a tuple
+ // shape.
+ const HloComputation* while_body = while_op->while_body();
+ const HloComputation* while_condition = while_op->while_condition();
+ if (while_body->num_parameters() != 1 ||
+ while_condition->num_parameters() != 1) {
+ return std::nullopt;
+ }
+
+ std::optional<PipelinedP2PInfo> pipelined_p2p_info =
+ FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(while_init);
+ if (!pipelined_p2p_info.has_value()) {
+ return std::nullopt;
+ }
+
+ VLOG(10) << "opnd_start " << pipelined_p2p_info->opnd_start << " opnd_end "
+ << pipelined_p2p_info->opnd_end;
+
+ // In the while-result or while-body parameter, the index for RecvDone should
+ // correspond to one get-tuple-element user and the index for SendDone should
+ // not correspond to any get-tuple-element user.
+ for (int64_t i = pipelined_p2p_info->opnd_start;
+ i < pipelined_p2p_info->opnd_end; ++i) {
+ const HloInstruction* op = while_init->operand(i);
+ if (op->opcode() == HloOpcode::kRecvDone) {
+ if (!FindUniqueGTEUserWithIndex(while_op, i)) {
+ VLOG(10) << "While result get-tuple-element user with index " << i
+ << " not unique";
+ return std::nullopt;
+ }
+ if (!FindUniqueGTEUserWithIndex(while_body->parameter_instruction(0),
+ i)) {
+ VLOG(10) << "While-body parameter get-tuple-element user with index "
+ << i << " not unique";
+ return std::nullopt;
+ }
+ } else {
+ CHECK(op->opcode() == HloOpcode::kSendDone);
+ if (HasGTEUserWithIndex(while_op, i) ||
+ HasGTEUserWithIndex(while_body->parameter_instruction(0), i)) {
+ VLOG(10) << "SendDone with index " << i << " has unexpected users";
+ return std::nullopt;
+ }
+ }
+ }
+
+ // The element in the while-body result tuple corresponding to the pipelined
+ // SendDone/RecvDone in the while-init have the same opcode.
+ const HloInstruction* root = while_body->root_instruction();
+ for (int64_t i = pipelined_p2p_info->opnd_start;
+ i < pipelined_p2p_info->opnd_end; ++i) {
+ const HloInstruction* op_init = while_init->operand(i);
+ const HloInstruction* op_root = root->operand(i);
+ op_root = MaySkipTrivialTuple(op_root);
+ if (op_init->opcode() != op_root->opcode()) {
+ VLOG(10) << "Mismatching opcode, op_init: " << op_init->ToString()
+ << " op_root: " << op_root->ToString();
+ return std::nullopt;
+ }
+ }
+
+ return pipelined_p2p_info.value();
+}
+
+absl::Status RemoveOpFromParent(HloInstruction* op) {
+ TF_RETURN_IF_ERROR(op->DropAllControlDeps());
+ TF_RETURN_IF_ERROR(op->parent()->RemoveInstruction(op));
+ return absl::OkStatus();
+}
+
+absl::Status ReplaceOpInSequence(HloInstruction* old_op, HloInstruction* new_op,
+ HloInstructionSequence& instruction_sequence) {
+ VLOG(10) << "old_op: " << old_op->ToString();
+ VLOG(10) << "new_op: " << new_op->ToString();
+ instruction_sequence.replace_instruction(old_op, new_op);
+ return RemoveOpFromParent(old_op);
+}
+
+absl::Status ReplaceUsesAndUpdateSequence(
+ HloInstruction* old_op, HloInstruction* new_op,
+ HloInstructionSequence& instruction_sequence, bool diff_shape = false) {
+ VLOG(10) << "old_op: " << old_op->ToString();
+ VLOG(10) << "new_op: " << new_op->ToString();
+ if (diff_shape) {
+ TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWithDifferentShape(new_op));
+ } else {
+ TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWith(new_op));
+ }
+ return ReplaceOpInSequence(old_op, new_op, instruction_sequence);
+}
+
+absl::Status ReplaceUsesAndUpdateSequence(
+ const InstructionVector& old_ops, const InstructionVector& new_ops,
+ HloInstructionSequence& instruction_sequence) {
+ CHECK(old_ops.size() == new_ops.size());
+ for (int64_t i = 0; i < old_ops.size(); ++i) {
+ TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(old_ops[i], new_ops[i],
+ instruction_sequence));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status RemoveDoneOpsAndUpdateSequence(
+ const InstructionVector& ops,
+ HloInstructionSequence& instruction_sequence) {
+ auto remove_op = [&](HloInstruction* op) {
+ VLOG(10) << "op: " << op->ToString();
+ TF_RETURN_IF_ERROR(RemoveOpFromParent(op));
+ instruction_sequence.remove_instruction(op);
+ return absl::OkStatus();
+ };
+ for (auto op : ops) {
+ if (op->opcode() == HloOpcode::kTuple) {
+ InstructionVector to_remove;
+ HloInstruction* tuple_op = op;
+ op = MaySkipTrivialTuple(tuple_op);
+ to_remove.push_back(tuple_op);
+ for (auto opnd : tuple_op->mutable_operands()) {
+ to_remove.push_back(opnd);
+ }
+ for (auto opnd : to_remove) {
+ TF_RETURN_IF_ERROR(remove_op(opnd));
+ }
+ }
+ TF_RETURN_IF_ERROR(remove_op(op));
+ }
+ return absl::OkStatus();
+}
+
+bool InsertBeforeFirstCollectiveOp(
+ const InstructionVector& ops,
+ const CollectiveInComputation& collective_in_computation,
+ HloInstructionSequence& instruction_sequence, int64_t& idx,
+ int64_t& idx_tot) {
+ bool inserted = false;
+ while (idx < idx_tot) {
+ HloInstruction* hlo = instruction_sequence.instructions()[idx];
+ if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
+ for (auto op : ops) {
+ instruction_sequence.insert_instruction(op, idx);
+ idx++;
+ idx_tot++;
+ }
+ inserted = true;
+ break;
+ }
+ idx++;
+ }
+ return inserted;
+}
+
+void CopyInstructionInfo(const HloInstruction* old_op, HloInstruction* new_op) {
+ new_op->SetAndSanitizeName(absl::StrCat(old_op->name(), ".clone"));
+ new_op->set_metadata(old_op->metadata());
+ new_op->add_frontend_attributes(old_op->frontend_attributes());
+ new_op->CopyBackendConfigFrom(old_op);
+}
+
+HloInstruction* CreateRecvDoneFrom(const HloInstruction* old_recv_done,
+ HloInstruction* recv,
+ HloComputation* computation) {
+ HloInstruction* recv_done =
+ computation->AddInstruction(HloInstruction::CreateRecvDone(
+ recv, old_recv_done->channel_id().value()));
+ CopyInstructionInfo(old_recv_done, recv_done);
+ return recv_done;
+}
+
+HloInstruction* CreateSendDoneFrom(const HloInstruction* old_send_done,
+ HloInstruction* send,
+ HloComputation* computation) {
+ HloInstruction* send_done =
+ computation->AddInstruction(HloInstruction::CreateSendDone(
+ send, old_send_done->channel_id().value()));
+ CopyInstructionInfo(old_send_done, send_done);
+ return send_done;
+}
+
+absl::Status RewritePipelinedP2PWhileBody(
+ const CollectiveInComputation& collective_in_computation,
+ const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op,
+ int64_t opnd_start, int64_t opnd_end) {
+ HloComputation* computation = while_op->while_body();
+ HloInstruction* while_init = while_op->while_init();
+ HloInstruction* root = computation->root_instruction();
+ HloInstructionSequence& instruction_sequence =
+ computation->parent()->schedule().GetOrCreateSequence(computation);
+
+ HloInstruction* param = computation->parameter_instruction(0);
+ *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
+
+ InstructionVector recv_dones;
+ InstructionVector new_recv_dones;
+ InstructionVector new_send_dones;
+ for (int64_t i = opnd_start; i < opnd_end; ++i) {
+ const HloInstruction* op = root->operand(i);
+ op = MaySkipTrivialTuple(op);
+ if (op->opcode() == HloOpcode::kRecvDone) {
+ HloInstruction* gte = FindUniqueGTEUserWithIndex(param, i);
+ CHECK(gte != nullptr);
+ recv_dones.push_back(gte);
+
+ // Create the new RecvDone using the new while-body parameter.
+ HloInstruction* recv = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(param, i));
+
+ HloInstruction* recv_done = CreateRecvDoneFrom(op, recv, computation);
+ new_recv_dones.push_back(recv_done);
+ continue;
+ }
+ CHECK(op->opcode() == HloOpcode::kSendDone);
+ // Create the new SendDone using the new while-op result.
+ HloInstruction* send = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(param, i));
+ HloInstruction* send_done = CreateSendDoneFrom(op, send, computation);
+ new_send_dones.push_back(send_done);
+ }
+ TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
+ instruction_sequence));
+
+ // Create a new root tuple.
+ InstructionVector done_ops;
+ InstructionVector new_opnds;
+ for (int64_t i = 0; i < while_init->operand_count(); ++i) {
+ HloInstruction* op = root->mutable_operand(i);
+ if (i >= opnd_start && i < opnd_end) {
+ new_opnds.push_back(MaySkipTrivialTuple(op)->mutable_operand(0));
+ done_ops.push_back(op);
+ } else {
+ new_opnds.push_back(op);
+ }
+ }
+ HloInstruction* new_root =
+ computation->AddInstruction(HloInstruction::CreateTuple(new_opnds));
+ computation->set_root_instruction(new_root,
+ /*accept_different_shape=*/true);
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(root));
+ instruction_sequence.replace_instruction(root, new_root);
+
+ TF_RETURN_IF_ERROR(
+ RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
+
+ // Find a place to put the new SendDone. It will be either the first
+ // may-invoke-collective ops that is not in the pipelined Send/Recv chain or
+ // the first op in the pipelined Send/Recv chain.
+ int64_t idx = 0;
+ int64_t idx_end = instruction_sequence.size();
+ bool inserted =
+ InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
+ instruction_sequence, idx, idx_end);
+ CHECK(inserted); // There are Send/Recv in the while-body, expect inserted.
+ CHECK(idx_end == instruction_sequence.size());
+
+ // The module schedule will be updated at the end of the pass.
+ return absl::OkStatus();
+}
+
+void RewritePipelinedP2PWhileCond(
+ const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op) {
+ HloComputation* computation = while_op->while_condition();
+ HloInstruction* param = computation->parameter_instruction(0);
+ *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
+ VLOG(10) << computation->ToString();
+}
+
+// Rewrites the while-op with a recognized pipelined SendDone/RecvDone pattern
+// to pipeline Send/Recv instead.
+absl::Status TransformLoop(
+ const PipelinedP2PInfo& pipelined_info,
+ const CollectiveInComputation& collective_in_computation, int64_t& idx,
+ int64_t& idx_end, HloInstructionSequence& instruction_sequence,
+ HloInstruction* while_op) {
+ HloComputation* computation = while_op->parent();
+ int64_t opnd_start = pipelined_info.opnd_start;
+ int64_t opnd_end = pipelined_info.opnd_end;
+ VLOG(10) << "Transform pipelined while-op " << while_op->ToString();
+ HloInstruction* while_init = while_op->while_init();
+ InstructionVector new_while_init_opnds;
+ std::vector<Shape> new_parameter_shapes;
+ for (int64_t i = 0; i < while_init->operand_count(); ++i) {
+ HloInstruction* op = while_init->mutable_operand(i);
+ if (i >= opnd_start && i < opnd_end) {
+ // Get Send/Recv from SendDone/RecvDone.
+ new_while_init_opnds.push_back(op->mutable_operand(0));
+ } else {
+ new_while_init_opnds.push_back(op);
+ }
+ new_parameter_shapes.push_back(new_while_init_opnds.back()->shape());
+ }
+
+ RewritePipelinedP2PWhileCond(new_parameter_shapes, while_op);
+ TF_RETURN_IF_ERROR(RewritePipelinedP2PWhileBody(
+ collective_in_computation, new_parameter_shapes, while_op, opnd_start,
+ opnd_end));
+ HloInstruction* new_while_init = computation->AddInstruction(
+ HloInstruction::CreateTuple(new_while_init_opnds), "while-init");
+ VLOG(10) << "new_while_init: " << new_while_init->ToString();
+ HloInstruction* new_while_op = computation->AddInstruction(
+ HloInstruction::CreateWhile(
+ while_op->while_body()->root_instruction()->shape(),
+ while_op->while_condition(), while_op->while_body(), new_while_init),
+ "while-result");
+ CopyInstructionInfo(while_op, new_while_op);
+ VLOG(10) << "new_while_op: " << new_while_op->ToString();
+
+ InstructionVector recv_dones;
+ InstructionVector new_recv_dones;
+ InstructionVector new_send_dones;
+ InstructionVector done_ops;
+ for (int64_t i = opnd_start; i < opnd_end; ++i) {
+ HloInstruction* op = while_init->mutable_operand(i);
+ done_ops.push_back(op);
+ if (op->opcode() == HloOpcode::kRecvDone) {
+ HloInstruction* gte = FindUniqueGTEUserWithIndex(while_op, i);
+ CHECK(gte != nullptr);
+ recv_dones.push_back(gte);
+
+ // Create the new RecvDone using the new while-op result.
+ HloInstruction* recv = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(new_while_op, i));
+ HloInstruction* recv_done = computation->AddInstruction(
+ HloInstruction::CreateRecvDone(recv, op->channel_id().value()));
+ new_recv_dones.push_back(recv_done);
+ CopyInstructionInfo(op, recv_done);
+ continue;
+ }
+ CHECK(op->opcode() == HloOpcode::kSendDone);
+ // Create the new SendDone using the new while-op result.
+ HloInstruction* send = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(new_while_op, i));
+ HloInstruction* send_done = computation->AddInstruction(
+ HloInstruction::CreateSendDone(send, op->channel_id().value()));
+ new_send_dones.push_back(send_done);
+ CopyInstructionInfo(op, send_done);
+ }
+
+ TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(
+ while_op, new_while_op, instruction_sequence, /*diff_shape*/ true));
+ TF_RETURN_IF_ERROR(
+ ReplaceOpInSequence(while_init, new_while_init, instruction_sequence));
+ TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
+ instruction_sequence));
+ TF_RETURN_IF_ERROR(
+ RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
+
+ int64_t opnd_tot = opnd_end - opnd_start;
+ // Verify that the numbers of ops we have removed from the sequence is
+ // opnd_tot and they are before the position of the new while-op.
+ CHECK(idx_end == instruction_sequence.size() + opnd_tot);
+ CHECK(instruction_sequence.instructions()[idx - opnd_tot] == new_while_op);
+
+ // Update idx_end to reflect the current size of the instruction sequence.
+ // Update idx to right after the new while-op.
+ idx_end -= opnd_tot;
+ idx = idx - opnd_tot + 1;
+ bool inserted =
+ InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
+ instruction_sequence, idx, idx_end);
+ CHECK(idx_end == instruction_sequence.size());
+ // If there isn't any may-invoke-collective ops after the while-op, add
+ // the new SendDone ops before the last instruction in the sequence.
+ if (!inserted) {
+ CHECK(idx_end == idx);
+ idx--;
+ for (auto send_done : new_send_dones) {
+ instruction_sequence.insert_instruction(send_done, idx++);
+ }
+ }
+ return absl::OkStatus();
+}
+
+// Find while-loop with pipelined Send/Recv and rotates the SendDone/RecvDone
+// for such while-loop.
+absl::StatusOr<bool> ProcessComputation(
+ HloModule* module, HloComputation* computation,
+ CollectiveInComputation& collective_in_computation) {
+ VLOG(10) << "Process compuation " << computation->name();
+ bool changed = false;
+ HloInstructionSequence& instruction_sequence =
+ module->schedule().GetOrCreateSequence(computation);
+ int64_t idx = 0;
+ int64_t idx_end = instruction_sequence.size();
+ while (idx < idx_end) {
+ HloInstruction* hlo = instruction_sequence.instructions()[idx];
+
+ if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
+ collective_in_computation[computation] = true;
+ }
+
+ if (hlo->opcode() != HloOpcode::kWhile) {
+ idx++;
+ continue;
+ }
+
+ std::optional<PipelinedP2PInfo> pipelined_info = FindPipelinedP2P(hlo);
+ if (!pipelined_info.has_value()) {
+ idx++;
+ continue;
+ }
+ TF_RETURN_IF_ERROR(TransformLoop(pipelined_info.value(),
+ collective_in_computation, idx, idx_end,
+ instruction_sequence, hlo));
+ changed = true;
+ }
+ return changed;
+}
+} // namespace
+
+absl::StatusOr<bool> PipelinedP2PRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ if (!module->has_schedule()) return changed;
+ CollectiveInComputation collective_in_computation;
+ // Visit the computations in the order of callees to callers, so that
+ // while-body is processed before while-op.
+ for (auto* computation :
+ module->MakeComputationPostOrder(execution_threads)) {
+ if (computation->IsFusionComputation()) {
+ collective_in_computation[computation] = false;
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ bool cur_changed,
+ ProcessComputation(module, computation, collective_in_computation));
+ changed |= cur_changed;
+ }
+
+ if (changed) {
+ TF_RETURN_IF_ERROR(module->schedule().Update());
+ }
+
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h
new file mode 100644
index 0000000..d2aca8c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h
@@ -0,0 +1,133 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_
+
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// PipelinedP2PRewriter is a pass that rewrites pipelined Send/Recv related
+// code for point-to-point communication to rotate SendDone and RecvDone at the
+// end of a loop iteration to the beginning of the next iteration. This pass
+// operates on scheduled module and updates the instruction sequence.
+//
+// In particular, a pipelined Send/Recv chain with one channel group with this
+// code pattern:
+//
+// main:
+// recv
+// send
+// recv-done
+// send-done
+// while-init = (recv-done, send-done, ...)
+// while-op = while(whiel-init) ...
+//
+// while-body:
+// ...
+// recv
+// send
+// recv-done
+// send-done
+// ROOT tuple(recv-done, send-done, ...)
+//
+// Will be transformed to:
+//
+// main:
+// recv
+// send
+// while-init = (recv, send, ...)
+// while-op = while(whiel-init) ...
+// recv-done
+// send-done
+//
+// while-body:
+// recv-done
+// ...
+// send-done
+// recv
+// send
+// ROOT tuple(recv, send, ...)
+//
+// A pipelined Send/Recv chain with two channel groups with this code pattern:
+//
+// main:
+// recv.0
+// send.0
+// recv.1
+// send.1
+// recv-done.0
+// send-done.0
+// recv-done.1
+// send-done.1
+// while-init = (recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
+// while-op = while(whiel-init) ...
+//
+// while-body:
+// ...
+// recv.0
+// send.0
+// recv.1
+// send.1
+// recv-done.0
+// send-done.0
+// recv-done.1
+// send-done.1
+// ROOT = tuple(recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
+//
+// Will be transformed to:
+//
+// main:
+//
+// recv.0
+// send.0
+// recv.1
+// send.1
+// while-init = (recv.0, send.0, recv.1, send.1, ...)
+// while-op = while(while-init) ...
+// recv-done.0
+// send-done.0
+// recv-done.1
+// send-done.1
+//
+// while-body:
+// recv-done.0
+// recv-done.1
+// ...
+// send-done.0
+// send-done.1
+// recv.0
+// send.1
+// recv.1
+// send.1
+// ROOT tuple(recv.0, send.0, recv.1, send.1, ...)
+//
+class PipelinedP2PRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "pipelined-p2p-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc
new file mode 100644
index 0000000..287603c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc
@@ -0,0 +1,674 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h"
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class PipelinedP2pRewriterTest : public HloTestBase {
+ protected:
+ void DoFileCheck(const HloModule* module, absl::string_view expected) {
+ HloPrintOptions options;
+ options.set_print_operand_shape(false);
+ options.set_print_result_shape(false);
+ TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+ RunFileCheck(module->ToString(options), expected));
+ EXPECT_TRUE(filecheck_matched);
+ }
+};
+
+TEST_F(PipelinedP2pRewriterTest, SendRecUnpipelinedNotTransform) {
+ const char* kModuleStr = R"(
+HloModule test
+
+cond {
+ param = (u32[], u32[2]) parameter(0)
+ count = get-tuple-element(%param), index=0
+ ub = u32[] constant(11)
+ ROOT result = pred[] compare(count, ub), direction=LT
+ }
+
+body {
+ param = (u32[], u32[2]) parameter(0)
+ count = get-tuple-element(param), index=0
+ send-data = u32[2] get-tuple-element(param), index=1
+
+ after-all.0.n = token[] after-all()
+ recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{3,0}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.n),
+ channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{3,0}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.0 = token[] send-done(send.0), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+
+ recv-data = u32[2] get-tuple-element(recv-done.0), index=0
+
+ c1 = u32[] constant(1)
+ new_count = u32[] add(count, c1)
+
+ r = u32[2] broadcast(c1), dimensions={}
+ s = u32[2] add(r, recv-data)
+
+ ROOT result = (u32[], u32[2]) tuple(new_count, s)
+ }
+
+ ENTRY test_computation {
+ c0 = u32[] constant(0)
+ c1 = u32[] constant(1)
+ r = u32[] replica-id()
+ a = u32[] add(c1, r)
+ init = u32[2] broadcast(a), dimensions={}
+ while_init = (u32[], u32[2]) tuple(c0, init)
+ while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond,
+ backend_config={"known_trip_count":{"n":"11"}}
+ ROOT recv-data = u32[2] get-tuple-element(while_result), index=1
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(kModuleStr));
+ PipelinedP2PRewriter rewriter;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+// Tests the rewrite for a pipelined Send/Recv chain with only one channel
+// group.
+TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) {
+ const char* kModuleStr = R"(
+ HloModule test, is_scheduled=true
+
+ while-cond {
+ param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+ ub = u32[] constant(25)
+ ROOT cond-result = pred[] compare(count, ub), direction=LT
+ }
+
+ while-body {
+ param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+
+ recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+ recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
+
+ c1 = u32[] constant(1)
+ new-count = u32[] add(count, c1)
+ replica = u32[] replica-id()
+ c10 = u32[] constant(10)
+ sum = u32[] add(replica, c10)
+ sum2 = u32[] add(sum, count)
+ conv = f32[] convert(sum2)
+ p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
+ b = f32[1, 1024, 1024] add(p, recv-data)
+ c = f32[1, 1024, 1024] multiply(b, b)
+ d = f32[1, 1024, 1024] tan(c)
+ s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
+ lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+ send-data = f32[1, 1024, 1024] add(c, s)
+
+ after-all = token[] after-all()
+ recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+ channel_id=1, frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.p = token[] send-done(send), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
+ gte.1 = token[] get-tuple-element(recv-done.p), index=1
+ recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+ ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
+ tuple(new-count, recv-done-tuple, send-done.p)
+ }
+
+ ENTRY main {
+ c0 = u32[] constant(0)
+ f0 = f32[] constant(0.0)
+ init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+ after-all.1 = token[] after-all()
+ recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.1.p = token[] send-done(send.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ while-init.p = (u32[], (f32[1,1024,1024], token[]), token[])
+ tuple(c0, recv-done.1.p, send-done.1.p)
+ while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
+ while(while-init.p),
+ body=while-body, condition=while-cond,
+ backend_config={"known_trip_count":{"n":"25"}}
+
+ recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
+
+ ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
+ }
+ )";
+
+ const char* kExpected = R"(
+ CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
+ CHECK: %param.1 = parameter(0)
+ CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
+ CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
+ CHECK: %count.1 = get-tuple-element(%param.1), index=0
+ CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK: %recv-data = get-tuple-element(%recv-done.p.clone), index=0
+ CHECK: %c1 = constant(1)
+ CHECK: %new-count = add(%count.1, %c1)
+ CHECK: %replica = replica-id()
+ CHECK: %c10 = constant(10)
+ CHECK: %sum = add(%replica, %c10)
+ CHECK: %sum2 = add(%sum, %count.1)
+ CHECK: %conv = convert(%sum2)
+ CHECK: %p = broadcast(%conv), dimensions={}
+ CHECK: %b = add(%p, %recv-data)
+ CHECK: %c = multiply(%b, %b)
+ CHECK: %d = tan(%c)
+ CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+ CHECK: %send-data = add(%c, %s)
+ CHECK: %after-all = after-all()
+ CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+ CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+ CHECK: ROOT %tuple = tuple(%new-count, %recv, %send)
+ CHECK: }
+
+ CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
+ CHECK: %param = parameter(0)
+ CHECK: %count = get-tuple-element(%param), index=0
+ CHECK: %ub = constant(25)
+ CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
+ CHECK: }
+
+ CHECK: ENTRY %main () -> f32[1,1024,1024] {
+ CHECK: %c0 = constant(0)
+ CHECK: %f0 = constant(0)
+ CHECK: %init = broadcast(%f0), dimensions={}
+ CHECK: %after-all.1 = after-all()
+ CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+ CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+ CHECK: %while-init = tuple(%c0, %recv.1, %send.1)
+ CHECK: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body,
+ CHECK-SAME{LITERAL}: backend_config={"known_trip_count":{"n":"25"}}
+ CHECK: %get-tuple-element.2 = get-tuple-element(%while-result.p.clone), index=1
+ CHECK: %get-tuple-element.3 = get-tuple-element(%while-result.p.clone), index=2
+ CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK: ROOT %entry-result = get-tuple-element(%recv-done.1.p.clone), index=0
+ CHECK: })";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(kModuleStr));
+ PipelinedP2PRewriter rewriter;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ DoFileCheck(module.get(), kExpected);
+}
+
+// Repeats the Send/Recv pattern in the previous test, to test that we can
+// rewrite a routine with multiple pipelined loops without crashing.
+TEST_F(PipelinedP2pRewriterTest, SendRecvTwoPipelinedWhileLoops) {
+ const char* kModuleStr = R"(
+ HloModule test, is_scheduled=true
+
+ while-cond {
+ param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+ ub = u32[] constant(25)
+ ROOT cond-result = pred[] compare(count, ub), direction=LT
+ }
+
+ while-body {
+ param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+
+ recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+ send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
+
+ c1 = u32[] constant(1)
+ new-count = u32[] add(count, c1)
+
+ after-all = token[] after-all()
+ recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+ channel_id=1, frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.p = token[] send-done(send), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
+ gte.1 = token[] get-tuple-element(recv-done.p), index=1
+ recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+ ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
+ tuple(new-count, recv-done-tuple, send-done.p)
+ }
+
+ while-cond-2 {
+ param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+ ub = u32[] constant(25)
+ ROOT cond-result = pred[] compare(count, ub), direction=LT
+ }
+
+ while-body-2 {
+ param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+
+ recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+ send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
+
+ c1 = u32[] constant(1)
+ new-count = u32[] add(count, c1)
+
+ after-all = token[] after-all()
+ recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+ channel_id=1, frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.p = token[] send-done(send), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
+ gte.1 = token[] get-tuple-element(recv-done.p), index=1
+ recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+ ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
+ tuple(new-count, recv-done-tuple, send-done.p)
+ }
+
+ ENTRY main {
+ c0 = u32[] constant(0)
+ f0 = f32[] constant(0.0)
+ init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+ after-all.1 = token[] after-all()
+ recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.1.p = token[] send-done(send.1), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ while-init.p = (u32[], (f32[1,1024,1024], token[]), token[])
+ tuple(c0, recv-done.1.p, send-done.1.p)
+ while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
+ while(while-init.p),
+ body=while-body, condition=while-cond,
+ backend_config={"known_trip_count":{"n":"25"}}
+
+ recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
+
+ after-all-2.1 = token[] after-all()
+ recv-2.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-2.1), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send-2.1 = (f32[1, 1024, 1024], u32[], token[]) send(recv-done.1.q, after-all-2.1), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done-2.1.p = (f32[1,1024,1024], token[]) recv-done(recv-2.1), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done-2.1.p = token[] send-done(send-2.1), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ while-init-2.p = (u32[], (f32[1,1024,1024], token[]), token[])
+ tuple(c0, recv-done-2.1.p, send-done-2.1.p)
+ while-result-2.p = (u32[], (f32[1,1024,1024], token[]), token[])
+ while(while-init-2.p),
+ body=while-body-2, condition=while-cond-2,
+ backend_config={"known_trip_count":{"n":"25"}}
+
+ recv-done-2.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result-2.p), index=1
+
+ ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done-2.1.q), index=0
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(kModuleStr));
+ PipelinedP2PRewriter rewriter;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+ // Check that we transform the module without crashing.
+ EXPECT_TRUE(changed);
+}
+
+// Tests the rewrite for a pipelined Send/Recv chain with two channel groups.
+TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) {
+ const char* kModuleStr = R"(
+ HloModule test, is_scheduled=true
+
+ while-cond {
+ param = (u32[], (f32[1,1024,1024], token[]), token[],
+ (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+ ub = u32[] constant(25)
+ ROOT cond-result = pred[] compare(count, ub), direction=LT
+ }
+
+ while-body {
+ param = (u32[], (f32[1,1024,1024], token[]), token[],
+ (f32[1,1024,1024], token[]), token[]) parameter(0)
+ count = get-tuple-element(param), index=0
+
+ recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+ recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0
+ recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3
+ recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
+
+ replica = u32[] replica-id()
+ constant0 = u32[] constant(0)
+ compare0 = pred[] compare(replica, constant0), direction=EQ
+ compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
+ recv-data = f32[1, 1024, 1024] select(compare, recv-data.0, recv-data.1)
+
+ c1 = u32[] constant(1)
+ new-count = u32[] add(count, c1)
+ c10 = u32[] constant(10)
+ sum = u32[] add(replica, c10)
+ sum2 = u32[] add(sum, count)
+ conv = f32[] convert(sum2)
+ p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
+ b = f32[1, 1024, 1024] add(p, recv-data)
+ c = f32[1, 1024, 1024] multiply(b, b)
+ d = f32[1, 1024, 1024] tan(c)
+ s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
+ lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+ send-data = f32[1, 1024, 1024] add(c, s)
+
+ after-all = token[] after-all()
+ recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{3,0}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+ channel_id=1, frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{3,0}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.p = token[] send-done(send), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+
+ after-all.1 = token[] after-all()
+ recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+ _xla_send_recv_pipeline="1"
+ }
+ send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1),
+ channel_id=2, frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+ _xla_send_recv_pipeline="1"
+ }
+ recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_pipeline="1"
+ }
+ send-done.1.p = token[] send-done(send.1), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_pipeline="1"
+ }
+
+ ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[],
+ (f32[1,1024,1024], token[]), token[])
+ tuple(new-count, recv-done.p, send-done.p, recv-done.1.p, send-done.1.p)
+ }
+
+ ENTRY main {
+ c0 = u32[] constant(0)
+ f0 = f32[] constant(0.0)
+ init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+ after-all.2 = token[] after-all()
+ recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{3,0}}",
+ _xla_send_recv_pipeline="0"
+ }
+ send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{3,0}}",
+ _xla_send_recv_pipeline="0"
+ }
+ recv-done.2.p = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+ send-done.2.p = token[] send-done(send.2), channel_id=1,
+ frontend_attributes={
+ _xla_send_recv_pipeline="0"
+ }
+
+ after-all.3 = token[] after-all()
+ recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+ _xla_send_recv_pipeline="1"
+ }
+ send.3 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.3), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+ _xla_send_recv_pipeline="1"
+ }
+ recv-done.3.p = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_pipeline="1"
+ }
+ send-done.3.p = token[] send-done(send.3), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_pipeline="1"
+ }
+
+ while-init.p = (u32[], (f32[1,1024,1024], token[]), token[],
+ (f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2.p, send-done.2.p, recv-done.3.p, send-done.3.p)
+ while-result.p = (u32[], (f32[1,1024,1024], token[]), token[],
+ (f32[1,1024,1024], token[]), token[]) while(while-init.p),
+ body=while-body, condition=while-cond,
+ backend_config={"known_trip_count":{"n":"25"}}
+
+ recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
+ recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0
+ recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=3
+ recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0
+
+ replica = u32[] replica-id()
+ constant0 = u32[] constant(0)
+ compare0 = pred[] compare(replica, constant0), direction=EQ
+ compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
+ ROOT entry-result = f32[1, 1024, 1024] select(compare, recv-data.2, recv-data.3)
+ }
+ )";
+
+ const char* kExpected = R"(
+ CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
+ CHECK: %param.1 = parameter(0)
+ CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
+ CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
+ CHECK: %get-tuple-element.2 = get-tuple-element(%param.1), index=3
+ CHECK: %get-tuple-element.3 = get-tuple-element(%param.1), index=4
+ CHECK: %count.1 = get-tuple-element(%param.1), index=0
+ CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK: %recv-data.0 = get-tuple-element(%recv-done.p.clone), index=0
+ CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+ CHECK: %recv-data.1 = get-tuple-element(%recv-done.1.p.clone), index=0
+ CHECK: %replica = replica-id()
+ CHECK: %constant0 = constant(0)
+ CHECK: %compare0 = compare(%replica, %constant0), direction=EQ
+ CHECK: %compare = broadcast(%compare0), dimensions={}
+ CHECK: %recv-data.2 = select(%compare, %recv-data.0, %recv-data.1)
+ CHECK: %c1 = constant(1)
+ CHECK: %new-count = add(%count.1, %c1)
+ CHECK: %c10 = constant(10)
+ CHECK: %sum = add(%replica, %c10)
+ CHECK: %sum2 = add(%sum, %count.1)
+ CHECK: %conv = convert(%sum2)
+ CHECK: %p = broadcast(%conv), dimensions={}
+ CHECK: %b = add(%p, %recv-data.2)
+ CHECK: %c = multiply(%b, %b)
+ CHECK: %d = tan(%c)
+ CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+ CHECK: %send-data = add(%c, %s)
+ CHECK: %after-all = after-all()
+ CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+ CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+ CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+ CHECK: %after-all.1 = after-all()
+ CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+ CHECK{LITERAL}: %send.1 = send(%send-data, %after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+ CHECK: ROOT %tuple = tuple(%new-count, %recv, %send, %recv.1, %send.1)
+ CHECK: }
+
+ CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
+ CHECK: %param = parameter(0)
+ CHECK: %count = get-tuple-element(%param), index=0
+ CHECK: %ub = constant(25)
+ CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
+ CHECK: }
+
+ CHECK: ENTRY %main () -> f32[1,1024,1024] {
+ CHECK: %c0 = constant(0)
+ CHECK: %f0 = constant(0)
+ CHECK: %init = broadcast(%f0), dimensions={}
+ CHECK: %after-all.2 = after-all()
+ CHECK{LITERAL}: %recv.2 = recv(%after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+ CHECK{LITERAL}: %send.2 = send(%init, %after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+ CHECK: %after-all.3 = after-all()
+ CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+ CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+ CHECK: %while-init = tuple(%c0, %recv.2, %send.2, %recv.3, %send.3)
+ CHECK{LITERAL}: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}}
+ CHECK: %get-tuple-element.4 = get-tuple-element(%while-result.p.clone), index=1
+ CHECK: %get-tuple-element.5 = get-tuple-element(%while-result.p.clone), index=2
+ CHECK: %get-tuple-element.6 = get-tuple-element(%while-result.p.clone), index=3
+ CHECK: %get-tuple-element.7 = get-tuple-element(%while-result.p.clone), index=4
+ CHECK: %recv-done.2.p.clone = recv-done(%get-tuple-element.4), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK: %recv-data.3 = get-tuple-element(%recv-done.2.p.clone), index=0
+ CHECK: %recv-done.3.p.clone = recv-done(%get-tuple-element.6), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+ CHECK: %recv-data.4 = get-tuple-element(%recv-done.3.p.clone), index=0
+ CHECK: %replica.1 = replica-id()
+ CHECK: %constant0.1 = constant(0)
+ CHECK: %compare0.1 = compare(%replica.1, %constant0.1), direction=EQ
+ CHECK: %compare.1 = broadcast(%compare0.1), dimensions={}
+ CHECK: %send-done.2.p.clone = send-done(%get-tuple-element.5), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+ CHECK: %send-done.3.p.clone = send-done(%get-tuple-element.7), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+ CHECK: ROOT %entry-result = select(%compare.1, %recv-data.3, %recv-data.4)
+ CHECK: })";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(kModuleStr));
+ PipelinedP2PRewriter rewriter;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ DoFileCheck(module.get(), kExpected);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc
new file mode 100644
index 0000000..9ceca0c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc
@@ -0,0 +1,886 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/priority_fusion.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/meta/type_traits.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/dump.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusion_process_dump.pb.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/model/fusion_analysis_cache.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/model/gpu_performance_model.h"
+#include "xla/service/gpu/model/gpu_performance_model_base.h"
+#include "xla/service/gpu/model/symbolic_tile_analysis.h"
+#include "xla/service/hlo_graph_dumper.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/blocking_counter.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+bool ElementIsF32OrF16(const Shape& shape) {
+ PrimitiveType type = shape.element_type();
+ return type == F32 || type == F16;
+}
+
+bool IsFusible(const HloInstruction& instr) {
+ // Side-effecting operations are not fusible.
+ if (!instr.IsFusible()) {
+ return false;
+ }
+
+ // Element-wise operations are always fusible.
+ if (instr.IsElementwise()) {
+ return true;
+ }
+
+ // Other non-elementwise ops also supported by elemental fusion.
+ switch (instr.opcode()) {
+ case HloOpcode::kFusion:
+ return instr.fusion_kind() != HloInstruction::FusionKind::kCustom;
+
+ case HloOpcode::kCopy:
+ case HloOpcode::kIota:
+ case HloOpcode::kConstant:
+ case HloOpcode::kReduce:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kGather:
+ case HloOpcode::kPad:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kScatter:
+ case HloOpcode::kSlice:
+ case HloOpcode::kTranspose:
+ return true;
+ default:
+ return false;
+ }
+}
+
+// An implementation of FusionQueue that determines whether to fuse instructions
+// according to a cost model, and chooses the next fusion candidate according to
+// dynamically updated priorities. The elements in the queue are producer nodes
+// that could be fused, and the priority of a producer is the benefit in
+// performance when fusing it to all of its fusible users. We greedily pick the
+// max-benefit producer to fuse, and update the estimated benefits of the fused
+// nodes and their operands.
+class PriorityFusionQueue {
+ using Priority = int64_t;
+ using CanFuseCallback = std::function<FusionDecision(
+ HloInstruction* /*producer*/, int64_t /*consumer operand_index*/)>;
+
+ public:
+ PriorityFusionQueue(HloComputation* computation,
+ const GpuHloCostAnalysis::Options& cost_analysis_options,
+ const se::DeviceDescription* device_info,
+ FusionProcessDumpProto* fusion_process_dump,
+ tsl::thread::ThreadPool* thread_pool,
+ mlir::MLIRContext* mlir_context,
+ HloFusionAnalysisCache& fusion_analysis_cache,
+ bool triton_softmax_priority_fusion_enabled)
+ : computation_(computation),
+ device_info_(device_info),
+ cost_analysis_(cost_analysis_options, *device_info),
+ fusion_process_dump_(fusion_process_dump),
+ thread_pool_(thread_pool),
+ mlir_context_(mlir_context),
+ fusion_analysis_cache_(fusion_analysis_cache),
+ triton_softmax_priority_fusion_enabled_(
+ triton_softmax_priority_fusion_enabled) {
+ VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
+ TF_CHECK_OK(computation_->Accept(&cost_analysis_));
+
+ dump_fusion_visualization_ = computation->parent()
+ ->config()
+ .debug_options()
+ .xla_dump_fusion_visualization();
+
+ // Initializes the priority queue.
+ std::vector<HloInstruction*> instructions;
+ for (auto* instruction : computation->MakeInstructionPostOrder()) {
+ UpdatePerformanceModelCache(instruction);
+ if (instruction->opcode() == HloOpcode::kParameter ||
+ instruction->user_count() == 0 || !instruction->IsFusible() ||
+ instruction->opcode() == HloOpcode::kTuple ||
+ instruction->opcode() == HloOpcode::kGetTupleElement) {
+ continue;
+ }
+ instructions.push_back(instruction);
+ }
+ ComputeAndSetPriorities(instructions);
+ }
+
+ void ComputeAndSetPriorities(
+ const std::vector<HloInstruction*>& instructions) {
+ std::vector<Priority> priorities = ComputePriorities(instructions);
+
+ for (auto [instruction, priority] : llvm::zip(instructions, priorities)) {
+ auto key = std::make_pair(priority, instruction->unique_id());
+
+ // Remove instruction with the old priority from the queue.
+ auto reverse_it = reverse_map_.find(instruction);
+ if (reverse_it != reverse_map_.end()) {
+ const PriorityQueue::iterator& queue_it = reverse_it->second;
+ // Priority didn't change. Nothing to do.
+ if (key == queue_it->first) {
+ continue;
+ }
+ producer_priority_queue_.erase(queue_it);
+ reverse_map_.erase(reverse_it);
+ }
+
+ // If the priority is negative, it's not helpful to perform fusion on this
+ // instruction.
+ if (priority < 0) {
+ continue;
+ }
+
+ auto emplace_result = producer_priority_queue_.emplace(key, instruction);
+ reverse_map_.emplace(instruction, emplace_result.first);
+ }
+ }
+
+ std::vector<Priority> ComputePriorities(
+ const std::vector<HloInstruction*>& instructions) {
+ auto schedule_or_run = [this](std::function<void()> fn) {
+ if (thread_pool_) {
+ thread_pool_->Schedule(std::move(fn));
+ } else {
+ fn();
+ }
+ };
+ tsl::BlockingCounter counter(instructions.size());
+ std::vector<Priority> priorities(instructions.size());
+
+ for (size_t i = 0; i < instructions.size(); ++i) {
+ schedule_or_run([&, i] {
+ priorities[i] = CalculateProducerPriority(instructions[i]);
+ counter.DecrementCount();
+ });
+ }
+ counter.Wait();
+ return priorities;
+ }
+
+ // Gets the next pair of (producer, consumers) from the queue for fusion.
+ // Returns true if there is the next producer to fuse, otherwise false. Stores
+ // the producer and consumers in `current_producer_` and `current_consumers_`.
+ bool DequeueNextProducer() {
+ current_producer_ = nullptr;
+ current_consumers_.clear();
+
+ while (!producer_priority_queue_.empty() && current_consumers_.empty()) {
+ auto next_it = std::prev(producer_priority_queue_.end());
+
+ current_producer_ = next_it->second;
+ producer_priority_queue_.erase(next_it);
+ reverse_map_.erase(current_producer_);
+
+ current_consumers_ = current_producer_->users();
+
+ if (current_producer_->opcode() == HloOpcode::kBitcast) {
+ // We don't check if bitcasts can be fused with all consumers, so we
+ // have to do it here.
+ llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) {
+ return !CanFuseCached(current_producer_, consumer);
+ });
+ }
+ }
+
+ return !current_consumers_.empty();
+ }
+
+ void UpdatePerformanceModelCache(HloInstruction* producer) {
+ if (!IsFusible(*producer) && !IsGenericTritonFusion(*producer)) {
+ return;
+ }
+
+ auto config = GpuPerformanceModelOptions::PriorityFusion(
+ &fusion_analysis_cache_, &gpu_performance_model_cache_);
+
+ if (!gpu_performance_model_cache_.Get(*producer)) {
+ auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction(
+ producer, *device_info_, &cost_analysis_, config);
+ gpu_performance_model_cache_.Set(*producer, runtime_data);
+ }
+ }
+
+ // Update priorities of all affected ops.
+ void UpdatePriorities() {
+ // Revisit costs of all updated ops. It's important to update cost analysis
+ // before recalculating priorities.
+ for (auto instruction : to_update_priority_) {
+ TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction));
+ }
+ for (auto producer : to_update_priority_) {
+ UpdatePerformanceModelCache(producer);
+ }
+
+ ComputeAndSetPriorities(std::vector<HloInstruction*>{
+ to_update_priority_.begin(), to_update_priority_.end()});
+
+ to_update_priority_.clear();
+ }
+
+ // Prepares producer and consumer instruction to be fused. Invalidates caches
+ // and writes logs.
+ void PreFusion(HloInstruction* producer, HloInstruction* consumer) {
+ if (dump_fusion_visualization_) {
+ RegisterFusionState(
+ *computation_,
+ absl::StrCat("About to fuse |", producer->name(), "| into |",
+ consumer->name(), "| inside PriorityFusion"),
+ *consumer, producer);
+ }
+ InvalidateCaches(consumer);
+ }
+
+ // Invalidates all cached value related to this instruction. Called before the
+ // instruction is fused. The instruction can be either producer or consumer.
+ void InvalidateCaches(HloInstruction* instruction) {
+ can_fuse_cache_.erase(instruction);
+ for (const HloInstruction* operand : instruction->operands()) {
+ auto it = can_fuse_cache_.find(operand);
+ if (it != can_fuse_cache_.end()) {
+ it->second.erase(instruction);
+ }
+ }
+
+ gpu_performance_model_cache_.Invalidate(*instruction);
+ fusion_analysis_cache_.Invalidate(*instruction);
+ fusion_info_cache_.Invalidate(instruction);
+ }
+
+ // Updates data for the new fusion instruction and its users and operands.
+ void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {
+ if (fusion_process_dump_) {
+ auto* fusion_step =
+ fusion_process_dump_->add_fusion_steps()->mutable_fusion();
+
+ // Explicit std::string is needed for OSS proto implementation.
+ fusion_step->set_fusion_name(std::string(fusion->name()));
+ fusion_step->set_producer_name(std::string(original_producer->name()));
+ fusion_step->set_consumer_name(std::string(original_consumer->name()));
+ }
+
+ if (dump_fusion_visualization_) {
+ RegisterFusionState(
+ *computation_,
+ absl::StrCat("Fused |", original_producer->name(), "| into |",
+ fusion->name(), "| inside PriorityFusion"),
+ *fusion);
+ }
+
+ // The original consumer was replaced with the fusion, but it's pointer can
+ // still be referenced somewhere, for example, in to_update_priority_.
+ // Priority recomputation is called before DCE. Remove all references to
+ // the original consumer here.
+ if (fusion != original_consumer) {
+ RemoveInstruction(original_consumer);
+ }
+
+ // Detach 'original_producer' from its operands if it has no users.
+ // This avoids having it appear as a "phantom" user in subsequent priority
+ // calculations on 'fusion.operands' below, before it is finally removed
+ // in 'RemoveInstruction'.
+ if (original_producer->user_count() == 0) {
+ InvalidateCaches(original_producer);
+ original_producer->DetachFromOperandsAndUsers();
+ }
+
+ // Collect the instructions whose priorities need to be updated.
+ for (HloInstruction* operand : fusion->operands()) {
+ if (operand == original_producer ||
+ operand->opcode() == HloOpcode::kConstant ||
+ operand->opcode() == HloOpcode::kGetTupleElement) {
+ continue;
+ }
+ // Need to consider only instructions that are fusible, e.g., rng with
+ // greater than one user is not fusible.
+ if (!operand->IsFusible()) {
+ continue;
+ }
+
+ to_update_priority_.insert(operand);
+ }
+ to_update_priority_.insert(fusion);
+ }
+
+ // Removes data for the instruction.
+ void RemoveInstruction(HloInstruction* instruction) {
+ to_update_priority_.erase(instruction);
+ fusion_analysis_cache_.Invalidate(*instruction);
+
+ auto reverse_it = reverse_map_.find(instruction);
+ if (reverse_it == reverse_map_.end()) {
+ return;
+ }
+ producer_priority_queue_.erase(reverse_it->second);
+ reverse_map_.erase(reverse_it);
+ }
+
+ HloInstruction* current_producer() { return current_producer_; }
+
+ const std::vector<HloInstruction*>& current_consumers() {
+ return current_consumers_;
+ }
+
+ private:
+ // Returns the priority of the producer based on its current operands and
+ // users.
+ Priority CalculateProducerPriority(HloInstruction* producer) {
+ // Bitcasts should always be fused first, since they are no-ops.
+ if (producer->opcode() == HloOpcode::kBitcast) {
+ return std::numeric_limits<Priority>::max();
+ }
+ // We always fuse constants, but the cost model doesn't handle them very
+ // well: fusing constants changes costs significantly. Also, there's no
+ // point recomputing priorities. Therefore, we fuse all of them at the end.
+ if (producer->opcode() == HloOpcode::kConstant) {
+ return std::numeric_limits<Priority>::min();
+ }
+
+ // Don't fuse if we can't fuse in all users.
+ if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer);
+ !fusion_decision) {
+ if (fusion_process_dump_) {
+ absl::MutexLock lock(&fusion_process_dump_mutex_);
+ auto* step = fusion_process_dump_->add_fusion_steps()
+ ->mutable_producer_ineligible();
+ step->set_producer_name(std::string(producer->name()));
+ step->set_reason(fusion_decision.Explain());
+ }
+ return std::numeric_limits<Priority>::min();
+ }
+
+ GpuPerformanceModel::RunTimes run_times =
+ GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
+ producer, *device_info_, &cost_analysis_,
+ GpuPerformanceModelOptions::PriorityFusion(
+ &fusion_analysis_cache_, &gpu_performance_model_cache_),
+ producer->users());
+
+ if (fusion_process_dump_) {
+ absl::MutexLock lock(&fusion_process_dump_mutex_);
+ auto* step =
+ fusion_process_dump_->add_fusion_steps()->mutable_update_priority();
+ step->set_producer_name(std::string(producer->name()));
+ for (auto* consumer : producer->users()) {
+ step->add_consumer_names(std::string(consumer->name()));
+ }
+ step->set_us_fused(absl::ToDoubleMicroseconds(run_times.time_fused));
+ step->set_us_unfused(absl::ToDoubleMicroseconds(run_times.time_unfused));
+ }
+ return absl::ToInt64Nanoseconds(run_times.time_unfused -
+ run_times.time_fused);
+ }
+
+ FusionDecision CanFuseTriton(HloInstruction* producer,
+ HloInstruction* consumer) {
+ if (!triton_softmax_priority_fusion_enabled_) {
+ return "triton softmax fusion is not enabled";
+ }
+
+ if (IsGenericTritonFusion(*producer)) {
+ if (!IsFusible(*consumer)) {
+ return "the consumer is not fusible";
+ }
+ } else {
+ if (!IsFusible(*producer)) {
+ return "the producer is not fusible";
+ }
+ }
+
+ auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer);
+
+ SymbolicTileAnalysisOrError symbolic_tile_analysis_or =
+ SymbolicTileAnalysis::AnalyzeFusion(*fusion, mlir_context_);
+
+ if (const auto* fusion_decision =
+ std::get_if<FusionDecision>(&symbolic_tile_analysis_or)) {
+ return {
+ absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ",
+ fusion_decision->Explain())};
+ }
+
+ return {};
+ }
+
+ FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) {
+ if (IsGenericTritonFusion(*producer) || IsGenericTritonFusion(*consumer)) {
+ return CanFuseTriton(producer, consumer);
+ }
+
+ if (!IsFusible(*producer)) {
+ return "the producer is not fusible";
+ }
+
+ if (!IsFusible(*consumer)) {
+ return "the consumer is not fusible";
+ }
+
+ if (consumer->opcode() == HloOpcode::kBitcast) {
+ return "not fusing into a single bitcast as consumer";
+ }
+
+ // Scatter is special as it has no elemental version but is still input
+ // fusible. Block attempts to create scatter fusions we can't codegen.
+ if (auto can_fuse = CanEmitInputFusedScatter(*producer, *consumer);
+ !can_fuse) {
+ return can_fuse;
+ }
+
+ // Avoid fusing reduce into reduce. Our cost model doesn't currently
+ // understand this case due to a lack of tiling analysis.
+ // TODO(b/312200883): Remove this.
+ auto contains_significant_reduce = [&](const HloInstruction* instr) {
+ auto fusion = HloFusionAdaptor::ForInstruction(instr);
+ return HloAnyOf(*fusion, [](auto node) {
+ if (!(node.opcode() == HloOpcode::kReduce && node.shape().IsArray())) {
+ return false;
+ }
+
+ int64_t reduction_size =
+ ShapeUtil::ElementsIn(node.instruction().operand(0)->shape()) /
+ ShapeUtil::ElementsIn(node.shape());
+
+ // Small reductions are emitted using the elemental emitter anyway.
+ return reduction_size >= 16;
+ });
+ };
+ if (contains_significant_reduce(producer) &&
+ contains_significant_reduce(consumer)) {
+ return "both the producer and the consumer contain a reduce";
+ }
+
+ // Avoid doing fusions into the output of an "input" fusion when it would
+ // switch it to the loop emitter. This often occurs during epilog fusion for
+ // reductions, which suffer from limited emitter support.
+ // TODO(b/312686229): Cost model should handle this.
+ const auto& analysis = fusion_analysis_cache_.Get(*producer);
+ if (analysis.GetEmitterFusionKind() ==
+ HloFusionAnalysis::EmitterFusionKind::kReduction) {
+ const auto& analysis_fused =
+ fusion_analysis_cache_.Get(*producer, *consumer);
+ if (analysis_fused.GetEmitterFusionKind() ==
+ HloFusionAnalysis::EmitterFusionKind::kLoop) {
+ return "fusion into output of a reduce fusion would create a loop "
+ "fusion";
+ }
+ }
+
+ // Avoid cases where we'd create a fusion that hit limitations in ptxas.
+ // Would be nice to model this with cost instead.
+ if (auto fits_budget = FusionFitsInBudget(
+ *consumer, *producer, *device_info_,
+ /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
+ !fits_budget) {
+ return fits_budget;
+ }
+
+ // Also check that our emitter can handle the fusion node. We currently can
+ // have exponential time/memory requirements for emitting certain fusion
+ // kernels, in which case we don't want to fuse.
+ // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
+ if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) {
+ return "the fusion would result in an overly large code duplication";
+ }
+
+ // Don't fuse across a root instruction. There are situation when a root
+ // instruction is not the last in the computation. Instructions after the
+ // root are not necessary dead. They can be inputs to instructions with side
+ // effects, like outfeed.
+ if (producer == producer->parent()->root_instruction()) {
+ return "not fusing into the output of the root instruction";
+ }
+
+ return InstructionFusion::ShouldFuseInPlaceOp(producer, consumer);
+ }
+
+ FusionDecision CanFuseCached(HloInstruction* producer,
+ HloInstruction* consumer) {
+ {
+ absl::MutexLock lock(&can_fuse_cache_mutex_);
+ auto& producer_cache = can_fuse_cache_[producer];
+
+ auto it = producer_cache.find(consumer);
+ if (it != producer_cache.end()) {
+ return it->second;
+ }
+ }
+ auto fusion_decision = CanFuse(producer, consumer);
+
+ // The lock is required, because writing to a flat_hash_map is not
+ // thread-safe even for different keys. We never call this computation
+ // concurrently for the same producer, so it's guaranteed that we don't
+ // override any value.
+ {
+ absl::MutexLock lock(&can_fuse_cache_mutex_);
+ can_fuse_cache_[producer][consumer] = fusion_decision;
+ }
+
+ return fusion_decision;
+ }
+
+ FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) {
+ if (producer->users().empty()) {
+ return "No users to fuse";
+ }
+
+ FusionDecision result;
+ bool has_non_bitcast_user = false;
+ for (const auto& user : producer->users()) {
+ if (user->opcode() == HloOpcode::kBitcast) {
+ continue;
+ }
+ has_non_bitcast_user = true;
+ if (auto fusion_decision = CanFuseCached(producer, user);
+ !fusion_decision) {
+ VLOG(10) << "Cannot fuse " << producer->name() << " with "
+ << user->name() << ", because: " << fusion_decision.Explain();
+ return fusion_decision;
+ }
+ }
+ if (!has_non_bitcast_user) {
+ return "not fusing because there are only bitcast users";
+ }
+ return {};
+ }
+
+ // Store computation for cost analysis.
+ HloComputation* computation_;
+
+ const se::DeviceDescription* device_info_;
+
+ // Reference to cost model that defines priorities in the queue.
+ GpuHloCostAnalysis cost_analysis_;
+
+ // The priority queue of producers, implemented as an ordered map, where a
+ // key is a pair: the first element is the priority and the second element is
+ // the unique ID of the instruction to break ties.
+ using PriorityQueue = std::map<std::pair<Priority, int>, HloInstruction*>;
+ PriorityQueue producer_priority_queue_;
+
+ // A reverse map that helps find an instruction in the priority queue.
+ absl::flat_hash_map<HloInstruction*, PriorityQueue::iterator> reverse_map_;
+
+ // The current producer being visited.
+ HloInstruction* current_producer_;
+
+ // The current consumers being visited.
+ std::vector<HloInstruction*> current_consumers_;
+
+ // The set of producers whose priorities need to be updated. Their
+ // priorities are changed because their neighbors got fused, but we delay
+ // the priority updates until current_consumers_ becomes empty. This is to
+ // avoid recomputing priorities multiple times before we dequeue a new
+ // producer.
+ absl::flat_hash_set<HloInstruction*> to_update_priority_;
+
+ // Proto with structured logs of fusion decisions. Used only for debugging. If
+ // null, logging is disabled.
+ FusionProcessDumpProto* fusion_process_dump_;
+ absl::Mutex fusion_process_dump_mutex_;
+
+ tsl::thread::ThreadPool* thread_pool_;
+
+ mlir::MLIRContext* mlir_context_;
+
+ HloFusionAnalysisCache& fusion_analysis_cache_;
+
+ // Caches result of can_fuse for a (producer, consumer) pair. A cache entry is
+ // invalidated if producer or consumer is modified.
+ absl::flat_hash_map<
+ const HloInstruction*,
+ absl::flat_hash_map<const HloInstruction*, FusionDecision>>
+ can_fuse_cache_;
+ absl::Mutex can_fuse_cache_mutex_;
+
+ GpuPerformanceModelCache gpu_performance_model_cache_;
+
+ // Cache for `FusionFitsInBudget` to avoid recomputing expensive properties
+ // like shared memory usage or number of unnested reductions of fusion nodes.
+ FusionInfoCache fusion_info_cache_;
+
+ bool triton_softmax_priority_fusion_enabled_;
+
+ bool dump_fusion_visualization_;
+};
+
+} // namespace
+
+/*static*/ bool PriorityFusion::IsExpensive(const HloInstruction& instruction) {
+ // Some floating-point math ops are cheap on the GPU.
+ switch (instruction.opcode()) {
+ case HloOpcode::kDivide:
+ case HloOpcode::kSqrt:
+ case HloOpcode::kRsqrt:
+ case HloOpcode::kExp:
+ if (ElementIsF32OrF16(instruction.shape())) {
+ return false;
+ }
+ break;
+ // Loop fusions are cheap.
+ case HloOpcode::kFusion:
+ return false;
+ default:
+ break;
+ }
+ return InstructionFusion::IsExpensive(instruction);
+}
+
+// Return true, if instr is a small constant.
+//
+// There is not single definition for what is a small constant in XLA.
+// IrEmitterContext::emit_constant treats as small only constants of 1 element.
+// HloPrintOptions::print_large_constants is effective for constants larger
+// than 10 elements.
+//
+// This function matches the emitter logic.
+bool IsSmallConstant(const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kConstant && instr->shape().IsArray() &&
+ ShapeUtil::ElementsIn(instr->shape()) <= 1;
+}
+
+bool PriorityFusion::ConsumeFuel(HloInstruction* producer,
+ HloInstruction* consumer) {
+ return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] {
+ return absl::StrFormat("Not fusing producer %s with consumer %s",
+ producer->name(), consumer->name());
+ });
+};
+
+absl::StatusOr<bool> PriorityFusion::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool dump_enabled =
+ DumpingEnabledForHloPass(name(), module->config().debug_options());
+ if (dump_enabled) {
+ fusion_process_dump_ = std::make_unique<FusionProcessDumpProto>();
+ *fusion_process_dump_->mutable_gpu_device_info() =
+ device_info_.ToGpuProto();
+ }
+
+ // Compute the computations within which more fusion is possible.
+ auto fusible_computations =
+ GetFusibleComputations(*module, execution_threads);
+
+ // Appends ".0" suffix to all instructions.
+ //
+ // Every time an instruction is duplicated, the last integer suffix is
+ // incremented.
+ // Before: broadcast.123 -> broadcast.124
+ // After: broadcast.123.0 -> broadcast.123.1
+ //
+ // With this modification it will be easier to match instructions before and
+ // after fusion passes, because they will have the same unique prefix. Names
+ // are not used in the pipeline, but it makes debugging much easier.
+ for (auto* computation : fusible_computations) {
+ for (auto* instruction : computation->instructions()) {
+ module->SetAndUniquifyInstrName(instruction,
+ absl::StrCat(instruction->name(), ".0"));
+ }
+ }
+
+ if (dump_enabled) {
+ fusion_process_dump_->set_hlo_module_before_fusion(
+ module->ToString(HloPrintOptions::ShortParsable()));
+ }
+
+ bool triton_softmax_priority_fusion_enabled =
+ module->config()
+ .debug_options()
+ .xla_gpu_enable_triton_softmax_priority_fusion();
+
+ int changed = false;
+ for (auto* computation : fusible_computations) {
+ CHECK(!computation->IsFusionComputation());
+
+ auto fusion_queue = std::make_unique<PriorityFusionQueue>(
+ computation, cost_analysis_options_, &device_info_,
+ fusion_process_dump_.get(), thread_pool_, &mlir_context_,
+ fusion_analysis_cache_, triton_softmax_priority_fusion_enabled);
+
+ while (fusion_queue->DequeueNextProducer()) {
+ auto producer = fusion_queue->current_producer();
+
+ for (auto* consumer : fusion_queue->current_consumers()) {
+ // Don't fuse into single bitcasts. We ignore them in the check
+ // CanFuseWithAllNonBitcastUsers(), so we need to check it here.
+ if (consumer->opcode() == HloOpcode::kBitcast) {
+ continue;
+ }
+ if (!ConsumeFuel(producer, consumer)) continue;
+
+ VLOG(5) << "next: " << consumer->name() << "(" << consumer << ") + "
+ << producer->name() << "(" << producer << ")";
+
+ fusion_queue->PreFusion(producer, consumer);
+ auto fusion_instruction = Fuse(producer, consumer, computation);
+ fusion_queue->OnFusingInstruction(fusion_instruction, producer,
+ consumer);
+
+ changed = true;
+ }
+
+ if (producer->user_count() == 0) {
+ fusion_queue->RemoveInstruction(producer);
+ // Remove from computation.
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer));
+ }
+
+ fusion_queue->UpdatePriorities();
+ }
+
+ // Fuse all constants.
+ std::vector<HloInstruction*> constants;
+ for (auto* instruction : computation->instructions()) {
+ // Small constants should be fused, because they can be folded and
+ // codegened efficiently.
+ // Fusing large constants doesn't give much benefits, because they're
+ // treated like parameters and read from global memory anyway. Fusion
+ // and duplication of large constants can, however, cause problems if we
+ // want to dump hlo and parse back, because in that case duplicated
+ // constants will be filled with different data.
+ if (IsSmallConstant(instruction)) {
+ constants.push_back(instruction);
+ }
+ }
+ for (auto* constant : constants) {
+ auto users = constant->users();
+ for (auto* user : users) {
+ if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) {
+ Fuse(constant, user, computation);
+ changed = true;
+ }
+ }
+ }
+ }
+
+ // FusionAnalysis cache uses unique_id as key. IDs are only unique inside one
+ // module. It's important to fully clear the cache if the same instance of the
+ // pass will be called on a different module.
+ fusion_analysis_cache_.Clear();
+
+ if (dump_enabled) {
+ DumpPerModuleProtobufToFile(*module, *fusion_process_dump_,
+ module->config().debug_options(),
+ "priority_fusion_dump");
+ }
+
+ return changed;
+}
+
+FusionDecision PriorityFusion::ShouldFuse(HloInstruction* consumer,
+ int64_t operand_index) {
+ // This method is called in `InstructionFusion::Run` right before fusion, but
+ // it will always return true. Fusion decision are fully controlled by the
+ // PriorityQueue. If the queue returns a producer that shouldn't be fused,
+ // it's a bug and should be fixed in the queue logic.
+ return {};
+}
+
+HloInstruction::FusionKind PriorityFusion::ChooseKind(
+ const HloInstruction* producer, const HloInstruction* consumer) {
+ // Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't
+ // matter but some passes downstream still query these instead of fusion
+ // analysis.
+ const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer);
+ switch (analysis.GetEmitterFusionKind()) {
+ case HloFusionAnalysis::EmitterFusionKind::kLoop:
+ return HloInstruction::FusionKind::kLoop;
+ case HloFusionAnalysis::EmitterFusionKind::kTriton:
+ case HloFusionAnalysis::EmitterFusionKind::kCustomFusion:
+ case HloFusionAnalysis::EmitterFusionKind::kCuDnn:
+ return HloInstruction::FusionKind::kCustom;
+ case HloFusionAnalysis::EmitterFusionKind::kConcatenate:
+ case HloFusionAnalysis::EmitterFusionKind::kReduction:
+ case HloFusionAnalysis::EmitterFusionKind::kTranspose:
+ case HloFusionAnalysis::EmitterFusionKind::kInputSlices:
+ case HloFusionAnalysis::EmitterFusionKind::kScatter:
+ return HloInstruction::FusionKind::kInput;
+ }
+}
+
+HloInstruction* PriorityFusion::FuseInstruction(
+ HloInstruction* fusion_instruction, HloInstruction* producer) {
+ HloInstruction* result = fusion_instruction;
+ if (producer->opcode() == HloOpcode::kFusion) {
+ if (IsGenericTritonFusion(*producer)) {
+ TF_CHECK_OK(fusion_instruction->set_backend_config(
+ *producer->backend_config<GpuBackendConfig>()));
+ }
+
+ fusion_instruction->MergeFusionInstruction(producer);
+ } else {
+ result = InstructionFusion::FuseInstruction(fusion_instruction, producer);
+ }
+ return result;
+}
+
+std::unique_ptr<FusionQueue> PriorityFusion::GetFusionQueue(
+ HloComputation* computation) {
+ return nullptr;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h
new file mode 100644
index 0000000..fce2be5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h
@@ -0,0 +1,100 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/gpu/fusion_process_dump.pb.h"
+#include "xla/service/gpu/model/fusion_analysis_cache.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+
+class PriorityFusion : public InstructionFusion {
+ public:
+ PriorityFusion(tsl::thread::ThreadPool* thread_pool,
+ const se::DeviceDescription& device,
+ GpuHloCostAnalysis::Options cost_analysis_options)
+ : InstructionFusion(PriorityFusion::IsExpensive),
+ thread_pool_(thread_pool),
+ device_info_(device),
+ cost_analysis_options_(std::move(cost_analysis_options)),
+ fusion_analysis_cache_(device_info_) {}
+
+ absl::string_view name() const override { return "priority-fusion"; }
+
+ static bool IsExpensive(const HloInstruction& instruction);
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ protected:
+ std::unique_ptr<FusionQueue> GetFusionQueue(
+ HloComputation* computation) override;
+
+ FusionDecision ShouldFuse(HloInstruction* consumer,
+ int64_t operand_index) override;
+
+ HloInstruction::FusionKind ChooseKind(
+ const HloInstruction* producer, const HloInstruction* consumer) override;
+
+ private:
+ HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
+ HloInstruction* producer) override;
+
+ // Consumes a unit of compiler fuel and returns true if we should
+ // continue with the transformation.
+ bool ConsumeFuel(HloInstruction* producer, HloInstruction* consumer);
+
+ tsl::thread::ThreadPool* thread_pool_;
+ se::DeviceDescription device_info_;
+
+ // Cost model options that defines priorities in the queue.
+ GpuHloCostAnalysis::Options cost_analysis_options_;
+
+ // Proto with structured logs of fusion decisions. Used only for debugging. If
+ // null, logging is disabled.
+ std::unique_ptr<FusionProcessDumpProto> fusion_process_dump_;
+
+ HloFusionAnalysisCache fusion_analysis_cache_;
+
+ mlir::MLIRContext mlir_context_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc
new file mode 100644
index 0000000..4abd182
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc
@@ -0,0 +1,942 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/priority_fusion.h"
+
+#include <stdint.h>
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace m = ::xla::match;
+
+using ::testing::UnorderedElementsAre;
+using ::tsl::testing::IsOk;
+using ::tsl::testing::IsOkAndHolds;
+
+namespace xla {
+namespace gpu {
+
+class PriorityFusionTest : public HloTestBase {
+ HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
+ return [&](const Shape& shape) {
+ constexpr int64_t kPointerSize = 8;
+ return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+ };
+ }
+
+ public:
+ std::vector<HloFusionAnalysis::EmitterFusionKind> RunAndGetFusionKinds(
+ absl::string_view hlo) {
+ auto module = ParseAndReturnVerifiedModule(hlo).value();
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+ EXPECT_THAT(module->RemoveUnusedComputations(), IsOk());
+ std::vector<HloFusionAnalysis::EmitterFusionKind> kinds;
+ for (auto computation : module->computations()) {
+ if (!computation->FusionInstruction()) continue;
+
+ auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+ auto analysis = HloFusionAnalysis::Create(
+ *computation->FusionInstruction(), device_info);
+ kinds.push_back(analysis.GetEmitterFusionKind());
+ }
+ return kinds;
+ }
+
+ PriorityFusion priority_fusion_{
+ /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+ GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(),
+ /*per_second_rates=*/{},
+ /*count_multiple_input_accesses=*/true}};
+};
+
+TEST_F(PriorityFusionTest, FuseWithSharedArgument) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ ENTRY main {
+ %p0 = f32[] parameter(0)
+ %p1 = f32[] parameter(1)
+ %subtract = f32[] subtract(%p0, %p1)
+ %compare = pred[] compare(%subtract, %subtract), direction=NE
+ %add = f32[] add(%p0, %p1)
+ %abs = f32[] abs(%subtract)
+ ROOT %select = f32[] select(%compare, %add, %abs)
+ })")
+ .value();
+
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::Fusion()));
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+}
+
+TEST_F(PriorityFusionTest, FusionFusionWithDuplication) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ square {
+ p = f32[16384]{0} parameter(0)
+ ROOT m = f32[16384]{0} multiply(p, p)
+ }
+
+ exp {
+ p = f32[16384]{0} parameter(0)
+ ROOT e = f32[16384]{0} exponential(p)
+ }
+
+ log {
+ p = f32[16384]{0} parameter(0)
+ ROOT l = f32[16384]{0} log(p)
+ }
+
+ ENTRY main {
+ p = f32[16384]{0} parameter(0)
+ s = f32[16384]{0} fusion(p), kind=kLoop, calls=square
+ e = f32[16384]{0} fusion(s), kind=kLoop, calls=exp
+ l = f32[16384]{0} fusion(s), kind=kInput, calls=log
+ ROOT t = (f32[16384], f32[16384]) tuple(l, e)
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ENTRY
+CHECK-NEXT: %[[PARAM:.*]] = f32[16384]{0} parameter(0)
+CHECK-NEXT: %[[FUSION_0:.*]] = f32[16384]{0} fusion(%[[PARAM]])
+CHECK-NEXT: %[[FUSION_1:.*]] = f32[16384]{0} fusion(%[[PARAM]])
+CHECK-NEXT: ROOT {{.*}} tuple(%[[FUSION_0]], %[[FUSION_1]])
+ )");
+}
+
+TEST_F(PriorityFusionTest, FuseBroadcastIntoBitcastConsumers) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ ENTRY main {
+ param_0 = f32[96]{0} parameter(0)
+ broadcast = f32[8,96,128,7]{3,2,1,0} broadcast(param_0), dimensions={1}
+ bitcast.6079.2 = f32[8,24,4,128,7]{4,3,2,1,0} bitcast(broadcast)
+ ROOT transpose.1990.2 = f32[8,24,128,7,4]{4,3,2,1,0} transpose(bitcast.6079.2), dimensions={0,1,3,4,2}
+ }
+ )";
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ENTRY
+CHECK-NEXT: %[[PARAM:.*]] = f32[96]{0} parameter(0)
+CHECK-NEXT: ROOT %{{.*}} fusion(%[[PARAM]])
+ )");
+}
+
+TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ ENTRY main {
+ p = f16[512]{0} parameter(0)
+ a = f16[512]{0} add(p, p)
+ c = f32[512]{0} convert(a)
+ s = f32[512]{0} multiply(c, c)
+ bc = s32[512]{0} bitcast(c)
+ ROOT t = (f32[512], s32[512]) tuple(s, bc)
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ENTRY
+CHECK-NEXT: %[[PARAM:.*]] = f16[512]{0} parameter(0)
+CHECK-NEXT: %[[FUSION_F32:.*]] = f32[512]{0} fusion(%[[PARAM]])
+CHECK-NEXT: %[[CONVERT_FUSION:.*]] = f32[512]{0} fusion(%[[PARAM]])
+CHECK-NEXT: %[[BITCAST:.*]] = s32[512]{0} bitcast(%[[CONVERT_FUSION]])
+CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[BITCAST]])
+ )");
+}
+
+TEST_F(PriorityFusionTest, FuseConvertIntoReduce) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add.13235 = f32[] add(p0, p1)
+ }
+
+ ENTRY main {
+ param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
+ param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
+ param_2.483 = f32[8192]{0} parameter(2)
+ param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
+ convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
+ convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
+ constant_7773 = f32[] constant(0)
+ broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
+ multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
+ reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
+ convert.13970 = bf16[1024]{0} convert(reduce.4813)
+ convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
+ multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
+ reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
+ convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
+ multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
+ reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
+ convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
+ ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK-COUNT-3: ROOT {{.*}} convert(
+CHECK: ENTRY %main
+CHECK-COUNT-3: fusion
+ )");
+}
+
+TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) {
+ // Regression test for epilogue fusion of convert into a reduction, even if
+ // the convert has a bitcast as consumer.
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ rhs.407 = f32[] parameter(1)
+ lhs.407 = f32[] parameter(0)
+ ROOT add.24451 = f32[] add(lhs.407, rhs.407)
+ }
+
+ ENTRY main {
+ param_1.15162 = f32[2752]{0} parameter(1)
+ convert.44829 = bf16[2752]{0} convert(param_1.15162)
+ bitcast.24686 = bf16[1,1,2752]{2,1,0} bitcast(convert.44829)
+ convert.44468 = f32[1,1,2752]{2,1,0} convert(bitcast.24686)
+ constant_13722 = bf16[] constant(1)
+ convert.17451 = f32[] convert(constant_13722)
+ broadcast.17565 = f32[1,1,2752]{2,1,0} broadcast(convert.17451), dimensions={}
+ negate.167 = f32[1,1,2752]{2,1,0} negate(convert.44468)
+ exponential.569 = f32[1,1,2752]{2,1,0} exponential(negate.167)
+ add.1850 = f32[1,1,2752]{2,1,0} add(broadcast.17565, exponential.569)
+ divide.1376 = f32[1,1,2752]{2,1,0} divide(broadcast.17565, add.1850)
+ multiply.9709 = f32[1,1,2752]{2,1,0} multiply(convert.44468, divide.1376)
+ param_0.15005 = f32[2752]{0} parameter(0)
+ convert.44826 = bf16[2752]{0} convert(param_0.15005)
+ bitcast.24683 = bf16[1,1,2752]{2,1,0} bitcast(convert.44826)
+ convert.44467 = f32[1,1,2752]{2,1,0} convert(bitcast.24683)
+ multiply.9708 = f32[1,1,2752]{2,1,0} multiply(multiply.9709, convert.44467)
+ convert.16959 = bf16[1,1,2752]{2,1,0} convert(multiply.9708)
+ fusion.3203 = bf16[2752]{0} bitcast(convert.16959)
+ convert.15093 = f32[2752]{0} convert(fusion.3203)
+ broadcast.13841 = f32[8192,2752]{1,0} broadcast(convert.15093), dimensions={1}
+ param_0.15525 = bf16[8192,2752]{1,0} parameter(2)
+ convert.13738 = f32[8192,2752]{1,0} convert(param_0.15525)
+ multiply.6422 = f32[8192,2752]{1,0} multiply(broadcast.13841, convert.13738)
+ constant_14382 = f32[] constant(0)
+ fusion.339 = f32[8192]{0} reduce(multiply.6422, constant_14382), dimensions={1}, to_apply=add
+ convert.44633 = bf16[8192]{0} convert(fusion.339)
+ ROOT bitcast.24487 = bf16[1,1,8192]{2,1,0} bitcast(convert.44633)
+ }
+ )";
+
+ EXPECT_THAT(
+ RunAndGetFusionKinds(kHlo),
+ UnorderedElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop,
+ HloFusionAnalysis::EmitterFusionKind::kReduction));
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ENTRY
+CHECK: ROOT {{.*}} bitcast({{.*}}fusion{{.*}})
+ )");
+}
+
+TEST_F(PriorityFusionTest, DoNotChangeReductionFusionToLoopFusion) {
+ // Regression test for epilogue fusion of slice into a reduction. The fusion
+ // kind for the reduction fusion is intentionally chosen to be set to kLoop,
+ // as we cannot rely on reductions always having fusion kind kInput.
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ rhs.407 = f32[] parameter(1)
+ lhs.407 = f32[] parameter(0)
+ ROOT add.24451 = f32[] add(lhs.407, rhs.407)
+ }
+
+ fused_computation {
+ p0 = f32[16,64]{1,0} parameter(0)
+ zero = f32[] constant(0.0)
+ ROOT reduce = f32[16]{0} reduce(p0, zero), dimensions={1}, to_apply=add
+ }
+
+ ENTRY main {
+ param0 = f32[16,64]{1,0} parameter(0)
+ fusion = f32[16]{0} fusion(param0), kind=kLoop, calls=fused_computation
+ ROOT slice = f32[8]{0} slice(fusion), slice={[0:8]}
+ })");
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ Arg_1.1046 = f32[] parameter(1)
+ Arg_0.1045 = f32[] parameter(0)
+ ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
+ }
+
+ ENTRY main {
+ param_0.17323 = pred[2048,2048]{1,0} parameter(0)
+ broadcast.22829 = pred[1,12,2048,2048]{3,2,1,0} broadcast(param_0.17323), dimensions={2,3}
+ param_1.19761 = bf16[2048,24576]{1,0} parameter(1)
+ convert.29880.clone.1 = f32[2048,24576]{1,0} convert(param_1.19761)
+ constant_10033_clone_1 = bf16[] constant(0.02002)
+ convert.30056.clone.1 = f32[] convert(constant_10033_clone_1)
+ broadcast.18898.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30056.clone.1), dimensions={}
+ multiply.13451.clone.1 = f32[2048,24576]{1,0} multiply(convert.29880.clone.1, broadcast.18898.clone.1)
+ tanh.798.clone.1 = f32[2048,24576]{1,0} tanh(multiply.13451.clone.1)
+ constant_10244_clone_1 = bf16[] constant(50)
+ convert.30039.clone.1 = f32[] convert(constant_10244_clone_1)
+ broadcast.18310.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30039.clone.1), dimensions={}
+ multiply.12550.clone.1 = f32[2048,24576]{1,0} multiply(tanh.798.clone.1, broadcast.18310.clone.1)
+ convert.29370.clone.1 = bf16[2048,24576]{1,0} convert(multiply.12550.clone.1)
+ bitcast.22330 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
+ transpose.6582 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22330), dimensions={0,3,2,1}
+ convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6582)
+ constant_10212 = f32[] constant(-2.38197633e+38)
+ broadcast.22828 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10212), dimensions={}
+ select.589 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22829, convert.33705, broadcast.22828)
+ bitcast.22075 = f32[12,2048,2048]{2,1,0} bitcast(select.589)
+ constant_10192 = f32[] constant(-inf)
+ reduce.1614 = f32[12,2048]{1,0} reduce(bitcast.22075, constant_10192), dimensions={2}, to_apply=add
+
+ predarg = pred[1,1,2048,2048]{3,2,1,0} parameter(2)
+ bitcast.11069 = pred[2048,2048]{1,0} bitcast(predarg)
+
+ broadcast.22825 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
+ bitcast.22331 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
+ transpose.6580 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22331), dimensions={0,3,2,1}
+ convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6580)
+ constant_10213 = f32[] constant(-2.38197633e+38)
+ broadcast.22824 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10213), dimensions={}
+ select.587 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22825, convert.33703, broadcast.22824)
+ broadcast.22819 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
+ subtract.1129 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.587, broadcast.22819)
+ exponential.418 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1129)
+ bitcast.22074 = f32[12,2048,2048]{2,1,0} bitcast(exponential.418)
+ constant_10490 = f32[] constant(0)
+ reduce.1613 = f32[12,2048]{1,0} reduce(bitcast.22074, constant_10490), dimensions={2}, to_apply=add
+
+ constant_468 = f32[] constant(-2.38197633e+38)
+ broadcast.22833 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
+ bitcast.22332 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
+ transpose.6584 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22332), dimensions={0,3,2,1}
+ convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6584)
+ broadcast.22832 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_468), dimensions={}
+ select.591 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22833, convert.33707, broadcast.22832)
+ broadcast.22821 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
+ subtract.1131 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.591, broadcast.22821)
+ exponential.420 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1131)
+ broadcast.18351 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1613), dimensions={1,2}
+ divide.340 = f32[1,12,2048,2048]{3,2,1,0} divide(exponential.420, broadcast.18351)
+ ROOT convert.29418 = bf16[1,12,2048,2048]{3,2,1,0} convert(divide.340)
+ })";
+
+ using Kind = HloFusionAnalysis::EmitterFusionKind;
+ EXPECT_THAT(
+ RunAndGetFusionKinds(kHlo),
+ UnorderedElementsAre(Kind::kLoop, Kind::kLoop, Kind::kLoop,
+ Kind::kReduction, Kind::kReduction, Kind::kTranspose,
+ Kind::kTranspose, Kind::kTranspose));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduce) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add.13235 = f32[] add(p0, p1)
+ }
+
+ ENTRY main {
+ p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
+ ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} reduce(
+CHECK: ROOT {{.*}} reduce(
+ )");
+}
+
+TEST_F(PriorityFusionTest, ConvertFusedIntoReduce) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add.13235 = f32[] add(p0, p1)
+ }
+
+ ENTRY main {
+ param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
+ param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
+ param_2.483 = f32[8192]{0} parameter(2)
+ param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
+ convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
+ convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
+ constant_7773 = f32[] constant(0)
+ broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
+ multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
+ reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
+ convert.13970 = bf16[1024]{0} convert(reduce.4813)
+ convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
+ multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
+ reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
+ convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
+ multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
+ reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
+ convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
+ ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK-COUNT-3: ROOT {{.*}} convert(
+CHECK: ENTRY %main
+CHECK-COUNT-3: fusion(
+CHECK-NOT: fusion(
+ )");
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseDynamicUpdateSliceIntoReduce) {
+ GTEST_SKIP() << "b/294198633";
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+add {
+ Arg_1.1046 = f32[] parameter(1)
+ Arg_0.1045 = f32[] parameter(0)
+ ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
+}
+
+ENTRY main {
+ param_0.10549 = f32[4,2112]{1,0} parameter(0)
+ param_5.2561 = pred[] parameter(5)
+ broadcast.19725 = pred[4,1]{1,0} broadcast(param_5.2561), dimensions={}
+ param_1.11587 = pred[4]{0} parameter(1)
+ constant_5837 = f32[] constant(1)
+ broadcast.19723 = f32[4]{0} broadcast(constant_5837), dimensions={}
+ param_2.5952 = f32[4,8000]{1,0} parameter(2)
+ param_3.4004 = f32[4]{0} parameter(3)
+ broadcast.19718 = f32[4,8000]{1,0} broadcast(param_3.4004), dimensions={0}
+ subtract.1112 = f32[4,8000]{1,0} subtract(param_2.5952, broadcast.19718)
+ exponential.418 = f32[4,8000]{1,0} exponential(subtract.1112)
+ constant_6254 = f32[] constant(0)
+ reduce.1154 = f32[4]{0} reduce(exponential.418, constant_6254), dimensions={1}, to_apply=add
+ log.38 = f32[4]{0} log(reduce.1154)
+ broadcast.19717 = f32[4,8000]{1,0} broadcast(log.38), dimensions={0}
+ subtract.1111 = f32[4,8000]{1,0} subtract(subtract.1112, broadcast.19717)
+ iota.170 = s32[4,1]{1,0} iota(), iota_dimension=0
+ constant_6281 = s32[] constant(0)
+ broadcast.19735 = s32[4]{0} broadcast(constant_6281), dimensions={}
+ param_4.3400 = s32[4,8000]{1,0} parameter(4)
+ slice.3186 = s32[4,40]{1,0} slice(param_4.3400), slice={[0:4], [0:40]}
+ iota.168 = s32[4,1]{1,0} iota(), iota_dimension=0
+ param_7.1596 = s32[4]{0} parameter(7)
+ compare.341 = pred[4]{0} compare(param_7.1596, broadcast.19735), direction=LT
+ constant_5833 = s32[] constant(40)
+ broadcast.19731 = s32[4]{0} broadcast(constant_5833), dimensions={}
+ add.8348 = s32[4]{0} add(param_7.1596, broadcast.19731)
+ select.418 = s32[4]{0} select(compare.341, add.8348, param_7.1596)
+ bitcast.20942 = s32[4,1]{1,0} bitcast(select.418)
+ concatenate.1337 = s32[4,2]{1,0} concatenate(iota.168, bitcast.20942), dimensions={1}
+ gather.43 = s32[4,1,1]{2,1,0} gather(slice.3186, concatenate.1337), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
+ bitcast.20941 = s32[4]{0} bitcast(gather.43)
+ select.398 = s32[4]{0} select(param_1.11587, broadcast.19735, bitcast.20941)
+ compare.334 = pred[4]{0} compare(select.398, broadcast.19735), direction=LT
+ constant_6260 = s32[] constant(8000)
+ broadcast.19720 = s32[4]{0} broadcast(constant_6260), dimensions={}
+ add.8336 = s32[4]{0} add(select.398, broadcast.19720)
+ select.396 = s32[4]{0} select(compare.334, add.8336, select.398)
+ bitcast.20830 = s32[4,1]{1,0} bitcast(select.396)
+ concatenate.1308 = s32[4,2]{1,0} concatenate(iota.170, bitcast.20830), dimensions={1}
+ gather.41 = f32[4,1,1]{2,1,0} gather(subtract.1111, concatenate.1308), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
+ bitcast.20824 = f32[4]{0} bitcast(gather.41)
+ select.389 = f32[4]{0} select(param_1.11587, broadcast.19723, bitcast.20824)
+ bitcast.20823 = f32[4,1]{1,0} bitcast(select.389)
+ param_6.1719 = s32[] parameter(6)
+ constant_6323 = s32[] constant(2048)
+ add.8549 = s32[] add(param_6.1719, constant_6323)
+ compare.388 = pred[] compare(add.8549, constant_6281), direction=LT
+ constant_5436 = s32[] constant(4160)
+ add.8339 = s32[] add(param_6.1719, constant_5436)
+ select.409 = s32[] select(compare.388, add.8339, add.8549)
+ dynamic-slice.36 = f32[4,1]{1,0} dynamic-slice(param_0.10549, constant_6281, select.409), dynamic_slice_sizes={4,1}
+ select.388 = f32[4,1]{1,0} select(broadcast.19725, bitcast.20823, dynamic-slice.36)
+ ROOT dynamic-update-slice.307 = f32[4,2112]{1,0} dynamic-update-slice(param_0.10549, select.388, constant_6281, select.409)
+})";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} dynamic-update-slice(
+CHECK: %[[REDUCE:.*]] = {{.*}} reduce(
+CHECK: ROOT {{.*}} log(%[[REDUCE]])
+CHECK: ENTRY
+CHECK-COUNT-2: fusion(
+ )");
+}
+
+TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) {
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+ }
+
+ ENTRY FuseIntoScatter {
+ p0 = s32[3,3] parameter(0)
+ operand = s32[3,3] add(p0, p0)
+ p1 = s32[2] parameter(1)
+ indices = s32[2] add(p1, p1)
+ p2 = s32[2,3] parameter(2)
+ updates = s32[2,3] add(p2, p2)
+ scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ ROOT add = s32[3,3] add(scatter, scatter)
+ })");
+
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloInstruction* fusion = nullptr;
+ ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
+ EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
+ EXPECT_THAT(fusion->fused_expression_root(),
+ GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
+}
+
+// This test is similar to DontFuseIntoFirstOperandOfScatter, but PriorityFusion
+// has a separate run to fuse constants. Fusing anything into a scatter fusion
+// will fail in the emitter.
+TEST_F(PriorityFusionTest, DontFuseConstantIntoFirstOperandOfScatter) {
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+ }
+
+ ENTRY FuseIntoScatter {
+ operand = s32[1] constant({0})
+ indices = s32[24,1] parameter(0)
+ constant = s32[] constant(1)
+ updates = s32[24,1] broadcast(constant)
+ ROOT scatter = s32[1] scatter(operand, indices, updates),
+ to_apply=add,
+ update_window_dims={1},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ })");
+
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
+ EXPECT_THAT(root->fused_expression_root(),
+ GmockMatch(m::Scatter(m::Parameter(), m::Parameter(),
+ m::Broadcast(m::Constant()))));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduceEvenIfOccupancyIsHigh) {
+ constexpr absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY main {
+ p0 = f32[4,3584,128,168]{3,2,1,0} parameter(0)
+ c = f32[] constant(0)
+ r1 = f32[4,3584,128]{2,1,0} reduce(p0, c), dimensions={3}, to_apply=add
+ ROOT r2 = f32[4,3584]{1,0} reduce(r1, c), dimensions={2}, to_apply=add
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} reduce(
+CHECK: ROOT {{.*}} reduce(
+ )");
+}
+
+TEST_F(PriorityFusionTest, FuseReductionEpilogueWithMultipleUsers) {
+ // Regression test that verifies we correctly fuse the `log` into the reduce.
+ constexpr absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x, y)
+ }
+
+ fused_computation {
+ p0 = f32[64,16384]{1,0} parameter(0)
+ c0 = f32[] constant(0)
+ ROOT reduce.858 = f32[64]{0} reduce(p0, c0), dimensions={1}, to_apply=add
+ }
+
+ ENTRY main {
+ p0 = f32[64,16384]{1,0} parameter(0)
+ fusion = f32[64]{0} fusion(p0), kind=kInput, calls=fused_computation
+ log = f32[64]{0} log(fusion)
+ negate = f32[64]{0} custom-call(log), custom_call_target="negate"
+ ROOT add = f32[64]{0} add(negate, log)
+ }
+ )";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+ CHECK: ENTRY
+ CHECK: %[[PARAM:.*]] = {{.*}} parameter(0)
+ CHECK: %[[FUSION:.*]] = {{.*}} fusion(%[[PARAM]])
+ CHECK: custom-call(%[[FUSION]])
+ )");
+}
+
+TEST_F(PriorityFusionTest, EpilogueFusion) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add.13235 = f32[] add(p0, p1)
+ }
+
+ fused_computation.1 {
+ p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ ROOT r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
+ }
+
+ fused_computation.2 {
+ p0 = f32[8,4,128]{2,1,0} parameter(0)
+ r1 = f32[8,4,128]{2,1,0} log(p0)
+ ROOT r2 = f32[8,4,128]{2,1,0} log(r1)
+ }
+
+ ENTRY main {
+ p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+ f1 = f32[8,4,128]{2,1,0} fusion(p0), kind=kInput, calls=%fused_computation.1
+ ROOT fusion = f32[8,4,128]{2,1,0} fusion(f1), kind=kLoop, calls=%fused_computation.2
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} = f32[8,4,128]{2,1,0} fusion(%p{{.*}}), kind=kInput, calls=%fused_computation)");
+}
+
+TEST_F(PriorityFusionTest, EpilogueFusionFails) {
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add.13235 = f32[] add(p0, p1)
+ }
+
+ fused_computation.1 {
+ p0 = f32[28672,4096]{1,0} parameter(0)
+ c0 = f32[] constant(0)
+ ROOT r = f32[28672]{0} reduce(p0, c0), dimensions={1}, to_apply=add
+ }
+
+ fused_computation.2 {
+ p0 = f32[28672]{0} parameter(0)
+ p1 = f32[28672]{0} parameter(1)
+ ROOT a = f32[28672]{0} add(p0, p1)
+ }
+
+ ENTRY main {
+ p0 = f32[28672,4096]{1,0} parameter(0)
+ p1 = f32[28672]{0} parameter(1)
+ f = f32[28672]{0} fusion(p0), kind=kInput, calls=%fused_computation.1
+ ROOT fusion = f32[28672]{0} fusion(f,p1), kind=kLoop, calls=%fused_computation.2
+ })");
+
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseIntoRoot) {
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule test_module
+
+ ENTRY %main (p.0: u32[2], p.1: u32[]) -> u32[2] {
+ %p.0 = u32[2]{0} parameter(0)
+ %p.1 = u32[] parameter(1)
+ ROOT %broadcast = u32[2]{0} broadcast(u32[] %p.1), dimensions={}, sharding={replicated}
+ %add = u32[2]{0} add(u32[2]{0} %p.0, u32[2]{0} %broadcast)
+ %tuple.1 = (u32[2]{0}) tuple(u32[2]{0} %add)
+ %token.0 = token[] after-all()
+ %outfeed.6 = token[] outfeed((u32[2]{0}) %tuple.1, token[] %token.0), outfeed_shape=(u32[2]{0}), sharding={maximal device=0}
+ })");
+
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, DontFuseConcat) {
+ // Regression test that verifies we don't fuse concat into a column reduction.
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ %maximum (param_0: f32[], param_1: f32[]) -> f32[] {
+ %param_0 = f32[] parameter(0)
+ %param_1 = f32[] parameter(1)
+ ROOT %maximum = f32[] maximum(f32[] %param_0, f32[] %param_1)
+ }
+
+ %fused_concat (param_0: f32[1,4,401,8,8], param_1: f32[1,1,4,1023,8], param_2: bf16[1,4,1023,8,8]) -> f32[1,4,1424,8,8] {
+ %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
+ %convert = f32[1,4,1023,8,8]{4,3,2,1,0} convert(bf16[1,4,1023,8,8]{4,3,2,1,0} %param_2)
+ %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
+ %bitcast = f32[4,1023,8]{2,1,0} bitcast(f32[1,1,4,1023,8]{4,3,2,1,0} %param_1)
+ %broadcast = f32[1,4,1023,8,8]{4,3,2,1,0} broadcast(f32[4,1023,8]{2,1,0} %bitcast), dimensions={1,2,4}
+ %add = f32[1,4,1023,8,8]{4,3,2,1,0} add(f32[1,4,1023,8,8]{4,3,2,1,0} %convert, f32[1,4,1023,8,8]{4,3,2,1,0} %broadcast)
+ %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
+ ROOT %concatenate = f32[1,4,1424,8,8]{4,3,2,1,0} concatenate(f32[1,4,1023,8,8]{4,3,2,1,0} %add, f32[1,4,401,8,8]{4,3,2,1,0} %param_0), dimensions={2}
+ }
+
+ %fused_reduce (param_0: f32[], param_1: f32[1,4,1424,8,8]) -> f32[4,8,8] {
+ %param_1 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(1)
+ %bitcast = f32[4,1424,8,8]{3,2,1,0} bitcast(f32[1,4,1424,8,8]{4,3,2,1,0} %param_1)
+ %param_0 = f32[] parameter(0)
+ ROOT %reduce = f32[4,8,8]{2,1,0} reduce(f32[4,1424,8,8]{3,2,1,0} %bitcast, f32[] %param_0), dimensions={1}, to_apply=%maximum
+ }
+
+ %fused_broadcast (param_0: f32[1,4,1424,8,8], param_1: f32[4,8,8]) -> f32[1,4,1424,8,8] {
+ %param_0 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(0)
+ %param_1 = f32[4,8,8]{2,1,0} parameter(1)
+ %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} broadcast(f32[4,8,8]{2,1,0} %param_1), dimensions={1,3,4}
+ ROOT %subtract = f32[1,4,1424,8,8]{4,3,2,1,0} subtract(f32[1,4,1424,8,8]{4,3,2,1,0} %param_0, f32[1,4,1424,8,8]{4,3,2,1,0} %broadcast)
+ }
+
+ ENTRY fusion {
+ %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
+ %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
+ %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
+ %concat = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%param_0, %param_1, %param_2), kind=kLoop, calls=fused_concat
+ %param_3 = f32[] parameter(3)
+ %reduce = f32[4,8,8]{2,1,0} fusion(%param_3, %concat), kind=kLoop, calls=fused_reduce
+ %param_4 = f32[4,8,8]{2,1,0} parameter(4)
+ %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%concat, %param_4), kind=kLoop, calls=fused_broadcast
+ ROOT tuple = (f32[4,8,8]{2,1,0}, f32[1,4,1424,8,8]{4,3,2,1,0}) tuple(%reduce, %broadcast)
+ }
+ )");
+
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, FuseOnlySmallConstant) {
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ ENTRY main {
+ param_0 = f32[32,32]{1,0} parameter(0)
+ c_1 = f32[] constant(1)
+ c_2 = f32[32,32] constant({...})
+ broadcast = f32[32,32]{1,0} broadcast(c_1), dimensions={}
+ add = f32[32,32]{1,0} add(param_0, broadcast)
+ ROOT mul = f32[32,32]{1,0} multiply(c_2, add)
+ }
+ )");
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
+ EXPECT_THAT(root->fused_expression_root(),
+ GmockMatch(m::Multiply(
+ m::Parameter(),
+ m::Add(m::Parameter(), m::Broadcast(m::Constant())))));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) {
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ HloModule module
+
+ fused_computation.1 {
+ iota.9.7 = s32[3,1,1]{2,1,0} iota(), iota_dimension=0
+ param_3.29 = s32[] parameter(2)
+ pad.2.7 = s32[3,1,2]{2,1,0} pad(iota.9.7, param_3.29), padding=0_0x0_0x0_1
+ param_2.39 = s32[] parameter(1)
+ broadcast.76.1 = s32[3,1,2]{2,1,0} broadcast(param_2.39), dimensions={}
+ compare.9.1 = pred[3,1,2]{2,1,0} compare(pad.2.7, broadcast.76.1), direction=GE
+ param_1.73 = s32[2]{0} parameter(0)
+ broadcast.78.1 = s32[3,2]{1,0} broadcast(param_1.73), dimensions={1}
+ bitcast.1 = s32[3,2]{1,0} bitcast(pad.2.7)
+ compare.10.1 = pred[3,2]{1,0} compare(bitcast.1, broadcast.78.1), direction=LE
+ bitcast.2 = pred[3,1,2]{2,1,0} bitcast(compare.10.1)
+ ROOT and.3.1 = pred[3,1,2]{2,1,0} and(compare.9.1, bitcast.2)
+ }
+
+ and {
+ x = pred[] parameter(0)
+ y = pred[] parameter(1)
+ ROOT and = pred[] and(x, y)
+ }
+
+ fused_computation.2 {
+ param0 = pred[3,1,2]{2,1,0} parameter(0)
+ slice = pred[1,1,2]{2,1,0} slice(param0), slice={[0:1], [0:1], [0:2]}
+ bitcast = pred[2]{0} bitcast(slice)
+ init = pred[] constant(true)
+ reduce = pred[2]{0} reduce(param0, init), dimensions={0,1}, to_apply=and
+ and = pred[2]{0} and(bitcast, reduce)
+ pad = pred[3]{0} pad(and, init), padding=0_1
+ broadcast = pred[3,2]{1,0} broadcast(pad), dimensions={0}
+ bitcast2 = pred[6]{0} bitcast(broadcast)
+ broadcast2 = pred[2,3]{1,0} broadcast(pad), dimensions={1}
+ bitcast3 = pred[6]{0} bitcast(broadcast2)
+ ROOT and2 = pred[6]{0} and(bitcast2, bitcast3)
+ }
+
+ ENTRY main {
+ p0 = s32[2]{0} parameter(0)
+ p1 = s32[] parameter(1)
+ p2 = s32[] parameter(2)
+ fusion1 = pred[3,1,2]{2,1,0} fusion(p0, p1, p2), kind=kLoop, calls=fused_computation.1
+ ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2
+ }
+ )");
+ auto& debug_options = module->mutable_config().mutable_debug_options();
+ debug_options.set_xla_gpu_mlir_emitter_level(3);
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, CanMergeTritonFusionWithBothProducerAndConsumer) {
+#ifndef GOOGLE_CUDA
+ GTEST_SKIP() << "Triton fusion only enable for CUDA devices.";
+#endif
+
+ const std::string kHloText = R"(
+HloModule t
+add {
+ Arg_0 = f32[] parameter(0)
+ Arg_1 = f32[] parameter(1)
+ ROOT add = f32[] add(Arg_0, Arg_1)
+}
+
+producer_computation {
+ parameter_0 = f32[125]{0} parameter(0)
+ ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0}
+}
+
+consumer_computation {
+ parameter_0 = f32[125,127]{1,0} parameter(0)
+ parameter_1 = f32[125,127]{1,0} parameter(1)
+ ROOT multiply = f32[125,127]{1,0} multiply(parameter_1, parameter_0)
+}
+
+triton_softmax_computation {
+ parameter_0 = f32[125,127]{1,0} parameter(0)
+ multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0)
+ constant_0 = f32[] constant(0)
+ reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add
+ broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0}
+ ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4)
+}
+
+ENTRY main {
+ param_0 = f32[125]{0} parameter(0)
+ param_1 = f32[125,127]{1,0} parameter(1)
+ producer_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=producer_computation
+ triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
+ ROOT consumer_fusion = f32[125,127]{1,0} fusion(param_1, triton_softmax), kind=kLoop, calls=consumer_computation
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
+ auto debug_options = module->config().debug_options();
+ debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true);
+ module->mutable_config().set_debug_options(debug_options);
+
+ EXPECT_TRUE(priority_fusion_.Run(module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom);
+ EXPECT_TRUE(IsGenericTritonFusion(*root));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) {
+ auto module = *ParseAndReturnVerifiedModule(R"(
+ %reducer {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ add = f32[] add(p0, p1)
+ ROOT max = f32[] maximum(add, p0)
+ }
+
+ %fused_reduce {
+ p0 = f32[256] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT reduce = f32[] reduce(p0, p1), dimensions={0}, to_apply=%reducer
+ }
+
+ ENTRY fusion {
+ p0 = f32[256] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT %reduce = f32[] fusion(p0, p1), kind=kInput, calls=fused_reduce
+ }
+ )");
+ EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc
new file mode 100644
index 0000000..d33c849
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc
@@ -0,0 +1,130 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_opt_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/status_macros.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> ReduceScatterCreator::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ const HloModuleConfig &config = module->config();
+ int64_t next_channel_id = hlo_query::NextChannelId(*module);
+
+ bool changed = false;
+ for (HloComputation *computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction *instruction :
+ computation->MakeInstructionPostOrder()) {
+ if (instruction->opcode() != HloOpcode::kAllReduce) {
+ continue;
+ }
+ auto *ar = Cast<HloAllReduceInstruction>(instruction);
+ auto ar_spec = MatchReduceScatter(ar, config.num_partitions(),
+ config.replica_count(),
+ /*allow_multiple_split_dims=*/false,
+ /*allow_intervening_reshape=*/true);
+ if (!ar_spec) {
+ VLOG(2) << "Cannot match reduce-scatter " << ar->ToString();
+ continue;
+ }
+
+ HloInstruction *ds = ar_spec->dynamic_slice;
+
+ // Convert to all-reduce scatter. The output shape of the all-reduce
+ // scatter will the same as the input shape, except the split dim size is
+ // that of the result of the dynamic slice.
+ const int64_t split_dim = ar_spec->split_dim;
+ Shape scatter_shape = ar->shape();
+ const int64_t split_dim_size = scatter_shape.dimensions(split_dim);
+ HloInstruction *rs_input = ar->mutable_operand(0);
+ const int64_t scatter_dim_size = split_dim_size / ar_spec->group_size;
+ TF_RET_CHECK(scatter_dim_size * ar_spec->group_size <= split_dim_size);
+ if (split_dim_size % ar_spec->group_size != 0) {
+ // The dynamic-slice does not evenly split the scatter dim. In that
+ // case, create a reduce-scatter with the relevant slice of the
+ // all-reduce input.
+ scatter_shape.set_dimensions(split_dim,
+ scatter_dim_size * ar_spec->group_size);
+ rs_input = computation->AddInstruction(HloInstruction::CreateSlice(
+ scatter_shape, rs_input,
+ std::vector<int64_t>(scatter_shape.rank(), 0),
+ scatter_shape.dimensions(),
+ std::vector<int64_t>(scatter_shape.rank(), 1)));
+ }
+ scatter_shape.set_dimensions(split_dim, scatter_dim_size);
+
+ std::optional<int64_t> channel_id;
+ if (ar->channel_id()) {
+ // We cannot reuse the channel_id on all-reduce for reduce-scatter.
+ channel_id = next_channel_id++;
+ }
+
+ HloInstruction *ars =
+ computation->AddInstruction(HloInstruction::CreateReduceScatter(
+ scatter_shape, {rs_input}, ar->to_apply(), ar->device_list(),
+ ar->constrain_layout(), channel_id, ar->use_global_device_ids(),
+ ar_spec->split_dim));
+
+ // If there was an intervening reshape, reshape the non-split dimensions
+ // to match that existing reshape. Basically we can just reshape the ars
+ // result to the dynamic slice shape.
+ HloInstruction *result = ars;
+ HloInstruction *reshape = nullptr;
+ if (ds->operand(0) != ar) {
+ reshape = ds->mutable_operand(0);
+ result = computation->AddInstruction(
+ HloInstruction::CreateReshape(ds->shape(), result));
+ }
+
+ // Note that RemoveInstructionAndUnusedOperands may not always remove the
+ // all-reduce operand of the dynamic-slice, so remove all the dead
+ // instructions manually.
+ TF_RETURN_IF_ERROR(ds->ReplaceAllUsesWith(result));
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(ds));
+ if (reshape) {
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(reshape));
+ }
+ TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar));
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h
new file mode 100644
index 0000000..4e74394
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h
@@ -0,0 +1,43 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Transforms dynamic-slice(all-reduce) to a reduce-scatter.
+class ReduceScatterCreator : public HloModulePass {
+ public:
+ ReduceScatterCreator() = default;
+ absl::string_view name() const override { return "reduce-scatter-creator"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc
new file mode 100644
index 0000000..39a2c72
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc
@@ -0,0 +1,572 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class GpuReduceScatterCreatorTest : public HloTestBase {
+ public:
+ absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
+ absl::string_view hlo_module, int64_t num_replicas,
+ int64_t num_partitions, bool expect_change) {
+ HloModuleConfig config = GetModuleConfigForTest(
+ /*replica_count=*/num_replicas,
+ /*num_partitions=*/num_partitions);
+ config.set_use_spmd_partitioning(num_partitions > 1);
+ TF_ASSIGN_OR_RETURN(auto module,
+ ParseAndReturnVerifiedModule(hlo_module, config));
+ auto changed = ReduceScatterCreator().Run(module.get());
+ if (!changed.ok()) {
+ return changed.status();
+ }
+ EXPECT_EQ(changed.value(), expect_change);
+ return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
+ }
+
+ size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
+ return CollectiveCount(module, HloOpcode::kAllReduce);
+ }
+
+ size_t ReduceScatterCount(std::unique_ptr<HloModule> &module) {
+ return CollectiveCount(module, HloOpcode::kAllReduce);
+ }
+
+ private:
+ size_t CollectiveCount(std::unique_ptr<HloModule> &module, HloOpcode opcode) {
+ return absl::c_count_if(
+ module->entry_computation()->instructions(),
+ [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
+ }
+};
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicas) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={}, to_apply=%sum
+ %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+ %rid = u32[] replica-id()
+ %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%id)
+ %slice_size = s32[] constant(4)
+ %offset = s32[] multiply(%reshape, %slice_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+ dynamic_slice_sizes={4,8,128}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/true));
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ const auto *rs = Cast<HloReduceScatterInstruction>(
+ module->entry_computation()->root_instruction());
+ EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithOffsetReshape) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={}, to_apply=%sum
+ %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+ %rid = u32[] replica-id()
+ %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+ %slice_size = s32[1] constant({4})
+ %offset = s32[1] multiply(%id, %slice_size)
+ %reshape = s32[] reshape(%offset)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %reshape, %zero, %zero),
+ dynamic_slice_sizes={4,8,128}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/true));
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ const auto *rs = Cast<HloReduceScatterInstruction>(
+ module->entry_computation()->root_instruction());
+ EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={}, to_apply=%sum
+ %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+ %rid = u32[] replica-id()
+ %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%id)
+ %slice_size = s32[] constant(4)
+ %offset = s32[] multiply(%reshape, %slice_size)
+ %zero = s32[] constant(0)
+ %reshape.1 = f32[32,16,64] reshape(%all-reduce)
+ ROOT %dynamic-slice = f32[4,16,64] dynamic-slice(%reshape.1, %offset, %zero, %zero),
+ dynamic_slice_sizes={4,16,64}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshapeSplitDimModified) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[336,1024] parameter(0)
+ %all-reduce = f32[336,1024] all-reduce(%param), replica_groups={}, to_apply=%sum
+ %rid = u32[] replica-id()
+ %id = s32[] convert(%rid)
+ %slice_size = s32[] constant(128)
+ %offset = s32[] multiply(%id, %slice_size)
+ %zero = s32[] constant(0)
+ %reshape.1 = f32[4,84,1024] reshape(%all-reduce)
+ ROOT %dynamic-slice = f32[4,84,128] dynamic-slice(%reshape.1, %zero, %zero, %offset),
+ dynamic_slice_sizes={4,84,128}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasDim2) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={}, to_apply=%sum
+ %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+ %rid = u32[] replica-id()
+ %rid_s32 = s32[] convert(%rid)
+ %slice_size = s32[] constant(16)
+ %offset = s32[] multiply(%rid_s32, %slice_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[32,8,16] dynamic-slice(%all-reduce, %zero, %zero, %offset),
+ dynamic_slice_sizes={32,8,16}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/true));
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ const auto *rs = Cast<HloReduceScatterInstruction>(
+ module->entry_computation()->root_instruction());
+ EXPECT_EQ(rs->scatter_dimension(), 2) << rs->ToString();
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWrongOffsets) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={}, to_apply=%sum
+ %table = s32[8]{0} constant({0,1,2,3,4,5,6,8})
+ %rid = u32[] replica-id()
+ %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%id)
+ %slice_size = s32[] constant(4)
+ %offset = s32[] multiply(%reshape, %slice_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+ dynamic_slice_sizes={4,8,128}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/1,
+ /*expect_change=*/false));
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasIotaTable) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={}, to_apply=%sum
+ %table = s32[8]{0} iota(), iota_dimension=0
+ %rid = u32[] replica-id()
+ %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%id)
+ %slice_size = s32[] constant(4)
+ %offset = s32[] multiply(%reshape, %slice_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+ dynamic_slice_sizes={4,8,128}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/2,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupedReplicas) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum
+ %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
+ %rid = u32[] replica-id()
+ %id = s32[1] dynamic-slice(%gtable, %rid), dynamic_slice_sizes={1}
+ %reshape.0 = s32[] reshape(%id)
+ %table = s32[4]{0} constant({0,8,16,24})
+ %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
+ %reshape.1 = s32[] reshape(%offset)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
+ dynamic_slice_sizes={8,8,128}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/8,
+ /*num_partitions=*/2,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllPartitions) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={{0},{1}}, to_apply=%sum, channel_id=1
+ %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+ %pid = u32[] partition-id()
+ %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%id)
+ %slice_size = s32[] constant(4)
+ %offset = s32[] multiply(%reshape, %slice_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+ dynamic_slice_sizes={4,8,128}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/2,
+ /*num_partitions=*/8,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReduceFollowedByAllReduce) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce.scattered = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=1
+ %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+ %pid = u32[] partition-id()
+ %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%id)
+ %slice_size = s32[] constant(4)
+ %offset = s32[] multiply(%reshape, %slice_size)
+ %zero = s32[] constant(0)
+ %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce.scattered, %offset, %zero, %zero),
+ dynamic_slice_sizes={4,8,128}
+ ROOT %all-reduce.sync = f32[4,8,128]{2,1,0} all-reduce(%dynamic-slice),
+ replica_groups={{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=2
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/2,
+ /*num_partitions=*/8,
+ /*expect_change=*/true));
+ EXPECT_EQ(AllReduceCount(module), 1);
+ EXPECT_EQ(ReduceScatterCount(module), 1);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobals) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+ %pid = u32[] partition-id()
+ %rid = u32[] replica-id()
+ %pcount = u32[] constant(4)
+ %ridxp = u32[] multiply(%rid, %pcount)
+ %gid = u32[] add(%ridxp, %pid)
+ %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
+ %id = s32[1] dynamic-slice(%gtable, %gid), dynamic_slice_sizes={1}
+ %reshape.0 = s32[] reshape(%id)
+ %table = s32[4]{0} constant({0,8,16,24})
+ %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
+ %reshape.1 = s32[] reshape(%offset)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
+ dynamic_slice_sizes={8,8,128}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/2,
+ /*num_partitions=*/4,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsOrthogonalReplicas) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={{1,3,2,0},{5,7,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+ %pid = u32[] partition-id()
+ %pid_table = s32[4]{0} constant({3,0,2,1})
+ %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%offset)
+ %shard_size = s32[] constant(8)
+ %mul = s32[] multiply(%reshape, %shard_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
+ dynamic_slice_sizes={8,8,128}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/2,
+ /*num_partitions=*/4,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Parameter(0))));
+ EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsNonOrthogonalReplicas) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[32,8,128]{2,1,0} parameter(0)
+ %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+ replica_groups={{1,3,2,0},{7,5,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+ %pid = u32[] partition-id()
+ %pid_table = s32[4]{0} constant({3,0,2,1})
+ %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%offset)
+ %shard_size = s32[] constant(8)
+ %mul = s32[] multiply(%reshape, %shard_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
+ dynamic_slice_sizes={8,8,128}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/2,
+ /*num_partitions=*/4,
+ /*expect_change=*/false));
+}
+
+TEST_F(GpuReduceScatterCreatorTest, NonUniformSplit) {
+ absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+ %param = f32[1,7]{1,0} parameter(0)
+ %all-reduce = f32[1,7]{1,0} all-reduce(%param),
+ replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+ %pid = u32[] partition-id()
+ %pid_table = s32[8]{0} constant({0, 1, 0, 1, 0, 1, 0, 1})
+ %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
+ %reshape = s32[] reshape(%offset)
+ %shard_size = s32[] constant(3)
+ %mul = s32[] multiply(%reshape, %shard_size)
+ %zero = s32[] constant(0)
+ ROOT %dynamic-slice = f32[1,3] dynamic-slice(%all-reduce, %zero, %mul),
+ dynamic_slice_sizes={1,3}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+ /*num_replicas=*/1,
+ /*num_partitions=*/8,
+ /*expect_change=*/true));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::ReduceScatter(m::Slice(m::Parameter(0)))));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc
new file mode 100644
index 0000000..8c2929c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc
@@ -0,0 +1,131 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleReduce(HloInstruction *hlo) override {
+ auto instr = Cast<HloReduceInstruction>(hlo);
+ absl::InlinedVector<HloInstruction *, 2> input_reshapes;
+ absl::InlinedVector<Shape, 2> canonical_reduce_shapes;
+
+ int idx = -1;
+ std::vector<int64_t> updated_reduced_dimensions;
+ for (HloInstruction *reduced_op : instr->inputs()) {
+ idx++;
+ const Shape &input_shape = reduced_op->shape();
+ const Shape &reduce_shape = instr->shape().IsTuple()
+ ? instr->shape().tuple_shapes(idx)
+ : instr->shape();
+
+ if (!ShapeUtil::HasDegenerateDimensions(reduced_op->shape())) {
+ return absl::OkStatus();
+ }
+ Shape canonical_input_shape =
+ ShapeUtil::DropDegenerateDimensions(input_shape);
+
+ Shape canonical_reduce_shape =
+ ShapeUtil::DropDegenerateDimensions(reduce_shape);
+
+ auto reduced_dimensions = instr->dimensions();
+ int64_t shift = 0;
+
+ for (int dim = 0; dim < input_shape.rank(); dim++) {
+ if (input_shape.dimensions(dim) == 1) {
+ shift++;
+ } else {
+ if (absl::c_linear_search(reduced_dimensions, dim) && idx == 0) {
+ // Only populate on first iteration.
+ updated_reduced_dimensions.push_back(dim - shift);
+ }
+ }
+ }
+
+ if (updated_reduced_dimensions.empty()) {
+ std::unique_ptr<HloInstruction> reshape =
+ HloInstruction::CreateBitcast(reduce_shape, reduced_op);
+ return ReplaceWithNewInstruction(instr, std::move(reshape));
+ }
+
+ input_reshapes.push_back(instr->parent()->AddInstruction(
+ HloInstruction::CreateBitcast(canonical_input_shape, reduced_op)));
+ canonical_reduce_shapes.push_back(canonical_reduce_shape);
+ }
+
+ Shape canonical_reduce_shape =
+ ShapeUtil::MakeMaybeTupleShape(canonical_reduce_shapes);
+ const Shape &orig_reduce_shape = instr->shape();
+ std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
+ canonical_reduce_shape, input_reshapes, instr->init_values(),
+ updated_reduced_dimensions, instr->to_apply());
+ instr->SetupDerivedInstruction(new_reduce.get());
+
+ if (canonical_reduce_shape != instr->shape()) {
+ HloInstruction *wrapped_reduce =
+ instr->parent()->AddInstruction(std::move(new_reduce));
+ absl::InlinedVector<HloInstruction *, 2> out;
+ if (!canonical_reduce_shape.IsTuple()) {
+ new_reduce =
+ HloInstruction::CreateBitcast(orig_reduce_shape, wrapped_reduce);
+ } else {
+ for (int oidx = 0; oidx < instr->input_count(); oidx++) {
+ HloInstruction *gte = instr->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
+ out.push_back(
+ instr->parent()->AddInstruction(HloInstruction::CreateBitcast(
+ orig_reduce_shape.tuple_shapes(oidx), gte)));
+ }
+ new_reduce = HloInstruction::CreateTuple(out);
+ }
+ }
+
+ return ReplaceWithNewInstruction(instr, std::move(new_reduce));
+ }
+};
+
+absl::StatusOr<bool> ReductionDegenerateDimRemover::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ TF_ASSIGN_OR_RETURN(bool changed,
+ ReductionDegenerateDimRemoverVisitor().RunOnModule(
+ module, execution_threads));
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h
new file mode 100644
index 0000000..1630aec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h
@@ -0,0 +1,56 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Enforces the invariant that reduction input and output have no degenerate
+// (size 1) dimension. Since these dimensions are physically meaningless, they
+// are removed using bitcasts.
+//
+// For example,
+//
+// f[1] out = reduce(f[100, 1, 1] input, dimensions={0, 1})
+//
+// becomes:
+//
+//
+// f[100] tmp1 = f[100] bitcast(f[100, 1, 1], input)
+// f[] tmp2 = reduce(f[100] tmp1, dimensions={0})
+// f[1] out = f[] bitcast(tmp2)
+//
+class ReductionDegenerateDimRemover : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "reduction-degenerate-dim-remover";
+ }
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc
new file mode 100644
index 0000000..7a9b7fa
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class ReductionDegenerateDimRemoverTest : public HloTestBase {
+ public:
+ void CheckDegenerateDimRemover(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, gpu::ReductionDegenerateDimRemover{},
+ expected);
+ }
+};
+
+TEST_F(ReductionDegenerateDimRemoverTest, ReductionWithDegenerateDimensions) {
+ const char* hlo = R"(
+HloModule ReduceWithDegenerateDimensions
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1,3,1,4,1,5,1] parameter(0)
+ zero = f32[] constant(0)
+
+ ROOT out = f32[1,1,1,1] reduce(input, zero), dimensions={1,3,5}, to_apply=add
+}
+
+)";
+
+ CheckDegenerateDimRemover(hlo, R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[] reduce([[bitcast_0]], [[zero_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[reduce_2]])
+ )");
+}
+
+TEST_F(ReductionDegenerateDimRemoverTest,
+ ReductionWithDegenerateDimensionsVariadic) {
+ const char* hlo = R"(
+HloModule ReduceWithDegenerateDimensions
+
+argmax {
+ running_max = f32[] parameter(0)
+ running_max_idx = u32[] parameter(1)
+ current_value = f32[] parameter(2)
+ current_value_idx = u32[] parameter(3)
+
+ current = (f32[], u32[]) tuple(running_max, running_max_idx)
+ potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+ cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+ new_max = f32[] select(cmp_code, current_value, running_max)
+ new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+ ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+ input = f32[1,3,1,4,1,5,1] parameter(0)
+ idxs = u32[1,3,1,4,1,5,1] parameter(1)
+ zero = f32[] constant(0)
+ zero_idx = u32[] constant(0)
+
+ ROOT out = (f32[1,1,1,1], u32[1,1,1,1]) reduce(input, idxs, zero, zero_idx), dimensions={1,3,5}, to_apply=argmax
+}
+
+)";
+
+ CheckDegenerateDimRemover(hlo, R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[bitcast_1_2:%[^ ]+]] = u32[3,4,5]{2,1,0} bitcast([[idxs_3:%[^ ]+]])
+// CHECK: [[reduce_4:%[^ ]+]] = (f32[], u32[]) reduce([[bitcast_0]], [[bitcast_1_2]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK-NEXT: [[get_tuple_element_8:%[^ ]+]] = f32[] get-tuple-element([[reduce_4]]), index=0
+// CHECK-NEXT: [[bitcast_2_9:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK-NEXT: [[get_tuple_element_1_10:%[^ ]+]] = u32[] get-tuple-element([[reduce_4]]), index=1
+// CHECK-NEXT: [[bitcast_3_11:%[^ ]+]] = u32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK-NEXT: ROOT [[tuple_12:%[^ ]+]] = (f32[1,1,1,1]{3,2,1,0}, u32[1,1,1,1]{3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
+)");
+}
+
+TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
+ const char* hlo = R"(
+HloModule ReduceWithDegenerateDimensions
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1,3,1,4,1,5,1] parameter(0)
+ zero = f32[] constant(0)
+
+ ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
+}
+)";
+
+ CheckDegenerateDimRemover(hlo,
+ R"(
+// CHECK: ROOT [[bitcast_0:%[^ ]+]] = f32[3,4,5,1]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+ )");
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc
new file mode 100644
index 0000000..ca4fba4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc
@@ -0,0 +1,122 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_dimension_grouper.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/layout_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleReduce(HloInstruction *hlo) override {
+ auto reduce = Cast<HloReduceInstruction>(hlo);
+
+ VLOG(4) << "Input: " << reduce->ToString();
+
+ absl::InlinedVector<HloInstruction *, 2> reduce_inputs_grouped;
+ std::vector<int64_t> reduced_dims_grouped;
+
+ int idx = -1;
+ for (HloInstruction *operand : reduce->inputs()) {
+ idx++;
+ std::vector<int64_t> new_grouped_dims;
+ const Shape &shape = operand->shape();
+ CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
+ << "Default layout should be enforced on reduction operand";
+ auto is_reduced = [&](int dim) {
+ return absl::c_linear_search(reduce->dimensions(), dim);
+ };
+
+ bool changed = false;
+ int64_t next_dim_size = 1;
+
+ // Since we have enforced the standard layout, iteration over logical
+ // dimensions is equivalent to iteration over the major-to-minor order.
+ for (int logical_dim = 0; logical_dim < shape.rank(); logical_dim++) {
+ VLOG(5) << "Processing dimension " << logical_dim << " of size "
+ << shape.dimensions(logical_dim);
+ if (is_reduced(logical_dim) && logical_dim < shape.rank() - 1 &&
+ is_reduced(logical_dim + 1)) {
+ VLOG(5) << "This and consecutive dimension are reduced, merging";
+ changed = true;
+ next_dim_size *= shape.dimensions(logical_dim);
+ continue;
+ }
+
+ if (is_reduced(logical_dim)) {
+ new_grouped_dims.push_back(next_dim_size *
+ shape.dimensions(logical_dim));
+ if (idx == 0) {
+ // Only populate for first argument.
+ reduced_dims_grouped.push_back(new_grouped_dims.size() - 1);
+ }
+ next_dim_size = 1;
+ } else {
+ new_grouped_dims.push_back(shape.dimensions(logical_dim));
+ }
+ }
+
+ if (!changed) { // Since all inputs have same shape dimensions.
+ return absl::OkStatus();
+ }
+
+ Shape grouped_shape =
+ ShapeUtil::MakeShape(shape.element_type(), new_grouped_dims);
+ reduce_inputs_grouped.push_back(reduce->parent()->AddInstruction(
+ HloInstruction::CreateBitcast(grouped_shape, operand),
+ &operand->metadata()));
+ VLOG(5) << "Adding bitcast: " << reduce_inputs_grouped.back()->ToString();
+ }
+
+ std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
+ reduce->shape(), reduce_inputs_grouped, reduce->init_values(),
+ reduced_dims_grouped, reduce->to_apply());
+ VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
+ return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
+ }
+};
+
+absl::StatusOr<bool> ReductionDimensionGrouper::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ TF_ASSIGN_OR_RETURN(bool changed, ReduceDimensionGroupVisitor().RunOnModule(
+ module, execution_threads));
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h
new file mode 100644
index 0000000..1ebd600
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h
@@ -0,0 +1,56 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Groups adjacent (logically and physically) reduced dimensions in reduction
+// input.
+//
+// Precondition: ReductionLayoutNormalizer has been run (physical proximity and
+// logical proximity become the same).
+//
+// For example,
+//
+// f[] out = reduce(f[10,20,30] input, dimensions={0,1,2})
+//
+// becomes:
+//
+// f[600] tmp = f[600] bitcast(f[10,20,30] input)
+// f[] out = reduce(f[600] tmp, dimensions={0})
+//
+class ReductionDimensionGrouper : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "reduction-dimension-grouper";
+ }
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc
new file mode 100644
index 0000000..afbbbec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc
@@ -0,0 +1,103 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_dimension_grouper.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class ReductionDimensionGrouperTest : public HloTestBase {
+ public:
+ void CheckDimensionGrouper(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, gpu::ReductionDimensionGrouper{}, expected);
+ }
+};
+
+TEST_F(ReductionDimensionGrouperTest, ReductionWithGrouping) {
+ const char* hlo = R"(
+HloModule ReductionWithGrouping
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[100,10,32,3]{3,2,1,0} parameter(0)
+ zero = f32[] constant(0)
+
+ ROOT out = f32[100,10]{0,1} reduce(input, zero), dimensions={2,3}, to_apply=add
+}
+)";
+
+ CheckDimensionGrouper(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK: ROOT [[out_1_2:%[^ ]+]] = f32[100,10]{0,1} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+ )");
+}
+
+TEST_F(ReductionDimensionGrouperTest, ReductionWithGroupingVariadic) {
+ const char* hlo = R"(
+HloModule ReductionWithGrouping
+
+argmax {
+ running_max = f32[] parameter(0)
+ running_max_idx = u32[] parameter(1)
+ current_value = f32[] parameter(2)
+ current_value_idx = u32[] parameter(3)
+
+ current = (f32[], u32[]) tuple(running_max, running_max_idx)
+ potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+ cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+ new_max = f32[] select(cmp_code, current_value, running_max)
+ new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+ ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+ input = f32[100,10,32,3]{3,2,1,0} parameter(0)
+ idxs = u32[100,10,32,3]{3,2,1,0} parameter(1)
+ zero = f32[] constant(0)
+ zero_idx = u32[] constant(0)
+
+ ROOT out = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(input, idxs, zero, zero_idx), dimensions={2,3}, to_apply=argmax
+}
+)";
+
+ CheckDimensionGrouper(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK: [[idxs_2:%[^ ]+]] = u32[100,10,32,3]{3,2,1,0} parameter(1)
+// CHECK: [[bitcast_1_3:%[^ ]+]] = u32[100,10,96]{2,1,0} bitcast([[idxs_2]])
+// CHECK: ROOT [[out_1_4:%[^ ]+]] = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_7:%[^ ]+]]
+)");
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc
new file mode 100644
index 0000000..fd45f8b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc
@@ -0,0 +1,203 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_layout_normalizer.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
+ absl::Status HandleReduce(HloInstruction *hlo) override {
+ auto reduce = Cast<HloReduceInstruction>(hlo);
+ VLOG(5) << "Input: " << reduce->ToString();
+
+ int operand_idx = -1;
+
+ absl::InlinedVector<HloInstruction *, 2> canonical_reduce_inputs;
+ absl::InlinedVector<Shape, 2> new_reduce_shapes;
+
+ DimensionVector out_reduce_dimensions;
+ const Shape &first_instruction_shape = reduce->inputs()[0]->shape();
+
+ for (HloInstruction *operand : reduce->inputs()) {
+ operand_idx++;
+
+ if (operand_idx != 0 &&
+ operand->shape().layout() != first_instruction_shape.layout()) {
+ HloInstruction *copy =
+ reduce->parent()->AddInstruction(HloInstruction::CreateUnary(
+ operand->shape(), HloOpcode::kCopy, operand));
+
+ LayoutUtil::ClearLayout(copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ first_instruction_shape, copy->mutable_shape()));
+
+ copy->set_metadata(operand->metadata());
+ operand = copy;
+ VLOG(3) << "Copying to establish consistent inputs layout: "
+ << copy->ToString();
+ }
+
+ const Shape &operand_shape = operand->shape();
+ const Layout &operand_layout = operand_shape.layout();
+
+ const Shape &reduce_shape =
+ reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(operand_idx)
+ : reduce->shape();
+
+ DimensionVector new_reduce_dimensions;
+ DimensionVector new_operand_shape_data;
+ DimensionVector new_reduce_shape_data;
+
+ // The layout order of the reduction output can be different to the
+ // ordering of kept dimensions in the input operand, thus we need to
+ // calculate the new layout.
+ DimensionVector new_reduce_shape_layout(reduce_shape.rank());
+ std::vector<int64_t> reduce_shape_logical_to_physical =
+ LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout());
+
+ auto to_reduce_logical_dim = [&](int64_t op_logical_dim) {
+ return op_logical_dim -
+ absl::c_count_if(reduce->dimensions(), [&](int64_t dim) {
+ CHECK(dim != op_logical_dim);
+ return dim < op_logical_dim;
+ });
+ };
+
+ for (int i = 0; i < operand_shape.rank(); i++) {
+ // Process the dimensions in the major-to-minor order in order to
+ // enforce the default layout.
+ int64_t major_to_minor_dim_idx = operand_shape.rank() - i - 1;
+ int64_t logical_dim =
+ operand_layout.minor_to_major(major_to_minor_dim_idx);
+ int64_t dim_size = operand_shape.dimensions(logical_dim);
+ VLOG(5) << "Processing logical dimension " << logical_dim << " of size "
+ << dim_size;
+ new_operand_shape_data.push_back(dim_size);
+
+ if (absl::c_linear_search(reduce->dimensions(), logical_dim)) {
+ new_reduce_dimensions.push_back(i);
+ } else {
+ new_reduce_shape_data.push_back(dim_size);
+ int64_t logical_reduce_dim = to_reduce_logical_dim(logical_dim);
+ int64_t physical_reduce_dim =
+ reduce_shape_logical_to_physical[logical_reduce_dim];
+ VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", "
+ << "physical_reduce_dim = " << physical_reduce_dim;
+ new_reduce_shape_layout[reduce_shape.rank() - physical_reduce_dim -
+ 1] = new_reduce_shape_data.size() - 1;
+ }
+ }
+
+ Shape new_operand_shape = ShapeUtil::MakeShape(
+ operand_shape.element_type(), new_operand_shape_data);
+ Shape new_reduce_shape = ShapeUtil::MakeShapeWithDenseLayout(
+ reduce_shape.element_type(), new_reduce_shape_data,
+ new_reduce_shape_layout);
+
+ if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) {
+ return absl::OkStatus();
+ }
+
+ HloInstruction *canonical_reduce_input =
+ new_operand_shape != operand_shape
+ ? reduce->parent()->AddInstruction(
+ HloInstruction::CreateBitcast(new_operand_shape, operand))
+ : operand;
+ canonical_reduce_input->set_metadata(operand->metadata());
+ VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
+
+ new_reduce_shapes.push_back(new_reduce_shape);
+ canonical_reduce_inputs.push_back(canonical_reduce_input);
+
+ if (out_reduce_dimensions.empty()) {
+ out_reduce_dimensions = new_reduce_dimensions;
+ } else {
+ TF_RET_CHECK(out_reduce_dimensions == new_reduce_dimensions);
+ }
+ }
+
+ Shape new_reduce_shape = ShapeUtil::MakeMaybeTupleShape(new_reduce_shapes);
+
+ std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
+ new_reduce_shape, canonical_reduce_inputs, reduce->init_values(),
+ out_reduce_dimensions, reduce->to_apply());
+ VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
+ const Shape &orig_reduce_shape = reduce->shape();
+
+ if (new_reduce_shape != orig_reduce_shape) {
+ HloInstruction *wrapped_reduce =
+ reduce->parent()->AddInstruction(std::move(new_reduce));
+
+ if (!new_reduce_shape.IsTuple()) {
+ new_reduce =
+ HloInstruction::CreateBitcast(reduce->shape(), wrapped_reduce);
+ } else {
+ // Bitcast each element of the tuple.
+ absl::InlinedVector<HloInstruction *, 2> out;
+ for (int oidx = 0; oidx < reduce->input_count(); oidx++) {
+ HloInstruction *gte = reduce->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
+ out.push_back(
+ reduce->parent()->AddInstruction(HloInstruction::CreateBitcast(
+ orig_reduce_shape.tuple_shapes(oidx), gte)));
+ }
+ new_reduce = HloInstruction::CreateTuple(out);
+ }
+ }
+
+ VLOG(5) << "Generated output: " << new_reduce->ToString();
+ return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
+ }
+};
+
+absl::StatusOr<bool> ReductionLayoutNormalizer::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ TF_ASSIGN_OR_RETURN(bool changed,
+ EnforceMinorToMajorReduceOpVisitor().RunOnModule(
+ module, execution_threads));
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h
new file mode 100644
index 0000000..f6e2d7c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h
@@ -0,0 +1,54 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Enforces default (minor-to-major) layout on all reduction inputs.
+// Note that since reduction output can request a custom layout,
+// this pass only guarantees standard layout for the input.
+//
+// For example,
+//
+// f[20,30]{0,1} out = reduce(f[10,20,30]{2,0,1} input, dimensions={0})
+//
+// becomes:
+//
+// f[20,10,30] tmp = f[20,10,30] bitcast(f[10,20,30]{2,0,1} input)
+// f[20,30]{0,1} out = reduce(f[20,10,30]{2,1,0} tmp, dimensions={1})
+class ReductionLayoutNormalizer : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "reduction-layout-normalizer";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc
new file mode 100644
index 0000000..46f5e93
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_layout_normalizer.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/error_spec.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class ReductionLayoutNormalizerTest : public HloTestBase {
+ public:
+ void CheckReductionLayoutNormalizer(
+ absl::string_view hlo, std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
+ }
+};
+
+TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
+ const char* hlo = R"(
+HloModule ReduceWithLayoutChange
+
+add {
+ x0 = f32[] parameter(0)
+ y0 = f32[] parameter(1)
+ ROOT add0 = f32[] add(x0, y0)
+}
+
+ENTRY main {
+ arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
+ constant0 = f32[] constant(0)
+ ROOT reduce0 = f32[4,5,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
+ dimensions={1,6,7}, to_apply=add
+}
+
+)";
+
+ CheckReductionLayoutNormalizer(hlo,
+ R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
+ )");
+}
+
+TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
+ const char* hlo = R"(
+HloModule ReduceWithLayoutChangeVariadic
+
+
+argmax {
+ running_max = f32[] parameter(0)
+ running_max_idx = u32[] parameter(1)
+ current_value = f32[] parameter(2)
+ current_value_idx = u32[] parameter(3)
+
+ current = (f32[], u32[]) tuple(running_max, running_max_idx)
+ potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+ cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+ new_max = f32[] select(cmp_code, current_value, running_max)
+ new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+ ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+ arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
+ idxs = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
+ constant0 = f32[] constant(0)
+ constant1 = u32[] constant(0)
+ ROOT reduce0 = (
+ f32[4,5,16,12,12]{4,3,2,1,0},
+ u32[4,5,16,12,12]{4,3,2,1,0}
+ ) reduce(arg0, idxs, constant0,constant1), dimensions={1,6,7}, to_apply=argmax
+}
+
+
+)";
+
+ CheckReductionLayoutNormalizer(hlo,
+ R"(
+// CHECK: [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
+// CHECK: [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
+// CHECK: [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
+// CHECK: [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK: [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
+// CHECK: [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK: [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
+// CHECK: [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK: ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
+ )");
+}
+
+TEST_F(ReductionLayoutNormalizerTest,
+ LayoutCanonicalizerTestVariadicDifferentLayouts) {
+ const char* hlo = R"(
+HloModule ReduceWithLayoutChangeVariadicDifferent
+
+argmax {
+ running_max = f32[] parameter(0)
+ running_max_idx = u32[] parameter(1)
+ current_value = f32[] parameter(2)
+ current_value_idx = u32[] parameter(3)
+
+ current = (f32[], u32[]) tuple(running_max, running_max_idx)
+ potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+ cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+ new_max = f32[] select(cmp_code, current_value, running_max)
+ new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+ ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+ arg0 = f32[2,3,4,7]{2,1,0,3} parameter(0)
+ idxs = u32[2,3,4,7]{3,2,1,0} parameter(1)
+ constant0 = f32[] constant(0)
+ constant1 = u32[] constant(0)
+ ROOT reduce0 = (
+ f32[2,3,4]{2,1,0},
+ u32[2,3,4]{2,1,0}
+ ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
+}
+
+
+)";
+
+ CheckReductionLayoutNormalizer(hlo,
+ R"(
+// CHECK: [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
+// CHECK: [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
+// CHECK: [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
+// CHECK: [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
+// CHECK: ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
+ )");
+ EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc
new file mode 100644
index 0000000..dce9288
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc
@@ -0,0 +1,140 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_splitter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class ReductionSplitterVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit ReductionSplitterVisitor(bool ignore_small_dims)
+ : ignore_small_dims_(ignore_small_dims) {}
+ absl::Status HandleReduce(HloInstruction *reduce) override {
+ VLOG(4) << "Input: " << reduce->ToString();
+
+ // Reductions with contiguous dimensions are lowered to efficient code. No
+ // need to split such ops.
+ if (IsReductionFromOrToContiguousDimensions(*reduce)) {
+ VLOG(4) << "Reduction with contiguous dimensions. Return.";
+ return absl::OkStatus();
+ }
+ if (reduce->dimensions().size() < 2) {
+ return absl::OkStatus();
+ }
+ if (!reduce->shape().IsArray()) {
+ // TODO(cheshire): Handle variadic reduction.
+ return absl::OkStatus();
+ }
+
+ HloInstruction *operand = reduce->mutable_operand(0);
+ const Shape &shape = operand->shape();
+ CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
+ << "Default layout should be enforced on reduction operand";
+ // Verify that contiguous dimensions have been grouped by the
+ // ReductionDimensionGrouper pass.
+ for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
+ for (int64_t j = i + 1; j < reduce->dimensions().size(); ++j) {
+ CHECK(abs(reduce->dimensions(i) - reduce->dimensions(j)) > 1)
+ << "Reduction dimensions must not be consecutive";
+ }
+ }
+
+ // The reduce op has non-contiguous dimensions. Look for the dimension with
+ // the largest shape dimension. Reducing along this dimension first will
+ // reduce the output size most effectively.
+ int64_t max_shape_dim = 0;
+ int64_t max_reduce_dim = 0;
+ const auto &input_shape = reduce->operand(0)->shape();
+ for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
+ if (input_shape.dimensions(reduce->dimensions(i)) > max_shape_dim) {
+ max_reduce_dim = reduce->dimensions(i);
+ max_shape_dim = input_shape.dimensions(max_reduce_dim);
+ }
+ }
+ if (ignore_small_dims_ && max_shape_dim <= 8) {
+ return absl::OkStatus();
+ }
+
+ // Split the reduction into a pre-reduction and a final reduction.
+ VLOG(3) << "Splitting reduction " << reduce->name() << " at dimension "
+ << max_reduce_dim;
+ std::vector<int64_t> pre_reduce_dims;
+ pre_reduce_dims.push_back(max_reduce_dim);
+ std::vector<int64_t> pre_reduce_shape_dims(input_shape.dimensions().begin(),
+ input_shape.dimensions().end());
+ pre_reduce_shape_dims.erase(pre_reduce_shape_dims.begin() + max_reduce_dim);
+ Shape pre_reduce_shape = ShapeUtil::MakeShape(
+ reduce->shape().element_type(), pre_reduce_shape_dims);
+ std::unique_ptr<HloInstruction> pre_reduce = HloInstruction::CreateReduce(
+ pre_reduce_shape, reduce->mutable_operand(0),
+ reduce->mutable_operand(1), pre_reduce_dims, reduce->to_apply());
+ pre_reduce->set_metadata(reduce->metadata());
+
+ std::vector<int64_t> final_reduce_dims(reduce->dimensions().begin(),
+ reduce->dimensions().end());
+ final_reduce_dims.erase(
+ std::remove(final_reduce_dims.begin(), final_reduce_dims.end(),
+ max_reduce_dim),
+ final_reduce_dims.end());
+ for (int64_t i = 0; i < final_reduce_dims.size(); ++i) {
+ if (final_reduce_dims[i] > max_reduce_dim) {
+ final_reduce_dims[i]--;
+ }
+ }
+ std::unique_ptr<HloInstruction> final_reduce = HloInstruction::CreateReduce(
+ reduce->shape(),
+ reduce->parent()->AddInstruction(std::move(pre_reduce)),
+ reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply());
+ return ReplaceWithNewInstruction(reduce, std::move(final_reduce));
+ }
+
+ private:
+ bool ignore_small_dims_;
+};
+
+absl::StatusOr<bool> ReductionSplitter::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ TF_ASSIGN_OR_RETURN(bool changed,
+ ReductionSplitterVisitor(ignore_small_dims_)
+ .RunOnModule(module, execution_threads));
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h
new file mode 100644
index 0000000..87520d3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h
@@ -0,0 +1,59 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Splits a reduce op into two consecutive reduce ops if the reduce dimensions
+// are not contiguous. Ignores small reduce dimensions if `ignore_small_dims` is
+// set.
+//
+// Reductions with non-contiguous dimensions are emitted as simple element-wise
+// loops. This is inefficient when reducing large input shape dimensions.
+// Splitting such reductions allows using more efficient reduction emitters.
+//
+// This pass splits reduce ops into two consecutive reduce ops. Run it to a
+// fixpoint to split reduce ops along multiple dimensions.
+//
+// Precondition: ReductionDimensionGrouper has been run and adjacent reduce
+// dimentsions have been grouped. Reduction layouts have been normalized.
+
+class ReductionSplitter : public HloModulePass {
+ public:
+ explicit ReductionSplitter(bool ignore_small_dims)
+ : ignore_small_dims_(ignore_small_dims) {}
+ absl::string_view name() const override { return "reduction-splitter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ bool ignore_small_dims_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc
new file mode 100644
index 0000000..4b9f6fb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc
@@ -0,0 +1,152 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_splitter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape_util.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class ReductionSplitterTest : public HloTestBase {};
+
+TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test
+
+ add_computation {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x, y)
+ }
+
+ ENTRY entry_computation {
+ param_0 = f16[6,16,512,64]{3,2,1,0} parameter(0)
+ transpose.1781 = f16[6,512,16,64]{3,1,2,0} transpose(param_0), dimensions={0,2,1,3}
+ convert.6986 = f32[6,512,16,64]{3,1,2,0} convert(transpose.1781)
+ bitcast.2136 = f32[6,16,512,64]{3,2,1,0} bitcast(convert.6986)
+ constant_11111 = f32[] constant(0)
+ ROOT reduce.982 = f32[16,64]{1,0} reduce(bitcast.2136, constant_11111), dimensions={0,2}, to_apply=add_computation
+ }
+ )")
+ .value();
+ ASSERT_TRUE(
+ ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root_reduction =
+ module->entry_computation()->root_instruction();
+ ASSERT_THAT(root_reduction,
+ GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
+
+ auto* pre_reduction = root_reduction->operand(0);
+ EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({2}));
+ EXPECT_THAT(pre_reduction->shape(), ShapeUtil::MakeShape(F32, {6, 16, 64}));
+ EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({0}));
+ EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
+}
+
+TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test
+
+ add_computation {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x, y)
+ }
+
+ ENTRY entry_computation {
+ param_0 = f32[1024,16,512,64,128]{4,3,2,1,0} parameter(0)
+ constant_11111 = f32[] constant(0)
+ ROOT reduce.982 = f32[16,64]{1,0} reduce(param_0, constant_11111), dimensions={2,0,4}, to_apply=add_computation
+ }
+ )")
+ .value();
+ ASSERT_TRUE(
+ ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root_reduction =
+ module->entry_computation()->root_instruction();
+ ASSERT_THAT(root_reduction,
+ GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
+
+ auto* pre_reduction = root_reduction->operand(0);
+ EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({0}));
+ EXPECT_THAT(pre_reduction->shape(),
+ ShapeUtil::MakeShape(F32, {16, 512, 64, 128}));
+ EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({1, 3}));
+ EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
+}
+
+TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test
+
+ add_computation {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x, y)
+ }
+
+ ENTRY entry_computation {
+ param_0 = f32[16,8,1024,8]{3,2,1,0} parameter(0)
+ constant_11111 = f32[] constant(0)
+ ROOT reduce.982 = f32[16,1024]{1,0} reduce(param_0, constant_11111), dimensions={3,1}, to_apply=add_computation
+ }
+ )")
+ .value();
+ EXPECT_FALSE(
+ ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
+ EXPECT_TRUE(
+ ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
+}
+
+TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule test
+
+ add_computation {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x, y)
+ }
+
+ ENTRY entry_computation {
+ param_0 = f32[128,128,64,128]{3,2,1,0} parameter(0)
+ constant_11111 = f32[] constant(0)
+ // The dimenstions to keep (1 and 2) are contiguous.
+ ROOT reduce.982 = f32[128,64]{1,0} reduce(param_0, constant_11111), dimensions={3,0}, to_apply=add_computation
+ }
+ )")
+ .value();
+ EXPECT_FALSE(
+ ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc
new file mode 100644
index 0000000..9ab62f6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc
@@ -0,0 +1,92 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/rename_fusions.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/container/btree_set.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr absl::string_view FusionKindToString(
+ HloInstruction::FusionKind kind) {
+ switch (kind) {
+ case HloInstruction::FusionKind::kCustom:
+ return "custom";
+ case HloInstruction::FusionKind::kLoop:
+ return "loop";
+ case HloInstruction::FusionKind::kInput:
+ return "input";
+ case HloInstruction::FusionKind::kOutput:
+ return "output";
+ }
+}
+
+std::string MakeFusionHeroNames(const HloInstruction* instruction) {
+ std::unique_ptr<HloFusionAdaptor> fusion_adaptor =
+ HloFusionAdaptor::ForInstruction(instruction);
+ absl::btree_set<absl::string_view> heroes;
+
+ for (auto root : fusion_adaptor->GetRoots()) {
+ heroes.insert(HloOpcodeString(FindNonTrivialHero(root).opcode()));
+ }
+ return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}});
+}
+
+void RenameFusion(HloModule* module, HloInstruction* instruction) {
+ std::string hero_names = MakeFusionHeroNames(instruction);
+ module->SetAndUniquifyInstrName(
+ instruction, absl::StrCat(FusionKindToString(instruction->fusion_kind()),
+ "_", hero_names, "_fusion"));
+ module->SetAndUniquifyComputationName(
+ instruction->fused_instructions_computation(),
+ absl::StrCat("fused_", hero_names));
+}
+
+} // namespace
+
+absl::StatusOr<bool> RenameFusions::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() != HloOpcode::kFusion ||
+ instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) {
+ continue;
+ }
+ RenameFusion(module, instruction);
+ }
+ }
+ return true;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/rename_fusions.h b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h
new file mode 100644
index 0000000..5abcd61
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h
@@ -0,0 +1,47 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// An HLO pass that gives fusions and fused computations descriptive names.
+//
+// The name is based on hero instructions and the fusion kind, i.e.
+// Fusions get name "<fusion kind>_<hero instrucitons>_fusion",
+// and fused computations get name "fused_<hero instructions>".
+// In the case of multiple roots, the hero instructions in the name are
+// underscore-separated and alphabetically sorted.
+
+class RenameFusions : public HloModulePass {
+ absl::string_view name() const override { return "rename_fusions"; }
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc
new file mode 100644
index 0000000..4747085
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/rename_fusions.h"
+
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+class RenameFusionsTest : public HloTestBase {
+ protected:
+ RenameFusions rename_fusions_;
+};
+
+TEST_F(RenameFusionsTest, FusionInstructionNames) {
+ absl::string_view kHlo = R"(
+ HloModule test_module
+
+ square {
+ p = f32[16384] parameter(0)
+ ROOT m = f32[16384] multiply(p, p)
+ }
+
+ exp {
+ p = f32[16384] parameter(0)
+ ROOT e = f32[16384] exponential(p)
+ }
+
+ log {
+ p = f32[16384] parameter(0)
+ ROOT l = f32[16384] log(p)
+ }
+
+ add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT add = f32[] add(p0, p1)
+ }
+
+ ENTRY main {
+ p0 = bf16[1024,8192] parameter(0)
+ p1 = f32[8192] parameter(1)
+ p2 = f32[16384] parameter(2)
+ convert = f32[1024,8192] convert(p0)
+ broadcast = f32[1024,8192] broadcast(p1), dimensions={1}
+ c0 = f32[] constant(0)
+ multiply = f32[1024,8192] multiply(broadcast, convert)
+ reduce = f32[1024] reduce(multiply, c0), dimensions={1}, to_apply=add
+ convert.1 = bf16[1024] convert(reduce)
+ s = f32[16384] fusion(p2), kind=kLoop, calls=square
+ e = f32[16384] fusion(s), kind=kLoop, calls=exp
+ l = f32[16384] fusion(s), kind=kInput, calls=log
+ ROOT result = (bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(convert.1, l, e)
+ })";
+
+ RunAndFilecheckHloRewrite(kHlo, std::move(rename_fusions_), R"(
+CHECK: ENTRY %main
+CHECK: %loop_multiply_fusion{{.*}} calls=%fused_multiply
+CHECK: %input_log_fusion{{.*}} calls=%fused_log
+CHECK: %loop_exponential_fusion{{.*}} calls=%fused_exponential
+CHECK: ROOT %result
+ )");
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc
new file mode 100644
index 0000000..3841f4a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc
@@ -0,0 +1,75 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
+
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/llvm_ir/buffer_assignment_util.h"
+#include "xla/service/name_uniquer.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+
+namespace gpu {
+
+absl::StatusOr<bool> SanitizeConstantNames::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+
+ NameUniquer instr_name_uniquer(/*separator=*/"_");
+ // Collect the names used for the non-constant HLO instructions.+
+ for (HloComputation* computation : module->computations(execution_threads)) {
+ for (HloInstruction* instr : computation->instructions()) {
+ if (instr->opcode() == HloOpcode::kConstant) {
+ continue;
+ }
+
+ // Record the non-constant HLO instruction name in uniquer, and keep
+ // original instruction name unchanged.
+ instr_name_uniquer.GetUniqueName(instr->name());
+ }
+ }
+
+ // Sanitize the names for the constant HLO instructions and make them unique.
+ // This is not merged into the above loop because we don't want this pass to
+ // change the names of non-constant instructions, that is, if a constant HLO
+ // conflicts with a non-constant HLO, we change the name of the constant HLO
+ // even though the non-constant HLO comes after in the HLO module.
+ for (HloComputation* computation : module->computations(execution_threads)) {
+ for (HloInstruction* instr : computation->instructions()) {
+ if (instr->opcode() != HloOpcode::kConstant) {
+ continue;
+ }
+ std::string sanitized_name = llvm_ir::SanitizeConstantName(*instr);
+ instr->SetAndSanitizeName(sanitized_name);
+ instr->UniquifyName(&instr_name_uniquer);
+ // Register this new name with the module's instruction_name_uniquer to
+ // avoid name collision that might happen in future.
+ module->instruction_name_uniquer().GetUniqueName(instr->name());
+ changed = true;
+ }
+ }
+
+ return changed;
+} // namespace gpu
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h
new file mode 100644
index 0000000..f743137
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h
@@ -0,0 +1,44 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Sanitizes HLO instruction names for the GPU backend. Currently, it only
+// replaces . and - in the HLO constant instruction names with _ to please the
+// LLVM PTX backend.
+class SanitizeConstantNames : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "sanitize-constant-names"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc
new file mode 100644
index 0000000..8e97790
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc
@@ -0,0 +1,111 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/literal_util.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+using SanitizeConstantNamesTest = HloTestBase;
+
+TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) {
+ const char *const kHloString = R"(
+ HloModule HyphenInInstructionName
+ ENTRY kernelEntry {
+ ROOT equal-to = s32[2]{0} constant({42, 73})
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+ HloInstruction *root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->name(), "equal_to");
+}
+
+TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) {
+ const char *const kHloString = R"(
+ HloModule HyphenInInstructionName
+ ENTRY kernelEntry {
+ ROOT equal.to = s32[2]{0} constant({42, 73})
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+ HloInstruction *root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->name(), "equal_to");
+}
+
+TEST_F(SanitizeConstantNamesTest, NewInstructionNameRegisteredWithModule) {
+ const char *const kHloString = R"(
+ HloModule HyphenInInstructionName
+ ENTRY kernelEntry {
+ ROOT equal.to = s32[2]{0} constant({42, 73})
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+ HloInstruction *root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->name(), "equal_to");
+
+ auto constant_instr =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1));
+ constant_instr->SetAndSanitizeName("equal_to");
+ module->entry_computation()->AddInstruction(std::move(constant_instr));
+
+ EXPECT_THAT(FindInstruction(module.get(), "equal_to.1"),
+ GmockMatch(m::Constant()));
+}
+
+TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) {
+ const char *const kHloString = R"(
+ HloModule BufferSanitizedName
+ ENTRY kernelEntry {
+ equal.to = s32[2]{0} constant({42, 73})
+ equal-to = s32[2]{0} constant({67, 3})
+ ROOT equal_to = s32[2]{0} add(equal.to, equal-to)
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+ EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"),
+ GmockMatch(m::Constant()));
+ EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"),
+ GmockMatch(m::Constant()));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc
new file mode 100644
index 0000000..26eb210
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc
@@ -0,0 +1,33 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scatter_expander.h"
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/primitive_util.h"
+
+namespace xla {
+
+bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
+ // TODO(b/129698548): Scattering elements larger than 64 bits is not
+ // supported by XLA:GPU.
+ // TODO(b/227486631): Variadic scatter is not yet supported by GPU.
+ return inst->opcode() == HloOpcode::kScatter &&
+ (inst->shape().IsTuple() ||
+ primitive_util::BitWidth(inst->shape().element_type()) > 64);
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_expander.h b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h
new file mode 100644
index 0000000..f86b932
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h
@@ -0,0 +1,40 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_
+
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/scatter_expander.h"
+
+namespace xla {
+
+// Legalizes scatters on the GPU.
+class GpuScatterExpander : public ScatterExpander {
+ public:
+ // Although we pass kEliminateAllScatters, we override this behavior in
+ // InstructionMatchesPattern and select only some scatters to expand.
+ GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {}
+
+ absl::string_view name() const override { return "gpu_scatter_expander"; }
+
+ protected:
+ bool InstructionMatchesPattern(HloInstruction* inst) override;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc
new file mode 100644
index 0000000..d9c1deb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc
@@ -0,0 +1,264 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scatter_slice_simplifier.h"
+
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+// Returns whether the instruction could be an operand for a slice instruction.
+bool IsValidIntermediaryUser(const HloInstruction* instruction) {
+ // Allow elementwise instructions, as they don't depend on the truncated
+ // elements. In case of multi-output scatters, the resulting shape is a tuple.
+ return instruction->IsElementwise() ||
+ instruction->opcode() == HloOpcode::kGetTupleElement;
+}
+
+// Matches the "Scatter -> Elementwise (zero or more) -> Slice" pattern.
+// Calculates the resulting scatter dimensions from the slice users.
+class ScatterSliceMatcher {
+ public:
+ explicit ScatterSliceMatcher(const HloScatterInstruction* scatter)
+ : scatter_(scatter),
+ operand_dimensions_(
+ scatter->scatter_operands()[0]->shape().dimensions()),
+ result_dimensions_(operand_dimensions_.begin(),
+ operand_dimensions_.end()) {}
+
+ // Determine the scatter shape from the user slice instructions.
+ // If any of the users are not truncation slices, return `nullopt`.
+ std::optional<Shape> InferShape() {
+ VLOG(10) << "Evaluating scatter " << scatter_->name();
+ if (!AreAllUsersValid(scatter_)) {
+ return std::nullopt;
+ }
+ std::vector<Shape> result_shapes;
+ absl::c_transform(scatter_->scatter_operands(),
+ std::back_inserter(result_shapes),
+ [&](const HloInstruction* op) {
+ return ShapeUtil::MakeShape(op->shape().element_type(),
+ result_dimensions_);
+ });
+ return ShapeUtil::MakeMaybeTupleShape(result_shapes);
+ }
+
+ private:
+ // Update the resulting scatter dimensions from the slice configuration and
+ // the original scatter dimensions. Return `false` if the update is not
+ // possible.
+ bool UpdateDimensions(const HloSliceInstruction* slice) {
+ int64_t rank = slice->shape().rank();
+ for (int64_t i = 0; i < rank; ++i) {
+ if (slice->slice_starts(i) != 0 || slice->slice_strides(i) != 1) {
+ return false; // The slice is not a truncation.
+ }
+ if (slice->slice_limits(i) != result_dimensions_[i]) {
+ if (result_dimensions_[i] != operand_dimensions_[i]) {
+ return false; // Another slice has incompatible dimensions.
+ }
+ auto& update_window_dims =
+ scatter_->scatter_dimension_numbers().update_window_dims();
+ if (absl::c_binary_search(update_window_dims, i)) {
+ return false; // Update dimensions cannot be truncated.
+ }
+ result_dimensions_[i] = slice->slice_limits(i);
+ VLOG(10) << "Dimension " << i << " truncated to size "
+ << result_dimensions_[i];
+ }
+ }
+ return true;
+ }
+
+ // Verify that the instruction is a valid scatter user, i.e. is either a slice
+ // operation or is an elementwise operation that has slice users (recursive).
+ bool IsUserValid(const HloInstruction* op) {
+ VLOG(10) << "Visiting user " << op->name();
+
+ // If the user is a slice operation, verify the configuration and update
+ // the resulting dimensions.
+ if (auto* slice = DynCast<HloSliceInstruction>(op)) {
+ return UpdateDimensions(slice);
+ }
+ // If the user is an elementwise operation, verify the users recursively
+ // (unless already visited).
+ bool is_valid = visited_set_.contains(op) ||
+ (IsValidIntermediaryUser(op) && AreAllUsersValid(op));
+ if (is_valid) {
+ visited_set_.emplace(op);
+ }
+ return is_valid;
+ }
+
+ // Verify that all users are valid (see the definition of IsValidUser).
+ // If we reach the root instruction, fail the matching (slice is not found).
+ bool AreAllUsersValid(const HloInstruction* op) {
+ if (op->user_count() == 0) {
+ return !op->IsRoot();
+ }
+ return absl::c_all_of(op->users(), [this](const HloInstruction* user) {
+ return IsUserValid(user);
+ });
+ }
+
+ const HloScatterInstruction* scatter_;
+ absl::flat_hash_set<const HloInstruction*> visited_set_;
+ absl::Span<const int64_t> operand_dimensions_;
+ DimensionVector result_dimensions_;
+};
+
+// Create a replacement operand for the scatter instruction.
+HloInstruction* CreateSliceFrom(HloInstruction* operand, const Shape& shape) {
+ std::vector<int64_t> start_indices(shape.rank(), 0);
+ std::vector<int64_t> limit_indices(shape.rank());
+ std::vector<int64_t> strides(shape.rank(), 1);
+ for (int64_t i = 0; i < shape.rank(); ++i) {
+ limit_indices[i] = shape.dimensions(i);
+ }
+ return operand->AddInstruction(HloInstruction::CreateSlice(
+ shape, operand, start_indices, limit_indices, strides));
+}
+
+// Create a replacement for the scatter instruction.
+HloInstruction* CreateScatterFrom(HloScatterInstruction* scatter,
+ const Shape& shape) {
+ std::vector<HloInstruction*> operands(scatter->scatter_operand_count());
+ for (int64_t i = 0; i < operands.size(); ++i) {
+ operands[i] =
+ CreateSliceFrom(scatter->scatter_operands()[i],
+ shape.IsTuple() ? shape.tuple_shapes(i) : shape);
+ }
+ return scatter->AddInstruction(HloInstruction::CreateScatter(
+ shape, absl::MakeSpan(operands), scatter->scatter_indices(),
+ scatter->scatter_updates(), scatter->called_computations()[0],
+ scatter->scatter_dimension_numbers(), scatter->indices_are_sorted(),
+ scatter->unique_indices()));
+}
+
+class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleScatter(HloInstruction* instruction) override {
+ auto* scatter = Cast<HloScatterInstruction>(instruction);
+
+ // Infer scatter shape from the slice users.
+ std::optional<Shape> result_shape =
+ ScatterSliceMatcher(scatter).InferShape();
+ if (!result_shape.has_value()) {
+ return absl::OkStatus();
+ }
+ VLOG(2) << "Matched scatter " << scatter->name() << " with shape "
+ << scatter->shape().ToString() << ", inferred result shape "
+ << result_shape->ToString() << " (from the slice users)";
+
+ // Replace slice user instructions.
+ HloInstruction* new_scatter = CreateScatterFrom(scatter, *result_shape);
+ return ReplaceAllUsersRecursive(scatter, new_scatter);
+ }
+
+ private:
+ // Create a replacement for every user. If the user is a slice operation,
+ // replace it in the computation graph, the old branch will be removed.
+ absl::Status ReplaceAllUsersRecursive(HloInstruction* old_instruction,
+ HloInstruction* new_instruction) {
+ // Maintain the replacement map, needed for non-unary elementwise users.
+ replacements_[old_instruction] = new_instruction;
+
+ // It's importand to make a copy of the users list, as it may be modified
+ // during the iteration.
+ std::vector<HloInstruction*> users = old_instruction->users();
+ for (HloInstruction* user : users) {
+ if (user->parent() == nullptr) {
+ VLOG(3) << "Skipping user " << user->name() << " (already replaced)";
+ continue;
+ }
+ TF_RETURN_IF_ERROR(ReplaceUserRecursive(user, new_instruction));
+ }
+ return absl::OkStatus();
+ }
+
+ // Replace the slice user with a new scatter (or a new chain of operations
+ // starting with a scatter). For elementwise operations, create a new user
+ // with updated operands (build the chain).
+ absl::Status ReplaceUserRecursive(HloInstruction* user,
+ HloInstruction* operand) {
+ VLOG(3) << "Replacing scatter user " << user->name();
+ if (user->opcode() == HloOpcode::kSlice) {
+ return ReplaceInstruction(user, operand);
+ }
+
+ // Create the replacement instruction with new shape.
+ HloInstruction* new_user = nullptr;
+ if (user->IsElementwise()) {
+ auto new_shape = [operand](HloInstruction* from) {
+ return ShapeUtil::MakeShape(from->shape().element_type(),
+ operand->shape().dimensions());
+ };
+ std::vector<HloInstruction*> new_operands;
+ absl::c_transform(user->operands(), std::back_inserter(new_operands),
+ [&](HloInstruction* op) {
+ auto it = replacements_.find(op);
+ return it != replacements_.end()
+ ? it->second
+ : CreateSliceFrom(op, new_shape(op));
+ });
+ new_user = user->AddInstruction(
+ user->CloneWithNewOperands(new_shape(user), new_operands));
+ } else {
+ auto* gte = Cast<HloGetTupleElementInstruction>(user);
+ TF_ASSIGN_OR_RETURN(new_user,
+ MakeGetTupleElementHlo(operand, gte->tuple_index(),
+ &user->metadata()));
+ }
+
+ // Replace slice user instructions recursively.
+ return ReplaceAllUsersRecursive(user, new_user);
+ }
+
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements_;
+};
+
+} // namespace
+
+absl::StatusOr<bool> ScatterSliceSimplifier::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ return ScatterSliceSimplifierVisitor{}.RunOnModule(module, execution_threads);
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h
new file mode 100644
index 0000000..96f39b5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h
@@ -0,0 +1,58 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Replaces scatters followed by truncation slices with a new scatter using
+// a different output shape, and the slices are eliminated.
+//
+// (a) Single output (b) Multiple outputs (c) Elementwise users
+//
+// T[N+1] scatter (T1, T2) scatter T scatter T constant
+// v v v v v
+// T[N] slice T1 gte T2 gte T maximum
+// v v v
+// T1 slice T2 slice T slice
+//
+// This pattern is used when the last element of the scatter output is intended
+// to accumulate updates from the input elements that should be ignored.
+// This is slow if there are many indices mapped to the last output index and
+// the scatter is implemented using atomics, so everything collides on that one
+// memory location.
+// As OOB scatter indices are dropped by the GPU implementation, we can remove
+// the slice step entirely and avoid the memory congestion in the scatter step.
+
+class ScatterSliceSimplifier : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "scatter-slice-simplifier"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc
new file mode 100644
index 0000000..8f1c93c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc
@@ -0,0 +1,336 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scatter_slice_simplifier.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+
+namespace m = ::xla::match;
+
+using ScatterSliceSimplifierTest = HloTestBase;
+
+TEST_F(ScatterSliceSimplifierTest, Scatter1D) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[9] constant(0)
+ %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ ROOT %slice = f32[8] slice(%scatter), slice={[0:8]}
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
+ m::Parameter(1))
+ .WithShape(F32, {8})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, Scatter3D) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+ %indices = s32[2] parameter(0)
+ %updates = f32[2,4,4] parameter(1)
+ %operands = f32[5,4,4] constant(0)
+ %scatter = f32[5,4,4] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ ROOT %slice = f32[4,4,4] slice(%scatter), slice={[0:4], [0:4], [0:4]}
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
+ m::Parameter(1))
+ .WithShape(F32, {4, 4, 4})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, ScatterMultiOutput) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32_add_F16 {
+ %lhs.0 = f32[] parameter(0)
+ %rhs.0 = f32[] parameter(2)
+ %add.0 = f32[] add(%lhs.0, %rhs.0)
+ %lhs.1 = f16[] parameter(1)
+ %rhs.1 = f16[] parameter(3)
+ %add.1 = f16[] add(%lhs.1, %rhs.1)
+ ROOT %tuple = (f32[], f16[]) tuple(%add.0, %add.1)
+}
+
+ENTRY main {
+ %indices = s32[4] parameter(0)
+ %updates.0 = f32[4] parameter(1)
+ %updates.1 = f16[4] parameter(2)
+ %operands.0 = f32[9] constant(0)
+ %operands.1 = f16[9] constant(0)
+ %scatter = (f32[9], f16[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_add_F16
+ %gte.0 = f32[9] get-tuple-element(%scatter), index=0
+ %slice.0 = f32[8] slice(%gte.0), slice={[0:8]}
+ %gte.1 = f16[9] get-tuple-element(%scatter), index=1
+ %slice.1 = f16[8] slice(%gte.1), slice={[0:8]}
+ ROOT %tuple = (f32[8], f16[8]) tuple(%slice.0, %slice.1)
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+ auto expected_scatter =
+ m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
+ m::Parameter(0), m::Parameter(1), m::Parameter(2));
+
+ Shape expected_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F16, {8})});
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::GetTupleElement(expected_scatter),
+ m::GetTupleElement(expected_scatter))
+ .WithShapeEqualTo(&expected_shape)));
+}
+
+TEST_F(ScatterSliceSimplifierTest, NotMatching) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+slice_not_truncation {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[9] constant(0)
+ %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ ROOT %slice = f32[8] slice(%scatter), slice={[1:9]}
+}
+
+slice_with_stride {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[9] constant(0)
+ %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ ROOT %slice = f32[4] slice(%scatter), slice={[0:8:2]}
+}
+
+scatter_multiple_users {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[9] constant(0)
+ %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ %slice = f32[8] slice(%scatter), slice={[0:8]}
+ ROOT %tuple = (f32[9], f32[8]) tuple(%scatter, %slice)
+}
+
+scatter_incompatible_slices {
+ %indices = s32[2] parameter(0)
+ %updates = f32[2,4] parameter(1)
+ %operands = f32[4,4] constant(0)
+ %scatter = f32[4,4] scatter(%operands, %indices, %updates), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ %slice.0 = f32[3,4] slice(%scatter), slice={[0:3], [0:4]}
+ %slice.1 = f32[4,3] slice(%scatter), slice={[0:4], [0:3]}
+ ROOT %tuple = (f32[3,4], f32[4,3]) tuple(%slice.0, %slice.1)
+}
+
+slice_not_found {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[8] constant(0)
+ %scatter = f32[8] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ ROOT %exp = f32[8] exponential(%scatter)
+}
+
+slice_update_dimensions {
+ %indices = s32[10] parameter(0)
+ %updates = f32[10,1,128] parameter(1)
+ %operands = f32[100,128] constant(0)
+ %scatter = f32[100,128] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ ROOT %slice = f32[100,64] slice(%scatter), slice={[0:100], [0:64]}
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_FALSE(RunHloPass(&test_pass, module.get()).value());
+}
+
+TEST_F(ScatterSliceSimplifierTest, IntermediaryUsers) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[9] constant(0)
+ %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ %unary = f32[9] abs(%scatter)
+ %slice.0 = f32[8] slice(%unary), slice={[0:8]}
+ %binary = f32[9] maximum(%scatter, %operands)
+ %slice.1 = f32[8] slice(%binary), slice={[0:8]}
+ ROOT %tuple = (f32[8], f32[8]) tuple(%slice.0, %slice.1)
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+ auto expected_scatter =
+ m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
+
+ Shape expected_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})});
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::Abs(expected_scatter),
+ m::Maximum(expected_scatter, m::Slice(m::Constant())))
+ .WithShapeEqualTo(&expected_shape)));
+}
+
+TEST_F(ScatterSliceSimplifierTest, IntermediaryChain) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[9] constant(0)
+ %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ %elementwise.0 = f32[9] abs(%scatter)
+ %elementwise.1 = f32[9] exponential(%elementwise.0)
+ %elementwise.2 = f32[9] add(%elementwise.0, %elementwise.1)
+ ROOT %result = f32[8] slice(%elementwise.2), slice={[0:8]}
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+ auto expected_scatter =
+ m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Add(m::Abs(expected_scatter),
+ m::Exp(m::Abs(expected_scatter)))
+ .WithShape(F32, {8})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, DiamondShape) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32_mul_F32 {
+ %lhs.0 = f32[] parameter(0)
+ %rhs.0 = f32[] parameter(2)
+ %add.0 = f32[] add(%lhs.0, %rhs.0)
+ %lhs.1 = f32[] parameter(1)
+ %rhs.1 = f32[] parameter(3)
+ %mul.1 = f32[] multiply(%lhs.1, %rhs.1)
+ ROOT %tuple = (f32[], f32[]) tuple(%add.0, %mul.1)
+}
+
+ENTRY main {
+ %indices = s32[4] parameter(0)
+ %updates.0 = f32[4] parameter(1)
+ %updates.1 = f32[4] parameter(2)
+ %operands.0 = f32[9] constant(0)
+ %operands.1 = f32[9] constant(0)
+ %scatter = (f32[9], f32[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_mul_F32
+ %gte.0 = f32[9] get-tuple-element(%scatter), index=0
+ %gte.1 = f32[9] get-tuple-element(%scatter), index=1
+ %consumer = f32[9] add(%gte.0, %gte.1)
+ ROOT %slice = f32[8] slice(%consumer), slice={[0:8]}
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+ auto expected_scatter =
+ m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
+ m::Parameter(0), m::Parameter(1), m::Parameter(2));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Add(m::GetTupleElement(expected_scatter),
+ m::GetTupleElement(expected_scatter))
+ .WithShape(F32, {8})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, ElementwiseSelect) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+ %indices = s32[4] parameter(0)
+ %updates = f32[4] parameter(1)
+ %operands = f32[9] constant(0)
+ %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+ %pred_ = pred[9] parameter(2)
+ %select = f32[9] select(%pred_, %scatter, %operands)
+ ROOT %slice = f32[8] slice(%select), slice={[0:8]}
+}
+ )")
+ .value();
+ ScatterSliceSimplifier test_pass;
+ ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+ auto expected_scatter =
+ m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Select(m::Slice(m::Parameter(2)), expected_scatter,
+ m::Slice(m::Constant()))
+ .WithShape(F32, {8})));
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc
new file mode 100644
index 0000000..9929b35
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc
@@ -0,0 +1,165 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/schedule_postprocessing.h"
+
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+// Maps a computation to a boolean that indicates whether the computation may
+// invoke custom-calls directly or indirectly, which can eventually trigger gpu
+// synchronization.
+using CustomCallInComputation =
+ absl::flat_hash_map<const HloComputation*, bool>;
+
+// Returns whether the hlo may invoke custom-calls which may trigger gpu
+// synchronization. Currently, we only check for custom-calls, because they are
+// the only operations that can be parallel with asynchronous collectives
+// operations in an hlo-schedule and may trigger gpu synchronization.
+bool MayInvokeCustomCall(
+ const HloInstruction* hlo,
+ const CustomCallInComputation& custom_call_in_computation) {
+ if (hlo->opcode() == HloOpcode::kCustomCall) {
+ return true;
+ }
+
+ return absl::c_any_of(
+ hlo->called_computations(), [&](const HloComputation* callee) {
+ return custom_call_in_computation.find(callee)->second;
+ });
+}
+
+// Returns true if this is an asynchronous collective start operation, excluding
+// P2P operations.
+absl::StatusOr<bool> IsRelevantAsynchronousStart(const HloInstruction* hlo) {
+ if (!hlo_query::IsAsyncCollectiveStartOp(hlo,
+ /*include_send_recv=*/false)) {
+ return false;
+ }
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ hlo->backend_config<GpuBackendConfig>());
+ const CollectiveBackendConfig& collective_backend_config =
+ gpu_config.collective_backend_config();
+ return !collective_backend_config.is_sync();
+}
+
+// Returns true if this is a collective done operation, excluding P2P
+// operations.
+absl::StatusOr<bool> IsRelevantAsynchronousDone(const HloInstruction* hlo) {
+ return hlo_query::IsAsyncCollectiveDoneOp(hlo,
+ /*include_send_recv=*/false);
+}
+
+// For a given computation, finds all the asynchronous collective operations
+// that aren't parallel with custom-calls and sets its no_parallel_custom_call
+// attribute to true. Also records whether the given computation may invoke
+// custom-calls.
+absl::StatusOr<bool> ProcessComputation(
+ const HloSchedule& schedule, HloComputation* computation,
+ CustomCallInComputation& custom_call_in_computation) {
+ bool changed = false;
+ bool has_custom_call = false;
+ absl::flat_hash_set<HloInstruction*> async_starts;
+ const HloInstructionSequence& sequence = schedule.sequence(computation);
+
+ // Visit instructions in the sequence. Collect relevant asynchronous
+ // collective start ops. When we see a relevant asynchronous collective done
+ // op, remove the corresponding start op from the collection and set its
+ // attribute no_parallel_custom_call to true. When we see a custom-call, clear
+ // the start ops from the collection and keep their attribute
+ // no_parallel_custom_call as false.
+ const std::vector<HloInstruction*>& all_instructions =
+ sequence.instructions();
+ for (HloInstruction* hlo : all_instructions) {
+ if (MayInvokeCustomCall(hlo, custom_call_in_computation)) {
+ async_starts.clear();
+ has_custom_call = true;
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(bool is_async_start, IsRelevantAsynchronousStart(hlo));
+ if (is_async_start) {
+ async_starts.insert(hlo);
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(bool is_async_done, IsRelevantAsynchronousDone(hlo));
+ if (is_async_done) {
+ HloInstruction* async_start = hlo->mutable_operand(0);
+ if (async_starts.contains(async_start)) {
+ changed = true;
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ async_start->backend_config<GpuBackendConfig>());
+ CollectiveBackendConfig& collective_backend_config =
+ *gpu_config.mutable_collective_backend_config();
+ collective_backend_config.set_no_parallel_custom_call(true);
+ TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
+ async_starts.erase(async_start);
+ }
+ }
+ }
+
+ custom_call_in_computation[computation] = has_custom_call;
+ return changed;
+}
+
+} // anonymous namespace
+
+absl::StatusOr<bool> SchedulePostprocessing::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ if (!module->has_schedule()) return false;
+ HloSchedule& schedule = module->schedule();
+ bool changed = false;
+ CustomCallInComputation custom_call_in_computation;
+
+ // We visit computations in the order of callees to callers, as information is
+ // propagated from calles to callers.
+ std::vector<HloComputation*> all_computations =
+ module->MakeComputationPostOrder(execution_threads);
+ for (auto iter = all_computations.begin(); iter != all_computations.end();
+ ++iter) {
+ HloComputation* computation = *iter;
+ if (computation->IsFusionComputation()) {
+ custom_call_in_computation[computation] = false;
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ bool result,
+ ProcessComputation(schedule, computation, custom_call_in_computation));
+ changed |= result;
+ }
+
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h
new file mode 100644
index 0000000..899098d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h
@@ -0,0 +1,50 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Amends a schedule result with the needed information to support a runtime
+// implementation. Currently, this pass refines attribute
+// no_parallel_custom_call for asynchronous collective operations to support
+// runtime optimization, such as skipping rendezvous of all participating
+// threads for NCCL collective operations. In particular, it sets the attribute
+// value for Collective-start operations with is_sync=false; it also keeps the
+// attribute value untouch for the operations with is_sync=true and for P2P
+// operations, assumming the runtime won't use those values.
+//
+class SchedulePostprocessing : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "schedule-postprocessing"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc
new file mode 100644
index 0000000..0c9c6e6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc
@@ -0,0 +1,163 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/schedule_postprocessing.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using SchedulePostprocessingTest = HloTestBase;
+
+TEST_F(SchedulePostprocessingTest, SynchronousOpsNotChanged) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+
+ ENTRY entry {
+ pf32 = f32[1] parameter(0)
+
+ all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
+ ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
+ }
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kHloString)));
+ SchedulePostprocessing pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(SchedulePostprocessingTest, P2POpsNotChanged) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+
+ ENTRY main {
+ f0 = f32[] constant(0.0)
+ init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+ after-all = token[] after-all()
+ recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2,
+ frontend_attributes={
+ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}"
+ }
+ recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2
+ ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0
+ }
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kHloString)));
+ SchedulePostprocessing pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(SchedulePostprocessingTest, AsynchronousOpsChanged) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+
+ ENTRY entry {
+ pf32 = f32[1] parameter(0)
+ pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
+ all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
+ ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
+ }
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kHloString)));
+ SchedulePostprocessing pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ start->backend_config<GpuBackendConfig>());
+ const CollectiveBackendConfig& collective_backend_config =
+ gpu_config.collective_backend_config();
+ EXPECT_TRUE(collective_backend_config.no_parallel_custom_call());
+}
+
+TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+
+ ENTRY entry {
+ pf32 = f32[1] parameter(0)
+ all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
+ pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
+ all-gather-done = f32[2] all-gather-done(all-gather-start)
+ ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
+ }
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kHloString)));
+ SchedulePostprocessing pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+ EXPECT_FALSE(changed);
+
+ HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ start->backend_config<GpuBackendConfig>());
+ const CollectiveBackendConfig& collective_backend_config =
+ gpu_config.collective_backend_config();
+ EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
+}
+
+TEST_F(SchedulePostprocessingTest,
+ AsynchronousOpsWithParallelNestedCustomcall) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+ foo {
+ v = f32[1] parameter(0)
+ ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call"
+ }
+
+ ENTRY entry {
+ pf32 = f32[1] parameter(0)
+ all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
+ pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo
+ all-gather-done = f32[2] all-gather-done(all-gather-start)
+ ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
+ }
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnUnverifiedModule((kHloString)));
+ SchedulePostprocessing pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+ EXPECT_FALSE(changed);
+
+ HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ start->backend_config<GpuBackendConfig>());
+ const CollectiveBackendConfig& collective_backend_config =
+ gpu_config.collective_backend_config();
+ EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc
new file mode 100644
index 0000000..d796213
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc
@@ -0,0 +1,74 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
+
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+// Populates `OpMetadata`'s `scheduling_name` field for all of the instructions
+// belonging to `computation`.
+absl::StatusOr<bool> AnnotateSchedulingInstructionNames(
+ HloComputation& computation) {
+ bool changed = false;
+ for (HloInstruction* inst : computation.instructions()) {
+ if (!inst->metadata().scheduling_name().empty()) {
+ continue;
+ }
+ // We skip constants as we might have to sanitize them in order to satisfy
+ // LLVM backend. I.e. we allow `GpuSanitizeConstantNames` pass to run post
+ // scheduling.
+ if (inst->opcode() == HloOpcode::kConstant) {
+ continue;
+ }
+ inst->set_metadata_scheduling_name(inst->name());
+ changed = true;
+ }
+ return changed;
+}
+
+} // namespace
+
+absl::StatusOr<bool> SchedulingInstructionAnnotator::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ CHECK(module->has_schedule())
+ << "The pass is supposed to run in the beginning of post-scheduling!";
+ bool changed = false;
+
+ // We visit computations in the order of callees to callers, as information is
+ // propagated from calles to callers.
+ for (HloComputation* computation :
+ module->MakeComputationPostOrder(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result,
+ AnnotateSchedulingInstructionNames(*computation));
+ changed |= result;
+ }
+
+ return changed;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h
new file mode 100644
index 0000000..03c21bb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h
@@ -0,0 +1,44 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// The pass amends the `OpMetadata` with instruction name present at the
+// scheduling time. This is later being used to make sure instructions are not
+// renamed post scheduling. Enforcing this is necessary because otherwise
+class SchedulingInstructionAnnotator : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "scheduling-instruction-annotator";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc
new file mode 100644
index 0000000..abe8d50
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc
@@ -0,0 +1,131 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using SchedulingInstructionAnnotatorTest = HloTestBase;
+
+TEST_F(SchedulingInstructionAnnotatorTest,
+ AnnotatesAllInstructionsWithTheirRespectiveNames) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+
+ ENTRY entry {
+ p0 = f32[1] parameter(0)
+ p1 = f32[1] parameter(1)
+ add0 = f32[1] add(p0,p1)
+ ROOT exp0 = f32[1] exponential(add0)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ SchedulingInstructionAnnotator pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+
+ ASSERT_TRUE(changed);
+ for (const auto* comp : module->computations()) {
+ for (const auto* instruction : comp->instructions()) {
+ EXPECT_EQ(instruction->name(), instruction->metadata().scheduling_name());
+ }
+ }
+ constexpr absl::string_view kExpected = R"(
+// CHECK: %[[P0:.+]] = {{.*}} parameter(0)
+// CHECK-SAME: scheduling_name="[[P0]]"
+// CHECK: %[[P1:.+]] = {{.*}} parameter(1)
+// CHECK-SAME: scheduling_name="[[P1]]"
+// CHECK: %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[P1]])
+// CHECK-SAME: scheduling_name="[[ADD0]]"
+// CHECK: ROOT %[[EXP0:.+]] = {{.*}} exponential(%[[ADD0]])
+// CHECK-SAME: scheduling_name="[[EXP0]]"
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(
+ module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+ kExpected));
+ EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(SchedulingInstructionAnnotatorTest, SkipsAnnotatingConstants) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+
+ ENTRY entry {
+ p0 = f32[1] parameter(0)
+ c1 = f32[1] constant(42)
+ ROOT add0 = f32[1] add(p0, c1)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ SchedulingInstructionAnnotator pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+
+ ASSERT_TRUE(changed);
+ constexpr absl::string_view kExpected = R"(
+// CHECK: %[[P0:.+]] = {{.*}} parameter(0)
+// CHECK-SAME: scheduling_name="[[P0]]"
+// CHECK-NEXT: %[[C1:.+]] = f32[1]
+// CHECK-NOT: scheduling_name
+// CHECK-SAME: constant({42})
+// CHECK: %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[C1]])
+// CHECK-SAME: scheduling_name="[[ADD0]]"
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(
+ module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+ kExpected));
+ EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(SchedulingInstructionAnnotatorTest,
+ DoesNotAnnotateAllInstructionsWithTheirRespectiveNames) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule module, is_scheduled=true
+
+ ENTRY entry {
+ p0 = f32[1] parameter(0), metadata={scheduling_name="p0"}
+ p1 = f32[1] parameter(1), metadata={scheduling_name="p1"}
+ add0 = f32[1] add(p0,p1), metadata={scheduling_name="add0"}
+ ROOT exp0 = f32[1] exponential(add0), metadata={scheduling_name="exp0"}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ SchedulingInstructionAnnotator pass;
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+
+ EXPECT_FALSE(changed);
+}
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc
new file mode 100644
index 0000000..90bc607
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc
@@ -0,0 +1,797 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/softmax_rewriter_triton.h"
+
+#include <cstdint>
+#include <functional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/model/fusion_analysis_cache.h"
+#include "xla/service/gpu/model/gpu_indexing_performance_model.h"
+#include "xla/service/gpu/model/symbolic_tile_analysis.h"
+#include "xla/service/gpu/model/tiled_hlo_computation.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using hlo_query::IsBroadcastOfParameter;
+using hlo_query::IsBroadcastOfScalarConstant;
+
+bool HasDefaultLayout(const Shape& shape) {
+ return shape.has_layout() &&
+ LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
+}
+
+// Returns true if a trivially connected producer of 'consumer' with opcode
+// 'opcode' exists. If such an instruction is found, the value of 'producer' is
+// set to it. The definition of "trivial" operations is as given in
+// 'IsTriviallyFusible'.
+bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
+ HloOpcode opcode, const se::GpuComputeCapability& gpu_version);
+
+bool BitcastIsTilingNoop(HloInstruction* bitcast,
+ const se::GpuComputeCapability& gpu_version) {
+ CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
+
+ if (ShapeUtil::IsEffectiveScalar(bitcast->shape())) {
+ return true;
+ }
+
+ // In the Softmax rewriter for now, tiling is derived from a hero reduction
+ // operation, which should be reducing its input on the last axis. Therefore,
+ // a bitcast is always a no-op with regards to a tile if
+ // (1) it does not change the size of the reduction dimension of its input
+ // (the last one); if its input is already reduced, then (1) is true
+ // by default
+ // (2) the layout of its output is ordered in the same way as the layout of
+ // its input. This is a fuzzy definition, but since we assume fusible
+ // ops to always have a default layout, we can just check if both the
+ // bitcast and its input have a default layout
+ auto last_dimension = [](const HloInstruction* instr) {
+ return instr->shape().dimensions().back();
+ };
+
+ HloInstruction* reduce = nullptr;
+ TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce,
+ gpu_version);
+
+ return (HasDefaultLayout(bitcast->shape()) &&
+ HasDefaultLayout(bitcast->operand(0)->shape()) &&
+ (reduce != nullptr ||
+ last_dimension(bitcast->operand(0)) == last_dimension(bitcast)));
+}
+
+inline bool HasOneUse(const HloInstruction* instr) {
+ return instr->user_count() == 1;
+}
+
+// Supports two types of broadcast of parameters. Either to one batch
+// dim, or one reduction dim. For example the following cases are supported:
+//
+// Case #1:
+// p = f32[a] parameter(0)
+// b = f32[a,x] broadcast(p), dimensions={0}
+//
+// Case #2:
+// p = f32[a] parameter(0)
+// b = f32[x,a] broadcast(p), dimensions={1}
+//
+// Case #3:
+// p = f32[a,b] parameter(0)
+// b = f32[x,a,b] broadcast(p), dimensions={1,2}
+//
+// Other broadcast tiling patterns are currently unsupported.
+// See b/328049138 for details.
+//
+// Unsupported case #1:
+// p = f32[a] parameter(0)
+// b = f32[x,a,y] broadcast(p), dimensions={1}
+//
+// Unsupported case #2:
+// p = f32[a,b] parameter(0)
+// b = f32[x,a,y,b] broadcast(p), dimensions={1,3}
+//
+// Unsupported case #3:
+// p = f32[a] parameter(0)
+// b = f32[x,y,a] broadcast(p), dimensions={2}
+//
+// Unsupported case #4:
+// p = f32[a,b] parameter(0)
+// b = f32[a,x,b] broadcast(p), dimensions={0,2}
+bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) {
+ CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
+ << "Expected broadcast " << hlo.ToShortString();
+ CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
+ << "Expected parameter " << hlo.operand(0)->ToShortString();
+
+ const HloBroadcastInstruction* broadcast =
+ Cast<HloBroadcastInstruction>(&hlo);
+
+ const HloParameterInstruction* parameter =
+ Cast<HloParameterInstruction>(hlo.operand(0));
+
+ // Support only one dim broadcast.
+ if (parameter->shape().dimensions_size() + 1 !=
+ broadcast->shape().dimensions_size()) {
+ return false;
+ }
+
+ // It is enough to ensure that the broadcast does not preserve both last, and
+ // first dimensions of the parameter at the same time. Otherwise the broadcast
+ // is the unsupported case #4.
+ //
+ // Preserve the first dim:
+ // p = f32[a,b] parameter(0)
+ // b1 = f32[a,b,c] broadcast(p), dimensions={0,1}
+ bool preserve_first_dim = broadcast->dimensions().front() == 0;
+ // Preserve the last dim:
+ // p = f32[a,b] parameter(0)
+ // b1 = f32[c,a,b] broadcast(p), dimensions={1,2}
+ bool preserve_last_dim = broadcast->dimensions().back() ==
+ broadcast->shape().dimensions_size() - 1;
+ // We do not want to preserve both first and last dim, as it means the
+ // broadcast is not expanding on outermost dims.
+ return !(preserve_first_dim && preserve_last_dim);
+}
+
+bool IsBroadcastOfAScalar(const HloInstruction& hlo) {
+ CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
+ << "Expected broadcast " << hlo.ToShortString();
+ return ShapeUtil::IsScalar(hlo.operand(0)->shape());
+}
+
+bool IsSingleRowParameterBroadcast(const HloInstruction& hlo) {
+ CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
+ << "Expected broadcast " << hlo.ToShortString();
+ CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
+ << "Expected parameter " << hlo.operand(0)->ToShortString();
+
+ const HloBroadcastInstruction* broadcast =
+ Cast<HloBroadcastInstruction>(&hlo);
+ const HloParameterInstruction* parameter =
+ Cast<HloParameterInstruction>(hlo.operand(0));
+
+ if (parameter->shape().dimensions_size() != 1) {
+ return false;
+ }
+ return broadcast->dimensions()[0] == broadcast->shape().dimensions_size() - 1;
+}
+
+bool IsSupportedBroadcastOfParameter(const HloInstruction& hlo) {
+ return IsBroadcastOfParameter(hlo) &&
+ (IsBatchOrReductionDimBroadcast(hlo) || IsBroadcastOfAScalar(hlo) ||
+ IsSingleRowParameterBroadcast(hlo));
+}
+
+// Chooses which operand to use for fusion processing. Taking in a unary or
+// binary instruction, returns the first non-splat operand. If none is
+// present, returns any operand.
+HloInstruction* ChooseOperandForFusionProcessing(HloInstruction* instr) {
+ CHECK_GT(instr->operand_count(), 0);
+ CHECK_LE(instr->operand_count(), 2);
+
+ // TODO(b/326217416): Extend the broadcast of splat constants/parameters to a
+ // broadcast of any op.
+ if (instr->operand_count() > 1 &&
+ (IsBroadcastOfScalarConstant(*instr->operand(0)) ||
+ IsSupportedBroadcastOfParameter(*instr->operand(0)))) {
+ return instr->mutable_operand(1);
+ }
+ return instr->mutable_operand(0);
+}
+
+bool IsTriviallyFusible(HloInstruction* instr,
+ const se::GpuComputeCapability& gpu_version,
+ int num_allowed_users = 1) {
+ // Checks whether an op is trivially fusible. An op is said to be trivially
+ // fusible if it does not increase the amount of memory read/written by the
+ // resulting fusion, is compatible with any chosen tiling, and can be
+ // codegen'd using Triton. The op is allowed to have up to num_allowed_users
+ // users.
+ if (instr->user_count() > num_allowed_users ||
+ !HasDefaultLayout(instr->shape())) {
+ return false;
+ }
+
+ if (instr->opcode() == HloOpcode::kBitcast &&
+ BitcastIsTilingNoop(instr, gpu_version)) {
+ return true;
+ }
+
+ if (instr->IsElementwise() && instr->operand_count() == 1) {
+ return static_cast<bool>(IsTritonSupportedInstruction(*instr, gpu_version));
+ }
+
+ // Elementwise binary ops are trivially fusible if the operands are the same,
+ // or if exactly one of the operands is a splat constant.
+ if (instr->IsElementwiseBinary()) {
+ const HloInstruction* operand_0 = instr->operand(0);
+ const HloInstruction* operand_1 = instr->operand(1);
+
+ // Elementwise binary ops should be fused if both operands are the same and
+ // if the operand is triton supported.
+ if (operand_0 == operand_1) {
+ return static_cast<bool>(
+ IsTritonSupportedInstruction(*instr, gpu_version));
+ }
+
+ // For simplicity we only fuse elementwise binary ops with splat operands
+ // if they contain one non-splat operand.
+ // TODO(b/326217416): Extend the broadcast of splat constants/parameters to
+ // a broadcast of any op.
+ if ((IsBroadcastOfScalarConstant(*operand_0) ||
+ IsSupportedBroadcastOfParameter(*operand_0)) ^
+ (IsBroadcastOfScalarConstant(*operand_1) ||
+ IsSupportedBroadcastOfParameter(*operand_1))) {
+ return static_cast<bool>(
+ IsTritonSupportedInstruction(*instr, gpu_version));
+ }
+ }
+
+ return false;
+}
+
+bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
+ HloOpcode opcode,
+ const se::GpuComputeCapability& gpu_version) {
+ while (consumer->opcode() != opcode) {
+ if (IsTriviallyFusible(consumer, gpu_version)) {
+ consumer = ChooseOperandForFusionProcessing(consumer);
+ } else {
+ return false;
+ }
+ }
+
+ *producer = consumer;
+ return true;
+}
+
+bool IsTriviallyConnectedProducerOf(
+ HloInstruction* producer, HloInstruction* consumer,
+ const se::GpuComputeCapability& gpu_version) {
+ if (producer == consumer) {
+ return true;
+ }
+
+ HloInstruction* found_producer = consumer;
+ while (
+ TrivialEdge(&found_producer, consumer, producer->opcode(), gpu_version)) {
+ if (found_producer == producer) {
+ return true;
+ }
+
+ if (!IsTriviallyFusible(found_producer, gpu_version)) {
+ return false;
+ }
+
+ consumer = found_producer->mutable_operand(0);
+ }
+
+ return false;
+}
+
+// Finds the first non-fusible producer of a diamond. This instruction is either
+// 1. the direct producer of the diamond, if that producer is used more than
+// twice and/or is not otherwise trivially fusible
+// 2. the first parent instruction of the producer of the diamond such that
+// that instruction is used more than once, and/or is not trivially
+// fusible.
+HloInstruction* FindFirstNonFusibleDiamondProducer(
+ HloInstruction* diamond_producer,
+ const se::GpuComputeCapability& gpu_version) {
+ if (IsTriviallyFusible(diamond_producer, gpu_version,
+ /*num_allowed_users=*/2)) {
+ diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
+ while (IsTriviallyFusible(diamond_producer, gpu_version)) {
+ diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
+ }
+ }
+
+ return diamond_producer;
+}
+
+// Creates a fusion corresponding to the input diamond chain. The resulting
+// fusion instruction is added to the module, but is not yet inserted into the
+// graph as a replacement of the original instructions.
+//
+// TODO(b/347956491): this awkward abstraction is needed to work around
+// limitations of HloFusionAdaptor, which underpins the implementation of
+// SymbolicTileAnalysis. We need to come up with a better solution.
+absl::StatusOr<HloFusionInstruction*> MakeFusionForDiamondChain(
+ const DiamondChainDescriptor& diamond_chain) {
+ auto [root, producer] = diamond_chain;
+
+ std::string suggested_name = "triton_softmax";
+ HloComputation::Builder builder(absl::StrCat(suggested_name, "_computation"));
+ // Original instruction -> fused one.
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*>
+ old_to_new_mapping;
+
+ int param = 0;
+ old_to_new_mapping[producer] =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ param, producer->shape(), absl::StrCat("parameter_", param)));
+ param++;
+
+ std::vector<HloInstruction*> parameters = {producer};
+
+ std::function<void(HloInstruction*)> create_computation =
+ [&](HloInstruction* instr) -> void {
+ if (old_to_new_mapping.contains(instr)) {
+ return;
+ }
+ std::vector<HloInstruction*> new_operands;
+ for (HloInstruction* operand : instr->mutable_operands()) {
+ create_computation(operand);
+ new_operands.push_back(old_to_new_mapping[operand]);
+ }
+ if (instr->opcode() == HloOpcode::kParameter) {
+ old_to_new_mapping[instr] =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ param, instr->shape(), absl::StrCat("parameter_", param)));
+ parameters.push_back(instr);
+ param++;
+ } else {
+ old_to_new_mapping[instr] = builder.AddInstruction(
+ instr->CloneWithNewOperands(instr->shape(), new_operands));
+ }
+ };
+ create_computation(root);
+
+ HloComputation* computation =
+ root->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
+ /*is_entry=*/false);
+
+ HloInstruction* softmax_fusion =
+ root->parent()->AddInstruction(HloInstruction::CreateFusion(
+ root->shape(), HloInstruction::FusionKind::kCustom, parameters,
+ computation));
+
+ softmax_fusion->GetModule()->SetAndUniquifyInstrName(softmax_fusion,
+ "triton_softmax");
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ softmax_fusion->backend_config<GpuBackendConfig>());
+ FusionBackendConfig& backend_config =
+ *gpu_config.mutable_fusion_backend_config();
+ backend_config.set_kind(std::string(kTritonFusionKind));
+ TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(gpu_config));
+ return xla::Cast<HloFusionInstruction>(softmax_fusion);
+}
+
+absl::Status FuseDiamondChainImpl(
+ const DiamondChainDescriptor& diamond_chain,
+ GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model) {
+ TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
+ MakeFusionForDiamondChain(diamond_chain));
+ HloInstruction* root = diamond_chain.root;
+
+ auto fusion_adaptor = HloFusionAdaptor::ForInstruction(softmax_fusion);
+
+ TF_ASSIGN_OR_RETURN(
+ TiledRunTimeDataOrError tiled_runtime_data_or,
+ indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor));
+
+ if (const auto* fusion_decision =
+ std::get_if<FusionDecision>(&tiled_runtime_data_or)) {
+ return absl::FailedPreconditionError(absl::StrCat(
+ "SymbolicTileAnalysis failed. ", fusion_decision->Explain()));
+ }
+
+ TiledRunTimeData tiled_runtime_data =
+ std::get<TiledRunTimeData>(std::move(tiled_runtime_data_or));
+
+ TF_ASSIGN_OR_RETURN(auto backend_config,
+ softmax_fusion->backend_config<GpuBackendConfig>());
+ *backend_config.mutable_fusion_backend_config()
+ ->mutable_block_level_fusion_config() =
+ tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig();
+ TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config));
+
+ if (root->IsRoot()) {
+ root->parent()->set_root_instruction(softmax_fusion);
+ TF_RETURN_IF_ERROR(
+ root->parent()->RemoveInstructionAndUnusedOperands(root));
+ } else {
+ TF_RETURN_IF_ERROR(
+ root->parent()->ReplaceInstruction(root, softmax_fusion));
+ }
+
+ VLOG(5) << softmax_fusion->ToString();
+ return absl::OkStatus();
+}
+
+// Returns `true` if the diamond chain passed as a parameter can be tiled
+// correctly using `SymbolicTileAnalysis`.
+absl::StatusOr<bool> CanSymbolicTileAnalysisTileDiamondChain(
+ const DiamondChainDescriptor& diamond_chain) {
+ TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
+ MakeFusionForDiamondChain(diamond_chain));
+ mlir::MLIRContext context;
+ SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error =
+ SymbolicTileAnalysis::AnalyzeComputation(
+ *softmax_fusion->called_computation(), &context);
+
+ bool can_tile = std::holds_alternative<SymbolicTileAnalysis>(
+ symbolic_tile_analysis_or_error);
+
+ TF_RETURN_IF_ERROR(diamond_chain.root->GetModule()->RemoveEmbeddedComputation(
+ softmax_fusion->called_computation()));
+ TF_RETURN_IF_ERROR(
+ diamond_chain.root->parent()->RemoveInstruction(softmax_fusion));
+
+ return can_tile;
+}
+
+FusionDecision ShouldFuseReduction(const HloInstruction& reduce,
+ const se::GpuComputeCapability& cc) {
+ if (CodegenDecision is_supported = IsTritonSupportedInstruction(reduce, cc);
+ !is_supported) {
+ return FusionDecision(is_supported.Explain());
+ }
+
+ // Ensure that the reduction's identity is either a constant or a supported
+ // convert of a constant.
+ const HloInstruction* identity = reduce.operand(1);
+ bool should_fuse_identity =
+ identity->opcode() == HloOpcode::kConstant ||
+ (identity->opcode() == HloOpcode::kConvert &&
+ identity->operand(0)->opcode() == HloOpcode::kConstant &&
+ IsTritonSupportedInstruction(*identity, cc));
+ if (!should_fuse_identity) {
+ return "Reduction identity is not a constant or a supported convert of a "
+ "constant.";
+ }
+
+ return {};
+}
+
+DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl(
+ HloInstruction* instr, const se::GpuComputeCapability& cc) {
+ if (!instr->IsElementwiseBinary()) {
+ return "Root is not elementwise binary.";
+ }
+
+ if (!IsTritonSupportedInstruction(*instr, cc)) {
+ return "Root is not supported for Triton instruction.";
+ }
+
+ HloInstruction* producer;
+ HloInstruction* broadcast;
+ HloInstruction* reduce;
+
+ if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast,
+ cc)) {
+ return "Could not find a trivial connection from root to a broadcast.";
+ }
+
+ if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce,
+ cc)) {
+ return "Could not find a trivial connection from matched broadcast to a "
+ "reduction.";
+ }
+
+ if (!(HasDefaultLayout(broadcast->shape()) &&
+ HasDefaultLayout(reduce->shape()))) {
+ return "Broadcast or reduce have non-default layouts.";
+ }
+
+ if (FusionDecision should_fuse_reduction = ShouldFuseReduction(*reduce, cc);
+ !should_fuse_reduction) {
+ VLOG(2) << should_fuse_reduction.Explain();
+ return should_fuse_reduction;
+ }
+
+ // Ensure that the reduction's identity is either a constant or a supported
+ // convert of a constant.
+ const HloInstruction* identity = reduce->operand(1);
+ bool should_fuse_identity =
+ identity->opcode() == HloOpcode::kConstant ||
+ (identity->opcode() == HloOpcode::kConvert &&
+ identity->operand(0)->opcode() == HloOpcode::kConstant &&
+ IsTritonSupportedInstruction(*identity, cc));
+ if (!should_fuse_identity) {
+ return "Reduction identity is not a constant or a supported convert of a "
+ "constant.";
+ }
+
+ if (!HasOneUse(broadcast) || !HasOneUse(reduce)) {
+ return "More than one use of broadcast or reduce.";
+ }
+
+ producer = reduce->mutable_operand(0);
+
+ if (absl::c_linear_search(broadcast->dimensions(),
+ broadcast->shape().rank() - 1)) {
+ return "Broadcast is not along the reduction dimension.";
+ }
+
+ while (IsTriviallyFusible(producer, cc)) {
+ producer = ChooseOperandForFusionProcessing(producer);
+ }
+
+ if (!HasDefaultLayout(producer->shape())) {
+ return "Producer has non-default layout.";
+ }
+
+ if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0),
+ cc)) {
+ return "Producer is not trivially connected.";
+ }
+
+ if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) {
+ return "Unsupported root-producer connection.";
+ }
+
+ VLOG(5) << "Matched Softmax diamond with: ";
+ VLOG(5) << "root: " << instr->ToString();
+ VLOG(5) << "producer: " << producer->ToString();
+ VLOG(5) << "broadcast: " << broadcast->ToString();
+ VLOG(5) << "reduce: " << reduce->ToString();
+
+ return producer;
+}
+
+// Returns a vector containing all the single diamonds in the parameter module.
+// The diamonds are returned in def-before-use order, and grouped by
+// computation.
+absl::StatusOr<std::vector<DiamondChainDescriptor>> FindAllFusibleDiamonds(
+ HloModule& module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads,
+ const se::GpuComputeCapability& cc) {
+ std::vector<DiamondChainDescriptor> matched_diamonds;
+
+ for (HloComputation* comp :
+ module.MakeNonfusionComputations(execution_threads)) {
+ if (comp->IsCustomCallComputation()) {
+ continue;
+ }
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ auto producer =
+ MatchesTritonCompatibleClosedReductionDiamondImpl(instr, cc);
+ if (std::holds_alternative<HloInstruction*>(producer)) {
+ DiamondChainDescriptor diamond_chain{
+ /*root=*/instr, /*producer=*/std::get<HloInstruction*>(producer)};
+ // We filter out the diamond chains that cannot be tiled correctly using
+ // `SymbolicTileAnalysis`.
+ TF_ASSIGN_OR_RETURN(
+ bool can_tile_diamond_chain,
+ CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
+ if (can_tile_diamond_chain) {
+ matched_diamonds.push_back(diamond_chain);
+ } else {
+ VLOG(5) << "Cannot tile the diamond pattern described by "
+ << "instructions " << instr->ToString() << " and "
+ << std::get<HloInstruction*>(producer)->ToString() << ".";
+ continue;
+ }
+
+ } else {
+ VLOG(5) << "Cannot match the diamond pattern for instruction "
+ << instr->ToString()
+ << ". Reason: " << std::get<FusionDecision>(producer).Explain();
+ }
+ }
+ }
+
+ return std::move(matched_diamonds);
+}
+
+// Returns the size of the reduction dimension of the input diamond.
+int64_t GetReductionDimensionSizeForDiamond(
+ const DiamondChainDescriptor& diamond_chain) {
+ HloInstruction* diamond_root = diamond_chain.root;
+ HloInstruction* instr = diamond_root->mutable_operand(1);
+ while (instr->opcode() != HloOpcode::kReduce) {
+ instr = ChooseOperandForFusionProcessing(instr);
+ }
+
+ int operand_rank = instr->operand(0)->shape().rank();
+ CHECK_EQ(instr->dimensions().size(), 1);
+ CHECK_EQ(instr->dimensions(0), operand_rank - 1);
+ return instr->operand(0)->shape().dimensions(operand_rank - 1);
+}
+
+// Returns a pointer to the last user of `instr` that is trivially fusible.
+HloInstruction* GetLastTriviallyFusibleUser(
+ HloInstruction* instr, const se::GpuComputeCapability& cc) {
+ while (HasOneUse(instr) && !instr->IsRoot() &&
+ IsTriviallyFusible(instr->users().front(), cc)) {
+ instr = instr->users().front();
+ }
+
+ // We do not care about the number of users for the last instruction of the
+ // fusion, so attempt to fuse one more instruction with this relaxed
+ // restriction.
+ if (HasOneUse(instr) && !instr->IsRoot() &&
+ IsTriviallyFusible(
+ instr->users().front(), cc,
+ /*num_allowed_users=*/instr->users().front()->user_count())) {
+ instr = instr->users().front();
+ }
+ return instr;
+}
+
+} // anonymous namespace
+
+DiamondMatchingDecision
+SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
+ HloInstruction* instr) const {
+ return MatchesTritonCompatibleClosedReductionDiamondImpl(
+ instr, device_info_.gpu_compute_capability());
+}
+
+absl::StatusOr<std::vector<DiamondChainDescriptor>>
+SoftmaxRewriterTriton::FindAllFusibleDiamondChains(
+ HloModule& module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) const {
+ const se::GpuComputeCapability& cc = device_info_.gpu_compute_capability();
+ TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> matched_diamonds,
+ FindAllFusibleDiamonds(module, execution_threads, cc));
+
+ if (matched_diamonds.empty()) {
+ return std::vector<DiamondChainDescriptor>();
+ }
+
+ // If we matched several diamonds, it may be possible for some of them to be
+ // fused together. This is the case if the following conditions hold:
+ // 1. The path between the root of diamond n towards the producer of
+ // diamond n+1 is composed only of trivially fusible operations. In that
+ // case, the first non-trivially fusible producer of diamond n+1 must be
+ // exactly the root of diamond n.
+ // 2. The root of diamond n/first non-fusible producer of diamond n+1 must
+ // have
+ // a. exactly one user if it is not exactly the producer of diamond
+ // n+1;
+ // b/ exactly two users otherwise.
+ // 3. The axis being reduced must have the same length in all the diamonds
+ // being fused together.
+ //
+ // Crucially, this approach relies on a diamond root never being considered a
+ // trivially fusible operation.
+ std::vector<DiamondChainDescriptor> diamond_chains;
+ diamond_chains.reserve(matched_diamonds.size());
+
+ HloInstruction* current_fusion_producer =
+ FindFirstNonFusibleDiamondProducer(matched_diamonds.front().producer, cc);
+ int current_reduce_dimension_size =
+ GetReductionDimensionSizeForDiamond(matched_diamonds.front());
+
+ for (int diamond_idx = 1; diamond_idx < matched_diamonds.size();
+ ++diamond_idx) {
+ HloInstruction* diamond_producer = matched_diamonds[diamond_idx].producer;
+ HloInstruction* previous_diamond_root =
+ matched_diamonds[diamond_idx - 1].root;
+
+ HloInstruction* first_non_fusible_diamond_producer =
+ FindFirstNonFusibleDiamondProducer(diamond_producer, cc);
+
+ int diamond_reduce_dimension_size =
+ GetReductionDimensionSizeForDiamond(matched_diamonds[diamond_idx]);
+
+ if (first_non_fusible_diamond_producer == previous_diamond_root && // 1
+ ((first_non_fusible_diamond_producer != diamond_producer &&
+ HasOneUse(first_non_fusible_diamond_producer)) || // 2.a
+ (first_non_fusible_diamond_producer == diamond_producer &&
+ first_non_fusible_diamond_producer->user_count() == 2)) && // 2.b
+ diamond_reduce_dimension_size == current_reduce_dimension_size) { // 3
+ continue;
+ }
+
+ // The "last trivially fusible user" chain of diamond chain n should never
+ // intersect with the "first non fusible diamond producer" chain of diamond
+ // chain n+1: if these chains intersected, then all the intermediate ops
+ // between the diamond chains could be trivially fused, and both diamond
+ // chains could be fused into a single diamond chain. Note that this only
+ // holds insofar as we do not allow fusing in bitcasts that modify the last
+ // dimension of the input array. It is however possible for the last
+ // trivially fusible user of diamond chain n to be the first non fusible
+ // diamond producer of diamond chain n+1.
+ diamond_chains.push_back(DiamondChainDescriptor{
+ GetLastTriviallyFusibleUser(previous_diamond_root, cc),
+ current_fusion_producer,
+ });
+
+ current_fusion_producer = first_non_fusible_diamond_producer;
+ current_reduce_dimension_size = diamond_reduce_dimension_size;
+ }
+
+ // The last diamond chain is still open; close it.
+ diamond_chains.push_back(DiamondChainDescriptor{
+ GetLastTriviallyFusibleUser(matched_diamonds.back().root, cc),
+ current_fusion_producer});
+
+ // We filter out the diamond chains that cannot be tiled correctly using
+ // `SymbolicTileAnalysis`.
+ std::vector<DiamondChainDescriptor> filtered_diamond_chains;
+ for (const DiamondChainDescriptor& diamond_chain : diamond_chains) {
+ TF_ASSIGN_OR_RETURN(bool can_tile_diamond_chain,
+ CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
+ if (can_tile_diamond_chain) {
+ filtered_diamond_chains.push_back(diamond_chain);
+ }
+ }
+ return filtered_diamond_chains;
+}
+
+absl::Status SoftmaxRewriterTriton::FuseDiamondChain(
+ const DiamondChainDescriptor& diamond_chain) {
+ HloFusionAnalysisCache fusion_analysis_cache(device_info_);
+ GpuPerformanceModelWithIndexingAnalysis indexing_performance_model(
+ &device_info_, &fusion_analysis_cache, shape_size_, &mlir_context_);
+
+ return FuseDiamondChainImpl(diamond_chain, indexing_performance_model);
+}
+
+absl::StatusOr<bool> SoftmaxRewriterTriton::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability(
+ device_info_.gpu_compute_capability()));
+
+ TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
+ FindAllFusibleDiamondChains(*module, execution_threads));
+
+ if (diamond_chains.empty()) {
+ return false;
+ }
+
+ // The diamond chains must be emitted in reverse order, to make sure that
+ // producer instructions are emitted correctly when the root of
+ // diamond chain n is exactly the producer of diamond chain n+1.
+ for (auto diamond_chain = diamond_chains.rbegin();
+ diamond_chain != diamond_chains.rend(); ++diamond_chain) {
+ TF_RET_CHECK(FuseDiamondChain(*diamond_chain).ok());
+ }
+ return true;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h
new file mode 100644
index 0000000..36f780f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h
@@ -0,0 +1,101 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_
+
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+struct DiamondChainDescriptor {
+ HloInstruction* root = nullptr;
+ HloInstruction* producer = nullptr;
+};
+
+using DiamondMatchingDecision = std::variant<FusionDecision, HloInstruction*>;
+
+// Rewrite compatible Softmax into a custom fusion region to be code-generated
+// with the Triton-based Softmax emitter.
+class SoftmaxRewriterTriton : public HloModulePass {
+ public:
+ explicit SoftmaxRewriterTriton(const se::DeviceDescription& device_info,
+ HloCostAnalysis::ShapeSizeFunction shape_size)
+ : device_info_(device_info), shape_size_(shape_size) {}
+
+ absl::string_view name() const override { return "triton-softmax-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ // Finds and returns all the fusible diamond chains in the module. The
+ // resulting vector is sorted according to a post-order matching (i.e. within
+ // the same computation, producer diamonds appear before consumer diamonds).
+ absl::StatusOr<std::vector<DiamondChainDescriptor>>
+ FindAllFusibleDiamondChains(
+ HloModule& module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) const;
+
+ // Constructs a Softmax fusion containing all the instructions between the
+ // root and the producer of a diamond chain. The producer is excluded from the
+ // fusion.
+ absl::Status FuseDiamondChain(const DiamondChainDescriptor& diamond_chain);
+
+ // Return the producer of the following pattern:
+ //
+ // producer
+ // | \
+ // | reduce_{max,sum,...}
+ // | |
+ // | broadcast
+ // | /
+ // binop (elementwise)
+ //
+ // where each edge is allowed to contain also trivial operations that can be
+ // generated by Triton. We mean by "trivial" here those operations that do not
+ // increase the amount of memory read/written by the fusion, and that are
+ // compatible with any chosen tiling.
+ //
+ // We also assume that the reduction is done on the last axis of the producer
+ // array.
+ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamond(
+ HloInstruction* instr) const;
+
+ private:
+ const se::DeviceDescription& device_info_;
+ const HloCostAnalysis::ShapeSizeFunction shape_size_;
+ mlir::MLIRContext mlir_context_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc
new file mode 100644
index 0000000..1b3139c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc
@@ -0,0 +1,1590 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/softmax_rewriter_triton.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using ::testing::HasSubstr;
+
+GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() {
+ return [&](const Shape& shape) {
+ constexpr int64_t kPointerSize = 8;
+ return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+ };
+}
+
+bool HasBlockLevelFusionConfig(const HloInstruction* fusion) {
+ return fusion->opcode() == HloOpcode::kFusion &&
+ fusion->has_backend_config() &&
+ fusion->backend_config<GpuBackendConfig>().ok() &&
+ fusion->backend_config<GpuBackendConfig>()
+ ->fusion_backend_config()
+ .has_block_level_fusion_config();
+}
+
+// Wrapper around SoftmaxRewriterTriton(gpu_version).Run(module) that finds
+// and fuses as many diamond chains as possible without invoking any kind of
+// cost analysis.
+absl::StatusOr<bool> SoftmaxRewriterTritonMatchAndRewrite(
+ const se::DeviceDescription& device_info, HloModule* module) {
+ CHECK_NE(module, nullptr);
+ SoftmaxRewriterTriton softmax_rewriter_triton(device_info,
+ ShapeSizeBytesFunction());
+ TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
+ softmax_rewriter_triton.FindAllFusibleDiamondChains(
+ *module, /*execution_threads=*/{}));
+
+ for (auto diamond_chain = diamond_chains.rbegin();
+ diamond_chain != diamond_chains.rend(); ++diamond_chain) {
+ TF_RETURN_IF_ERROR(
+ softmax_rewriter_triton.FuseDiamondChain(*diamond_chain));
+ }
+
+ return !diamond_chains.empty();
+}
+
+class SoftmaxRewriterTritonTest
+ : public HloTestBase,
+ public ::testing::WithParamInterface<PrimitiveType> {
+ protected:
+ se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()};
+};
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseExactSoftmaxF32) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ exponential = f32[127,125]{1,0} exponential(subtract)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ VLOG(2) << module->ToString();
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseSoftmaxLikeComputationWithNonF32DataType) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f16[] parameter(0)
+ arg_1 = f16[] parameter(1)
+ ROOT maximum = f16[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f16[] parameter(0)
+ arg_1.1 = f16[] parameter(1)
+ ROOT add = f16[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f16[127,125]{1,0} parameter(0)
+ constant_neg_inf = f16[] constant(-inf)
+ reduce = f16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f16[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f16[127,125]{1,0} subtract(param_0, broadcast)
+ // Replace Softmax exponential with abs, because Triton doesn't support
+ // non-f32 exponentials.
+ abs = f16[127,125]{1,0} abs(subtract)
+ constant_zero = f16[] constant(0)
+ second_reduce = f16[127]{0} reduce(abs, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f16[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ // Replace divide with multiply, because Triton doesn't support f16
+ // divisions.
+ ROOT multiply = f16[127,125]{1,0} multiply(abs, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseSingleNormalizationDiamond) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ DoesNotFuseDiamondInvolvingUnsupportedTritonInstruction) {
+ const std::string hlo_string = R"(
+HloModule softmax
+add_computation {
+ arg_0.1 = bf16[] parameter(0)
+ arg_1.1 = bf16[] parameter(1)
+ ROOT add = bf16[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = bf16[127,125]{1,0} parameter(0)
+ constant_zero = bf16[] constant(0)
+ reduce = bf16[127]{0} reduce(param_0, constant_zero), dimensions={1}, to_apply=add_computation
+ broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT divide = bf16[127,125]{1,0} divide(param_0, broadcast)
+})";
+
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ const HloInstruction* bf16_divide =
+ module->entry_computation()->root_instruction();
+ EXPECT_FALSE(IsTritonSupportedInstruction(
+ *bf16_divide, device_info_.gpu_compute_capability()));
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ DoesNotFuseInstructionsUnsupportedByTritonIntoDiamonds) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = bf16[] parameter(0)
+ arg_1 = bf16[] parameter(1)
+ ROOT maximum = bf16[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = bf16[127,125]{1,0} parameter(0)
+ constant_neg_inf = bf16[] constant(-inf)
+ reduce = bf16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = bf16[127,125]{1,0} subtract(param_0, broadcast)
+ ROOT exponential = bf16[127,125]{1,0} exponential(subtract)
+})";
+
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ const HloInstruction* bf16_exponential =
+ hlo_query::GetFirstInstructionWithOpcode(*module->entry_computation(),
+ HloOpcode::kExp);
+ EXPECT_FALSE(IsTritonSupportedInstruction(
+ *bf16_exponential, device_info_.gpu_compute_capability()));
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Exp(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithWrongLayout) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{0,1} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseSoftmaxDiamondWithWrongReduceDimension) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={0}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={1}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseSoftmaxDiamondWithWrongBroadcastDimension) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[125,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[125,125]{1,0} broadcast(reduce), dimensions={1}
+ ROOT subtract = f32[125,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseSoftmaxDiamondWithExtraBroadcastUsage) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ ROOT multiply = f32[127,125]{1,0} multiply(broadcast, subtract)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseSoftmaxWithIntermediateUnaryElementwise) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ abs = f32[127,125]{1,0} abs(subtract)
+ exponential = f32[127,125]{1,0} exponential(abs)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(subtract, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ ROOT divide = f32[127,125]{1,0} divide(subtract, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseDiamondWithTrailingUnaryElementwiseAtTheRoot) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ ROOT abs = f32[127,125]{1,0} abs(subtract)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseDiamondWithUnaryElementwisePrefix) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ abs = f32[127,125]{1,0} abs(param_0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(abs, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseDiamondWithMultipleBroadcastDimensions) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[1,3,125,125]{3,2,1,0} parameter(0)
+ bitcast = f32[3,125,125]{2,1,0} bitcast(f32[1,3,125,125]{3,2,1,0} param_0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[3,125]{1,0} reduce(f32[3,125,125]{2,1,0} bitcast, f32[] constant_neg_inf), dimensions={2}, to_apply=max_computation
+ broadcast = f32[1,3,125,125]{3,2,1,0} broadcast(f32[3,125]{1,0} reduce), dimensions={1,2}
+ ROOT subtract = f32[1,3,125,125]{3,2,1,0} subtract(f32[1,3,125,125]{3,2,1,0} param_0, f32[1,3,125,125]{3,2,1,0} broadcast)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseSoftmaxDiamondWithParameterReducerIdentity) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ identity = f32[] parameter(1)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseSoftmaxDiamondWithTritonIncompatibleReducer) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ if_0 = pred[] is-finite(arg_0)
+ c = f32[] convert(if_0)
+ ROOT maximum = f32[] maximum(c, arg_1)
+}
+
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseSoftmaxDiamondWithLastDimensionBitcastAfterReduce) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+ param_0 = f32[3,127,125]{2,1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[3,127]{1,0} reduce(param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
+ bitcasted_reduce = f32[381]{0} bitcast(reduce)
+ broadcast = f32[381,125]{1,0} broadcast(bitcasted_reduce), dimensions={0}
+ bitcasted_broadcast = f32[3,127,125]{2,1,0} bitcast(broadcast)
+ ROOT subtract = f32[3,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseSoftmaxDiamondWithTransposeBitcast) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,127,125]{2,1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ bitcasted_param_0 = f32[127,1,125]{2,0,1} bitcast(param_0)
+ reduce = f32[127,1]{0,1} reduce(bitcasted_param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
+ broadcast = f32[127,1,125]{2,0,1} broadcast(reduce), dimensions={0,1}
+ bitcasted_broadcast = f32[1,127,125]{2,1,0} bitcast(broadcast)
+ ROOT subtract = f32[1,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseTwoDiamondsWithDifferentReductionAxisSizeTogether) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,625]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,625]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,625]{1,0} subtract(param_0, broadcast)
+ bitcasted_subtract = f32[127,5,125] bitcast(subtract)
+ exponential = f32[127,5,125] exponential(bitcasted_subtract)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127,5] reduce(exponential, constant_zero), dimensions={2}, to_apply=add_computation
+ second_broadcast = f32[127,5,125] broadcast(second_reduce), dimensions={0,1}
+ ROOT divide = f32[127,5,125] divide(exponential, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Bitcast(m::Fusion(m::Parameter())
+ .WithPredicate(HasBlockLevelFusionConfig)))
+ .WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseTwoDiamondsWithExtraUsageForFirstDiamondRoot) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ exponential = f32[127,125]{1,0} exponential(subtract)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+ ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, subtract)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
+ m::Fusion(m::Parameter())
+ .WithPredicate(HasBlockLevelFusionConfig))));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseTwoDiamondsWithExtraUsageForSecondDiamondProducer) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ exponential = f32[127,125]{1,0} exponential(subtract)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+ ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, exponential)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(
+ m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseSoftmaxDiamondWithTritonIncompatibleProducer) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+ param_0 = f16[127,125]{1,0} parameter(0)
+ exponential = f16[127,125] exponential(param_0)
+ convert = f32[127,125] convert(exponential)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(convert, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(convert, broadcast)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Exp(m::Parameter()))
+ .WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanNotFuseSoftmaxDiamondWithNonFusibleBitcastBetweenReduceAndProducer) {
+ const std::string hlo_string = R"(
+HloModule softmax
+
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,127,5,25]{3,2,1,0} parameter(0)
+ bitcast_0 = f32[127,125] bitcast(param_0)
+ bitcast_1 = f32[127,125] bitcast(param_0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseSoftmaxDiamondWithBitcastProducerFollowedByBitcastsOnEachUse) {
+ const std::string hlo_string = R"(
+HloModule softmax
+
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,1,127,125]{3,2,1,0} parameter(0)
+ bitcast_parent = f32[127,125]{1,0} bitcast(param_0)
+ bitcast_0 = f32[127,125]{1,0} bitcast(bitcast_parent)
+ bitcast_1 = f32[127,125]{1,0} bitcast(bitcast_parent)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnPreAmpereCudaGpu) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = bf16[127,125]{1,0} parameter(0)
+ param_0_f32 = f32[127,125]{1,0} convert(param_0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
+})";
+
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_THAT(
+ SoftmaxRewriterTriton(
+ TestGpuDeviceInfo::RTXA6000DeviceInfo(
+ se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}),
+ ShapeSizeBytesFunction())
+ .Run(module.get()),
+ tsl::testing::StatusIs(
+ tsl::error::FAILED_PRECONDITION,
+ ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
+ "(compute capability 8.0) and up, but got")));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, RewriterSucceedsOnNonCudaGpu) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = bf16[127,125]{1,0} parameter(0)
+ param_0_f32 = f32[127,125]{1,0} convert(param_0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
+})";
+
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+ EXPECT_TRUE(SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(),
+ ShapeSizeBytesFunction())
+ .Run(module.get())
+ .ok());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) {
+ const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ multiply = f32[127,125]{1,0} multiply(param_0, param_0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(multiply, broadcast)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ CanFuseIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ multiply = f32[127]{0} multiply(reduce, reduce)
+ broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) {
+ const std::string hlo_string = R"(
+HloModule fusible_diamonds
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ multiply = f32[127,125]{1,0} multiply(subtract, subtract)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ ROOT subtract_second = f32[127,125]{1,0} subtract(multiply, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) {
+ const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ ROOT multiply = f32[127,125]{1,0} multiply(subtract, subtract)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ DoesNotFuseIntermediateBinaryElementwiseWithBothSplatOperandsIntoDiamond) {
+ const std::string hlo_string = R"(
+HloModule nonfusible_splat
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ constant_0 = f32[] constant(0.333333343)
+ splat_0 = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
+ constant_1 = f32[] constant(0.66666)
+ splat_1 = f32[127,125]{1,0} broadcast(constant_1), dimensions={}
+ param_0 = f32[127,125]{1,0} parameter(0)
+ multiply_splats = f32[127,125]{1,0} multiply(splat_0, splat_1)
+ multiply_splat_param = f32[127,125]{1,0} multiply(multiply_splats, param_0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(multiply_splat_param, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ DoesNotFuseIntermediateBinaryElementwiseWithSameSplatOperandsIntoDiamond) {
+ const std::string hlo_string = R"(
+HloModule nonfusible_splat_diamond
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ constant_0 = f32[] constant(0.333333343)
+ splat = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
+ param_0 = f32[127,125]{1,0} parameter(0)
+ multiply = f32[127,125]{1,0} multiply(splat, splat)
+ add = f32[127,125]{1,0} add(param_0, multiply)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(add, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ SoftmaxRewriterTriton fusion_rewriter(device_info_, ShapeSizeBytesFunction());
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseRMSNormDiamond) {
+ const std::string hlo_string = R"(
+HloModule rms_norm
+add_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT add.1 = f32[] add(arg_0, arg_1)
+}
+ENTRY main.30 {
+ param_0 = f32[10,10,10,128]{3,2,1,0} parameter(0)
+ multiply_param = f32[10,10,10,128]{3,2,1,0} multiply(param_0, param_0)
+ constant_0 = f32[] constant(0)
+ reduce = f32[10,10,10]{2,1,0} reduce(multiply_param, constant_0), dimensions={3}, to_apply=add_computation
+ constant_1 = f32[] constant(0.333333343)
+ splat = f32[10,10,10]{2,1,0} broadcast(constant_1), dimensions={}
+ multiply_splat = f32[10,10,10]{2,1,0} multiply(reduce, splat)
+ epsilon = f32[] constant(1e-06)
+ splat_epsilon = f32[10,10,10]{2,1,0} broadcast(epsilon), dimensions={}
+ add = f32[10,10,10]{2,1,0} add(multiply_splat, splat_epsilon)
+ rsqrt = f32[10,10,10]{2,1,0} rsqrt(add)
+ broadcast = f32[10,10,10,128]{3,2,1,0} broadcast(rsqrt), dimensions={0,1,2}
+ ROOT multiply = f32[10,10,10,128]{3,2,1,0} multiply(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get())
+ .value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Fusion(m::Parameter())
+ .WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule fusible_diamonds
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ constant = f32[] constant(0.333333343)
+ broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
+ multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule fusible_diamonds
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ constant = f32[] constant(0.333333343)
+ broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
+ multiply = f32[127,125]{1,0} multiply(subtract, broadcast_splat)
+ constant_zero = f32[] constant(0)
+ second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ CanFuseBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) {
+ const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ constant = f32[] constant(0.333333343)
+ broadcast_splat = f32[127]{0} broadcast(constant), dimensions={}
+ multiply = f32[127]{0} multiply(broadcast_splat, reduce)
+ broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstant) {
+ const std::string hlo_string = R"(
+HloModule fusible_diamond
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+ constant = f32[] constant(0.333333343)
+ broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
+ ROOT multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+ CanFuseBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducer) {
+ const std::string hlo_string = R"(
+HloModule nonfusible_diamond
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT max = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_2 = f32[] constant(0.333333343)
+ broadcast_splat = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
+ param_1 = f32[127,125]{1,0} parameter(1)
+ multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, param_1)
+ multiply = f32[127,125]{1,0} multiply(param_0, broadcast_splat)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+ EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+ VLOG(2) << module->ToString();
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(
+ m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ DoesNotFuseBinaryElementwiseOperationWhereFirstOperandIsASplatAndSecondOperandIsASharedSplatProducer) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule nonfusible_diamond
+add_computation {
+ arg_0.1 = f32[] parameter(0)
+ arg_1.1 = f32[] parameter(1)
+ ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ constant_2 = f32[] constant(0.333333343)
+ broadcast_splat_shared = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
+ param_1 = f32[127,125]{1,0} parameter(1)
+ multiply_splat_shared = f32[127,125]{1,0} multiply(broadcast_splat_shared, param_1)
+ constant_3 = f32[] constant(0.5)
+ broadcast_splat = f32[127,125]{1,0} broadcast(constant_3), dimensions={}
+ multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, broadcast_splat_shared)
+ multiply = f32[127,125]{1,0} multiply(param_0, multiply_splat)
+ constant_neg_inf = f32[] constant(-inf)
+ reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=add_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest, FusionDecisionIsCapturedExplicitly) {
+ const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+ arg_0 = f32[] parameter(0)
+ arg_1 = f32[] parameter(1)
+ ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+ param_0 = f32[127,125]{1,0} parameter(0)
+ identity_f8 = f8e5m2[] parameter(1)
+ identity = f32[] convert(identity_f8)
+ reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
+ broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+ ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ SoftmaxRewriterTriton softmax_rewriter_triton(device_info_,
+ ShapeSizeBytesFunction());
+ int unmatched = 0, matched = 0;
+ for (HloInstruction* instruction :
+ module->entry_computation()->MakeInstructionPostOrder()) {
+ DiamondMatchingDecision decision =
+ softmax_rewriter_triton.MatchesTritonCompatibleClosedReductionDiamond(
+ instruction);
+ if (std::holds_alternative<FusionDecision>(decision)) {
+ std::string actual_decision =
+ std::get<FusionDecision>(decision).Explain();
+ EXPECT_THAT(
+ actual_decision,
+ AnyOf(
+ HasSubstr("Root is not elementwise binary"),
+ HasSubstr("identity is not a constant or a supported convert")));
+ unmatched++;
+ } else {
+ matched++;
+ }
+ }
+ EXPECT_EQ(unmatched, 6);
+ EXPECT_EQ(matched, 0);
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongReductionDimAsParameter) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ p0 = f32[32]{0} parameter(0)
+ p1 = f32[32,16]{1,0} parameter(1)
+ c = f32[] constant(0)
+
+ r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
+ b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
+ b1 = f32[32,16]{1,0} broadcast(p0), dimensions={0}
+ add0 = f32[32,16]{1,0} add(b1, p1)
+ ROOT add1 = f32[32,16]{1,0} add(add0, b0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ p0 = f32[16]{0} parameter(0)
+ p1 = f32[32,16]{1,0} parameter(1)
+ c = f32[] constant(0)
+
+ r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
+ b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
+ b1 = f32[32,16]{1,0} broadcast(p0), dimensions={1}
+ add0 = f32[32,16]{1,0} add(b1, p1)
+ ROOT add1 = f32[32,16]{1,0} add(add0, b0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ FusesBinaryElementwiseIfIntermediateDiamondOpWithMultiDimTensorBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ p0 = f32[32,16]{1,0} parameter(0)
+ p1 = f32[64,32,16]{2,1,0} parameter(1)
+ c = f32[] constant(0)
+
+ r0 = f32[64,32]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
+ b0 = f32[64,32,16]{2,1,0} broadcast(r0), dimensions={0,1}
+ b1 = f32[64,32,16]{2,1,0} broadcast(p0), dimensions={1,2}
+ add0 = f32[64,32,16]{2,1,0} add(b1, p1)
+ ROOT add1 = f32[64,32,16]{2,1,0} add(add0, b0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ FusesBinaryElementwiseIfIntermediateDiamondOpWithZeroDimTensorBroadcastAsParameter) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ parameter_0 = f32[] parameter(0)
+ parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
+ c = f32[] constant(0)
+
+ reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
+ broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
+ broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={}
+ add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
+ ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ FusesBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongNonReductionDimensions) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ parameter_0 = f32[16] parameter(0)
+ parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
+ c = f32[] constant(0)
+
+ reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
+ broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
+ broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={2}
+ add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
+ ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_TRUE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongBothBatchAndReductionDimensions) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ parameter_0 = f32[64] parameter(0)
+ parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
+ c = f32[] constant(0)
+
+ reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
+ broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
+ broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={0}
+ add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
+ ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchAndReductionDimAsParameter) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ p0 = f32[8]{0} parameter(0)
+ p1 = f32[32,8,16]{2,1,0} parameter(1)
+ c = f32[] constant(0)
+
+ r0 = f32[32,8]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
+ b0 = f32[32,8,16]{2,1,0} broadcast(r0), dimensions={0,1}
+ b1 = f32[32,8,16]{2,1,0} broadcast(p0), dimensions={1}
+ add0 = f32[32,8,16]{2,1,0} add(b1, p1)
+ ROOT add1 = f32[32,8,16]{2,1,0} add(add0, b0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithPartialBroadcastToBatchDim) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ p0 = f32[16,64]{1,0} parameter(0)
+ p1 = f32[8,16,32,64]{3,2,1,0} parameter(1)
+ c = f32[] constant(0)
+
+ r0 = f32[8,16,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
+ b0 = f32[8,16,32,64]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
+ b1 = f32[8,16,32,64]{3,2,1,0} broadcast(p0), dimensions={1,3}
+ add0 = f32[8,16,32,64]{3,2,1,0} add(b1, p1)
+ ROOT add1 = f32[8,16,32,64]{3,2,1,0} add(add0, b0)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+ SoftmaxRewriterTritonTest,
+ DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithMultiDimBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length)
+ const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+ y = f32[] parameter(1)
+ x = f32[] parameter(0)
+ ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+ p0 = f32[32,16]{1,0} parameter(0)
+ p1 = f32[128,64,32,16]{3,2,1,0} parameter(1)
+ c = f32[] constant(0)
+
+ r0 = f32[128,64,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
+ b0 = f32[128,64,32,16]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
+ b1 = f32[128,64,32,16]{3,2,1,0} broadcast(p0), dimensions={2,3}
+ add0 = f32[128,64,32,16]{3,2,1,0} add(b1, p1)
+ ROOT add1 = f32[128,64,32,16]{3,2,1,0} add(add0, b0)
+})";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+ EXPECT_FALSE(
+ SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc
new file mode 100644
index 0000000..b299db8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc
@@ -0,0 +1,342 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/runtime/cub_sort_thunk.h"
+#include "xla/service/stable_sort_expander.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Analyze sort comparer function.
+struct SortComputationAnalysis {
+ int key_operand; // 0 or 1
+ bool descending;
+};
+
+std::pair<int64_t, int64_t> ParametersFromCmpOperands(
+ const HloCompareInstruction* cmp_op) {
+ if (cmp_op == nullptr) {
+ return std::pair<int64_t, int64_t>(-1, -1);
+ }
+ const HloParameterInstruction* param0 =
+ DynCast<HloParameterInstruction>(cmp_op->operand(0));
+ const HloParameterInstruction* param1 =
+ DynCast<HloParameterInstruction>(cmp_op->operand(1));
+ return (param0 && param1) ? std::make_pair(param0->parameter_number(),
+ param1->parameter_number())
+ : std::pair<int64_t, int64_t>(-1, -1);
+}
+
+// Returns sort info on compatible compare instructions. The instruction may
+// belong to a computation that has 2 or 4 operands. If this is the root
+// instruction of a computation with 4 parameters only succeeds in cases where
+// 2 of the parameters are ignored.
+std::optional<SortComputationAnalysis> AnalyzeCompareOp(
+ const HloInstruction* maybe_compare_op) {
+ // Root instruction must be a comparison with a valid direction.
+ const HloCompareInstruction* compare =
+ DynCast<HloCompareInstruction>(maybe_compare_op);
+ if (compare == nullptr || compare->direction() == ComparisonDirection::kEq ||
+ compare->direction() == ComparisonDirection::kNe) {
+ return std::nullopt;
+ }
+
+ // Compare should operate on the function parameters for a single tensor.
+ auto [index0, index1] = ParametersFromCmpOperands(compare);
+ if (index0 == -1 || index1 == -1) {
+ return std::nullopt;
+ }
+
+ // When sorting a pair of tensors, the parameters should be adjacent.
+ int first_index = std::min(index0, index1);
+ if (first_index % 2 != 0 || std::max(index0, index1) != first_index + 1) {
+ return std::nullopt;
+ }
+
+ // Return the tensor index and the sort direction.
+ bool descending = compare->direction() == ComparisonDirection::kGt ||
+ compare->direction() == ComparisonDirection::kGe;
+ bool reverse = first_index != index0;
+ return SortComputationAnalysis{first_index / 2, descending != reverse};
+}
+
+// Detects a sort with these properties:
+// - Has two operands -- one is an iota op
+// - Has a comparison computation that takes 4 inputs and compares them
+// hierarchically, so that the iota inputs are the final tie-breaker.
+//
+// The above is equivalent to a stable sort where the iota operand is completely
+// ignored. That simpler comparator is the one detected in AnalyzeCompareOp, but
+// that's insufficient, because the StableSortExpander pass expands it into the
+// more complex version detected below.
+std::optional<SortComputationAnalysis> AnalyzeComplexSortComputation(
+ const HloSortInstruction& sort_op) {
+ auto computation = sort_op.called_computations().front();
+ if (computation->num_parameters() != 4) {
+ return std::nullopt;
+ }
+
+ int64_t iota_operand_index =
+ StableSortExpander::IotaOperandIndexForStableSort(sort_op);
+ if (iota_operand_index < 0) {
+ return std::nullopt;
+ }
+
+ auto root = computation->root_instruction();
+ if (root->opcode() != HloOpcode::kSelect) {
+ return std::nullopt;
+ }
+
+ // Check that the middle operand of the select compares the iota input.
+ auto iota_cmp = DynCast<HloCompareInstruction>(root->operand(1));
+ auto [iotap0, iotap1] = ParametersFromCmpOperands(iota_cmp);
+ if (iota_cmp == nullptr ||
+ iota_cmp->direction() != ComparisonDirection::kLt ||
+ iotap0 != iota_operand_index * 2 ||
+ iotap1 != iota_operand_index * 2 + 1) {
+ return std::nullopt;
+ }
+
+ // Check that the first operand of the select is an EQ comparison of the
+ // values (non-iota) input.
+ auto eq_cmp = DynCast<HloCompareInstruction>(root->operand(0));
+ if (eq_cmp == nullptr || eq_cmp->direction() != ComparisonDirection::kEq) {
+ return std::nullopt;
+ }
+
+ // EQ comparison case 1: direct comparison of parameters
+ auto [p0, p1] = ParametersFromCmpOperands(eq_cmp);
+ if (p0 < 0 || p1 < 0) {
+ // EQ comparison case 2: comparison of comparisons. This is what
+ // the StableSortExpander pass currently generates.
+ auto cmp = DynCast<HloCompareInstruction>(eq_cmp->operand(0));
+ auto cmp_reverse = DynCast<HloCompareInstruction>(eq_cmp->operand(1));
+ auto [a, b] = ParametersFromCmpOperands(cmp);
+ auto [p, q] = ParametersFromCmpOperands(cmp_reverse);
+ if (cmp == nullptr || cmp_reverse == nullptr || a < 0 || b < 0 || a != q ||
+ b != p || cmp->direction() != cmp_reverse->direction() ||
+ cmp->direction() == Comparison::Direction::kEq ||
+ cmp->direction() == Comparison::Direction::kNe) {
+ return std::nullopt;
+ }
+ }
+
+ // At this point only the last operand of the select needs to be verified.
+ return AnalyzeCompareOp(root->operand(2));
+}
+
+std::optional<SortComputationAnalysis> AnalyzeSortOp(
+ const HloSortInstruction& sort_op) {
+ auto computation = sort_op.called_computations().front();
+
+ // First, check if the computation is a simple compare op on the operands.
+ auto result = AnalyzeCompareOp(computation->root_instruction());
+ if (!result.has_value()) {
+ // If the above fails, check if the sort instruction and comparer are more
+ // complex, like what is produced by the StableSortExpander pass.
+ result = AnalyzeComplexSortComputation(sort_op);
+ }
+ return result;
+}
+
+// Create runner for CUB sort operation.
+absl::StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner(
+ HloSortInstruction* sort_op, const SortComputationAnalysis& sort_config) {
+ int value_index = 1 - sort_config.key_operand;
+ return CubSortRunnerInterface::Create(
+ sort_op->operand(sort_config.key_operand)->shape().element_type(),
+ sort_op->operand_count() == 2
+ ? std::optional(sort_op->operand(value_index)->shape().element_type())
+ : std::nullopt);
+}
+
+// Verify that the sort tensor shape is supported by CUB.
+bool IsCubCompatibleSort(HloSortInstruction* sort_op) {
+ VLOG(1) << "Sort instruction: " << sort_op->name();
+ if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) {
+ VLOG(2) << "Unsupported operand count: " << sort_op->operand_count();
+ return false;
+ }
+
+ const Shape& operand_shape = sort_op->operand(0)->shape();
+ if (sort_op->sort_dimension() != operand_shape.rank() - 1) {
+ VLOG(2) << "Sort dimension should be the minor one";
+ return false;
+ }
+ if (Product(operand_shape.dimensions()) < SortRewriter::SortSizeThreshold()) {
+ VLOG(2) << "Tensor shape size is too small to see an improvement";
+ return false;
+ }
+
+ auto sort_config = AnalyzeSortOp(*sort_op);
+ if (!sort_config.has_value()) {
+ VLOG(2) << "Only simple compare computations are supported";
+ return false;
+ }
+ if (!CreateRunner(sort_op, *sort_config).ok()) {
+ VLOG(2) << "Unsupported operand types (no compiled CUB kernels)";
+ return false;
+ }
+ VLOG(2) << "Sort operation is compatible";
+ return true;
+}
+
+// Restore the result shape after sorting a pair of tensors.
+// The trailing argument is the scratch buffer which should be discarded.
+HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
+ HloInstruction* custom_call, bool swap) {
+ HloComputation* parent = sort_op->parent();
+ HloInstruction* gte0 =
+ parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+ sort_op->operand(0)->shape(), custom_call, swap ? 1 : 0));
+ HloInstruction* gte1 =
+ parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+ sort_op->operand(1)->shape(), custom_call, swap ? 0 : 1));
+ return parent->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+}
+
+} // namespace
+
+// Rewrites a single sort instruction with a custom call.
+absl::StatusOr<bool> SortRewriter::RunOnInstruction(
+ HloSortInstruction* sort_op) {
+ // Get the sort tensor index and direction.
+ SortComputationAnalysis sort_config = AnalyzeSortOp(*sort_op).value();
+
+ // Get scratch size requirements from CUB.
+ const Shape& operand_shape = sort_op->operand(0)->shape();
+ int64_t batch_size = Product(operand_shape.dimensions()) /
+ operand_shape.dimensions(sort_op->sort_dimension());
+
+ TF_ASSIGN_OR_RETURN(auto runner, CreateRunner(sort_op, sort_config));
+ TF_ASSIGN_OR_RETURN(
+ int64_t scratch_size,
+ runner->GetScratchSize(Product(operand_shape.dimensions()), batch_size));
+
+ // Align and increase scratch size to fit the offsets.
+ if (batch_size > 1) {
+ scratch_size += sizeof(int) - scratch_size % sizeof(int);
+ scratch_size += (batch_size + 1) * sizeof(int);
+ }
+
+ // Values are only present if sorting a pair of tensors.
+ HloInstruction* keys = sort_op->mutable_operand(0);
+ HloInstruction* values = nullptr;
+ if (sort_op->operand_count() == 2) {
+ values = sort_op->mutable_operand(1);
+ if (sort_config.key_operand == 1) {
+ std::swap(keys, values);
+ }
+ }
+
+ // Build the resulting shape for the custom call.
+ std::vector<Shape> shapes{keys->shape()};
+ std::vector<HloInstruction*> operands{keys};
+ if (values != nullptr) {
+ shapes.push_back(values->shape());
+ operands.push_back(values);
+ }
+ shapes.push_back(ShapeUtil::MakeShape(U8, {scratch_size}));
+ Shape call_shape = ShapeUtil::MakeTupleShape(absl::MakeSpan(shapes));
+
+ // Build the custom call instruction.
+ HloInstruction* custom_call =
+ sort_op->parent()->AddInstruction(HloInstruction::CreateCustomCall(
+ call_shape, absl::MakeSpan(operands), kCubDeviceRadixSortTarget));
+
+ xla::SortOptions backend_config;
+ backend_config.set_descending(sort_config.descending);
+ TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
+
+ // Build the replacement instruction.
+ HloInstruction* replacement;
+ if (sort_op->operand_count() == 1) {
+ replacement =
+ sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+ sort_op->shape(), custom_call, 0));
+ } else {
+ replacement = UnpackResultPair(sort_op, custom_call,
+ /*swap=*/sort_config.key_operand == 1);
+ }
+
+ // Replace sort operation with custom call followed by GTE.
+ TF_RETURN_IF_ERROR(
+ sort_op->parent()->ReplaceInstruction(sort_op, replacement));
+ return true;
+}
+
+// Rewrites the sorts in the given computation into calls to CUB.
+absl::StatusOr<bool> SortRewriter::RunOnComputation(
+ HloComputation* computation) {
+ std::vector<HloSortInstruction*> sort_ops;
+ for (auto* inst : computation->instructions()) {
+ HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
+ if (sort != nullptr && IsCubCompatibleSort(sort)) {
+ sort_ops.push_back(sort);
+ }
+ }
+ bool changed = false;
+ for (auto* sort : sort_ops) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(sort));
+ changed |= result;
+ }
+ return changed;
+}
+
+// Replace compatible sort operations with custom calls.
+absl::StatusOr<bool> SortRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_VLOG_LINES(2, "SortRewriter::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ for (HloComputation* computation :
+ module->MakeNonfusionComputations(execution_threads)) {
+ TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+ changed |= result;
+ }
+ XLA_VLOG_LINES(2, "SortRewriter::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h
new file mode 100644
index 0000000..406df7a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h
@@ -0,0 +1,63 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites sort operations into CustomCall HLOs that call into CUB.
+// Only a subset of shapes is supported - either a single tensor with a simple
+// compare function or a pair of tensors where keys are unsigned integers.
+
+class SortRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "sort-rewriter"; }
+
+ // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite
+ // tensors with sizes below this limit.
+ static int SortSizeThreshold() { return sort_size_threshold_; }
+ static void SetSortSizeThresholdForTestingOnly(int threshold) {
+ // We need to be able to reduce the threshold for testing, so that the tests
+ // can run and compare against the reference interpreter, which is quite
+ // slow.
+ sort_size_threshold_ = threshold;
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ absl::StatusOr<bool> RunOnInstruction(HloSortInstruction* sort_op);
+ absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+
+ static inline int sort_size_threshold_ = 16385;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc
new file mode 100644
index 0000000..e9bf60c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc
@@ -0,0 +1,45 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> SortRewriter::RunOnInstruction(
+ HloSortInstruction* sort_op) {
+ return false;
+}
+
+absl::StatusOr<bool> SortRewriter::RunOnComputation(
+ HloComputation* computation) {
+ return false;
+}
+
+absl::StatusOr<bool> SortRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ return false;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc
new file mode 100644
index 0000000..853de5b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc
@@ -0,0 +1,453 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class SortRewriterTest : public HloTestBase {
+ public:
+ void SetUp() override {
+ HloTestBase::SetUp();
+ SortRewriter::SetSortSizeThresholdForTestingOnly(1000);
+ }
+
+ bool RunModuleAndPass(HloModule* module) {
+ auto cloned = module->Clone();
+ bool changed = SortRewriter().Run(module).value();
+ if (changed) {
+ // Here we run an end to end test to make sure that SortRewriter does
+ // not introduce an incorrect rewrite. To do this, we need to clone the
+ // original module because the interpreter cannot process the already
+ // optimized module.
+ EXPECT_TRUE(RunAndCompare(std::move(cloned), ErrorSpec{0, 0}));
+ }
+ return changed;
+ }
+
+ void ExpectDirection(const HloInstruction* instruction, bool descending) {
+ auto config = instruction->backend_config<xla::SortOptions>();
+ EXPECT_EQ(config->descending(), descending);
+ }
+};
+
+// Basic sort: ascending.
+TEST_F(SortRewriterTest, SortKeysLessThan) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input = f32[1000] parameter(0)
+ ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+ ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+ /*descending=*/false);
+}
+
+// Basic sort: descending.
+TEST_F(SortRewriterTest, SortKeysGreaterThan) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
+}
+
+ENTRY %main {
+ %input = f32[1000] parameter(0)
+ ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+ ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+ /*descending=*/true);
+}
+
+// Comparer swaps the parameter order -> direction is reversed.
+TEST_F(SortRewriterTest, SortKeysGreaterThanSwapped) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(1)
+ %rhs = f32[] parameter(0)
+ ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
+}
+
+ENTRY %main {
+ %input = f32[1000] parameter(0)
+ ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+ ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+ /*descending=*/false);
+}
+
+// Sort a pair of tensors, keys go first.
+TEST_F(SortRewriterTest, SortPairs) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs_key = u32[] parameter(0)
+ %rhs_key = u32[] parameter(1)
+ %lhs_value = f32[] parameter(2)
+ %rhs_value = f32[] parameter(3)
+ ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
+}
+
+ENTRY %main {
+ %input_keys = u32[1000] parameter(0)
+ %input_values = f32[1000] parameter(1)
+ ROOT %sort = (u32[1000], f32[1000]) sort(%input_keys, %input_values),
+ dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
+ m::GetTupleElement(m::CustomCall(), 1))));
+}
+
+// Sort a pair of tensors, keys go last.
+TEST_F(SortRewriterTest, SortPairsSwapped) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs_value = f32[] parameter(0)
+ %rhs_value = f32[] parameter(1)
+ %lhs_key = u32[] parameter(2)
+ %rhs_key = u32[] parameter(3)
+ ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
+}
+
+ENTRY %main {
+ %input_values = f32[1000] parameter(0)
+ %input_keys = u32[1000] parameter(1)
+ ROOT %sort = (f32[1000], u32[1000]) sort(%input_values, %input_keys),
+ dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 1),
+ m::GetTupleElement(m::CustomCall(), 0))));
+}
+
+// CUB sort doesn't support more than two tensors.
+TEST_F(SortRewriterTest, NoRewriteManyTensors) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ %unused1 = f64[] parameter(2)
+ %unused2 = f64[] parameter(3)
+ %unused3 = u64[] parameter(4)
+ %unused4 = u64[] parameter(5)
+ ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input1 = f32[1000] parameter(0)
+ %input2 = f64[1000] parameter(1)
+ %input3 = u64[1000] parameter(2)
+ ROOT %sort = (f32[1000], f64[1000], u64[1000]) sort(%input1, %input2, %input3),
+ dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Only 1D shapes are supported.
+TEST_F(SortRewriterTest, NoRewriteNonMinorSortDimension) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input = f32[1000,4] parameter(0)
+ ROOT %sort = f32[1000,4] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Kernels are compiled for a subset of types.
+TEST_F(SortRewriterTest, NoRewriteUnsupportedType) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = pred[] parameter(0)
+ %rhs = pred[] parameter(1)
+ ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input = pred[1000] parameter(0)
+ ROOT %sort = pred[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Comparer must be a simple function.
+TEST_F(SortRewriterTest, NoRewriteComplexComparer) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %lhs_scaled = f32[] multiply(%lhs, f32[] constant(2))
+ %rhs = f32[] parameter(1)
+ ROOT %lt = pred[] compare(%lhs_scaled, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input = f32[1000] parameter(0)
+ ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Comparer must use adjacent input values.
+TEST_F(SortRewriterTest, NoRewriteMixedKeysValues) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs_key = u32[] parameter(0)
+ %rhs_key = u32[] parameter(1)
+ %lhs_value = u32[] parameter(2)
+ %rhs_value = u32[] parameter(3)
+ ROOT %mixed = pred[] compare(%rhs_key, %lhs_value), direction=LT
+}
+
+ENTRY %main {
+ %input_keys = u32[1000] parameter(0)
+ %input_values = u32[1000] parameter(1)
+ ROOT %sort = (u32[1000], u32[1000]) sort(%input_keys, %input_values),
+ dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Small shapes do not see improvement from CUB sort.
+TEST_F(SortRewriterTest, NoRewriteSmallSize) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input = f32[100] parameter(0)
+ ROOT %sort = f32[100] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Basic sort: with batch dimension.
+TEST_F(SortRewriterTest, SortWithBatchDim) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input = f32[10,100] parameter(0)
+ ROOT %sort = f32[10,100] sort(%input), dimensions={1}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+ ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+ /*descending=*/false);
+}
+
+// Basic sort: with multiple batch dimensions.
+TEST_F(SortRewriterTest, SortWithMultipleBatchDims) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+ %input = f32[10,10,10] parameter(0)
+ ROOT %sort = f32[10,10,10] sort(%input), dimensions={2}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction(),
+ GmockMatch(m::GetTupleElement(
+ m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+ ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+ /*descending=*/false);
+}
+
+// Sort a pair of tensors (values, indices generated by iota) with a complex
+// compare.
+TEST_F(SortRewriterTest, SortPairsIotaComparerSimple) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = u16[] parameter(0)
+ %rhs = u16[] parameter(1)
+ %lhs_index = s32[] parameter(2)
+ %rhs_index = s32[] parameter(3)
+
+ cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
+ cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
+ cmp_eq = pred[] compare(%lhs, %rhs), direction=EQ
+
+ ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
+}
+
+ENTRY %main {
+ %inputs = u16[1000] parameter(0)
+ %iota = s32[1000] iota(), iota_dimension=0
+ ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
+ dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
+ m::GetTupleElement(m::CustomCall(), 1))));
+}
+
+// Sort a pair of tensors (values, indices generated by iota) with a complex
+// compare computation that matches the output of the StableSortExpander pass.
+TEST_F(SortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+ %lhs = u16[] parameter(0)
+ %rhs = u16[] parameter(1)
+ %lhs_index = s32[] parameter(2)
+ %rhs_index = s32[] parameter(3)
+
+ cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
+ cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
+ cmp_rl = pred[] compare(%rhs, %lhs), direction=GT
+ cmp_eq = pred[] compare(cmp_lr, cmp_rl), direction=EQ
+
+ ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
+}
+
+ENTRY %main {
+ %inputs = u16[1000] parameter(0)
+ %iota = s32[1000] iota(), iota_dimension=0
+ ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
+ dimensions={0}, to_apply=%compare
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+ EXPECT_TRUE(RunModuleAndPass(module.get()));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
+ m::GetTupleElement(m::CustomCall(), 1))));
+}
+
+TEST_F(SortRewriterTest, SortSizeThresholdIsSet) {
+ EXPECT_EQ(SortRewriter::SortSizeThreshold(), 1000);
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc
new file mode 100644
index 0000000..68805b1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc
@@ -0,0 +1,225 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/runtime/thunk.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+bool IsOnlyRootNonDefaultStream(HloComputation* computation) {
+ HloInstruction* root = computation->root_instruction();
+ auto root_gpu_config = root->backend_config<GpuBackendConfig>();
+ if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple) {
+ return false;
+ }
+ int64_t root_stream_id = root_gpu_config->operation_queue_id();
+ VLOG(2) << "Found fusion computation's root stream id to be "
+ << root_stream_id;
+ if (root_stream_id == Thunk::kDefaultExecutionStreamId.value()) {
+ return false;
+ }
+ for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
+ if (instr == root) {
+ continue;
+ }
+ int64_t instr_stream_id =
+ instr->backend_config<GpuBackendConfig>()->operation_queue_id();
+ if (instr_stream_id != Thunk::kDefaultExecutionStreamId.value() &&
+ instr_stream_id != root_stream_id) {
+ return false;
+ }
+ }
+ return true;
+}
+
+absl::StatusOr<bool> AnnotateStreamAttributesForInstruction(
+ HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
+ if (instr->called_computations().size() != 1) {
+ return false;
+ }
+ HloComputation* called_comp = instr->called_computations()[0];
+ int64_t stream_id = instr_gpu_config.operation_queue_id();
+
+ if (!IsOnlyRootNonDefaultStream(called_comp) ||
+ stream_id != Thunk::kDefaultExecutionStreamId.value()) {
+ return false;
+ }
+
+ auto comp_root_gpu_config =
+ called_comp->root_instruction()->backend_config<GpuBackendConfig>();
+
+ instr_gpu_config.set_operation_queue_id(
+ comp_root_gpu_config->operation_queue_id());
+ *instr_gpu_config.mutable_wait_on_operation_queues() =
+ comp_root_gpu_config->wait_on_operation_queues();
+ TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
+ return true;
+}
+
+absl::StatusOr<bool> AnnotateStreamAttributesForCopyStart(
+ HloInstruction* instr, int64_t channel_id,
+ GpuBackendConfig& instr_gpu_config) {
+ // Do nothing if copy-start has already been annotated
+ if (instr_gpu_config.operation_queue_id() !=
+ Thunk::kDefaultExecutionStreamId.value()) {
+ return false;
+ }
+ instr_gpu_config.set_operation_queue_id(channel_id);
+ TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
+ VLOG(3) << "Add copy-start's backend config: " << channel_id;
+ return true;
+}
+
+absl::StatusOr<bool> WrapIntoFusionAndAnnotateStreamAttributes(
+ HloInstruction* instruction, int64_t channel_id,
+ GpuBackendConfig& instr_gpu_config) {
+ auto* computation = instruction->parent();
+ auto* module = computation->parent();
+ auto* fusion_instruction =
+ computation->AddInstruction(HloInstruction::CreateFusion(
+ instruction->shape(), ChooseFusionKind(*instruction, *instruction),
+ instruction));
+ const absl::string_view wrapped_opcode =
+ HloOpcodeString(instruction->opcode());
+ module->SetAndUniquifyInstrName(fusion_instruction,
+ absl::StrCat("wrapped_", wrapped_opcode));
+ module->SetAndUniquifyComputationName(
+ fusion_instruction->fused_instructions_computation(),
+ absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
+ if (module->has_schedule()) {
+ // Update the scheduling names of the fusion and its root instruction
+ // to match their newly assigned instruction names during creation.
+ fusion_instruction->set_metadata_scheduling_name(
+ fusion_instruction->name());
+ HloInstruction* root = fusion_instruction->fused_expression_root();
+ root->set_metadata_scheduling_name(root->name());
+ module->schedule().replace_instruction(computation, instruction,
+ fusion_instruction);
+ }
+ TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction));
+ TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
+
+ instr_gpu_config.set_operation_queue_id(channel_id);
+ TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config));
+ VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction "
+ << instruction->ToString();
+ VLOG(3) << " Fusion wrapper: " << fusion_instruction->ToString();
+ return true;
+}
+
+absl::StatusOr<bool> AnnotateStreamAttributesForUsers(
+ HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
+ bool changed = false;
+ int64_t stream_id = instr_gpu_config.operation_queue_id();
+ if (stream_id == Thunk::kDefaultExecutionStreamId.value()) {
+ return changed;
+ }
+ std::vector<HloInstruction*> all_consumers;
+ for (auto user : instr->users()) {
+ if (user->opcode() == HloOpcode::kGetTupleElement) {
+ user = user->users()[0];
+ }
+ all_consumers.push_back(user);
+ }
+
+ for (auto user : all_consumers) {
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ user->backend_config<GpuBackendConfig>());
+ auto it = absl::c_find(gpu_config.wait_on_operation_queues(), stream_id);
+ if (it == gpu_config.wait_on_operation_queues().end() &&
+ gpu_config.operation_queue_id() != stream_id) {
+ gpu_config.mutable_wait_on_operation_queues()->Add(stream_id);
+ TF_RETURN_IF_ERROR(user->set_backend_config(gpu_config));
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+} // namespace
+
+absl::StatusOr<bool> StreamAttributeAnnotator::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_VLOG_LINES(
+ 5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ int64_t channel_id = hlo_query::NextChannelId(*module);
+ for (const HloComputation* comp :
+ module->MakeComputationPostOrder(execution_threads)) {
+ for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+ auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
+ if (!instr_gpu_config.ok()) {
+ continue;
+ }
+ // For fusion instruction, only annotate
+ // when the root of fusion is a single instruction
+ // running on non-default stream.
+ if (instr->opcode() == HloOpcode::kFusion) {
+ TF_ASSIGN_OR_RETURN(bool comp_result,
+ AnnotateStreamAttributesForInstruction(
+ instr, instr_gpu_config.value()));
+ changed |= comp_result;
+ } else if (instr->opcode() == HloOpcode::kCopyStart) {
+ TF_ASSIGN_OR_RETURN(bool comp_result,
+ AnnotateStreamAttributesForCopyStart(
+ instr, channel_id, instr_gpu_config.value()));
+ changed |= comp_result;
+ continue;
+ } else if (comp->IsAsyncComputation() &&
+ (instr->opcode() == HloOpcode::kDynamicSlice ||
+ instr->opcode() == HloOpcode::kDynamicUpdateSlice)) {
+ TF_ASSIGN_OR_RETURN(bool comp_result,
+ WrapIntoFusionAndAnnotateStreamAttributes(
+ instr, channel_id, instr_gpu_config.value()));
+ changed |= comp_result;
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ bool user_result,
+ AnnotateStreamAttributesForUsers(instr, instr_gpu_config.value()));
+ changed |= user_result;
+ }
+ }
+ XLA_VLOG_LINES(
+ 5, "StreamAttributeAnnotator::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h
new file mode 100644
index 0000000..81816f8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h
@@ -0,0 +1,60 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass checks to see if:
+// 1. there's any instruction, that
+// consumes data from other computes streams,
+// is missing "wait_on_operation_queues" attribute.
+// 2. there's any fusion instruction with non-default
+// stream fusion root.
+// It will annotate the corresponding instruction with
+// the correct attribute in GpuBackendConfig.
+// Instructions annotated with operation_queue_id > 0
+// will be wrapped with AsyncInstruction and split into
+// AsyncStart and AsyncDone in the
+// StreamAttributeAsyncWrapper pass.
+// We also check if there's any non-default-stream
+// instruction's user doesn't have the correct "wait_on_operation_queues"
+// attribute and set it with producer's operation_queue_id.
+// "wait_on_operation_queues" will need to used by the emitter to emit the
+// correct WaitForStreams thunk.
+
+class StreamAttributeAnnotator : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "stream-attribute-annotator";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc
new file mode 100644
index 0000000..c7d2ca5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc
@@ -0,0 +1,340 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using StreamAttributeAnnotatorTest = HloTestBase;
+
+TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithAsync
+
+ ENTRY entry {
+ p1_32 = f32[1] parameter(0)
+ p2_32 = f32[1] parameter(1)
+ add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
+ exp_32 = f32[1] exponential(add_32)
+
+ neg32 = f32[1] negate(add_32)
+ ROOT add_out_32 = f32[1] add(neg32, exp_32)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ StreamAttributeAnnotator attr_annotator;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ const HloInstruction* add = FindInstruction(module.get(), "add_32");
+ for (auto user : add->users()) {
+ // Every user should have an annotation.
+ EXPECT_TRUE(user->has_backend_config());
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ user->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
+ }
+}
+
+TEST_F(StreamAttributeAnnotatorTest, MultipleStreamsAreCombined) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithAsync
+
+ ENTRY entry {
+ p1_32 = f32[1] parameter(0)
+ p2_32 = f32[1] parameter(1)
+ add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
+ exp_32 = f32[1] exponential(p2_32), backend_config={"operation_queue_id":"2", "wait_on_operation_queues":[]}
+
+ ROOT add_out_32 = f32[1] add(add_32, exp_32)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ StreamAttributeAnnotator attr_annotator;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ // Root should wait on 2 streams.
+ EXPECT_TRUE(root->has_backend_config());
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ root->backend_config<GpuBackendConfig>());
+ std::vector<int64_t> expected_stream_ids = {1, 2};
+ for (auto id : expected_stream_ids) {
+ auto it = absl::c_find(gpu_config.wait_on_operation_queues(), id);
+ EXPECT_NE(it, gpu_config.wait_on_operation_queues().end());
+ }
+}
+
+TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithAsync
+
+ ENTRY entry {
+ p1_32 = f32[16,32] parameter(0)
+ p2_32 = f32[32,16] parameter(1)
+
+ custom-call.3 = (f32[16,16], s8[1028]{0}) custom-call(p1_32, p2_32), custom_call_target="__cublas$gemm", backend_config={"operation_queue_id":"1","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","grad_x":false,"grad_y":false}}
+ get-tuple-element.24 = f32[16,16] get-tuple-element(custom-call.3), index=0
+
+ exp_32 = f32[16,16] exponential(get-tuple-element.24)
+
+ ROOT neg32 = f32[16,16] negate(exp_32)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ StreamAttributeAnnotator attr_annotator;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ const HloInstruction* exp = FindInstruction(module.get(), "exp_32");
+ EXPECT_TRUE(exp->has_backend_config());
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ exp->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
+}
+
+TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithFusion
+
+ fused_computation.1 {
+ fusion_p0_32 = f32[16,16] parameter(0)
+ fusion_p2_32 = f32[16,16] parameter(1)
+ ROOT add = f32[16,16] add(fusion_p0_32, fusion_p2_32), backend_config={"operation_queue_id":"1","wait_on_operation_queues":[]}
+ }
+
+ ENTRY entry {
+ p1_32 = f32[16,16] parameter(0)
+ p2_32 = f32[16,16] parameter(1)
+ ROOT fusion.1 = f32[16,16] fusion(p1_32, p2_32), kind=kLoop, calls=fused_computation.1
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ StreamAttributeAnnotator attr_annotator;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ const HloInstruction* fusion = FindInstruction(module.get(), "fusion.1");
+ EXPECT_TRUE(fusion->has_backend_config());
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ fusion->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+}
+
+TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule offloading
+ ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] {
+ %param_1 = f32[1024]{0} parameter(1)
+ %param_0 = f32[1024]{0} parameter(0)
+ %res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1)
+ %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3)
+ %res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3)
+ %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4)
+ %res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4)
+ %copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start)
+ %res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5)
+ %copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2)
+ %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2)
+ %res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6)
+ %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done)
+ %res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5)
+ %copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3)
+ %res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3)
+ %copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1)
+ %res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1)
+ ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ StreamAttributeAnnotator attr_annotator;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ for (std::string i : {"", ".1", ".2", ".3"}) {
+ const HloInstruction* cp_start =
+ FindInstruction(module.get(), "copy-start" + i);
+ EXPECT_TRUE(cp_start->has_backend_config());
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ cp_start->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+ }
+}
+
+TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithAsyncDynamicUpdateSlice, is_scheduled=true
+
+ ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] {
+ param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0), metadata={scheduling_name="param_0"}
+ param_1 = f32[1,128,128]{2,1,0} parameter(1), metadata={scheduling_name="param_1"}
+ izero = s32[] constant(0), metadata={scheduling_name="izero"}
+ dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[])
+ dynamic-update-slice-start(param_0, param_1, izero, izero, izero),
+ metadata={scheduling_name="dynamic-update-slice-start.2"}
+ ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)}
+ dynamic-update-slice-done(dynamic-update-slice-start.2),
+ metadata={scheduling_name="dynamic-update-slice-done.2"}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+ EXPECT_TRUE(module->has_schedule());
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ StreamAttributeAnnotator().Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ // Check that the dynamic-update-slice instruction is wrapped in a fusion
+ // and the fusion is annotated with the correct operation_queue_id.
+ const HloInstruction* dus =
+ FindInstruction(module.get(), HloOpcode::kDynamicUpdateSlice);
+ const HloComputation* computation = dus->parent();
+ EXPECT_TRUE(computation->IsFusionComputation());
+ const HloInstruction* fusion = computation->FusionInstruction();
+ EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
+
+ EXPECT_TRUE(fusion->has_backend_config());
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ fusion->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+ // Check if the schedule name the same as the instruction name
+ for (const auto* comp : module->computations()) {
+ for (const auto* instruction : comp->instructions()) {
+ if (!instruction->metadata().scheduling_name().empty()) {
+ EXPECT_EQ(instruction->name(),
+ instruction->metadata().scheduling_name());
+ }
+ }
+ }
+ constexpr absl::string_view kExpectedSchedulingName = R"(
+// CHECK: %wrapped_dynamic-update-slice_computation
+// CHECK: ROOT %[[DYNAMIC_UPDATE_SLICE:.+]] = f32[256,128,128]{2,1,0:S(5)} dynamic-update-slice(
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_UPDATE_SLICE]]"}
+// CHECK: %[[DYNAMIC_UPDATE_SLICE_START:.+]] = {{.*}} fusion-start(
+// CHECK-SAME: calls=%wrapped_dynamic-update-slice_computation
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_UPDATE_SLICE_START]]"}
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(
+ module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+ kExpectedSchedulingName));
+ EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithAsyncDynamicSlice, is_scheduled=true
+
+ ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] {
+ param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0), metadata={scheduling_name="param_0"}
+ izero = s32[] constant(0), metadata={scheduling_name="izero"}
+ dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[])
+ dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128},
+ metadata={scheduling_name="dynamic-slice-start.2"}
+ ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0}
+ dynamic-slice-done(dynamic-slice-start.2),
+ metadata={scheduling_name="dynamic-slice-done.2"}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ EXPECT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ StreamAttributeAnnotator().Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ // Check that the dynamic-slice instruction is wrapped in a fusion
+ // and the fusion is annotated with the correct operation_queue_id.
+ const HloInstruction* ds =
+ FindInstruction(module.get(), HloOpcode::kDynamicSlice);
+ const HloComputation* computation = ds->parent();
+ EXPECT_TRUE(computation->IsFusionComputation());
+ const HloInstruction* fusion = computation->FusionInstruction();
+ EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
+
+ EXPECT_TRUE(fusion->has_backend_config());
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ fusion->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+ // Check if the schedule name the same as the instruction name
+ for (const auto* comp : module->computations()) {
+ for (const auto* instruction : comp->instructions()) {
+ if (!instruction->metadata().scheduling_name().empty()) {
+ EXPECT_EQ(instruction->name(),
+ instruction->metadata().scheduling_name());
+ }
+ }
+ }
+ constexpr absl::string_view kExpectedSchedulingName = R"(
+// CHECK: %wrapped_dynamic-slice_computation
+// CHECK: ROOT %[[DYNAMIC_SLICE:.+]] = f32[1,128,128]{2,1,0} dynamic-slice(
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_SLICE]]"}
+// CHECK: %[[DYNAMIC_SLICE_START:.+]] = {{.*}} fusion-start(
+// CHECK-SAME: calls=%wrapped_dynamic-slice_computation
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_SLICE_START]]"}
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_matches,
+ RunFileCheck(
+ module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+ kExpectedSchedulingName));
+ EXPECT_TRUE(filecheck_matches);
+}
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc
new file mode 100644
index 0000000..be0eb6f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc
@@ -0,0 +1,74 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/runtime/thunk.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+namespace {
+static absl::StatusOr<bool> AsynchronizeInstruction(HloInstruction* instr) {
+ auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
+ if (!instr_gpu_config.ok() || instr_gpu_config->operation_queue_id() ==
+ Thunk::kDefaultExecutionStreamId.value()) {
+ return false;
+ }
+ HloComputation* computation = instr->parent();
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * done,
+ computation->CreateAsyncInstructions(
+ instr, {}, StreamAttributeAsyncWrapper::kParallelExecutionThread,
+ /*replace=*/true));
+ TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+ done->backend_config<GpuBackendConfig>());
+ // Set the false delay of done op to be false so it can be scheduled
+ // far apart from start.
+ gpu_config.set_force_earliest_schedule(false);
+ TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config));
+ VLOG(5) << "Created async instruction: " << done->ToString();
+ return true;
+}
+} // namespace
+
+absl::StatusOr<bool> StreamAttributeAsyncWrapper::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_VLOG_LINES(
+ 2, "StreamAttributeAsyncWrapper::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ for (const HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* instr : comp->instructions()) {
+ TF_ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr));
+ changed |= result;
+ }
+ }
+ XLA_VLOG_LINES(
+ 2, "StreamAttributeAsyncWrapper::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h
new file mode 100644
index 0000000..157b579
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h
@@ -0,0 +1,49 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass will find the instructions that
+// are annotated with non-default stream id in backend configs
+// by the StreamAttributeAnnotator pass
+// and wrap them using AsyncStartDone pairs to achieve
+// asynchronous executions.
+class StreamAttributeAsyncWrapper : public HloModulePass {
+ public:
+ inline static constexpr char kParallelExecutionThread[] = "parallel";
+
+ absl::string_view name() const override {
+ return "async-stream-attribute-wrapper";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc
new file mode 100644
index 0000000..32ed4c5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc
@@ -0,0 +1,77 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using StreamAttributeAsyncWrapperTest = HloTestBase;
+
+TEST_F(StreamAttributeAsyncWrapperTest, NonDefaultOpIsWrapped) {
+ constexpr absl::string_view kHloString = R"(
+ HloModule ModuleWithAsync
+
+ ENTRY entry {
+ p1_32 = f32[1] parameter(0)
+ p2_32 = f32[1] parameter(1)
+ add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[], "force_earliest_schedule":true}
+ ROOT exp_32 = f32[1] exponential(add_32), backend_config={"operation_queue_id":"0", "wait_on_operation_queues":[1]}
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ StreamAttributeAsyncWrapper async_wrapper;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, async_wrapper.Run(module.get()));
+ EXPECT_TRUE(changed);
+ const HloInstruction* producer =
+ module->entry_computation()->root_instruction()->operand(0);
+ EXPECT_EQ(producer->opcode(), HloOpcode::kAsyncDone);
+ // Verify that the force_earliest_schedule is set to false for the done op.
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig done_gpu_config,
+ producer->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(done_gpu_config.force_earliest_schedule(), false);
+
+ const HloInstruction* producer_start = producer->operand(0);
+ EXPECT_EQ(producer_start->opcode(), HloOpcode::kAsyncStart);
+
+ const xla::HloAsyncInstruction* async =
+ Cast<HloAsyncInstruction>(producer_start);
+ EXPECT_EQ(async->async_wrapped_opcode(), HloOpcode::kAdd);
+ // Verify that the backend config is kept intact
+ TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+ async->backend_config<GpuBackendConfig>());
+ EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+ EXPECT_EQ(gpu_config.force_earliest_schedule(), true);
+ EXPECT_EQ(async->async_execution_thread(), "parallel");
+}
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc
new file mode 100644
index 0000000..1cc6206
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc
@@ -0,0 +1,113 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_specializer.h"
+
+#include <stddef.h>
+
+#include <initializer_list>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/primitive_util.h"
+#include "xla/service/hlo.pb.h"
+#include "xla/service/tuple_util.h"
+#include "xla/shape.h"
+#include "xla/status_macros.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+absl::StatusOr<HloInstruction*> SmallBufferOptimization(
+ HloCustomCallInstruction* topk) {
+ Shape data_shape = topk->operand(0)->shape();
+ auto supported_dtypes = {F32, BF16};
+ if (!absl::c_linear_search(supported_dtypes, data_shape.element_type())) {
+ return InvalidArgument(
+ "Invalid Dtype: %s",
+ primitive_util::LowercasePrimitiveTypeName(data_shape.element_type()));
+ }
+ // We only support topk of the shape [x] or [batch, x].
+ if (data_shape.dimensions_size() > 2) {
+ return InvalidArgument("Invalid input dimensions: %s",
+ data_shape.ToString());
+ }
+ bool has_batch = data_shape.dimensions_size() == 2;
+ constexpr size_t max_k = 16;
+ constexpr size_t min_n = 1024;
+ size_t n = data_shape.dimensions(has_batch ? 1 : 0);
+ size_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
+ if (k > max_k) {
+ return InvalidArgument("k too large (%d), must be <= %d", k, max_k);
+ }
+ if (n < min_n) {
+ return InvalidArgument("Input too small (n=%d, min_n=%d)", n, min_n);
+ }
+ HloComputation* comp = topk->parent();
+ HloInstruction* new_topk =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ topk->shape(), topk->operands(),
+ // We don't need the original to_apply, but keeping it around allows
+ // us to round-trip this CustomCall on tests.
+ topk->to_apply(), "__gpu$TopK",
+ /*opaque=*/"", CustomCallApiVersion::API_VERSION_TYPED_FFI));
+ return TupleUtil::ExtractPrefix(new_topk, 2);
+}
+
+class SpecializeTopkVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleCustomCall(HloInstruction* inst) override {
+ HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
+ if (topk == nullptr || topk->custom_call_target() != "TopK") {
+ return absl::OkStatus();
+ }
+ TF_RET_CHECK(topk->operand_count() == 1);
+
+ if (auto small_topk = SmallBufferOptimization(topk); small_topk.ok()) {
+ return ReplaceInstruction(topk, *small_topk);
+ } else {
+ VLOG(2) << "Small TopK optimization doesn't match: "
+ << small_topk.status();
+ }
+
+ return absl::OkStatus();
+ }
+};
+
+} // namespace
+
+absl::StatusOr<bool> TopkSpecializer::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ return SpecializeTopkVisitor().RunOnModule(module, execution_threads);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer.h b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h
new file mode 100644
index 0000000..e3ec565
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h
@@ -0,0 +1,41 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass transforms eligible TopK CustomCall into a call to be executed by
+// runtime/topk.cc.
+class TopkSpecializer : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "topk-specializer"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc
new file mode 100644
index 0000000..96d7e49
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_specializer.h"
+
+#include <stddef.h>
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/topk_rewriter.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+using ::testing::Combine;
+using ::testing::Values;
+
+// Params:
+// - n_kb: number of elements in kilobytes.
+// - k: number of elements to return.
+// - batch_size
+// - dtype
+using ParameterizedInterface =
+ ::testing::WithParamInterface<std::tuple<int, int, int, std::string_view>>;
+
+class TopkTest : public HloTestBase, public ParameterizedInterface {
+ public:
+ TopkTest()
+ : HloTestBase(*PlatformUtil::GetPlatform("gpu"),
+ *PlatformUtil::GetPlatform("gpu"), true, true, {}) {}
+
+ protected:
+ absl::StatusOr<std::unique_ptr<HloModule>> TopkHlo(int n, int k,
+ int batch_size,
+ std::string_view dtype) {
+ return ParseAndReturnVerifiedModule(absl::Substitute(
+ R"(
+ %compare {
+ %p.1.lhs.40628 = s32[] parameter(2)
+ %p.1.rhs.40629 = s32[] parameter(3)
+ %constant.40630 = pred[] constant(true)
+ %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
+ %p.0.lhs.40626 = f32[] parameter(0)
+ %p.0.rhs.40627 = f32[] parameter(1)
+ %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
+ ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
+ }
+
+ ENTRY top_k {
+ %arg = $3[$2,$0] parameter(0)
+ ROOT %result = ($3[$2,$1], s32[$2,$1]) custom-call(%arg), custom_call_target="TopK", to_apply=%compare
+ }
+ )",
+ n, k, batch_size, dtype));
+ }
+};
+
+class GeneralizeTopkVisitor : public DfsHloRewriteVisitor {
+ public:
+ absl::Status HandleCustomCall(HloInstruction* inst) override {
+ HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
+ if (topk == nullptr || topk->custom_call_target() != "__gpu$TopK") {
+ return absl::OkStatus();
+ }
+ HloComputation* comp = topk->parent();
+ auto original_shape = ShapeUtil::SliceTuple(topk->shape(), 0, 2);
+ HloInstruction* original_topk =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ original_shape, topk->operands(), topk->to_apply(), "TopK"));
+ // TupleUtil::ExtractPrefix creates the following structure:
+ // TopK
+ // -------------
+ // | | |
+ // Get Get Get
+ // \ | /
+ // CreateTuple
+ // Here we walk to Create Tuple and replace it with the original topk.
+ HloInstruction* new_tuple = topk->users()[0]->users()[0];
+ return ReplaceInstruction(new_tuple, original_topk);
+ }
+};
+
+class GeneralizeTopk : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "generalized-topk"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(HloModule* module,
+ const absl::flat_hash_set<absl::string_view>&
+ execution_threads) override {
+ return GeneralizeTopkVisitor().RunOnModule(module, execution_threads);
+ }
+};
+
+void ToSortAndSlice(HloModule* module) {
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, GeneralizeTopk().Run(module));
+ ASSERT_TRUE(changed);
+ TF_ASSERT_OK_AND_ASSIGN(changed, TopkDecomposer().Run(module));
+ ASSERT_TRUE(changed);
+}
+
+TEST_P(TopkTest, ProducesCorrectResult) {
+ const auto [n_kb, k, batch_size, dtype] = GetParam();
+ const size_t n = n_kb * 1024;
+ TF_ASSERT_OK_AND_ASSIGN(auto topk_module, TopkHlo(n, k, batch_size, dtype));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ gpu::TopkSpecializer().Run(topk_module.get()));
+ ASSERT_TRUE(changed);
+ EXPECT_TRUE(
+ RunAndCompare(std::move(topk_module), std::nullopt, ToSortAndSlice));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ TopkTests, TopkTest,
+ Combine(
+ /*n_kb=*/Values(1, 8, 12, 32),
+ /*k=*/Values(1, 2, 4, 8, 16, 7, 12),
+ /*batch_size=*/Values(1, 16, 32, 64, 128),
+ /*dtype=*/Values(absl::string_view("f32"), "bf16")),
+ [](const auto& info) {
+ return absl::Substitute("n$0KiB_k$1_batch_size$2_$3",
+ std::get<0>(info.param), std::get<1>(info.param),
+ std::get<2>(info.param), std::get<3>(info.param));
+ });
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc
new file mode 100644
index 0000000..41ba135
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc
@@ -0,0 +1,154 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_splitter.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/numeric/bits.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr size_t kRequiredAlignment = 1024;
+constexpr size_t kMaximumBatchSize = 1024;
+
+class TopkSplitterVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit TopkSplitterVisitor(size_t split_threshold)
+ : split_threshold_(split_threshold) {}
+
+ absl::Status HandleCustomCall(HloInstruction* inst) override {
+ HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
+ if (topk == nullptr || topk->custom_call_target() != "TopK") {
+ return absl::OkStatus();
+ }
+ HloComputation* comp = inst->parent();
+ Shape data_shape = topk->operand(0)->shape();
+ bool has_batch = data_shape.dimensions_size() == 2;
+ // TODO(doak): Support multiple batches.
+ if (has_batch && data_shape.dimensions(0) != 1) {
+ return absl::OkStatus();
+ }
+ size_t n = data_shape.dimensions(has_batch ? 1 : 0);
+ int64_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
+ // If K approaches N, splitting the input will not be beneficial anymore.
+ if (k > sqrt(n)) {
+ return absl::OkStatus();
+ }
+ // TODO(doak): Relax this alignment requirement.
+ if (n % kRequiredAlignment != 0) {
+ return absl::OkStatus();
+ }
+ if (n < split_threshold_) return absl::OkStatus();
+ int new_batch =
+ std::min(absl::bit_floor(n / split_threshold_), kMaximumBatchSize);
+ int new_n = n / new_batch;
+ // Split the input into B batches and compute TopK over the batched arrays.
+ Shape split_input_shape =
+ ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, new_n});
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * reshaped,
+ MakeReshapeHlo(split_input_shape, topk->mutable_operand(0)));
+ Shape batch_topk_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, k}),
+ ShapeUtil::MakeShape(S32, {new_batch, k})});
+ HloInstruction* batch_topk =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ batch_topk_shape, {reshaped}, topk->to_apply(), "TopK",
+ /*opaque=*/""));
+ // Fix indices, adding j*split_N to the j-th batch of indices.
+ TF_ASSIGN_OR_RETURN(HloInstruction * indices,
+ MakeGetTupleElementHlo(batch_topk, 1));
+ TF_ASSIGN_OR_RETURN(HloInstruction * values,
+ MakeGetTupleElementHlo(batch_topk, 0));
+ Shape iota_shape = ShapeUtil::MakeShape(S32, {new_batch});
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * fix,
+ MakeBinaryHlo(
+ HloOpcode::kMultiply, MakeIotaHlo(comp, iota_shape, 0),
+ MakeBroadcastHlo(MakeR0ConstantHlo<int32_t>(comp, new_n),
+ /*broadcast_dimensions=*/{}, iota_shape)));
+ TF_ASSIGN_OR_RETURN(
+ indices, MakeBinaryHlo(HloOpcode::kAdd, indices,
+ MakeBroadcastHlo(fix, {0}, indices->shape())));
+ // With the indices restored, compute a final top-k. Since this topk uses
+ // arbitrary indices, we need to use sort+slice.
+ Shape linear_index_shape = ShapeUtil::MakeShape(S32, {k * new_batch});
+ Shape linear_shape = ShapeUtil::ChangeElementType(
+ linear_index_shape, data_shape.element_type());
+ Shape linear_sort_shape =
+ ShapeUtil::MakeTupleShape({linear_shape, linear_index_shape});
+ // Assuming the outputs of the TopK above are stably sorted, using a stable
+ // sort here is enough to guarantee global stable sorting:
+ // - Within a blocks elements are stably sorted by TopK.
+ // - Since blocks are organized linearly from smallest to largest, the
+ // index used on the stable sort below will also respect block ordering.
+ HloInstruction* aggregated_sort =
+ comp->AddInstruction(HloInstruction::CreateSort(
+ linear_sort_shape, 0,
+ {*MakeReshapeHlo(linear_shape, values),
+ *MakeReshapeHlo(linear_index_shape, indices)},
+ topk->to_apply(), /*is_stable=*/true));
+ auto slice_tuple = [&](HloInstruction* sort, const size_t index) {
+ return *MakeReshapeHlo(
+ topk->shape().tuple_shapes(index),
+ *MakeSliceHlo(*MakeGetTupleElementHlo(sort, index), {0}, {k}, {1}));
+ };
+ return ReplaceInstruction(topk,
+ comp->AddInstruction(HloInstruction::CreateTuple({
+ slice_tuple(aggregated_sort, 0),
+ slice_tuple(aggregated_sort, 1),
+ })));
+ }
+
+ private:
+ size_t split_threshold_;
+};
+
+} // namespace
+
+absl::StatusOr<bool> TopKSplitter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ return TopkSplitterVisitor(split_threshold_)
+ .RunOnModule(module, execution_threads);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter.h b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h
new file mode 100644
index 0000000..c6fe429
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h
@@ -0,0 +1,52 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_
+
+#include <cstddef>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Splits large TopK into batches of smaller TopKs, followed by sorting and
+// slicing the results of those smaller topks. We consider TopKs to be 'large'
+// the last dimension of the TopK is larger than `split_threshold`.
+class TopKSplitter : public HloModulePass {
+ public:
+ explicit TopKSplitter(size_t split_threshold = 1024 * 1024)
+ : split_threshold_(split_threshold) {}
+ absl::string_view name() const override { return "topk-splitter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ const size_t split_threshold_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc
new file mode 100644
index 0000000..8236c26
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc
@@ -0,0 +1,210 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_splitter.h"
+
+#include <stdint.h>
+
+#include <cstddef>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_dce.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/topk_rewriter.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace m = ::xla::match;
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+using TopkSplitterTest = HloTestBase;
+
+constexpr absl::string_view kComparator = R"(
+ %compare {
+ %p.1.lhs.40628 = s32[] parameter(2)
+ %p.1.rhs.40629 = s32[] parameter(3)
+ %constant.40630 = pred[] constant(true)
+ %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
+ %p.0.lhs.40626 = f32[] parameter(0)
+ %p.0.rhs.40627 = f32[] parameter(1)
+ %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
+ ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
+ })";
+
+TEST_F(TopkSplitterTest, SplitsTopK) {
+ const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+ %arg.1 = f32[1,1073741824] parameter(0)
+ ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+ kComparator);
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
+ auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
+ auto slice_result = [&](auto input, size_t i) {
+ return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
+ };
+ auto index_correction =
+ m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
+ auto sorted = m::Sort(
+ m::Reshape(m::GetTupleElement(first_topk, 0)),
+ m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
+ EXPECT_TRUE(
+ Match(module->entry_computation()->root_instruction(),
+ m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
+}
+
+TEST_F(TopkSplitterTest, SplitsTopKNoBatchDimension) {
+ const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+ %arg.1 = f32[1073741824] parameter(0)
+ ROOT %cc.2 = (f32[5], s32[5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+ kComparator);
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
+ auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
+ auto slice_result = [&](auto input, size_t i) {
+ return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
+ };
+ auto index_correction =
+ m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
+ auto sorted = m::Sort(
+ m::Reshape(m::GetTupleElement(first_topk, 0)),
+ m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
+ EXPECT_TRUE(
+ Match(module->entry_computation()->root_instruction(),
+ m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
+}
+
+TEST_F(TopkSplitterTest, SplitFailsUnderThreshold) {
+ const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+ %arg.1 = f32[1,524288] parameter(0)
+ ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+ kComparator);
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_THAT(
+ RunHloPass(TopKSplitter(/*split_threshold=*/1048576), module.get()),
+ IsOkAndHolds(false));
+}
+
+TEST_F(TopkSplitterTest, SplitFailsUnaligned) {
+ const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+ %arg.1 = f32[1,524289] parameter(0)
+ ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+ kComparator);
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
+ IsOkAndHolds(false));
+}
+
+TEST_F(TopkSplitterTest, SplitFailsLargeK) {
+ const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+ %arg.1 = f32[1,524288] parameter(0)
+ ROOT %cc.2 = (f32[1,1024], s32[1,1024]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+ kComparator);
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
+ IsOkAndHolds(false));
+}
+
+TEST_F(TopkSplitterTest, Equivalent) {
+ const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+ %arg.1 = f32[1,16384] parameter(0)
+ ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+ kComparator);
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
+ auto round_trip = [](HloModule* module) {
+ EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
+ return true;
+ }).Run(module),
+ IsOkAndHolds(true));
+ EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
+ EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
+ EXPECT_TRUE(HloDCE().Run(module).status().ok());
+ };
+ EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
+}
+
+TEST_F(TopkSplitterTest, StableSorts) {
+ const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+ %constant.1 = f32[] constant(42)
+ %broadcast.2= f32[1,16384] broadcast(f32[] %constant.1), dimensions={}
+ ROOT %cc.3 = (f32[1,5], s32[1,5]) custom-call(%broadcast.2), custom_call_target= "TopK", to_apply=%compare
+})",
+ kComparator);
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
+ auto round_trip = [](HloModule* module) {
+ EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
+ return true;
+ }).Run(module),
+ IsOkAndHolds(true));
+ EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
+ EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
+ EXPECT_TRUE(HloDCE().Run(module).status().ok());
+ };
+ EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc
new file mode 100644
index 0000000..fb023fc
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc
@@ -0,0 +1,389 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/numeric/bits.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+absl::InlinedVector<int64_t, 2> GetSortedReducedDims(
+ HloReduceInstruction *reduce) {
+ absl::InlinedVector<int64_t, 2> reduced_dims{reduce->dimensions().begin(),
+ reduce->dimensions().end()};
+ absl::c_sort(reduced_dims);
+ return reduced_dims;
+}
+
+bool IsMinMaxReduction(HloReduceInstruction *reduce) {
+ HloComputation *called = &reduce->to_apply()[0];
+ if (auto reduction_kind = MatchReductionComputation(called)) {
+ return reduction_kind == ReductionKind::MAX ||
+ reduction_kind == ReductionKind::MIN;
+ }
+ return false;
+}
+
+} // namespace
+
+class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit ReductionRewriterVisitor(se::GpuComputeCapability gpu_version)
+ : gpu_version_(gpu_version) {}
+
+ absl::Status HandleReduce(HloInstruction *hlo) override {
+ auto *reduce = Cast<HloReduceInstruction>(hlo);
+ VLOG(3) << "Reduction instruction: " << reduce->ToString();
+
+ const HloModuleConfig &config = reduce->GetModule()->config();
+ if (!MatchReductionForSplit(reduce, config)) {
+ return absl::OkStatus();
+ }
+ ReductionDimensions reduction_dims =
+ GetReductionKindAndContiguousComponents(*hlo);
+ if (ReductionIsRaceFree(config, reduction_dims)) {
+ VLOG(3) << "Base case: dimensions fit";
+ return absl::OkStatus();
+ }
+ auto sorted_dims_to_reduce = GetSortedReducedDims(reduce);
+ CHECK_LE(sorted_dims_to_reduce.size(), 2);
+
+ // If the major reduced dimension does not fit, reduce the minor dimension
+ // first, then the major.
+ if (reduction_dims.is_row_reduction &&
+ reduction_dims
+ .dimensions[ReductionDimensions::kRowMajorReducedDimension] >
+ BatchedReductionRaceFreeBound()) {
+ VLOG(2) << "Splitting batched dimension reduce into a separate reduction";
+ return RewriteBatchDimensionLargerThanTile(reduce, reduction_dims,
+ sorted_dims_to_reduce);
+ }
+ SplitParams split_params =
+ ComputeSplitParams(reduce, reduction_dims, sorted_dims_to_reduce);
+ return SplitReductionDimension(reduce, split_params, sorted_dims_to_reduce);
+ }
+
+ private:
+ bool MatchReductionForSplit(HloReduceInstruction *reduce,
+ const HloModuleConfig &config) {
+ // MLIR emitters only support race-free reductions.
+ // TODO(jreiffers: Verify performance and implement atomics for reductions
+ // if needed.
+ bool reductions_via_mlir_disabled =
+ config.debug_options().xla_gpu_mlir_emitter_level() < 4;
+ if (reductions_via_mlir_disabled && IsMinMaxReduction(reduce)) {
+ // TODO(cheshire): Also enable for integers.
+ VLOG(1) << "Not performing tree expansion on min/max-reduction: "
+ << reduce->ToString()
+ << " since min/max operations are associative";
+ return false;
+ }
+ if (!IsReductionFromOrToContiguousDimensions(*reduce)) {
+ VLOG(3) << "Is not a reduction from or to contiguous dimensions";
+ return false;
+ }
+ VLOG(3) << "Perform rewrite";
+ return true;
+ }
+
+ // We observe larger n_div_k can improve tree reduction performance in most of
+ // the cases by reducing memory store and the launch overhead of blocks. Swap
+ // k and n_div_k if possible.
+ bool ShouldSwapInnerAndOuterReducedMinorDimension(uint64_t k1, uint64_t k2,
+ uint64_t n,
+ int64_t race_free_bound,
+ bool is_row_reduction) {
+ CHECK(k1 >= k2);
+ // Keep inner reduction as race free.
+ if (k1 > race_free_bound) {
+ return false;
+ }
+ // Swapping only affects row reduction vectorization.
+ if (is_row_reduction) {
+ // Rough conditions for row reduction vectorization, not mean that
+ // vectorization will definitely occur.
+ bool maybe_vectorized = k2 % 2 == 0 && n % 2 == 0;
+ if (maybe_vectorized) {
+ // Swap if n_div_k is small enough or k dim can be vectorized also.
+ return k2 * 2 < k1 || k1 % 2 == 0;
+ }
+ // Current reduction emitter only checks reduction input dimensions but
+ // not fusion input dimensions. Due to pad and inner reduction always fuse
+ // into same computation, it may leads to each thread reads multiple non
+ // aligned elements but can not vectorized so that get bad performance.
+ // Don't swap If encountered this situation.
+ return n % 2 == 0 || k1 % 2 != 0;
+ }
+ // There exists no specific situation where swapping has no performance gain
+ // for column reduction.
+ return true;
+ }
+
+ // Parameters how to split a dimension `dim` with `k` elements into `k1` x
+ // `k2`.
+ struct SplitParams {
+ int64_t k1;
+ int64_t k2;
+ int64_t dim;
+ };
+
+ // Attempts to find the best way to split a dimension `dim` with `k` elements
+ // into `k1` x `k2`.
+ SplitParams ComputeSplitParams(
+ HloReduceInstruction *reduce, const ReductionDimensions &reduction_dims,
+ absl::Span<const int64_t> sorted_dims_to_reduce) {
+ absl::Span<int64_t const> input_shape_dims =
+ reduce->inputs()[0]->shape().dimensions();
+
+ int64_t reduced_dim = sorted_dims_to_reduce.back();
+ int64_t reduced_dim_size = input_shape_dims[reduced_dim];
+ VLOG(3) << "reduced dim size = " << reduced_dim_size;
+
+ // We will do this reduction in two stages. The first will reduce from k
+ // elements to k1 elements in the reduction dimension. The second will
+ // reduce further, from k2 to 1 element.
+ //
+ // We do this by splitting the input shape [a, k, b] into [a, k1, k2, b].
+ //
+ // We want to choose k1 to be roughly equal to sqrt(k) so that we process
+ // "most of" the reduction in the first step. But it is also important that
+ // we choose a value of k1 with the least amount of padding we need to add
+ // to n to make it divisible by k1. We search for the best value of k2
+ // between sqrt(k)/2 and sqrt(k). If there are several possible values for
+ // k2 that result in the minimum amount of padding, we also want k2 to
+ // be a power of 2, so that the GPU kernel doesn't spend all its time doing
+ // slow integer divmods to compute indices into the shape [a,k1,k2,b].
+ // Note that by searching in the range between sqrt(k)/2 and sqrt(k), we
+ // will have a power of 2 in that range.
+ uint64_t k2 =
+ static_cast<uint64_t>(std::floor(std::sqrt(reduced_dim_size)));
+ int64_t race_free_bound = ReductionDimensionRaceFreeBound(
+ reduce->GetModule()->config(), reduction_dims);
+ if (k2 > race_free_bound) {
+ // This means we need more than one split. It is best to limit the n/k
+ // dimension to the maximum size that doesn't require further splitting.
+ // Otherwise we might choose a rather small reduce dimension size for the
+ // first step (in the worst case, sqrt(race_free_bound + 1)).
+ k2 = race_free_bound;
+ }
+ uint64_t minimum_padding = (k2 - reduced_dim_size % k2) % k2;
+ uint64_t best_k1 = (reduced_dim_size + minimum_padding) / k2;
+ for (uint64_t i = k2 - 1; i > k2 / 2; --i) {
+ uint64_t padding = (i - reduced_dim_size % i) % i;
+ if (padding < minimum_padding ||
+ (padding == minimum_padding && absl::has_single_bit(i))) {
+ minimum_padding = padding;
+ best_k1 = (reduced_dim_size + padding) / i;
+ }
+ }
+ uint64_t padded_k = reduced_dim_size + minimum_padding;
+
+ // We get the best {k_1, k_2} pair by the size of padding and whether
+ // index computation is fast. But we ignored the overhead of memory
+ // read/write and blocks launch, which are also important for kernel
+ // performance. It is obvious that the swapped {k1, k2} pairs has same
+ // padding size and consumption of index computation as the original. So we
+ // only need to compare the memory read/write and blocks launch to choose
+ // the better one of them.
+ uint64_t best_k2 = padded_k / best_k1;
+ if (ShouldSwapInnerAndOuterReducedMinorDimension(
+ best_k1, best_k2, reduced_dim_size, race_free_bound,
+ reduction_dims.is_row_reduction)) {
+ std::swap(best_k1, best_k2);
+ }
+ return SplitParams{static_cast<int64_t>(best_k1),
+ static_cast<int64_t>(best_k2), reduced_dim};
+ }
+
+ // Replaces the original reduce with pad->reshape>inner_reduce->outer_reduce.
+ // * 1. pads split dimension of the inputs to k1 * k2 if necessary.
+ // * 2. reshapes split dimension of the padded inputs into [k1, k2].
+ // * 3. inner reduction reduces the dims specified in the original reduction.
+ // Instead of reducing the split dimension, reduces K2.
+ // * 4. outer_reduction reduces K1 only.
+ absl::Status SplitReductionDimension(
+ HloReduceInstruction *reduce, const SplitParams &split_params,
+ absl::Span<const int64_t> sorted_dims_to_reduce) {
+ absl::Span<int64_t const> reduce_input_dims =
+ reduce->inputs()[0]->shape().dimensions();
+ int64_t split_dim_size = reduce_input_dims[split_params.dim];
+ VLOG(2) << "dimension to split = " << split_params.dim << " with "
+ << split_dim_size << " elements into " << split_params.k1 << " by "
+ << split_params.k2;
+
+ // Pad 'k' to 'k1 * k2' if necessary.
+ HloInstruction::InstructionVector padded_inputs(reduce->inputs().begin(),
+ reduce->inputs().end());
+ auto padded_size = split_params.k1 * split_params.k2;
+ absl::InlinedVector<int64_t, 3> padded_dimensions(reduce_input_dims.begin(),
+ reduce_input_dims.end());
+ if (split_dim_size != padded_size) {
+ padded_dimensions[split_params.dim] = padded_size;
+ PaddingConfig padding_config =
+ MakeNoPaddingConfig(reduce_input_dims.size());
+ padding_config.mutable_dimensions(split_params.dim)
+ ->set_edge_padding_high(padded_size - split_dim_size);
+
+ for (int input_idx = 0; input_idx < padded_inputs.size(); ++input_idx) {
+ auto &reduction_input = padded_inputs[input_idx];
+ Shape padded_shape = ShapeUtil::MakeShape(
+ reduction_input->shape().element_type(), padded_dimensions);
+ VLOG(2) << "Generated padded shape: " << padded_shape.ToString();
+ reduction_input = reduce->parent()->AddInstruction(
+ HloInstruction::CreatePad(padded_shape, reduction_input,
+ reduce->init_values()[input_idx],
+ padding_config),
+ &reduction_input->metadata());
+ }
+ }
+
+ // Compute output type of reshape that expands the split dimension into
+ // [k1, k2].
+ absl::InlinedVector<int64_t, 3> reshaped_dimensions;
+ int64_t input_rank = reduce_input_dims.size();
+ for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
+ if (dim_idx == split_params.dim) {
+ reshaped_dimensions.push_back(split_params.k1);
+ reshaped_dimensions.push_back(split_params.k2);
+ } else {
+ reshaped_dimensions.push_back(padded_dimensions[dim_idx]);
+ }
+ }
+
+ // Compute dimensions to reduce for inner reduction.
+ absl::InlinedVector<int64_t, 2> inner_reduce_dims(
+ sorted_dims_to_reduce.begin(), sorted_dims_to_reduce.end());
+ auto split_dim_it = std::find(inner_reduce_dims.begin(),
+ inner_reduce_dims.end(), split_params.dim);
+ *split_dim_it += 1;
+
+ // Compute dimension to reduce for outer reduction.
+ absl::InlinedVector<int64_t, 1> outer_reduce_dims{
+ split_params.dim -
+ std::distance(inner_reduce_dims.begin(), split_dim_it)};
+
+ // Compute output shape of the inner reduction.
+ absl::InlinedVector<int64_t, 3> inner_reduce_shape =
+ RemoveElements(inner_reduce_dims, reshaped_dimensions);
+
+ // Reshape the split dimensions of the padded inputs into [k1, k2].
+ HloInstruction::InstructionVector reshaped_padded_inputs;
+ absl::InlinedVector<Shape, 2> inner_reduce_shapes;
+ for (HloInstruction *padded_input : padded_inputs) {
+ Shape reshaped_shape = ShapeUtil::MakeShape(
+ padded_input->shape().element_type(), reshaped_dimensions);
+ HloInstruction *reshaped_padded_input = reduce->parent()->AddInstruction(
+ HloInstruction::CreateBitcast(reshaped_shape, padded_input),
+ &padded_input->metadata());
+ VLOG(2) << "Generated reshape: " << reshaped_padded_input->ToString();
+ reshaped_padded_inputs.push_back(reshaped_padded_input);
+ inner_reduce_shapes.push_back(ShapeUtil::MakeShape(
+ padded_input->shape().element_type(), inner_reduce_shape));
+ }
+
+ // Inner reduce that reduces [k1, k2] to [k1].
+ HloInstruction *inner_reduce = reduce->parent()->AddInstruction(
+ HloInstruction::CreateReduce(
+ ShapeUtil::MakeMaybeTupleShape(inner_reduce_shapes),
+ reshaped_padded_inputs, reduce->init_values(), inner_reduce_dims,
+ reduce->to_apply()),
+ &reduce->metadata());
+ VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
+
+ // Outer reduce that reduces [k2].
+ std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
+ reduce->shape(), inner_reduce, reduce->init_values(), outer_reduce_dims,
+ reduce->to_apply());
+
+ VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
+ return ReplaceWithNewInstruction(reduce, std::move(outer_reduce));
+ }
+
+ // Rewrites batch dimension reduction into a separate reduce operation.
+ absl::Status RewriteBatchDimensionLargerThanTile(
+ HloReduceInstruction *hlo,
+ const ReductionDimensions &reduction_dimensions,
+ absl::Span<const int64_t> sorted_dims_to_reduce) {
+ // TODO(cheshire): this codepath is essentially the exact reverse of what
+ // algebraic_simplifier is doing, we need to make sure they don't keep
+ // undoing each other.
+ CHECK(reduction_dimensions.is_row_reduction);
+
+ absl::InlinedVector<Shape, 2> tuple_shapes;
+ int64_t minor_reduction_dim = sorted_dims_to_reduce.back();
+ for (HloInstruction *input : hlo->inputs()) {
+ tuple_shapes.push_back(
+ ShapeUtil::DeleteDimension(minor_reduction_dim, input->shape()));
+ }
+
+ HloInstruction *inner_reduce =
+ hlo->parent()->AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeMaybeTupleShape(tuple_shapes), hlo->inputs(),
+ hlo->init_values(), {minor_reduction_dim}, hlo->to_apply()));
+
+ VLOG(1) << "Inner reduction: " << inner_reduce->ToString();
+ std::unique_ptr<HloInstruction> out = HloInstruction::CreateReduce(
+ hlo->shape(), inner_reduce, hlo->init_values(), {0}, hlo->to_apply());
+ VLOG(1) << "Generated: " << out->ToString();
+ return ReplaceWithNewInstruction(hlo, std::move(out));
+ }
+
+ se::GpuComputeCapability gpu_version_;
+};
+
+absl::StatusOr<bool> TreeReductionRewriter::Run(
+ HloModule *module,
+ const absl::flat_hash_set<absl::string_view> &execution_threads) {
+ VLOG(5) << "Rewriter input: " << module->ToString();
+ TF_ASSIGN_OR_RETURN(bool changed,
+ ReductionRewriterVisitor(gpu_version_)
+ .RunOnModule(module, execution_threads));
+ VLOG(5) << "Rewriter output: " << module->ToString();
+ return changed;
+}
+
+} // end namespace gpu
+} // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h
new file mode 100644
index 0000000..7f57d21
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h
@@ -0,0 +1,96 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites reductions in a way they can be implemented without atomics.
+//
+// Rule application: rewrite a single HLO reduce operation into two.
+//
+// Case 1: Row reduction, batched dimension is present, larger than
+// Z-tiling size.
+// -----------------------------------------------------------------
+//
+// Rewriting:
+//
+// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
+//
+// Into:
+//
+// f32[A, B] tmp = reduce(f32[A, B, C] input, dimensions={2})
+// f32[B] out = reduce(f32[A, B] tmp, dimensions={0})
+//
+// Case 2: Row reduction
+// ------------------------------------------------------------------
+//
+// Let M be the thread tiling multiplied by the warp size.
+// We go from (assuming C > M):
+//
+// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
+//
+// to:
+//
+// f32[A, B, P] padded = pad(input) // Let P = ceil(C/M) * M.
+// f32[A, B, Q, M] reshaped = bitcast(padded) // Let Q = ceil(C/M)
+// f32[B, Q] inner_reduce = reduce(reshaped, dimensions={0, 3})
+// f32[B] outer_reduce = reduce(inner_reduce, dimensions={1})
+//
+// Case 3: Column reduction
+// -------------------------------------------------------------------
+//
+// Let T be the tiling size for the column reduction.
+//
+// We go from (assuming B > T):
+//
+// f32[A, C] out = reduce(f32[A, B, C] input, dimensions={1})
+//
+// to:
+//
+// f32[A, P, C] padded = pad(input) // Let P = ceil(B/T) * T.
+// f32[A, Q, T, C] reshaped = bitcast(padded) // Let Q = ceil(B/T)
+// f32[A, Q, C] inner_reduce = reduce(reshaped, dimensions={2})
+// f32[A, C] outer_reduce = reduce(inner_reduce, dimensions={1})
+//
+class TreeReductionRewriter : public HloModulePass {
+ public:
+ explicit TreeReductionRewriter(se::GpuComputeCapability gpu_version)
+ : gpu_version_(gpu_version) {}
+
+ ~TreeReductionRewriter() override = default;
+ absl::string_view name() const override { return "tree-reduction-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ se::GpuComputeCapability gpu_version_;
+};
+
+} // end namespace gpu
+} // end namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc
new file mode 100644
index 0000000..bea969e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc
@@ -0,0 +1,557 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class TreeReductionRewriterTest : public HloTestBase {
+ public:
+ void CheckTreeRewriter(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(
+ hlo,
+#if TENSORFLOW_USE_ROCM
+ gpu::TreeReductionRewriter{se::RocmComputeCapability {
+ "908"
+ }},
+#else
+ gpu::TreeReductionRewriter{se::CudaComputeCapability{8, 1}},
+#endif
+ expected);
+ }
+};
+
+TEST_F(TreeReductionRewriterTest, RowReductionSingleDimensionNoBatched) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[50021] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[pad_0:%[^ ]+]] = f32[50022]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1
+// CHECK: [[bitcast_3:%[^ ]+]] = f32[126,397]{1,0} bitcast([[pad_0]])
+// CHECK: [[reduce_4:%[^ ]+]] = f32[126]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]]
+// CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionWeirdOutputLayout) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[2,4,17000]{2,1,0} parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[2,4]{0,1} reduce(input, zero), dimensions={2}, to_apply=add
+}
+)";
+
+ // Check that we preserve the layout.
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: f32[2,4]{0,1} reduce(
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest,
+ RowReductionSingleDimensionNoBatchedDivisible) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[50048] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[50048]{0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,391]{1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[128]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionNoBatched) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[100,10,65536] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[100,10] reduce(input, zero), dimensions={2}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[100,10,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[100,10,256]{2,1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100,10]{1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={2}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest,
+ RowReductionSingleDimensionNoBatchedLargeInput) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1048576] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1048576]{0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[1024,1024]{1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1024]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionFits) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[8,100,65536] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[8,100,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[100,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={0,3}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[32,100,90000] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[reduce_0:%[^ ]+]] = f32[32,100]{1,0} reduce([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), dimensions={2}, to_apply=[[add_3:%[^ ]+]]
+// CHECK: ROOT [[out_1_4:%[^ ]+]] = f32[100]{0} reduce([[reduce_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_3]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionSimple) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[16384,100] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK: [[input_0:%[^ ]+]] = f32[16384,100]{1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,128,100]{2,1,0} bitcast([[input_0]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[128,100]{1,0} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_2]], [[zero_3]]), dimensions={0}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionSimpleNoDivisible) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[10303,100] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[10303,100]{1,0} parameter(0)
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1x0_0
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[64,161,100]{2,1,0} bitcast([[pad_0]])
+// CHECK: [[reduce_3:%[^ ]+]] = f32[64,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionOtherIndex) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[16384,2,2,2] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[16384,2,2,2]{3,2,1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,128,2,2,2]{4,3,2,1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[128,2,2,2]{3,2,1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionVeryLargeInput) {
+ const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1048576,5] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[1024,1024,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,5]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[5]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, VariadicReductionLargeRow) {
+ const char* hlo = R"(
+HloModule Reduce_R1x2_to_R0x2_argmax
+
+argmax {
+ running_max = f32[] parameter(0)
+ running_max_idx = u32[] parameter(1)
+ current_value = f32[] parameter(2)
+ current_value_idx = u32[] parameter(3)
+
+ current = (f32[], u32[]) tuple(running_max, running_max_idx)
+ potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+ cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+ new_max = f32[] select(cmp_code, current_value, running_max)
+ new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+ ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+ input = f32[2,100003] parameter(0)
+ idxs = u32[2,100003] iota(), iota_dimension=0
+ zero = f32[] constant(0)
+ zero_idx = u32[] constant(0)
+
+ ROOT out = (f32[2], u32[2]) reduce(
+ input, idxs, zero, zero_idx),
+ dimensions={1},
+ to_apply=%argmax
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[pad_0:%[^ ]+]] = f32[2,100005]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_2
+// CHECK: [[bitcast_3:%[^ ]+]] = f32[2,295,339]{2,1,0} bitcast([[pad_0]])
+// CHECK: [[zero_idx_4:%[^ ]+]] = u32[] constant(0)
+// CHECK: [[pad_1_5:%[^ ]+]] = u32[2,100005]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_2
+// CHECK: [[bitcast_1_7:%[^ ]+]] = u32[2,295,339]{2,1,0} bitcast([[pad_1_5]])
+// CHECK: [[reduce_8:%[^ ]+]] = (f32[2,295]{1,0}, u32[2,295]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]]
+// CHECK: [[get_tuple_element_10:%[^ ]+]] = f32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=0
+// CHECK: [[get_tuple_element_1_11:%[^ ]+]] = u32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=1
+// CHECK: ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, VariadicReductionLargeBatchSize) {
+ const char* hlo = R"(
+HloModule Reduce_R1x2_to_R0x2_argmax
+
+argmax {
+ running_max = f32[] parameter(0)
+ running_max_idx = u32[] parameter(1)
+ current_value = f32[] parameter(2)
+ current_value_idx = u32[] parameter(3)
+
+ current = (f32[], u32[]) tuple(running_max, running_max_idx)
+ potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+ cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+ new_max = f32[] select(cmp_code, current_value, running_max)
+ new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+ ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+ input = f32[20,2,100] parameter(0)
+ idxs = u32[20,2,100] iota(), iota_dimension=0
+ zero = f32[] constant(0)
+ zero_idx = u32[] constant(0)
+
+ ROOT out = (f32[2], u32[2]) reduce(
+ input, idxs, zero, zero_idx),
+ dimensions={0,2},
+ to_apply=%argmax
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[reduce_0:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce([[input_1:%[^ ]+]], [[idxs_2:%[^ ]+]], [[zero_3:%[^ ]+]], [[zero_idx_4:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_5:%[^ ]+]]
+// CHECK: [[get_tuple_element_6:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=0
+// CHECK: [[get_tuple_element_1_7:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=1
+// CHECK: ROOT [[out_1_8:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_6]], [[get_tuple_element_1_7]], [[zero_3]], [[zero_idx_4]]), dimensions={0}, to_apply=[[argmax_5]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, KeepInnerReductionVectorized) {
+ const char* hlo = R"(
+HloModule KeepInnerRowReductionVectorized
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1024,73984] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[1024,289,256]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,289]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, PreferLargeVectorizedDimension) {
+ const char* hlo = R"(
+HloModule PreferLargeVectorizedDimension
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1024,98304] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[1024,256,384]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, SwapIfNonAlignedBeforePadding) {
+ const char* hlo = R"(
+HloModule SwapIfNonAlignedBeforePadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1024,19739] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK-DAG: [[bitcast_0:%[^ ]+]] = f32[1024,140,141]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK-DAG: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, DontSwapIfNonAlignedBeforePadding) {
+ const char* hlo = R"(
+HloModule DontSwapIfNonAlignedBeforePadding
+
+add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+ input = f32[1024,19459] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK-DAG: [[bitcast_0:%[^ ]+]] = f32[1024,140,139]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK-DAG: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+ )");
+}
+
+TEST_F(TreeReductionRewriterTest, NonCosequtiveReductionDims) {
+ const char* hlo = R"(
+ HloModule NonCosequtiveReductionDims
+
+ add {
+ accum = f32[] parameter(0)
+ op = f32[] parameter(1)
+ ROOT out = f32[] add(accum, op)
+ }
+
+ ENTRY main {
+ input = f32[5,3,4,5] parameter(0)
+ zero = f32[] constant(0)
+ ROOT out = f32[5,4] reduce(input, zero), dimensions={1,3}, to_apply=add
+ }
+ )";
+
+ CheckTreeRewriter(hlo, std::nullopt);
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc
new file mode 100644
index 0000000..e81bdae
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc
@@ -0,0 +1,88 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/triangular_solve_rewriter.h"
+
+#include <cstdint>
+#include <numeric>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> TriangularSolveRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ std::vector<HloInstruction*> to_rewrite;
+ for (HloInstruction* instr : comp->instructions()) {
+ if (instr->opcode() == HloOpcode::kTriangularSolve) {
+ to_rewrite.push_back(instr);
+ }
+ }
+
+ for (HloInstruction* instr : to_rewrite) {
+ const Shape& b_shape = instr->operand(1)->shape();
+ int64_t batch_size = std::accumulate(
+ b_shape.dimensions().begin(), b_shape.dimensions().end() - 2,
+ int64_t{1}, [](int64_t a, int64_t b) { return a * b; });
+
+ // batch 1 triangular solves get 0 temp bytes, because unbatched trsm()
+ // doesn't require temp memory.
+ int64_t temp_bytes = batch_size == 1 ? 0 : 2 * sizeof(void*) * batch_size;
+ Shape new_shape = ShapeUtil::MakeTupleShape({
+ instr->shape(),
+ ShapeUtil::MakeShape(S8, {temp_bytes}),
+ });
+
+ HloInstruction* custom_call =
+ comp->AddInstruction(HloInstruction::CreateCustomCall(
+ new_shape, instr->operands(), kTriangularSolveCallTarget));
+ module->SetAndUniquifyInstrName(custom_call, "triangular-solve");
+ TF_RETURN_IF_ERROR(
+ custom_call->set_backend_config(instr->triangular_solve_options()));
+
+ // Preserve metadata from `instr`.
+ custom_call->set_metadata(instr->metadata());
+ custom_call->set_frontend_attributes(instr->frontend_attributes());
+
+ // Get the actual result out of the custom call's tuple.
+ TF_ASSIGN_OR_RETURN(HloInstruction * gte,
+ MakeGetTupleElementHlo(custom_call, 0));
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h
new file mode 100644
index 0000000..c52e0ff
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h
@@ -0,0 +1,60 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites HLO TriangularSolve ops into a custom-call.
+//
+// The motivation for this is that we need to add temp memory to batched
+// triangular-solve ops in order to call cublas trsmBatched. We rewrite batch 1
+// ops as well so that we have fewer codepaths to worry about in the backend.
+//
+// cublas trsmBatched takes arrays in GPU memory of pointers to the inputs and
+// outputs, `a` and `b`. In XLA the inputs/outputs are always contiguous, but
+// we still have to materialize out these arrays.
+//
+// We use the same trick as for cudnn convolutions: This custom-call returns a
+// tuple (actual-result, temp-memory). In this our case the temp buffer always
+// has size 2 * sizeof(void*) * batch_size, because we need two arrays of
+// pointers.
+//
+// The custom-call has a backend-config equal to the TriangularSolveOptions
+// object.
+class TriangularSolveRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "triangular-solve-rewriter";
+ }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc
new file mode 100644
index 0000000..10ae640
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc
@@ -0,0 +1,190 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
+
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/functional/any_invocable.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+namespace {
+
+using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
+
+// Returns the input instruction as a fusion instruction, if it represents a
+// Triton fusion. Otherwise, returns nullptr.
+absl::StatusOr<const HloFusionInstruction*> AsTritonFusion(
+ const HloInstruction* hlo) {
+ if (hlo->opcode() != HloOpcode::kFusion) {
+ return nullptr;
+ }
+ const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
+ fusion->backend_config<GpuBackendConfig>());
+ const FusionBackendConfig& backend_config =
+ gpu_config.fusion_backend_config();
+ if (backend_config.kind() == kTritonFusionKind) {
+ return fusion;
+ }
+ return nullptr;
+}
+
+std::unique_ptr<HloModule> NewHloModuleFromFusion(
+ const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
+ bool clear_backend_config) {
+ std::unique_ptr<HloModule> new_module =
+ ExtractInstructionIntoNewModule(fusion);
+ if (clear_backend_config) {
+ new_module->entry_computation()->root_instruction()->clear_backend_config();
+ }
+ new_module->mutable_config().set_debug_options(debug_opts);
+
+ return new_module;
+}
+
+} // namespace
+
+namespace triton_fusion_numerics_pass_internal {
+
+absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
+ AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
+ const AutotuneConfig& config, const DebugOptions& debug_opts,
+ bool clear_backend_config) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+ util.Compile([&](const DebugOptions& opts) {
+ return NewHloModuleFromFusion(fusion, opts,
+ clear_backend_config);
+ }));
+ TF_ASSIGN_OR_RETURN(auto rz_buffers, RedzoneBuffers::FromInstruction(
+ fusion, config, debug_opts,
+ RedzoneBuffers::kAllInputs));
+ TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
+ TF_ASSIGN_OR_RETURN(std::optional<ProfilingOutput> profiling_output,
+ util.ProfileExecutable(executable.get(), stream,
+ rz_buffers.input_buffers(),
+ rz_buffers.input_shapes()));
+ if (!profiling_output.has_value()) {
+ return Internal("No output after a successful verification run.");
+ }
+
+ return std::move(profiling_output->output);
+}
+
+absl::Status CompareBuffers(const ScopedShapedBuffer& current,
+ const ScopedShapedBuffer& expected,
+ const Shape& shape, const HloModuleConfig& config,
+ se::Stream* stream) {
+ BufferComparator comparator(
+ shape, config.debug_options().xla_gpu_autotune_gemm_rtol());
+ TF_ASSIGN_OR_RETURN(bool outputs_match,
+ comparator.CompareEqual(stream, current.root_buffer(),
+ expected.root_buffer()));
+
+ if (!outputs_match) {
+ return Internal("Triton fusion output does not match emitters output.");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ForAllTritonFusions(
+ const HloModule& module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads,
+ absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn) {
+ for (HloComputation* computation :
+ module.MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ TF_ASSIGN_OR_RETURN(auto triton_fusion, AsTritonFusion(instruction));
+ if (triton_fusion != nullptr) {
+ TF_RETURN_IF_ERROR(fn(*triton_fusion));
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
+} // namespace triton_fusion_numerics_pass_internal
+
+namespace {
+absl::Status VerifyTritonFusion(AutotunerCompileUtil& util,
+ const HloFusionInstruction& fusion,
+ const AutotuneConfig& config,
+ const DebugOptions& debug_opts) {
+ TF_ASSIGN_OR_RETURN(auto triton_result,
+ triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+ util, fusion, config, debug_opts,
+ /*clear_backend_config=*/false));
+ TF_ASSIGN_OR_RETURN(auto emitters_result,
+ triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+ util, fusion, config, debug_opts,
+ /*clear_backend_config=*/true));
+
+ TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
+ return triton_fusion_numerics_pass_internal::CompareBuffers(
+ triton_result, emitters_result, fusion.shape(),
+ fusion.GetModule()->config(), stream);
+}
+
+} // namespace
+
+absl::StatusOr<bool> TritonFusionNumericsVerifier::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ if (config_.IsDeviceless()) {
+ return absl::InternalError(
+ "Cannot run TritonFusionNumericsVerifier on a deviceless compilation.");
+ }
+
+ const DebugOptions& debug_options = module->config().debug_options();
+ TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
+ AutotunerCompileUtil::Create(config_, debug_options));
+ TF_RET_CHECK(opt_compile_util.has_value());
+
+ TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions(
+ *module, execution_threads, [&](const HloFusionInstruction& fusion) {
+ return VerifyTritonFusion(*opt_compile_util, fusion, config_,
+ debug_options);
+ }));
+ return false;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h
new file mode 100644
index 0000000..e3dc6eb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h
@@ -0,0 +1,74 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/functional/any_invocable.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/stream.h"
+
+namespace xla::gpu {
+
+// For each Triton fusion in the Hlo module this pass checks that the output
+// of the fusion generated via Triton matches the output of the fusion if
+// generated with the regular emitters.
+class TritonFusionNumericsVerifier : public HloModulePass {
+ public:
+ explicit TritonFusionNumericsVerifier(const AutotuneConfig& config)
+ : config_(config) {}
+
+ static absl::string_view Name() { return "triton-numerics-verifier"; }
+ absl::string_view name() const override { return Name(); }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ AutotuneConfig config_;
+};
+
+namespace triton_fusion_numerics_pass_internal {
+// These are exposed only for testing. Do not use.
+absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
+ AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
+ const AutotuneConfig& config, const DebugOptions& debug_opts,
+ bool clear_backend_config);
+absl::Status CompareBuffers(const ScopedShapedBuffer& current,
+ const ScopedShapedBuffer& expected,
+ const Shape& shape, const HloModuleConfig& config,
+ se::Stream* stream);
+absl::Status ForAllTritonFusions(
+ const HloModule& module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads,
+ absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn);
+} // namespace triton_fusion_numerics_pass_internal
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc
new file mode 100644
index 0000000..0382577
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc
@@ -0,0 +1,195 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/platform_util.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/test_helpers.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+
+namespace xla::gpu {
+namespace {
+
+class TritonFusionNumericsVerifierTest
+ : public HloTestBase,
+ public ::testing::WithParamInterface<PrimitiveType> {
+ public:
+ DebugOptions GetDebugOptionsForTest() override {
+ auto options = HloTestBase::GetDebugOptionsForTest();
+ options.set_xla_gpu_enable_triton_softmax_fusion(true);
+ options.set_xla_gpu_verify_triton_fusion_numerics(true);
+ return options;
+ }
+
+ protected:
+ std::unique_ptr<xla::HloModule> Module(absl::string_view hlo_text_template,
+ absl::string_view type) {
+ auto m = GetOptimizedModule(absl::Substitute(hlo_text_template, type));
+ TF_EXPECT_OK(m);
+ return std::move(m.value());
+ }
+
+ const HloFusionInstruction* TritonFusion(const xla::HloModule& module) {
+ const HloFusionInstruction* fusion_result = nullptr;
+
+ absl::Status res =
+ triton_fusion_numerics_pass_internal::ForAllTritonFusions(
+ module, /*execution_threads=*/{},
+ [&](const HloFusionInstruction& fusion) -> absl::Status {
+ EXPECT_EQ(fusion_result, nullptr);
+ fusion_result = &fusion;
+ return absl::OkStatus();
+ });
+ return fusion_result;
+ }
+
+ AutotuneConfig CreateAutotuneConfig() {
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+ auto executors_or = PlatformUtil::GetStreamExecutors(platform);
+ TF_EXPECT_OK(executors_or);
+ return AutotuneConfig{DeviceConfig{executors_or->at(0), nullptr},
+ GetDebugOptionsForTest()};
+ }
+
+ AutotunerCompileUtil CreateAutotunerCompileUtil(AutotuneConfig& config) {
+ auto opt_compile_util_or =
+ AutotunerCompileUtil::Create(config, GetDebugOptionsForTest());
+ TF_EXPECT_OK(opt_compile_util_or);
+ EXPECT_TRUE(opt_compile_util_or->has_value());
+ return std::move(opt_compile_util_or->value());
+ }
+};
+
+constexpr absl::string_view kSoftmaxHlo = R"(
+HloModule softmax
+max_computation {
+ arg_0 = $0[] parameter(0)
+ arg_1 = $0[] parameter(1)
+ ROOT maximum = $0[] maximum(arg_0, arg_1)
+}
+add_computation {
+ arg_0.1 = $0[] parameter(0)
+ arg_1.1 = $0[] parameter(1)
+ ROOT add = $0[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+ param_0 = $0[127,125]{1,0} parameter(0)
+ constant_neg_inf = $0[] constant(-inf)
+ reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+ broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0}
+ subtract = $0[127,125]{1,0} subtract(param_0, broadcast)
+ exponential = $0[127,125]{1,0} exponential(subtract)
+ constant_zero = $0[] constant(0)
+ second_reduce = $0[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+ second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+ ROOT divide = $0[127,125]{1,0} divide(exponential, second_broadcast)
+}
+)";
+
+bool HloPassHasRun(const HloModule& module, absl::string_view pass_name) {
+ for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) {
+ if (pass_metadata.pass_name() == pass_name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+TEST_P(TritonFusionNumericsVerifierTest, VerifyExactSoftmaxFusionNumerics) {
+ PrimitiveType data_type = GetParam();
+
+ auto module = Module(kSoftmaxHlo,
+ primitive_util::LowercasePrimitiveTypeName(data_type));
+
+ // At this point all HLO passes have been executed successfully, because the
+ // Module() function hasn't failed. In particular the numerics verification
+ // pass should have also run and **not** found any issues. Below we just
+ // ensure that the pass has indeed been correctly enabled and that there are
+ // Triton Fusions in the input module.
+
+ EXPECT_TRUE(HloPassHasRun(*module, TritonFusionNumericsVerifier::Name()));
+ auto fusion = TritonFusion(*module);
+ EXPECT_NE(fusion, nullptr);
+}
+
+TEST_F(TritonFusionNumericsVerifierTest, CheckMismatch) {
+ // This test intentionally compares two different Triton modules to each
+ // other. This is to test that the verifier functions correctly catch and
+ // report mismatches.
+ //
+ // Note that as part of computing the two modules below, the numerics verifier
+ // pass also runs individually for each module. These runs compare the
+ // modules to the corresponding emitters generated version, which matches. In
+ // that sense this test covers what is being tested by
+ // VerifyExactSoftmaxFusionNumerics. The reason to keep two tests is that
+ // VerifyExactSoftmaxFusionNumerics is minimal and will be easier to debug if
+ // it fails.
+
+ auto module_f16 = Module(kSoftmaxHlo, "f16");
+ auto fusion_f16 = TritonFusion(*module_f16);
+ EXPECT_NE(fusion_f16, nullptr);
+
+ auto module_f32 = Module(kSoftmaxHlo, "f32");
+ auto fusion_f32 = TritonFusion(*module_f32);
+ EXPECT_NE(fusion_f32, nullptr);
+
+ AutotuneConfig autotune_config = CreateAutotuneConfig();
+ AutotunerCompileUtil compile_util =
+ CreateAutotunerCompileUtil(autotune_config);
+ const DebugOptions& debug_options = GetDebugOptionsForTest();
+
+ auto f16_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+ compile_util, *fusion_f16, autotune_config, debug_options,
+ /*clear_backend_config=*/false);
+ TF_EXPECT_OK(f16_result);
+
+ auto f32_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+ compile_util, *fusion_f32, autotune_config, debug_options,
+ /*clear_backend_config=*/false);
+ TF_EXPECT_OK(f32_result);
+
+ auto stream = autotune_config.GetStream();
+ TF_EXPECT_OK(stream);
+
+ // Intentionally compare the fusions from the different modules, triggering a
+ // mismatch.
+ auto cmp = triton_fusion_numerics_pass_internal::CompareBuffers(
+ *f16_result, *f32_result, fusion_f16->shape(),
+ fusion_f16->GetModule()->config(), *stream);
+
+ EXPECT_FALSE(cmp.ok());
+}
+
+INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite,
+ TritonFusionNumericsVerifierTest,
+ ::testing::Values(F32, F16, BF16));
+
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc
new file mode 100644
index 0000000..0712040
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc
@@ -0,0 +1,115 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/variadic_op_splitter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+// The parameter space on the GPU device is limited. We pick an arbitrary low
+// constant here to try to prevent exceeding this parameter space. For a proper
+// fix, we would have to take into account which parameters share a buffer, and
+// how big these buffers are.
+constexpr int32_t kMaxParameters = 128;
+
+absl::StatusOr<bool> SplitConcatenate(HloInstruction* concat,
+ HloComputation* comp) {
+ auto operands = concat->operands();
+ std::vector<HloInstruction*> operands_to_split(operands.begin(),
+ operands.end());
+ while (operands_to_split.size() > 1) {
+ std::vector<HloInstruction*> new_operands;
+ absl::Span<HloInstruction*> operands_span(operands_to_split);
+ for (int64_t offset = 0; offset < operands_to_split.size();
+ offset += kMaxParameters) {
+ // Check if there is a remainder of operands that does not completely fill
+ // one "batch" of exactly 'kMaxParameters' operands. If there are only
+ // less than 'kMaxParameters' operands left, then we still put them into a
+ // concat together. Otherwise, we spare them for another round so that
+ // they can be put together into a concat with some of the newly created
+ // concats.
+ if (offset > 0 && offset + kMaxParameters > operands_to_split.size()) {
+ new_operands.insert(new_operands.end(),
+ operands_to_split.begin() + offset,
+ operands_to_split.end());
+ } else {
+ Shape new_shape = concat->shape();
+ int64_t concat_dimension_size = 0;
+ for (int64_t i = 0;
+ i < kMaxParameters && offset + i < operands_to_split.size(); ++i) {
+ concat_dimension_size +=
+ operands_to_split[i + offset]->shape().dimensions(
+ concat->concatenate_dimension());
+ }
+ new_shape.set_dimensions(concat->concatenate_dimension(),
+ concat_dimension_size);
+ auto new_concat = comp->AddInstruction(concat->CloneWithNewOperands(
+ new_shape, operands_span.subspan(offset, kMaxParameters)));
+ new_operands.push_back(new_concat);
+ }
+ }
+ operands_to_split = new_operands;
+ }
+ TF_RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0]));
+ return true;
+}
+
+std::vector<HloInstruction*> GetRelevantVariadicOps(HloComputation* comp) {
+ std::vector<HloInstruction*> ops;
+ for (HloInstruction* instr : comp->instructions()) {
+ if (instr->opcode() == HloOpcode::kConcatenate &&
+ instr->operand_count() > kMaxParameters) {
+ ops.push_back(instr);
+ }
+ }
+ return ops;
+}
+
+} // namespace
+
+absl::StatusOr<bool> VariadicOpSplitter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ for (HloInstruction* op : GetRelevantVariadicOps(comp)) {
+ // TODO(b/112613927): Handle also other ops than concatenate.
+ TF_ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp));
+ changed |= result;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h
new file mode 100644
index 0000000..304afa1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Splits variadic ops with many operands into pieces such that we don't exceed
+// the parameter space on the GPU. Currently only concatenate ops are split up.
+class VariadicOpSplitter : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "variadic-op-splitter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc
new file mode 100644
index 0000000..1d72613
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/variadic_op_splitter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/literal_util.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+using match::Concatenate;
+
+class VariadicOpSplitterTest : public HloTestBase {};
+
+TEST_F(VariadicOpSplitterTest, DontSplit) {
+ auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ p0 = f16[30,41] parameter(0)
+ p1 = f16[30,41] parameter(1)
+ ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
+ })")
+ .value();
+ EXPECT_FALSE(VariadicOpSplitter().Run(module.get()).value());
+}
+
+TEST_F(VariadicOpSplitterTest, SplitInto2) {
+ auto builder = HloComputation::Builder(TestName());
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
+ std::vector<HloInstruction*> concat_operands(255, operand);
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(S32, {255}), concat_operands, 0));
+ auto module = CreateNewVerifiedModule();
+ auto entry_computation = module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
+ EXPECT_TRUE(Match(entry_computation->root_instruction(),
+ Concatenate().WithNumOperands(128).WithOperand(
+ 0, Concatenate().WithNumOperands(128))));
+}
+
+TEST_F(VariadicOpSplitterTest, SplitInto3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
+ std::vector<HloInstruction*> concat_operands(256, operand);
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(S32, {256}), concat_operands, 0));
+ auto module = CreateNewVerifiedModule();
+ auto entry_computation = module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
+ EXPECT_TRUE(Match(entry_computation->root_instruction(),
+ Concatenate(Concatenate().WithNumOperands(128),
+ Concatenate().WithNumOperands(128))));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc
new file mode 100644
index 0000000..eeb24ce
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc
@@ -0,0 +1,1148 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/shape_inference.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = match;
+
+// Enables the creation of FP8 GEMM Custom Calls for all-gather and
+// reduce-scatter windowed einsums in gemm_rewriter.cc by moving the scalings
+// and type conversions of FP8 operands into the bodies of their while loops,
+// i.e. rewrites
+//
+// inputs --> dequant --> while loop {dynamic-slice/collective-permute/dot}
+//
+// into
+//
+// inputs --> while loop {dequant --> dynamic-slice/collective-permute/dot}.
+absl::Status ShiftDequantizationF8(const HloComputation* comp,
+ const std::array<HloInstruction*, 2>& gte) {
+ HloInstruction* while_instr = comp->WhileCallInstruction();
+ if (!while_instr) {
+ return absl::OkStatus();
+ }
+
+ // Identify the scalings and type conversions applied to the inputs of the
+ // while loop.
+ HloInstruction* param_tuple = while_instr->mutable_operand(0);
+ std::array<HloInstruction*, 2> binaries, operands, scales;
+ for (int k = 0; k < 2; ++k) {
+ if (!Match(param_tuple->mutable_operand(k),
+ m::AnyOf<HloInstruction>(
+ m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])),
+ m::Broadcast(m::Op(&scales[k]))),
+ m::MultiplyAnyOrder(&binaries[k],
+ m::Convert(m::Op(&operands[k])),
+ m::Broadcast(m::Op(&scales[k])))))) {
+ VLOG(5) << "Unable to identify FP8 dequantization pattern.";
+ return absl::OkStatus();
+ }
+ }
+
+ // For the dot to be rewritten by gemm_rewriter.cc into an FP8 GEMM, at most
+ // one of the inputs can be F8E5M2.
+ std::array<PrimitiveType, 2> operand_types{
+ operands[0]->shape().element_type(), operands[1]->shape().element_type()};
+ if (!((operand_types[0] == F8E4M3FN && operand_types[1] == F8E4M3FN) ||
+ (operand_types[0] == F8E4M3FN && operand_types[1] == F8E5M2) ||
+ (operand_types[0] == F8E5M2 && operand_types[1] == F8E4M3FN))) {
+ VLOG(5) << "Unsupported types.";
+ return absl::OkStatus();
+ }
+
+ // The dequantized types must be BF16, FP16 or FP32.
+ for (int k = 0; k < 2; ++k) {
+ if (binaries[k]->shape().element_type() != BF16 &&
+ binaries[k]->shape().element_type() != F16 &&
+ binaries[k]->shape().element_type() != F32) {
+ VLOG(5) << "Unsupported types.";
+ return absl::OkStatus();
+ }
+ }
+
+ // The FP8 scaling operands must be scalars.
+ if (!ShapeUtil::IsScalar(scales[0]->shape()) ||
+ !ShapeUtil::IsScalar(scales[1]->shape())) {
+ VLOG(5) << "Scaling factors must be scalars.";
+ return absl::OkStatus();
+ }
+
+ // Identify the dot and collective-permute or dynamic-slice instructions in
+ // the all-gather or reduce-scatter patterns in while's body.
+ HloComputation* while_body = while_instr->while_body();
+ HloComputation* while_condition = while_instr->while_condition();
+ HloInstruction* while_root = while_body->root_instruction();
+ std::array<HloInstruction*, 2> dots, dyn_slices{nullptr, nullptr},
+ coll_perms{nullptr, nullptr};
+ if (Match(
+ while_root,
+ m::Tuple(m::CollectivePermute(
+ &coll_perms[1], m::CollectivePermute(
+ &coll_perms[0], m::Op().Is(gte[0]))),
+ m::Op().Is(gte[1]),
+ m::DynamicUpdateSlice(
+ m::DynamicUpdateSlice().WithOperand(
+ 1, m::Dot(&dots[0], m::Op().Is(gte[0]),
+ m::Op().Is(gte[1]))),
+ m::Dot(&dots[1], m::Op(), m::Op().Is(gte[1])), m::Op(),
+ m::Op(), m::Op()),
+ m::Op(), m::Op()))) {
+ VLOG(5) << "Identified all-gather windowed einsum pattern.";
+ } else if (Match(
+ while_root,
+ m::Tuple(m::Op().Is(gte[0]), m::Op().Is(gte[1]),
+ m::AddAnyOrder(
+ m::Dot(&dots[0], m::DynamicSlice(&dyn_slices[0]),
+ m::Op().Is(gte[1])),
+ m::Op()),
+ m::CollectivePermute(m::AddAnyOrder(
+ m::Dot(&dots[1], m::DynamicSlice(&dyn_slices[1]),
+ m::Op().Is(gte[1])),
+ m::Op())),
+ m::Op()))) {
+ VLOG(5) << "Identified reduce-scatter windowed einsum pattern.";
+ } else {
+ VLOG(5) << "Unable to identify valid windowed einsum pattern.";
+ return absl::OkStatus();
+ }
+
+ // Replace the dequantized dot operands in the parameter tuple used by while
+ // with FP8 operands.
+ for (int k = 0; k < 2; ++k) {
+ TF_RETURN_IF_ERROR(
+ param_tuple->ReplaceOperandWithDifferentShape(k, operands[k]));
+ ShapeUtil::UpdateTupleShape(operands[k]->shape(), k,
+ param_tuple->mutable_shape());
+ param_tuple->AppendOperand(scales[k]);
+ ShapeUtil::AppendShapeToTuple(scales[k]->shape(),
+ param_tuple->mutable_shape());
+ }
+
+ // Update the parameter tuples of while's body and condition computations.
+ for (HloComputation* while_comp : {while_body, while_condition}) {
+ while_comp->ReplaceParameter(
+ 0, HloInstruction::CreateParameter(
+ 0, param_tuple->shape(),
+ while_comp->parameter_instruction(0)->name()));
+ }
+
+ // In the while body, replace the existing get-tuple-element instructions
+ // retrieving BF16/FP16/FP32 dot operands with dequantized get-tuple-element
+ // instructions retrieving FP8 dot operands from the input tuple.
+ HloInstruction* body_param = while_body->parameter_instruction(0);
+ for (int k = 0; k < 2; ++k) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * operand_f8,
+ MakeGetTupleElementHlo(body_param, k));
+
+ if (while_root->operand(k) == gte[k]) {
+ TF_RETURN_IF_ERROR(
+ while_root->ReplaceOperandWithDifferentShape(k, operand_f8));
+ ShapeUtil::UpdateTupleShape(operand_f8->shape(), k,
+ while_root->mutable_shape());
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * operand_scale,
+ MakeGetTupleElementHlo(
+ body_param, body_param->shape().tuple_shapes_size() - 2 + k));
+
+ // Also add the scaling factor to the output tuple of the while body.
+ while_root->AppendOperand(operand_scale);
+ ShapeUtil::AppendShapeToTuple(operand_scale->shape(),
+ while_root->mutable_shape());
+
+ // Dequantize the operands of the dots and dynamic-slices.
+ HloInstruction* operand_f32 =
+ MakeConvertToHlo(operand_f8, gte[k]->shape().element_type());
+ HloInstruction* broadcast_scale =
+ MakeBroadcastHlo(operand_scale, {}, operand_f32->shape());
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * operand_scaled,
+ MakeBinaryHlo(binaries[k]->opcode(), operand_f32, broadcast_scale));
+
+ // Replace the original get-tuple-element instructions accessing the
+ // operands of the dots and dynamic-slices with the dequantized FP8
+ // operands. The order of dequantization and dynamic-slices will be
+ // exchanged in gemm_rewriter.cc.
+ for (int l = 0; l < 2; ++l) {
+ if (dots[l]->operand(k) == gte[k]) {
+ TF_RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled));
+ }
+ if (dyn_slices[l] && dyn_slices[l]->operand(0) == gte[k]) {
+ TF_RETURN_IF_ERROR(
+ dyn_slices[l]->ReplaceOperandWith(0, operand_scaled));
+ }
+ }
+
+ // In the all-gather case, coll_perms[0] has two users, coll_perms[1] and
+ // dots[1], which prevents it from being exchanged with dequantization in
+ // gemm_rewriter.cc. Instead, directly insert the dequantization before
+ // dots[1] here.
+ if (coll_perms[0] && coll_perms[0]->operand(0) == gte[k]) {
+ std::array<HloInstruction*, 2> coll_perms_f8{nullptr, nullptr};
+ // Change the type of both collective-permutes to FP8.
+ coll_perms_f8[0] =
+ while_body->AddInstruction(coll_perms[0]->CloneWithNewOperands(
+ operand_f8->shape(), {operand_f8}));
+ coll_perms_f8[1] =
+ while_body->AddInstruction(coll_perms[1]->CloneWithNewOperands(
+ coll_perms_f8[0]->shape(), {coll_perms_f8[0]}));
+
+ // Insert the dequantization between coll_perms[0] and dots[1].
+ HloInstruction* coll_perm0_f32 =
+ MakeConvertToHlo(coll_perms_f8[0], gte[k]->shape().element_type());
+ TF_ASSIGN_OR_RETURN(HloInstruction * x_scaled,
+ MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32,
+ broadcast_scale));
+ TF_RETURN_IF_ERROR(dots[1]->ReplaceOperandWith(0, x_scaled));
+
+ // Update the output tuple.
+ TF_RETURN_IF_ERROR(
+ while_root->ReplaceOperandWithDifferentShape(0, coll_perms_f8[1]));
+ ShapeUtil::UpdateTupleShape(coll_perms_f8[1]->shape(), 0,
+ while_root->mutable_shape());
+ }
+ }
+
+ // Update the shape of the while call in the parent computation.
+ TF_RETURN_IF_ERROR(
+ while_instr->ReplaceAllUsesWithDifferentShape(while_instr->AddInstruction(
+ while_instr->CloneWithNewShape(while_root->shape()))));
+ TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr));
+
+ if (coll_perms[0]) {
+ TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1]));
+ TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0]));
+ }
+ TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[0]));
+ TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[1]));
+
+ VLOG(5) << "FP8 dequantization moved into while loop.";
+ return absl::OkStatus();
+}
+
+int64_t NumberOfInstructionsInComp(const HloComputation* comp, HloOpcode op) {
+ int64_t total_count = 0;
+ for (const HloInstruction* inst : comp->instructions()) {
+ if (inst->opcode() == op) {
+ ++total_count;
+ }
+ }
+ return total_count;
+}
+
+absl::Status UpdateDotAndConsumerConfig(HloInstruction* dot,
+ int64_t stream_id) {
+ auto dot_gpu_config = dot->backend_config<gpu::GpuBackendConfig>();
+ HloInstruction* updater = dot->users()[0];
+ auto updater_gpu_config = updater->backend_config<gpu::GpuBackendConfig>();
+ dot_gpu_config->set_operation_queue_id(stream_id);
+ updater_gpu_config->mutable_wait_on_operation_queues()->Add(stream_id);
+
+ TF_RETURN_IF_ERROR(dot->set_backend_config(dot_gpu_config.value()));
+ TF_RETURN_IF_ERROR(updater->set_backend_config(updater_gpu_config.value()));
+ return absl::OkStatus();
+}
+
+absl::Status SetForceDelayForInstruction(HloInstruction* instr,
+ bool force_delay) {
+ auto gpu_config = instr->backend_config<gpu::GpuBackendConfig>();
+
+ gpu_config->set_force_earliest_schedule(force_delay);
+
+ TF_RETURN_IF_ERROR(instr->set_backend_config(gpu_config.value()));
+ return absl::OkStatus();
+}
+
+absl::StatusOr<bool> HandleRsWindowedEinsumLoop(HloComputation* comp,
+ int64_t stream_id) {
+ bool changed = false;
+ // If we have a einsum loop with only 1 dot, this means either
+ // the loop is not unrolled or only 1 partition is available.
+ // It's a no-op in either case.
+ if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
+ return changed;
+ }
+ for (auto inst : comp->MakeInstructionPostOrder()) {
+ HloInstruction* matched_dot;
+ std::array<HloInstruction*, 2> gte;
+ // The dot we'd like to parallelize is consuming the second loop input
+ // as RHS.
+ if (Match(inst,
+ m::Dot(&matched_dot,
+ m::DynamicSlice().WithOperand(
+ 0, m::GetTupleElement(>e[0], m::Parameter(), 0)),
+ m::GetTupleElement(>e[1], m::Parameter(), 1)))) {
+ // If present, move the dequantization of FP8 operands of the dot into the
+ // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
+ // Call.
+ TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
+
+ // Dispatch the dot to additional compute stream.
+ TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
+ ++stream_id;
+ changed = true;
+ }
+
+ // We need to enforce the first collective-permute to be always scheduled
+ // at the beginning of the loop.
+ HloInstruction* matched_cp;
+ if (Match(inst, m::CollectivePermute(
+ &matched_cp, m::GetTupleElement(m::Parameter(), 2)))) {
+ TF_RETURN_IF_ERROR(
+ SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+absl::StatusOr<bool> HandleAgWindowedEinsumLoop(HloComputation* comp,
+ int64_t stream_id) {
+ bool changed = false;
+ // If we have a einsum loop with only 1 dot, this means either
+ // the loop is not unrolled or only 1 partition is available.
+ // It's a no-op in either case.
+ if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
+ return changed;
+ }
+ for (auto inst : comp->MakeInstructionPostOrder()) {
+ HloInstruction* matched_dot;
+ std::array<HloInstruction*, 2> gte;
+ // The dot we'd like to parallelize is consuming the second loop input
+ // as RHS and first loop input as LHS.
+ if (Match(inst, m::Dot(&matched_dot,
+ m::GetTupleElement(>e[0], m::Parameter(), 0),
+ m::GetTupleElement(>e[1], m::Parameter(), 1)))) {
+ // If present, move the dequantization of FP8 operands of the dot into the
+ // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
+ // Call.
+ TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
+
+ // Dispatch the dot to additional compute stream.
+ TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
+ ++stream_id;
+ TF_RETURN_IF_ERROR(
+ SetForceDelayForInstruction(matched_dot, /*force_delay=*/true));
+ changed = true;
+ }
+
+ // We need to enforce the first collective-permute to be always scheduled
+ // at the beginning of the loop.
+ HloInstruction* matched_cp;
+ if (Match(inst, m::CollectivePermute(
+ &matched_cp, m::GetTupleElement(m::Parameter(), 0)))) {
+ TF_RETURN_IF_ERROR(
+ SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) {
+ const HloInstruction* loop_tuple = while_loop->operand(0);
+ const Shape& tuple_shape = loop_tuple->shape();
+ CHECK(tuple_shape.IsTuple());
+ return tuple_shape.tuple_shapes_size();
+}
+
+absl::Status ProcessWindowedEinsumLoopForActivationCaching(
+ WindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop,
+ HloInstruction* ag_with_shared_operand) {
+ HloInstruction* loop = ag_loop.loop;
+ // Transform the while body to cache the allgathered result in the
+ // output buffer to be consumed by the dot
+ HloComputation* while_body = loop->while_body();
+ HloInstruction* input_gte;
+ for (HloInstruction* gte : while_body->parameter_instruction(0)->users()) {
+ if (gte->tuple_index() == 0) {
+ input_gte = gte;
+ }
+ }
+ // Get the output operand of the full buffer.
+ HloInstruction* root = while_body->root_instruction();
+ // Change loop body to include the new input and output element.
+ HloInstruction* input_tuple = while_body->parameter_instruction(0);
+ const Shape& input_shape = input_tuple->shape();
+ // The full buffer that we will use to cache the accumulated activation
+ // is the last operand in the output tuple.
+ int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop);
+ std::vector<Shape> new_input_shapes(input_shape.tuple_shapes().begin(),
+ input_shape.tuple_shapes().end());
+ new_input_shapes.push_back(ag_with_shared_operand->shape());
+ // Update body input shape
+ Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes);
+ *input_tuple->mutable_shape() = new_input_shape;
+ HloInstruction* full_buffer_output_gte =
+ while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ag_with_shared_operand->shape(), input_tuple,
+ full_cache_buffer_index));
+
+ // Update condition input shape
+ HloComputation* cond_comp = loop->while_condition();
+ HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0);
+ *cond_input_tuple->mutable_shape() = new_input_shape;
+
+ // Update input to the while instruction in parent computation
+ HloInstruction* original_while_input = loop->mutable_operand(0);
+ HloComputation* parent_comp = loop->parent();
+ std::vector<HloInstruction*> new_operands(
+ original_while_input->operands().begin(),
+ original_while_input->operands().end());
+ new_operands.push_back(
+ parent_comp->AddInstruction(HloInstruction::CreateBroadcast(
+ ag_with_shared_operand->shape(),
+ parent_comp->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(new_input_shapes[0].element_type()))),
+ {})));
+ HloInstruction* new_while_input =
+ parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands));
+ TF_RETURN_IF_ERROR(
+ loop->ReplaceOperandWithDifferentShape(0, new_while_input));
+ TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape(
+ original_while_input, new_while_input));
+ *loop->mutable_shape() = new_input_shape;
+
+ HloInstruction* new_full_buffer_output = nullptr;
+ // Find the DUS in the loop body and re-use the slice indices
+ // This should just be a constant(0)
+ HloInstruction* dus_boundary_constant;
+ // The slice we need this time is the output of the first
+ // collective-permute
+ HloInstruction* first_cp_output;
+ for (HloInstruction* gte_user : input_gte->users()) {
+ if (gte_user->opcode() == HloOpcode::kCollectivePermute) {
+ first_cp_output = gte_user;
+ break;
+ }
+ }
+ for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) {
+ HloInstruction* slice_indices;
+ // If we have a DUS(PARAM,DS) pattern, we need to update the output
+ // buffer with the first slice.
+ if (Match(inst,
+ m::DynamicUpdateSlice(
+ m::GetTupleElement(m::Parameter()), m::Op(),
+ m::Constant(&dus_boundary_constant),
+ m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
+ m::Op()))) {
+ slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
+ dus_boundary_constant->shape(), slice_indices));
+ VLOG(5) << "Created slice op for first slice: "
+ << slice_indices->ToString();
+ full_buffer_output_gte =
+ while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ full_buffer_output_gte->shape(), full_buffer_output_gte,
+ input_gte,
+ {dus_boundary_constant, slice_indices, dus_boundary_constant}));
+ }
+ // If we have a DUS(DUS,DS) pattern, then the einsum loop is
+ // unrolled, we need to update the output buffer again with the
+ // second slice. Since the second slice will have different indices,
+ // we need to re-capture slice_indices.
+ if (Match(inst,
+ m::DynamicUpdateSlice(
+ m::DynamicUpdateSlice(), m::Op(), m::Constant(),
+ m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
+ m::Op()))) {
+ slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
+ dus_boundary_constant->shape(), slice_indices));
+ VLOG(5) << "Created slice op for second slice: "
+ << slice_indices->ToString();
+ new_full_buffer_output =
+ while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ full_buffer_output_gte->shape(), full_buffer_output_gte,
+ first_cp_output,
+ {dus_boundary_constant, slice_indices, dus_boundary_constant}));
+ }
+
+ // If we have a Dot(DS(parameter_index1)), then operands are sharded along
+ // the contracting dim. Slice indices will be the contracting dim's slices.
+ HloInstruction* slice_index;
+ HloInstruction* ds_index_constant;
+ HloInstruction* remainder;
+ HloInstruction* ds_param;
+ // There will be 2 dynamic-slices for unrolled loops, match for each one to
+ // get the slice index which will be used to write the corresponding
+ // received shard into cached activation buffer. For unrolled loops, we need
+ // to write to the final buffer twice per iteration, so we need to match for
+ // the correct slice index based on each DS.
+ if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) &&
+ Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) {
+ for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size();
+ ds_op_i++) {
+ if (!Match(
+ ds_param->mutable_operand(ds_op_i),
+ m::Reshape(&slice_index, m::DynamicSlice(m::Constant(),
+ m::Op(&remainder)))) &&
+ !Match(ds_param->mutable_operand(ds_op_i),
+ m::Constant(&ds_index_constant))) {
+ return absl::OkStatus();
+ }
+ }
+ // First DS has slice index calculated based on loop iterator
+ // Remainder(add(gte, partition_id))
+ if (Match(remainder,
+ m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) {
+ full_buffer_output_gte =
+ while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ full_buffer_output_gte->shape(), full_buffer_output_gte,
+ input_gte,
+ {ds_index_constant, ds_index_constant, slice_index}));
+ }
+ // Second DS has slice index calculated based on loop iterator+1 hence
+ // Remainder(add(add(gte, 1), partition_id))
+ if (Match(remainder,
+ m::Remainder(
+ m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()),
+ m::Op()))) {
+ new_full_buffer_output =
+ while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ full_buffer_output_gte->shape(), full_buffer_output_gte,
+ first_cp_output,
+ {ds_index_constant, ds_index_constant, slice_index}));
+ }
+ }
+ }
+ std::vector<HloInstruction*> original_operands(root->operands().begin(),
+ root->operands().end());
+ original_operands.push_back(new_full_buffer_output);
+ HloInstruction* new_output_tuple = while_body->AddInstruction(
+ HloInstruction::CreateTuple(original_operands));
+ TF_RETURN_IF_ERROR(
+ while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple));
+ return absl::OkStatus();
+}
+
+bool HasReplicaGroups(const HloInstruction* inst) {
+ return inst->replica_groups().size() > 0;
+}
+
+bool ShouldAddToChain(const HloInstruction* inst) {
+ switch (inst->opcode()) {
+ case HloOpcode::kTranspose:
+ case HloOpcode::kReshape:
+ case HloOpcode::kCopy:
+ return inst->user_count() == 1;
+ default:
+ return false;
+ }
+}
+
+struct MatchedGemmA2aResult {
+ HloInstruction* producer_gemm;
+ HloInstruction* lhs;
+ HloInstruction* rhs;
+ HloInstruction* a2a_replacement = nullptr;
+ bool matched = false;
+};
+
+class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
+ public:
+ explicit WindowedEinsumVisitor(
+ std::vector<WindowedEinsumHandler::WindowedEinsumAgLoops>& all_ag_loops)
+ : all_ag_loops_(all_ag_loops) {}
+ absl::StatusOr<bool> MatchA2aGemmWithIntermediateReshapes(
+ HloInstruction* dot, HloInstruction** lhs, HloInstruction** rhs) {
+ if (Match(dot, m::Dot(m::AllToAll(lhs).WithOneUse().WithPredicate(
+ HasReplicaGroups),
+ m::Op(rhs))) &&
+ !DynCast<HloAllToAllInstruction>((*lhs))->constrain_layout() &&
+ !(*lhs)->shape().IsTuple()) {
+ return true;
+ }
+ std::vector<HloInstruction*> allowed_intermediate_ops(
+ {dot->mutable_operand(0)});
+
+ HloAllToAllInstruction* matched_a2a = nullptr;
+ // We keep pushing until an unmet condition or we have found the a2a.
+ while (true) {
+ HloInstruction* curr = allowed_intermediate_ops.back();
+ if (ShouldAddToChain(curr)) {
+ allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
+ std::begin(curr->operands()),
+ std::end(curr->operands()));
+ } else if (curr->opcode() == HloOpcode::kAllToAll &&
+ curr->user_count() == 1) {
+ matched_a2a = DynCast<HloAllToAllInstruction>(curr);
+ allowed_intermediate_ops.pop_back();
+ break;
+ } else {
+ return false;
+ }
+ }
+ CHECK(matched_a2a != nullptr);
+ if (matched_a2a->constrain_layout() || matched_a2a->shape().IsTuple() ||
+ !HasReplicaGroups(matched_a2a) || !matched_a2a->split_dimension()) {
+ return false;
+ }
+ // We need to create a new a2a that's a direct producer of the dot and
+ // replace it with the original a2a. A new reshape will be added to the
+ // orginal a2a's input. We first need to determine the new split dimension
+ // after all the reshape ops.
+ int64_t split_dimension = *matched_a2a->split_dimension();
+ for (int64_t i = allowed_intermediate_ops.size() - 1; i >= 0; i--) {
+ HloInstruction* current_op = allowed_intermediate_ops[i];
+ if (current_op->opcode() == HloOpcode::kReshape) {
+ std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
+ ShapeUtil::DimensionsUnmodifiedByReshape(
+ current_op->operand(0)->shape(), current_op->shape());
+ auto it = absl::c_find_if(
+ unmodified_dims,
+ [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
+ return dim_pair.first == split_dimension;
+ });
+ // Split dimension of a2a has been modified, we cannot deduce the new
+ // split dim easily, so skip decomposition.
+ if (it == unmodified_dims.end()) {
+ VLOG(5) << "Split dimension of: " << matched_a2a->ToShortString()
+ << " has been modified by reshapes. Skip process it for "
+ "decomposition.";
+ return false;
+ }
+ // Assign the new split dim.
+ split_dimension = it->second;
+ } else if (current_op->opcode() == HloOpcode::kTranspose) {
+ const auto& transpose_dims = current_op->dimensions();
+ for (int64_t j = 0; j < transpose_dims.size(); j++) {
+ if ((int64_t)transpose_dims[j] == split_dimension) {
+ split_dimension = j;
+ break;
+ }
+ }
+ }
+ }
+ TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
+ 0, matched_a2a->mutable_operand(0)));
+ HloInstruction* new_a2a =
+ matched_a2a->parent()->AddInstruction(HloInstruction::CreateAllToAll(
+ allowed_intermediate_ops.front()->shape(),
+ {allowed_intermediate_ops.front()}, matched_a2a->replica_groups(),
+ false, hlo_query::NextChannelId(*matched_a2a->GetModule()),
+ split_dimension));
+
+ TF_RETURN_IF_ERROR(dot->ReplaceOperandWith(0, new_a2a));
+ TF_RETURN_IF_ERROR(
+ matched_a2a->parent()->RemoveInstructionAndUnusedOperands(matched_a2a));
+ MarkAsChanged();
+ *lhs = new_a2a;
+ *rhs = dot->mutable_operand(1);
+ return true;
+ }
+
+ absl::Status HandleDot(HloInstruction* dot) override {
+ CHECK_EQ(dot->opcode(), HloOpcode::kDot);
+ HloComputation* comp = dot->parent();
+ // Rewrites a allgather-dot pattern that shares the same operand
+ // with a windowed einsum loop to consume the output of the loop
+ // and remove the all-gather.
+ // Now that we have processed all loops, we can check if there are any
+ // allgather-dot pattern that we can optimize. We'd want to transform:
+ // input
+ // / |
+ // / |
+ // AG windowed loop
+ // /
+ // /
+ // dot
+ // to:
+ // input
+ // |
+ // |
+ // windowed loop
+ // |
+ // |
+ // dot
+ // The windowed einsum loop will also be rewritten to output the full input
+ // to be consumed by the dot. This is advantageous since the chained dot can
+ // fully utilize all the resources on the GPU while comm is hidden by the
+ // first collective matmul loop.
+ for (WindowedEinsumHandler::WindowedEinsumAgLoops ag_loop : all_ag_loops_) {
+ HloInstruction* loop = ag_loop.loop;
+ HloInstruction* ag_operand = nullptr;
+
+ if (Match(dot, m::Dot(m::AllGather(&ag_operand), m::Op())) ||
+ Match(dot, m::Dot(m::Op(), m::AllGather(&ag_operand)))) {
+ HloInstruction* windowed_lhs =
+ loop->mutable_operand(0)->mutable_operand(0);
+ HloInstruction* ag_with_shared_operand = nullptr;
+ if (ag_operand && ag_operand->mutable_operand(0) == windowed_lhs) {
+ ag_with_shared_operand = ag_operand;
+ }
+
+ if (!ag_with_shared_operand) {
+ continue;
+ }
+
+ VLOG(5) << "Found all-gather that shares the same operand with a "
+ "windowed einsum loop : "
+ << loop->ToString();
+
+ if (!ag_loop.consumed) {
+ TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching(
+ ag_loop, ag_with_shared_operand));
+ ag_loop.consumed = true;
+ }
+ int64_t cache_output_index = dot->operand_index(ag_with_shared_operand);
+ HloComputation* comp = dot->parent();
+ HloInstruction* new_gte =
+ comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+ loop, GetAgActivationCacheIndex(loop) - 1));
+ TF_RETURN_IF_ERROR(
+ dot->ReplaceOperandWith(cache_output_index, new_gte));
+ TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand));
+ }
+ }
+ // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms
+ // to minimize communication overhead. To do this, the original input will
+ // be sliced into replica_group size and perform all-to-all+gemm.
+ HloInstruction* lhs;
+ HloInstruction* rhs;
+ std::vector<xla::ReplicaGroup> replica_groups;
+ TF_ASSIGN_OR_RETURN(bool matched,
+ MatchA2aGemmWithIntermediateReshapes(dot, &lhs, &rhs));
+ if (matched) {
+ replica_groups = lhs->replica_groups();
+ // We split the a2a+gemm along the contracting dimension into multiple
+ // a2a+gemms and perform partial dots, partial results are added to the
+ // final output buffer.
+ int64_t group_size = replica_groups[0].replica_ids_size();
+ if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
+ return group.replica_ids_size() != group_size;
+ }) != replica_groups.end()) {
+ VLOG(5) << "All-to-all split groups don't have the same number of "
+ "replicas.";
+ return absl::OkStatus();
+ }
+
+ // Get the dimension to slice for lhs and rhs, we slice on the contracting
+ // dimensions to calculate partial results
+ const DotDimensionNumbers& original_dot_dnums =
+ dot->dot_dimension_numbers();
+ const PrecisionConfig& original_precision = dot->precision_config();
+ const auto& lhs_contracting_dims =
+ dot->dot_dimension_numbers().lhs_contracting_dimensions();
+ const auto& rhs_contracting_dims =
+ dot->dot_dimension_numbers().rhs_contracting_dimensions();
+
+ if (lhs_contracting_dims.size() != 1 ||
+ rhs_contracting_dims.size() != 1) {
+ VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
+ "sharding will be skipped.";
+ return absl::OkStatus();
+ }
+ int64_t lhs_contracting_dim = lhs_contracting_dims[0];
+ int64_t rhs_contracting_dim = rhs_contracting_dims[0];
+ HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(lhs);
+ int64_t contracting_dim_value =
+ rhs->shape().dimensions()[rhs_contracting_dim];
+
+ // Each split is sliced out of the input buffer, we need to determine the
+ // slice sizes and increments.
+ std::vector<int64_t> lhs_slice_sizes(a2a->shape().rank(), 0);
+ std::vector<int64_t> lhs_slice_increments(a2a->shape().rank(), 1);
+ std::vector<int64_t> lhs_slice_max_range(
+ a2a->shape().dimensions().begin(), a2a->shape().dimensions().end());
+
+ std::vector<int64_t> rhs_slice_sizes(rhs->shape().rank(), 0);
+ std::vector<int64_t> rhs_slice_increments(rhs->shape().rank(), 1);
+ std::vector<int64_t> rhs_slice_max_range(
+ rhs->shape().dimensions().begin(), rhs->shape().dimensions().end());
+
+ // Create a zero-valued buffer to hold output.
+ HloInstruction* output_buffer =
+ comp->AddInstruction(HloInstruction::CreateBroadcast(
+ dot->shape(),
+ comp->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(dot->shape().element_type()))),
+ {}));
+ HloInstruction* a2a_operand = a2a->mutable_operand(0);
+ if (contracting_dim_value % group_size) {
+ VLOG(5) << absl::StrFormat(
+ "Contracting dimension %d needs to be divisible by group_size %d",
+ contracting_dim_value, group_size);
+ return absl::OkStatus();
+ }
+ int64_t size_per_split = contracting_dim_value / group_size;
+
+ // Each split is sliced out of the input buffer, we need to determine the
+ // slice sizes and increments.
+ lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
+ rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
+
+ Shape lhs_slice_shape = a2a->shape();
+ Shape rhs_slice_shape = rhs->shape();
+
+ lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
+ rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
+
+ HloInstruction* lhs_slice;
+ HloInstruction* rhs_slice;
+
+ HloInstruction* partial_result = output_buffer;
+
+ Shape partial_all_to_all_shape = lhs_slice_shape;
+
+ TF_ASSIGN_OR_RETURN(
+ Shape partial_dot_shape,
+ ShapeInference::InferDotOpShape(
+ partial_all_to_all_shape, rhs_slice_shape, original_dot_dnums,
+ /*preferred_element_type=*/std::nullopt));
+ int64_t stream_id = hlo_query::NextChannelId(*a2a->GetModule());
+ for (int64_t i = 0; i < group_size; ++i) {
+ lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+ lhs_slice_shape, a2a_operand, lhs_slice_sizes, lhs_slice_max_range,
+ lhs_slice_increments));
+ a2a->SetupDerivedInstruction(lhs_slice);
+ lhs_slice_sizes[lhs_contracting_dim] =
+ lhs_slice_max_range[lhs_contracting_dim];
+ lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
+
+ rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+ rhs_slice_shape, rhs, rhs_slice_sizes, rhs_slice_max_range,
+ rhs_slice_increments));
+ a2a->SetupDerivedInstruction(rhs_slice);
+ rhs_slice_sizes[rhs_contracting_dim] =
+ rhs_slice_max_range[rhs_contracting_dim];
+ rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
+
+ HloInstruction* partial_all_to_all =
+ comp->AddInstruction(HloInstruction::CreateAllToAll(
+ partial_all_to_all_shape, {lhs_slice}, a2a->device_list(),
+ false, hlo_query::NextChannelId(*a2a->GetModule()),
+ a2a->split_dimension()));
+ a2a->SetupDerivedInstruction(partial_all_to_all);
+
+ HloInstruction* partial_dot =
+ comp->AddInstruction(HloInstruction::CreateDot(
+ partial_dot_shape, partial_all_to_all, rhs_slice,
+ original_dot_dnums, original_precision));
+ partial_result = comp->AddInstruction(
+ HloInstruction::CreateBinary(partial_dot->shape(), HloOpcode::kAdd,
+ partial_dot, partial_result));
+ a2a->SetupDerivedInstruction(partial_result);
+ TF_RETURN_IF_ERROR(
+ UpdateDotAndConsumerConfig(partial_dot, stream_id++));
+ }
+ TF_RETURN_IF_ERROR(ReplaceInstruction(dot, partial_result));
+ }
+ return absl::OkStatus();
+ }
+
+ absl::StatusOr<MatchedGemmA2aResult> MatchGemmA2aWithIntermediateReshapes(
+ HloInstruction* inst) {
+ MatchedGemmA2aResult result;
+ HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(inst);
+ if (!HasReplicaGroups(a2a) || a2a->constrain_layout() ||
+ a2a->shape().IsTuple()) {
+ return result;
+ }
+ if (Match(a2a, m::AllToAll(m::Dot(&result.producer_gemm, m::Op(&result.lhs),
+ m::Op(&result.rhs))
+ .WithOneUse()))) {
+ result.matched = true;
+ return result;
+ }
+ std::vector<HloInstruction*> allowed_intermediate_ops(
+ {a2a->mutable_operand(0)});
+
+ HloInstruction* matched_dot = nullptr;
+ // We keep pushing until an unmet condition or we have found the producer
+ // dot.
+ while (true) {
+ HloInstruction* curr = allowed_intermediate_ops.back();
+ if (ShouldAddToChain(curr)) {
+ allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
+ std::begin(curr->operands()),
+ std::end(curr->operands()));
+ } else if (curr->opcode() == HloOpcode::kDot && curr->user_count() == 1) {
+ matched_dot = curr;
+ allowed_intermediate_ops.pop_back();
+ break;
+ } else {
+ return result;
+ }
+ }
+ CHECK(matched_dot != nullptr);
+ // We need to create a new a2a that's a direct consumer of the dot and
+ // replace it with the original a2a. A new reshape will be added to the
+ // orginal a2a's output. We first need to determine the new split dimension
+ // after all the reshape ops.
+ int64_t split_dimension = *a2a->split_dimension();
+ for (int64_t i = 0; i < allowed_intermediate_ops.size(); i++) {
+ HloInstruction* current_op = allowed_intermediate_ops[i];
+ if (current_op->opcode() == HloOpcode::kReshape) {
+ std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
+ ShapeUtil::DimensionsUnmodifiedByReshape(
+ current_op->operand(0)->shape(), current_op->shape());
+ auto it = absl::c_find_if(
+ unmodified_dims,
+ [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
+ return dim_pair.second == split_dimension;
+ });
+ // Split dimension of a2a has been modified, we cannot deduce the new
+ // split dim easily, so skip decomposition.
+ if (it == unmodified_dims.end()) {
+ VLOG(5) << "Split dimension of: " << a2a->ToShortString()
+ << " has been modified by reshapes. Skip process it for "
+ "decomposition.";
+ return result;
+ }
+ // Assign the new split dim.
+ split_dimension = it->first;
+ } else if (current_op->opcode() == HloOpcode::kTranspose) {
+ const auto& transpose_dims = current_op->dimensions();
+ split_dimension = transpose_dims[split_dimension];
+ }
+ }
+ result.a2a_replacement =
+ matched_dot->parent()->AddInstruction(HloInstruction::CreateAllToAll(
+ matched_dot->shape(), {matched_dot}, a2a->replica_groups(), false,
+ hlo_query::NextChannelId(*matched_dot->GetModule()),
+ split_dimension));
+ TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
+ 0, result.a2a_replacement));
+ inst->SetupDerivedInstruction(result.a2a_replacement);
+
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(inst, allowed_intermediate_ops.front()));
+ result.lhs = matched_dot->mutable_operand(0);
+ result.rhs = matched_dot->mutable_operand(1);
+ result.producer_gemm = matched_dot;
+ result.matched = true;
+ return result;
+ }
+
+ // Rewrites an gemm+all-to-all into multiple independent partial gemm+a2a's
+ // to minimize communication overhead. To do this, the original input will be
+ // sliced into replica_group size and perform gemm+all-to-all.
+ absl::Status HandleAllToAll(HloInstruction* inst) override {
+ CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll);
+ HloComputation* comp = inst->parent();
+ // Rewrites a gemm+alltoall into multiple independent partial gemm+a2as
+ // to minimize communication overhead.
+ std::vector<xla::ReplicaGroup> replica_groups;
+ TF_ASSIGN_OR_RETURN(MatchedGemmA2aResult matched_result,
+ MatchGemmA2aWithIntermediateReshapes(inst));
+ if (matched_result.matched) {
+ HloInstruction* a2a = inst;
+ if (matched_result.a2a_replacement) {
+ a2a = matched_result.a2a_replacement;
+ }
+ replica_groups = a2a->replica_groups();
+ // Similar to a2a+gemm, we split along contracting dimensions
+ // and aggregate result at each step.
+ int64_t group_size = replica_groups[0].replica_ids_size();
+
+ if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
+ return group.replica_ids_size() != group_size;
+ }) != replica_groups.end()) {
+ VLOG(5) << "All-to-all split groups don't have the same number of "
+ "replicas.";
+ return absl::OkStatus();
+ }
+
+ // Get the dimension to slice for lhs and rhs, we slice on the contracting
+ // dimensions to calculate partial results
+ const DotDimensionNumbers& original_dot_dnums =
+ matched_result.producer_gemm->dot_dimension_numbers();
+ const PrecisionConfig& original_precision =
+ matched_result.producer_gemm->precision_config();
+ const auto& lhs_contracting_dims =
+ matched_result.producer_gemm->dot_dimension_numbers()
+ .lhs_contracting_dimensions();
+ const auto& rhs_contracting_dims =
+ matched_result.producer_gemm->dot_dimension_numbers()
+ .rhs_contracting_dimensions();
+
+ if (lhs_contracting_dims.size() != 1 ||
+ rhs_contracting_dims.size() != 1) {
+ VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
+ "sharding will be skipped.";
+ return absl::OkStatus();
+ }
+ int64_t lhs_contracting_dim = lhs_contracting_dims[0];
+ int64_t rhs_contracting_dim = rhs_contracting_dims[0];
+ HloAllToAllInstruction* all_to_all = DynCast<HloAllToAllInstruction>(a2a);
+ int64_t contracting_dim_value =
+ matched_result.rhs->shape().dimensions()[rhs_contracting_dim];
+ // Each split is sliced out of the input buffer, we need to determine the
+ // slice sizes and increments.
+ std::vector<int64_t> lhs_slice_sizes(matched_result.lhs->shape().rank(),
+ 0);
+ std::vector<int64_t> lhs_slice_increments(
+ matched_result.lhs->shape().rank(), 1);
+ std::vector<int64_t> lhs_slice_max_range(
+ matched_result.lhs->shape().dimensions().begin(),
+ matched_result.lhs->shape().dimensions().end());
+
+ std::vector<int64_t> rhs_slice_sizes(matched_result.rhs->shape().rank(),
+ 0);
+ std::vector<int64_t> rhs_slice_increments(
+ matched_result.rhs->shape().rank(), 1);
+ std::vector<int64_t> rhs_slice_max_range(
+ matched_result.rhs->shape().dimensions().begin(),
+ matched_result.rhs->shape().dimensions().end());
+
+ // Create a zero-valued buffer to hold output.
+ HloInstruction* output_buffer =
+ comp->AddInstruction(HloInstruction::CreateBroadcast(
+ all_to_all->shape(),
+ comp->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(all_to_all->shape().element_type()))),
+ {}));
+ if (contracting_dim_value % group_size) {
+ VLOG(5) << absl::StrFormat(
+ "Contracting dimension %d needs to be divisible by group_size %d",
+ contracting_dim_value, group_size);
+ return absl::OkStatus();
+ }
+
+ int64_t size_per_split = contracting_dim_value / group_size;
+ // Each split is sliced out of the input buffer, we need to determine the
+ // slice sizes and increments.
+ lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
+ rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
+
+ Shape lhs_slice_shape = matched_result.lhs->shape();
+ Shape rhs_slice_shape = matched_result.rhs->shape();
+
+ lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
+ rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
+
+ HloInstruction* lhs_slice;
+ HloInstruction* rhs_slice;
+
+ HloInstruction* partial_result = output_buffer;
+ Shape partial_all_to_all_shape = all_to_all->shape();
+
+ TF_ASSIGN_OR_RETURN(
+ Shape partial_dot_shape,
+ ShapeInference::InferDotOpShape(
+ lhs_slice_shape, rhs_slice_shape, original_dot_dnums,
+ /*preferred_element_type=*/std::nullopt));
+ int64_t stream_id = hlo_query::NextChannelId(*all_to_all->GetModule());
+ for (int64_t i = 0; i < group_size; ++i) {
+ lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+ lhs_slice_shape, matched_result.lhs, lhs_slice_sizes,
+ lhs_slice_max_range, lhs_slice_increments));
+ all_to_all->SetupDerivedInstruction(lhs_slice);
+ lhs_slice_sizes[lhs_contracting_dim] =
+ lhs_slice_max_range[lhs_contracting_dim];
+ lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
+
+ rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+ rhs_slice_shape, matched_result.rhs, rhs_slice_sizes,
+ rhs_slice_max_range, rhs_slice_increments));
+
+ all_to_all->SetupDerivedInstruction(rhs_slice);
+ rhs_slice_sizes[rhs_contracting_dim] =
+ rhs_slice_max_range[rhs_contracting_dim];
+ rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
+
+ HloInstruction* partial_dot = comp->AddInstruction(
+ HloInstruction::CreateDot(partial_dot_shape, lhs_slice, rhs_slice,
+ original_dot_dnums, original_precision));
+
+ HloInstruction* partial_all_to_all =
+ comp->AddInstruction(HloInstruction::CreateAllToAll(
+ partial_all_to_all_shape, {partial_dot},
+ all_to_all->device_list(), false,
+ hlo_query::NextChannelId(*all_to_all->GetModule()),
+ all_to_all->split_dimension()));
+ all_to_all->SetupDerivedInstruction(partial_all_to_all);
+ partial_result = comp->AddInstruction(HloInstruction::CreateBinary(
+ partial_all_to_all_shape, HloOpcode::kAdd, partial_all_to_all,
+ partial_result));
+ all_to_all->SetupDerivedInstruction(partial_result);
+ TF_RETURN_IF_ERROR(
+ UpdateDotAndConsumerConfig(partial_dot, stream_id++));
+ }
+ TF_RETURN_IF_ERROR(ReplaceInstruction(all_to_all, partial_result));
+ }
+
+ return absl::OkStatus();
+ }
+
+ private:
+ std::vector<WindowedEinsumHandler::WindowedEinsumAgLoops>& all_ag_loops_;
+};
+
+} // namespace
+
+absl::StatusOr<bool> WindowedEinsumHandler::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ XLA_VLOG_LINES(
+ 5, "WindowedEinsumHandler::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ int64_t stream_id = hlo_query::NextChannelId(*module);
+
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ if (comp->name().find(kWindowedEinsumRsLoopName) == 0) {
+ VLOG(5) << "Processing computation: " << comp->name();
+ TF_ASSIGN_OR_RETURN(bool comp_result,
+ HandleRsWindowedEinsumLoop(comp, stream_id));
+ changed = comp_result;
+ } else if (comp->name().find(kWindowedEinsumAgLoopName) == 0) {
+ VLOG(5) << "Processing computation: " << comp->name();
+ TF_ASSIGN_OR_RETURN(bool comp_result,
+ HandleAgWindowedEinsumLoop(comp, stream_id));
+ all_ag_loops_.push_back(
+ WindowedEinsumAgLoops(comp->WhileCallInstruction()));
+ changed = comp_result;
+ }
+ }
+ for (HloComputation* comp :
+ module->MakeNonfusionComputations(execution_threads)) {
+ WindowedEinsumVisitor visitor(all_ag_loops_);
+ TF_RETURN_IF_ERROR(comp->Accept(&visitor));
+ changed |= visitor.changed();
+ }
+
+ XLA_VLOG_LINES(5,
+ "WindowedEinsumHandler::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h
new file mode 100644
index 0000000..bcc7680e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h
@@ -0,0 +1,64 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_
+
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass is targeting the windowed einsum optimization
+// in the SPMD pipeline. It rewrites all-gather+gemm or
+// gemm+reduce-scatter into sharded loops to achieve overlap
+// between sharded gemms and communication. This pass will
+// optimize it on GPU by annotating independent gemms with
+// stream ids in the backend config. By running them in different
+// streams, we can practically achieve overlap between gemms too.
+class WindowedEinsumHandler : public HloModulePass {
+ public:
+ absl::string_view name() const override { return "windowed-einsum-handler"; }
+
+ struct WindowedEinsumAgLoops {
+ explicit WindowedEinsumAgLoops(HloInstruction* loop) : loop(loop) {}
+ HloInstruction* loop;
+ bool consumed = false;
+ };
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ constexpr static const char* kWindowedEinsumRsLoopName =
+ "windowed_dot_general_body_rs";
+ constexpr static const char* kWindowedEinsumAgLoopName =
+ "windowed_dot_general_body_ag";
+
+ private:
+ std::vector<WindowedEinsumAgLoops> all_ag_loops_;
+};
+
+} // namespace xla::gpu
+
+#endif // XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc
new file mode 100644
index 0000000..a8ea1b1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc
@@ -0,0 +1,918 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
+
+#include <memory>
+#include <string>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using WindowedEinsumHanlderTest = HloTestBase;
+
+HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) {
+ for (auto inst : comp->instructions()) {
+ if (inst->name() == name) {
+ return inst;
+ }
+ }
+ return nullptr;
+}
+
+TEST_F(WindowedEinsumHanlderTest, AgLoopsHaveStreamIds) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,512,24576]{2,1,0}, bf16[24576,24576]{1,0})->bf16[2048,24576]{1,0}}, num_partitions=4
+
+windowed_dot_general_body_ag.1 {
+ param = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
+ get-tuple-element = bf16[512,24576]{1,0} get-tuple-element(param), index=0
+ collective-permute = bf16[512,24576]{1,0} collective-permute(get-tuple-element), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ get-tuple-element.1 = bf16[24576,24576]{1,0} get-tuple-element(param), index=1
+ get-tuple-element.2 = bf16[2048,24576]{1,0} get-tuple-element(param), index=2
+ dot.2 = bf16[512,24576]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ constant.1 = s32[4]{0} constant({0, 512, 1024, 1536})
+ get-tuple-element.4 = u32[] get-tuple-element(param), index=4
+ partition-id = u32[] partition-id()
+ add = u32[] add(get-tuple-element.4, partition-id)
+ constant = u32[] constant(4)
+ remainder = u32[] remainder(add, constant)
+ dynamic-slice = s32[1]{0} dynamic-slice(constant.1, remainder), dynamic_slice_sizes={1}
+ reshape.4 = s32[] reshape(dynamic-slice)
+ constant.2 = s32[] constant(0)
+ dynamic-update-slice = bf16[2048,24576]{1,0} dynamic-update-slice(get-tuple-element.2, dot.2, reshape.4, constant.2), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ dot.3 = bf16[512,24576]{1,0} dot(collective-permute, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ constant.3 = u32[] constant(1)
+ add.1 = u32[] add(get-tuple-element.4, constant.3)
+ add.2 = u32[] add(add.1, partition-id)
+ remainder.1 = u32[] remainder(add.2, constant)
+ dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.1, remainder.1), dynamic_slice_sizes={1}
+ reshape.5 = s32[] reshape(dynamic-slice.1)
+ dynamic-update-slice.1 = bf16[2048,24576]{1,0} dynamic-update-slice(dynamic-update-slice, dot.3, reshape.5, constant.2)
+ get-tuple-element.3 = bf16[2048,24576]{1,0} get-tuple-element(param), index=3
+ add.3 = u32[] add(add.1, constant.3)
+ ROOT tuple = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(collective-permute, get-tuple-element.1, dynamic-update-slice.1, get-tuple-element.3, add.3)
+} // windowed_dot_general_body_ag.1
+
+windowed_dot_general_cond_ag {
+ param.1 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
+ get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
+ constant.8 = u32[] constant(4)
+ ROOT compare = pred[] compare(get-tuple-element.5, constant.8), direction=LT
+}
+
+ENTRY test_main {
+ param.4 = bf16[1,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
+ reshape.8 = bf16[512,24576]{1,0} reshape(param.4)
+ param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
+ constant.18 = bf16[] constant(0)
+ broadcast = bf16[2048,24576]{1,0} broadcast(constant.18), dimensions={}
+ constant.20 = u32[] constant(0)
+ tuple.2 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(reshape.8, param.5, broadcast, broadcast, constant.20)
+ while = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag.1
+ ROOT get-tuple-element.13 = bf16[2048,24576]{1,0} get-tuple-element(while), index=2
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloInstruction* ag_loop =
+ module->entry_computation()->root_instruction()->mutable_operand(0);
+ HloComputation* ag_loop_body = ag_loop->while_body();
+ HloInstruction* inst = FindInstructionByName(ag_loop_body, "dot.2");
+ EXPECT_GT(inst->backend_config<GpuBackendConfig>()->operation_queue_id(), 0);
+ EXPECT_TRUE(
+ inst->backend_config<GpuBackendConfig>()->force_earliest_schedule());
+
+ HloInstruction* cp1 =
+ FindInstructionByName(ag_loop_body, "collective-permute");
+ EXPECT_TRUE(
+ cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
+}
+
+TEST_F(WindowedEinsumHanlderTest, RsLoopsHaveStreamIds) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[2048,24576]{1,0})->bf16[512,24576]{1,0}}, num_partitions=4
+
+windowed_dot_general_body_rs_clone.1 {
+ param.2 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
+ get-tuple-element.6 = bf16[2048,24576]{1,0} get-tuple-element(param.2), index=0
+ get-tuple-element.7 = bf16[24576,24576]{1,0} get-tuple-element(param.2), index=1
+ get-tuple-element.9 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=2
+ collective-permute.1 = bf16[512,24576]{1,0} collective-permute(get-tuple-element.9), channel_id=4, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ constant.10 = s32[4]{0} constant({0, 512, 1024, 1536})
+ get-tuple-element.11 = u32[] get-tuple-element(param.2), index=4
+ constant.12 = u32[] constant(2)
+ add.8 = u32[] add(get-tuple-element.11, constant.12)
+ constant.13 = u32[] constant(1)
+ add.9 = u32[] add(add.8, constant.13)
+ partition-id.3 = u32[] partition-id()
+ add.10 = u32[] add(add.9, partition-id.3)
+ constant.9 = u32[] constant(4)
+ remainder.3 = u32[] remainder(add.10, constant.9)
+ dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.10, remainder.3), dynamic_slice_sizes={1}
+ reshape.7 = s32[] reshape(dynamic-slice.4)
+ constant.11 = s32[] constant(0)
+ dynamic-slice.5 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.7, constant.11), dynamic_slice_sizes={512,24576}
+ dot.7 = bf16[512,24576]{1,0} dot(dynamic-slice.5, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ add.11 = bf16[512,24576]{1,0} add(collective-permute.1, dot.7), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ get-tuple-element.10 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=3
+ add.6 = u32[] add(get-tuple-element.11, partition-id.3)
+ remainder.2 = u32[] remainder(add.6, constant.9)
+ dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.10, remainder.2), dynamic_slice_sizes={1}
+ reshape.6 = s32[] reshape(dynamic-slice.2)
+ dynamic-slice.3 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.6, constant.11), dynamic_slice_sizes={512,24576}
+ dot.5 = bf16[512,24576]{1,0} dot(dynamic-slice.3, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ add.7 = bf16[512,24576]{1,0} add(get-tuple-element.10, dot.5), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+ collective-permute.2 = bf16[512,24576]{1,0} collective-permute(add.7), channel_id=5, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
+ ROOT tuple.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(get-tuple-element.6, get-tuple-element.7, add.11, collective-permute.2, add.8)
+}
+
+windowed_dot_general_cond_rs {
+ param.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
+ get-tuple-element.12 = u32[] get-tuple-element(param.3), index=4
+ constant.17 = u32[] constant(4)
+ ROOT compare.1 = pred[] compare(get-tuple-element.12, constant.17), direction=LT
+}
+
+ENTRY main.9_spmd {
+ param.6 = bf16[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
+ param.7 = bf16[512,24576]{1,0} parameter(1)
+ param.8 = bf16[2048,24576]{1,0} parameter(2)
+ constant.20 = u32[] constant(0)
+ tuple.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(param.8, param.6, param.7, param.7, constant.20)
+ while.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs_clone.1
+ ROOT get-tuple-element.14 = bf16[512,24576]{1,0} get-tuple-element(while.1), index=2
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloInstruction* rs_loop =
+ module->entry_computation()->root_instruction()->mutable_operand(0);
+ HloComputation* rs_loop_body = rs_loop->while_body();
+ HloInstruction* inst = FindInstructionByName(rs_loop_body, "dot.7");
+ EXPECT_TRUE(inst->backend_config<GpuBackendConfig>()->operation_queue_id() >
+ 0);
+
+ HloInstruction* cp1 =
+ FindInstructionByName(rs_loop_body, "collective-permute.1");
+ EXPECT_TRUE(
+ cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
+}
+
+TEST_F(WindowedEinsumHanlderTest, AgLoopsMultipleConsumersAreChained) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[24576,24576]{1,0})->bf16[2,2048,24576]{2,1,0}}, num_partitions=4
+
+windowed_dot_general_body_ag {
+ param.1 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element.1 = bf16[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
+ collective-permute = bf16[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+ collective-permute.1 = bf16[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=3, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+ get-tuple-element.2 = bf16[24576,24576]{1,0} get-tuple-element(param.1), index=1
+ get-tuple-element.3 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
+ dot = bf16[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ constant.2 = s32[] constant(0)
+ constant.3 = s32[4]{0} constant({0, 512, 1024, 1536})
+ get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
+ partition-id = u32[] partition-id()
+ add = u32[] add(get-tuple-element.5, partition-id)
+ constant.1 = u32[] constant(4)
+ remainder = u32[] remainder(add, constant.1)
+ dynamic-slice = s32[1]{0} dynamic-slice(constant.3, remainder), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(dynamic-slice)
+ dynamic-update-slice = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.2, reshape, constant.2)
+ dot.1 = bf16[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ constant.5 = u32[] constant(1)
+ add.1 = u32[] add(get-tuple-element.5, constant.5)
+ add.2 = u32[] add(add.1, partition-id)
+ remainder.1 = u32[] remainder(add.2, constant.1)
+ dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.3, remainder.1), dynamic_slice_sizes={1}
+ reshape.1 = s32[] reshape(dynamic-slice.1)
+ dynamic-update-slice.1 = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.2, reshape.1, constant.2)
+ get-tuple-element.4 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
+ add.3 = u32[] add(add.1, constant.5)
+ ROOT tuple = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
+} // windowed_dot_general_body_ag
+
+windowed_dot_general_cond_ag {
+ param = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element = u32[] get-tuple-element(param), index=4
+ constant = u32[] constant(4)
+ ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
+}
+
+ENTRY main.12_spmd {
+ param.4 = bf16[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
+ param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
+ constant.22 = bf16[] constant(0)
+ broadcast = bf16[2,2048,24576]{2,1,0} broadcast(constant.22), dimensions={}
+ constant.24 = u32[] constant(0)
+ tuple.2 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
+ while = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
+ get-tuple-element.13 = bf16[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
+ copy.1 = bf16[2,2048,24576]{2,1,0} copy(get-tuple-element.13)
+ all-gather = bf16[2,2048,24576]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true
+ param.6 = bf16[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]}
+ ROOT dot.7 = bf16[2,2048,24576]{2,1,0} dot(all-gather, param.6), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ HloInstruction* ag_loop =
+ FindInstructionByName(module->entry_computation(), "while");
+ HloInstruction* inst =
+ FindInstructionByName(module->entry_computation(), "dot.7");
+ // dot.7 should now consume output of the windowed einsum while loop.
+ EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
+ EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
+ EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
+
+ // while loop's root should now have a chain of DUS.
+ HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction();
+ EXPECT_THAT(ag_while_root,
+ GmockMatch(m::Tuple(
+ m::Op(), m::Op(), m::Op(), m::Op(), m::Op(),
+ m::DynamicUpdateSlice(
+ m::DynamicUpdateSlice(
+ m::GetTupleElement(m::Parameter())
+ .WithPredicate([](const HloInstruction* instr) {
+ return instr->tuple_index() == 5;
+ }),
+ m::Op(), m::Op(), m::Op(), m::Op()),
+ m::Op(), m::Op(), m::Op(), m::Op()))));
+}
+TEST_F(WindowedEinsumHanlderTest, A2aGemmHaveStreamIds) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,8192]{3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=8
+
+ENTRY main.9_spmd {
+ param0 = bf16[1,8192,32768]{2,1,0} parameter(0)
+ param1 = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
+ all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(param1), channel_id=4, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={1}
+ ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(all-to-all, param0), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+}
+)";
+
+ const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["5"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
+
+CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+ RunFileCheck(module->ToString(), kExpected));
+ EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, GemmA2aHaveStreamIds) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,32768]{3,2,1,0})->bf16[1,4,2048,8192]{3,2,1,0}}, num_partitions=4
+
+ENTRY main.9_spmd {
+ param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
+ param.10 = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
+ dot.12 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.10, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}
+ ROOT all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(dot.12), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={1}
+}
+)";
+
+ const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
+CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [24576:32768]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [16384:24576]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [8192:16384]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
+
+CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+ RunFileCheck(module->ToString(), kExpected));
+ EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, A2aTransposeLoopsHaveStreamIds) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=4
+
+ENTRY main.9_spmd {
+ param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
+ param.10 = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
+ all-to-all = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} all-to-all(param.10), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={3}
+ transpose.15 = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(all-to-all), dimensions={0,3,1,2,4,5}
+ reshape.2170 = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(transpose.15)
+ reshape.2173 = bf16[4,8192,1,2048]{3,2,1,0} reshape(reshape.2170)
+ transpose.16 = bf16[1,4,2048,8192]{2,0,3,1} transpose(reshape.2173), dimensions={2,0,3,1}
+ copy.53 = bf16[1,4,2048,8192]{3,2,1,0} copy(transpose.16)
+ ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(copy.53, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+}
+)";
+
+ const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
+CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} %[[P1:.*]]), dimensions={0,3,1,2,4,5}
+CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} %[[TRANSPOSE0:.*]])
+CHECK-DAG: %[[RESHAPE1:.*]] = bf16[4,8192,1,2048]{3,2,1,0} reshape(bf16[1,4,8192,1,2048]{4,3,2,1,0} %[[RESHAPE0:.*]])
+CHECK-DAG: %[[TRANSPOSE1:.*]] = bf16[1,4,2048,8192]{2,0,3,1} transpose(bf16[4,8192,1,2048]{3,2,1,0} %[[RESHAPE1:.*]]), dimensions={2,0,3,1}
+CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{2,0,3,1} %[[TRANSPOSE1:.*]])
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["8"],"force_earliest_schedule":false}
+
+CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+ RunFileCheck(module->ToString(), kExpected));
+ EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, GemmA2aTransposeLoopsHaveStreamIds) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,4,2048,32768]{3,2,1,0}, bf16[1,32768,8192]{2,1,0})->bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0}}, num_partitions=4
+
+ENTRY main.9_spmd {
+ param.9 = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
+ param.10 = bf16[1,32768,8192]{2,1,0} parameter(1)
+ dot.13 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.9, param.10), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+ copy.55 = bf16[1,4,2048,8192]{3,2,1,0} copy(dot.13)
+ transpose.17 = bf16[4,1,2048,8192]{3,2,0,1} transpose(copy.55), dimensions={1,0,2,3}
+ copy.56 = bf16[4,1,2048,8192]{3,2,1,0} copy(transpose.17)
+ reshape.2216 = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(copy.56)
+ reshape.2219 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(reshape.2216)
+ ROOT all-to-all.1 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} all-to-all(reshape.2219), channel_id=7, replica_groups={{0,1,2,3}}, dimensions={1}
+}
+)";
+
+ const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
+CHECK-DAG: %[[P0:.*]] = bf16[1,32768,8192]{2,1,0} parameter(1)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [24576:32768], [0:8192]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"12","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [16384:24576], [0:8192]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"11","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [8192:16384], [0:8192]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"10","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
+CHECK: replica_groups={
+CHECK: {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
+CHECK-DAG: %[[ADD3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
+
+CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{3,2,1,0} %[[ADD3:.*]])
+CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[4,1,2048,8192]{3,2,0,1} transpose(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), dimensions={1,0,2,3}
+CHECK-DAG: %[[COPY1:.*]] = bf16[4,1,2048,8192]{3,2,1,0} copy(bf16[4,1,2048,8192]{3,2,0,1} %[[TRANSPOSE0:.*]])
+CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(bf16[4,1,2048,8192]{3,2,1,0} %[[COPY1:.*]])
+
+CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,2048,8192]{4,3,2,1,0} %[[RESHAPE0:.*]])
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+ RunFileCheck(module->ToString(), kExpected));
+ EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, AllGatherF8) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4
+
+windowed_dot_general_body_ag {
+ param.1 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element.1 = f32[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
+ collective-permute = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+ collective-permute.1 = f32[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+ get-tuple-element.2 = f32[24576,24576]{1,0} get-tuple-element(param.1), index=1
+ get-tuple-element.3 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
+ dot = f32[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ constant.12 = s32[] constant(0)
+ constant.13 = s32[4]{0} constant({0, 512, 1024, 1536})
+ get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
+ partition-id = u32[] partition-id()
+ add = u32[] add(get-tuple-element.5, partition-id)
+ constant.11 = u32[] constant(4)
+ remainder = u32[] remainder(add, constant.11)
+ dynamic-slice = s32[1]{0} dynamic-slice(constant.13, remainder), dynamic_slice_sizes={1}
+ reshape = s32[] reshape(dynamic-slice)
+ dynamic-update-slice = f32[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.12, reshape, constant.12)
+ dot.1 = f32[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ constant.15 = u32[] constant(1)
+ add.1 = u32[] add(get-tuple-element.5, constant.15)
+ add.2 = u32[] add(add.1, partition-id)
+ remainder.1 = u32[] remainder(add.2, constant.11)
+ dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.13, remainder.1), dynamic_slice_sizes={1}
+ reshape.1 = s32[] reshape(dynamic-slice.1)
+ dynamic-update-slice.1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.12, reshape.1, constant.12)
+ get-tuple-element.4 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
+ add.3 = u32[] add(add.1, constant.15)
+ ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
+} // windowed_dot_general_body_ag
+
+windowed_dot_general_cond_ag {
+ param = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element = u32[] get-tuple-element(param), index=4
+ constant.10 = u32[] constant(4)
+ ROOT compare = pred[] compare(get-tuple-element, constant.10), direction=LT
+}
+
+ENTRY test_main {
+ param.4 = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
+ reshape.8 = f8e4m3fn[2,512,24576]{2,1,0} reshape(param.4)
+ param.5 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
+ constant.18 = f32[] constant(0)
+ broadcast = f32[2,2048,24576]{2,1,0} broadcast(constant.18), dimensions={}
+ constant.20 = u32[] constant(0)
+ scale_lhs = f32[] parameter(2)
+ scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
+ lhs_bf32 = f32[2,512,24576]{2,1,0} convert(reshape.8)
+ lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_bf32, scale_lhs_bcast)
+ scale_rhs = f32[] parameter(3)
+ scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
+ rhs_bf32 = f32[24576,24576]{1,0} convert(param.5)
+ rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf32, scale_rhs_bcast)
+ tuple.2 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, broadcast, broadcast, constant.20)
+ while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
+ ROOT get-tuple-element.13 = f32[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
+}
+)";
+
+ RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(),
+ R"(
+; CHECK-LABEL: windowed_dot_general_body_ag
+; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
+; CHECK-NEXT: [[GTE0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=0
+; CHECK-NEXT: [[CP0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[GTE0]]), channel_id=4
+; CHECK-NEXT: [[CP1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[CP0]]), channel_id=5
+; CHECK-NEXT: [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
+; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=2
+; CHECK-NEXT: [[CONVERT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[GTE0]])
+; CHECK-NEXT: [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
+; CHECK-NEXT: [[BCAST0:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
+; CHECK-NEXT: [[MUL0:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
+; CHECK-NEXT: [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
+; CHECK-NEXT: [[GTE4:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
+; CHECK-NEXT: [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE4]]), dimensions={}
+; CHECK-NEXT: [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
+; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL0]], [[MUL1]]),
+; CHECK-DAG: lhs_contracting_dims={2},
+; CHECK-DAG: rhs_contracting_dims={0},
+; CHECK-DAG: backend_config={
+; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
+; CHECK-DAG: "wait_on_operation_queues":[],
+; CHECK-DAG: "force_earliest_schedule":true}
+; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0)
+; CHECK-NEXT: [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
+; CHECK-NEXT: [[GTE5:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
+; CHECK-NEXT: [[PID:%[^ ]+]] = u32[] partition-id()
+; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[GTE5]], [[PID]])
+; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(4)
+; CHECK-NEXT: [[REM0:%[^ ]+]] = u32[] remainder([[ADD0]], [[C2]])
+; CHECK-NEXT: [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
+; CHECK-NEXT: [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
+; CHECK-NEXT: [[DUPDATESLICE0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[GTE2]], [[DOT0]], [[C0]], [[RESHAPE0]], [[C0]]),
+; CHECK-DAG: backend_config={
+; CHECK-DAG: "operation_queue_id":"0",
+; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"],
+; CHECK-DAG: "force_earliest_schedule":false}
+; CHECK-NEXT: [[CONVERT2:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[CP0]])
+; CHECK-NEXT: [[MUL2:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT2]], [[BCAST0]])
+; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL2]], [[MUL1]]),
+; CHECK-DAG: lhs_contracting_dims={2},
+; CHECK-DAG: rhs_contracting_dims={0}
+; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(1)
+; CHECK-NEXT: [[ADD1:%[^ ]+]] = u32[] add([[GTE5]], [[C3]])
+; CHECK-NEXT: [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
+; CHECK-NEXT: [[REM1:%[^ ]+]] = u32[] remainder([[ADD2]], [[C2]])
+; CHECK-NEXT: [[DSLICE1:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
+; CHECK-NEXT: [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE1]])
+; CHECK-NEXT: [[DUPDATESLICE1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[DUPDATESLICE0]], [[DOT1]], [[C0]], [[RESHAPE1]], [[C0]])
+; CHECK-NEXT: [[GTE6:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=3
+; CHECK-NEXT: [[ADD3:%[^ ]+]] = u32[] add([[ADD1]], [[C3]])
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[CP1]], [[GTE1]], [[DUPDATESLICE1]], [[GTE6]], [[ADD3]], /*index=5*/[[GTE3]], [[GTE4]])
+)");
+}
+
+TEST_F(WindowedEinsumHanlderTest, ReduceScatterF8) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f8e4m3fn[2,2048,24576]{2,1,0}, f32[], f32[])->f32[2,512,24576]{2,1,0}}, num_partitions=4
+
+windowed_dot_general_body_rs {
+ param.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element.7 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.3), index=0
+ get-tuple-element.8 = f32[24576,24576]{1,0} get-tuple-element(param.3), index=1
+ get-tuple-element.9 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=2
+ collective-permute.2 = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.9), channel_id=9, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
+ constant.23 = s32[] constant(0)
+ constant.24 = s32[4]{0} constant({0, 512, 1024, 1536})
+ get-tuple-element.11 = u32[] get-tuple-element(param.3), index=4
+ constant.26 = u32[] constant(2)
+ add.8 = u32[] add(get-tuple-element.11, constant.26)
+ constant.27 = u32[] constant(1)
+ add.9 = u32[] add(add.8, constant.27)
+ partition-id.3 = u32[] partition-id()
+ add.10 = u32[] add(add.9, partition-id.3)
+ constant.22 = u32[] constant(4)
+ remainder.3 = u32[] remainder(add.10, constant.22)
+ dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.24, remainder.3), dynamic_slice_sizes={1}
+ reshape.3 = s32[] reshape(dynamic-slice.4)
+ dynamic-slice.5 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.3, constant.23), dynamic_slice_sizes={2,512,24576}
+ dot.3 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.5, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ add.11 = f32[2,512,24576]{2,1,0} add(collective-permute.2, dot.3)
+ get-tuple-element.10 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=3
+ add.6 = u32[] add(get-tuple-element.11, partition-id.3)
+ remainder.2 = u32[] remainder(add.6, constant.22)
+ dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.24, remainder.2), dynamic_slice_sizes={1}
+ reshape.2 = s32[] reshape(dynamic-slice.2)
+ dynamic-slice.3 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.2, constant.23), dynamic_slice_sizes={2,512,24576}
+ dot.2 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.3, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ add.7 = f32[2,512,24576]{2,1,0} add(get-tuple-element.10, dot.2)
+ collective-permute.3 = f32[2,512,24576]{2,1,0} collective-permute(add.7), channel_id=10, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
+ ROOT tuple.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(get-tuple-element.7, get-tuple-element.8, add.11, collective-permute.3, add.8)
+} // windowed_dot_general_body_rs
+
+windowed_dot_general_cond_rs {
+ param.2 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element.6 = u32[] get-tuple-element(param.2), index=4
+ constant.21 = u32[] constant(4)
+ ROOT compare.1 = pred[] compare(get-tuple-element.6, constant.21), direction=LT
+}
+
+ENTRY main.9_spmd {
+ param.6 = f8e4m3fn[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
+ param.7 = f32[2,512,24576]{2,1,0} parameter(1)
+ param.8 = f8e4m3fn[2,2048,24576]{2,1,0} parameter(2)
+ constant.20 = u32[] constant(0)
+ scale_lhs = f32[] parameter(3)
+ scale_lhs_bcast = f32[2,2048,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
+ lhs_bf16 = f32[2,2048,24576]{2,1,0} convert(param.8)
+ lhs_scaled = f32[2,2048,24576]{2,1,0} multiply(lhs_bf16, scale_lhs_bcast)
+ scale_rhs = f32[] parameter(4)
+ scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
+ rhs_bf16 = f32[24576,24576]{1,0} convert(param.6)
+ rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf16, scale_rhs_bcast)
+ tuple.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, param.7, param.7, constant.20)
+ while.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs
+ ROOT get-tuple-element.14 = f32[2,512,24576]{2,1,0} get-tuple-element(while.1), index=2
+}
+)";
+
+ RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(),
+ R"(
+; CHECK-LABEL: windowed_dot_general_body_rs
+; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
+; CHECK-NEXT: [[GTE0:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=0
+; CHECK-NEXT: [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
+; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=2
+; CHECK-NEXT: [[CP0:%[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[GTE2]]), channel_id=9
+; CHECK-NEXT: [[CONVERT0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[GTE0]])
+; CHECK-NEXT: [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
+; CHECK-NEXT: [[BCAST0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
+; CHECK-NEXT: [[MUL0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
+; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0)
+; CHECK-NEXT: [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
+; CHECK-NEXT: [[GTE4:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
+; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(2)
+; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[GTE4]], [[C2]])
+; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(1)
+; CHECK-NEXT: [[ADD1:%[^ ]+]] = u32[] add([[ADD0]], [[C3]])
+; CHECK-NEXT: [[PID:%[^ ]+]] = u32[] partition-id()
+; CHECK-NEXT: [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
+; CHECK-NEXT: [[C4:%[^ ]+]] = u32[] constant(4)
+; CHECK-NEXT: [[REM0:%[^ ]+]] = u32[] remainder([[ADD2]], [[C4]])
+; CHECK-NEXT: [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
+; CHECK-NEXT: [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
+; CHECK-NEXT: [[DSLICE1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE0]], [[C0]]), dynamic_slice_sizes={2,512,24576}
+; CHECK-NEXT: [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
+; CHECK-NEXT: [[GTE5:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
+; CHECK-NEXT: [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE5]]), dimensions={}
+; CHECK-NEXT: [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
+; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE1]], [[MUL1]]),
+; CHECK-DAG: lhs_contracting_dims={2},
+; CHECK-DAG: rhs_contracting_dims={0},
+; CHECK-DAG: backend_config={
+; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
+; CHECK-DAG: "wait_on_operation_queues":[],
+; CHECK-DAG: "force_earliest_schedule":false}
+; CHECK-NEXT: [[ADD3:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[CP0]], [[DOT0]]),
+; CHECK-DAG: backend_config={"
+; CHECK-DAG: operation_queue_id":"0",
+; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"],
+; CHECK-DAG: "force_earliest_schedule":false}
+; CHECK-NEXT: [[GTE6:[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=3
+; CHECK-NEXT: [[ADD4:%[^ ]+]] = u32[] add([[GTE4]], [[PID]])
+; CHECK-NEXT: [[REM1:%[^ ]+]] = u32[] remainder([[ADD4]], [[C4]])
+; CHECK-NEXT: [[DSLICE2:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
+; CHECK-NEXT: [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE2]])
+; CHECK-NEXT: [[DSLICE3:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE1]], [[C0]]), dynamic_slice_sizes={2,512,24576}
+; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE3]], [[MUL1]]),
+; CHECK-DAG: lhs_contracting_dims={2},
+; CHECK-DAG: rhs_contracting_dims={0}
+; CHECK-NEXT: [[ADD5:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[GTE6]], [[DOT1]])
+; CHECK-NEXT: [[CP1:[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[ADD5]]), channel_id=10
+; CHECK-NEXT: ROOT [[OUT:[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[GTE0]], [[GTE1]], [[ADD3]], [[CP1]], [[ADD0]], /*index=5*/[[GTE3]], [[GTE5]])
+)");
+}
+
+TEST_F(WindowedEinsumHanlderTest,
+ AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) {
+ constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8
+
+windowed_dot_general_body_ag {
+ param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0
+ collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
+ collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
+ get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1
+ get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2
+ constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584})
+ get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4
+ partition-id.194 = u32[] partition-id()
+ add.4309 = u32[] add(get-tuple-element.592, partition-id.194)
+ constant.11431 = u32[] constant(8)
+ remainder.194 = u32[] remainder(add.4309, constant.11431)
+ dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1}
+ reshape.12959 = s32[] reshape(dynamic-slice.388)
+ constant.11433 = s32[] constant(0)
+ dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288}
+ dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244)
+ constant.11434 = u32[] constant(1)
+ add.4312 = u32[] add(get-tuple-element.592, constant.11434)
+ add.4313 = u32[] add(add.4312, partition-id.194)
+ remainder.195 = u32[] remainder(add.4313, constant.11431)
+ dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1}
+ reshape.12960 = s32[] reshape(dynamic-slice.390)
+ dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288}
+ dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+ add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245)
+ get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3
+ add.4315 = u32[] add(add.4312, constant.11434)
+ ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315)
+} // windowed_dot_general_body_ag
+
+windowed_dot_general_cond_ag {
+ param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
+ get-tuple-element = u32[] get-tuple-element(param), index=4
+ constant = u32[] constant(4)
+ ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
+}
+
+ENTRY main.12_spmd {
+ param.4 = bf16[16,2048,512]{2,1,0} parameter(0)
+ param.5 = bf16[4096,6288]{1,0} parameter(1)
+ constant.22 = bf16[] constant(0)
+ broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={}
+ constant.24 = u32[] constant(0)
+ tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
+ while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
+ get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2
+ all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true
+ param.6 = bf16[16,2048,6288]{2,1,0} parameter(2)
+ ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(kHloString));
+
+ WindowedEinsumHandler gpu_handler;
+ bool changed;
+ TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+ EXPECT_TRUE(changed);
+
+ HloInstruction* ag_loop =
+ FindInstructionByName(module->entry_computation(), "while");
+ HloInstruction* inst =
+ FindInstructionByName(module->entry_computation(), "dot.7");
+ // dot.7 should now consume output of the windowed einsum while loop.
+ EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
+ EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
+ EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
+}
+} // namespace
+} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc
deleted file mode 100644
index b54d006..0000000
--- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc
+++ /dev/null
@@ -1,389 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/tree_reduction_rewriter.h"
-
-#include <algorithm>
-#include <cmath>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/numeric/bits.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-absl::InlinedVector<int64_t, 2> GetSortedReducedDims(
- HloReduceInstruction *reduce) {
- absl::InlinedVector<int64_t, 2> reduced_dims{reduce->dimensions().begin(),
- reduce->dimensions().end()};
- absl::c_sort(reduced_dims);
- return reduced_dims;
-}
-
-bool IsMinMaxReduction(HloReduceInstruction *reduce) {
- HloComputation *called = &reduce->to_apply()[0];
- if (auto reduction_kind = MatchReductionComputation(called)) {
- return reduction_kind == ReductionKind::MAX ||
- reduction_kind == ReductionKind::MIN;
- }
- return false;
-}
-
-} // namespace
-
-class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
- public:
- explicit ReductionRewriterVisitor(se::GpuComputeCapability gpu_version)
- : gpu_version_(gpu_version) {}
-
- absl::Status HandleReduce(HloInstruction *hlo) override {
- auto *reduce = Cast<HloReduceInstruction>(hlo);
- VLOG(3) << "Reduction instruction: " << reduce->ToString();
-
- const HloModuleConfig &config = reduce->GetModule()->config();
- if (!MatchReductionForSplit(reduce, config)) {
- return absl::OkStatus();
- }
- ReductionDimensions reduction_dims =
- GetReductionKindAndContiguousComponents(*hlo);
- if (ReductionIsRaceFree(config, reduction_dims)) {
- VLOG(3) << "Base case: dimensions fit";
- return absl::OkStatus();
- }
- auto sorted_dims_to_reduce = GetSortedReducedDims(reduce);
- CHECK_LE(sorted_dims_to_reduce.size(), 2);
-
- // If the major reduced dimension does not fit, reduce the minor dimension
- // first, then the major.
- if (reduction_dims.is_row_reduction &&
- reduction_dims
- .dimensions[ReductionDimensions::kRowMajorReducedDimension] >
- BatchedReductionRaceFreeBound()) {
- VLOG(2) << "Splitting batched dimension reduce into a separate reduction";
- return RewriteBatchDimensionLargerThanTile(reduce, reduction_dims,
- sorted_dims_to_reduce);
- }
- SplitParams split_params =
- ComputeSplitParams(reduce, reduction_dims, sorted_dims_to_reduce);
- return SplitReductionDimension(reduce, split_params, sorted_dims_to_reduce);
- }
-
- private:
- bool MatchReductionForSplit(HloReduceInstruction *reduce,
- const HloModuleConfig &config) {
- // MLIR emitters only support race-free reductions.
- // TODO(jreiffers: Verify performance and implement atomics for reductions
- // if needed.
- bool reductions_via_mlir_disabled =
- config.debug_options().xla_gpu_mlir_emitter_level() < 4;
- if (reductions_via_mlir_disabled && IsMinMaxReduction(reduce)) {
- // TODO(cheshire): Also enable for integers.
- VLOG(1) << "Not performing tree expansion on min/max-reduction: "
- << reduce->ToString()
- << " since min/max operations are associative";
- return false;
- }
- if (!IsReductionFromOrToContiguousDimensions(*reduce)) {
- VLOG(3) << "Is not a reduction from or to contiguous dimensions";
- return false;
- }
- VLOG(3) << "Perform rewrite";
- return true;
- }
-
- // We observe larger n_div_k can improve tree reduction performance in most of
- // the cases by reducing memory store and the launch overhead of blocks. Swap
- // k and n_div_k if possible.
- bool ShouldSwapInnerAndOuterReducedMinorDimension(uint64_t k1, uint64_t k2,
- uint64_t n,
- int64_t race_free_bound,
- bool is_row_reduction) {
- CHECK(k1 >= k2);
- // Keep inner reduction as race free.
- if (k1 > race_free_bound) {
- return false;
- }
- // Swapping only affects row reduction vectorization.
- if (is_row_reduction) {
- // Rough conditions for row reduction vectorization, not mean that
- // vectorization will definitely occur.
- bool maybe_vectorized = k2 % 2 == 0 && n % 2 == 0;
- if (maybe_vectorized) {
- // Swap if n_div_k is small enough or k dim can be vectorized also.
- return k2 * 2 < k1 || k1 % 2 == 0;
- }
- // Current reduction emitter only checks reduction input dimensions but
- // not fusion input dimensions. Due to pad and inner reduction always fuse
- // into same computation, it may leads to each thread reads multiple non
- // aligned elements but can not vectorized so that get bad performance.
- // Don't swap If encountered this situation.
- return n % 2 == 0 || k1 % 2 != 0;
- }
- // There exists no specific situation where swapping has no performance gain
- // for column reduction.
- return true;
- }
-
- // Parameters how to split a dimension `dim` with `k` elements into `k1` x
- // `k2`.
- struct SplitParams {
- int64_t k1;
- int64_t k2;
- int64_t dim;
- };
-
- // Attempts to find the best way to split a dimension `dim` with `k` elements
- // into `k1` x `k2`.
- SplitParams ComputeSplitParams(
- HloReduceInstruction *reduce, const ReductionDimensions &reduction_dims,
- absl::Span<const int64_t> sorted_dims_to_reduce) {
- absl::Span<int64_t const> input_shape_dims =
- reduce->inputs()[0]->shape().dimensions();
-
- int64_t reduced_dim = sorted_dims_to_reduce.back();
- int64_t reduced_dim_size = input_shape_dims[reduced_dim];
- VLOG(3) << "reduced dim size = " << reduced_dim_size;
-
- // We will do this reduction in two stages. The first will reduce from k
- // elements to k1 elements in the reduction dimension. The second will
- // reduce further, from k2 to 1 element.
- //
- // We do this by splitting the input shape [a, k, b] into [a, k1, k2, b].
- //
- // We want to choose k1 to be roughly equal to sqrt(k) so that we process
- // "most of" the reduction in the first step. But it is also important that
- // we choose a value of k1 with the least amount of padding we need to add
- // to n to make it divisible by k1. We search for the best value of k2
- // between sqrt(k)/2 and sqrt(k). If there are several possible values for
- // k2 that result in the minimum amount of padding, we also want k2 to
- // be a power of 2, so that the GPU kernel doesn't spend all its time doing
- // slow integer divmods to compute indices into the shape [a,k1,k2,b].
- // Note that by searching in the range between sqrt(k)/2 and sqrt(k), we
- // will have a power of 2 in that range.
- uint64_t k2 =
- static_cast<uint64_t>(std::floor(std::sqrt(reduced_dim_size)));
- int64_t race_free_bound = ReductionDimensionRaceFreeBound(
- reduce->GetModule()->config(), reduction_dims);
- if (k2 > race_free_bound) {
- // This means we need more than one split. It is best to limit the n/k
- // dimension to the maximum size that doesn't require further splitting.
- // Otherwise we might choose a rather small reduce dimension size for the
- // first step (in the worst case, sqrt(race_free_bound + 1)).
- k2 = race_free_bound;
- }
- uint64_t minimum_padding = (k2 - reduced_dim_size % k2) % k2;
- uint64_t best_k1 = (reduced_dim_size + minimum_padding) / k2;
- for (uint64_t i = k2 - 1; i > k2 / 2; --i) {
- uint64_t padding = (i - reduced_dim_size % i) % i;
- if (padding < minimum_padding ||
- (padding == minimum_padding && absl::has_single_bit(i))) {
- minimum_padding = padding;
- best_k1 = (reduced_dim_size + padding) / i;
- }
- }
- uint64_t padded_k = reduced_dim_size + minimum_padding;
-
- // We get the best {k_1, k_2} pair by the size of padding and whether
- // index computation is fast. But we ignored the overhead of memory
- // read/write and blocks launch, which are also important for kernel
- // performance. It is obvious that the swapped {k1, k2} pairs has same
- // padding size and consumption of index computation as the original. So we
- // only need to compare the memory read/write and blocks launch to choose
- // the better one of them.
- uint64_t best_k2 = padded_k / best_k1;
- if (ShouldSwapInnerAndOuterReducedMinorDimension(
- best_k1, best_k2, reduced_dim_size, race_free_bound,
- reduction_dims.is_row_reduction)) {
- std::swap(best_k1, best_k2);
- }
- return SplitParams{static_cast<int64_t>(best_k1),
- static_cast<int64_t>(best_k2), reduced_dim};
- }
-
- // Replaces the original reduce with pad->reshape>inner_reduce->outer_reduce.
- // * 1. pads split dimension of the inputs to k1 * k2 if necessary.
- // * 2. reshapes split dimension of the padded inputs into [k1, k2].
- // * 3. inner reduction reduces the dims specified in the original reduction.
- // Instead of reducing the split dimension, reduces K2.
- // * 4. outer_reduction reduces K1 only.
- absl::Status SplitReductionDimension(
- HloReduceInstruction *reduce, const SplitParams &split_params,
- absl::Span<const int64_t> sorted_dims_to_reduce) {
- absl::Span<int64_t const> reduce_input_dims =
- reduce->inputs()[0]->shape().dimensions();
- int64_t split_dim_size = reduce_input_dims[split_params.dim];
- VLOG(2) << "dimension to split = " << split_params.dim << " with "
- << split_dim_size << " elements into " << split_params.k1 << " by "
- << split_params.k2;
-
- // Pad 'k' to 'k1 * k2' if necessary.
- HloInstruction::InstructionVector padded_inputs(reduce->inputs().begin(),
- reduce->inputs().end());
- auto padded_size = split_params.k1 * split_params.k2;
- absl::InlinedVector<int64_t, 3> padded_dimensions(reduce_input_dims.begin(),
- reduce_input_dims.end());
- if (split_dim_size != padded_size) {
- padded_dimensions[split_params.dim] = padded_size;
- PaddingConfig padding_config =
- MakeNoPaddingConfig(reduce_input_dims.size());
- padding_config.mutable_dimensions(split_params.dim)
- ->set_edge_padding_high(padded_size - split_dim_size);
-
- for (int input_idx = 0; input_idx < padded_inputs.size(); ++input_idx) {
- auto &reduction_input = padded_inputs[input_idx];
- Shape padded_shape = ShapeUtil::MakeShape(
- reduction_input->shape().element_type(), padded_dimensions);
- VLOG(2) << "Generated padded shape: " << padded_shape.ToString();
- reduction_input = reduce->parent()->AddInstruction(
- HloInstruction::CreatePad(padded_shape, reduction_input,
- reduce->init_values()[input_idx],
- padding_config),
- &reduction_input->metadata());
- }
- }
-
- // Compute output type of reshape that expands the split dimension into
- // [k1, k2].
- absl::InlinedVector<int64_t, 3> reshaped_dimensions;
- int64_t input_rank = reduce_input_dims.size();
- for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
- if (dim_idx == split_params.dim) {
- reshaped_dimensions.push_back(split_params.k1);
- reshaped_dimensions.push_back(split_params.k2);
- } else {
- reshaped_dimensions.push_back(padded_dimensions[dim_idx]);
- }
- }
-
- // Compute dimensions to reduce for inner reduction.
- absl::InlinedVector<int64_t, 2> inner_reduce_dims(
- sorted_dims_to_reduce.begin(), sorted_dims_to_reduce.end());
- auto split_dim_it = std::find(inner_reduce_dims.begin(),
- inner_reduce_dims.end(), split_params.dim);
- *split_dim_it += 1;
-
- // Compute dimension to reduce for outer reduction.
- absl::InlinedVector<int64_t, 1> outer_reduce_dims{
- split_params.dim -
- std::distance(inner_reduce_dims.begin(), split_dim_it)};
-
- // Compute output shape of the inner reduction.
- absl::InlinedVector<int64_t, 3> inner_reduce_shape =
- RemoveElements(inner_reduce_dims, reshaped_dimensions);
-
- // Reshape the split dimensions of the padded inputs into [k1, k2].
- HloInstruction::InstructionVector reshaped_padded_inputs;
- absl::InlinedVector<Shape, 2> inner_reduce_shapes;
- for (HloInstruction *padded_input : padded_inputs) {
- Shape reshaped_shape = ShapeUtil::MakeShape(
- padded_input->shape().element_type(), reshaped_dimensions);
- HloInstruction *reshaped_padded_input = reduce->parent()->AddInstruction(
- HloInstruction::CreateBitcast(reshaped_shape, padded_input),
- &padded_input->metadata());
- VLOG(2) << "Generated reshape: " << reshaped_padded_input->ToString();
- reshaped_padded_inputs.push_back(reshaped_padded_input);
- inner_reduce_shapes.push_back(ShapeUtil::MakeShape(
- padded_input->shape().element_type(), inner_reduce_shape));
- }
-
- // Inner reduce that reduces [k1, k2] to [k1].
- HloInstruction *inner_reduce = reduce->parent()->AddInstruction(
- HloInstruction::CreateReduce(
- ShapeUtil::MakeMaybeTupleShape(inner_reduce_shapes),
- reshaped_padded_inputs, reduce->init_values(), inner_reduce_dims,
- reduce->to_apply()),
- &reduce->metadata());
- VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
-
- // Outer reduce that reduces [k2].
- std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
- reduce->shape(), inner_reduce, reduce->init_values(), outer_reduce_dims,
- reduce->to_apply());
-
- VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
- return ReplaceWithNewInstruction(reduce, std::move(outer_reduce));
- }
-
- // Rewrites batch dimension reduction into a separate reduce operation.
- absl::Status RewriteBatchDimensionLargerThanTile(
- HloReduceInstruction *hlo,
- const ReductionDimensions &reduction_dimensions,
- absl::Span<const int64_t> sorted_dims_to_reduce) {
- // TODO(cheshire): this codepath is essentially the exact reverse of what
- // algebraic_simplifier is doing, we need to make sure they don't keep
- // undoing each other.
- CHECK(reduction_dimensions.is_row_reduction);
-
- absl::InlinedVector<Shape, 2> tuple_shapes;
- int64_t minor_reduction_dim = sorted_dims_to_reduce.back();
- for (HloInstruction *input : hlo->inputs()) {
- tuple_shapes.push_back(
- ShapeUtil::DeleteDimension(minor_reduction_dim, input->shape()));
- }
-
- HloInstruction *inner_reduce =
- hlo->parent()->AddInstruction(HloInstruction::CreateReduce(
- ShapeUtil::MakeMaybeTupleShape(tuple_shapes), hlo->inputs(),
- hlo->init_values(), {minor_reduction_dim}, hlo->to_apply()));
-
- VLOG(1) << "Inner reduction: " << inner_reduce->ToString();
- std::unique_ptr<HloInstruction> out = HloInstruction::CreateReduce(
- hlo->shape(), inner_reduce, hlo->init_values(), {0}, hlo->to_apply());
- VLOG(1) << "Generated: " << out->ToString();
- return ReplaceWithNewInstruction(hlo, std::move(out));
- }
-
- se::GpuComputeCapability gpu_version_;
-};
-
-absl::StatusOr<bool> GpuTreeReductionRewriter::Run(
- HloModule *module,
- const absl::flat_hash_set<absl::string_view> &execution_threads) {
- VLOG(5) << "Rewriter input: " << module->ToString();
- TF_ASSIGN_OR_RETURN(bool changed,
- ReductionRewriterVisitor(gpu_version_)
- .RunOnModule(module, execution_threads));
- VLOG(5) << "Rewriter output: " << module->ToString();
- return changed;
-}
-
-} // end namespace gpu
-} // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h
deleted file mode 100644
index 5f6edf8..0000000
--- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_
-#define XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_
-
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites reductions in a way they can be implemented without atomics.
-//
-// Rule application: rewrite a single HLO reduce operation into two.
-//
-// Case 1: Row reduction, batched dimension is present, larger than
-// Z-tiling size.
-// -----------------------------------------------------------------
-//
-// Rewriting:
-//
-// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
-//
-// Into:
-//
-// f32[A, B] tmp = reduce(f32[A, B, C] input, dimensions={2})
-// f32[B] out = reduce(f32[A, B] tmp, dimensions={0})
-//
-// Case 2: Row reduction
-// ------------------------------------------------------------------
-//
-// Let M be the thread tiling multiplied by the warp size.
-// We go from (assuming C > M):
-//
-// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
-//
-// to:
-//
-// f32[A, B, P] padded = pad(input) // Let P = ceil(C/M) * M.
-// f32[A, B, Q, M] reshaped = bitcast(padded) // Let Q = ceil(C/M)
-// f32[B, Q] inner_reduce = reduce(reshaped, dimensions={0, 3})
-// f32[B] outer_reduce = reduce(inner_reduce, dimensions={1})
-//
-// Case 3: Column reduction
-// -------------------------------------------------------------------
-//
-// Let T be the tiling size for the column reduction.
-//
-// We go from (assuming B > T):
-//
-// f32[A, C] out = reduce(f32[A, B, C] input, dimensions={1})
-//
-// to:
-//
-// f32[A, P, C] padded = pad(input) // Let P = ceil(B/T) * T.
-// f32[A, Q, T, C] reshaped = bitcast(padded) // Let Q = ceil(B/T)
-// f32[A, Q, C] inner_reduce = reduce(reshaped, dimensions={2})
-// f32[A, C] outer_reduce = reduce(inner_reduce, dimensions={1})
-//
-class GpuTreeReductionRewriter : public HloModulePass {
- public:
- explicit GpuTreeReductionRewriter(se::GpuComputeCapability gpu_version)
- : gpu_version_(gpu_version) {}
-
- ~GpuTreeReductionRewriter() override = default;
- absl::string_view name() const override {
- return "gpu-tree-reduction-rewriter";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- se::GpuComputeCapability gpu_version_;
-};
-
-} // end namespace gpu
-} // end namespace xla
-
-#endif // XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc b/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc
deleted file mode 100644
index 2dcd365..0000000
--- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc
+++ /dev/null
@@ -1,88 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/triangular_solve_rewriter.h"
-
-#include <cstdint>
-#include <numeric>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> TriangularSolveRewriter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- std::vector<HloInstruction*> to_rewrite;
- for (HloInstruction* instr : comp->instructions()) {
- if (instr->opcode() == HloOpcode::kTriangularSolve) {
- to_rewrite.push_back(instr);
- }
- }
-
- for (HloInstruction* instr : to_rewrite) {
- const Shape& b_shape = instr->operand(1)->shape();
- int64_t batch_size = std::accumulate(
- b_shape.dimensions().begin(), b_shape.dimensions().end() - 2,
- int64_t{1}, [](int64_t a, int64_t b) { return a * b; });
-
- // batch 1 triangular solves get 0 temp bytes, because unbatched trsm()
- // doesn't require temp memory.
- int64_t temp_bytes = batch_size == 1 ? 0 : 2 * sizeof(void*) * batch_size;
- Shape new_shape = ShapeUtil::MakeTupleShape({
- instr->shape(),
- ShapeUtil::MakeShape(S8, {temp_bytes}),
- });
-
- HloInstruction* custom_call =
- comp->AddInstruction(HloInstruction::CreateCustomCall(
- new_shape, instr->operands(), kTriangularSolveCallTarget));
- module->SetAndUniquifyInstrName(custom_call, "triangular-solve");
- TF_RETURN_IF_ERROR(
- custom_call->set_backend_config(instr->triangular_solve_options()));
-
- // Preserve metadata from `instr`.
- custom_call->set_metadata(instr->metadata());
- custom_call->set_frontend_attributes(instr->frontend_attributes());
-
- // Get the actual result out of the custom call's tuple.
- TF_ASSIGN_OR_RETURN(HloInstruction * gte,
- MakeGetTupleElementHlo(custom_call, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
- }
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h b/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h
deleted file mode 100644
index 6d4b1c1..0000000
--- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_
-#define XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites HLO TriangularSolve ops into a custom-call.
-//
-// The motivation for this is that we need to add temp memory to batched
-// triangular-solve ops in order to call cublas trsmBatched. We rewrite batch 1
-// ops as well so that we have fewer codepaths to worry about in the backend.
-//
-// cublas trsmBatched takes arrays in GPU memory of pointers to the inputs and
-// outputs, `a` and `b`. In XLA the inputs/outputs are always contiguous, but
-// we still have to materialize out these arrays.
-//
-// We use the same trick as for cudnn convolutions: This custom-call returns a
-// tuple (actual-result, temp-memory). In this our case the temp buffer always
-// has size 2 * sizeof(void*) * batch_size, because we need two arrays of
-// pointers.
-//
-// The custom-call has a backend-config equal to the TriangularSolveOptions
-// object.
-class TriangularSolveRewriter : public HloModulePass {
- public:
- absl::string_view name() const override {
- return "triangular-solve-rewriter";
- }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc
index c911bf1..c5cc0ad 100644
--- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc
@@ -23,7 +23,7 @@
#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gemm_fusion.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/verified_hlo_module.h"
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc
deleted file mode 100644
index 75c43fe..0000000
--- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc
+++ /dev/null
@@ -1,190 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/triton_fusion_numerics_verifier.h"
-
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/functional/any_invocable.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-namespace {
-
-using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
-
-// Returns the input instruction as a fusion instruction, if it represents a
-// Triton fusion. Otherwise, returns nullptr.
-absl::StatusOr<const HloFusionInstruction*> AsTritonFusion(
- const HloInstruction* hlo) {
- if (hlo->opcode() != HloOpcode::kFusion) {
- return nullptr;
- }
- const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
- TF_ASSIGN_OR_RETURN(auto gpu_config,
- fusion->backend_config<GpuBackendConfig>());
- const FusionBackendConfig& backend_config =
- gpu_config.fusion_backend_config();
- if (backend_config.kind() == kTritonFusionKind) {
- return fusion;
- }
- return nullptr;
-}
-
-std::unique_ptr<HloModule> NewHloModuleFromFusion(
- const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
- bool clear_backend_config) {
- std::unique_ptr<HloModule> new_module =
- ExtractInstructionIntoNewModule(fusion);
- if (clear_backend_config) {
- new_module->entry_computation()->root_instruction()->clear_backend_config();
- }
- new_module->mutable_config().set_debug_options(debug_opts);
-
- return new_module;
-}
-
-} // namespace
-
-namespace triton_fusion_numerics_pass_internal {
-
-absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
- AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
- const AutotuneConfig& config, const DebugOptions& debug_opts,
- bool clear_backend_config) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
- util.Compile([&](const DebugOptions& opts) {
- return NewHloModuleFromFusion(fusion, opts,
- clear_backend_config);
- }));
- TF_ASSIGN_OR_RETURN(auto rz_buffers, RedzoneBuffers::FromInstruction(
- fusion, config, debug_opts,
- RedzoneBuffers::kAllInputs));
- TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
- TF_ASSIGN_OR_RETURN(std::optional<ProfilingOutput> profiling_output,
- util.ProfileExecutable(executable.get(), stream,
- rz_buffers.input_buffers(),
- rz_buffers.input_shapes()));
- if (!profiling_output.has_value()) {
- return Internal("No output after a successful verification run.");
- }
-
- return std::move(profiling_output->output);
-}
-
-absl::Status CompareBuffers(const ScopedShapedBuffer& current,
- const ScopedShapedBuffer& expected,
- const Shape& shape, const HloModuleConfig& config,
- se::Stream* stream) {
- BufferComparator comparator(
- shape, config.debug_options().xla_gpu_autotune_gemm_rtol());
- TF_ASSIGN_OR_RETURN(bool outputs_match,
- comparator.CompareEqual(stream, current.root_buffer(),
- expected.root_buffer()));
-
- if (!outputs_match) {
- return Internal("Triton fusion output does not match emitters output.");
- }
- return absl::OkStatus();
-}
-
-absl::Status ForAllTritonFusions(
- const HloModule& module,
- const absl::flat_hash_set<absl::string_view>& execution_threads,
- absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn) {
- for (HloComputation* computation :
- module.MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* instruction : computation->instructions()) {
- TF_ASSIGN_OR_RETURN(auto triton_fusion, AsTritonFusion(instruction));
- if (triton_fusion != nullptr) {
- TF_RETURN_IF_ERROR(fn(*triton_fusion));
- }
- }
- }
- return absl::OkStatus();
-}
-
-} // namespace triton_fusion_numerics_pass_internal
-
-namespace {
-absl::Status VerifyTritonFusion(AutotunerCompileUtil& util,
- const HloFusionInstruction& fusion,
- const AutotuneConfig& config,
- const DebugOptions& debug_opts) {
- TF_ASSIGN_OR_RETURN(auto triton_result,
- triton_fusion_numerics_pass_internal::CompileAndRunFusion(
- util, fusion, config, debug_opts,
- /*clear_backend_config=*/false));
- TF_ASSIGN_OR_RETURN(auto emitters_result,
- triton_fusion_numerics_pass_internal::CompileAndRunFusion(
- util, fusion, config, debug_opts,
- /*clear_backend_config=*/true));
-
- TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
- return triton_fusion_numerics_pass_internal::CompareBuffers(
- triton_result, emitters_result, fusion.shape(),
- fusion.GetModule()->config(), stream);
-}
-
-} // namespace
-
-absl::StatusOr<bool> TritonFusionNumericsVerifier::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- if (config_.IsDeviceless()) {
- return absl::InternalError(
- "Cannot run TritonFusionNumericsVerifier on a deviceless compilation.");
- }
-
- const DebugOptions& debug_options = module->config().debug_options();
- TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
- AutotunerCompileUtil::Create(config_, debug_options));
- TF_RET_CHECK(opt_compile_util.has_value());
-
- TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions(
- *module, execution_threads, [&](const HloFusionInstruction& fusion) {
- return VerifyTritonFusion(*opt_compile_util, fusion, config_,
- debug_options);
- }));
- return false;
-}
-
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h
deleted file mode 100644
index 6d74f46..0000000
--- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_
-#define XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/functional/any_invocable.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/stream.h"
-
-namespace xla::gpu {
-
-// For each Triton fusion in the Hlo module this pass checks that the output
-// of the fusion generated via Triton matches the output of the fusion if
-// generated with the regular emitters.
-class TritonFusionNumericsVerifier : public HloModulePass {
- public:
- explicit TritonFusionNumericsVerifier(const AutotuneConfig& config)
- : config_(config) {}
-
- static absl::string_view Name() { return "triton-numerics-verifier"; }
- absl::string_view name() const override { return Name(); }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
- AutotuneConfig config_;
-};
-
-namespace triton_fusion_numerics_pass_internal {
-// These are exposed only for testing. Do not use.
-absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
- AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
- const AutotuneConfig& config, const DebugOptions& debug_opts,
- bool clear_backend_config);
-absl::Status CompareBuffers(const ScopedShapedBuffer& current,
- const ScopedShapedBuffer& expected,
- const Shape& shape, const HloModuleConfig& config,
- se::Stream* stream);
-absl::Status ForAllTritonFusions(
- const HloModule& module,
- const absl::flat_hash_set<absl::string_view>& execution_threads,
- absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn);
-} // namespace triton_fusion_numerics_pass_internal
-
-} // namespace xla::gpu
-
-#endif // XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc
deleted file mode 100644
index 8703eff..0000000
--- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc
+++ /dev/null
@@ -1,195 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/triton_fusion_numerics_verifier.h"
-
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/platform_util.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/test_helpers.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-
-namespace xla::gpu {
-namespace {
-
-class TritonFusionNumericsVerifierTest
- : public HloTestBase,
- public ::testing::WithParamInterface<PrimitiveType> {
- public:
- DebugOptions GetDebugOptionsForTest() override {
- auto options = HloTestBase::GetDebugOptionsForTest();
- options.set_xla_gpu_enable_triton_softmax_fusion(true);
- options.set_xla_gpu_verify_triton_fusion_numerics(true);
- return options;
- }
-
- protected:
- std::unique_ptr<xla::HloModule> Module(absl::string_view hlo_text_template,
- absl::string_view type) {
- auto m = GetOptimizedModule(absl::Substitute(hlo_text_template, type));
- TF_EXPECT_OK(m);
- return std::move(m.value());
- }
-
- const HloFusionInstruction* TritonFusion(const xla::HloModule& module) {
- const HloFusionInstruction* fusion_result = nullptr;
-
- absl::Status res =
- triton_fusion_numerics_pass_internal::ForAllTritonFusions(
- module, /*execution_threads=*/{},
- [&](const HloFusionInstruction& fusion) -> absl::Status {
- EXPECT_EQ(fusion_result, nullptr);
- fusion_result = &fusion;
- return absl::OkStatus();
- });
- return fusion_result;
- }
-
- AutotuneConfig CreateAutotuneConfig() {
- se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
- auto executors_or = PlatformUtil::GetStreamExecutors(platform);
- TF_EXPECT_OK(executors_or);
- return AutotuneConfig{DeviceConfig{executors_or->at(0), nullptr},
- GetDebugOptionsForTest()};
- }
-
- AutotunerCompileUtil CreateAutotunerCompileUtil(AutotuneConfig& config) {
- auto opt_compile_util_or =
- AutotunerCompileUtil::Create(config, GetDebugOptionsForTest());
- TF_EXPECT_OK(opt_compile_util_or);
- EXPECT_TRUE(opt_compile_util_or->has_value());
- return std::move(opt_compile_util_or->value());
- }
-};
-
-constexpr absl::string_view kSoftmaxHlo = R"(
-HloModule softmax
-max_computation {
- arg_0 = $0[] parameter(0)
- arg_1 = $0[] parameter(1)
- ROOT maximum = $0[] maximum(arg_0, arg_1)
-}
-add_computation {
- arg_0.1 = $0[] parameter(0)
- arg_1.1 = $0[] parameter(1)
- ROOT add = $0[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
- param_0 = $0[127,125]{1,0} parameter(0)
- constant_neg_inf = $0[] constant(-inf)
- reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
- broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0}
- subtract = $0[127,125]{1,0} subtract(param_0, broadcast)
- exponential = $0[127,125]{1,0} exponential(subtract)
- constant_zero = $0[] constant(0)
- second_reduce = $0[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
- second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0}
- ROOT divide = $0[127,125]{1,0} divide(exponential, second_broadcast)
-}
-)";
-
-bool HloPassHasRun(const HloModule& module, absl::string_view pass_name) {
- for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) {
- if (pass_metadata.pass_name() == pass_name) {
- return true;
- }
- }
- return false;
-}
-
-TEST_P(TritonFusionNumericsVerifierTest, VerifyExactSoftmaxFusionNumerics) {
- PrimitiveType data_type = GetParam();
-
- auto module = Module(kSoftmaxHlo,
- primitive_util::LowercasePrimitiveTypeName(data_type));
-
- // At this point all HLO passes have been executed successfully, because the
- // Module() function hasn't failed. In particular the numerics verification
- // pass should have also run and **not** found any issues. Below we just
- // ensure that the pass has indeed been correctly enabled and that there are
- // Triton Fusions in the input module.
-
- EXPECT_TRUE(HloPassHasRun(*module, TritonFusionNumericsVerifier::Name()));
- auto fusion = TritonFusion(*module);
- EXPECT_NE(fusion, nullptr);
-}
-
-TEST_F(TritonFusionNumericsVerifierTest, CheckMismatch) {
- // This test intentionally compares two different Triton modules to each
- // other. This is to test that the verifier functions correctly catch and
- // report mismatches.
- //
- // Note that as part of computing the two modules below, the numerics verifier
- // pass also runs individually for each module. These runs compare the
- // modules to the corresponding emitters generated version, which matches. In
- // that sense this test covers what is being tested by
- // VerifyExactSoftmaxFusionNumerics. The reason to keep two tests is that
- // VerifyExactSoftmaxFusionNumerics is minimal and will be easier to debug if
- // it fails.
-
- auto module_f16 = Module(kSoftmaxHlo, "f16");
- auto fusion_f16 = TritonFusion(*module_f16);
- EXPECT_NE(fusion_f16, nullptr);
-
- auto module_f32 = Module(kSoftmaxHlo, "f32");
- auto fusion_f32 = TritonFusion(*module_f32);
- EXPECT_NE(fusion_f32, nullptr);
-
- AutotuneConfig autotune_config = CreateAutotuneConfig();
- AutotunerCompileUtil compile_util =
- CreateAutotunerCompileUtil(autotune_config);
- const DebugOptions& debug_options = GetDebugOptionsForTest();
-
- auto f16_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
- compile_util, *fusion_f16, autotune_config, debug_options,
- /*clear_backend_config=*/false);
- TF_EXPECT_OK(f16_result);
-
- auto f32_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
- compile_util, *fusion_f32, autotune_config, debug_options,
- /*clear_backend_config=*/false);
- TF_EXPECT_OK(f32_result);
-
- auto stream = autotune_config.GetStream();
- TF_EXPECT_OK(stream);
-
- // Intentionally compare the fusions from the different modules, triggering a
- // mismatch.
- auto cmp = triton_fusion_numerics_pass_internal::CompareBuffers(
- *f16_result, *f32_result, fusion_f16->shape(),
- fusion_f16->GetModule()->config(), *stream);
-
- EXPECT_FALSE(cmp.ok());
-}
-
-INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite,
- TritonFusionNumericsVerifierTest,
- ::testing::Values(F32, F16, BF16));
-
-} // namespace
-} // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc b/third_party/xla/xla/service/gpu/variadic_op_splitter.cc
deleted file mode 100644
index f137157..0000000
--- a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc
+++ /dev/null
@@ -1,115 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/variadic_op_splitter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/shape.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-// The parameter space on the GPU device is limited. We pick an arbitrary low
-// constant here to try to prevent exceeding this parameter space. For a proper
-// fix, we would have to take into account which parameters share a buffer, and
-// how big these buffers are.
-constexpr int32_t kMaxParameters = 128;
-
-absl::StatusOr<bool> SplitConcatenate(HloInstruction* concat,
- HloComputation* comp) {
- auto operands = concat->operands();
- std::vector<HloInstruction*> operands_to_split(operands.begin(),
- operands.end());
- while (operands_to_split.size() > 1) {
- std::vector<HloInstruction*> new_operands;
- absl::Span<HloInstruction*> operands_span(operands_to_split);
- for (int64_t offset = 0; offset < operands_to_split.size();
- offset += kMaxParameters) {
- // Check if there is a remainder of operands that does not completely fill
- // one "batch" of exactly 'kMaxParameters' operands. If there are only
- // less than 'kMaxParameters' operands left, then we still put them into a
- // concat together. Otherwise, we spare them for another round so that
- // they can be put together into a concat with some of the newly created
- // concats.
- if (offset > 0 && offset + kMaxParameters > operands_to_split.size()) {
- new_operands.insert(new_operands.end(),
- operands_to_split.begin() + offset,
- operands_to_split.end());
- } else {
- Shape new_shape = concat->shape();
- int64_t concat_dimension_size = 0;
- for (int64_t i = 0;
- i < kMaxParameters && offset + i < operands_to_split.size(); ++i) {
- concat_dimension_size +=
- operands_to_split[i + offset]->shape().dimensions(
- concat->concatenate_dimension());
- }
- new_shape.set_dimensions(concat->concatenate_dimension(),
- concat_dimension_size);
- auto new_concat = comp->AddInstruction(concat->CloneWithNewOperands(
- new_shape, operands_span.subspan(offset, kMaxParameters)));
- new_operands.push_back(new_concat);
- }
- }
- operands_to_split = new_operands;
- }
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0]));
- return true;
-}
-
-std::vector<HloInstruction*> GetRelevantVariadicOps(HloComputation* comp) {
- std::vector<HloInstruction*> ops;
- for (HloInstruction* instr : comp->instructions()) {
- if (instr->opcode() == HloOpcode::kConcatenate &&
- instr->operand_count() > kMaxParameters) {
- ops.push_back(instr);
- }
- }
- return ops;
-}
-
-} // namespace
-
-absl::StatusOr<bool> VariadicOpSplitter::Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) {
- bool changed = false;
- for (HloComputation* comp :
- module->MakeNonfusionComputations(execution_threads)) {
- for (HloInstruction* op : GetRelevantVariadicOps(comp)) {
- // TODO(b/112613927): Handle also other ops than concatenate.
- TF_ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp));
- changed |= result;
- }
- }
- return changed;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.h b/third_party/xla/xla/service/gpu/variadic_op_splitter.h
deleted file mode 100644
index 4449ce2..0000000
--- a/third_party/xla/xla/service/gpu/variadic_op_splitter.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_
-#define XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Splits variadic ops with many operands into pieces such that we don't exceed
-// the parameter space on the GPU. Currently only concatenate ops are split up.
-class VariadicOpSplitter : public HloModulePass {
- public:
- absl::string_view name() const override { return "variadic-op-splitter"; }
-
- using HloPassInterface::Run;
- absl::StatusOr<bool> Run(
- HloModule* module,
- const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-} // namespace gpu
-} // namespace xla
-
-#endif // XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc b/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc
deleted file mode 100644
index 6d7b72e..0000000
--- a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/variadic_op_splitter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/literal_util.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-using match::Concatenate;
-
-class VariadicOpSplitterTest : public HloTestBase {};
-
-TEST_F(VariadicOpSplitterTest, DontSplit) {
- auto module = ParseAndReturnVerifiedModule(R"(
- HloModule TestModule
-
- ENTRY TestComputation {
- p0 = f16[30,41] parameter(0)
- p1 = f16[30,41] parameter(1)
- ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
- })")
- .value();
- EXPECT_FALSE(VariadicOpSplitter().Run(module.get()).value());
-}
-
-TEST_F(VariadicOpSplitterTest, SplitInto2) {
- auto builder = HloComputation::Builder(TestName());
- auto operand = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
- std::vector<HloInstruction*> concat_operands(255, operand);
- builder.AddInstruction(HloInstruction::CreateConcatenate(
- ShapeUtil::MakeShape(S32, {255}), concat_operands, 0));
- auto module = CreateNewVerifiedModule();
- auto entry_computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
- EXPECT_TRUE(Match(entry_computation->root_instruction(),
- Concatenate().WithNumOperands(128).WithOperand(
- 0, Concatenate().WithNumOperands(128))));
-}
-
-TEST_F(VariadicOpSplitterTest, SplitInto3) {
- auto builder = HloComputation::Builder(TestName());
- auto operand = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
- std::vector<HloInstruction*> concat_operands(256, operand);
- builder.AddInstruction(HloInstruction::CreateConcatenate(
- ShapeUtil::MakeShape(S32, {256}), concat_operands, 0));
- auto module = CreateNewVerifiedModule();
- auto entry_computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
- EXPECT_TRUE(Match(entry_computation->root_instruction(),
- Concatenate(Concatenate().WithNumOperands(128),
- Concatenate().WithNumOperands(128))));
-}
-
-} // namespace
-} // namespace gpu
-} // namespace xla
diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc
index fc319e6..7f6cce5 100644
--- a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc
+++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc
@@ -219,14 +219,12 @@
absl::StatusOr<int64_t> HeapSimulator::MinimumMemoryForComputation(
const HloComputation& computation, const HloInstructionSequence& sequence,
const HloAliasAnalysis& alias_analysis,
- const LogicalBuffer::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>*
- memory_by_computation) {
+ const LogicalBuffer::SizeFunction& size_function) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result<HloValue> result,
HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
computation, sequence, alias_analysis, size_function,
- HeapSimulator::Options(), memory_by_computation));
+ HeapSimulator::Options()));
return result.heap_size;
}
@@ -267,11 +265,9 @@
const HloComputation& computation,
const HloInstructionSequence& instruction_sequence,
const HloAliasAnalysis& alias_analysis,
- const BufferValue::SizeFunction& size_fn, const Options& options,
- const absl::flat_hash_map<const HloComputation*, int64_t>*
- memory_by_computation) {
+ const BufferValue::SizeFunction& size_fn, const Options& options) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*schedule=*/nullptr, memory_by_computation);
+ /*schedule=*/nullptr);
HloSchedule schedule(computation.parent());
schedule.set_sequence(&computation, instruction_sequence);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
@@ -291,7 +287,7 @@
const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
const Options& options) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*schedule=*/schedule, nullptr);
+ /*schedule=*/schedule);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloLiveRange> hlo_live_range,
HloLiveRange::Run(*schedule, alias_analysis, &computation));
@@ -492,19 +488,16 @@
return absl::OkStatus();
}
-HeapSimulator::HeapSimulator(
- std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
- const BufferValue::SizeFunction& size_fn, const Options& options,
- const HloSchedule* schedule,
- const absl::flat_hash_map<const HloComputation*, int64_t>*
- memory_by_computation)
+HeapSimulator::HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
+ const BufferValue::SizeFunction& size_fn,
+ const Options& options,
+ const HloSchedule* schedule)
: no_fragmentation_stats_(
std::make_unique<NoFragmentationStatsHeap<HloValue>>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- schedule_(schedule),
- memory_by_computation_(memory_by_computation) {
+ schedule_(schedule) {
debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
}
@@ -629,21 +622,10 @@
template <typename BufferType>
void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
- const HloInstruction* instruction, int64_t alloc_size_by_instruction,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation) {
+ const HloInstruction* instruction, int64_t alloc_size_by_instruction) {
// We only count the memory usage of the largest subcomputation, instead of
// adding them all, because subcomputations won't execute in parallel.
int64_t max_subcomputation_bytes = 0;
- for (const auto* c : instruction->called_computations()) {
- auto it = memory_by_computation.find(c);
- if (it != memory_by_computation.end()) {
- int64_t subcomputation_bytes = it->second;
- if (subcomputation_bytes > max_subcomputation_bytes) {
- max_subcomputation_bytes = subcomputation_bytes;
- }
- }
- }
if (max_subcomputation_bytes > 0 &&
(instruction->opcode() == HloOpcode::kWhile ||
instruction->opcode() == HloOpcode::kCall ||
diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.h b/third_party/xla/xla/service/heap_simulator/heap_simulator.h
index 09e12d2..e446a4e 100644
--- a/third_party/xla/xla/service/heap_simulator/heap_simulator.h
+++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.h
@@ -148,9 +148,7 @@
static absl::StatusOr<int64_t> MinimumMemoryForComputation(
const HloComputation& computation, const HloInstructionSequence& sequence,
const HloAliasAnalysis& alias_analysis,
- const LogicalBuffer::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>*
- memory_by_computation = nullptr);
+ const LogicalBuffer::SizeFunction& size_function);
static absl::StatusOr<int64_t> MinimumMemoryForComputation(
const HloComputation& computation, const HloInstructionSequence& sequence,
@@ -184,9 +182,7 @@
const HloInstructionSequence& instruction_sequence,
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_fn,
- const Options& options = Options(),
- const absl::flat_hash_map<const HloComputation*, int64_t>*
- memory_by_computation = nullptr);
+ const Options& options = Options());
// Same as above, but runs on with a schedule that covers all nested
// computations.
@@ -204,9 +200,7 @@
// be run recursively. I.e. the simulation is run over the whole module.
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
const BufferValue::SizeFunction& size_fn,
- const Options& options, const HloSchedule* schedule = nullptr,
- const absl::flat_hash_map<const HloComputation*, int64_t>*
- memory_by_computation = nullptr);
+ const Options& options, const HloSchedule* schedule = nullptr);
~HeapSimulator();
absl::Status RunComputation(
@@ -244,13 +238,10 @@
const std::unique_ptr<HeapAlgorithm<HloValue>> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
- // schedule_ is set by buffer assignment, and memory_by_computation_ is
- // set by hlo scheduling. Then, in RunComputation, we check both in order to
- // handle subcomputations. It would be good to unify the handling of
- // subcomputations, but it's not clear how.
+ // schedule_ is set by buffer assignment. Then, in RunComputation, we check
+ // both in order to handle subcomputations. It would be good to unify the
+ // handling of subcomputations, but it's not clear how.
const HloSchedule* schedule_;
- const absl::flat_hash_map<const HloComputation*, int64_t>*
- memory_by_computation_;
// Hold some sets for error-checking the sequence of Alloc and Free calls.
absl::flat_hash_set<const HloValue*> allocated_buffers_;
@@ -290,9 +281,7 @@
virtual void AccountForSubcomputationMemory(
const HloInstruction* instruction,
// The total number of bytes allocated by instruction.
- int64_t alloc_size_by_instruction,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation) {}
+ int64_t alloc_size_by_instruction) {}
// Free de-allocates a previously allocated buffer.
virtual void Free(const BufferType* buffer, int64_t size) = 0;
@@ -328,9 +317,8 @@
void Alloc(const BufferType* buffer, int64_t size) override;
void AccountForSubcomputationMemory(
- const HloInstruction* instruction, int64_t alloc_size_by_instruction,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation) override;
+ const HloInstruction* instruction,
+ int64_t alloc_size_by_instruction) override;
void Free(const BufferType* buffer, int64_t size) override;
diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc
index cff0e2f..c0f48d4 100644
--- a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc
+++ b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc
@@ -210,9 +210,6 @@
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
- absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
- memory_by_computation[cond_computation] = 5;
- memory_by_computation[body_computation] = 16;
std::unique_ptr<HloAliasAnalysis> alias_analysis =
HloAliasAnalysis::Run(module.get()).value();
@@ -221,7 +218,7 @@
// so we don't double count.
EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
*entry_computation, schedule.sequence(entry_computation),
- *alias_analysis, size_fn, &memory_by_computation)
+ *alias_analysis, size_fn)
.value());
}
diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto
index fdeaa68..83c3b85 100644
--- a/third_party/xla/xla/service/hlo.proto
+++ b/third_party/xla/xla/service/hlo.proto
@@ -112,7 +112,7 @@
}
// Serialization of HloInstruction.
-// Next ID: 89
+// Next ID: 90
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -385,6 +385,9 @@
// For HLO value tracking.
xla.OriginalValueProto original_value = 88;
+
+ // Specifies if a call instruction is a composite.
+ bool is_composite = 89;
}
// Serialization of HloComputation.
@@ -579,6 +582,7 @@
FUSION = 2;
LAYOUT = 3;
DOT = 4;
+ FLAGNET = 5;
}
// Information about the optimization profile that this module contains.
diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc
index 16ce018..a7190b3 100644
--- a/third_party/xla/xla/service/hlo_computation_test.cc
+++ b/third_party/xla/xla/service/hlo_computation_test.cc
@@ -15,8 +15,8 @@
#include "xla/hlo/ir/hlo_computation.h"
+#include <cstdint>
#include <memory>
-#include <set>
#include <string>
#include <string_view>
#include <vector>
@@ -24,19 +24,24 @@
#include <gmock/gmock.h>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "xla/comparison_util.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_matchers.h"
-#include "xla/literal.h"
+#include "xla/literal_util.h"
#include "xla/service/hlo_parser.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/shape.h"
+#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/test_helpers.h"
#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
namespace xla {
@@ -940,5 +945,32 @@
cloned_done.get()->async_wrapped_computation());
}
+TEST_F(HloComputationTest, CompositeCall) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ add (x: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %constant = f32[] constant(2)
+ ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+ }
+
+ ENTRY %CallR0F32AddScalar.v2 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call = f32[] call(f32[] %constant.1), to_apply=add, is_composite=true,
+ frontend_attributes={
+ composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},
+ composite.name="foo.bar",
+ composite.version="1"
+ }
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ HloInstruction* composite_call = FindInstruction(module.get(), "call");
+ EXPECT_EQ(composite_call->opcode(), HloOpcode::kCall);
+ EXPECT_TRUE(composite_call->is_composite());
+ EXPECT_EQ(composite_call->frontend_attributes().map().size(), 3);
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc
index 981b967..20f7690 100644
--- a/third_party/xla/xla/service/hlo_instruction_test.cc
+++ b/third_party/xla/xla/service/hlo_instruction_test.cc
@@ -15,6 +15,11 @@
#include "xla/hlo/ir/hlo_instruction.h"
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <limits>
+#include <memory>
#include <optional>
#include <set>
#include <string>
@@ -22,16 +27,23 @@
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal.h"
+#include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout_util.h"
+#include "xla/literal_util.h"
#include "xla/protobuf_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/test_helpers.h"
@@ -40,6 +52,7 @@
#include "xla/util.h"
#include "xla/window_util.h"
#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
namespace xla {
namespace {
@@ -2752,7 +2765,7 @@
module->AddEntryComputation(main_builder.Build());
// Should find conditional branch computations in the graph and it should
- // point to the conditonal instruction.
+ // point to the conditional instruction.
int num_conditional_branch_comp = 0;
for (HloComputation* comp : module->MakeComputationPostOrder()) {
if (comp->IsConditionalBranchComputation()) {
@@ -2827,7 +2840,7 @@
module->AddEntryComputation(main_builder.Build());
// Should find conditional branch computations in the graph and it should
- // point to the conditonal instruction.
+ // point to the conditional instruction.
int num_conditional_branch_comp = 0;
for (HloComputation* comp : module->MakeComputationPostOrder()) {
if (comp->IsConditionalBranchComputation()) {
diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.cc b/third_party/xla/xla/service/hlo_memory_scheduler.cc
index 283b82e..83e4072 100644
--- a/third_party/xla/xla/service/hlo_memory_scheduler.cc
+++ b/third_party/xla/xla/service/hlo_memory_scheduler.cc
@@ -90,11 +90,8 @@
static absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation) {
- ListScheduler scheduler(computation, points_to_analysis, size_function,
- memory_by_computation);
+ const BufferValue::SizeFunction& size_function) {
+ ListScheduler scheduler(computation, points_to_analysis, size_function);
return scheduler.CreateSchedule();
}
@@ -115,13 +112,10 @@
ListScheduler(HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation)
+ const BufferValue::SizeFunction& size_function)
: computation_(computation),
points_to_analysis_(points_to_analysis),
- size_function_(size_function),
- memory_by_computation_(memory_by_computation) {
+ size_function_(size_function) {
// Create a map containing the LogicalBuffer uses for each HLO
// instruction. An HLO instruction "uses" a LogicalBuffer if the
// LogicalBuffer is in an operand of the instruction as indicated by
@@ -242,29 +236,7 @@
freed_bytes += size_function_(*buffer);
}
}
- // We only count the memory usage of the largest subcomputation, instead of
- // adding them all, because subcomputations won't execute in parallel.
- int64_t max_subcomputation_bytes = 0;
- for (const auto* c : instruction->called_computations()) {
- auto it = memory_by_computation_.find(c);
- if (it != memory_by_computation_.end()) {
- int64_t subcomputation_bytes = it->second;
- if (subcomputation_bytes > max_subcomputation_bytes) {
- max_subcomputation_bytes = subcomputation_bytes;
- }
- }
- }
- int64_t bytes_defined;
- if (max_subcomputation_bytes > 0 &&
- (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
- opcode == HloOpcode::kConditional)) {
- // The output buffer of while/call/conditional is always aliased with the
- // output buffer of the root instruction in the body. Don't double count.
- bytes_defined = max_subcomputation_bytes;
- } else {
- bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
- }
- return freed_bytes - bytes_defined;
+ return freed_bytes - entry.bytes_defined;
}
// Constructs the scheduling priority of the given instruction.
@@ -392,11 +364,6 @@
HloComputation* computation_;
const TuplePointsToAnalysis& points_to_analysis_;
const BufferValue::SizeFunction& size_function_;
- // Computations are analyzed in post-order. When scheduling an instruction
- // that includes subcomputations, such as a while loop, we use this map to
- // look up the memory needed by subcomputations.
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation_;
// A map containing the LogicalBuffers that each instruction uses.
absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
@@ -426,19 +393,15 @@
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
VLOG(2) << "Computation: " << computation->name();
if (algorithm) {
return algorithm(computation, points_to_analysis, alias_analysis,
- size_function, memory_by_computation, postprocessor,
- peak_memory);
+ size_function, postprocessor, peak_memory);
}
return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis,
- size_function, memory_by_computation,
- postprocessor, peak_memory);
+ size_function, postprocessor, peak_memory);
}
} // namespace
@@ -448,8 +411,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
// These variables are a hack to prevent overflows.
int64_t cumulative_total_size = 0;
@@ -526,9 +487,9 @@
CHECK_EQ(sequence.size(), computation->instruction_count());
if (peak_memory) {
TF_ASSIGN_OR_RETURN(
- *peak_memory, HeapSimulator::MinimumMemoryForComputation(
- *computation, sequence, alias_analysis, size_function,
- &memory_by_computation));
+ *peak_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ *computation, sequence, alias_analysis, size_function));
}
return sequence;
}
@@ -538,8 +499,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
// Index of HloInstruction in the `computation`.
absl::flat_hash_map<const HloInstruction*, int64_t> inst_index;
@@ -586,9 +545,9 @@
CHECK_EQ(sequence.size(), computation->instruction_count());
if (peak_memory) {
TF_ASSIGN_OR_RETURN(
- *peak_memory, HeapSimulator::MinimumMemoryForComputation(
- *computation, sequence, alias_analysis, size_function,
- &memory_by_computation));
+ *peak_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ *computation, sequence, alias_analysis, size_function));
}
return sequence;
@@ -605,16 +564,14 @@
const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) -> absl::StatusOr<HloSchedule> {
HloSchedule schedule(module);
- absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
for (auto* computation :
module->MakeComputationPostOrder(execution_threads)) {
if (!computation->IsFusionComputation()) {
- TF_ASSIGN_OR_RETURN(
- HloInstructionSequence computation_sequence,
- ScheduleComputationHelper(
- computation, points_to_analysis, alias_analysis, size_func,
- computation_scheduler, memory_by_computation, postprocessor,
- /*peak_memory=*/nullptr));
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
+ ScheduleComputationHelper(
+ computation, points_to_analysis, alias_analysis,
+ size_func, computation_scheduler, postprocessor,
+ /*peak_memory=*/nullptr));
schedule.set_sequence(computation, std::move(computation_sequence));
}
}
@@ -631,20 +588,18 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
- TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence,
- ListScheduler::Run(computation, points_to_analysis,
- size_function, memory_by_computation));
+ TF_ASSIGN_OR_RETURN(
+ HloInstructionSequence sequence,
+ ListScheduler::Run(computation, points_to_analysis, size_function));
if (postprocessor) {
sequence = postprocessor(sequence);
}
if (peak_memory) {
TF_ASSIGN_OR_RETURN(
- *peak_memory, HeapSimulator::MinimumMemoryForComputation(
- *computation, sequence, alias_analysis, size_function,
- &memory_by_computation));
+ *peak_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ *computation, sequence, alias_analysis, size_function));
}
return sequence;
}
@@ -654,8 +609,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
HloInstructionSequence sequence(computation->MakeInstructionPostOrder());
if (postprocessor) {
@@ -663,9 +616,9 @@
}
if (peak_memory) {
TF_ASSIGN_OR_RETURN(
- *peak_memory, HeapSimulator::MinimumMemoryForComputation(
- *computation, sequence, alias_analysis, size_function,
- &memory_by_computation));
+ *peak_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ *computation, sequence, alias_analysis, size_function));
}
return sequence;
}
@@ -675,8 +628,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
// We try a few schedulers and choose whichever returns a lower min-memory,
// not accounting for fragmentation.
@@ -690,24 +641,21 @@
TF_ASSIGN_OR_RETURN(
HloInstructionSequence list_sequence,
ListMemoryScheduler(computation, points_to_analysis, alias_analysis,
- size_function, memory_by_computation, postprocessor,
- &list_memory));
+ size_function, postprocessor, &list_memory));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
int64_t dfs_memory;
TF_ASSIGN_OR_RETURN(
HloInstructionSequence dfs_sequence,
DFSMemoryScheduler(computation, points_to_analysis, alias_analysis,
- size_function, memory_by_computation, postprocessor,
- &dfs_memory));
+ size_function, postprocessor, &dfs_memory));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
int64_t post_order_memory;
- TF_ASSIGN_OR_RETURN(
- HloInstructionSequence post_order_sequence,
- PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis,
- size_function, memory_by_computation,
- postprocessor, &post_order_memory));
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence post_order_sequence,
+ PostOrderMemoryScheduler(
+ computation, points_to_analysis, alias_analysis,
+ size_function, postprocessor, &post_order_memory));
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
@@ -815,21 +763,6 @@
return std::move(schedule);
}
-absl::StatusOr<HloInstructionSequence> ScheduleComputation(
- HloComputation* computation, const BufferValue::SizeFunction& size_function,
- const MemorySchedulerPostprocessor& postprocessor) {
- CHECK(!computation->IsFusionComputation());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
- TuplePointsToAnalysis::Run(computation->parent()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(computation->parent()));
- absl::flat_hash_map<const HloComputation*, int64_t> empty_map;
- return ScheduleComputationHelper(
- computation, *points_to_analysis, *alias_analysis, size_function,
- /*algorithm=*/nullptr, empty_map, postprocessor,
- /*peak_memory=*/nullptr);
-}
-
HloMemoryScheduler::HloMemoryScheduler(
const BufferValue::SizeFunction& size_function,
const ModuleSchedulerAlgorithm& algorithm)
diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.h b/third_party/xla/xla/service/hlo_memory_scheduler.h
index 112ced3..2fb211a 100644
--- a/third_party/xla/xla/service/hlo_memory_scheduler.h
+++ b/third_party/xla/xla/service/hlo_memory_scheduler.h
@@ -19,7 +19,6 @@
#include <cstdint>
#include <functional>
-#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@@ -51,7 +50,6 @@
std::function<absl::StatusOr<HloInstructionSequence>(
HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
const LogicalBuffer::SizeFunction&,
- const absl::flat_hash_map<const HloComputation*, int64_t>&,
const MemorySchedulerPostprocessor&,
/*peak_memory*/ int64_t*)>;
@@ -73,8 +71,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
// DFS-order scheduler
@@ -83,8 +79,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
// BFS-order scheduler
@@ -102,8 +96,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
// Naive Post Order scheduler
@@ -112,8 +104,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
// The default scheduling algorithm. Runs the list scheduler, the DFS scheduler,
@@ -125,8 +115,6 @@
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const absl::flat_hash_map<const HloComputation*, int64_t>&
- memory_by_computation,
const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
absl::StatusOr<HloSchedule> DefaultModuleScheduler(
@@ -146,13 +134,6 @@
const absl::flat_hash_set<absl::string_view>& execution_threads = {},
int64_t* peak_memory = nullptr);
-// Computes the schedule for a single computation.
-// Currently only used by the GPU backend.
-absl::StatusOr<HloInstructionSequence> ScheduleComputation(
- HloComputation* computation,
- const LogicalBuffer::SizeFunction& size_function,
- const MemorySchedulerPostprocessor& postprocessor);
-
// A pass which schedules the HLO instructions in a module. The HloModule's
// schedule field is set to the resulting HloSchedule using
// HloModule::set_schedule.
diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc
index 2910932..f237575 100644
--- a/third_party/xla/xla/service/hlo_module_test.cc
+++ b/third_party/xla/xla/service/hlo_module_test.cc
@@ -37,9 +37,9 @@
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc
index 8f097ca..2ff0697 100644
--- a/third_party/xla/xla/service/hlo_parser.cc
+++ b/third_party/xla/xla/service/hlo_parser.cc
@@ -80,6 +80,7 @@
#include "tsl/lib/gtl/map_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
namespace xla {
@@ -1915,6 +1916,7 @@
std::vector<HloInstruction*> async_wrapped_operands;
std::vector<Shape> async_wrapped_operand_shapes;
Shape async_wrapped_root_shape;
+ async_wrapped_operand_shapes.reserve(operands.size());
for (const HloInstruction* operand : operands) {
async_wrapped_operand_shapes.push_back(operand->shape());
}
@@ -2249,8 +2251,11 @@
}
case HloOpcode::kCall: {
optional<HloComputation*> to_apply;
+ optional<bool> is_composite = false;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
+ attrs["is_composite"] = {/*required=*/false, AttrTy::kBool,
+ &is_composite};
if ((!preset_operands && !ParseOperands(&operands, builder)) ||
!ParseAttributes(attrs, allow_attributes, shape)) {
return nullptr;
@@ -2266,8 +2271,10 @@
})) {
return nullptr;
}
- return builder->AddInstruction(
- HloInstruction::CreateCall(*shape, operands, *to_apply));
+
+ auto call_op = HloInstruction::CreateCall(*shape, operands, *to_apply);
+ call_op->set_is_composite(is_composite.value());
+ return builder->AddInstruction(std::move(call_op));
}
case HloOpcode::kReduceWindow: {
optional<HloComputation*> reduce_computation;
@@ -3178,6 +3185,13 @@
optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted};
+ optional<std::vector<int64_t>> operand_batching_dims;
+ attrs["operand_batching_dims"] = {
+ /*required=*/false, AttrTy::kBracedInt64List, &operand_batching_dims};
+ optional<std::vector<int64_t>> start_indices_batching_dims;
+ attrs["start_indices_batching_dims"] = {/*required=*/false,
+ AttrTy::kBracedInt64List,
+ &start_indices_batching_dims};
if ((!preset_operands &&
!ParseOperands(&operands, builder, /*expected_size=*/2)) ||
@@ -3190,7 +3204,13 @@
/*offset_dims=*/*offset_dims,
/*collapsed_slice_dims=*/*collapsed_slice_dims,
/*start_index_map=*/*start_index_map,
- /*index_vector_dim=*/*index_vector_dim);
+ /*index_vector_dim=*/*index_vector_dim,
+ /*operand_batching_dims=*/
+ operand_batching_dims ? *operand_batching_dims
+ : std::vector<int64_t>(),
+ /*start_indices_batching_dims=*/
+ start_indices_batching_dims ? *start_indices_batching_dims
+ : std::vector<int64_t>());
if (!maybe_infer_shape([&] {
return ShapeInference::InferGatherShape(operands[0]->shape(),
operands[1]->shape(),
@@ -3226,6 +3246,13 @@
optional<bool> unique_indices = false;
attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
&unique_indices};
+ optional<std::vector<int64_t>> input_batching_dims;
+ attrs["input_batching_dims"] = {
+ /*required=*/false, AttrTy::kBracedInt64List, &input_batching_dims};
+ optional<std::vector<int64_t>> scatter_indices_batching_dims;
+ attrs["scatter_indices_batching_dims"] = {/*required=*/false,
+ AttrTy::kBracedInt64List,
+ &scatter_indices_batching_dims};
if ((!preset_operands && !ParseOperands(&operands, builder)) ||
!ParseAttributes(attrs, allow_attributes, shape)) {
@@ -3243,7 +3270,13 @@
/*update_window_dims=*/*update_window_dims,
/*inserted_window_dims=*/*inserted_window_dims,
/*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
- /*index_vector_dim=*/*index_vector_dim);
+ /*index_vector_dim=*/*index_vector_dim,
+ /*input_batching_dims=*/
+ input_batching_dims ? *input_batching_dims
+ : std::vector<int64_t>(),
+ /*scatter_indices_batching_dims=*/
+ scatter_indices_batching_dims ? *scatter_indices_batching_dims
+ : std::vector<int64_t>());
if (!maybe_infer_shape([&] {
absl::InlinedVector<const Shape*, 3> arg_shapes;
@@ -3421,11 +3454,21 @@
if (!ParseAttributeName(&attribute)) {
return false;
}
- if (lexer_.GetKind() != TokKind::kString) {
+
+ std::string result;
+ if (lexer_.GetKind() == TokKind::kString) {
+ if (!ParseString(&result)) {
+ return false;
+ }
+ } else if (lexer_.GetKind() == TokKind::kLbrace) {
+ if (!ParseJsonDict(&result)) {
+ return false;
+ }
+ } else {
return false;
}
- (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
- lexer_.Lex();
+
+ (*frontend_attributes->mutable_map())[attribute] = result;
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace,
diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc
index 4c67c01..6378f08 100644
--- a/third_party/xla/xla/service/hlo_parser_test.cc
+++ b/third_party/xla/xla/service/hlo_parser_test.cc
@@ -15,6 +15,7 @@
#include "xla/service/hlo_parser.h"
+#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
@@ -22,16 +23,25 @@
#include <vector>
#include <gtest/gtest.h>
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/array.h"
#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/service/hlo_lexer.h"
+#include "xla/service/hlo_module_config.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/shape.h"
@@ -40,6 +50,7 @@
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/window_util.h"
#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
@@ -68,7 +79,7 @@
//
// In general we want to avoid these because we want HLO text to be
// round-trippable! But nested instructions, e.g. add(sqrt(x), y), cannot be
-// round-triped without modification.
+// round-tripped without modification.
struct NonRoundtripTestData {
std::string test_name;
std::string input_module_string;
@@ -463,6 +474,96 @@
)"
},
+// composite call
+{
+"CompositeCall",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %constant = f32[] constant(2)
+ ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+)"
+},
+// composite call with extra frontend attributes
+{
+"CompositeCallWithExtraFrontendAttributes",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %constant = f32[] constant(2)
+ ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1",foo="bar"}
+}
+
+)"
+},
+// composite call optional composite.attributes and composite.version
+{
+"CompositeCallOptionalAttributesAndVersion",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %constant = f32[] constant(2)
+ ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.name="foo.bar"}
+}
+
+)"
+},
+// composite call optional composite.attributes
+{
+"CompositeCallOptionalAttributes",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %constant = f32[] constant(2)
+ ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.name="foo.bar",composite.version="1"}
+}
+
+)"
+},
+// composite call optional composite.version
+{
+"CompositeCallOptionalVersion",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %constant = f32[] constant(2)
+ ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar"}
+}
+
+)"
+},
// CustomCall with backend_config.
{
"CustomCallWithOpaque",
@@ -1062,6 +1163,18 @@
)"
},
{
+"BatchGather",
+R"(HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0}}
+
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512], start_indices: s64[10,9,8,7,5,512]) -> f32[10,9,8,7,30,29,28,27,26,512] {
+ %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0)
+ %start_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1)
+ ROOT %gather = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, operand_batching_dims={5}, start_indices_batching_dims={5}, index_vector_dim=4, slice_sizes={30,29,28,27,26,1}
+}
+
+)"
+},
+{
"Scatter",
R"(HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0}, s64[10,9,8,7,5]{4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
@@ -1081,6 +1194,25 @@
)"
},
{
+"BatchScatter",
+R"(HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46,512]{5,4,3,2,1,0}}
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512], scatter_indices: s64[10,9,8,7,5,512], updates: f32[10,9,8,7,30,29,28,27,26,512]) -> f32[50,49,48,47,46,512] {
+ %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0)
+ %scatter_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1)
+ %updates = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} parameter(2)
+ ROOT %scatter = f32[50,49,48,47,46,512]{5,4,3,2,1,0} scatter(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, input_batching_dims={5}, scatter_indices_batching_dims={5}, index_vector_dim=4, to_apply=%add_F32.v3
+}
+
+)"
+},
+{
"TupleScatter",
R"(HloModule TupleScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0}, s64[10,9,8,7,5]{4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}, bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->(f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0})}
diff --git a/third_party/xla/xla/service/hlo_unstacker.cc b/third_party/xla/xla/service/hlo_unstacker.cc
index c6b0971..21d0eb9 100644
--- a/third_party/xla/xla/service/hlo_unstacker.cc
+++ b/third_party/xla/xla/service/hlo_unstacker.cc
@@ -790,14 +790,8 @@
HloInstruction* bitcast = mutable_dynamic_slicing_fusion->AddInstruction(
HloInstruction::CreateBitcast(mutable_dynamic_slicing_fusion->shape(),
new_operand));
- HloInstruction* bitcast_fusion =
- mutable_dynamic_slicing_fusion->AddInstruction(
- HloInstruction::CreateFusion(mutable_dynamic_slicing_fusion->shape(),
- HloInstruction::FusionKind::kLoop,
- bitcast));
-
return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape(
- bitcast_fusion);
+ bitcast);
}
// This function recognizes fusions with the following pattern:
@@ -1430,6 +1424,7 @@
/*force_unroll=*/true, /*prepare=*/false));
CHECK(unrolled);
}
+ VLOG(3) << "after unstacking \n" << module->ToString();
return true;
}
diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc
index 37a9843..3b00f92 100644
--- a/third_party/xla/xla/service/hlo_unstacker_test.cc
+++ b/third_party/xla/xla/service/hlo_unstacker_test.cc
@@ -34,18 +34,18 @@
using UnstackerTest = HloTestBase;
-int64_t GetSliceCountInEntry(HloModule* module) {
- int64_t slice_instrs_count = 0;
+int64_t GetInstrCountWithOpcodeInEntry(HloModule* module, HloOpcode opcode) {
+ int64_t instr_with_opcode_count = 0;
for (HloInstruction* instr :
module->entry_computation()->MakeInstructionPostOrder()) {
- if (instr->opcode() == HloOpcode::kSlice) {
- slice_instrs_count++;
+ if (instr->opcode() == opcode) {
+ instr_with_opcode_count++;
}
}
- return slice_instrs_count;
+ return instr_with_opcode_count;
}
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) {
+TEST_F(UnstackerTest, UnstackDSFusionPattern) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -63,7 +63,8 @@
p1 = s8[3,128,128] get-tuple-element(wide_p), index=2
one = s32[] constant(1)
inc = s32[] add(i, one)
- %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf
+ %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice
+ conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf
ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1)
}
@@ -80,7 +81,7 @@
init = s32[] constant(0)
while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0)
while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body
- while_use = s8[3,128,128] get-tuple-element(while.out), index=2
+ while_use = s8[3,128,128] get-tuple-element(while.out), index=2
ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
}
)";
@@ -90,12 +91,15 @@
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
+ // Check that the bitcast is unfused and there are not fusions.
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+ 0);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
- std::nullopt));
+ std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUser2) {
+TEST_F(UnstackerTest, UnstackReduceFusionPattern) {
std::string hlo_string = R"(
HloModule SimpleLoop
dynamic-slice.609.reduce_sub_computation {
@@ -148,10 +152,10 @@
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
EXPECT_TRUE(unstacked);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
- std::nullopt));
+ std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUserNoBitcast) {
+TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] {
@@ -195,15 +199,17 @@
ParseAndReturnVerifiedModule(hlo_string));
auto original = module->Clone();
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
- std::cout << module->ToString() << std::endl;
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
+ // Check that all the fusions are removed.
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+ 0);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUserNoBitcastKeepFused) {
+TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] {
@@ -249,33 +255,35 @@
auto unfuse = [](HloInstruction* instruction) { return false; };
TF_ASSERT_OK_AND_ASSIGN(bool unstacked,
HloUnstacker(unfuse).Run(module.get()));
- std::cout << module->ToString() << std::endl;
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 0);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 0);
+ // Check that dynamic-slices are still fused.
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+ 3);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUserDifferentLayout) {
+TEST_F(UnstackerTest, UnstackDSFusionPatternWithDifferentLayout) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.30.clone (param_0.153: bf16[32,4,64,64,3], param_1.123: s32[]) -> bf16[64,4,64,3] {
- %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} parameter(0)
+ %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0} parameter(0)
%param_1.123 = s32[]{:T(128)} parameter(1)
%constant.227 = s32[]{:T(128)} constant(0)
- %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}
- ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %dynamic-slice.5)
+ %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3}
+ ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0} %dynamic-slice.5)
}
%while.body (wide_param: (s32[], bf16[8,128], bf16[32,4,64,64,3])) -> (s32[], bf16[8,128], bf16[32,4,64,64,3]) {
wide_p = (s32[], bf16[8,128], bf16[32,4,64,64,3]) parameter(0)
i = s32[] get-tuple-element(wide_p), index=0
p0 = bf16[8,128] get-tuple-element(wide_p), index=1
- p1 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} get-tuple-element(wide_p), index=2
+ p1 = bf16[32,4,64,64,3]{2,1,4,3,0} get-tuple-element(wide_p), index=2
one = s32[] constant(1)
inc = s32[] add(i, one)
- %fusion.67830 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone
+ %fusion.67830 = bf16[64,4,64,3]{0,1,3,2} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone
ROOT out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(inc, p0, p1)
}
@@ -291,7 +299,7 @@
p1 = bf16[8,128] parameter(1)
init = s32[] constant(0)
while.input = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(init, p1, p0)
- while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body
+ while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body
while_use = bf16[32,4,64,64,3] get-tuple-element(while.out), index=2
ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
}
@@ -301,11 +309,17 @@
auto original = module->Clone();
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
EXPECT_TRUE(unstacked);
+ // Check for the creation of slice instructions.
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice),
+ 32);
+ // Check that dynamic-slices are still fused.
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+ 0);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt));
}
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPattern) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -358,14 +372,14 @@
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt, false));
}
// Instead of slicing the entire shape, this test slices only even elements from
// the first parameter.
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDynamicIndex) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithDynamicIndex) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] {
@@ -423,7 +437,7 @@
std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserMultipleIndex) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithMultipleIndex) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice.1 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
@@ -497,12 +511,12 @@
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions. For each unstacked input, we
// create 4 slices, 8 in total.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 8);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDiffereOperandsOrder) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithDiffereOperandsOrder) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -555,12 +569,12 @@
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackLoopMultipleNestedFusionUsersSameUnstackingComps) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithSameUnstackingComps) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice.1 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -631,12 +645,12 @@
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt, false));
}
-TEST_F(UnstackerTest, NotUnstackLoopMultipleDifferentUnstackingComps) {
+TEST_F(UnstackerTest, NotUnstackNestedDSFusionPatternWithSameUnstackingComps) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice.1 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] {
@@ -691,7 +705,86 @@
EXPECT_FALSE(unstacked);
}
-TEST_F(UnstackerTest, UnstackMultipleLoops) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternSingleNestedLoop) {
+ std::string hlo_string = R"(
+ HloModule SimpleLoop
+ %fused_computation.slice (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
+ %param_0.51117 = s8[4,128,128] parameter(0)
+ p1 = s32[] parameter(1)
+ %constant.85694 = s32[] constant(0)
+ %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[4,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128}
+ ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040)
+ }
+
+ %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] {
+ %param_0.34523 = bf16[8,128] parameter(0)
+ %param_1.30691 = s8[4,128,128] parameter(1)
+ p2 = s32[] parameter(2)
+ %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice
+ ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf
+ }
+
+ %while.body.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
+ wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+ i = s32[] get-tuple-element(wide_p), index=0
+ inner_param_0 = bf16[8,128] get-tuple-element(wide_p), index=1
+ inner_param_1 = s8[4,128,128] get-tuple-element(wide_p), index=2
+ one = s32[] constant(1)
+ inc = s32[] add(i, one)
+ fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner
+ ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(inc, fusion.conv, inner_param_1)
+ }
+
+ %while.cond.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
+ wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+ i = s32[] get-tuple-element(wide_p), index=0
+ %constant.12857 = s32[] constant(4)
+ ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT
+ }
+
+ %while.body (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
+ wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+ i = s32[] get-tuple-element(wide_p), index=0
+ param0 = bf16[8,128] get-tuple-element(wide_p), index=1
+ param1 = s8[4,128,128] get-tuple-element(wide_p), index=2
+ one = s32[] constant(2)
+ zero = s32[] constant(0)
+ mult = s32[] multiply(i, one)
+ inner.in = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1)
+ inner.out = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in), condition=%while.cond.inner, body=%while.body.inner
+ fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out), index=1
+ ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(mult, fusion.conv.inner, param1)
+ }
+
+ %while.cond (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
+ wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+ i = s32[] get-tuple-element(wide_p), index=0
+ %constant.12857 = s32[] constant(20)
+ add = s32[] add(%constant.12857, %constant.12857)
+ ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, add), direction=LT
+ }
+
+ ENTRY main {
+ weight = s8[4,128,128] parameter(0)
+ p1 = bf16[8,128] parameter(1)
+ init = s32[] constant(1)
+ while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight)
+ while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond , body=%while.body
+ ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ auto original = module->Clone();
+ TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
+ EXPECT_TRUE(unstacked);
+ // Check for the creation of slice instructions.
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 4);
+ EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
+ std::nullopt, false));
+}
+
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternTwoNestedLoops) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice1 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
@@ -827,91 +920,12 @@
EXPECT_TRUE(unstacked);
// Check for the creation of slice instructions. For each loop there is one
// unstacked input that creates 4 slices, in total 8 slices for two loops.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 8);
+ EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8);
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
std::nullopt, false));
}
-TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) {
- std::string hlo_string = R"(
- HloModule SimpleLoop
- %fused_computation.slice (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
- %param_0.51117 = s8[4,128,128] parameter(0)
- p1 = s32[] parameter(1)
- %constant.85694 = s32[] constant(0)
- %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[4,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128}
- ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040)
- }
-
- %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] {
- %param_0.34523 = bf16[8,128] parameter(0)
- %param_1.30691 = s8[4,128,128] parameter(1)
- p2 = s32[] parameter(2)
- %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice
- ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf
- }
-
- %while.body.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
- wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
- i = s32[] get-tuple-element(wide_p), index=0
- inner_param_0 = bf16[8,128] get-tuple-element(wide_p), index=1
- inner_param_1 = s8[4,128,128] get-tuple-element(wide_p), index=2
- one = s32[] constant(1)
- inc = s32[] add(i, one)
- fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner
- ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(inc, fusion.conv, inner_param_1)
- }
-
- %while.cond.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
- wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
- i = s32[] get-tuple-element(wide_p), index=0
- %constant.12857 = s32[] constant(4)
- ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT
- }
-
- %while.body (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
- wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
- i = s32[] get-tuple-element(wide_p), index=0
- param0 = bf16[8,128] get-tuple-element(wide_p), index=1
- param1 = s8[4,128,128] get-tuple-element(wide_p), index=2
- one = s32[] constant(2)
- zero = s32[] constant(0)
- mult = s32[] multiply(i, one)
- inner.in = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1)
- inner.out = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in), condition=%while.cond.inner, body=%while.body.inner
- fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out), index=1
- ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(mult, fusion.conv.inner, param1)
- }
-
- %while.cond (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
- wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
- i = s32[] get-tuple-element(wide_p), index=0
- %constant.12857 = s32[] constant(20)
- add = s32[] add(%constant.12857, %constant.12857)
- ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, add), direction=LT
- }
-
- ENTRY main {
- weight = s8[4,128,128] parameter(0)
- p1 = bf16[8,128] parameter(1)
- init = s32[] constant(1)
- while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight)
- while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond , body=%while.body
- ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
- }
- )";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseAndReturnVerifiedModule(hlo_string));
- auto original = module->Clone();
- TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
- EXPECT_TRUE(unstacked);
- // Check for the creation of slice instructions.
- EXPECT_EQ(GetSliceCountInEntry(module.get()), 4);
- EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
- std::nullopt, false));
-}
-
-TEST_F(UnstackerTest, UnstackSingleLoopOnlyWithDSAndDUS) {
+TEST_F(UnstackerTest, UnstackDSAndDUSPattern) {
std::string hlo_string = R"(
HloModule SimpleLoop
%fused_computation.slice (param_0.51117: s32[4,3], offset: s32[]) -> s32[3] {
@@ -975,7 +989,7 @@
// Unstacking outer loop at index 1 forces to unstacked inner while at index 1
// as well. This is because the output of the outer loop at index 1 is aliased
// to the output of the inner while at index 1.
-TEST_F(UnstackerTest, UnstackNestedLoopWithDSAndDUS) {
+TEST_F(UnstackerTest, UnstackDSAndDUSPatternNestedLoop) {
std::string hlo_string = R"(
HloModule SimpleLoop
@@ -1059,7 +1073,7 @@
// Unstacking the first loop at index 1 forces to unstack the second loop at
// index 1 as well.
-TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUS) {
+TEST_F(UnstackerTest, UnstackDSAndDUSPatternLoopFeedingLoop) {
std::string hlo_string = R"(
HloModule SimpleLoop
@@ -1076,45 +1090,43 @@
%param_0.51117 = bf16[4,1,8,257,128] parameter(0)
offset = s32[] parameter(1)
zero = s32[] constant(0)
- %dynamic-slice.22040 = bf16[1,1,8,257,128]
- dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero,
- zero, zero), dynamic_slice_sizes={1,1,8,257,128} ROOT %bitcast.31250 =
- bf16[1,8,257,128] bitcast(%dynamic-slice.22040)
+ %dynamic-slice.22040 = bf16[1,1,8,257,128] dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128}
+ ROOT %bitcast.31250 = bf16[1,8,257,128] bitcast(%dynamic-slice.22040)
}
first.body {
loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0)
- get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0
- get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0
+ get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1
constant = bf16[1,8,257,128] constant({...})
sliced = bf16[1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1), kind=kLoop, calls=%fused_computation.slice
tmp = bf16[1,8,257,128] add(sliced, sliced)
one = s32[] constant(1)
- idx = s32[] add(get-tuple-element.1, one)
+ idx = s32[] add(get-tuple-element.1, one)
ROOT out = tuple(idx, get-tuple-element.2)
}
first.condition {
- loop_var.1 = (s32[], bf16[4,1,8,257,128])
- parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),
- index=0 constant.2 = s32[] constant(4) ROOT less-than = pred[]
- compare(get-tuple-element.1, constant.2), direction=LT
+ loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.2 = s32[] constant(4)
+ ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT
}
next.body {
loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0)
- get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0
- get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0
+ get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1
constant = bf16[1,8,257,128] constant({...})
- update.sliced = bf16[4,1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1, constant), kind=kLoop, calls=%fused_computation.update.slice
+ update.sliced = bf16[4,1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1, constant), kind=kLoop, calls=%fused_computation.update.slice
one = s32[] constant(1)
- idx = s32[] add(get-tuple-element.1, one)
+ idx = s32[] add(get-tuple-element.1, one)
ROOT out = tuple(idx, update.sliced)
}
next.condition {
- loop_var.1 = (s32[], bf16[4,1,8,257,128])
- parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),
- index=0 constant.2 = s32[] constant(4) ROOT less-than = pred[]
- compare(get-tuple-element.1, constant.2), direction=LT
+ loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.2 = s32[] constant(4)
+ ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT
}
ENTRY SimpleLoop {
@@ -1138,7 +1150,7 @@
EXPECT_TRUE(unstacked);
}
-TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUSFusionWithPad) {
+TEST_F(UnstackerTest, UnstackDUSFusionWithPadPatternLoopFeedingLoop) {
std::string hlo_string = R"(
HloModule SimpleLoop
fused_computation.75.clone {
@@ -1213,7 +1225,7 @@
EXPECT_TRUE(unstacked);
}
-TEST_F(UnstackerTest, UnstackSingleLoopWithDSFusionWithAdd) {
+TEST_F(UnstackerTest, UnstackDUSFusionWithAddPattern) {
std::string hlo_string = R"(
HloModule SimpleLoop
diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
index 67259f8..06ae990 100644
--- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
+++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
@@ -1757,6 +1757,7 @@
[&](const ShapeIndex& index,
const HloValueSemantics* semantics) -> absl::Status {
std::vector<HloValueSemantics> semantics_vector;
+ semantics_vector.reserve(semantics_tree_vec.size());
for (size_t i = 0; i < semantics_tree_vec.size(); ++i) {
semantics_vector.push_back(
*(semantics_tree_vec[i].find(index)->second));
diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc
index ebf03de..3a8d958 100644
--- a/third_party/xla/xla/service/hlo_verifier.cc
+++ b/third_party/xla/xla/service/hlo_verifier.cc
@@ -35,6 +35,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
@@ -1322,6 +1323,34 @@
for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) {
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
}
+ if (call->is_composite()) {
+ TF_RET_CHECK(call->has_frontend_attributes())
+ << "A composite call op must have frontend attributes";
+ auto map = call->frontend_attributes().map();
+ if (auto name = map.find("composite.name");
+ name == map.end() || name->second.empty()) {
+ return InvalidArgument(
+ "A composite call op must have frontend attributes with key "
+ "composite.name whose value is non-empty");
+ }
+ if (auto attributes = map.find("composite.attributes");
+ attributes != map.end() && attributes->second.empty()) {
+ return InvalidArgument(
+ "A composite call op must have frontend attributes with key "
+ "composite.attributes whose value is default: {} or non-empty");
+ }
+ if (auto version_str = map.find("composite.version");
+ version_str != map.end()) {
+ int64_t version = 0;
+ if (!absl::SimpleAtoi(version_str->second, &version) || version < 0) {
+ return InvalidArgument(
+ "A composite call op must have frontend attributes with a "
+ "composite.version whose value is a non-negative integer but got: "
+ "%s",
+ version_str->second);
+ }
+ }
+ }
// The shape of kCall should match the shape of the computation it calls.
return CheckShape(call, call->to_apply()->root_instruction()->shape());
}
@@ -1920,6 +1949,26 @@
}
return ShapesSame(instruction->shape(), inferred_shape, equal);
}
+ case HloOpcode::kCopy: {
+ // Disallow host offloading copies which change FpPrecision.
+ if (opts_.IsLayoutSensitive()) {
+ if (instruction->shape().has_layout() &&
+ inferred_shape.has_layout()) {
+ int64_t instruction_memory_space =
+ instruction->shape().layout().memory_space();
+ int64_t operand_memory_space =
+ inferred_shape.layout().memory_space();
+ if (instruction_memory_space != operand_memory_space &&
+ (instruction_memory_space == Layout::kHostMemorySpace ||
+ operand_memory_space == Layout::kHostMemorySpace)) {
+ // Is a host->device copy for a device->host copy.
+ return Shape::Equal().IgnoreMemorySpaceInLayout()(
+ instruction->shape(), inferred_shape);
+ }
+ }
+ }
+ [[fallthrough]];
+ }
// We allow arbitrary layout and f32->bf16 transformations on all other
// instructions, although this may be made more strict pending discussion
@@ -2907,6 +2956,15 @@
}
}
+ if (instruction->has_to_apply() &&
+ instruction->to_apply()->execution_thread() !=
+ instruction->parent()->execution_thread()) {
+ return Internal(
+ "%s top_apply computation execution thread does not match (%s vs %s)",
+ instruction->name(), instruction->to_apply()->execution_thread(),
+ instruction->parent()->execution_thread());
+ }
+
return absl::OkStatus();
}
diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc
index 82f712a..b7055e9 100644
--- a/third_party/xla/xla/service/hlo_verifier_test.cc
+++ b/third_party/xla/xla/service/hlo_verifier_test.cc
@@ -83,6 +83,15 @@
LayoutAssignment::InstructionCanChangeLayout) {}
};
+class HloVerifierTestLayoutSensitiveAndAllowMixedPrecision
+ : public HloTestBase {
+ public:
+ HloVerifierTestLayoutSensitiveAndAllowMixedPrecision()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/true,
+ LayoutAssignment::InstructionCanChangeLayout) {}
+};
+
class HloVerifierTestLayoutFusion : public HloTestBase {
public:
HloVerifierTestLayoutFusion()
@@ -216,8 +225,164 @@
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
- HasSubstr("expects parent computation thread name same as called "
- "computation's thread name"));
+ HasSubstr("mycall top_apply computation execution thread does "
+ "not match (parallel_thread vs main)"));
+}
+
+TEST_F(HloVerifierTest, CompositeCall) {
+ constexpr absl::string_view hlo = R"(
+ HloModule Module
+
+ add_n {
+ x = f32[] parameter(0)
+ constant = f32[] constant(2)
+ ROOT z = f32[] add(f32[] x, f32[] constant)
+ }
+
+ ENTRY entry {
+ constant = f32[] constant(42)
+ ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar",composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.version="1"}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+ auto status = verifier().Run(module.get()).status();
+ EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallMissingFrontendAttributes) {
+ constexpr absl::string_view hlo = R"(
+ HloModule Module
+
+ add_n {
+ x = f32[] parameter(0)
+ constant = f32[] constant(2)
+ ROOT z = f32[] add(f32[] x, f32[] constant)
+ }
+
+ ENTRY entry {
+ constant = f32[] constant(42)
+ ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.message(),
+ HasSubstr("A composite call op must have frontend attributes"));
+}
+
+TEST_F(HloVerifierTest, CompositeCallOptionalAttributesAndVersion) {
+ constexpr absl::string_view hlo = R"(
+ HloModule Module
+
+ add_n {
+ x = f32[] parameter(0)
+ constant = f32[] constant(2)
+ ROOT z = f32[] add(f32[] x, f32[] constant)
+ }
+
+ ENTRY entry {
+ constant = f32[] constant(42)
+ ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar"}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+ auto status = verifier().Run(module.get()).status();
+ EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallOptionalAttributes) {
+ constexpr absl::string_view hlo = R"(
+ HloModule Module
+
+ add_n {
+ x = f32[] parameter(0)
+ constant = f32[] constant(2)
+ ROOT z = f32[] add(f32[] x, f32[] constant)
+ }
+
+ ENTRY entry {
+ constant = f32[] constant(42)
+ ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar",composite.version="1"}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+ auto status = verifier().Run(module.get()).status();
+ EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallMissingName) {
+ constexpr absl::string_view hlo = R"(
+ HloModule Module
+
+ add_n {
+ x = f32[] parameter(0)
+ constant = f32[] constant(2)
+ ROOT z = f32[] add(f32[] x, f32[] constant)
+ }
+
+ ENTRY entry {
+ constant = f32[] constant(42)
+ ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.version="1"}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.message(),
+ HasSubstr("A composite call op must have frontend attributes "
+ "with key composite.name whose value is non-empty"));
+}
+
+TEST_F(HloVerifierTest, CompositeCallOptionalVersion) {
+ constexpr absl::string_view hlo = R"(
+ HloModule Module
+
+ add_n {
+ x = f32[] parameter(0)
+ constant = f32[] constant(2)
+ ROOT z = f32[] add(f32[] x, f32[] constant)
+ }
+
+ ENTRY entry {
+ constant = f32[] constant(42)
+ ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar"}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+ auto status = verifier().Run(module.get()).status();
+ EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallNonNegativeVersion) {
+ constexpr absl::string_view hlo = R"(
+ HloModule Module
+
+ add_n {
+ x = f32[] parameter(0)
+ constant = f32[] constant(2)
+ ROOT z = f32[] add(f32[] x, f32[] constant)
+ }
+
+ ENTRY entry {
+ constant = f32[] constant(42)
+ ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="-1"}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(
+ status.message(),
+ HasSubstr("A composite call op must have frontend attributes with a "
+ "composite.version whose value is a non-negative integer"));
}
TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
@@ -2000,10 +2165,10 @@
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
- EXPECT_THAT(verifier().Run(module.get()).status().message(),
- HasSubstr("Nested computations expects same computation's thread "
- "name: parallel_thread vs main, in called computation "
- "`add` vs caller computation `fused_computation`"));
+ EXPECT_THAT(
+ verifier().Run(module.get()).status().message(),
+ HasSubstr("crs0 top_apply computation execution thread does not match "
+ "(parallel_thread vs main)"));
}
TEST_F(HloVerifierTest, AllReduceVerifier) {
@@ -2639,8 +2804,8 @@
.status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
- HasSubstr("expects parent computation thread name same as called "
- "computation's thread name"));
+ HasSubstr("custom top_apply computation execution thread does "
+ "not match (parallel_thread vs main)"));
}
TEST_F(HloVerifierTest, CheckWhileThread) {
@@ -3133,6 +3298,49 @@
"memory space from device to host"));
}
+TEST_F(HloVerifierTestLayoutSensitiveAndAllowMixedPrecision,
+ HostOffloadingCopyCannotChangeType) {
+ const char* const hlo_string = R"(
+HloModule m
+
+ENTRY main {
+ param = f32[1024,1024]{1,0:T(8,128)S(5)} parameter(0)
+ copy = bf16[1024,1024]{1,0:T(8,128)} copy(param)
+ ROOT dot = f32[1024,1024]{1,0:T(8,128)} dot(copy, copy), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.message(),
+ HasSubstr("Expected instruction to have shape equal to "
+ "f32[1024,1024]{1,0:T(8,128)S(5)}, actual shape is "
+ "bf16[1024,1024]{1,0:T(8,128)}"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitiveAndAllowMixedPrecision,
+ HostOffloadingCopyCannotChangeLayout) {
+ const char* const hlo_string = R"(
+HloModule m
+
+ENTRY main {
+ param = f32[1024,1024]{1,0:T(8,128)S(5)} parameter(0)
+ ROOT copy = f32[1024,1024]{0,1:T(8,128)} copy(param)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.message(),
+ HasSubstr("Expected instruction to have shape equal to "
+ "f32[1024,1024]{1,0:T(8,128)S(5)}, actual shape is "
+ "f32[1024,1024]{0,1:T(8,128)}"));
+}
+
TEST_F(HloVerifierTestLayoutSensitive,
MismatchedMinorToMajorSizeAndDimensionSize) {
const char* const hlo_string = R"(
diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc
index 5f7757b..dc59e5c 100644
--- a/third_party/xla/xla/service/latency_hiding_scheduler.cc
+++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc
@@ -486,6 +486,15 @@
});
}
+bool AsyncTracker::OccupiesSelectiveResource(const HloGraphNode* node) const {
+ return absl::c_any_of(
+ node->GetResources(), [&](const ResourcePair& resource) {
+ return resource.second == ResourceUsageType::kResourceOccupy &&
+ GetResourceHazardType(resource.first) ==
+ ResourceHazardType::kSelective;
+ });
+}
+
BufferInfoTracker::BufferInfoTracker(
const HloModule* module, const HloAliasAnalysis* alias_analysis,
const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes) {
@@ -731,6 +740,25 @@
namespace {
+// Find the num hops to the closest selective resource overlap in ready set that
+// provided node can be scheduled in between.
+int64_t GetNumHopsToClosestSelectiveOverlap(
+ const DefaultSchedulerCore::ReadyQueueSet& ready_set,
+ const HloGraphNode* node) {
+ int64_t num_hops_to_closest_selective_resource_occupier =
+ std::numeric_limits<int64_t>::max();
+ for (const HloGraphNode* n : ready_set) {
+ // Skip the node itself.
+ if (n == node) {
+ continue;
+ }
+ num_hops_to_closest_selective_resource_occupier =
+ std::min(num_hops_to_closest_selective_resource_occupier,
+ n->GetNumHopsToClosestSelectiveResourceOccupier());
+ }
+ return num_hops_to_closest_selective_resource_occupier;
+}
+
// Comparator for the ready set. This class represents the priority policies
// for the nodes in the ready set. The policy can be whatever is appropriate to
// reduce the execution time of the graph or achieve interesting properties
@@ -1002,6 +1030,31 @@
return *value;
}
}
+ // If there are no selective overlaps open currently and there will be
+ // overlaps opened in the near future, hold off scheduling instructions
+ // that are valuable for selective overlaps.
+ if (sched_state_.config.enable_selective_resources &&
+ sched_state_.selective_resource_releasers.empty()) {
+ int64_t distance_to_selective_overlap_for_a =
+ GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, a.node);
+ int64_t distance_to_selective_overlap_for_b =
+ GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, b.node);
+ // If a is valuable for selective overlap and there is a selective
+ // overlap in the near future a can be scheduled inside, hold off
+ // scheduling a and schedule b instead. Same logic applies in reverse.
+ int64_t max_distance =
+ sched_state_.config.max_hops_to_closest_selective_overlap;
+ if (auto value = DefaultSchedulerCore::ChooseBestCandidate(
+ (a.node->GetValuableForSelectiveOverlap() &&
+ distance_to_selective_overlap_for_a <= max_distance),
+ b,
+ (b.node->GetValuableForSelectiveOverlap() &&
+ distance_to_selective_overlap_for_b <= max_distance),
+ a, "kNotValuableForSelectiveOverlap")) {
+ return *value;
+ }
+ }
+
if (sched_state_.config.aggressive_scheduling_policies) {
// Favor nodes that unlock other nodes to be scheduled if possible.
// This makes us more flexible in what we can use in scheduling.
@@ -1693,6 +1746,8 @@
new_node_it->second->GetResources());
new_node_it->second->releases_selective_resource_ =
async_tracker->ReleasesSelectiveResource(new_node_it->second.get());
+ new_node_it->second->occupies_selective_resource_ =
+ async_tracker->OccupiesSelectiveResource(new_node_it->second.get());
// Gather while instructions for subsequent send-done dependency checks.
if (instr->opcode() == HloOpcode::kWhile) {
while_instrs.push_back(instr);
@@ -1900,6 +1955,25 @@
while (!stack.empty()) {
auto* node = stack.back();
stack.pop_back();
+ // If a node occupies a selective resource, it is the closest selective
+ // resource occupier to itself and is 0 hops away. Otherwise, the num hops
+ // to closest selective resource occupier is the minimum of that of all
+ // predecessors plus 1.
+ if (async_tracker->OccupiesSelectiveResource(node)) {
+ node->num_hops_to_closest_selective_resource_occupier_ = 0;
+ } else {
+ int64_t closest_predecessor_distance =
+ std::numeric_limits<int64_t>::max();
+ for (auto& pred : node->GetPredecessors()) {
+ closest_predecessor_distance = std::min(
+ closest_predecessor_distance,
+ pred.Target().num_hops_to_closest_selective_resource_occupier_);
+ }
+ if (closest_predecessor_distance != std::numeric_limits<int64_t>::max()) {
+ node->num_hops_to_closest_selective_resource_occupier_ =
+ closest_predecessor_distance + 1;
+ }
+ }
if (async_tracker->IsSupportedAsyncDone(node->GetInstr())) {
for (auto& pred : node->GetPredecessors()) {
node->SetAsyncDepth(
diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h
index ebe1cf0..b0d8a8d 100644
--- a/third_party/xla/xla/service/latency_hiding_scheduler.h
+++ b/third_party/xla/xla/service/latency_hiding_scheduler.h
@@ -137,6 +137,7 @@
bool resource_serializing = false;
bool depth_based_memory_pressure_reduction = false;
bool enable_selective_resources = false;
+ int64_t max_hops_to_closest_selective_overlap = 0;
int64_t rerun = 0;
};
@@ -284,6 +285,9 @@
// Returns whether the provided node releases a selective resource.
bool ReleasesSelectiveResource(const HloGraphNode* node) const;
+ // Returns whether the provided node occupies a selective resource.
+ bool OccupiesSelectiveResource(const HloGraphNode* node) const;
+
inline CanonicalAsyncOp GetCanonicalAsyncOp(const HloInstruction& hlo) const {
return get_canonical_async_op_(hlo);
}
@@ -386,6 +390,17 @@
bool ReleasesSelectiveResource() const {
return releases_selective_resource_;
}
+ bool OccupiesSelectiveResource() const {
+ return occupies_selective_resource_;
+ }
+ int64_t GetNumHopsToClosestSelectiveResourceOccupier() const {
+ return num_hops_to_closest_selective_resource_occupier_;
+ }
+ void SetNumHopsToClosestSelectiveResourceOccupier(
+ int64_t num_hops_to_closest_selective_resource_occupier) {
+ num_hops_to_closest_selective_resource_occupier_ =
+ num_hops_to_closest_selective_resource_occupier;
+ }
ResourcesVector GetResources() const { return resources_; }
bool DoesOccupyAnyResource() const {
@@ -525,6 +540,11 @@
bool valuable_for_selective_overlap_ = true;
// Whether this node releases a selective resource.
bool releases_selective_resource_ = false;
+ // Whether this node occupies a selective resource.
+ bool occupies_selective_resource_ = false;
+ // Nums hops to closest selective resource occupier.
+ int64_t num_hops_to_closest_selective_resource_occupier_ =
+ std::numeric_limits<int64_t>::max();
};
// Schedule graph that can be used to drive scheduling
@@ -920,7 +940,6 @@
virtual absl::StatusOr<HloGraphNode*> FindAndExtractBestNodeAvailable(
SchedulingState& sched_state,
DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node);
- bool DoesNodeReleaseSelectiveResource(const HloGraphNode* node) const;
void DumpLatencyHidingSchedule(
const HloComputation* computation, const HloScheduleGraph& schedule_graph,
const std::vector<HloInstruction*>& instructions,
diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc
index 29a4f4b..8c9c290 100644
--- a/third_party/xla/xla/service/llvm_ir/ir_array.cc
+++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc
@@ -527,6 +527,7 @@
if (!index.LinearValidOnShape(shape_)) {
// Create a valid linear index.
std::vector<int64_t> dimensions;
+ dimensions.reserve(shape_.rank());
for (int64_t i = 0; i < shape_.rank(); ++i) {
dimensions.push_back(shape_.dimensions(i));
}
diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h
index 9ec78b0..691f93f 100644
--- a/third_party/xla/xla/service/llvm_ir/ir_array.h
+++ b/third_party/xla/xla/service/llvm_ir/ir_array.h
@@ -250,9 +250,9 @@
IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape);
// Default implementations of copying and moving.
- IrArray(IrArray&& other) = default;
+ IrArray(IrArray&& other) noexcept = default;
IrArray(const IrArray& other) = default;
- IrArray& operator=(IrArray&& other) = default;
+ IrArray& operator=(IrArray&& other) noexcept = default;
IrArray& operator=(const IrArray& other) = default;
llvm::Value* GetBasePointer() const { return base_ptr_; }
diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc
index 399c335..0ed7bac 100644
--- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc
+++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc
@@ -715,17 +715,6 @@
return result;
}
-static absl::Status CreateAndWriteStringToFile(
- const std::string& directory_name, const std::string& file_name,
- const std::string& text) {
- std::unique_ptr<tsl::WritableFile> f;
- TF_RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir(directory_name));
- TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(file_name, &f));
- TF_RETURN_IF_ERROR(f->Append(text));
- TF_RETURN_IF_ERROR(f->Close());
- return absl::OkStatus();
-}
-
void DumpIrIfEnabled(const HloModule& hlo_module,
const llvm::Module& llvm_module, bool optimized,
absl::string_view filename_suffix) {
diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc
index 4611453..6737136 100644
--- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc
@@ -1493,7 +1493,8 @@
continue;
}
- if (interval.size > available_heap_size()) {
+ if (!options_.enable_window_prefetch &&
+ interval.size > available_heap_size()) {
VLOG(3) << "Skip " << interval.buffer->ToShortString()
<< " because the buffer is larger than the heap size.";
continue;
@@ -2152,6 +2153,12 @@
options_.alternate_memory_space;
VLOG(4) << "require_no_copy_alternate_mem_allocation = "
<< require_no_copy_alternate_mem_allocation;
+ if (require_no_copy_alternate_mem_allocation &&
+ allocation_value.size() > available_heap_size()) {
+ VLOG(3) << "Skip " << allocation_value.value()->ToShortString()
+ << " because the buffer is larger than the heap size.";
+ continue;
+ }
if (!options_.is_position_allowed_in_alternate_mem_fn(
allocation_value.defining_position())) {
if (require_no_copy_alternate_mem_allocation) {
@@ -3018,8 +3025,12 @@
const AllocationSequence& allocations, int64_t time) {
for (auto allocation_it = allocations.rbegin();
allocation_it != allocations.rend(); ++allocation_it) {
+ // The use case of GetLiveAllocationAt is to find the allocation that
+ // corresponds to the full buffer. Window prefetched allocations allocates
+ // only partial buffers, so we want to skip them.
if ((*allocation_it)->start_time() <= time &&
- (*allocation_it)->end_time() >= time) {
+ (*allocation_it)->end_time() >= time &&
+ !(*allocation_it)->is_window_prefetched_allocation()) {
return allocation_it->get();
}
}
@@ -4197,6 +4208,11 @@
<< "Not trying to prefetch because use requires buffer in default mem.";
(*prev_allocation_in_default_mem_it)->set_end_time(request.end_time);
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
+
+ // If the buffer is placed in default memory, we can also try window
+ // prefetching it, which will try to prefetch only a window worth of data to
+ // alternate memory.
+ WindowPrefetch(request, **prev_allocation_in_default_mem_it);
return Result::kSuccess;
}
@@ -4286,9 +4302,28 @@
// default memory.
(*prev_allocation_in_default_mem_it)->set_end_time(request.end_time);
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
+
+ // If the buffer is placed in default memory, we can try window prefetching
+ // it, which will try to prefetch only a window worth of data to alternate
+ // memory.
+ WindowPrefetch(request, **prev_allocation_in_default_mem_it);
return allocation_result;
}
+void MsaAlgorithm::AddAsyncCopyForWindowPrefetch(
+ Allocation& prev_allocation, HloUse use, const Chunk& chunk,
+ int64_t exclusive_start_time, int64_t inclusive_end_time,
+ AllocationSequence* allocations, AliasedOffset* aliased_offset,
+ float resource, const WindowPrefetchedAllocation::Options& options) {
+ allocations->push_back(std::make_unique<WindowPrefetchedAllocation>(
+ prev_allocation, use, chunk, exclusive_start_time, inclusive_end_time,
+ options));
+
+ RegisterAsyncCopy(MemorySpace::kAlternate, exclusive_start_time,
+ inclusive_end_time, allocations, aliased_offset, resource,
+ /*cross_program_prefetch_index=*/std::nullopt);
+}
+
void MsaAlgorithm::AddAsyncCopy(
Allocation& prev_allocation, MemorySpace memory_space,
std::optional<Chunk> chunk, int64_t exclusive_start_time, int64_t end_time,
@@ -4306,6 +4341,16 @@
prev_allocation, memory_space, chunk, exclusive_start_time,
copy_done_schedule_before_time, end_time, cross_program_prefetch_index));
+ RegisterAsyncCopy(memory_space, exclusive_start_time,
+ copy_done_schedule_before_time, allocations, aliased_offset,
+ resource, cross_program_prefetch_index);
+}
+
+void MsaAlgorithm::RegisterAsyncCopy(
+ MemorySpace memory_space, int64_t exclusive_start_time,
+ int64_t copy_done_schedule_before_time, AllocationSequence* allocations,
+ AliasedOffset* aliased_offset, float resource,
+ std::optional<int> cross_program_prefetch_index) {
// Register the additional async copy with the interval tree to keep track of
// the limit at any given time.
pending_async_copies_.push_back({exclusive_start_time,
@@ -4445,7 +4490,8 @@
prev_allocation =
request.allocation_value->allocation_sequence()->back().get();
can_eliminate_copy =
- (prev_allocation->memory_space() == MemorySpace::kAlternate);
+ (prev_allocation->memory_space() == MemorySpace::kAlternate &&
+ !prev_allocation->is_window_prefetched_allocation());
}
if (!can_eliminate_copy) {
@@ -4718,9 +4764,41 @@
} // namespace
-MsaAlgorithm::Result MsaAlgorithm::Prefetch(
+MsaAlgorithm::Result MsaAlgorithm::WindowPrefetch(
const AllocationRequest& request,
Allocation& prev_allocation_in_default_mem) {
+ if (!options_.enable_window_prefetch) {
+ return Result::kSuccess;
+ }
+
+ const HloUse use = request.use->hlo_use;
+ VLOG(3) << "Considering window prefetch for use=" << use.ToString();
+
+ // Get the window prefetch details for this use.
+ WindowPrefetchDetail details =
+ options_.window_prefetch_detail_fn(use.instruction);
+ for (const WindowPrefetchDetail::WindowDetail& window : details.windows()) {
+ if (window.operand() != use.operand_number) {
+ continue;
+ }
+
+ WindowPrefetchedAllocation::Options options;
+ options.bytes = window.size();
+ options.uid = window.uid();
+ options.alternate_memory_space = options_.alternate_memory_space;
+ options.notify_operand_appended_fn = options_.notify_operand_appended_fn;
+ AllocationRequest window_prefetch_request = request;
+ window_prefetch_request.window_prefetch_options = &options;
+ window_prefetch_request.size = window.size();
+ const Shape shape = ShapeUtil::MakeShape(U8, {window.size()});
+ Prefetch(window_prefetch_request, prev_allocation_in_default_mem, &shape);
+ }
+ return Result::kSuccess;
+}
+
+MsaAlgorithm::Result MsaAlgorithm::Prefetch(
+ const AllocationRequest& request,
+ Allocation& prev_allocation_in_default_mem, const Shape* shape) {
// Try partially placing the buffer in the alternate space. The time that is
// overlapped will be used to asynchronously copy the buffer from the
// default memory to the alternate memory.
@@ -4743,6 +4821,10 @@
PrefetchContext context;
context.request = &request;
context.prev_allocation_in_default_mem = &prev_allocation_in_default_mem;
+ // If the request has window prefetch options, it is called from window
+ // prefetch.
+ context.window_prefetch = (request.window_prefetch_options != nullptr);
+ CHECK(!context.window_prefetch || options_.enable_window_prefetch);
// Create a SliceProposal and WorkingIntervals.
SetupPrefetchWorkingIntervalsAndSliceProposal(context);
@@ -4757,8 +4839,13 @@
return check_result;
}
const HloUse& use = request.use->hlo_use;
- context.full_shape = &ShapeUtil::GetSubshape(
- use.instruction->operand(use.operand_number)->shape(), use.operand_index);
+ if (shape != nullptr) {
+ context.full_shape = shape;
+ } else {
+ context.full_shape = &ShapeUtil::GetSubshape(
+ use.instruction->operand(use.operand_number)->shape(),
+ use.operand_index);
+ }
// While uses might be allowed to have additional outstanding prefetches.
context.extra_async_copy_limit =
use.instruction->opcode() == HloOpcode::kWhile
@@ -4849,14 +4936,26 @@
<< context.unsliced_solution->prefetch_picker_debug_string;
AddToPendingChunks(context.unsliced_solution_intervals.full,
context.unsliced_solution->chunk_candidate);
- AddAsyncCopy(
- *context.prev_allocation_in_default_mem, MemorySpace::kAlternate,
- context.unsliced_solution->chunk_candidate,
- context.unsliced_solution_intervals.full.start - 1,
- context.request->end_time, context.prefetch_end_time,
- context.request->allocation_value->mutable_allocation_sequence(),
- context.request->preferred_offset,
- context.unsliced_solution->prefetch_resource);
+ if (context.window_prefetch) {
+ AddAsyncCopyForWindowPrefetch(
+ *context.prev_allocation_in_default_mem, request.use->hlo_use,
+ context.unsliced_solution->chunk_candidate,
+ context.unsliced_solution_intervals.full.start - 1,
+ context.prefetch_end_time,
+ context.request->allocation_value->mutable_allocation_sequence(),
+ context.request->preferred_offset,
+ context.unsliced_solution->prefetch_resource,
+ *context.request->window_prefetch_options);
+ } else {
+ AddAsyncCopy(
+ *context.prev_allocation_in_default_mem, MemorySpace::kAlternate,
+ context.unsliced_solution->chunk_candidate,
+ context.unsliced_solution_intervals.full.start - 1,
+ context.request->end_time, context.prefetch_end_time,
+ context.request->allocation_value->mutable_allocation_sequence(),
+ context.request->preferred_offset,
+ context.unsliced_solution->prefetch_resource);
+ }
request.allocation_value->allocation_sequence()->back()->AddUse(
request.use->hlo_use);
@@ -4929,7 +5028,9 @@
context.sliced_solution_intervals.full;
// Attempt to generate a slice proposal.
- GenerateSliceProposal(context);
+ if (!context.window_prefetch) {
+ GenerateSliceProposal(context);
+ }
// Setup the full SlicedBufferIntervals for the sliced and unsliced solutions.
// If there is no slice proposal, we will not try a sliced solution. In such a
diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h
index 5e2073b..52d0f0e 100644
--- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h
+++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h
@@ -514,6 +514,10 @@
absl::Span<const int64_t> all_use_times;
// See the comment for require_copy_allocation
HloInstruction* required_copy_allocation_for;
+ // Data structure that contains the options for making window prefetched
+ // allocations.
+ const WindowPrefetchedAllocation::Options* window_prefetch_options =
+ nullptr;
};
// This struct contains mandatory memory assignments at a given time. E.g., an
@@ -669,6 +673,11 @@
// Data structures used to compute and store the unsliced solution.
WorkingIntervals unsliced_solution_intervals;
std::optional<UnslicedSolution> unsliced_solution;
+
+ // Indicates whether the prefetch is for a windowed prefetch. A window
+ // prefetch only prefetches a window worth of data. Its prefetch does not
+ // use sliced prefetch.
+ bool window_prefetch = false;
};
// Result of an allocation, prefetch, eviction etc. request. The result is
@@ -860,7 +869,8 @@
// Try prefetching to alternate memory space.
Result Prefetch(const AllocationRequest& request,
- Allocation& prev_allocation_in_default_mem);
+ Allocation& prev_allocation_in_default_mem,
+ const Shape* shape = nullptr);
// Helper methods used to implement Prefetch().
//
@@ -888,6 +898,10 @@
std::string AlternateMemoryAllocationAttemptToString(
bool for_sliced_solution, const PrefetchContext& context) const;
+ // Try to prefetch a window worth of data into the alternate memory.
+ Result WindowPrefetch(const AllocationRequest& request,
+ Allocation& prev_allocation_in_default_mem);
+
// Find the best possible chunk candidate, where it has the longest possible
// availability if no preferred offset is given, or at the preferred_offset if
// it is given.
@@ -1014,6 +1028,14 @@
void ImportRepackedSlicedAllocation(RepackAllocationBlock& block);
absl::Status AreRepackedSlicesValid(const RepackAllocationBlock& block);
+ // Registers an asynchronous copy with asynchronous copy data structures to
+ // keep track of its state.
+ void RegisterAsyncCopy(MemorySpace memory_space, int64_t exclusive_start_time,
+ int64_t copy_done_schedule_before_time,
+ AllocationSequence* allocations,
+ AliasedOffset* aliased_offset, float resource,
+ std::optional<int> cross_program_prefetch_index);
+
// Adds an asynchronous copy to allocations.
void AddAsyncCopy(
Allocation& prev_allocation, MemorySpace memory_space,
@@ -1032,6 +1054,15 @@
const std::vector<SliceDecision>& slice_decisions_sorted_by_start_time,
int64_t prefetch_end_time, int64_t allocation_end_time);
+ // For window prefetching, adds a WindowPrefetchedAllocation to allocations.
+ // Also updates asynchronous copy data structures, prefetch_interval_tree_,
+ // and aliasing data structures.
+ void AddAsyncCopyForWindowPrefetch(
+ Allocation& prev_allocation, HloUse use, const Chunk& chunk,
+ int64_t exclusive_start_time, int64_t inclusive_end_time,
+ AllocationSequence* allocations, AliasedOffset* aliased_offset,
+ float resource, const WindowPrefetchedAllocation::Options& options);
+
// This method is used for committing the chunk candidate but adding it to
// pending_chunks_ so that we can "uncommit" them in case we need to roll back
// this allocation sequence.
diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc
index 50bec57..8699aab 100644
--- a/third_party/xla/xla/service/memory_space_assignment/allocation.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc
@@ -37,6 +37,8 @@
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
#include "xla/service/heap_simulator/allocation_block.h"
#include "xla/service/heap_simulator/heap_simulator.h"
#include "xla/service/hlo_value.h"
@@ -854,6 +856,133 @@
return casted_other != nullptr && (*this) == (*casted_other);
}
+WindowPrefetchedAllocation::WindowPrefetchedAllocation(
+ Allocation& prev_allocation, HloUse use, const HeapSimulator::Chunk& chunk,
+ int64_t prefetch_start_schedule_after_time,
+ int64_t prefetch_done_schedule_before_time, const Options& options)
+ : Allocation(
+ {nullptr, {}}, MemorySpace::kAlternate, chunk,
+ ExclusiveToInclusiveStartTime(prefetch_start_schedule_after_time),
+ InclusiveToExclusiveEndTime(prefetch_done_schedule_before_time),
+ /*is_scoped_allocation=*/false,
+ /*cross_program_prefetch_index=*/std::nullopt),
+ options_(options),
+ prev_allocation_(prev_allocation),
+ use_(use),
+ prefetch_start_schedule_after_(prefetch_start_schedule_after_time),
+ prefetch_done_schedule_before_(prefetch_done_schedule_before_time),
+ bytes_(chunk.size) {}
+
+HloPosition WindowPrefetchedAllocation::defining_position() const {
+ HloPosition defining_position = original_defining_position();
+ if (defining_position.instruction == nullptr) {
+ return prev_allocation_.defining_position();
+ }
+ return defining_position;
+}
+
+int64_t WindowPrefetchedAllocation::earliest_available_time() const {
+ return prefetch_done_schedule_before_;
+}
+
+absl::Status WindowPrefetchedAllocation::InsertWindowPrefetchInstruction(
+ HloInstruction* producing_instruction, HloInstruction* use_instruction,
+ HloComputation* computation) {
+ // Derive the shape for window buffer.
+ Shape shape = ShapeUtil::MakeShape(U8, {options_.bytes});
+ Layout layout = LayoutUtil::MakeLayout({0});
+ layout.set_memory_space(options_.alternate_memory_space);
+ *shape.mutable_layout() = layout;
+
+ // Insert a new parameter in the fused computation.
+ HloComputation* fused_computation =
+ use_instruction->fused_instructions_computation();
+ const int64_t num_parameters = fused_computation->num_parameters();
+ std::string name = absl::StrCat("window-buffer.", num_parameters);
+ HloInstruction* param = fused_computation->AddParameter(
+ HloInstruction::CreateParameter(num_parameters, shape, name));
+
+ // Insert async WindowPrefetch instructions as operands to the fusion.
+ HloInstruction* prefetch =
+ computation->AddInstruction(HloInstruction::CreateCustomCall(
+ shape, {producing_instruction}, "WindowPrefetch"));
+ TF_ASSIGN_OR_RETURN(prefetch_instruction_,
+ computation->CreateAsyncInstructions(prefetch, {}));
+ use_instruction->AppendOperand(prefetch_instruction_);
+
+ // Insert instruction to consume the added operands and forwards the original
+ // fusion output.
+ auto get_or_create_consumer =
+ [](HloComputation* computation) -> HloInstruction* {
+ HloInstruction* root = computation->root_instruction();
+ // If the root is already a WindowPrefetchBuffer, we don't need to create
+ // a new one.
+ if (root->IsCustomCall("WindowPrefetchBuffer")) {
+ return root;
+ }
+ HloInstruction* new_root =
+ computation->AddInstruction(HloInstruction::CreateCustomCall(
+ root->shape(), {root}, "WindowPrefetchBuffer"));
+ computation->set_root_instruction(new_root);
+ return new_root;
+ };
+ HloInstruction* consumer = get_or_create_consumer(fused_computation);
+ consumer->AppendOperand(param);
+ return absl::OkStatus();
+}
+
+absl::Status WindowPrefetchedAllocation::Process() {
+ HloInstruction* producing_instruction = AddGetTupleElements();
+ HloComputation* computation = producing_instruction->parent();
+ HloInstruction* use_instruction = use_.instruction;
+ CHECK_EQ(use_instruction->opcode(), HloOpcode::kFusion);
+
+ TF_RETURN_IF_ERROR(InsertWindowPrefetchInstruction(
+ producing_instruction, use_instruction, computation));
+
+ // Notify the backend that an operand has been appended as a window prefetch
+ // buffer.
+ int64_t use_operand = use_instruction->operand_count() - 1;
+ options_.notify_operand_appended_fn(use_instruction, options_.uid,
+ use_operand);
+
+ // Set the original defining position to the window prefetch instruction.
+ set_original_defining_position(HloPosition{prefetch_instruction_, {}});
+ AddUse(HloUse{use_instruction, use_operand});
+ return absl::OkStatus();
+}
+
+void WindowPrefetchedAllocation::MarkIfNeeded(
+ absl::flat_hash_set<const Allocation*>& needed_allocations) const {
+ MarkNeeded(needed_allocations);
+}
+
+void WindowPrefetchedAllocation::MarkNeeded(
+ absl::flat_hash_set<const Allocation*>& needed_allocations) const {
+ needed_allocations.insert(this);
+ prev_allocation_.MarkNeeded(needed_allocations);
+}
+
+std::string WindowPrefetchedAllocation::ToString() const {
+ return absl::StrCat("WindowPrefetched Allocation");
+}
+
+bool WindowPrefetchedAllocation::operator==(
+ const WindowPrefetchedAllocation& other) const {
+ return this->base_is_equal(static_cast<const Allocation&>(other)) &&
+ prefetch_done_schedule_before() ==
+ other.prefetch_done_schedule_before() &&
+ prefetch_start_schedule_after() ==
+ other.prefetch_start_schedule_after() &&
+ prefetch() == other.prefetch() && bytes_ == other.bytes_;
+}
+
+bool WindowPrefetchedAllocation::operator==(const Allocation& other) const {
+ const WindowPrefetchedAllocation* casted_other =
+ dynamic_cast<const WindowPrefetchedAllocation*>(&other);
+ return casted_other != nullptr && (*this) == (*casted_other);
+}
+
std::tuple<int64_t, bool, int64_t> GetAllocationSortTuple(
const std::unique_ptr<Allocation>& allocation) {
int64_t scheduled_on_or_before = allocation->start_time();
diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.h b/third_party/xla/xla/service/memory_space_assignment/allocation.h
index d0a4d72..bb3b324 100644
--- a/third_party/xla/xla/service/memory_space_assignment/allocation.h
+++ b/third_party/xla/xla/service/memory_space_assignment/allocation.h
@@ -18,6 +18,7 @@
#include <algorithm>
#include <cstdint>
+#include <functional>
#include <memory>
#include <optional>
#include <string>
@@ -130,6 +131,7 @@
virtual bool is_pinned_allocation() const = 0;
virtual bool is_copy_allocation() const = 0;
virtual bool is_sliced_copy_allocation() const = 0;
+ virtual bool is_window_prefetched_allocation() const = 0;
// True if the allocation is for a copy or a sliced-copy.
bool is_copy_like_allocation() const;
@@ -211,6 +213,7 @@
bool is_pinned_allocation() const override { return true; }
bool is_copy_allocation() const override { return false; }
bool is_sliced_copy_allocation() const override { return false; }
+ bool is_window_prefetched_allocation() const override { return false; }
absl::Status Process() override;
absl::Status PostProcess() override { return absl::OkStatus(); }
void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
@@ -249,6 +252,7 @@
bool is_pinned_allocation() const override { return false; }
bool is_copy_allocation() const override { return true; }
bool is_sliced_copy_allocation() const override { return false; }
+ bool is_window_prefetched_allocation() const override { return false; }
absl::Status Process() override;
absl::Status PostProcess() override { return absl::OkStatus(); }
void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
@@ -350,6 +354,7 @@
bool is_pinned_allocation() const override { return false; }
bool is_copy_allocation() const override { return false; }
bool is_sliced_copy_allocation() const override { return true; }
+ bool is_window_prefetched_allocation() const override { return false; }
// MemorySpaceAssignment::Process() calls Process() to create asynchronous
// slice copies, and a bitcast-concat call to glue the slices back together.
absl::Status Process() override;
@@ -393,6 +398,75 @@
absl::FunctionRef<Shape(const Shape&)> get_equivalent_s8_shape_fn_;
};
+// This class represents an allocation resulting from asynchronously prefetching
+// a window buffer. When a tensor is placed in the default memory, we can
+// prefetch the window buffer of the tensor to the alternate memory space. This
+// is called window prefetching.
+class WindowPrefetchedAllocation final : public Allocation {
+ public:
+ struct Options {
+ int64_t bytes = 0;
+ int64_t uid = 0;
+ int64_t alternate_memory_space = 0;
+ std::function<void(HloInstruction*, int64_t, int64_t)>
+ notify_operand_appended_fn =
+ [](const HloInstruction*, int64_t, int64_t) {};
+ };
+
+ WindowPrefetchedAllocation(Allocation& prev_allocation, HloUse use,
+ const HeapSimulator::Chunk& chunk,
+ int64_t prefetch_start_schedule_after_time,
+ int64_t prefetch_done_schedule_before_time,
+ const Options& options);
+
+ // Overridden methods
+ //
+ HloPosition defining_position() const override;
+ int64_t earliest_available_time() const override;
+ bool is_pinned_allocation() const override { return false; }
+ bool is_copy_allocation() const override { return false; }
+ bool is_sliced_copy_allocation() const override { return false; }
+ bool is_window_prefetched_allocation() const override { return true; }
+ // MemorySpaceAssignment::Process() calls Process() to create asynchronous
+ // window prefetches.
+ absl::Status Process() override;
+ absl::Status PostProcess() override { return absl::OkStatus(); }
+ // Marks the allocation as needed.
+ void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
+ const override;
+ void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
+ const override;
+ std::string ToString() const override;
+ bool operator==(const WindowPrefetchedAllocation& other) const;
+ bool operator==(const Allocation& other) const override;
+ int64_t bytes() const { return bytes_; }
+ int64_t prefetch_start_schedule_after() const {
+ return prefetch_start_schedule_after_;
+ }
+ int64_t prefetch_done_schedule_before() const {
+ return prefetch_done_schedule_before_;
+ }
+ HloInstruction* prefetch() const { return prefetch_instruction_; }
+
+ private:
+ // This method is called by Process() to create window prefetch instructions.
+ // These instructions include a pair of async WindowPrefetch outside the
+ // fusion and a WindowPrefetchBuffer inside the fusion. The
+ // WindowPrefetchBuffer is used for consuming the appended window buffer
+ // operands.
+ absl::Status InsertWindowPrefetchInstruction(
+ HloInstruction* producing_instruction, HloInstruction* use_instruction,
+ HloComputation* computation);
+
+ Options options_;
+ HloInstruction* prefetch_instruction_ = nullptr;
+ Allocation& prev_allocation_;
+ HloUse use_;
+ int64_t prefetch_start_schedule_after_;
+ int64_t prefetch_done_schedule_before_;
+ int64_t bytes_;
+};
+
// An allocation in the default memory space that mirrors another Allocation
// object. This is useful to model an eviction that happens before a while op
// so that we don't need to redundantly evict the buffer after the while op as
@@ -409,6 +483,7 @@
bool is_pinned_allocation() const override { return false; }
bool is_copy_allocation() const override { return false; }
bool is_sliced_copy_allocation() const override { return false; }
+ bool is_window_prefetched_allocation() const override { return false; }
absl::Status Process() override;
absl::Status PostProcess() override { return absl::OkStatus(); }
void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
@@ -442,6 +517,7 @@
bool is_pinned_allocation() const override { return false; }
bool is_copy_allocation() const override { return false; }
bool is_sliced_copy_allocation() const override { return false; }
+ bool is_window_prefetched_allocation() const override { return false; }
absl::Status Process() override;
absl::Status PostProcess() override;
void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto
index 77faa69..e15d564 100644
--- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto
+++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto
@@ -46,6 +46,21 @@
uint64 preferred_slice_size = 5;
}
+// Memory space assignment options for prefetching windows of data
+message WindowPrefetchDetail {
+ message WindowDetail {
+ // Index of the operand that is window prefetched.
+ int64 operand = 1;
+ // Window buffer size in bytes.
+ int64 size = 2;
+ // Unique identifier to distinguish the buffers that are associated with the
+ // same operand.
+ int64 uid = 3;
+ }
+
+ repeated WindowDetail windows = 1;
+}
+
// Options for memory-bound loop optimizations in memory space assignment. If
// enabled, this pass can optimize memory-bound unrolled loops to maximize the
// bandwidth utilized and minimize the execution time.
diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
index 3fb558b..40f5e1d 100644
--- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
@@ -8282,6 +8282,65 @@
}
}
+TEST_F(MemorySpaceAssignmentTest, WindowPrefetch) {
+ absl::string_view hlo_string = R"(
+HloModule module, is_scheduled=true
+
+%fused_computation {
+ %p0 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(0)
+ %p1 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(1)
+ %p2 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(2)
+ %add0 = bf16[64,8]{1,0:T(8,128)(2,1)} add(%p0, %p1)
+ ROOT %add1 = bf16[64,8]{1,0:T(8,128)(2,1)} add(%add0, %p2)
+}
+
+entry {
+ %p0 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(0)
+ %p1 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(1)
+ %p2 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(2)
+ ROOT fusion = bf16[64,8]{1,0:T(8,128)(2,1)} fusion(bf16[64,8]{1,0:T(8,128)(2,1)} %p0, bf16[64,8]{1,0:T(8,128)(2,1)} %p1, bf16[64,8]{1,0:T(8,128)(2,1)} %p2), kind=kLoop, calls=%fused_computation
+}
+
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ // Get info about window prefetch buffers, such as which operands they
+ // correspond to and their sizes.
+ auto window_prefetch_detail_fn = [&](const HloInstruction* instruction) {
+ WindowPrefetchDetail window_prefetch_detail;
+ const HloInstruction* fusion = FindInstruction(module.get(), "fusion");
+ if (instruction == fusion) {
+ for (int i = 0; i < 3; ++i) {
+ auto* operand = window_prefetch_detail.add_windows();
+ operand->set_operand(i);
+ operand->set_size(32);
+ }
+ }
+ return window_prefetch_detail;
+ };
+
+ Options options = DefaultMemorySpaceOptions();
+ options.enable_window_prefetch = true;
+ options.window_prefetch_detail_fn = window_prefetch_detail_fn;
+ AssignMemorySpace(module.get(), options, /*max_prefetch_interval=*/10,
+ /*min_prefetch_interval=*/0);
+ const HloInstruction* fusion = FindInstruction(module.get(), "fusion");
+ // The fusion instruction should have 5 operands: the 3 original operands
+ // plus 2 window prefetch buffers.
+ EXPECT_EQ(fusion->operand_count(), 5);
+
+ // The root of the fusion should be a WindowPrefetchBuffer. The first operand
+ // should be the original root, and the second and third operands should be
+ // the window prefetch buffers.
+ HloInstruction* root = fusion->fused_expression_root();
+ EXPECT_TRUE(root->IsCustomCall("WindowPrefetchBuffer"));
+ EXPECT_EQ(root->operand_count(), 3);
+ EXPECT_EQ(root->operand(1), fusion->fused_parameter(3));
+ EXPECT_EQ(root->operand(2), fusion->fused_parameter(4));
+ VLOG(2) << "module: " << module->ToString();
+}
+
using AsynchronousCopyOrderingTest = ::testing::Test;
TEST_F(AsynchronousCopyOrderingTest, Simple) {
diff --git a/third_party/xla/xla/service/memory_space_assignment/options.h b/third_party/xla/xla/service/memory_space_assignment/options.h
index 3a1d848..fb9730c 100644
--- a/third_party/xla/xla/service/memory_space_assignment/options.h
+++ b/third_party/xla/xla/service/memory_space_assignment/options.h
@@ -57,6 +57,10 @@
const absl::flat_hash_set<ShapeIndex>& /*outputs_in_alternate_memory*/)>;
using PositionRequiresContiguousAllocationFunction =
std::function<bool(const HloPosition&)>;
+using WindowPrefetchDetailFunction =
+ std::function<WindowPrefetchDetail(const HloInstruction*)>;
+using WindowPrefetchNotifyOperandAppendedFunction =
+ std::function<void(HloInstruction*, int64_t, int64_t)>;
// The different options to be passed to the Run() API.
struct Options {
@@ -111,6 +115,15 @@
position_requires_contiguous_allocation_fn =
[](const HloPosition&) { return false; };
+ // This function is called to get details about window prefetches.
+ WindowPrefetchDetailFunction window_prefetch_detail_fn =
+ [](const HloInstruction*) { return WindowPrefetchDetail(); };
+
+ // This function is called to notify that an operand has been appended as a
+ // window prefetch buffer.
+ WindowPrefetchNotifyOperandAppendedFunction notify_operand_appended_fn =
+ [](HloInstruction*, int64_t, int64_t) {};
+
// If true, we will try to reduce scoped allocation buffer size for all
// instructions if their operand/output has been allocated in alternate
// memory.
@@ -234,6 +247,13 @@
// Option to always spill buffers from alternate memory to default memory
// and prefetching back to alternate memory(if needed) just in time for use.
bool always_spill_to_default_memory = false;
+
+ // If true, enables window prefetching. Window prefetching is a mechanism
+ // where we prefetch windows of data into the alternate memory before the
+ // first use of the buffer. This allows large tensors to be prefetched as well
+ // and gives MSA more flexibility in choosing the prefetch time and how much
+ // data to prefetch.
+ bool enable_window_prefetch = false;
};
} // namespace memory_space_assignment
} // namespace xla
diff --git a/third_party/xla/xla/service/pattern_matcher.h b/third_party/xla/xla/service/pattern_matcher.h
index 9b5a953..76979f0 100644
--- a/third_party/xla/xla/service/pattern_matcher.h
+++ b/third_party/xla/xla/service/pattern_matcher.h
@@ -2682,6 +2682,7 @@
XLA_UNOP_PATTERN(Bitcast)
XLA_UNOP_PATTERN(BitcastConvert)
XLA_UNOP_PATTERN(Broadcast)
+XLA_UNOP_PATTERN(Cbrt)
XLA_UNOP_PATTERN(Ceil)
XLA_UNOP_PATTERN(Convert)
XLA_UNOP_PATTERN(Copy)
@@ -2695,6 +2696,7 @@
XLA_UNOP_PATTERN(CollectivePermuteStart)
XLA_UNOP_PATTERN(CollectivePermuteDone)
XLA_UNOP_PATTERN(Domain)
+XLA_UNOP_PATTERN(Erf)
XLA_UNOP_PATTERN(Exp)
XLA_UNOP_PATTERN(Expm1)
XLA_UNOP_PATTERN(Fft)
@@ -2704,6 +2706,7 @@
XLA_UNOP_PATTERN(Infeed)
XLA_UNOP_PATTERN(IsFinite)
XLA_UNOP_PATTERN(Log)
+XLA_UNOP_PATTERN(Logistic)
XLA_UNOP_PATTERN(Not)
XLA_UNOP_PATTERN(Negate)
XLA_UNOP_PATTERN(OptimizationBarrier)
diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc
index 53ff557..bf375e8 100644
--- a/third_party/xla/xla/service/shape_inference.cc
+++ b/third_party/xla/xla/service/shape_inference.cc
@@ -3794,6 +3794,7 @@
static absl::Status ValidateGatherDimensionNumbers(
const Shape& input_shape, absl::Span<const int64_t> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
+ // Validate offset_dims in GatherDimensionNumbers.
if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
@@ -3834,6 +3835,7 @@
start_indices_shape[dim_numbers.index_vector_dim()]);
}
+ // Validate start_index_map in GatherDimensionNumbers.
for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
int64_t operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
if (operand_dim_for_start_index_i < 0 ||
@@ -3858,6 +3860,7 @@
StrJoin(dim_numbers.start_index_map(), ", "));
}
+ // Validate collapsed_slice_dims in GatherDimensionNumbers.
for (int64_t collapsed_dim : dim_numbers.collapsed_slice_dims()) {
if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
return InvalidArgument(
@@ -3881,6 +3884,69 @@
StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
}
+ // Validate operand_batching_dims and start_indices_batching_dims are of the
+ // same size.
+ if (dim_numbers.operand_batching_dims_size() !=
+ dim_numbers.start_indices_batching_dims_size()) {
+ return InvalidArgument(
+ "operand_batching_dims and start_indices_batching_dims in gather op "
+ "must be of the same size; got: %d and %d.",
+ dim_numbers.operand_batching_dims_size(),
+ dim_numbers.start_indices_batching_dims_size());
+ }
+
+ // Validate operand_batching_dims in GatherDimensionNumbers.
+ for (int64_t operand_batching_dim : dim_numbers.operand_batching_dims()) {
+ if (operand_batching_dim < 0 ||
+ operand_batching_dim >= input_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid operand_batching_dims set in gather op; valid range is [0, "
+ "%d), got: %d.",
+ input_shape.dimensions_size(), operand_batching_dim);
+ }
+ }
+
+ if (!absl::c_is_sorted(dim_numbers.operand_batching_dims())) {
+ return InvalidArgument(
+ "operand_batching_dims in gather op must be sorted; got: %s",
+ StrJoin(dim_numbers.operand_batching_dims(), ", "));
+ }
+
+ if (absl::c_adjacent_find(dim_numbers.operand_batching_dims()) !=
+ dim_numbers.operand_batching_dims().end()) {
+ return InvalidArgument(
+ "Repeated dimensions not allowed in operand_batching_dims in gather "
+ "op; "
+ "got: %s.",
+ StrJoin(dim_numbers.operand_batching_dims(), ", "));
+ }
+
+ // Validate start_indices_batching_dims in GatherDimensionNumbers.
+ for (int i = 0; i < dim_numbers.start_indices_batching_dims_size(); i++) {
+ int64_t start_indices_batching_dim_i =
+ dim_numbers.start_indices_batching_dims(i);
+ if (start_indices_batching_dim_i < 0 ||
+ start_indices_batching_dim_i >= start_indices_shape.size()) {
+ return InvalidArgument(
+ "Invalid start_indices_batching_dims; domain is [0, %d), got: "
+ "%d->%d.",
+ start_indices_shape.size(), i, start_indices_batching_dim_i);
+ }
+ }
+
+ std::vector<int64_t> sorted_start_indices_batching_dims(
+ dim_numbers.start_indices_batching_dims().begin(),
+ dim_numbers.start_indices_batching_dims().end());
+
+ absl::c_sort(sorted_start_indices_batching_dims);
+
+ if (absl::c_adjacent_find(sorted_start_indices_batching_dims) !=
+ sorted_start_indices_batching_dims.end()) {
+ return InvalidArgument(
+ "Repeated dimensions are not allowed in start_indices_batching_dims; "
+ "got: %s.",
+ StrJoin(dim_numbers.start_indices_batching_dims(), ", "));
+ }
return absl::OkStatus();
}
@@ -3943,13 +4009,16 @@
if (slice_sizes.size() !=
gather_dim_numbers.offset_dims_size() +
- gather_dim_numbers.collapsed_slice_dims_size()) {
+ gather_dim_numbers.collapsed_slice_dims_size() +
+ gather_dim_numbers.operand_batching_dims_size()) {
return InvalidArgument(
"All components of the offset index in a gather op must either be a "
- "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
- "output_slice_sizes=%s, collapsed_slice_dims=%s.",
+ "offset dimension or explicitly collapsed or explicitly batched; got "
+ "len(slice_sizes)=%lu, output_slice_sizes=%s, collapsed_slice_dims=%s, "
+ "operand_batching_dims=%s.",
slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
- StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
+ StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","),
+ StrJoin(gather_dim_numbers.operand_batching_dims(), ","));
}
for (int i = 0; i < slice_sizes.size(); i++) {
@@ -3974,6 +4043,16 @@
}
}
+ for (int i = 0; i < gather_dim_numbers.operand_batching_dims_size(); i++) {
+ if (slice_sizes[gather_dim_numbers.operand_batching_dims(i)] > 1) {
+ return InvalidArgument(
+ "Gather op can only have operand_batching_dims with bound 1 or 0, "
+ "but bound is %d for index %d at position %d.",
+ slice_sizes[gather_dim_numbers.operand_batching_dims(i)],
+ gather_dim_numbers.operand_batching_dims(i), i);
+ }
+ }
+
int64_t result_rank = gather_dim_numbers.offset_dims_size() +
(expanded_start_indices_shape.size() - 1);
int64_t offset_dims_seen = 0;
@@ -3993,6 +4072,10 @@
offset_dims_seen)) {
offset_dims_seen++;
}
+ while (absl::c_binary_search(gather_dim_numbers.operand_batching_dims(),
+ offset_dims_seen)) {
+ offset_dims_seen++;
+ }
// Gathering an entire dynamic dimension creates dynamic dimension.
//
// e.g.,:
@@ -4075,7 +4158,8 @@
// Validate window size.
auto window_size = dim_numbers.update_window_dims_size() +
- dim_numbers.inserted_window_dims_size();
+ dim_numbers.inserted_window_dims_size() +
+ dim_numbers.input_batching_dims_size();
if (window_size != operand_shape.rank()) {
return InvalidArgument(
"Scatter op has window of size %d; doesn't match operand of rank %d.",
@@ -4117,6 +4201,61 @@
StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
}
+ // Validate input_batching_dims and scatter_indices_batching_dims in
+ // ScatterDimensionNumbers.
+ if (dim_numbers.input_batching_dims_size() !=
+ dim_numbers.scatter_indices_batching_dims_size()) {
+ return InvalidArgument(
+ "input_batching_dims and scatter_indices_batching_dims in scatter op "
+ "must be of the same size; got: %d and %d.",
+ dim_numbers.input_batching_dims_size(),
+ dim_numbers.scatter_indices_batching_dims_size());
+ }
+
+ // Validate input_batching_dims in ScatterDimensionNumbers.
+ if (!absl::c_is_sorted(dim_numbers.input_batching_dims())) {
+ return InvalidArgument(
+ "input_batching_dims in scatter op must be sorted; got: %s.",
+ StrJoin(dim_numbers.input_batching_dims(), ", "));
+ }
+ if (absl::c_adjacent_find(dim_numbers.input_batching_dims()) !=
+ dim_numbers.input_batching_dims().end()) {
+ return InvalidArgument(
+ "input_batching_dims in scatter op must not repeat; got: %s.",
+ StrJoin(dim_numbers.input_batching_dims(), ", "));
+ }
+ for (int64_t input_batching_dim : dim_numbers.input_batching_dims()) {
+ if (input_batching_dim < 0 ||
+ input_batching_dim >= operand_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid input_batching_dims set in scatter op; valid range is [0, "
+ "%d), got: %d.",
+ operand_shape.dimensions_size(), input_batching_dim);
+ }
+ }
+
+ // Validate scatter_indices_batching_dims in ScatterDimensionNumbers.
+ for (int64_t scatter_indices_batching_dim :
+ dim_numbers.scatter_indices_batching_dims()) {
+ if (scatter_indices_batching_dim < 0 ||
+ scatter_indices_batching_dim >= scatter_indices_shape.size()) {
+ return InvalidArgument(
+ "Invalid scatter_indices_batching_dims set in scatter op; valid "
+ "range is [0, %d), got: %d.",
+ scatter_indices_shape.size(), scatter_indices_batching_dim);
+ }
+ }
+ std::vector<int64_t> sorted_scatter_indices_batching_dims(
+ dim_numbers.scatter_indices_batching_dims().begin(),
+ dim_numbers.scatter_indices_batching_dims().end());
+ absl::c_sort(sorted_scatter_indices_batching_dims);
+ if (absl::c_adjacent_find(sorted_scatter_indices_batching_dims) !=
+ sorted_scatter_indices_batching_dims.end()) {
+ return InvalidArgument(
+ "scatter_indices_batching_dims in scatter op must not repeat; got: %s.",
+ StrJoin(dim_numbers.scatter_indices_batching_dims(), ", "));
+ }
+
return absl::OkStatus();
}
diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc
index 16a0152..5239d6c 100644
--- a/third_party/xla/xla/service/sharding_propagation.cc
+++ b/third_party/xla/xla/service/sharding_propagation.cc
@@ -358,7 +358,7 @@
computation_map.find(instruction->parent()) == computation_map.end() &&
!(is_entry_root && allow_spmd_sharding_propagation_to_output)) {
// We don't support sharding the root instruction of a computation yet,
- // unless the computation is a while body.
+ // unless the computation is in computation_map.
return false;
}
@@ -2886,14 +2886,23 @@
return std::vector<HloInstruction*>{inst, callee->root_instruction()};
} else if (inst->opcode() == HloOpcode::kParameter) {
auto it = computation_map.find(inst->parent());
- if (it != computation_map.end() &&
- it->second->opcode() == HloOpcode::kConditional) {
- HloInstruction* cond = it->second;
- for (int64_t i = 1; i < cond->operand_count(); ++i) {
- if (cond->called_computations()[i - 1] == inst->parent()) {
- return std::vector<HloInstruction*>{inst, cond->mutable_operand(i)};
+ if (it != computation_map.end()) {
+ if (it->second->opcode() == HloOpcode::kConditional) {
+ HloInstruction* cond = it->second;
+ for (int64_t i = 1; i < cond->operand_count(); ++i) {
+ if (cond->called_computations()[i - 1] == inst->parent()) {
+ return std::vector<HloInstruction*>{inst,
+ cond->mutable_operand(i)};
+ }
}
}
+ if (it->second->opcode() == HloOpcode::kCall) {
+ HloInstruction* call = it->second;
+ int64_t operand_index = inst->parameter_number();
+ CHECK_LT(operand_index, call->operand_count());
+ return std::vector<HloInstruction*>{
+ inst, call->mutable_operand(operand_index)};
+ }
}
return std::vector<HloInstruction*>{};
} else {
@@ -2936,9 +2945,11 @@
auto it = computation_map.find(instruction->parent());
if (it != computation_map.end()) {
propagate_to_instruction(it->second);
- // Propagate parameter shardings back to conditional's operands.
+ // Propagate parameter shardings back to conditional's and
+ // call's operands.
if (instruction->opcode() == HloOpcode::kParameter &&
- it->second->opcode() == HloOpcode::kConditional) {
+ (it->second->opcode() == HloOpcode::kConditional ||
+ it->second->opcode() == HloOpcode::kCall)) {
propagate_to_instruction(instruction);
}
}
@@ -2954,8 +2965,8 @@
}
}
- // Populate computation_map in order to associate while bodies to their
- // while instructions.
+ // Populate computation_map in order to associate while bodies and conditions
+ // to their while instructions.
for (auto computation : module->computations(execution_threads)) {
for (auto instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kWhile ||
@@ -2982,6 +2993,7 @@
}
if (instruction->opcode() == HloOpcode::kWhile) {
computation_map[instruction->while_body()] = instruction;
+ computation_map[instruction->while_condition()] = instruction;
} else {
for (HloComputation* c : instruction->called_computations()) {
computation_map[c] = instruction;
diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc
index ac04389..072f436 100644
--- a/third_party/xla/xla/service/sharding_propagation_test.cc
+++ b/third_party/xla/xla/service/sharding_propagation_test.cc
@@ -2757,6 +2757,60 @@
}
}
+TEST_F(ShardingPropagationTest, PropagateShardingInWhileCondition) {
+ const char* const hlo_string = R"(
+HloModule module
+
+%cond {
+ %vars.cond = (u32[], f32[]) parameter(0)
+ %count.cond = u32[] get-tuple-element(%vars.cond), index=0
+ %limit = u32[] constant(10)
+ ROOT %lt = pred[] compare(%count.cond, %limit), direction=LT
+}
+
+%body {
+ %vars = (u32[], f32[]) parameter(0)
+ %count = u32[] get-tuple-element(%vars), index=0
+ %acc = f32[] get-tuple-element(%vars), index=1
+
+ %one = u32[] constant(1)
+ %count.1 = u32[] add(u32[] %count, u32[] %one)
+ %acc.1 = f32[] add(f32[] %acc, f32[] %acc)
+ ROOT %tuple = (u32[], f32[]) tuple(%count.1, %acc.1)
+}
+
+ENTRY %entry {
+ %p0 = f32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+ %zero = u32[] constant(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+ %init = (u32[], f32[]) tuple(%zero, %p0)
+ ROOT %while = (u32[], f32[]) while(%init), body=%body, condition=%cond
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool changed,
+ ShardingPropagation(/*is_spmd=*/false, /*propagate_metadata=*/false,
+ /*allow_spmd_sharding_propagation_to_output=*/{true})
+ .Run(module.get()));
+ EXPECT_TRUE(changed);
+ HloSharding single_sharding =
+ ParseSharding("{devices=[2,2]<=[4] last_tile_dims={manual, replicated}}")
+ .value();
+ HloSharding tuple_sharding = HloSharding::SingleTuple(
+ module->entry_computation()->root_instruction()->shape(),
+ single_sharding);
+
+ for (const HloComputation* computation : module->computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ EXPECT_TRUE(instruction->has_sharding());
+ EXPECT_EQ(instruction->sharding(), instruction->shape().IsTuple()
+ ? tuple_sharding
+ : single_sharding);
+ }
+ }
+}
+
TEST_P(ParameterizedMetadataTest, WhileGetShardingFromRecvInBody) {
const char* const hlo_string = R"(
HloModule module
@@ -12070,5 +12124,36 @@
"last_tile_dim_replicate}}"));
}
+TEST_F(ShardingPropagationTest, CallPropagation) {
+ const absl::string_view hlo_string = R"(
+HloModule module
+
+called_computation {
+ p0 = bf16[20,2,68096,8512] parameter(0)
+ %add_called_comp = bf16[20,2,68096,8512] add(p0, p0)
+ ROOT tuple = (bf16[20,2,68096,8512]) tuple(add_called_comp)
+}
+
+ENTRY main {
+ %param0 = bf16[20,2,68096,8512] parameter(0)
+ %add = bf16[20,2,68096,8512] add(param0, param0)
+ ROOT %call = (bf16[20,2,68096,8512]) call(add), to_apply=%called_computation, sharding={{devices=[1,1,16,64]<=[64,16]T(1,0)}}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool changed,
+ ShardingPropagation(
+ /*is_spmd=*/true, /*propagate_metadata=*/true,
+ /*allow_spmd_sharding_propagation_to_output=*/{false},
+ /*allow_spmd_sharding_propagation_to_parameters=*/{false})
+ .Run(module.get()));
+ XLA_VLOG_LINES(1, module->ToString());
+ EXPECT_TRUE(changed);
+ auto* add = FindInstruction(module.get(), "add");
+ ASSERT_NE(add, nullptr);
+ EXPECT_THAT(add, op::Sharding("{devices=[1,1,16,64]<=[64,16]T(1,0)}"));
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/service/space_to_batch_converter.cc b/third_party/xla/xla/service/space_to_batch_converter.cc
index e07ee54..45f2113 100644
--- a/third_party/xla/xla/service/space_to_batch_converter.cc
+++ b/third_party/xla/xla/service/space_to_batch_converter.cc
@@ -1734,6 +1734,10 @@
}
if (consumer->opcode() == HloOpcode::kReduce) {
+ // Do not propagate through tuple outputs.
+ if (consumer->shape().IsTuple()) {
+ return false;
+ }
// Support only the trivial case where both batch and split spatial dim are
// being reduced
@@ -1741,8 +1745,13 @@
auto result = instr_to_dim_map_[consumer->mutable_operand(0)];
const int64_t batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
- VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim
- << " space_dim " << space_dim << " reduce " << consumer->ToString();
+ // Support the trivial case where none of the batch and split spatial dim
+ // are being reduced.
+ return !absl::c_linear_search(reduce_dims, batch_dim) &&
+ !absl::c_linear_search(reduce_dims, space_dim);
+
+ // Support only the trivial case where both batch and split spatial dim are
+ // being reduced
return absl::c_linear_search(reduce_dims, batch_dim) &&
absl::c_linear_search(reduce_dims, space_dim);
}
@@ -2072,16 +2081,116 @@
}
if (consumer->opcode() == HloOpcode::kReduce) {
- auto new_consumer = computation->AddInstruction(consumer->Clone());
- auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
-
+ auto reduce_dims = consumer->dimensions();
auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
+ auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
+ auto permute_dims = instr_to_dim_permute_map_[first_operand];
+
const int64_t old_batch_dim =
dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
+ const int64_t space_dim =
+ dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
- auto permute_dims = instr_to_dim_permute_map_[first_operand];
const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
+ const int64_t new_space_dim = DimLookUp(permute_dims, space_dim);
+ std::vector<int64_t> changed_dims(consumer->dimensions().size());
+ // Support the trivial case where none of the batch and split spatial dim
+ // are being reduced.
+ if (!absl::c_linear_search(reduce_dims, old_batch_dim) &&
+ !absl::c_linear_search(reduce_dims, space_dim)) {
+ for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
+ changed_dims[i] = DimLookUp(permute_dims, consumer->dimensions(i));
+ }
+
+ // Decide where the new batch and space dims are in the output.
+ int64_t new_output_batch_dim = new_batch_dim;
+ int64_t new_output_space_dim = new_space_dim;
+ for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
+ if (changed_dims[i] < new_batch_dim) {
+ new_output_batch_dim--;
+ }
+ if (changed_dims[i] < new_space_dim) {
+ new_output_space_dim--;
+ }
+ }
+
+ // Decide where the new batch and space dims are in the original reduce's
+ // output.
+ int64_t old_output_batch_dim = old_batch_dim;
+ int64_t old_output_space_dim = space_dim;
+ for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
+ if (reduce_dims[i] < old_batch_dim) {
+ old_output_batch_dim--;
+ }
+ if (reduce_dims[i] < space_dim) {
+ old_output_space_dim--;
+ }
+ }
+
+ HloInstruction* new_consumer = nullptr;
+ TF_ASSIGN_OR_RETURN(
+ new_consumer,
+ MakeReduceHlo(first_operand, consumer->mutable_operand(1),
+ changed_dims, consumer->called_computations()[0]));
+
+ VLOG(3) << " new_output_batch_dim " << new_output_batch_dim << " size "
+ << first_operand->shape().dimensions(new_batch_dim)
+ << " new_output_space_dim " << new_output_space_dim << " size "
+ << first_operand->shape().dimensions(new_space_dim);
+
+ std::vector<int64_t> dim_map(kNumMappedDims);
+ dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = old_output_batch_dim;
+ dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = old_output_space_dim;
+ // We don't know where the feature dim is, so set it to -1.
+ dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] = -1;
+
+ instr_to_dim_map_[consumer] = dim_map;
+ const int64_t rank = first_operand->shape().rank();
+
+ const int64_t output_rank = new_consumer->shape().rank();
+
+ // Make a map of each dim in original reduce output to input.
+ std::vector<int64_t> old_reduce_output_to_input(output_rank);
+ int dim_number_to_assign_old = 0;
+ for (int64_t i = 0; i < rank; ++i) {
+ if (auto it = absl::c_find(reduce_dims, i); it != reduce_dims.end()) {
+ continue;
+ }
+ old_reduce_output_to_input[i] = dim_number_to_assign_old++;
+ }
+
+ // Make a map of each dim in new reduce output to the new input.
+ std::vector<int64_t> new_reduce_output_to_input(output_rank);
+ int dim_number_to_assign_new = 0;
+ for (int64_t i = 0; i < rank; ++i) {
+ if (auto it = absl::c_find(changed_dims, i); it != changed_dims.end()) {
+ continue;
+ }
+ new_reduce_output_to_input[i] = dim_number_to_assign_new++;
+ }
+
+ std::vector<int64_t> new_permute_dims(output_rank);
+ // From the output dims to input dims mapping, figure how the old output
+ // dims are mapped to the new output dims.
+ for (int64_t i = 0; i < output_rank; ++i) {
+ new_permute_dims[i] = std::distance(
+ new_reduce_output_to_input.begin(),
+ absl::c_find(
+ new_reduce_output_to_input,
+ DimLookUp(permute_dims, old_reduce_output_to_input[i])));
+ }
+
+ instr_to_dim_permute_map_[new_consumer] = new_permute_dims;
+ old_to_new_instrs_[consumer] = new_consumer;
+
+ // Because batch and split spatial dims are not reduced, further
+ // propagation is needed.
+ return true;
+ }
+
+ HloInstruction* new_consumer =
+ computation->AddInstruction(consumer->Clone());
auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0));
std::vector<int64_t> old_spatial_dims = retval.first;
std::vector<int64_t> new_spatial_dims = retval.second;
@@ -2092,7 +2201,6 @@
consumer->mutable_operand(1), new_batch_dim,
new_spatial_dims, old_batch_dim, old_spatial_dims));
- std::vector<int64_t> changed_dims(new_consumer->dimensions().size());
for (int64_t i = 0; i < new_consumer->dimensions().size(); ++i) {
changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i));
}
diff --git a/third_party/xla/xla/service/space_to_batch_converter_test.cc b/third_party/xla/xla/service/space_to_batch_converter_test.cc
index e2ed331..dbc11e9 100644
--- a/third_party/xla/xla/service/space_to_batch_converter_test.cc
+++ b/third_party/xla/xla/service/space_to_batch_converter_test.cc
@@ -272,5 +272,84 @@
ASSERT_TRUE(converter.Run(module.get()).value());
}
+TEST_F(SpaceToBatchConverterTest, PropagateOnTrivialReduce) {
+ std::string hlo_string = R"(
+ HloModule module
+
+ %region_1.37 (Arg_0.38: f32[], Arg_1.39: f32[]) -> f32[] {
+ %Arg_0.38 = f32[] parameter(0)
+ %Arg_1.39 = f32[] parameter(1)
+ ROOT %add.40 = f32[] add(f32[] %Arg_0.38, f32[] %Arg_1.39)
+ }
+
+ ENTRY computation {
+ %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0)
+ %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1)
+ %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0, %p1),
+ window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f
+ %constant.5 = f32[] constant(0)
+ ROOT %reduce.41 = f32[7,160,400]{2,1,0} reduce(%c, %constant.5), dimensions={3}, to_apply=%region_1.37
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ auto computation = module->entry_computation();
+ SpaceToBatchConverter converter(
+ SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8});
+ ASSERT_TRUE(converter.Run(module.get()).value());
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_THAT(root, op::Transpose());
+ EXPECT_THAT(root->operand(0)->operand(0)->operand(0)->operand(0),
+ op::Reduce());
+ auto new_reduce = root->operand(0)->operand(0)->operand(0)->operand(0);
+ // Make sure we propagated on the reduce with the larger batch size.
+ EXPECT_EQ(new_reduce->shape().dimensions(1),
+ // batch*number_of_splits
+ 7 * 8);
+}
+
+TEST_F(SpaceToBatchConverterTest, DoNotPropagateOnTupleReduce) {
+ std::string hlo_string = R"(
+ HloModule module
+
+%minmax_func.2717 {
+ %lhs_value.2718 = f32[] parameter(0)
+ %rhs_value.2720 = f32[] parameter(2)
+ %compare.2722 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=GE
+ %select.2723 = f32[] select(pred[] %compare.2722, f32[] %lhs_value.2718, f32[] %rhs_value.2720)
+ %compare.2725 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=EQ
+ %lhs_index.2719 = f32[] parameter(1)
+ %rhs_index.2721 = f32[] parameter(3)
+ %minimum.2726 = f32[] minimum(f32[] %lhs_index.2719, f32[] %rhs_index.2721)
+ %select.2724 = f32[] select(pred[] %compare.2722, f32[] %lhs_index.2719, f32[] %rhs_index.2721)
+ %select.2727 = f32[] select(pred[] %compare.2725, f32[] %minimum.2726, f32[] %select.2724)
+ ROOT %tuple.4 = (f32[], f32[]) tuple(f32[] %select.2723, f32[] %select.2727)
+ }
+
+ ENTRY computation {
+ %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0)
+ %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1)
+ %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0, %p1),
+ window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f
+ %constant.5 = f32[] constant(0)
+ %constant.6 = f32[] constant(1)
+ ROOT %reduce.36 = (f32[7,160,400]{2,1,0}, f32[7,160,400]{2,1,0}) reduce(%c, %c,
+ %constant.5, %constant.6), dimensions={3}, to_apply=%minmax_func.2717
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ auto computation = module->entry_computation();
+ SpaceToBatchConverter converter(
+ SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8});
+ ASSERT_TRUE(converter.Run(module.get()).value());
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_THAT(root, op::Reduce());
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc
index 87db7f7..22f88cf 100644
--- a/third_party/xla/xla/service/spmd/dot_handler.cc
+++ b/third_party/xla/xla/service/spmd/dot_handler.cc
@@ -1899,7 +1899,7 @@
has_reshape_operand(lhs) ? lhs.hlo()->operand(0) : lhs.hlo();
auto rhs_operand =
has_reshape_operand(rhs) ? rhs.hlo()->operand(0) : rhs.hlo();
- for (auto loop : *windowed_dot_general_loops) {
+ for (const auto& loop : *windowed_dot_general_loops) {
if (loop.while_loop->while_body()->name().find(
"windowed_dot_general_body_ag") == 0) {
auto cm_lhs = loop.while_loop->operand(0)->operand(0);
@@ -2562,7 +2562,8 @@
};
std::optional<GroupedSharding> other_grouped =
try_sharding_for_other_operand(other.sharding());
- if (!other_grouped && !other.sharding().IsReplicated()) {
+ if (!other_grouped && !other.sharding().IsReplicated() &&
+ dims_mapping.conv_spatial_dims.empty()) {
const HloSharding expected_other_sharding =
hlo_sharding_util::InferDotOperandSharding(
&output_sharding, &matching.sharding(), lhs_matching ? 1 : 0,
@@ -2570,9 +2571,9 @@
// Try the expected sharding since it is no worse than the last resort
// (replicated sharding).
other_grouped = try_sharding_for_other_operand(expected_other_sharding);
- if (!other_grouped) {
- other = other.Replicate();
- }
+ }
+ if (!other_grouped) {
+ other = other.Replicate();
}
matching = matching.Reshard(UngroupSharding(matching_grouped));
diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD
index fcacdd7..bd15f20 100644
--- a/third_party/xla/xla/service/spmd/shardy/BUILD
+++ b/third_party/xla/xla/service/spmd/shardy/BUILD
@@ -142,8 +142,8 @@
"//xla/service/spmd/shardy/mhlo_round_trip:mhlo_import",
"//xla/service/spmd/shardy/mhlo_round_trip:shard_map_export",
"//xla/service/spmd/shardy/round_trip_common:convert_sharding_custom_calls",
- "//xla/service/spmd/shardy/round_trip_common:identity_to_pass_through_while_args",
"//xla/service/spmd/shardy/round_trip_common:import_constants",
+ "//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding",
"//xla/service/spmd/shardy/round_trip_common:shard_map_import",
"//xla/service/spmd/shardy/sdy_round_trip:export_ops",
"//xla/service/spmd/shardy/sdy_round_trip:export_shardings",
diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc
index b6c6a99..0ffff71 100644
--- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc
+++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc
@@ -66,7 +66,6 @@
using ::mlir::success;
using ::mlir::sdy::ConstantOp;
-using ::mlir::sdy::IdentityOp;
using ::mlir::sdy::kShardingAttr;
using ::mlir::sdy::ReshardOp;
using ::mlir::sdy::ShardingConstraintOp;
@@ -88,20 +87,6 @@
}
};
-// Removes `sdy::IdentityOp`.
-class IdentityPattern : public OpConversionPattern<IdentityOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- private:
- LogicalResult matchAndRewrite(
- IdentityOp op, OpAdaptor adaptor,
- ConversionPatternRewriter& rewriter) const override {
- rewriter.replaceOp(op, adaptor.getInput());
- return success();
- }
-};
-
class ReshardPattern : public OpConversionPattern<ReshardOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -148,15 +133,14 @@
// We do not expect to see ShardingConstraintOp in the input module.
// ShardingConstraintOp should be replaced by ReshardOp before this pass.
// Hence, we add ShardingConstraintOp as an illegal op.
- target.addIllegalOp<ConstantOp, IdentityOp, ReshardOp,
- ShardingConstraintOp>();
+ target.addIllegalOp<ConstantOp, ReshardOp, ShardingConstraintOp>();
target.addLegalOp<mhlo::ConstantOp, mhlo::CopyOp>();
mlir::RewritePatternSet patterns(&context);
// After converting `sdy.constant` into `mhlo.constant`, the constants
// should not be deduped via folding. Fortunately, folding only happens in
// greedy pattern rewriters. ExportHloShardingsPass does a simple walk,
// which keeps the constants as is.
- patterns.add<ConstantPattern, IdentityPattern, ReshardPattern>(&context);
+ patterns.add<ConstantPattern, ReshardPattern>(&context);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
@@ -166,8 +150,8 @@
StringRef getArgument() const override { return "xla-sdy-export-ops"; }
StringRef getDescription() const override {
- return "Exports Shardy ops to MHLO ops. Processes sdy::IdentityOp, "
- "sdy::ReshardOp, and sdy::ConstantOp.";
+ return "Exports Shardy ops to MHLO ops. Processes sdy::ReshardOp and "
+ "sdy::ConstantOp.";
}
void getDependentDialects(mlir::DialectRegistry& registry) const final {
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD
index f9fd53b..e929f61 100644
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD
@@ -32,9 +32,9 @@
)
cc_library(
- name = "identity_to_pass_through_while_args",
- srcs = ["identity_to_pass_through_while_args.cc"],
- hdrs = ["identity_to_pass_through_while_args.h"],
+ name = "import_constants",
+ srcs = ["import_constants.cc"],
+ hdrs = ["import_constants.h"],
deps = [
"//xla/mlir_hlo",
"@llvm-project//llvm:Support",
@@ -48,9 +48,9 @@
)
cc_library(
- name = "import_constants",
- srcs = ["import_constants.cc"],
- hdrs = ["import_constants.h"],
+ name = "open_while_free_vars_sharding",
+ srcs = ["open_while_free_vars_sharding.cc"],
+ hdrs = ["open_while_free_vars_sharding.h"],
deps = [
"//xla/mlir_hlo",
"@llvm-project//llvm:Support",
@@ -94,8 +94,8 @@
hdrs = ["pipeline_passes.h"],
deps = [
":convert_sharding_custom_calls",
- ":identity_to_pass_through_while_args",
":import_constants",
+ ":open_while_free_vars_sharding",
":shard_map_import",
"//xla/mlir_hlo:mhlo_passes",
"@llvm-project//mlir:FuncDialect",
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc
deleted file mode 100644
index e1675c6..0000000
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h"
-
-#include <memory>
-
-#include "llvm/ADT/StringRef.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/TypeID.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "shardy/dialect/sdy/ir/dialect.h"
-#include "shardy/dialect/sdy/ir/utils.h"
-#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-
-namespace xla {
-namespace sdy {
-
-namespace {
-
-using ::mlir::StringRef;
-
-using ::mlir::func::FuncOp;
-
-// For every block argument of an `mhlo::WhileOp` that is directly returned by
-// the body of the op (pass-through), add an `sdy::IdentityOp` between the block
-// argument and the return op.
-//
-// This will prevent canonicalization from replacing these block arguments with
-// the corresponding operands as free variables.
-class AddIdentityToPassThroughWhileArgsPass
- : public mlir::PassWrapper<AddIdentityToPassThroughWhileArgsPass,
- mlir::OperationPass<FuncOp>> {
- public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- AddIdentityToPassThroughWhileArgsPass)
-
- void runOnOperation() final {
- FuncOp funcOp = getOperation();
- mlir::IRRewriter rewriter(funcOp);
-
- funcOp.walk([&](mlir::mhlo::WhileOp op) {
- mlir::Operation* returnOp = mlir::sdy::getBodyTerminator(op);
- rewriter.setInsertionPoint(returnOp);
- for (mlir::Value returnValue : returnOp->getOperands()) {
- if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(returnValue);
- blockArg && blockArg.getOwner() == &op.getBody().front()) {
- auto identityOp = rewriter.create<mlir::sdy::IdentityOp>(
- returnValue.getLoc(), returnValue);
- rewriter.replaceUsesWithIf(returnValue, identityOp,
- [returnOp](mlir::OpOperand& use) {
- return use.getOwner() == returnOp;
- });
- }
- }
- });
- }
-
- StringRef getArgument() const override {
- return "xla-sdy-add-identity-to-pass-through-while-args";
- }
-
- StringRef getDescription() const override {
- return "Adds an identity op between pass-through block arguments of a "
- "while op.";
- }
-};
-
-} // namespace
-
-std::unique_ptr<mlir::Pass> createAddIdentityToPassThroughWhileArgsPass() {
- return std::make_unique<AddIdentityToPassThroughWhileArgsPass>();
-}
-
-void registerAddIdentityToPassThroughWhileArgsPass() {
- mlir::registerPass(createAddIdentityToPassThroughWhileArgsPass);
-}
-
-} // namespace sdy
-} // namespace xla
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h
deleted file mode 100644
index 5dcb51f..0000000
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_
-#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_
-
-#include <memory>
-
-#include "mlir/Pass/Pass.h"
-
-namespace xla {
-namespace sdy {
-
-// Creates a pass that adds an identity op between pass-through block arguments
-// of a while op.
-std::unique_ptr<mlir::Pass> createAddIdentityToPassThroughWhileArgsPass();
-
-// Registers the xla-sdy-add-identity-to-pass-through-while-args pass.
-void registerAddIdentityToPassThroughWhileArgsPass();
-
-} // namespace sdy
-} // namespace xla
-
-#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc
new file mode 100644
index 0000000..603b270
--- /dev/null
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc
@@ -0,0 +1,95 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
+
+#include <memory>
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "shardy/dialect/sdy/ir/dialect.h"
+#include "shardy/dialect/sdy/ir/utils.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace xla {
+namespace sdy {
+
+namespace {
+
+using ::mlir::StringRef;
+using ::mlir::func::FuncOp;
+using ::mlir::sdy::TensorShardingAttr;
+
+class OpenWhileFreeVarsShardingPass
+ : public mlir::PassWrapper<OpenWhileFreeVarsShardingPass,
+ mlir::OperationPass<FuncOp>> {
+ public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenWhileFreeVarsShardingPass)
+
+ void runOnOperation() final {
+ FuncOp funcOp = getOperation();
+ mlir::IRRewriter rewriter(funcOp);
+
+ funcOp.walk([&](mlir::mhlo::WhileOp op) {
+ llvm::SetVector<mlir::Value> freeVars;
+ mlir::getUsedValuesDefinedAbove(op->getRegions(), freeVars);
+ rewriter.setInsertionPoint(op);
+ for (mlir::Value freeVar : freeVars) {
+ TensorShardingAttr sharding = mlir::sdy::getSharding(freeVar);
+ if (!sharding || sharding.getRank() == 0) {
+ continue;
+ }
+ auto shardingConstraint =
+ rewriter.create<mlir::sdy::ShardingConstraintOp>(
+ freeVar.getLoc(), freeVar,
+ TensorShardingAttr::getFullyOpenLike(sharding));
+ // Only replace uses in the regions of the while op.
+ rewriter.replaceUsesWithIf(
+ freeVar, shardingConstraint, [op](mlir::OpOperand& use) {
+ return op->isProperAncestor(use.getOwner());
+ });
+ }
+ });
+ }
+
+ StringRef getArgument() const override {
+ return "xla-sdy-open-while-free-vars-sharding";
+ }
+
+ StringRef getDescription() const override {
+ return "Adds a fully open sharding constraint to free variables of while "
+ "op that already have a sharding.";
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> createOpenWhileFreeVarsShardingPass() {
+ return std::make_unique<OpenWhileFreeVarsShardingPass>();
+}
+
+void registerOpenWhileFreeVarsShardingPass() {
+ mlir::registerPass(createOpenWhileFreeVarsShardingPass);
+}
+
+} // namespace sdy
+} // namespace xla
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h
new file mode 100644
index 0000000..c06776f
--- /dev/null
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h
@@ -0,0 +1,40 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_
+#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_
+
+#include <memory>
+
+#include "mlir/Pass/Pass.h"
+
+namespace xla {
+namespace sdy {
+
+// Creates a pass that adds a fully open sharding constraint to free variables
+// of while op that already have a user-defined sharding.
+//
+// This allows for their uses in the while op to be further sharded, which is
+// important when converting to HLO as they will be lifted as passthrough while
+// operands/results.
+std::unique_ptr<mlir::Pass> createOpenWhileFreeVarsShardingPass();
+
+// Registers the xla-sdy-open-while-free-vars-sharding pass.
+void registerOpenWhileFreeVarsShardingPass();
+
+} // namespace sdy
+} // namespace xla
+
+#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
index 5ddeb43..23960ab 100644
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
@@ -20,8 +20,8 @@
#include "mlir/Transforms/Passes.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h"
-#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h"
#include "xla/service/spmd/shardy/round_trip_common/import_constants.h"
+#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
#include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h"
namespace xla {
@@ -36,11 +36,6 @@
// changes happen before shardings are added to operations, to ensure the
// correct shardings are added and that they are not lost by this pass.
pm.addNestedPass<FuncOp>(mlir::mhlo::createPrepareForExportPass());
- // The prepare-for-export pass lifts `mhlo::WhileOp` free variables, and added
- // them as additional operands of the op whose corresponding block arguments
- // are directly returned by the body of the op (pass-through). To prevent
- // canonicalization from undoing this, we add identity ops.
- pm.addNestedPass<FuncOp>(createAddIdentityToPassThroughWhileArgsPass());
// We import `mhlo.constant` ops to `sdy.constant` ops so that constants
// aren't folded in greedy pattern rewriters, which would lift them outside of
@@ -51,13 +46,15 @@
pm.addNestedPass<FuncOp>(mlir::mhlo::createFlattenTuplePass());
// We need to canonicalize redundant mhlo::GetTupleElementOp and
- // mhlo::GetTupleOp.
+ // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before
+ // `createOpenWhileFreeVarsShardingPass`.
pm.addPass(mlir::createCanonicalizerPass());
}
void addCommonPostImportPasses(mlir::OpPassManager& pm) {
pm.addPass(createShardMapImportPass());
pm.addPass(createConvertShardingCustomCallsPass());
+ pm.addNestedPass<FuncOp>(createOpenWhileFreeVarsShardingPass());
}
} // namespace sdy
diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc
index c12587c..b5670e7 100644
--- a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc
+++ b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc
@@ -29,8 +29,8 @@
#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
#include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h"
-#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h"
#include "xla/service/spmd/shardy/round_trip_common/import_constants.h"
+#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
#include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h"
@@ -55,7 +55,7 @@
xla::sdy::registerMhloImportShardingsPass();
xla::sdy::registerShardMapImportPass();
xla::sdy::registerConvertShardingCustomCallsPass();
- xla::sdy::registerAddIdentityToPassThroughWhileArgsPass();
+ xla::sdy::registerOpenWhileFreeVarsShardingPass();
xla::sdy::registerImportConstantsPass();
xla::sdy::registerMhloExportPipeline();
diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc
index d4e14da..b5bd21f 100644
--- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc
+++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc
@@ -61,7 +61,6 @@
using ::mlir::success;
using ::mlir::sdy::ConstantOp;
-using ::mlir::sdy::IdentityOp;
using ::mlir::sdy::ShardingConstraintOp;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::sdy::TensorShardingPerValueAttr;
@@ -81,20 +80,6 @@
}
};
-// Removes `sdy::IdentityOp`.
-class IdentityPattern : public OpConversionPattern<IdentityOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- private:
- LogicalResult matchAndRewrite(
- IdentityOp op, OpAdaptor adaptor,
- ConversionPatternRewriter& rewriter) const override {
- rewriter.replaceOp(op, adaptor.getInput());
- return success();
- }
-};
-
class ShardingConstraintPattern
: public OpConversionPattern<ShardingConstraintOp> {
public:
@@ -130,11 +115,10 @@
void runOnOperation() final {
mlir::MLIRContext& context = getContext();
mlir::ConversionTarget target(context);
- target.addIllegalOp<ConstantOp, IdentityOp, ShardingConstraintOp>();
+ target.addIllegalOp<ConstantOp, ShardingConstraintOp>();
target.addLegalOp<mhlo::ConstantOp, mhlo::CustomCallOp>();
mlir::RewritePatternSet patterns(&context);
- patterns.add<ConstantPattern, IdentityPattern, ShardingConstraintPattern>(
- &context);
+ patterns.add<ConstantPattern, ShardingConstraintPattern>(&context);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc
index 59197cc..40463d0 100644
--- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc
+++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc
@@ -535,36 +535,42 @@
TEST_F(ShardyXLATest, WhileWithFreeVariables) {
const char* const hloString = R"(
- HloModule main
+ HloModule main, entry_computation_layout={(f32[32,96]{1,0}, f32[32,96]{1,0})->f32[32,96]{1,0}}
- %region_0.6 (arg_tuple.7: (f32[32,96], s32[], s32[], s32[])) -> (f32[32,96], s32[], s32[], s32[]) {
- %arg_tuple.7 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0)
- %get-tuple-element.8 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=0
- %add.13 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.8, f32[32,96]{1,0} %get-tuple-element.8)
- %get-tuple-element.9 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=1
- %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=3
- %add.12 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.11)
- %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=2
- ROOT %tuple.14 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %add.13, s32[] %add.12, s32[] %get-tuple-element.10, s32[] %get-tuple-element.11)
+ %region_0.7 (arg_tuple.8: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> (f32[32,96], s32[], s32[], s32[], f32[32,96]) {
+ %arg_tuple.8 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0)
+ %get-tuple-element.9 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=0
+ %get-tuple-element.13 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=4
+ %add.15 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.9, f32[32,96]{1,0} %get-tuple-element.13), metadata={source_file="-" source_line=25}
+ %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=1
+ %get-tuple-element.12 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=3
+ %add.14 = s32[] add(s32[] %get-tuple-element.10, s32[] %get-tuple-element.12), metadata={source_file="-" source_line=24}
+ %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=2
+ ROOT %tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %add.15, s32[] %add.14, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[32,96]{1,0} %get-tuple-element.13)
}
- %region_1.15 (arg_tuple.16: (f32[32,96], s32[], s32[], s32[])) -> pred[] {
- %arg_tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0)
- %get-tuple-element.17 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=0
- %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=3
- %get-tuple-element.18 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=1
- %get-tuple-element.19 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=2
- ROOT %compare.21 = pred[] compare(s32[] %get-tuple-element.18, s32[] %get-tuple-element.19), direction=LT
+ %region_1.17 (arg_tuple.18: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> pred[] {
+ %arg_tuple.18 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0)
+ %get-tuple-element.19 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=0
+ %get-tuple-element.22 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=3
+ %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=4
+ %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=1
+ %get-tuple-element.21 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=2
+ ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.20, s32[] %get-tuple-element.21), direction=LT, metadata={source_file="-" source_line=21}
}
- ENTRY %main.27 (Arg_0.1: f32[32,96]) -> f32[32,96] {
+ ENTRY %main.30 (Arg_0.1: f32[32,96], Arg_1.2: f32[32,96]) -> f32[32,96] {
%Arg_0.1 = f32[32,96]{1,0} parameter(0), sharding={devices=[2,2]<=[4]}
- %constant.2 = s32[] constant(0)
- %constant.4 = s32[] constant(32)
- %constant.3 = s32[] constant(1)
- %tuple.5 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.2, s32[] %constant.4, s32[] %constant.3)
- %while.22 = (f32[32,96]{1,0}, s32[], s32[], s32[]) while((f32[32,96]{1,0}, s32[], s32[], s32[]) %tuple.5), condition=%region_1.15, body=%region_0.6
- ROOT %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %while.22), index=0
+ %constant.3 = s32[] constant(0)
+ %constant.5 = s32[] constant(32)
+ %constant.4 = s32[] constant(1)
+ %Arg_1.2 = f32[32,96]{1,0} parameter(1), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}
+ %tuple.6 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.3, s32[] %constant.5, s32[] %constant.4, f32[32,96]{1,0} %Arg_1.2), metadata={source_file="-" source_line=19}
+ %while.25 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) while((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %tuple.6), condition=%region_1.17, body=%region_0.7, metadata={source_file="-" source_line=19}
+ %get-tuple-element.27 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=1, metadata={source_file="-" source_line=19}
+ %get-tuple-element.26 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=0, metadata={source_file="-" source_line=19}
+ %tuple.28 = (f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %get-tuple-element.26)
+ ROOT %get-tuple-element.29 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}) %tuple.28), index=0
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hloString));
@@ -575,10 +581,14 @@
HloInstruction* whileInst =
FindInstruction(module.get(), xla::HloOpcode::kWhile);
EXPECT_NE(whileInst, nullptr);
- EXPECT_THAT(
- whileInst,
- op::Sharding(
- "{{devices=[2,2]<=[4]}, {replicated}, {replicated}, {replicated}}"));
+ // Verify that the sharding of parameter(1) hasn't changed.
+ EXPECT_THAT(module->entry_computation()->parameter_instruction(1),
+ op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}"));
+ // Verify the sharding of the while, and specifically that the sharding of the
+ // result that corresponds to parameter(1) is further sharded.
+ EXPECT_THAT(whileInst,
+ op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, {replicated}, "
+ "{devices=[2,2]<=[4]}, {replicated}}"));
}
TEST_F(ShardyXLATest, ShardMap) {
diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir
index fdc7efb..b022afc 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir
@@ -52,38 +52,44 @@
// -----
+// CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=2, "axis_1"=2>
+
// CHECK-LABEL: func @while_with_free_variables
-func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
+func.func @while_with_free_variables(
+ %arg0: tensor<32x96xf32>,
+ %arg1: tensor<32x96xf32> {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"})
+ -> tensor<32x96xf32> {
// CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
// CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1>
- // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32>
- // CHECK-NEXT: %[[WHILE:.*]]:4 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]], %iterArg_2 = %[[C1]])
+ // CHECK-NEXT: %[[C32:.*]] = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, []>]>} dense<32>
+ // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]>
+ // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
// CHECK-NEXT: cond {
- // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1
+ // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
// CHECK-NEXT: mhlo.return %[[COND]]
// CHECK-NEXT: } do {
- // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %iterArg_2
- // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg
- // CHECK-NEXT: %[[IDENTITY_0:.*]] = sdy.identity %iterArg_1
- // CHECK-NEXT: %[[IDENTITY_1:.*]] = sdy.identity %iterArg_2
- // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY_0]], %[[IDENTITY_1]]
+ // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]]
+ // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]]
+ // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[WHILE]]#0
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.constant dense<1> : tensor<i32>
- %2 = mhlo.constant dense<32> : tensor<i32>
+ %2 = mhlo.constant {mhlo.sharding = "{replicated}"} dense<32> : tensor<i32>
%3:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
cond {
%4 = mhlo.compare LT, %iterArg_0, %2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
mhlo.return %4 : tensor<i1>
} do {
%4 = mhlo.add %iterArg_0, %1 : tensor<i32>
- %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32>
+ %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
mhlo.return %5, %4 : tensor<32x96xf32>, tensor<i32>
}
return %3#0 : tensor<32x96xf32>
}
+// -----
+
// CHECK-LABEL: func @while_with_sinked_constants
func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
// CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
diff --git a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir
new file mode 100644
index 0000000..b87048e
--- /dev/null
+++ b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir
@@ -0,0 +1,93 @@
+// RUN: sdy_opt %s -xla-sdy-open-while-free-vars-sharding 2>&1 | FileCheck %s
+
+sdy.mesh @mesh1 = <"a"=2>
+sdy.mesh @mesh2 = <"b"=2>
+
+// CHECK-LABEL: func @while_with_free_variables
+func.func @while_with_free_variables(
+ %arg0: tensor<32x96xf32>,
+ %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>},
+ %arg2: tensor<32x96xf32>)
+ -> (tensor<32x96xf32>, tensor<32x96xf32>) {
+ // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0>
+ // CHECK-NEXT: %[[C1:.*]] = mhlo.constant dense<1>
+ // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32>
+ // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>}
+ // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]>
+ // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %[[ADD_0]] <@mesh2, [{?}, {?}]>
+ // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
+ // CHECK-NEXT: cond {
+ // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
+ // CHECK-NEXT: mhlo.return %[[COND]]
+ // CHECK-NEXT: } do {
+ // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg_0, %[[C1]]
+ // CHECK-NEXT: %[[ADD_2:.*]] = mhlo.add %iterArg, %[[SC_0]]
+ // CHECK-NEXT: %[[ADD_3:.*]] = mhlo.add %[[ADD_2]], %arg2
+ // CHECK-NEXT: %[[ADD_4:.*]] = mhlo.add %[[ADD_3]], %[[SC_1]]
+ // CHECK-NEXT: mhlo.return %[[ADD_4]], %[[ADD_1]]
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return %[[ADD_0]], %[[WHILE]]#0
+ %0 = mhlo.constant dense<0> : tensor<i32>
+ %1 = mhlo.constant dense<1> : tensor<i32>
+ %2 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor<i32>
+ %3 = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32>
+ %4:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
+ cond {
+ %5 = mhlo.compare LT, %iterArg_0, %2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ mhlo.return %5 : tensor<i1>
+ } do {
+ %5 = mhlo.add %iterArg_0, %1 : tensor<i32>
+ %6 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+ %7 = mhlo.add %6, %arg2 : tensor<32x96xf32>
+ %8 = mhlo.add %7, %3 : tensor<32x96xf32>
+ mhlo.return %8, %5 : tensor<32x96xf32>, tensor<i32>
+ }
+ return %3, %4#0 : tensor<32x96xf32>, tensor<32x96xf32>
+}
+
+// CHECK-LABEL: func @free_var_used_in_multiple_while_ops
+func.func @free_var_used_in_multiple_while_ops(
+ %arg0: tensor<32x96xf32>,
+ %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>})
+ -> tensor<32x96xf32> {
+ // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0>
+ // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32>
+ // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]>
+ // CHECK-NEXT: %[[WHILE_0:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
+ // CHECK-NEXT: cond {
+ // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
+ // CHECK-NEXT: mhlo.return %[[COND]]
+ // CHECK-NEXT: } do {
+ // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg, %[[SC_0]]
+ // CHECK-NEXT: mhlo.return %[[ADD_0]], %iterArg_0
+ // CHECK-NEXT: }
+ // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]>
+ // CHECK-NEXT: %[[WHILE_1:.*]]:2 = mhlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_0 = %[[C0]])
+ // CHECK-NEXT: cond {
+ // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
+ // CHECK-NEXT: mhlo.return %[[COND]]
+ // CHECK-NEXT: } do {
+ // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC_1]]
+ // CHECK-NEXT: mhlo.return %[[ADD_1]], %iterArg_0
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return %[[WHILE_1]]#0
+ %0 = mhlo.constant dense<0> : tensor<i32>
+ %1 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor<i32>
+ %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
+ cond {
+ %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ mhlo.return %4 : tensor<i1>
+ } do {
+ %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+ mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor<i32>
+ }
+ %3:2 = mhlo.while(%iterArg = %2#0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
+ cond {
+ %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ mhlo.return %4 : tensor<i1>
+ } do {
+ %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+ mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor<i32>
+ }
+ return %3#0 : tensor<32x96xf32>
+}
diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir
index dcd81b2..66b227d 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir
@@ -13,6 +13,7 @@
// CHECK: sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
+// CHECK-LABEL: func @main
func.func @main(
// CHECK: %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>})
%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}
@@ -35,6 +36,7 @@
// CHECK: sdy.mesh @mesh = <"a"=2, "b"=2>
sdy.mesh @mesh = <"a"=2, "b"=2>
+// CHECK-LABEL: func @main
func.func @main(
// CHECK: %arg0: tensor<8x16xf32>)
%arg0: tensor<8x16xf32>
@@ -55,6 +57,7 @@
// CHECK: sdy.mesh @mesh = <"a"=2, "b"=2>
sdy.mesh @mesh = <"a"=2, "b"=2>
+// CHECK-LABEL: func @main
func.func @main(
// CHECK: %arg0: tensor<8x16xf32>)
%arg0: tensor<8x16xf32>
@@ -78,6 +81,7 @@
// CHECK: sdy.mesh @mesh = <"a"=2, "b"=2>
sdy.mesh @mesh = <"a"=2, "b"=2>
+// CHECK-LABEL: func @main
func.func @main(
// CHECK: %arg0: tensor<8x16xf32>)
%arg0: tensor<8x16xf32>
@@ -97,7 +101,7 @@
// CHECK: sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
-// CHECK: @main(
+// CHECK-LABEL: @main(
// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}p4]>},
// CHECK-SAME: %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32>
// CHECK-SAME: ) -> tensor<8x8xf32> {
@@ -122,6 +126,7 @@
// CHECK: sdy.mesh @mesh = <"data"=2>
sdy.mesh @mesh = <"data"=2>
+// CHECK-LABEL: func @main
func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
// CHECK: sdy.sharding_constraint %arg0 <@mesh, [{"data", ?}, {?}]> : tensor<8x8xf32>
%0 = sdy.sharding_constraint %arg0 <@mesh, [{"data", ?}, {?}]> : tensor<8x8xf32>
@@ -144,6 +149,7 @@
// CHECK: sdy.mesh @mesh_2 = <"x"=8, "y"=4>
sdy.mesh @mesh_2 = <"x"=8, "y"=4>
+// CHECK-LABEL: func @main
func.func @main(
// CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) {
%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) {
@@ -157,36 +163,41 @@
// -----
+// CHECK: sdy.mesh @mesh = <"x"=2>
+sdy.mesh @mesh = <"x"=2>
+
// Test WhileOp with lifted free variables and sinked constants.
-// CHECK: func @main
-func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
+// CHECK-LABEL: func @main
+func.func @main(
+ %arg0: tensor<32x96xf32>,
+ %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}})
+ -> tensor<32x96xf32> {
// CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
// CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32>
- // CHECK-NEXT: %[[WHILE:.*]]:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]])
+ // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]>
+ // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
// CHECK-NEXT: cond {
- // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1
+ // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
// CHECK-NEXT: mhlo.return %[[COND]]
// CHECK-NEXT: } do {
// CHECK-DAG: %[[C1:.*]] = sdy.constant dense<1>
// CHECK-DAG: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]]
- // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg
- // CHECK-DAG: %[[IDENTITY:.*]] = sdy.identity %iterArg_1
- // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY]]
+ // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]]
+ // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[WHILE]]#0
%0 = sdy.constant dense<0> : tensor<i32>
%1 = sdy.constant dense<32> : tensor<i32>
- %2:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0, %iterArg_1 = %1) : tensor<32x96xf32>, tensor<i32>, tensor<i32>
+ %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
cond {
- %3 = mhlo.compare LT, %iterArg_0, %iterArg_1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = mhlo.compare LT, %iterArg_0, %1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
mhlo.return %3 : tensor<i1>
} do {
%3 = sdy.constant dense<1> : tensor<i32>
%4 = mhlo.add %iterArg_0, %3 : tensor<i32>
- %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32>
- %6 = sdy.identity %iterArg_1 : tensor<i32>
- mhlo.return %5, %4, %6 : tensor<32x96xf32>, tensor<i32>, tensor<i32>
+ %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+ mhlo.return %5, %4 : tensor<32x96xf32>, tensor<i32>
}
return %2#0 : tensor<32x96xf32>
}
diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir
index b180caa..89b8722 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir
@@ -89,11 +89,10 @@
return %0 : tensor<8x8xf32>
}
-// CHECK-LABEL: func @identity_and_constant
-func.func @identity_and_constant() -> tensor<i32> {
+// CHECK-LABEL: func @constant
+func.func @constant() -> tensor<i32> {
// CHECK-NEXT: %[[CONST:.*]] = mhlo.constant dense<0>
// CHECK-NEXT: return %[[CONST]]
%0 = sdy.constant dense<0> : tensor<i32>
- %1 = sdy.identity %0 : tensor<i32>
- return %1 : tensor<i32>
+ return %0 : tensor<i32>
}
diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
index 2354d83..e782de0 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
@@ -30,20 +30,22 @@
}
// CHECK-LABEL: func @while_with_free_variables
- func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
+ func.func @while_with_free_variables(
+ %arg0: tensor<32x96xf32>,
+ %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}})
+ -> tensor<32x96xf32> {
// CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
// CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1>
// CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32>
- // CHECK-NEXT: %[[WHILE:.*]]:4 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]], %iterArg_2 = %[[C1]])
+ // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]>
+ // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
// CHECK-NEXT: cond {
- // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1
+ // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
// CHECK-NEXT: mhlo.return %[[COND]]
// CHECK-NEXT: } do {
- // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %iterArg_2
- // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg
- // CHECK-NEXT: %[[IDENTITY_0:.*]] = sdy.identity %iterArg_1
- // CHECK-NEXT: %[[IDENTITY_1:.*]] = sdy.identity %iterArg_2
- // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY_0]], %[[IDENTITY_1]]
+ // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]]
+ // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]]
+ // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[WHILE]]#0
%0 = mhlo.constant dense<0> : tensor<i32>
@@ -55,7 +57,7 @@
mhlo.return %4 : tensor<i1>
} do {
%4 = mhlo.add %iterArg_0, %1 : tensor<i32>
- %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32>
+ %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
mhlo.return %5, %4 : tensor<32x96xf32>, tensor<i32>
}
return %3#0 : tensor<32x96xf32>
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc
index c3fc8b1..c02493e 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc
@@ -571,8 +571,11 @@
"not able to go from sharding "
<< sharding().ToString(/*include_metadata=*/true) << " to "
<< target.ToString(/*include_metadata=*/true)
- << " without doing a full rematerialization of the tensor. You "
- "probably want to enrich the sharding annotations to prevent "
+ << " without doing a full rematerialization of the tensor for HLO "
+ "operation: "
+ << hlo_->ToString()
+ << ". You probably want to enrich the sharding annotations to "
+ "prevent "
"this from happening.";
}
return Replicate().Reshard(target);
@@ -3316,6 +3319,7 @@
auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, operand_shape, "true_branch_param"));
std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(operands.size());
for (int64_t i = 0; i < operands.size(); ++i) {
new_operands.push_back(true_b.AddInstruction(
HloInstruction::CreateGetTupleElement(*operand_shapes[i], param, i)));
@@ -4040,20 +4044,21 @@
const HloSharding& sharding = hlo->sharding();
// Shardings for the body parameter, body root, and cond parameter must be
- // the same, and the condition root must be replicated so that all partitions
- // follow the same control flow.
+ // the same.
hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
- const HloSharding& cond_root_sharding =
- hlo->while_condition()->root_instruction()->sharding();
- TF_RETURN_IF_ERROR(partitioner_
- ->PartitionComputation(hlo->while_condition(),
- cond_root_sharding.IsManual()
- ? cond_root_sharding
- : HloSharding::Replicate(),
- next_channel_id_, logger_,
- call_graph_)
- .status());
+
+ // The condition root must be replicated so that all partitions follow the
+ // same control flow.
+ HloInstruction* cond_root = hlo->while_condition()->root_instruction();
+ const HloSharding cond_root_sharding =
+ hlo_sharding_util::ReplicateAllDataDims(cond_root->sharding());
+ cond_root->set_sharding(cond_root_sharding);
+ TF_RETURN_IF_ERROR(
+ partitioner_
+ ->PartitionComputation(hlo->while_condition(), cond_root_sharding,
+ next_channel_id_, logger_, call_graph_)
+ .status());
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(hlo->while_body(), sharding,
next_channel_id_, logger_,
@@ -4129,6 +4134,7 @@
if (hlo->sharding().IsManual()) {
auto clone_from_original = [&](const HloSharding& shared_sharding) {
std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(hlo->operand_count());
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
@@ -4310,6 +4316,7 @@
}
auto clone_from_original = [&](const HloSharding& shared_sharding) {
std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(hlo->operand_count());
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
@@ -4340,6 +4347,7 @@
TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
// Replicate the operands and run partitioned Rng on all devices.
std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(hlo->operand_count());
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
.Reshard(HloSharding::Replicate())
@@ -4659,6 +4667,7 @@
absl::Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(hlo->operand_count());
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i))
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc
index e24e32f..0d1d4a7 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc
@@ -4474,6 +4474,36 @@
EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
}
+TEST_P(SpmdPartitioningTest, WhilePartialManual) {
+ absl::string_view hlo_string = R"(
+HloModule module
+
+LoopCond {
+ x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+ const = s32[] constant(5), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+ ROOT lt = pred[] compare(x, const), direction=LT, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+}
+
+Inc {
+ x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+ const = s32[] constant(1), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+ ROOT add = s32[] add(x, const), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+}
+
+ENTRY entry {
+ zero = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+ ROOT while = s32[] while(zero), body=Inc, condition=LoopCond, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ PartitionComputation(hlo_string, /*num_devices=*/4));
+ VLOG(1) << module->ToString();
+
+ auto zero = AllOf(op::Parameter(0), op::Shape("s32[]"));
+ const auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
+}
+
TEST_P(SpmdPartitioningTest, TestWhileFrontendAttributes) {
absl::string_view hlo_string = R"(
HloModule module
@@ -9163,6 +9193,29 @@
EXPECT_THAT(root, op::AllReduce(dot));
}
+TEST_P(SpmdPartitioningTest, ReplicateLHSofConv) {
+ const char* const hlo_string = R"(
+HloModule module
+ENTRY main {
+ lhs = bf16[128,8,8,1280] parameter(0), sharding={devices=[128,1,1,1]<=[128]}
+ rhs = bf16[3,3,1280,1280] parameter(1), sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate}
+ ROOT conv = bf16[128,8,8,1280] convolution(lhs, rhs), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module, PartitionComputation(hlo_string, /*num_devices=*/128));
+ VLOG(1) << module->ToString();
+
+ const auto lhs = AllOf(op::Shape("bf16[128,8,8,1280]"),
+ op::AllReduce(op::DynamicUpdateSlice(
+ op::Broadcast(), op::Parameter(0), _, _, _, _)));
+ const auto rhs = AllOf(op::Shape("bf16[3,3,1280,160]"), op::Parameter(1));
+ const auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ AllOf(op::Shape("bf16[128,8,8,160]"), op::Convolution(lhs, rhs)));
+}
+
TEST_P(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) {
absl::string_view hlo_string = R"(
HloModule module
@@ -10911,6 +10964,53 @@
_));
}
+TEST_P(SpmdPartitioningTest, ScatterRepsOnLastTileDimDontDivideGroups) {
+ absl::string_view hlo_string = R"(
+HloModule module
+
+region.1 {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT res.1 = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+ %add.1 = f32[8,96,2048,16]{3,2,1,0} parameter(0)
+ %concatenate.1 = s32[8,96,2048,2,4]{4,3,2,1,0} parameter(1)
+ %broadcast.1 = f32[8,96,2048,2]{3,2,1,0} parameter(2)
+
+ %add.1.shard = f32[8,96,2048,16]{3,2,1,0} copy(%add.1), sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+ %concatenate.1.shard = s32[8,96,2048,2,4]{4,3,2,1,0} copy(%concatenate.1), sharding={devices=[8,8,1,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+ %broadcast.1.shard = f32[8,96,2048,2]{3,2,1,0} copy(%broadcast.1), sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+
+ ROOT %scatter.44 = f32[8,96,2048,16]{3,2,1,0} scatter(
+ %add.1.shard,
+ %concatenate.1.shard,
+ %broadcast.1.shard),
+ update_window_dims={},
+ inserted_window_dims={0,1,2,3},
+ scatter_dims_to_operand_dims={0,1,2,3},
+ index_vector_dim=4,
+ to_apply=region.1,
+ sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module, PartitionComputation(hlo_string, /*num_devices=*/1536));
+ VLOG(1) << module->ToString();
+ // Verify scatter is partitioned properly.
+ {
+ const auto partitioned_scatter =
+ module->entry_computation()->root_instruction();
+ auto operand = AllOf(op::Shape("f32[1,12,2048,16]"));
+ auto indices = AllOf(op::Shape("s32[8,96,2048,2,4]"));
+ auto update = AllOf(op::Shape("f32[8,96,2048,2]"));
+ auto scatter = AllOf(op::Shape("f32[1,12,2048,16]"),
+ op::Scatter(operand, indices, update));
+ EXPECT_THAT(partitioned_scatter, scatter);
+ }
+}
+
TEST_P(SpmdPartitioningTest, ParallelDimFromOutsideConditionalPositive) {
absl::string_view hlo_string = R"(
HloModule module
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h
index 65b5d01..a982c3e 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h
@@ -84,6 +84,7 @@
PrimitiveType)) {
if (shape.IsTuple()) {
std::vector<HloInstruction*> elements;
+ elements.reserve(ShapeUtil::TupleElementCount(shape));
for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
elements.push_back(
CreateConstantBase(ShapeUtil::GetTupleElementShape(shape, i),
diff --git a/third_party/xla/xla/service/stable_sort_expander.cc b/third_party/xla/xla/service/stable_sort_expander.cc
index 910ab5d..ca87dce 100644
--- a/third_party/xla/xla/service/stable_sort_expander.cc
+++ b/third_party/xla/xla/service/stable_sort_expander.cc
@@ -55,7 +55,6 @@
HloComputation* computation = sort->parent();
HloInstruction* expanded_sort = nullptr;
- absl::flat_hash_set<int64_t> used_indices;
int64_t iota_index = IotaOperandIndexForStableSort(*sort);
// If there is currently no iota operand which we could use for making the
diff --git a/third_party/xla/xla/service/stream_pool_test.cc b/third_party/xla/xla/service/stream_pool_test.cc
index fd0a05e..2bea411 100644
--- a/third_party/xla/xla/service/stream_pool_test.cc
+++ b/third_party/xla/xla/service/stream_pool_test.cc
@@ -26,22 +26,21 @@
class StreamPoolTest : public ::testing::Test {
protected:
- std::unique_ptr<se::StreamExecutor> NewStreamExecutor() {
+ se::StreamExecutor* NewStreamExecutor() {
se::Platform* platform =
se::PlatformManager::PlatformWithName("Host").value();
- se::StreamExecutorConfig config(/*ordinal=*/0);
- return platform->GetUncachedExecutor(config).value();
+ return platform->ExecutorForDevice(/*ordinal=*/0).value();
}
};
TEST_F(StreamPoolTest, EmptyPool) {
- std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
- StreamPool pool(executor.get());
+ se::StreamExecutor* executor = NewStreamExecutor();
+ StreamPool pool(executor);
}
TEST_F(StreamPoolTest, OneStreamPool) {
- std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
- StreamPool pool(executor.get());
+ se::StreamExecutor* executor = NewStreamExecutor();
+ StreamPool pool(executor);
// Borrow and return a stream.
StreamPool::Ptr stream1 = pool.BorrowStream();
@@ -61,8 +60,8 @@
}
TEST_F(StreamPoolTest, TwoStreamPool) {
- std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
- StreamPool pool(executor.get());
+ se::StreamExecutor* executor = NewStreamExecutor();
+ StreamPool pool(executor);
// Borrow two streams.
StreamPool::Ptr stream1 = pool.BorrowStream();
diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc
index d1fd7ac..07b49db 100644
--- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc
+++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc
@@ -136,10 +136,6 @@
}
bool changed = false;
-
- absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>
- conditional_gte_index_to_insts =
- WhileUtil::GetGTEsMapForWhileConditional(*while_cond);
std::vector<HloInstruction*> invariant_body_gtes =
WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
std::vector<int64_t> tuple_indices;
diff --git a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
index 1901cf2e..0cae8b4 100644
--- a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
+++ b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
@@ -15,7 +15,7 @@
version: 3
results {
device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB"
- hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false},\"force_earliest_schedule\":false}"
+ hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"1\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"9\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"9\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
gemm {
algorithm: 13
@@ -24,7 +24,7 @@
}
results {
device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB"
- hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false}"
+ hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"leakyrelu_alpha\":0,\"side_input_scale\":0},\"force_earliest_schedule\":false,\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
run_time {
nanos: 8192
diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc
index 0deca08..56dd78d 100644
--- a/third_party/xla/xla/shape_util.cc
+++ b/third_party/xla/xla/shape_util.cc
@@ -16,7 +16,6 @@
#include "xla/shape_util.h"
#include <algorithm>
-#include <array>
#include <climits>
#include <cstddef>
#include <cstdint>
@@ -1982,8 +1981,9 @@
// Returns the indices of the first elements of all consecutive subarrays of the
// given array. For example:
// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64_t> xs) {
- std::vector<size_t> is = {0};
+static absl::InlinedVector<size_t, 3> ConsecutiveSegments(
+ absl::Span<const int64_t> xs) {
+ absl::InlinedVector<size_t, 3> is = {0};
for (size_t i = 1; i < xs.size(); ++i) {
if (1 != xs[i] - xs[i - 1]) {
is.push_back(i);
@@ -2010,19 +2010,22 @@
dimensions);
}
-static std::vector<int64_t> MajorToMinorLayout(const Shape& s) {
+static absl::InlinedVector<int64_t, 3> MajorToMinorLayout(const Shape& s) {
absl::Span<const int64_t> minor_to_major = LayoutUtil::MinorToMajor(s);
- return std::vector<int64_t>{minor_to_major.rbegin(), minor_to_major.rend()};
+ return absl::InlinedVector<int64_t, 3>{minor_to_major.rbegin(),
+ minor_to_major.rend()};
}
-static std::optional<Vector3> GetNormalizedTransposeShapeHelper(
+static std::optional<absl::InlinedVector<int64_t, 3>>
+GetNormalizedTransposeShapeHelper(
const Shape& input_shape, absl::Span<int64_t const> output_to_input,
- const Vector3& permutation) {
+ const absl::InlinedVector<int64_t, 3>& permutation) {
// 'permutation' should not be the identity permutation.
if (permutation[0] == 0 && permutation[1] == 1 && permutation[2] == 2) {
return std::nullopt;
}
- std::vector<size_t> segments = ConsecutiveSegments(output_to_input);
+ absl::InlinedVector<size_t, 3> segments =
+ ConsecutiveSegments(output_to_input);
if (segments.size() > 3) {
return std::nullopt;
}
@@ -2031,8 +2034,9 @@
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input_shape);
Shape normalized_shape = MergeDimensions(segments, normalized_input_shape);
- std::vector<int64_t> normalized_dims{normalized_shape.dimensions().begin(),
- normalized_shape.dimensions().end()};
+ absl::InlinedVector<int64_t, 3> normalized_dims{
+ normalized_shape.dimensions().begin(),
+ normalized_shape.dimensions().end()};
if (segments.size() == 2) {
// If we have two segments, we know that at least one transpose is
// happening, otherwise we would have only 1 segment.
@@ -2050,9 +2054,9 @@
normalized_dims.insert(normalized_dims.begin() + untransposed, 1);
} else if (segments.size() == 3) {
// Derive the order from the segments.
- Vector3 segment_order{output_to_input[segments[0]],
- output_to_input[segments[1]],
- output_to_input[segments[2]]};
+ absl::InlinedVector<int64_t, 3> segment_order{output_to_input[segments[0]],
+ output_to_input[segments[1]],
+ output_to_input[segments[2]]};
// We expect the same relative order.
for (int64_t i = 1; i < 3; ++i) {
if ((segment_order[i] > segment_order[i - 1]) !=
@@ -2062,31 +2066,32 @@
}
}
if (normalized_dims.size() == 3) {
- return Vector3{normalized_dims[permutation[0]],
- normalized_dims[permutation[1]],
- normalized_dims[permutation[2]]};
+ return absl::InlinedVector<int64_t, 3>{normalized_dims[permutation[0]],
+ normalized_dims[permutation[1]],
+ normalized_dims[permutation[2]]};
}
return std::nullopt;
}
-/* static */ std::optional<Vector3>
+/* static */ std::optional<absl::InlinedVector<int64_t, 3>>
ShapeUtil::GetNormalizedLogicalTransposeShape(
const Shape& input_shape, const Shape& output_shape,
- absl::Span<int64_t const> dimensions, const Vector3& permutation) {
+ absl::Span<int64_t const> dimensions,
+ const absl::InlinedVector<int64_t, 3>& permutation) {
if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) ||
!LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) {
// Only works on default layouts.
return std::nullopt;
}
// Drop degenerate dimensions.
- std::vector<int64_t> delta(input_shape.rank() + 1, 0);
+ absl::InlinedVector<int64_t, 3> delta(input_shape.rank() + 1, 0);
for (int i = 0; i < input_shape.rank(); ++i) {
delta[i + 1] = delta[i];
if (input_shape.dimensions(i) == static_cast<int64_t>(1)) {
++delta[i + 1];
}
}
- std::vector<int64_t> new_dimensions;
+ absl::InlinedVector<int64_t, 3> new_dimensions;
for (int i = 0; i < dimensions.size(); i++) {
if (output_shape.dimensions(i) != 1) {
new_dimensions.push_back(dimensions[i] - delta[dimensions[i]]);
@@ -2098,15 +2103,18 @@
permutation);
}
-/* static */ std::optional<Vector3> ShapeUtil::GetNormalizedTransposeShape(
+/* static */ std::optional<absl::InlinedVector<int64_t, 3>>
+ShapeUtil::GetNormalizedTransposeShape(
const Shape& input_shape, const Shape& output_shape,
- const Vector3& permutation) {
+ const absl::InlinedVector<int64_t, 3>& permutation) {
if (!ShapeUtil::CompatibleIgnoringElementType(input_shape, output_shape)) {
return std::nullopt;
}
- std::vector<int64_t> major_to_minor_input = MajorToMinorLayout(input_shape);
- std::vector<int64_t> major_to_minor_output = MajorToMinorLayout(output_shape);
+ absl::InlinedVector<int64_t, 3> major_to_minor_input =
+ MajorToMinorLayout(input_shape);
+ absl::InlinedVector<int64_t, 3> major_to_minor_output =
+ MajorToMinorLayout(output_shape);
std::vector<int64_t> output_to_input = ComposePermutations(
InversePermutation(major_to_minor_output), major_to_minor_input);
diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h
index a773ade..03a809d 100644
--- a/third_party/xla/xla/shape_util.h
+++ b/third_party/xla/xla/shape_util.h
@@ -44,7 +44,6 @@
#include "xla/primitive_util.h"
#include "xla/printer.h"
#include "xla/shape.h"
-#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h" // IWYU pragma: keep
@@ -1030,14 +1029,17 @@
// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
// normalized shape of `b` or the 0-2-1 shape. In general, the
// permutation[0]-permutation[1]-permutation[2] shape is returned.
- static std::optional<Vector3> GetNormalizedTransposeShape(
+ static std::optional<absl::InlinedVector<int64_t, 3>>
+ GetNormalizedTransposeShape(
const Shape& input_shape, const Shape& output_shape,
- const Vector3& permutation);
+ const absl::InlinedVector<int64_t, 3>& permutation);
// Entry point for physical + logical transposition.
- static std::optional<Vector3> GetNormalizedLogicalTransposeShape(
+ static std::optional<absl::InlinedVector<int64_t, 3>>
+ GetNormalizedLogicalTransposeShape(
const Shape& input_shape, const Shape& output_shape,
- absl::Span<int64_t const> dimensions, const Vector3& permutation);
+ absl::Span<int64_t const> dimensions,
+ const absl::InlinedVector<int64_t, 3>& permutation);
// Strips device-specific information, namely tiling and memory-space
// information, from a shape.
diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc
index e7c1beb..a66b148 100644
--- a/third_party/xla/xla/shape_util_test.cc
+++ b/third_party/xla/xla/shape_util_test.cc
@@ -23,6 +23,7 @@
#include <variant>
#include <vector>
+#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@@ -1395,8 +1396,9 @@
Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0});
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 128}, {0, 1});
- EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
- shape, transposed, Vector3{0, 2, 1}));
+ EXPECT_EQ(std::nullopt,
+ ShapeUtil::GetNormalizedTransposeShape(
+ shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, NoTranspose2) {
@@ -1404,8 +1406,9 @@
ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64, 32}, {2, 1, 0});
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 128}, {0, 1, 2});
- EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
- shape, transposed, Vector3{0, 1, 2}));
+ EXPECT_EQ(std::nullopt,
+ ShapeUtil::GetNormalizedTransposeShape(
+ shape, transposed, absl::InlinedVector<int64_t, 3>{0, 1, 2}));
}
TEST(Transpose021Test, WrongTranspose) {
@@ -1414,7 +1417,8 @@
Shape output_shape =
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2});
EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
- input_shape, output_shape, Vector3{0, 2, 1}));
+ input_shape, output_shape,
+ absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, WrongTranspose2) {
@@ -1422,7 +1426,8 @@
Shape output_shape =
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1});
EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
- input_shape, output_shape, Vector3{0, 1, 2}));
+ input_shape, output_shape,
+ absl::InlinedVector<int64_t, 3>{0, 1, 2}));
}
TEST(Transpose021Test, WrongTranspose3) {
@@ -1430,16 +1435,17 @@
Shape output_shape =
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1});
EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
- input_shape, output_shape, Vector3{1, 2, 0}));
+ input_shape, output_shape,
+ absl::InlinedVector<int64_t, 3>{1, 2, 0}));
}
TEST(Transpose021Test, Simple) {
Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0});
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {0, 1});
- EXPECT_EQ(std::make_optional(Vector3{1, 64, 128}),
- ShapeUtil::GetNormalizedTransposeShape(shape, transposed,
- Vector3{0, 2, 1}));
+ EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 64, 128}),
+ ShapeUtil::GetNormalizedTransposeShape(
+ shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, Simple2) {
@@ -1447,9 +1453,10 @@
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0});
Shape output_shape =
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {1, 2, 0});
- EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}),
- ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
- Vector3{0, 2, 1}));
+ EXPECT_EQ(
+ std::make_optional(absl::InlinedVector<int64_t, 3>{8, 16, 32768}),
+ ShapeUtil::GetNormalizedTransposeShape(
+ input_shape, output_shape, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, Simple3) {
@@ -1457,18 +1464,20 @@
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0});
Shape output_shape =
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2});
- EXPECT_EQ(std::make_optional(Vector3{16, 32768, 8}),
- ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
- Vector3{2, 1, 0}));
+ EXPECT_EQ(
+ std::make_optional(absl::InlinedVector<int64_t, 3>{16, 32768, 8}),
+ ShapeUtil::GetNormalizedTransposeShape(
+ input_shape, output_shape, absl::InlinedVector<int64_t, 3>{2, 1, 0}));
}
TEST(Transpose021Test, Simple4) {
Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0});
Shape output_shape =
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1});
- EXPECT_EQ(std::make_optional(Vector3{16, 1, 8}),
- ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
- Vector3{2, 1, 0}));
+ EXPECT_EQ(
+ std::make_optional(absl::InlinedVector<int64_t, 3>{16, 1, 8}),
+ ShapeUtil::GetNormalizedTransposeShape(
+ input_shape, output_shape, absl::InlinedVector<int64_t, 3>{2, 1, 0}));
}
TEST(Transpose021Test, LargeView) {
@@ -1476,9 +1485,10 @@
F32, {8, 32, 32, 32, 16}, {4, 3, 2, 1, 0});
Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(
F32, {8, 32, 32, 32, 16}, {3, 2, 1, 4, 0});
- EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}),
- ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
- Vector3{0, 2, 1}));
+ EXPECT_EQ(
+ std::make_optional(absl::InlinedVector<int64_t, 3>{8, 16, 32768}),
+ ShapeUtil::GetNormalizedTransposeShape(
+ input_shape, output_shape, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, LargeSizeOverflowTest) {
@@ -1487,7 +1497,8 @@
Shape output_shape =
ShapeUtil::MakeShapeWithDenseLayout(BF16, {4096, 4096, 128}, {2, 1, 0});
EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
- input_shape, output_shape, Vector3{0, 2, 1}));
+ input_shape, output_shape,
+ absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, Batched) {
@@ -1495,9 +1506,9 @@
ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {2, 1, 0});
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {1, 0, 2});
- EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}),
- ShapeUtil::GetNormalizedTransposeShape(shape, transposed,
- Vector3{0, 2, 1}));
+ EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 64, 96}),
+ ShapeUtil::GetNormalizedTransposeShape(
+ shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, BatchedLogical) {
@@ -1506,9 +1517,10 @@
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 32, 3}, {2, 1, 0});
std::vector<int64_t> dimensions = {2, 0, 1};
- EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}),
+ EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 64, 96}),
ShapeUtil::GetNormalizedLogicalTransposeShape(
- shape, transposed, dimensions, Vector3{0, 2, 1}));
+ shape, transposed, dimensions,
+ absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, LogicalWithDegenerateDims) {
@@ -1517,9 +1529,10 @@
Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(
F32, {1, 32, 1, 64, 1, 3, 1}, {6, 5, 4, 3, 2, 1, 0});
std::vector<int64_t> dimensions = {6, 1, 4, 5, 2, 3, 0};
- EXPECT_EQ(std::make_optional(Vector3{32, 64, 3}),
+ EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{32, 64, 3}),
ShapeUtil::GetNormalizedLogicalTransposeShape(
- shape, transposed, dimensions, Vector3{0, 2, 1}));
+ shape, transposed, dimensions,
+ absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, LogicalWithDegenerateLastDim) {
@@ -1528,9 +1541,10 @@
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 1}, {2, 1, 0});
std::vector<int64_t> dimensions = {2, 1, 0};
- EXPECT_EQ(std::make_optional(Vector3{1, 32, 64}),
+ EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 32, 64}),
ShapeUtil::GetNormalizedLogicalTransposeShape(
- shape, transposed, dimensions, Vector3{0, 2, 1}));
+ shape, transposed, dimensions,
+ absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose021Test, Large) {
@@ -1538,9 +1552,9 @@
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {3, 2, 1, 0});
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {2, 1, 3, 0});
- EXPECT_EQ(std::make_optional(Vector3{8, 65, 961}),
- ShapeUtil::GetNormalizedTransposeShape(shape, transposed,
- Vector3{0, 2, 1}));
+ EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{8, 65, 961}),
+ ShapeUtil::GetNormalizedTransposeShape(
+ shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
}
TEST(Transpose210Test, LogicalTranspose) {
@@ -1549,9 +1563,10 @@
Shape transposed =
ShapeUtil::MakeShapeWithDenseLayout(F32, {13, 12, 10, 11}, {3, 2, 1, 0});
std::vector<int64_t> dimensions = {3, 2, 0, 1};
- EXPECT_EQ(std::make_optional(Vector3{13, 12, 110}),
+ EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{13, 12, 110}),
ShapeUtil::GetNormalizedLogicalTransposeShape(
- shape, transposed, dimensions, Vector3{2, 1, 0}));
+ shape, transposed, dimensions,
+ absl::InlinedVector<int64_t, 3>{2, 1, 0}));
}
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
diff --git a/third_party/xla/xla/sort_json.cc b/third_party/xla/xla/sort_json.cc
new file mode 100644
index 0000000..aaa1e19
--- /dev/null
+++ b/third_party/xla/xla/sort_json.cc
@@ -0,0 +1,257 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/sort_json.h"
+
+#include <algorithm>
+#include <cctype>
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace {
+
+void SkipWhitespace(absl::string_view json, size_t& index) {
+ while (index < json.size() && std::isspace(json[index])) {
+ ++index;
+ }
+}
+
+absl::Status CheckNotEndOfString(absl::string_view json, int index,
+ absl::string_view expected) {
+ return index < json.size()
+ ? absl::OkStatus()
+ : absl::InvalidArgumentError(absl::StrCat(
+ "Prematurely reached end of JSON while looking for ",
+ expected, "."));
+}
+
+absl::Status Consume(absl::string_view json, size_t& index, char c,
+ bool optional = false) {
+ SkipWhitespace(json, index);
+ TF_RETURN_IF_ERROR(CheckNotEndOfString(json, index, std::string(1, c)));
+ if (json[index] == c) {
+ ++index;
+ SkipWhitespace(json, index);
+ } else if (!optional) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Expected '", std::string(1, c), "', but found '",
+ std::string(1, json[index]), "'."));
+ }
+ return absl::OkStatus();
+}
+
+struct JsonArray;
+struct JsonObject;
+
+using JsonValue = std::variant<absl::string_view, std::unique_ptr<JsonObject>,
+ std::unique_ptr<JsonArray>>;
+
+struct JsonField {
+ absl::string_view name;
+ JsonValue value;
+};
+
+template <typename T>
+struct JsonSequence {
+ std::vector<T> elements;
+};
+
+struct JsonArray : public JsonSequence<JsonValue> {};
+struct JsonObject : public JsonSequence<JsonField> {};
+
+// This parses either an array or an object.
+template <typename T, char begin, char end, const char* name, typename ElemFn>
+absl::StatusOr<std::unique_ptr<T>> ParseSequence(absl::string_view outer_json,
+ size_t& index,
+ ElemFn elem_fn) {
+ TF_RETURN_IF_ERROR(Consume(outer_json, index, begin));
+ TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name));
+
+ auto seq = std::make_unique<T>();
+ while (outer_json[index] != end) {
+ TF_ASSIGN_OR_RETURN(auto elem, elem_fn(outer_json, index));
+ seq->elements.emplace_back(std::move(elem));
+ TF_RETURN_IF_ERROR(Consume(outer_json, index, ',', /*optional=*/true));
+ TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name));
+ }
+ TF_RETURN_IF_ERROR(Consume(outer_json, index, end));
+ return seq;
+}
+
+absl::Status EnsureValidLiteralStart(char c) {
+ if (c != '"' && c != '+' && c != '-' && c != 'f' && c != 't' && c != 'n' &&
+ (c < '0' || c > '9')) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Invalid first character of literal: '", std::string(1, c), "'."));
+ }
+ return absl::OkStatus();
+}
+
+bool HandleEscape(absl::string_view outer_json, size_t& index,
+ bool& is_escaped) {
+ if (is_escaped) {
+ is_escaped = false;
+ ++index;
+ return true;
+ }
+
+ if (outer_json[index] == '\\') {
+ is_escaped = true;
+ ++index;
+ return true;
+ }
+ return false;
+}
+
+bool LiteralIsFinished(absl::string_view outer_json, size_t& index,
+ bool is_string_literal) {
+ char c = outer_json[index];
+ if (is_string_literal) {
+ index += (c == '"' ? 1 : 0);
+ return c == '"';
+ }
+
+ return std::isspace(c) || c == ',' || c == '{' || c == '}' || c == '[' ||
+ c == ']' || c == ':';
+}
+
+absl::StatusOr<absl::string_view> ParseLiteral(absl::string_view outer_json,
+ size_t& index) {
+ SkipWhitespace(outer_json, index);
+ TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "literal"));
+
+ auto c = outer_json[index];
+ TF_RETURN_IF_ERROR(EnsureValidLiteralStart(c));
+ bool is_string_literal = c == '"';
+ size_t start_index = index;
+ bool is_escaped = false;
+ ++index;
+
+ while (index < outer_json.size()) {
+ if (HandleEscape(outer_json, index, is_escaped)) {
+ continue;
+ }
+ if (LiteralIsFinished(outer_json, index, is_string_literal)) {
+ break;
+ }
+ ++index;
+ }
+ return outer_json.substr(start_index, index - start_index);
+}
+
+absl::StatusOr<JsonField> ParseField(absl::string_view outer_json,
+ size_t& index);
+
+absl::StatusOr<JsonValue> ParseValue(absl::string_view outer_json,
+ size_t& index) {
+ JsonValue value;
+ SkipWhitespace(outer_json, index);
+ TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "value"));
+ auto c = outer_json[index];
+ if (c == '{') {
+ constexpr static char kObject[] = "object";
+ auto seq = ParseSequence<JsonObject, '{', '}', kObject>(outer_json, index,
+ ParseField);
+ TF_ASSIGN_OR_RETURN(value, std::move(seq));
+ } else if (c == '[') {
+ constexpr static char kArray[] = "array";
+ auto seq = ParseSequence<JsonArray, '[', ']', kArray>(outer_json, index,
+ ParseValue);
+ TF_ASSIGN_OR_RETURN(value, std::move(seq));
+ } else {
+ TF_ASSIGN_OR_RETURN(value, ParseLiteral(outer_json, index));
+ }
+ return value;
+}
+
+absl::StatusOr<JsonField> ParseField(absl::string_view outer_json,
+ size_t& index) {
+ JsonField field;
+ TF_ASSIGN_OR_RETURN(field.name, ParseLiteral(outer_json, index));
+ TF_RETURN_IF_ERROR(Consume(outer_json, index, ':'));
+ TF_ASSIGN_OR_RETURN(field.value, ParseValue(outer_json, index));
+ return field;
+}
+
+template <typename T>
+std::vector<std::string> SerializedElements(const JsonSequence<T>& seq) {
+ std::vector<std::string> result;
+ for (const auto& field : seq.elements) {
+ result.push_back("");
+ Serialize(field, result.back());
+ }
+ return result;
+}
+
+template <typename ElemT, char begin_brace, char end_brace>
+void Serialize(const JsonSequence<ElemT>& object, std::string& result) {
+ auto elems = SerializedElements(object);
+ if constexpr (std::is_same_v<ElemT, JsonField>) {
+ std::sort(elems.begin(), elems.end());
+ }
+
+ result += begin_brace;
+ bool has_preceeding = false;
+ for (const auto& elem : elems) {
+ if (has_preceeding) {
+ result += ',';
+ }
+ result += elem;
+ has_preceeding = true;
+ }
+ result += end_brace;
+}
+
+void Serialize(const JsonValue& value, std::string& result) {
+ if (auto* lit = std::get_if<absl::string_view>(&value)) {
+ absl::StrAppend(&result, *lit);
+ } else if (auto* object = std::get_if<std::unique_ptr<JsonObject>>(&value)) {
+ Serialize<JsonField, '{', '}'>(**object, result);
+ } else if (auto* array = std::get_if<std::unique_ptr<JsonArray>>(&value)) {
+ Serialize<JsonValue, '[', ']'>(**array, result);
+ }
+}
+
+void Serialize(const JsonField& field, std::string& result) {
+ absl::StrAppend(&result, field.name, ":");
+ Serialize(field.value, result);
+}
+
+} // namespace
+
+namespace xla {
+absl::StatusOr<std::string> SortJson(absl::string_view json) {
+ size_t index = 0;
+ TF_ASSIGN_OR_RETURN(auto value, ParseValue(json, index));
+ SkipWhitespace(json, index);
+ if (index < json.size()) {
+ return absl::InvalidArgumentError("Found trailing characters in JSON.");
+ }
+ std::string result;
+ Serialize(value, result);
+ return result;
+}
+} // namespace xla
diff --git a/third_party/xla/xla/sort_json.h b/third_party/xla/xla/sort_json.h
new file mode 100644
index 0000000..b4283f5
--- /dev/null
+++ b/third_party/xla/xla/sort_json.h
@@ -0,0 +1,35 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SORT_JSON_H_
+#define XLA_SORT_JSON_H_
+
+#include <string>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+
+namespace xla {
+
+// Sorts the given JSON string or returns an error if the JSON could not be
+// parsed. Note that this function expects the input JSON to be valid and not
+// all forms of invalid JSON are correctly recognized. This function completely
+// ignores whitespace and the resulting JSON does not have any whitespace.
+// Comments are not supported in the input JSON.
+absl::StatusOr<std::string> SortJson(absl::string_view json);
+
+} // namespace xla
+
+#endif // XLA_SORT_JSON_H_
diff --git a/third_party/xla/xla/sort_json_test.cc b/third_party/xla/xla/sort_json_test.cc
new file mode 100644
index 0000000..f4ff0c1
--- /dev/null
+++ b/third_party/xla/xla/sort_json_test.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/sort_json.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+
+TEST(SortJsonTest, SortsJson) {
+ EXPECT_THAT(SortJson(R"({"a": 1, "c": 3,"b": 2, "b": 1,})"),
+ IsOkAndHolds(R"({"a":1,"b":1,"b":2,"c":3})"));
+
+ EXPECT_THAT(SortJson(R"({"a": 1 , "c": 1,"b": 1 })"),
+ IsOkAndHolds(R"({"a":1,"b":1,"c":1})"));
+
+ EXPECT_THAT(SortJson(R"({"a": 1,"c": 3,"b": 2,"b": [3,2,1],})"),
+ IsOkAndHolds(R"({"a":1,"b":2,"b":[3,2,1],"c":3})"));
+
+ EXPECT_THAT(SortJson(R"({"aa": 1, "a": {"c": "c", "b": "b"}})"),
+ IsOkAndHolds(R"({"a":{"b":"b","c":"c"},"aa":1})"));
+
+ EXPECT_THAT(
+ SortJson(
+ R"({"x": true, "x": false, "x": null, "x": 0, "x": -0.5,"x": "a"})"),
+ IsOkAndHolds(R"({"x":"a","x":-0.5,"x":0,"x":false,"x":null,"x":true})"));
+
+ EXPECT_THAT(SortJson(R"({"a": "a}", "a": "a"})"),
+ IsOkAndHolds(R"({"a":"a","a":"a}"})"));
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD
index e7d2bb7..baf2833 100644
--- a/third_party/xla/xla/stream_executor/BUILD
+++ b/third_party/xla/xla/stream_executor/BUILD
@@ -220,7 +220,6 @@
":stream",
":stream_executor_h",
"//xla:test",
- "@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
@@ -380,6 +379,7 @@
":device_memory",
":numeric_options",
"//xla/stream_executor/platform",
+ "//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
@@ -391,7 +391,6 @@
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3", # buildcleaner: keep
- "@local_tsl//tsl/lib/strings:proto_serialization",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:ml_dtypes",
"@local_tsl//tsl/platform:status",
@@ -454,7 +453,6 @@
":module_spec",
":platform",
":stream",
- "@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@@ -576,12 +574,13 @@
":platform",
":stream_executor_h",
"@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:statusor",
],
)
@@ -675,7 +674,6 @@
":stream",
":stream_executor_h",
"@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@@ -811,6 +809,19 @@
],
)
+xla_cc_test(
+ name = "executor_cache_test",
+ srcs = ["executor_cache_test.cc"],
+ deps = [
+ ":executor_cache",
+ ":mock_stream_executor",
+ ":platform",
+ ":stream",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
#===--------------------------------------------------------------------------------------------===#
# Aliases for StreamExecutor platforms
#===--------------------------------------------------------------------------------------------===#
diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD
index 1521b22..644e5ce 100644
--- a/third_party/xla/xla/stream_executor/cuda/BUILD
+++ b/third_party/xla/xla/stream_executor/cuda/BUILD
@@ -10,6 +10,7 @@
load(
"@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
+ "if_cuda_newer_than",
)
load(
"//xla:xla.bzl",
@@ -225,11 +226,7 @@
name = "cuda_driver_test",
srcs = ["cuda_driver_test.cc"],
backends = ["gpu"],
- tags = [
- # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
- "gpu",
- "no_rocm",
- ],
+ tags = ["no_rocm"],
deps = [
":cuda_driver",
":cuda_status",
@@ -428,7 +425,6 @@
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_semaphore",
- "//xla/stream_executor/gpu:gpu_stream",
"@com_google_absl//absl/status:statusor",
],
)
@@ -590,6 +586,7 @@
cc_library(
name = "ptx_compiler",
hdrs = ["ptx_compiler.h"],
+ tags = ["no_rocm"],
deps = select({
":libnvptxcompiler_support_enabled": [":ptx_compiler_impl"],
"//conditions:default": [":ptx_compiler_stub"],
@@ -599,11 +596,30 @@
],
)
+xla_test(
+ name = "cuda_platform_test",
+ srcs = ["cuda_platform_test.cc"],
+ backends = ["gpu"],
+ deps = [
+ ":cuda_platform",
+ "//xla/stream_executor:platform",
+ "//xla/stream_executor:platform_manager",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_googletest//:gtest_main",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
xla_cc_test(
name = "ptx_compiler_test",
srcs = ["ptx_compiler_test.cc"],
- # TODO(b/343996893): Figure out whether msan reports a false positive or not.
- tags = ["nomsan"],
+ tags = [
+ "no_rocm",
+ # TODO(b/343996893): Figure out whether msan reports a false positive or not.
+ "nomsan",
+ ],
deps = [
":ptx_compiler",
":ptx_compiler_support",
@@ -629,7 +645,11 @@
"//conditions:default": [
"LIBNVJITLINK_SUPPORT=false",
],
- }),
+ }) + if_cuda_newer_than(
+ "12_0",
+ ["CUDA_SUPPORTS_NVJITLINK=true"],
+ ["CUDA_SUPPORTS_NVJITLINK=false"],
+ ),
)
cc_library(
@@ -672,13 +692,25 @@
],
)
+# Since select() can't be nested, we need to wrap the cuda_newer_than check in a separate
+# library target.
+cc_library(
+ name = "nvjitlink_cuda_supported",
+ # Even though the macro is called `*_newer_than`, it does a greater-than-or-equal-to comparison.
+ deps = if_cuda_newer_than(
+ "12_0",
+ [":nvjitlink_impl"],
+ [":nvjitlink_stub"],
+ ),
+)
+
cc_library(
name = "nvjitlink",
hdrs = [
"nvjitlink.h",
],
deps = select({
- ":libnvjitlink_support_enabled": [":nvjitlink_impl"],
+ ":libnvjitlink_support_enabled": [":nvjitlink_cuda_supported"],
"//conditions:default": [":nvjitlink_stub"],
}) + [
"//xla/stream_executor/gpu:gpu_asm_opts",
@@ -824,7 +856,6 @@
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:fingerprint",
"@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:platform_port",
"@local_tsl//tsl/platform:statusor",
] + if_cuda_is_configured([":delay_kernel_cuda"]),
alwayslink = True,
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc
index 01aa153..7f2183f 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc
@@ -20,7 +20,6 @@
#include <cstdint>
#include <cstdlib>
#include <cstring>
-#include <sstream>
#include <string>
#include <tuple>
#include <utility>
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
index a4337df..ce16080 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
@@ -41,6 +41,7 @@
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/cuda/cuda_blas.h"
#include "xla/stream_executor/cuda/cuda_blas_utils.h"
+#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/gpu/gpu_activation.h"
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
index 2fae670..3d61c81 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
@@ -18,6 +18,7 @@
#include <cstddef>
#include <memory>
+#include <optional>
#include <type_traits>
#include <utility>
#include <vector>
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
index 46a2757..440f647 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
@@ -40,7 +40,6 @@
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
@@ -3762,32 +3761,6 @@
}
#if CUDNN_VERSION >= 8800
-enum CudnnfMHAUid {
- Q_ID = 400,
- K_ID,
- V_ID,
- P_ID,
- O_ID,
- dQ_ID,
- dK_ID,
- dV_ID,
- dP_ID,
- dO_ID,
- dS_ID,
- dBIAS_ID,
- BIAS_ID,
- MASK_ID,
- ZERO_VAL_ID,
- ONE_VAL_ID,
- NEG_INFINITY_ID,
- ALPHA_SCALE_ID,
- DROPOUT_SCALE_ID,
- Q_SEQLEN_ID,
- K_SEQLEN_ID,
- D_OFFSET_ID,
- D_SEED_ID,
- VIRTUAL_ID = 34857
-};
absl::StatusOr<cudnn_frontend::PointWiseDesc> CreatePwDesc(
dnn::DataType dtype, cudnnPointwiseMode_t mode) {
@@ -3842,49 +3815,6 @@
RETURN_MSG_IF_CUDNN_ERROR(pw_op_created);
return pw_op_created;
}
-
-// Returns a cudnn tensor that's the output of the mask op
-absl::StatusOr<cudnn_frontend::Tensor> CreateCudnnMaskFwdTensor(
- std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
- absl::Span<const int64_t> strides, dnn::DataType dtype,
- cudnn_frontend::Tensor& input_tensor) {
- std::vector<int64_t> mask_dim(dims.size(), 1);
- std::vector<int64_t> mask_stride(strides.size(), 1);
-
- // Create the mask tensor
- TF_ASSIGN_OR_RETURN(
- auto mask_tensor,
- CreateCudnnTensor(dims, strides, CudnnfMHAUid::MASK_ID, dtype, 1, -1,
- /*is_virtual=*/false));
- // Create the mask output tensor
- TF_ASSIGN_OR_RETURN(
- auto mask_out_tensor,
- CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 400,
- dnn::DataType::kFloat, 1, -1,
- /*is_virtual=*/true));
-
- auto mask_desc = cudnn_frontend::PointWiseDescBuilder()
- .setMode(CUDNN_POINTWISE_MUL)
- .setComputeType(CUDNN_DATA_FLOAT)
- .build();
-
- // Create the mask op.
- auto mask_op = cudnn_frontend::OperationBuilder(
- CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
- .setxDesc(input_tensor)
- .setbDesc(mask_tensor)
- .setyDesc(mask_out_tensor)
- .setpwDesc(mask_desc)
- .build();
-
- RETURN_MSG_IF_CUDNN_ERROR(mask_op);
-
- RETURN_MSG_IF_CUDNN_ERROR(mask_out_tensor);
- // Add mask to op list
- ops.push_back(std::move(mask_op));
-
- return mask_out_tensor;
-}
#endif // CUDNN_VERSION >= 8800
absl::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
@@ -5047,7 +4977,7 @@
const dnn::FMHAMaskKind mask_type) {
using cudnn_frontend::graph::Tensor_attributes;
-#if CUDNN_VERSION >= 8904
+#if CUDNN_VERSION >= 90000
if (VLOG_IS_ON(4)) {
VLOG(4) << "\n bmm1_lhs(q): " << q_descriptor.ToString()
<< "\n bmm1_rhs(k): " << k_descriptor.ToString()
@@ -5075,12 +5005,14 @@
.set_io_data_type(ioDataType)
.set_compute_data_type(cudnn_frontend::DataType_t::FLOAT);
+ auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); };
+
std::shared_ptr<Tensor_attributes> q_tensor =
graph.tensor(Tensor_attributes()
.set_name("Q")
.set_dim(q_descriptor.GetCudnnCompatibleDimensions(true))
.set_stride(q_descriptor.GetCudnnCompatibleStrides(true))
- .set_uid(CudnnfMHAUid::Q_ID));
+ .set_uid(next_uid()));
auto dim = k_descriptor.GetCudnnCompatibleDimensions(true);
std::shared_ptr<Tensor_attributes> k_tensor =
@@ -5088,13 +5020,13 @@
.set_name("K")
.set_dim(k_descriptor.GetCudnnCompatibleDimensions(true))
.set_stride(k_descriptor.GetCudnnCompatibleStrides(true))
- .set_uid(CudnnfMHAUid::K_ID));
+ .set_uid(next_uid()));
std::shared_ptr<Tensor_attributes> v_tensor = graph.tensor(
Tensor_attributes()
.set_name("V")
.set_dim(v_descriptor.GetCudnnCompatibleDimensions(false))
.set_stride(v_descriptor.GetCudnnCompatibleStrides(false))
- .set_uid(CudnnfMHAUid::V_ID));
+ .set_uid(next_uid()));
// Setting sdpa, and is_inference
bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL ||
@@ -5112,7 +5044,7 @@
.set_name("bias")
.set_dim(bias_descriptor->dimensions())
.set_stride(bias_descriptor->GetLogicalStrides())
- .set_uid(CudnnfMHAUid::BIAS_ID));
+ .set_uid(next_uid()));
sdpa_options.set_bias(bias_tensor);
}
// Setting actual seqlen
@@ -5126,37 +5058,38 @@
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
- .set_uid(CudnnfMHAUid::Q_SEQLEN_ID)
+ .set_uid(next_uid())
.set_data_type(cudnn_frontend::DataType_t::INT32));
auto seq_kv_tensor =
graph.tensor(Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
- .set_uid(CudnnfMHAUid::K_SEQLEN_ID)
+ .set_uid(next_uid())
.set_data_type(cudnn_frontend::DataType_t::INT32));
sdpa_options.set_padding_mask(true);
sdpa_options.set_seq_len_q(seq_q_tensor);
sdpa_options.set_seq_len_kv(seq_kv_tensor);
}
// Setting seed and offset
+ std::shared_ptr<Tensor_attributes> seed_tensor;
+ std::shared_ptr<Tensor_attributes> offset_tensor;
if (use_dropout) {
- auto seed_tensor =
+ // Skip setting UIDs: pass by value tensors go at the end.
+ seed_tensor =
graph.tensor(Tensor_attributes()
.set_name("seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(cudnn_frontend::DataType_t::INT64)
- .set_is_pass_by_value(true)
- .set_uid(CudnnfMHAUid::D_SEED_ID));
- auto offset_tensor =
+ .set_is_pass_by_value(true));
+ offset_tensor =
graph.tensor(Tensor_attributes()
.set_name("offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(cudnn_frontend::DataType_t::INT64)
- .set_is_pass_by_value(true)
- .set_uid(CudnnfMHAUid::D_OFFSET_ID));
+ .set_is_pass_by_value(true));
sdpa_options.set_dropout((float)dropout_rate.value(), seed_tensor,
offset_tensor);
}
@@ -5170,7 +5103,7 @@
.set_output(true)
.set_dim(o_descriptor.dimensions())
.set_stride(o_descriptor.GetLogicalStrides())
- .set_uid(CudnnfMHAUid::O_ID);
+ .set_uid(next_uid());
if (stats_descriptor.has_value()) {
cudnn_frontend::DataType_t statsType =
ToCudnnFrontendDataType(stats_descriptor->type());
@@ -5183,11 +5116,19 @@
.set_data_type(statsType)
.set_dim(stat_dims)
.set_stride(stat_strides)
- .set_uid(CudnnfMHAUid::P_ID);
+ .set_uid(next_uid());
+ }
+ if (seed_tensor != nullptr) {
+ seed_tensor->set_uid(next_uid());
+ }
+ if (offset_tensor != nullptr) {
+ offset_tensor->set_uid(next_uid());
}
CudnnGraph cudnnGraph(std::move(graph));
- TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support));
- TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt));
+ TF_RETURN_IF_ERROR(cudnnGraph.Prepare(
+ dnn_support, NumericOptions{/*require_determinism=*/false,
+ /*allow_tf32=*/true}));
+ TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt));
if (VLOG_IS_ON(4)) {
VLOG(4) << "\b flash attention operation graph: " << graph;
@@ -5195,7 +5136,7 @@
return cudnnGraph;
#else
return absl::UnimplementedError(
- "Cudnn flash attention only supported with Cudnn >= 8.9.4");
+ "Cudnn flash attention only supported with Cudnn >= 9.0.0");
#endif
}
@@ -5211,7 +5152,7 @@
std::optional<double> dropout_rate, std::optional<int64_t> seed,
double scale, bool use_dropout, bool use_bias, dnn::FMHAMaskKind mask_type,
bool force_deterministic) {
-#if CUDNN_VERSION >= 8904
+#if CUDNN_VERSION >= 90000
if (VLOG_IS_ON(4)) {
VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString()
<< "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString()
@@ -5236,41 +5177,6 @@
.set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT)
.set_io_data_type(ioDataType);
- std::shared_ptr<Tensor_attributes> q =
- graph.tensor(Tensor_attributes()
- .set_name("Q")
- .set_dim(q_desc.GetCudnnCompatibleDimensions(false))
- .set_stride(q_desc.GetCudnnCompatibleStrides(false))
- .set_uid(CudnnfMHAUid::Q_ID)
- .set_data_type(ioDataType));
- std::shared_ptr<Tensor_attributes> k =
- graph.tensor(Tensor_attributes()
- .set_name("K")
- .set_dim(k_desc.GetCudnnCompatibleDimensions(false))
- .set_stride(k_desc.GetCudnnCompatibleStrides(false))
- .set_uid(CudnnfMHAUid::K_ID)
- .set_data_type(ioDataType));
- std::shared_ptr<Tensor_attributes> v =
- graph.tensor(Tensor_attributes()
- .set_name("V")
- .set_dim(v_desc.GetCudnnCompatibleDimensions(true))
- .set_stride(v_desc.GetCudnnCompatibleStrides(true))
- .set_uid(CudnnfMHAUid::V_ID)
- .set_data_type(ioDataType));
- std::shared_ptr<Tensor_attributes> o =
- graph.tensor(Tensor_attributes()
- .set_name("O")
- .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
- .set_stride(do_desc.GetCudnnCompatibleStrides(false))
- .set_uid(CudnnfMHAUid::O_ID)
- .set_data_type(ioDataType));
- std::shared_ptr<Tensor_attributes> dO =
- graph.tensor(Tensor_attributes()
- .set_name("dO")
- .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
- .set_stride(do_desc.GetCudnnCompatibleStrides(false))
- .set_uid(CudnnfMHAUid::dO_ID)
- .set_data_type(ioDataType));
auto p_dims = p_desc.GetCudnnCompatibleDimensions(false);
auto p_strides = p_desc.GetCudnnCompatibleStrides(false);
std::vector<int64_t> p_reduction_dims(p_dims.begin(), p_dims.end() - 1);
@@ -5284,13 +5190,6 @@
p_reduction_strides.push_back(stride / p_reduced_dim_len);
}
p_reduction_strides[3] = 1;
- std::shared_ptr<Tensor_attributes> stats =
- graph.tensor(Tensor_attributes()
- .set_name("stats")
- .set_dim(p_reduction_dims)
- .set_stride(p_reduction_strides)
- .set_uid(CudnnfMHAUid::P_ID)
- .set_data_type(cudnn_frontend::DataType_t::FLOAT));
bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL ||
mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
auto sdpa_backward_options =
@@ -5300,7 +5199,44 @@
.set_attn_scale(scale)
.set_compute_data_type(cudnn_frontend::DataType_t::FLOAT);
- // Setting bias
+ auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); };
+
+ std::shared_ptr<Tensor_attributes> q =
+ graph.tensor(Tensor_attributes()
+ .set_name("Q")
+ .set_dim(q_desc.GetCudnnCompatibleDimensions(false))
+ .set_stride(q_desc.GetCudnnCompatibleStrides(false))
+ .set_uid(next_uid())
+ .set_data_type(ioDataType));
+ std::shared_ptr<Tensor_attributes> k =
+ graph.tensor(Tensor_attributes()
+ .set_name("K")
+ .set_dim(k_desc.GetCudnnCompatibleDimensions(false))
+ .set_stride(k_desc.GetCudnnCompatibleStrides(false))
+ .set_uid(next_uid())
+ .set_data_type(ioDataType));
+ std::shared_ptr<Tensor_attributes> v =
+ graph.tensor(Tensor_attributes()
+ .set_name("V")
+ .set_dim(v_desc.GetCudnnCompatibleDimensions(true))
+ .set_stride(v_desc.GetCudnnCompatibleStrides(true))
+ .set_uid(next_uid())
+ .set_data_type(ioDataType));
+ std::shared_ptr<Tensor_attributes> stats =
+ graph.tensor(Tensor_attributes()
+ .set_name("stats")
+ .set_dim(p_reduction_dims)
+ .set_stride(p_reduction_strides)
+ .set_uid(next_uid())
+ .set_data_type(cudnn_frontend::DataType_t::FLOAT));
+ std::shared_ptr<Tensor_attributes> dO =
+ graph.tensor(Tensor_attributes()
+ .set_name("dO")
+ .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
+ .set_stride(do_desc.GetCudnnCompatibleStrides(false))
+ .set_uid(next_uid())
+ .set_data_type(ioDataType));
+ std::shared_ptr<Tensor_attributes> d_bias_tensor;
if (use_bias) {
DCHECK(bias_descriptor != std::nullopt);
auto bias_dim = bias_descriptor->dimensions();
@@ -5313,21 +5249,29 @@
.set_name("bias")
.set_dim(bias_descriptor->dimensions())
.set_stride(bias_descriptor->GetLogicalStrides())
- .set_uid(CudnnfMHAUid::BIAS_ID));
+ .set_uid(next_uid()));
sdpa_backward_options.set_bias(bias_tensor);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for
// dbias calculation but they are supported for forward bias calculation
+ // Set UID later: this is the last output tuple element.
if (b == 1 && n == q_n) {
- auto d_bias_tensor =
+ d_bias_tensor =
graph.tensor(Tensor_attributes()
.set_name("dBias")
.set_dim(bias_descriptor->dimensions())
- .set_stride(bias_descriptor->GetLogicalStrides())
- .set_uid(CudnnfMHAUid::dBIAS_ID));
+ .set_stride(bias_descriptor->GetLogicalStrides()));
sdpa_backward_options.set_dbias(d_bias_tensor);
}
}
+ std::shared_ptr<Tensor_attributes> o =
+ graph.tensor(Tensor_attributes()
+ .set_name("O")
+ .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
+ .set_stride(do_desc.GetCudnnCompatibleStrides(false))
+ .set_uid(next_uid())
+ .set_data_type(ioDataType));
+
// Setting actual seqlen
bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING ||
mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
@@ -5339,38 +5283,39 @@
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
- .set_uid(CudnnfMHAUid::Q_SEQLEN_ID)
+ .set_uid(next_uid())
.set_data_type(cudnn_frontend::DataType_t::INT32));
auto seq_kv_tensor =
graph.tensor(Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
- .set_uid(CudnnfMHAUid::K_SEQLEN_ID)
+ .set_uid(next_uid())
.set_data_type(cudnn_frontend::DataType_t::INT32));
sdpa_backward_options.set_padding_mask(true);
sdpa_backward_options.set_seq_len_q(seq_q_tensor);
sdpa_backward_options.set_seq_len_kv(seq_kv_tensor);
}
// Setting seed and offset
+ std::shared_ptr<Tensor_attributes> seed_tensor;
+ std::shared_ptr<Tensor_attributes> offset_tensor;
if (use_dropout) {
DCHECK(dropout_rate != std::nullopt);
- auto seed_tensor =
+ // Skip setting UIDs: pass by value tensors go at the end.
+ seed_tensor =
graph.tensor(Tensor_attributes()
.set_name("seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(cudnn_frontend::DataType_t::INT64)
- .set_is_pass_by_value(true)
- .set_uid(CudnnfMHAUid::D_SEED_ID));
- auto offset_tensor =
+ .set_is_pass_by_value(true));
+ offset_tensor =
graph.tensor(Tensor_attributes()
.set_name("offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(cudnn_frontend::DataType_t::INT64)
- .set_is_pass_by_value(true)
- .set_uid(CudnnfMHAUid::D_OFFSET_ID));
+ .set_is_pass_by_value(true));
sdpa_backward_options.set_dropout((float)dropout_rate.value(), seed_tensor,
offset_tensor);
}
@@ -5385,25 +5330,36 @@
dQ->set_output(true)
.set_dim(dq_desc.dimensions())
.set_stride(dq_desc.GetLogicalStrides())
+ .set_uid(next_uid())
.set_name("dQ")
- .set_uid(CudnnfMHAUid::dQ_ID)
.set_data_type(ioDataType);
dK->set_output(true)
.set_dim(dk_desc.dimensions())
.set_stride(dk_desc.GetLogicalStrides())
+ .set_uid(next_uid())
.set_name("dK")
- .set_uid(CudnnfMHAUid::dK_ID)
.set_data_type(ioDataType);
dV->set_output(true)
.set_dim(dv_desc.dimensions())
.set_stride(dv_desc.GetLogicalStrides())
+ .set_uid(next_uid())
.set_name("dV")
- .set_uid(CudnnfMHAUid::dV_ID)
.set_data_type(ioDataType);
+ if (d_bias_tensor != nullptr) {
+ d_bias_tensor->set_uid(next_uid());
+ }
+ if (seed_tensor != nullptr) {
+ seed_tensor->set_uid(next_uid());
+ }
+ if (offset_tensor != nullptr) {
+ offset_tensor->set_uid(next_uid());
+ }
CudnnGraph cudnnGraph(std::move(graph));
- TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support));
- TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt));
+ TF_RETURN_IF_ERROR(
+ cudnnGraph.Prepare(dnn_support, NumericOptions{force_deterministic,
+ /*allow_tf32=*/true}));
+ TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt));
if (VLOG_IS_ON(4)) {
VLOG(4) << "\b flash attention operation backward graph: " << graph;
@@ -5412,7 +5368,7 @@
return cudnnGraph;
#else
return absl::UnimplementedError(
- "Cudnn flash attention only supported with Cudnn >= 8.9.4");
+ "Cudnn flash attention only supported with Cudnn >= 9.0.0");
#endif
}
@@ -5735,8 +5691,8 @@
}
// Utility for dealing with CUDA's type-erased scaling parameters, where some
-// sets of parameters expect a void* pointing at a float while others expect it
-// to point at a double.
+// sets of parameters expect a void* pointing at a float while others expect
+// it to point at a double.
//
// This is rather ugly, but its purpose is to quarantine the corresponding
// ugliness that already exists in the CUDA API.
@@ -5760,9 +5716,9 @@
//
// See
// https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
- // for more info; the behavior for int8 result tensors is not described there,
- // but is maintained from the existing behavior (namely, using a float scaling
- // parameter).
+ // for more info; the behavior for int8 result tensors is not described
+ // there, but is maintained from the existing behavior (namely, using a
+ // float scaling parameter).
void* ToVoidPointer(dnn::DataType element_type) {
if (element_type == dnn::DataType::kDouble) {
return &as_double_;
@@ -5834,10 +5790,11 @@
absl::c_transform(result, std::back_inserter(raw_ptrs),
[](const BackendDescriptor& ptr) { return ptr.get(); });
- // This API evidently does a deep copy of the descriptors into the pointers in
- // the output array, rather than writing pointers to the descriptors into the
- // output array. So, this writes the memory behind each BackendDescriptor in
- // result, rather than writing the contents of raw_ptrs.
+ // This API evidently does a deep copy of the descriptors into the pointers
+ // in the output array, rather than writing pointers to the descriptors into
+ // the output array. So, this writes the memory behind each
+ // BackendDescriptor in result, rather than writing the contents of
+ // raw_ptrs.
RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data()));
@@ -5873,9 +5830,9 @@
cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX,
CUDNN_TYPE_INT64, 1, &n, &engine_id));
- // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the
- // number of elements in the attribute by using an output limit value of 0
- // just returns 0; the only way to find out how many there are is to
+ // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query
+ // the number of elements in the attribute by using an output limit value of
+ // 0 just returns 0; the only way to find out how many there are is to
// pre-allocate space for every existing knob type (as an upper bound on the
// number of knob choices a config can have), and then look back at how many
// were filled.
@@ -6086,103 +6043,7 @@
std::vector<int64_t> scalar_input_uids_;
std::vector<ScalingParam> scalar_input_values_;
};
-#endif // CUDNN_VERSION >= 8100
-template <typename Sig>
-class CudnnGraphRunner;
-// An OpRunner implemented by a cuDNN frontend graph.
-//
-// This is the class holding the implementation of ToString, GetWorkspaceSize,
-// and operator() for use by the cudnn frontend op runners.
-template <typename... Args>
-class CudnnGraphRunner<void(Args...)> : public dnn::OpRunner<void(Args...)> {
- private:
- using Graph = cudnn_frontend::graph::Graph;
- using Tensor_attributes = cudnn_frontend::graph::Tensor_attributes;
-
- public:
- std::string ToString() const override { return graph_.Graph().print(); }
-
- size_t GetWorkspaceSize() const override {
- return graph_.Graph().get_workspace_size();
- }
-
- absl::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
- return absl::InternalError(
- "Unexpected call to CudnnGraphRunner::ToAlgorithmDesc");
- }
-
- absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
- DeviceMemoryBase scratch_memory,
- Args... inputs) const override {
- if (parent_ != stream->parent()) {
- return tsl::errors::Internal(
- "CudnnExecutionPlanRunner cached across multiple StreamExecutors.");
- }
- CudnnHandle handle = cudnn_->GetHandle(parent_, stream);
- std::unordered_map<int64_t, void*> variant_pack;
- std::vector<void*> vec = {inputs.opaque()...};
-
- // add device buffers to the variant pack
- for (int i = 0; i < uids_.size(); ++i) {
- if (uids_[i].has_value()) {
- variant_pack[*uids_[i]] = vec[i];
- }
- }
- if (dropout_rng_offset_increment_ > 0) {
-#if CUDNN_VERSION >= 8800
- variant_pack[CudnnfMHAUid::D_SEED_ID] = (void*)&dropout_rng_seed_;
- current_dropout_rng_offset_ += dropout_rng_offset_increment_;
- variant_pack[CudnnfMHAUid::D_OFFSET_ID] =
- (void*)¤t_dropout_rng_offset_;
-#else
- return absl::UnimplementedError(
- "Cudnn dropout offset and seed are only supported with Cudnn >= "
- "8.8.0");
-#endif // CUDNN_VERSION >= 8800
- }
- int workspace = graph_.Graph().get_workspace_size();
- if (workspace > scratch_memory.size()) {
- return tsl::errors::Internal(
- absl::StrFormat("CuDNN FMHA requires %d workspace, got %d workspace.",
- workspace, scratch_memory.size()));
- }
- RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.Graph().execute(
- handle.handle(), variant_pack, scratch_memory.opaque()));
-
- return absl::OkStatus();
- }
-
- static absl::StatusOr<CudnnGraphRunner> Create(
- GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph,
- int64_t dropout_rng_seed, int64_t dropout_rng_offset,
- std::vector<std::optional<int64_t>> uids) {
- return CudnnGraphRunner(parent, cudnn, std::move(graph), dropout_rng_seed,
- dropout_rng_offset, uids);
- }
-
- private:
- CudnnGraphRunner(GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph,
- int64_t dropout_rng_seed, int64_t dropout_rng_offset,
- std::vector<std::optional<int64_t>> uids)
- : parent_(parent),
- cudnn_(cudnn),
- graph_(std::move(graph)),
- dropout_rng_seed_(dropout_rng_seed),
- current_dropout_rng_offset_(0),
- dropout_rng_offset_increment_(dropout_rng_offset),
- uids_(uids) {}
- GpuExecutor* parent_;
- CudnnAccess* cudnn_;
- Stream* stream_;
- CudnnGraph graph_;
- int64_t dropout_rng_seed_;
- mutable int64_t current_dropout_rng_offset_;
- int64_t dropout_rng_offset_increment_;
- std::vector<std::optional<int64_t>> uids_;
-};
-
-#if CUDNN_VERSION >= 8100
namespace {
template <typename Sig>
@@ -6968,7 +6829,8 @@
use_fallback, out_exec_plans, /*need_side_input=*/true, numeric_options);
#else
return tsl::errors::Unimplemented(
- "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4.");
+ "Cudnn execution plans for matmul are only supported with Cudnn >= "
+ "8.4.");
#endif // CUDNN_VERSION >= 8400
}
@@ -7170,139 +7032,6 @@
return max_seq_len * max_seq_len / cudnn_mha_num_threads;
}
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHARunner>>
-CudnnSupport::FusedMHARunnerFromDesc(
- Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
- const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
- const dnn::TensorDescriptor& output_descriptor,
- std::optional<dnn::TensorDescriptor> activation_descriptor,
- std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type) {
-#if CUDNN_VERSION >= 8904
- auto cudnn = cudnn_->GetHandle(parent_, stream);
- bool use_dropout = dropout_rate && *dropout_rate > 0.0;
- std::vector<int64_t> intermediate_shape;
-
- TF_ASSIGN_OR_RETURN(auto graph,
- GetCudnnFlashAttentionOperationGraph(
- *this, /*q_descriptor=*/bmm1_lhs_descriptor,
- /*k_descriptor=*/bmm1_rhs_descriptor,
- /*v_descriptor=*/bmm2_rhs_descriptor,
- /*o_descriptor=*/output_descriptor, bias_descriptor,
- /*stats_descriptor=*/activation_descriptor,
- /*scale=*/static_cast<float>(scale), use_dropout,
- dropout_rate, mask_type));
-
- std::vector<int64_t> intermediate_bmm2_lhs_dims =
- intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true);
- intermediate_shape = intermediate_bmm2_lhs_dims;
- int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape);
- int64_t dropout_rng_seed = seed.has_value() ? *seed : 0;
- std::vector<std::optional<int64_t>> uids = {
- CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::V_ID,
- CudnnfMHAUid::O_ID};
- uids.emplace_back(bias_descriptor.has_value()
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::BIAS_ID)
- : std::nullopt);
- uids.emplace_back(activation_descriptor.has_value()
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::P_ID)
- : std::nullopt);
- bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING ||
- mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
- uids.emplace_back(is_padding
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::Q_SEQLEN_ID)
- : std::nullopt);
- uids.emplace_back(is_padding
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::K_SEQLEN_ID)
- : std::nullopt);
- TF_ASSIGN_OR_RETURN(auto runner,
- CudnnGraphRunner<dnn::FusedMHASignature>::Create(
- parent_, cudnn_.get(), std::move(graph),
- dropout_rng_seed, dropout_rng_offset, uids));
-
- return {std::make_unique<CudnnGraphRunner<dnn::FusedMHASignature>>(
- std::move(runner))};
-#else
- return absl::UnimplementedError(
- "Cudnn flash attention are only supported with Cudnn >= 8.9.4");
-#endif // CUDNN_VERSION >= 8904
-}
-
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
-CudnnSupport::FusedMHABackwardRunnerFromDesc(
- Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
- const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& d_output_descriptor,
- const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
- const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
- const dnn::TensorDescriptor& d_bmm2_rhs_descriptor,
- std::optional<dnn::TensorDescriptor> d_s_descriptor,
- std::optional<dnn::TensorDescriptor> d_bias_descriptor,
- std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
- std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type, bool force_deterministic) {
-#if CUDNN_VERSION >= 8904
- auto cudnn = cudnn_->GetHandle(parent_, stream);
-
- bool use_dropout = dropout_rate && *dropout_rate > 0.0;
- std::vector<int64_t> intermediate_shape;
-
- TF_ASSIGN_OR_RETURN(
- auto graph,
- GetCudnnFlashAttentionBackwardOperationGraph(
- *this, bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor,
- bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor,
- d_output_descriptor, d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor,
- d_bmm2_rhs_descriptor, bias_descriptor, dropout_rate, seed, scale,
- use_dropout, bias_descriptor != std::nullopt, mask_type,
- force_deterministic));
-
- std::vector<int64_t> p_dims =
- bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false);
- intermediate_shape = p_dims;
- int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape);
- int64_t dropout_rng_seed = seed.has_value() ? *seed : 0;
-
- std::vector<std::optional<int64_t>> uids;
- uids = {CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::P_ID,
- CudnnfMHAUid::V_ID, CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID,
- CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, std::nullopt};
- uids.emplace_back(d_bias_descriptor.has_value()
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::dBIAS_ID)
- : std::nullopt);
- uids.push_back(CudnnfMHAUid::O_ID);
- uids.emplace_back(bias_descriptor.has_value()
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::BIAS_ID)
- : std::nullopt);
- bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING ||
- mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
- uids.emplace_back(is_padding
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::Q_SEQLEN_ID)
- : std::nullopt);
- uids.emplace_back(is_padding
- ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::K_SEQLEN_ID)
- : std::nullopt);
- TF_ASSIGN_OR_RETURN(auto runner,
- CudnnGraphRunner<dnn::FusedMHABackwardSignature>::Create(
- parent_, cudnn_.get(), graph, dropout_rng_seed,
- dropout_rng_offset, uids));
- return {std::make_unique<CudnnGraphRunner<dnn::FusedMHABackwardSignature>>(
- std::move(runner))};
-#else
- return absl::UnimplementedError(
- "Cudnn flash attention bwd are only "
- "supported with Cudnn >= 8.9.4");
-#endif // CUDNN_VERSION >= 8904
-}
-
bool CudnnSupport::GetRnnAlgorithms(
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
PreloadCudnnSubLibs(PreloadCudnnType::Rnn);
@@ -8353,11 +8082,16 @@
return std::make_unique<CudnnGraph>(std::move(graph));
}
-absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support) {
+absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support,
+ const NumericOptions& numeric_options) {
const CudnnSupport& cudnn_support = static_cast<CudnnSupport&>(dnn_support);
TF_ASSIGN_OR_RETURN(auto cudnn, cudnn_support.cudnn_->GetLocalHandle());
RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.validate());
RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(cudnn->handle()));
+ if (numeric_options.require_determinism) {
+ graph_.deselect_numeric_notes(
+ {cudnn_frontend::NumericalNote_t::NONDETERMINISTIC});
+ }
RETURN_IF_CUDNN_FRONTEND_ERROR(
graph_.create_execution_plans({cudnn_frontend::HeurMode_t::A}));
RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.check_support(cudnn->handle()));
@@ -8382,15 +8116,30 @@
std::unordered_map<int64_t, void*> tensor_to_ptr_map;
absl::Span<DeviceMemoryBase> operands_without_workspace = operands;
DeviceMemoryBase workspace;
- if (graph_.get_workspace_size() != 0) {
+ if (graph_.get_workspace_size() > 0) {
workspace = operands.back();
CHECK_EQ(graph_.get_workspace_size(), workspace.size());
+ }
+ if (graph_.get_workspace_size() > 0 || operands.back().size() == 0) {
operands_without_workspace = operands.first(operands.size() - 1);
}
- int operand_number = 0;
+ auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); };
for (DeviceMemoryBase operand : operands_without_workspace) {
- tensor_to_ptr_map[CuDnnTensorUID(operand_number++)] = operand.opaque();
+ tensor_to_ptr_map[next_uid()] = operand.opaque();
}
+
+ if (dropout_rng_offset_increment_ > 0) {
+#if CUDNN_VERSION >= 8800
+ tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_;
+ current_dropout_rng_offset_ += dropout_rng_offset_increment_;
+ tensor_to_ptr_map[next_uid()] = (void*)¤t_dropout_rng_offset_;
+#else
+ return absl::UnimplementedError(
+ "Cudnn dropout offset and seed are only supported with Cudnn >= "
+ "8.8.0");
+#endif // CUDNN_VERSION >= 8800
+ }
+
const CudnnSupport& dnn_support =
static_cast<CudnnSupport&>(*stream.parent()->AsDnn());
RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute(
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
index 5208693..24d84e3 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
@@ -60,7 +60,7 @@
explicit CudnnGraph(cudnn_frontend::graph::Graph&& graph)
: graph_(std::move(graph)) {}
// Prepares a graph and checks whether it is generally supported.
- absl::Status Prepare(dnn::DnnSupport&) override;
+ absl::Status Prepare(dnn::DnnSupport&, const NumericOptions&) override;
// Builds single plan of the graph with given ID.
absl::Status Build(dnn::DnnSupport&, std::optional<int64_t> plan_id) override;
// Builds all the plans
@@ -70,6 +70,9 @@
private:
cudnn_frontend::graph::Graph graph_;
+ int64_t dropout_rng_seed_;
+ mutable int64_t current_dropout_rng_offset_;
+ int64_t dropout_rng_offset_increment_ = 0;
};
#endif // CUDNN_VERSION >= 8100
@@ -335,37 +338,6 @@
std::optional<dnn::TensorDescriptor> dscale_descriptor,
std::optional<dnn::TensorDescriptor> dbias_descriptor) override;
- absl::StatusOr<std::unique_ptr<const dnn::FusedMHARunner>>
- FusedMHARunnerFromDesc(
- Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
- const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
- const dnn::TensorDescriptor& output_descriptor,
- std::optional<dnn::TensorDescriptor> activation_descriptor,
- std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type) override;
-
- absl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
- FusedMHABackwardRunnerFromDesc(
- Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
- const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& d_output_descriptor,
- const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
- const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
- const dnn::TensorDescriptor& d_bmm2_rhs_descriptor,
- std::optional<dnn::TensorDescriptor> d_s_descriptor,
- std::optional<dnn::TensorDescriptor> d_bias_descriptor,
- std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
- std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type, bool force_deterministic);
-
bool GetRnnAlgorithms(
std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
index 866c1ff..e8e26e2 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
@@ -19,6 +19,7 @@
#include <stdlib.h>
#include <cstdint>
+#include <cstdlib>
#include <cstring>
#include <string>
#include <utility>
@@ -27,7 +28,6 @@
#include "absl/base/casts.h"
#include "absl/base/const_init.h"
-#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
#include "absl/debugging/leak_check.h"
#include "absl/log/check.h"
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h
index 5c04ab6..aefd896 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h
@@ -19,16 +19,13 @@
#define XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
#include <algorithm>
-#include <cstdint>
#include <memory>
-#include <string>
#include <utility>
#include <vector>
#include "absl/container/node_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
-#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/cuda/cuda_status.h"
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
index b624709..8ae2477 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
@@ -17,20 +17,17 @@
#include <cstdint>
#include <cstdio>
#include <cstdlib>
-#include <ios>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <variant>
-#include <vector>
-#include "absl/base/casts.h"
#include "absl/numeric/int128.h"
-#include "absl/strings/str_join.h"
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/event_based_timer.h"
@@ -46,7 +43,6 @@
#include <unistd.h>
#endif
-#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@@ -54,7 +50,6 @@
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
-#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
@@ -75,18 +70,15 @@
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/gpu_timer.h"
#include "xla/stream_executor/gpu/gpu_types.h"
-#include "xla/stream_executor/integrations/device_mem_allocator.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/module_spec.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/plugin_registry.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
-#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/fingerprint.h"
#include "tsl/platform/logging.h"
-#include "tsl/platform/numa.h"
#include "tsl/platform/statusor.h"
// LOG(ERROR) uses a const named ERROR, so a macro with the same name is
@@ -154,9 +146,6 @@
}
}
-static std::optional<int> TryToReadNumaNode(const std::string& pci_bus_id,
- int device_ordinal);
-
absl::Status GpuExecutor::Init() {
TF_RETURN_IF_ERROR(GpuDriver::Init());
TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal_, &device_));
@@ -164,17 +153,6 @@
GpuDriver::CreateContext(device_ordinal_, device_, &context_));
TF_RETURN_IF_ERROR(
GpuDriver::GetComputeCapability(&cc_major_, &cc_minor_, device_));
- std::optional<int> numa_node = TryToReadNumaNode(
- absl::AsciiStrToLower(GpuDriver::GetPCIBusID(device_ordinal_)),
- device_ordinal_);
- if (!numa_node || *numa_node < 0) {
- LOG(WARNING) << "NUMA node could not be determined for device "
- << device_ordinal_
- << ", host memory allocations will not be NUMA-pinned";
- numa_node_ = tsl::port::kNUMANoAffinity;
- } else {
- numa_node_ = *numa_node;
- }
return absl::OkStatus();
}
@@ -492,83 +470,6 @@
return absl::OkStatus();
}
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const Kernel& kernel, const KernelArgs& args) {
- return Launch(stream, thread_dims, block_dims, std::nullopt, kernel, args);
-}
-
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const ClusterDim& cluster_dims,
- const Kernel& kernel, const KernelArgs& args) {
- return Launch(stream, thread_dims, block_dims,
- std::make_optional(cluster_dims), kernel, args);
-}
-
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const std::optional<ClusterDim>& cluster_dims,
- const Kernel& kernel, const KernelArgs& args) {
- CUstream custream = AsGpuStreamValue(stream);
- const GpuKernel* cuda_kernel = AsGpuKernel(&kernel);
- CUfunction cufunc = cuda_kernel->gpu_function();
-
- if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
- TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
- cufunc, cuda_kernel->GetGpuCacheConfig()));
- }
-
- // Launch CUDA kernels with packed arguments.
- auto launch = [&](const KernelArgsPackedArrayBase& packed) {
- int32_t expected_number_of_arguments =
- kernel.Arity() + (packed.number_of_shared_bytes() > 0);
-
- CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments())
- << "Kernel " << kernel.name() << " has " << packed.number_of_arguments()
- << " arguments, but expected " << expected_number_of_arguments
- << "; arity=" << kernel.Arity()
- << "; number_of_shared_bytes=" << packed.number_of_shared_bytes();
-
- void** params = const_cast<void**>(packed.argument_addresses().data());
-
- if (cluster_dims.has_value()) {
- return GpuDriver::LaunchKernel(
- context_, kernel.name(), cufunc, cluster_dims->x, cluster_dims->y,
- cluster_dims->z, block_dims.x, block_dims.y, block_dims.z,
- thread_dims.x, thread_dims.y, thread_dims.z,
- packed.number_of_shared_bytes(), custream, params,
- /*extra=*/nullptr);
- } else {
- return GpuDriver::LaunchKernel(
- context_, kernel.name(), cufunc, block_dims.x, block_dims.y,
- block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
- packed.number_of_shared_bytes(), custream, params,
- /*extra=*/nullptr);
- }
- };
-
- // If arguments are already packed we can just launch the kernel.
- if (auto* packed = DynCast<KernelArgsPackedArrayBase>(&args)) {
- return launch(*packed);
- }
-
- // For device memory array we rely on a custom kernel arguments packing.
- if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
- auto& pack = kernel.args_packing();
- if (!pack) {
- return absl::InternalError(
- "Kernel is missing a custom arguments packing function for device "
- "memory arguments array");
- }
-
- TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem));
- return launch(*packed);
- }
-
- return absl::InternalError("Unsupported kernel arguments type");
-}
-
DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) {
if (memory_space == 1) {
auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size);
@@ -588,47 +489,6 @@
GpuDriver::DeviceDeallocate(context_, mem->opaque());
}
-// CUDA allocation/registration functions are necessary because the driver
-// internally sets up buffers for DMA operations (and page locks them). There's
-// no external interface for us to otherwise control these DMA settings.
-absl::StatusOr<std::unique_ptr<MemoryAllocation>>
-GpuExecutor::HostMemoryAllocate(uint64_t size) {
- if (numa_node_ != tsl::port::kNUMANoAffinity) {
- auto* buffer =
- tsl::port::NUMAMalloc(numa_node_, size, /* minimum_alignment=*/16);
- if (buffer == nullptr && size > 0) {
- return absl::InternalError(absl::StrFormat(
- "Failed to allocate host memory of size %d pinned to NUMA node %d",
- size, numa_node_));
- }
- if (size > 0 && !GpuDriver::HostRegister(context_, buffer, size)) {
- return absl::InternalError(
- absl::StrFormat("Failed to register host memory of size %d pinned to "
- "NUMA node %d with the GPU driver",
- size, numa_node_));
- }
- return std::make_unique<HostMemoryAllocation>(buffer, size, this);
- } else {
- auto* buffer = GpuDriver::HostAllocate(context_, size);
- if (buffer == nullptr && size > 0) {
- return absl::InternalError(
- absl::StrFormat("Failed to allocate HostMemory of size %d", size));
- }
- return std::make_unique<HostMemoryAllocation>(buffer, size, this);
- }
-}
-
-void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) {
- if (numa_node_ != tsl::port::kNUMANoAffinity) {
- if (size > 0) {
- GpuDriver::HostUnregister(context_, location);
- }
- tsl::port::NUMAFree(location, size);
- } else {
- GpuDriver::HostDeallocate(context_, location);
- }
-}
-
bool GpuExecutor::SynchronizeAllActivity() {
return GpuDriver::SynchronizeContext(context_);
}
@@ -829,22 +689,22 @@
GpuContext* GpuExecutor::gpu_context() { return context_; }
// Attempts to read the NUMA node corresponding to the GPU device's PCI bus out
-// of SysFS.
+// of SysFS. Returns -1 if it cannot.
//
// For anything more complicated/prod-focused than this, you'll likely want to
-// turn to gsys' topology modeling. nvmlDeviceGetMemoryAffinity could also be
-// used.
-static std::optional<int> TryToReadNumaNode(const std::string& pci_bus_id,
- int device_ordinal) {
+// turn to gsys' topology modeling.
+static int TryToReadNumaNode(const std::string& pci_bus_id,
+ int device_ordinal) {
#if defined(PLATFORM_WINDOWS)
// Windows support for NUMA is not currently implemented. Return node 0.
return 0;
#else
VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal;
+ static const int kUnknownNumaNode = -1;
if (pci_bus_id.empty()) {
LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal;
- return std::nullopt;
+ return kUnknownNumaNode;
}
std::string filename =
@@ -857,7 +717,7 @@
if (file == nullptr) {
LOG(INFO) << "could not open file to read NUMA node: " << filename
<< "\nYour kernel may have been built without NUMA support.";
- return std::nullopt;
+ return kUnknownNumaNode;
}
std::string content;
@@ -868,6 +728,17 @@
int32_t value;
if (absl::SimpleAtoi(content, &value)) {
+ if (value < 0) { // See http://b/18228951 for details on this path.
+ LOG(INFO) << "successful NUMA node read from SysFS had negative value ("
+ << value
+ << "), but there must be at least one NUMA node"
+ ", so returning NUMA node zero."
+ " See more at "
+ "https://github.com/torvalds/linux/blob/v6.0/Documentation/"
+ "ABI/testing/sysfs-bus-pci#L344-L355";
+ fclose(file);
+ return 0;
+ }
fclose(file);
return value;
}
@@ -877,7 +748,7 @@
<< content;
fclose(file);
- return std::nullopt;
+ return kUnknownNumaNode;
#endif
}
@@ -909,24 +780,8 @@
builder.set_pci_bus_id(pci_bus_id);
// Read the NUMA node corresponding to the PCI bus ID out of sysfs.
- std::optional<int> numa_node =
- TryToReadNumaNode(pci_bus_id, device_ordinal);
- if (numa_node.has_value()) {
- if (*numa_node < 0) { // See http://b/18228951 for details on this path.
- LOG(INFO)
- << "successful NUMA node read from SysFS had negative value ("
- << *numa_node
- << "), but there must be at least one NUMA node"
- ", so returning NUMA node zero."
- " See more at "
- "https://github.com/torvalds/linux/blob/v6.0/Documentation/"
- "ABI/testing/sysfs-bus-pci#L344-L355";
- numa_node = 0;
- }
- } else {
- numa_node = -1;
- }
- builder.set_numa_node(*numa_node);
+ int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal);
+ builder.set_numa_node(numa_node);
}
{
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc
index bdace57..16dab63 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc
@@ -15,19 +15,14 @@
#include "xla/stream_executor/cuda/cuda_platform.h"
-#include <algorithm>
-#include <cstdlib>
-#include <cstring>
#include <memory>
#include <string>
#include <utility>
-#include "absl/base/call_once.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
@@ -35,65 +30,16 @@
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform_manager.h"
+#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
namespace stream_executor {
namespace gpu {
-CudaPlatform::CudaPlatform()
- : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}
+CudaPlatform::CudaPlatform() : name_("CUDA") {}
CudaPlatform::~CudaPlatform() {}
-// Due to legacy issues in user code, we can't currently call InpectNumaNodes
-// at module initialization time, because non-GPU programs still include this
-// plugin via various methods, so instead, it has to be init-on-reference.
-void CudaPlatform::InspectNumaNodes() {
- // To get NUMA node information, we need to create all executors, so we can
- // examine their device descriptions to see their bus assignments.
- static absl::once_flag once;
- absl::call_once(once, [&] {
- for (int i = 0; i < VisibleDeviceCount(); i++) {
- StreamExecutor* exec = *ExecutorForDevice(i);
- if (i == 0) {
- // NUMA nodes may not start at 0, so set the minimum node based on the
- // first executor we see.
- min_numa_node_ = exec->GetDeviceDescription().numa_node();
- limit_numa_node_ = min_numa_node_ + 1;
- } else {
- min_numa_node_ =
- std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
- limit_numa_node_ = std::max(
- limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
- }
- }
- });
-}
-
-int CudaPlatform::BusCount() {
- InspectNumaNodes();
- return limit_numa_node_ - min_numa_node_;
-}
-
-int CudaPlatform::DeviceToBus(int device_ordinal) {
- StreamExecutor* exec = *ExecutorForDevice(device_ordinal);
- return exec->GetDeviceDescription().numa_node() - min_numa_node_;
-}
-
-absl::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus(
- int bus_ordinal) {
- InspectNumaNodes();
- CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
- for (int i = 0; i < VisibleDeviceCount(); i++) {
- if (DeviceToBus(i) == bus_ordinal) {
- return *ExecutorForDevice(i);
- }
- }
-
- return absl::NotFoundError(
- absl::StrFormat("Executor for bus %d not found.", bus_ordinal));
-}
-
Platform::Id CudaPlatform::id() const { return cuda::kCudaPlatformId; }
int CudaPlatform::VisibleDeviceCount() const {
@@ -118,6 +64,12 @@
return GetExecutor(config);
}
+absl::StatusOr<StreamExecutor*> CudaPlatform::FindExisting(int ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ return executor_cache_.Get(config);
+}
+
absl::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
const StreamExecutorConfig& config) {
if (config.gpu_stream) {
@@ -133,24 +85,15 @@
absl::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = std::make_unique<GpuExecutor>(this, config.ordinal);
- auto init_status = executor->Init();
- if (!init_status.ok()) {
- return absl::InternalError(absl::StrFormat(
- "failed initializing StreamExecutor for CUDA device ordinal %d: %s",
- config.ordinal, init_status.ToString()));
- }
-
+ TF_RETURN_IF_ERROR(executor->Init());
return std::move(executor);
}
} // namespace gpu
static void InitializeCudaPlatform() {
- // Disabling leak checking, PlatformManager does not destroy its
- // registered platforms.
-
- std::unique_ptr<gpu::CudaPlatform> platform(new gpu::CudaPlatform);
- TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform)));
+ TF_CHECK_OK(
+ PlatformManager::RegisterPlatform(std::make_unique<gpu::CudaPlatform>()));
}
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h
index 153282b..ac33429 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h
@@ -22,7 +22,6 @@
#include "absl/status/statusor.h"
#include "xla/stream_executor/executor_cache.h"
#include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream_executor.h"
namespace stream_executor {
@@ -41,16 +40,6 @@
CudaPlatform();
~CudaPlatform() override;
- // CudaPlatform-specific functionality
- // Returns the number of distinct buses / NUMA nodes on the machine.
- int BusCount();
-
- // Returns the bus/NUMA node for the specified device ordinal.
- int DeviceToBus(int device_ordinal);
-
- // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
- absl::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
-
// Platform interface implementation:
// Returns the same value as kCudaPlatform above.
Platform::Id id() const override;
@@ -64,32 +53,24 @@
int ordinal) const override;
absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
+ absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
- const StreamExecutorConfig& config) override;
+ const StreamExecutorConfig& config);
private:
- // Determines the number of NUMA nodes and the assignment of executor to each.
- void InspectNumaNodes();
-
// This platform's name.
std::string name_;
// Cache of created executors.
ExecutorCache executor_cache_;
- // The smallest NUMA node value for any device managed by this machine
- // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
- // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./
- int min_numa_node_;
-
- // Larger than the NUMA node value for any device managed by this machine
- // manager.
- int limit_numa_node_;
-
CudaPlatform(const CudaPlatform&) = delete;
void operator=(const CudaPlatform&) = delete;
};
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc
new file mode 100644
index 0000000..b9621f7
--- /dev/null
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc
@@ -0,0 +1,48 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/stream_executor/cuda/cuda_platform.h"
+
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_map.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/platform_manager.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace stream_executor::gpu {
+namespace {
+
+TEST(CudaPlatformTest, FindExistingWorks) {
+ TF_ASSERT_OK_AND_ASSIGN(Platform * platform,
+ PlatformManager::PlatformWithName("CUDA"));
+ CHECK_GT(platform->VisibleDeviceCount(), 0);
+ for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
+ EXPECT_FALSE(platform->FindExisting(i).ok());
+ }
+ absl::flat_hash_map<int, StreamExecutor*> executors;
+ for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
+ TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->ExecutorForDevice(i));
+ executors[i] = executor;
+ }
+ EXPECT_EQ(executors.size(), platform->VisibleDeviceCount());
+ for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
+ TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->FindExisting(i));
+ EXPECT_EQ(executor, executors[i]);
+ }
+}
+
+} // namespace
+} // namespace stream_executor::gpu
diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h
index aa59af5..0a30c1a 100644
--- a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h
+++ b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h
@@ -29,6 +29,11 @@
} \
} while (false)
+// UIDs for cuDNN are unique identifiers of tensors within a graph. They are
+// assigned during graph construction; then graph execution takes a {uid:
+// buffer pointer} map defining the correspondance of buffers to tensors.
+// UID assignment scheme can be arbitrary; at the moment for simplicity XLA uses
+// a scheme UID = (HLO operand number + 1).
int CuDnnTensorUID(int offset);
} // namespace gpu
diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h
index 09aad2f..016639d 100644
--- a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h
+++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h
@@ -18,7 +18,6 @@
#include "absl/status/statusor.h"
#include "xla/stream_executor/gpu/gpu_semaphore.h"
-#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/stream.h"
namespace stream_executor::gpu {
diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc
index 1803697..bedd416 100644
--- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc
+++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc
@@ -16,5 +16,7 @@
#include "xla/stream_executor/cuda/nvjitlink_support.h"
namespace stream_executor {
-bool IsLibNvJitLinkSupported() { return LIBNVJITLINK_SUPPORT; }
+bool IsLibNvJitLinkSupported() {
+ return LIBNVJITLINK_SUPPORT && CUDA_SUPPORTS_NVJITLINK;
+}
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h b/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h
index d6e28e9..12d5ae5 100644
--- a/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h
+++ b/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h
@@ -22,6 +22,7 @@
namespace stream_executor {
enum class PtxCompilationMethod {
+ kNvJitLink,
kNvPtxCompiler,
kPtxas,
};
@@ -30,6 +31,9 @@
static void AbslStringify(Sink& sink,
const PtxCompilationMethod& compilation_method) {
switch (compilation_method) {
+ case PtxCompilationMethod::kNvJitLink:
+ sink.Append("NvJitLink");
+ break;
case PtxCompilationMethod::kNvPtxCompiler:
sink.Append("NvPtxCompiler");
break;
diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc
index c295833..aae9406 100644
--- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc
+++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc
@@ -19,7 +19,6 @@
#include <cstdint>
#include <string>
-#include <utility>
#include <vector>
#include <gmock/gmock.h>
diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h b/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h
index 56dcdf1..aafc36d 100644
--- a/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h
+++ b/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h
@@ -26,11 +26,15 @@
kNone,
kNvLink,
kDriver,
+ kNvJitLink,
};
template <typename Sink>
void AbslStringify(Sink& sink, const PtxLinkingMethod& method) {
switch (method) {
+ case PtxLinkingMethod::kNvJitLink:
+ sink.Append("NvJitLink");
+ break;
case PtxLinkingMethod::kNvLink:
sink.Append("NvLink");
break;
diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc
index 5a674a0..951b2f6 100644
--- a/third_party/xla/xla/stream_executor/dnn.cc
+++ b/third_party/xla/xla/stream_executor/dnn.cc
@@ -41,7 +41,7 @@
#include "xla/stream_executor/data_type.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/numeric_options.h"
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
#include "tsl/platform/ml_dtypes.h"
#include "tsl/protobuf/dnn.pb.h"
@@ -249,42 +249,6 @@
return absl::UnimplementedError("NormRunnerFromDesc not implemented.");
}
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHARunner>>
-DnnSupport::FusedMHARunnerFromDesc(
- Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
- const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor,
- const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
- const dnn::TensorDescriptor& output_descriptor,
- std::optional<dnn::TensorDescriptor> activation_descriptor,
- std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type) {
- return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented.");
-}
-
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
-DnnSupport::FusedMHABackwardRunnerFromDesc(
- Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
- const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
- const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
- const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
- const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
- const MatmulTensorDescriptor& d_output_descriptor,
- const TensorDescriptor& d_bmm1_lhs_descriptor,
- const TensorDescriptor& d_bmm1_rhs_descriptor,
- const TensorDescriptor& d_bmm2_rhs_descriptor,
- std::optional<dnn::TensorDescriptor> d_s_descriptor,
- std::optional<dnn::TensorDescriptor> d_bias_descriptor,
- std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
- std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type, bool force_deterministic) {
- return absl::UnimplementedError(
- "FusedMHABackwardRunnerFromDesc not implemented.");
-}
-
bool DnnSupport::GetMIOpenConvolveAlgorithms(
dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/,
dnn::DataType /*output_type*/, Stream* /*stream*/,
diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h
index 72f4603..a2e1cd6 100644
--- a/third_party/xla/xla/stream_executor/dnn.h
+++ b/third_party/xla/xla/stream_executor/dnn.h
@@ -993,30 +993,6 @@
using NormSignature = void(std::vector<DeviceMemoryBase>);
using NormRunner = OpRunner<NormSignature>;
-using FusedMHASignature = void(DeviceMemoryBase /*BMM1_inputA_data*/,
- DeviceMemoryBase /* BMM1_inputB_data */,
- DeviceMemoryBase /* BMM2_inputA_data */,
- DeviceMemoryBase /* output_data */,
- DeviceMemoryBase /* bias_data */,
- DeviceMemoryBase /* activation_data */,
- DeviceMemoryBase /* seqlen_q_data */,
- DeviceMemoryBase /* seqlen_k_data */);
-using FusedMHARunner = OpRunner<FusedMHASignature>;
-
-using FusedMHABackwardSignature = void(
- DeviceMemoryBase /* BMM1_GRAD_GEMM1_inputA_data */,
- DeviceMemoryBase /* BMM1_GRAD_GEMM2_inputB_data */,
- DeviceMemoryBase /* BMM2_GRAD_GEMM1_inputA_data */,
- DeviceMemoryBase /* BMM2_GRAD_GEMM2_inputB_data */,
- DeviceMemoryBase /* d_output_data */,
- DeviceMemoryBase /* d_BMM1_inputA_data */,
- DeviceMemoryBase /* d_BMM1_inputB_data */,
- DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_S_data */,
- DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */,
- DeviceMemoryBase /* bias_data */, DeviceMemoryBase /* seqlen_q_data */,
- DeviceMemoryBase /* seqlen_k_data */);
-using FusedMHABackwardRunner = OpRunner<FusedMHABackwardSignature>;
-
// Describes the configuration for the algorithms that will used.
//
// Arguments:
@@ -1257,11 +1233,7 @@
DnnGraph() = default;
virtual ~DnnGraph() = default;
- // Returns non-OK status on hard failures (incorrectly constructed graph,
- // anything else unexpected),
- // false on expected ones (graph is valid but not supported),
- // true on success.
- virtual absl::Status Prepare(DnnSupport&) = 0;
+ virtual absl::Status Prepare(DnnSupport&, const NumericOptions&) = 0;
virtual absl::Status Build(DnnSupport&, std::optional<int64_t> plan_id) = 0;
virtual absl::Status Execute(Stream& stream,
absl::Span<DeviceMemoryBase> operands) const = 0;
@@ -1735,37 +1707,6 @@
return absl::UnimplementedError("Graph support requires cuDNN >= 8.1.");
};
- virtual absl::StatusOr<std::unique_ptr<const FusedMHARunner>>
- FusedMHARunnerFromDesc(
- Stream* stream, const AlgorithmDesc& algorithm_desc,
- const MatmulTensorDescriptor& bmm1_lhs_descriptor,
- const MatmulTensorDescriptor& bmm1_rhs_descriptor,
- const MatmulTensorDescriptor& bmm2_rhs_descriptor,
- const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
- const TensorDescriptor& output_descriptor,
- std::optional<TensorDescriptor> activation_descriptor,
- std::optional<TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type);
-
- virtual absl::StatusOr<std::unique_ptr<const FusedMHABackwardRunner>>
- FusedMHABackwardRunnerFromDesc(
- Stream* stream, const AlgorithmDesc& algorithm_desc,
- const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
- const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
- const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
- const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
- const MatmulTensorDescriptor& d_output_descriptor,
- const TensorDescriptor& d_bmm1_lhs_descriptor,
- const TensorDescriptor& d_bmm1_rhs_descriptor,
- const TensorDescriptor& d_bmm2_rhs_descriptor,
- std::optional<TensorDescriptor> d_s_descriptor,
- std::optional<TensorDescriptor> d_bias_descriptor,
- std::optional<TensorDescriptor> fwd_output_descriptor,
- std::optional<TensorDescriptor> bias_descriptor, double scale,
- std::optional<double> dropout_rate, std::optional<int64_t> seed,
- dnn::FMHAMaskKind mask_type, bool force_deterministic);
-
virtual bool GetMIOpenConvolveAlgorithms(
ConvolutionKind kind, DataType element_type, DataType output_type,
Stream* stream, const BatchDescriptor& input_descriptor,
diff --git a/third_party/xla/xla/stream_executor/executor_cache.cc b/third_party/xla/xla/stream_executor/executor_cache.cc
index eae7206..341af6f 100644
--- a/third_party/xla/xla/stream_executor/executor_cache.cc
+++ b/third_party/xla/xla/stream_executor/executor_cache.cc
@@ -25,11 +25,12 @@
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
namespace stream_executor {
ExecutorCache::ExecutorCache() = default;
-ExecutorCache::~ExecutorCache() { DestroyAllExecutors(); }
+ExecutorCache::~ExecutorCache() = default;
absl::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
const StreamExecutorConfig& config, const ExecutorFactory& factory) {
@@ -40,85 +41,36 @@
return fast_result;
}
- Entry* entry = nullptr;
- {
- absl::MutexLock lock{&mutex_};
- entry = &cache_[config.ordinal];
- // Release the map lock; the address of 'entry' is stable because
- // absl::node_hash_map guarantees reference stability.
- }
-
- // Acquire the per-Entry mutex without holding the map mutex. Initializing
- // an Executor may be expensive, so we want to allow concurrent
- // initialization of different entries.
- absl::MutexLock lock{&entry->configurations_mutex};
- for (const auto& iter : entry->configurations) {
- VLOG(2) << "hit in cache";
- return iter.second.get();
- }
-
VLOG(2) << "building executor";
- absl::StatusOr<std::unique_ptr<StreamExecutor>> result = factory();
- if (!result.ok()) {
- VLOG(2) << "failed to get build executor: " << result.status();
- // If construction failed, leave the cache Entry around, but with a null
- // executor.
- return result.status();
- }
- entry->configurations.emplace_back(config, std::move(result.value()));
- return entry->configurations.back().second.get();
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<StreamExecutor> result, factory());
+ auto returned_executor = result.get();
+ absl::MutexLock lock(&mutex_);
+ cache_.emplace(config.ordinal, std::move(result));
+ return returned_executor;
}
absl::StatusOr<StreamExecutor*> ExecutorCache::Get(
const StreamExecutorConfig& config) {
- Entry* entry = nullptr;
- {
- absl::ReaderMutexLock lock{&mutex_};
+ absl::ReaderMutexLock lock{&mutex_};
- // If gpu stream is not nullptr we have to find StreamExecutor that owns it,
- // and return NOT_FOUND error if we can't find it.
- if (config.gpu_stream) {
- for (auto& [ordinal, e] : cache_) {
- absl::ReaderMutexLock l{&e.configurations_mutex};
- for (auto& [c, executor] : e.configurations) {
- if (executor->FindAllocatedStream(config.gpu_stream)) {
- return executor.get();
- }
- }
+ // If gpu stream is not nullptr we have to find StreamExecutor that owns it,
+ // and return NOT_FOUND error if we can't find it.
+ if (config.gpu_stream) {
+ for (auto& [ordinal, executor] : cache_) {
+ if (executor->FindAllocatedStream(config.gpu_stream)) {
+ return executor.get();
}
- return absl::NotFoundError(
- absl::StrFormat("No executors own stream %p", config.gpu_stream));
}
-
- if (auto it = cache_.find(config.ordinal); it != cache_.end()) {
- entry = &it->second;
- } else {
- return absl::NotFoundError(absl::StrFormat(
- "No executors registered for ordinal %d", config.ordinal));
- }
+ return absl::NotFoundError(
+ absl::StrFormat("No executors own stream %p", config.gpu_stream));
}
- absl::ReaderMutexLock lock{&entry->configurations_mutex};
- if (entry->configurations.empty()) {
- return absl::NotFoundError(absl::StrFormat(
- "No executors registered for ordinal %d", config.ordinal));
+ if (auto it = cache_.find(config.ordinal); it != cache_.end()) {
+ return it->second.get();
}
- for (auto& [entry_config, entry_executor] : entry->configurations) {
- return entry_executor.get();
- }
-
- return absl::NotFoundError("No executor found with a matching config.");
-}
-
-void ExecutorCache::DestroyAllExecutors() {
- absl::MutexLock lock{&mutex_};
- cache_.clear();
-}
-
-ExecutorCache::Entry::~Entry() {
- absl::MutexLock lock{&configurations_mutex};
- configurations.clear();
+ return absl::NotFoundError(absl::StrFormat(
+ "No executors registered for ordinal %d", config.ordinal));
}
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/executor_cache.h b/third_party/xla/xla/stream_executor/executor_cache.h
index 6e7f32e..ae62c6d 100644
--- a/third_party/xla/xla/stream_executor/executor_cache.h
+++ b/third_party/xla/xla/stream_executor/executor_cache.h
@@ -18,11 +18,9 @@
#include <functional>
#include <memory>
-#include <utility>
-#include <vector>
#include "absl/base/thread_annotations.h"
-#include "absl/container/node_hash_map.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "xla/stream_executor/platform.h"
@@ -52,33 +50,13 @@
// has been created), or a NOT_FOUND status.
absl::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
- // Destroys all Executors and clears the cache.
- // Performs no synchronization with the executors - undefined behavior may
- // occur if any executors are active!
- void DestroyAllExecutors();
-
private:
- // Each Entry contains zero or more cached executors for a device ordinal.
- struct Entry {
- ~Entry();
-
- // Mutex that guards the contents of each entry. The 'mutex_' of the
- // ExecutorCache class protects both the 'cache_' and the existence of each
- // Entry, but not the Entry's contents. 'configurations_mutex' protects the
- // contents of the entry after 'mutex_' has been dropped.
- absl::Mutex configurations_mutex;
-
- // Vector of cached {config, executor} pairs.
- std::vector<
- std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>>
- configurations ABSL_GUARDED_BY(configurations_mutex);
- };
-
- // Maps ordinal number to a list of cached executors for that ordinal.
- // We key off of ordinal (instead of just looking up all fields in the
- // StreamExecutorConfig) for a slight improvement in lookup time.
+ // Protects cache_.
absl::Mutex mutex_;
- absl::node_hash_map<int, Entry> cache_ ABSL_GUARDED_BY(mutex_);
+
+ // Maps ordinal number to a cached executor for that ordinal.
+ absl::flat_hash_map<int, std::unique_ptr<StreamExecutor>> cache_
+ ABSL_GUARDED_BY(mutex_);
ExecutorCache(const ExecutorCache&) = delete;
void operator=(const ExecutorCache&) = delete;
diff --git a/third_party/xla/xla/stream_executor/executor_cache_test.cc b/third_party/xla/xla/stream_executor/executor_cache_test.cc
new file mode 100644
index 0000000..71e9f72
--- /dev/null
+++ b/third_party/xla/xla/stream_executor/executor_cache_test.cc
@@ -0,0 +1,128 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/stream_executor/executor_cache.h"
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/stream_executor/mock_stream_executor.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/stream.h"
+#include "tsl/platform/statusor.h"
+
+namespace stream_executor {
+namespace {
+
+TEST(ExecutorCacheTest, GetOnEmptyCacheFails) {
+ ExecutorCache cache;
+ StreamExecutorConfig config;
+ config.ordinal = 0;
+ EXPECT_FALSE(cache.Get(config).ok());
+}
+
+TEST(ExecutorCacheTest, GetViaStreamOnEmptyCacheFails) {
+ ExecutorCache cache;
+ StreamExecutorConfig config;
+ config.ordinal = 0;
+ config.gpu_stream = reinterpret_cast<void *>(0x1234);
+ EXPECT_FALSE(cache.Get(config).ok());
+}
+
+TEST(ExecutorCacheTest, GetOrCreateConstructsAndRepeatedlyReturns) {
+ ExecutorCache cache;
+ StreamExecutorConfig config;
+ config.ordinal = 0;
+ StreamExecutor *created = nullptr;
+ auto factory = [&created]() {
+ auto executor = std::make_unique<MockStreamExecutor>();
+ created = executor.get();
+ return executor;
+ };
+ TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory));
+ EXPECT_EQ(executor, created);
+ TF_ASSERT_OK_AND_ASSIGN(auto found, cache.GetOrCreate(config, factory));
+ EXPECT_EQ(found, created);
+ TF_ASSERT_OK_AND_ASSIGN(found, cache.Get(config));
+ EXPECT_EQ(found, created);
+}
+
+TEST(ExecutorCacheTest, GetViaStreamFailsIfNotFound) {
+ ExecutorCache cache;
+ StreamExecutorConfig config;
+ config.ordinal = 0;
+ StreamExecutor *created = nullptr;
+ void *expected_stream = reinterpret_cast<void *>(0x1234);
+ auto factory = [&created, &expected_stream]() {
+ auto executor = std::make_unique<MockStreamExecutor>();
+ EXPECT_CALL(*executor, FindAllocatedStream(expected_stream))
+ .WillRepeatedly(testing::Return(nullptr));
+ created = executor.get();
+ return executor;
+ };
+
+ // Create the executor.
+ TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory));
+ EXPECT_EQ(executor, created);
+ // Now look for the expected stream, and don't expected to find it.
+ config.gpu_stream = expected_stream;
+ EXPECT_FALSE(cache.Get(config).ok());
+}
+
+TEST(ExecutorCacheTest, GetViaStreamWorksOnSecondStream) {
+ ExecutorCache cache;
+ StreamExecutorConfig config;
+ config.ordinal = 0;
+ StreamExecutor *created = nullptr;
+ Stream *expected_stream = reinterpret_cast<Stream *>(0x1234);
+
+ // Create a factory that will make the second StreamExecutor find the
+ // expected_stream.
+ auto factory = [&created, &expected_stream]() {
+ static int count = 0;
+ auto executor = std::make_unique<MockStreamExecutor>();
+ if (count != 1) {
+ EXPECT_CALL(*executor, FindAllocatedStream(expected_stream))
+ .WillRepeatedly(testing::Return(nullptr));
+ } else {
+ created = executor.get();
+ EXPECT_CALL(*executor, FindAllocatedStream(expected_stream))
+ .WillRepeatedly(testing::Invoke(
+ [expected_stream](void *stream) { return expected_stream; }));
+ }
+ ++count;
+ return executor;
+ };
+
+ // Create four executors.
+ std::vector<StreamExecutor *> created_executors;
+ for (int i = 0; i < 4; ++i) {
+ config.ordinal = i;
+ TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory));
+ EXPECT_NE(executor, nullptr);
+ created_executors.push_back(executor);
+ }
+ EXPECT_EQ(created_executors.size(), 4);
+ // Now look for the expected stream, and expect to find it on the second
+ // stream.
+ config.gpu_stream = expected_stream;
+ TF_ASSERT_OK_AND_ASSIGN(auto found, cache.Get(config));
+ EXPECT_EQ(found, created);
+}
+
+} // namespace
+} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD
index ae1ad8c..2d763e5 100644
--- a/third_party/xla/xla/stream_executor/gpu/BUILD
+++ b/third_party/xla/xla/stream_executor/gpu/BUILD
@@ -228,14 +228,12 @@
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
- "@local_tsl//tsl/platform:platform_port",
"@local_tsl//tsl/platform:thread_annotations",
],
)
@@ -244,8 +242,7 @@
name = "gpu_helpers_header",
hdrs = ["gpu_helpers.h"],
deps = [
- ":gpu_types_header",
- "@local_tsl//tsl/platform:logging",
+ "//xla/stream_executor:device_memory",
],
)
@@ -311,10 +308,13 @@
deps = [
":gpu_event_header",
":gpu_executor_header",
+ ":gpu_kernel_header",
":gpu_types_header",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:event",
"//xla/stream_executor:event_based_timer",
+ "//xla/stream_executor:kernel",
+ "//xla/stream_executor:launch_dim",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_common",
@@ -332,10 +332,13 @@
":gpu_driver_header",
":gpu_event_header",
":gpu_executor_header",
+ ":gpu_kernel_header",
":gpu_types_header",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:event",
"//xla/stream_executor:event_based_timer",
+ "//xla/stream_executor:kernel",
+ "//xla/stream_executor:launch_dim",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_common",
@@ -346,6 +349,7 @@
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:nvtx_utils",
],
)
@@ -378,7 +382,6 @@
":gpu_stream",
":gpu_types_header",
"//xla/stream_executor",
- "//xla/stream_executor:event",
"//xla/stream_executor:event_based_timer",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log",
@@ -387,7 +390,6 @@
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
- "@com_google_absl//absl/utility",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
@@ -557,8 +559,6 @@
name = "redzone_allocator_test",
srcs = ["redzone_allocator_test.cc"],
backends = ["gpu"],
- # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
- tags = ["gpu"],
deps = [
":gpu_asm_opts",
":gpu_init",
@@ -604,11 +604,7 @@
name = "gpu_cudamallocasync_allocator_test",
srcs = ["gpu_cudamallocasync_allocator_test.cc"],
backends = ["gpu_any"],
- tags = [
- # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
- "gpu",
- "no_rocm",
- ],
+ tags = ["no_rocm"],
deps = [
":gpu_cudamallocasync_allocator",
":gpu_stream",
@@ -740,8 +736,6 @@
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
- # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
- tags = ["gpu"],
deps = [
"//xla/stream_executor",
"//xla/stream_executor:device_memory",
@@ -765,8 +759,6 @@
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
- # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
- tags = ["gpu"],
deps = [
"//xla/stream_executor",
"//xla/stream_executor/host:host_platform",
@@ -789,8 +781,6 @@
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
- # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
- tags = ["gpu"],
deps = [
"//xla/service:platform_util",
"//xla/stream_executor:platform",
@@ -799,7 +789,6 @@
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
- "@local_tsl//tsl/platform:platform_port",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
] + if_cuda([
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
index b4bf7eb..20caccb 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
@@ -25,6 +25,7 @@
#include <vector>
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/host_or_device_scalar.h"
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
index d5d9ebd..c672b11 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
@@ -1006,6 +1006,8 @@
} else {
TF_RETURN_IF_ERROR(retry);
}
+ } else {
+ TF_RETURN_IF_ERROR(instantiated);
}
uint64_t end_nanos = tsl::Env::Default()->NowNanos();
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc
index cd32935..0376a5a 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc
@@ -776,6 +776,69 @@
ASSERT_EQ(dst, expected);
}
+TEST(GpuCommandBufferTest, ConditionalIfWithMemset) {
+#if CUDA_VERSION < 12040
+ GTEST_SKIP() << "ConditionalsWithMemset are not supported before 12.4.1.";
+#endif
+ Platform* platform = GpuPlatform();
+
+ StreamExecutor* executor = platform->ExecutorForDevice(0).value();
+
+ TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
+
+ int64_t length = 4;
+ int64_t byte_length = sizeof(int32_t) * length;
+
+ // Prepare arguments: a=0, pred=true
+ DeviceMemory<bool> pred = executor->AllocateArray<bool>(1, 0);
+ DeviceMemory<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
+
+ constexpr bool kTrue = true;
+ TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1));
+ TF_ASSERT_OK(stream->Memset32(&a, 0, byte_length));
+
+ // if (pred == true) memset(&a, ...);
+ CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) {
+ return then_cmd->Memset(&a, uint8_t{1}, byte_length);
+ };
+
+ // Create a command buffer with a single conditional operation.
+ TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer,
+ executor->CreateCommandBuffer(primary));
+ TF_ASSERT_OK(cmd_buffer->If(pred, then_builder));
+ TF_ASSERT_OK(cmd_buffer->Finalize());
+
+ TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
+
+ // Copy `a` data back to host.
+ std::vector<int32_t> dst(length, 42);
+ TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length));
+
+ std::vector<int32_t> expected(length, 1 << 24 | 1 << 16 | 1 << 8 | 1);
+ ASSERT_EQ(dst, expected);
+
+ // Prepare argument for graph update: b = 0
+ DeviceMemory<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
+ TF_ASSERT_OK(stream->MemZero(&a, byte_length));
+
+ // if (pred == true) memset(&b, ...);
+ then_builder = [&](CommandBuffer* then_cmd) {
+ return then_cmd->Memset(&b, uint8_t{1}, byte_length);
+ };
+
+ // Update command buffer with a conditional to use new builder.
+ TF_ASSERT_OK(cmd_buffer->Update());
+ TF_ASSERT_OK(cmd_buffer->If(pred, then_builder));
+ TF_ASSERT_OK(cmd_buffer->Finalize());
+
+ TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
+
+ // Copy `b` data back to host.
+ std::fill(dst.begin(), dst.end(), 42);
+ TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length));
+ ASSERT_EQ(dst, expected);
+}
+
TEST(GpuCommandBufferTest, ConditionalIfElse) {
if (!IsAtLeastCuda12300()) {
GTEST_SKIP() << "CUDA graph conditionals are not supported";
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h
index 599480c..94cff46 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h
@@ -31,6 +31,7 @@
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
namespace stream_executor {
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
index 116120d..bfb7e57 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
@@ -1,4 +1,3 @@
-#include "xla/stream_executor/event_based_timer.h"
/* Copyright 2019 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -31,10 +30,10 @@
#include <unordered_map>
#include <utility>
#include <variant>
+#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
-#include "absl/functional/any_invocable.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@@ -46,6 +45,7 @@
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/event.h"
+#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/fft.h"
#include "xla/stream_executor/gpu/gpu_collectives.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
@@ -59,7 +59,6 @@
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_common.h"
-#include "tsl/platform/numa.h"
#include "tsl/platform/thread_annotations.h"
namespace stream_executor {
@@ -113,8 +112,7 @@
device_ordinal_(device_ordinal),
cc_major_(0),
cc_minor_(0),
- version_(0),
- numa_node_(tsl::port::kNUMANoAffinity) {}
+ version_(0) {}
// See the corresponding StreamExecutor methods for method comments on the
// following overrides.
@@ -140,15 +138,6 @@
absl::StatusOr<std::shared_ptr<DeviceMemoryBase>> CreateOrShareConstant(
Stream* stream, absl::Span<const uint8_t> content) override;
- absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims, const Kernel& kernel,
- const KernelArgs& args) override;
-
- absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const ClusterDim& cluster_dims, const Kernel& kernel,
- const KernelArgs& args) override;
-
DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
void Deallocate(DeviceMemoryBase* mem) override;
@@ -169,10 +158,23 @@
return GpuCollectives::CollectiveMemoryDeallocate(context_, location);
}
+ // CUDA allocation/registration functions are necessary because the driver
+ // internally sets up buffers for DMA operations (and page locks them).
+ // There's no external interface for us to otherwise control these DMA
+ // settings.
absl::StatusOr<std::unique_ptr<MemoryAllocation>> HostMemoryAllocate(
- uint64_t size) override;
+ uint64_t size) override {
+ auto* buffer = GpuDriver::HostAllocate(context_, size);
+ if (buffer == nullptr && size > 0) {
+ return absl::InternalError(
+ absl::StrFormat("Failed to allocate HostMemory of size %d", size));
+ }
+ return std::make_unique<HostMemoryAllocation>(buffer, size, this);
+ }
- void HostMemoryDeallocate(void* location, uint64_t size) override;
+ void HostMemoryDeallocate(void* location) override {
+ return GpuDriver::HostDeallocate(context_, location);
+ }
absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override {
return GpuDriver::GetPointerMemorySpace(
@@ -305,11 +307,6 @@
absl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module)
TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
- absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const std::optional<ClusterDim>& cluster_dims,
- const Kernel& kernel, const KernelArgs& args);
-
bool UnloadGpuBinary(const void* gpu_binary)
TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
@@ -369,9 +366,6 @@
// GPU ISA version for device_.
int version_;
- // NUMA node for device_.
- int numa_node_;
-
// Type erased XLA specific state attached to GpuExecutor.
Object xla_state_;
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc
index 9ac7be1..c3c67bc 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc
@@ -20,7 +20,6 @@
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream_executor.h"
-#include "tsl/platform/numa.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
@@ -55,32 +54,4 @@
executor->Deallocate(&mem);
}
-using HostMemoryAllocateTest = GpuExecutorTest;
-
-TEST_F(HostMemoryAllocateTest, Numa) {
- Platform* platform = GetPlatform();
- const uint64_t kSize = 1024;
- const int num_devices = platform->VisibleDeviceCount();
- for (int device = 0; device < num_devices; ++device) {
- TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor,
- platform->ExecutorForDevice(device));
- ASSERT_TRUE(executor);
- TF_ASSERT_OK_AND_ASSIGN(auto device_desc,
- executor->CreateDeviceDescription());
- ASSERT_TRUE(device_desc);
- TF_ASSERT_OK_AND_ASSIGN(auto host_ptr, executor->HostMemoryAllocate(kSize));
- ASSERT_TRUE(host_ptr);
- EXPECT_NE(host_ptr->opaque(), nullptr);
- const int numa_node = tsl::port::NUMAGetMemAffinity(host_ptr->opaque());
- if (numa_node == tsl::port::kNUMANoAffinity) {
- // Could be because `executor` could not determine its own NUMA node, in
- // which case numa_node() will be -1 or 0, depending on the failure mode.
- EXPECT_LE(device_desc->numa_node(), 0);
- EXPECT_GE(device_desc->numa_node(), -1);
- } else {
- EXPECT_EQ(device_desc->numa_node(), numa_node);
- }
- }
-}
-
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h
index 62db127..187d882 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h
@@ -23,17 +23,10 @@
#include <stddef.h>
-#include <complex>
-#include <cstdint>
-
-#include "xla/stream_executor/gpu/gpu_types.h"
-#include "tsl/platform/logging.h"
+#include "xla/stream_executor/device_memory.h"
namespace stream_executor {
-template <typename ElemT>
-class DeviceMemory;
-
namespace gpu {
// Converts a const DeviceMemory reference to its underlying typed pointer in
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
index 7068265..b257ffa 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
@@ -17,6 +17,7 @@
#include <cstdint>
#include <memory>
+#include <optional>
#include <utility>
#include <variant>
@@ -31,10 +32,14 @@
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
+#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/nvtx_utils.h"
namespace stream_executor {
@@ -195,6 +200,83 @@
return parent_->CreateEventBasedTimer(this, use_delay_kernel);
}
+absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
+ const BlockDim& block_dims, const Kernel& kernel,
+ const KernelArgs& args) {
+ return Launch(thread_dims, block_dims, std::nullopt, kernel, args);
+}
+
+absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
+ const BlockDim& block_dims,
+ const ClusterDim& cluster_dims,
+ const Kernel& kernel, const KernelArgs& args) {
+ return Launch(thread_dims, block_dims, std::make_optional(cluster_dims),
+ kernel, args);
+}
+
+absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
+ const BlockDim& block_dims,
+ const std::optional<ClusterDim>& cluster_dims,
+ const Kernel& kernel, const KernelArgs& args) {
+ const GpuKernel* gpu_kernel = AsGpuKernel(&kernel);
+ GpuFunctionHandle function = gpu_kernel->gpu_function();
+
+ if (gpu_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
+ TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
+ function, gpu_kernel->GetGpuCacheConfig()));
+ }
+
+ // Launch kernels with packed arguments.
+ auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims,
+ &function](const KernelArgsPackedArrayBase& packed) {
+ int32_t expected_number_of_arguments =
+ kernel.Arity() + (packed.number_of_shared_bytes() > 0);
+
+ CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments())
+ << "Kernel " << kernel.name() << " has " << packed.number_of_arguments()
+ << " arguments, but expected " << expected_number_of_arguments
+ << "; arity=" << kernel.Arity()
+ << "; number_of_shared_bytes=" << packed.number_of_shared_bytes();
+
+ void** params = const_cast<void**>(packed.argument_addresses().data());
+
+ if (cluster_dims.has_value()) {
+ return GpuDriver::LaunchKernel(
+ parent_->gpu_context(), kernel.name(), function, cluster_dims->x,
+ cluster_dims->y, cluster_dims->z, block_dims.x, block_dims.y,
+ block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
+ packed.number_of_shared_bytes(), gpu_stream(), params,
+ /*extra=*/nullptr);
+ } else {
+ return GpuDriver::LaunchKernel(
+ parent_->gpu_context(), kernel.name(), function, block_dims.x,
+ block_dims.y, block_dims.z, thread_dims.x, thread_dims.y,
+ thread_dims.z, packed.number_of_shared_bytes(), gpu_stream(), params,
+ /*extra=*/nullptr);
+ }
+ };
+
+ // If arguments are already packed we can just launch the kernel.
+ if (auto* packed = DynCast<KernelArgsPackedArrayBase>(&args)) {
+ return launch(*packed);
+ }
+
+ // For device memory array we rely on a custom kernel arguments packing.
+ if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
+ auto& pack = kernel.args_packing();
+ if (!pack) {
+ return absl::InternalError(
+ "Kernel is missing a custom arguments packing function for device "
+ "memory arguments array");
+ }
+
+ TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem));
+ return launch(*packed);
+ }
+
+ return absl::InternalError("Unsupported kernel arguments type");
+}
+
GpuStream* AsGpuStream(Stream* stream) {
DCHECK(stream != nullptr);
return static_cast<GpuStream*>(stream);
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h
index 4cf21ca..249fbf7 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h
@@ -34,6 +34,8 @@
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_common.h"
@@ -103,8 +105,18 @@
void set_name(absl::string_view name) override;
absl::StatusOr<std::unique_ptr<EventBasedTimer>> CreateEventBasedTimer(
bool use_delay_kernel) override;
+ absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+ const Kernel& k, const KernelArgs& args) override;
+ absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+ const ClusterDim& cluster_dims, const Kernel& k,
+ const KernelArgs& args) override;
private:
+ // Helper method to launch a kernel with optional cluster dimensions.
+ absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+ const std::optional<ClusterDim>& cluster_dims,
+ const Kernel& kernel, const KernelArgs& args);
+
GpuExecutor* parent_; // Executor that spawned this stream.
GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle.
std::variant<StreamPriority, int> stream_priority_;
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h
index be0f9a5..656dd1e 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h
@@ -21,12 +21,9 @@
#include "absl/status/statusor.h"
#include "absl/time/time.h"
-#include "xla/stream_executor/event.h"
#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_semaphore.h"
-#include "xla/stream_executor/gpu/gpu_types.h"
-#include "xla/stream_executor/stream.h"
namespace xla {
namespace gpu {
diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD
index 326b3d6..8db3751 100644
--- a/third_party/xla/xla/stream_executor/host/BUILD
+++ b/third_party/xla/xla/stream_executor/host/BUILD
@@ -81,8 +81,11 @@
],
deps = [
":host_event",
+ ":host_kernel",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:event",
+ "//xla/stream_executor:kernel",
+ "//xla/stream_executor:launch_dim",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_common",
"@com_google_absl//absl/base:core_headers",
@@ -166,14 +169,17 @@
":host_event",
":host_kernel",
":host_stream",
- "//xla/stream_executor",
+ "//xla/stream_executor:device_description",
+ "//xla/stream_executor:device_memory",
"//xla/stream_executor:event",
"//xla/stream_executor:host_memory_allocation",
+ "//xla/stream_executor:kernel",
"//xla/stream_executor:kernel_spec",
"//xla/stream_executor:memory_allocation",
+ "//xla/stream_executor:platform",
+ "//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_common",
"//xla/stream_executor:stream_executor_h",
- "@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc
index 38715ce..68562b7 100644
--- a/third_party/xla/xla/stream_executor/host/host_executor.cc
+++ b/third_party/xla/xla/stream_executor/host/host_executor.cc
@@ -22,17 +22,17 @@
#include <cstdint>
#include <memory>
+#include <optional>
#include <string>
#include <utility>
+#include <variant>
#include <vector>
-#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
-#include "absl/synchronization/notification.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
@@ -40,9 +40,8 @@
#include "xla/stream_executor/host/host_kernel.h"
#include "xla/stream_executor/host/host_stream.h"
#include "xla/stream_executor/kernel_spec.h"
-#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/stream_executor.h"
+#include "xla/stream_executor/stream.h"
#include "tsl/platform/cpu_info.h"
#include "tsl/platform/env.h"
#include "tsl/platform/mem.h"
@@ -91,26 +90,6 @@
return absl::InternalError("No method of loading host kernel provided");
}
-absl::Status HostExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const Kernel& kernel,
- const KernelArgs& args) {
- const HostKernel* host_kernel = AsHostKernel(&kernel);
-
- const KernelArgsDeviceMemoryArray* device_mem =
- DynCast<KernelArgsDeviceMemoryArray>(&args);
-
- absl::Status result;
- if (device_mem != nullptr) {
- result = host_kernel->Launch(thread_dims, device_mem->device_memory_args());
- } else {
- result = absl::UnimplementedError(
- "Host kernel implements Launch method only for DeviceMemoryArray "
- "arguments.");
- }
- return result;
-}
-
bool HostExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const {
tsl::port::MemoryInfo mem_info = tsl::port::GetMemoryInfo();
*free = (mem_info.free != INT64_MAX) ? mem_info.free : -1;
diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h
index 5f1c5d0..55eacc5 100644
--- a/third_party/xla/xla/stream_executor/host/host_executor.h
+++ b/third_party/xla/xla/stream_executor/host/host_executor.h
@@ -13,20 +13,15 @@
limitations under the License.
==============================================================================*/
-// Declares the HostExecutor class, which is a CPU-only implementation of
-// the StreamExecutor interface. For now, this is used for testing and to
-// examine the performance of host-based StreamExecutor code.
#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_
#define XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_
-#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <variant>
-#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/stream_executor/device_description.h"
@@ -36,24 +31,21 @@
#include "xla/stream_executor/host_memory_allocation.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
-#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor_common.h"
#include "tsl/platform/threadpool.h"
namespace stream_executor {
namespace host {
-// An implementation of StreamExecutor that does no communication or interaction
-// with a device, but DOES perform memory operations backed by the host.
-// Kernel invocations will fail, but host callbacks may be enqueued on this
-// executor and its associated stream, and should follow standard ordering
-// semantics.
+// Declares the HostExecutor class, which is a CPU-only implementation of
+// the StreamExecutor interface. For now, this is used for testing and to
+// examine the performance of host-based StreamExecutor code.
//
// This is useful for evaluating the performance of host-based or fallback
// routines executed under the context of a GPU executor.
-// See stream_executor.h for description of the below operations.
class HostExecutor : public StreamExecutorCommon {
public:
// A function that loads a kernel function from a given spec. If spec is not
@@ -73,10 +65,6 @@
absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec& spec) override;
- absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims, const Kernel& kernel,
- const KernelArgs& args) override;
-
DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
void Deallocate(DeviceMemoryBase* mem) override;
@@ -84,11 +72,10 @@
uint64_t size) override {
return std::make_unique<HostMemoryAllocation>(new char[size], size, this);
}
- void HostMemoryDeallocate(void* mem, uint64_t size) override {
+ void HostMemoryDeallocate(void* mem) override {
delete[] static_cast<char*>(mem);
}
- // No "synchronize all activity" implemented for this platform at the moment.
bool SynchronizeAllActivity() override { return true; }
absl::Status SynchronousMemZero(DeviceMemoryBase* location,
uint64_t size) override;
diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc
index a99c675..08bc958 100644
--- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc
+++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc
@@ -89,10 +89,10 @@
}
)";
-static absl::StatusOr<std::unique_ptr<StreamExecutor>> NewStreamExecutor() {
- StreamExecutorConfig config(/*ordinal=*/0);
+static absl::StatusOr<StreamExecutor*> NewStreamExecutor() {
TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host"));
- TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetUncachedExecutor(config));
+ TF_ASSIGN_OR_RETURN(auto stream_exec,
+ platform->ExecutorForDevice(/*ordinal=*/0));
return stream_exec;
}
diff --git a/third_party/xla/xla/stream_executor/host/host_platform.h b/third_party/xla/xla/stream_executor/host/host_platform.h
index 25c1179..3dd90a6 100644
--- a/third_party/xla/xla/stream_executor/host/host_platform.h
+++ b/third_party/xla/xla/stream_executor/host/host_platform.h
@@ -54,10 +54,13 @@
absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
- absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
- const StreamExecutorConfig& config) override;
-
private:
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
+ absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+ const StreamExecutorConfig& config);
+
// This platform's name.
std::string name_;
diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc
index ed6e040..76b6671 100644
--- a/third_party/xla/xla/stream_executor/host/host_stream.cc
+++ b/third_party/xla/xla/stream_executor/host/host_stream.cc
@@ -33,6 +33,9 @@
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/host/host_event.h"
+#include "xla/stream_executor/host/host_kernel.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_common.h"
#include "tsl/platform/denormal.h"
@@ -192,6 +195,21 @@
return status;
}
-} // namespace host
+absl::Status HostStream::Launch(const ThreadDim& thread_dims,
+ const BlockDim& block_dims,
+ const Kernel& kernel, const KernelArgs& args) {
+ const HostKernel* host_kernel = AsHostKernel(&kernel);
+ const KernelArgsDeviceMemoryArray* device_mem =
+ DynCast<KernelArgsDeviceMemoryArray>(&args);
+
+ if (device_mem != nullptr) {
+ return host_kernel->Launch(thread_dims, device_mem->device_memory_args());
+ }
+ return absl::UnimplementedError(
+ "Host kernel implements Launch method only for DeviceMemoryArray "
+ "arguments.");
+}
+
+} // namespace host
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h
index ed1bbc2..563d9fc 100644
--- a/third_party/xla/xla/stream_executor/host/host_stream.h
+++ b/third_party/xla/xla/stream_executor/host/host_stream.h
@@ -13,12 +13,10 @@
limitations under the License.
==============================================================================*/
-// Class declaration for Stream type that enqueues tasks onto a host/CPU-based
-// execution context (as opposed to a GPU device), HostExecutor.
#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_
#define XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_
-#include <cstddef>
+#include <cstdint>
#include <memory>
#include <queue>
@@ -27,6 +25,10 @@
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/event.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
+#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_common.h"
#include "tsl/platform/env.h"
#include "tsl/platform/thread_annotations.h"
@@ -34,6 +36,8 @@
namespace stream_executor {
namespace host {
+// Class declaration for Stream type that enqueues tasks onto a host/CPU-based
+// execution context (as opposed to a GPU device), HostExecutor.
class HostStream : public StreamCommon {
public:
explicit HostStream(StreamExecutor* executor);
@@ -65,6 +69,8 @@
uint64_t size) override;
absl::Status DoHostCallbackWithStatus(
absl::AnyInvocable<absl::Status() &&> callback) override;
+ absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+ const Kernel& kernel, const KernelArgs& args) override;
private:
bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
diff --git a/third_party/xla/xla/stream_executor/host_memory_allocation.cc b/third_party/xla/xla/stream_executor/host_memory_allocation.cc
index 9772396..e77c5e8 100644
--- a/third_party/xla/xla/stream_executor/host_memory_allocation.cc
+++ b/third_party/xla/xla/stream_executor/host_memory_allocation.cc
@@ -27,7 +27,7 @@
HostMemoryAllocation::~HostMemoryAllocation() {
if (ptr_ != nullptr && executor_ != nullptr) {
- executor_->HostMemoryDeallocate(ptr_, size_);
+ executor_->HostMemoryDeallocate(ptr_);
}
}
diff --git a/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h b/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h
index 8b31f8b..736b62e 100644
--- a/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h
+++ b/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h
@@ -82,7 +82,7 @@
auto status = stream_exec_->CollectiveMemoryDeallocate(ptr);
CHECK(status.ok()) << status.message();
} else if (memory_type_ == MemoryType::kHost) {
- stream_exec_->HostMemoryDeallocate(ptr, num_bytes);
+ stream_exec_->HostMemoryDeallocate(ptr);
} else {
DeviceMemoryBase device_ptr(ptr);
stream_exec_->Deallocate(&device_ptr);
diff --git a/third_party/xla/xla/stream_executor/kernel_test.cc b/third_party/xla/xla/stream_executor/kernel_test.cc
index cf63e5b..a554785 100644
--- a/third_party/xla/xla/stream_executor/kernel_test.cc
+++ b/third_party/xla/xla/stream_executor/kernel_test.cc
@@ -66,15 +66,12 @@
std::is_same_v<ArgsStorage<DeviceMemoryBase*, const DeviceMemoryBase*>,
std::tuple<const void*, const void*>>);
-static std::unique_ptr<StreamExecutor> NewStreamExecutor() {
+static StreamExecutor* NewStreamExecutor() {
Platform* platform = PlatformManager::PlatformWithName("Host").value();
- StreamExecutorConfig config(/*ordinal=*/0);
- return platform->GetUncachedExecutor(config).value();
+ return platform->ExecutorForDevice(/*ordinal=*/0).value();
}
TEST(KernelTest, PackDeviceMemoryArguments) {
- auto executor = NewStreamExecutor();
-
DeviceMemoryBase a(reinterpret_cast<void*>(0x12345678));
DeviceMemoryBase b(reinterpret_cast<void*>(0x87654321));
@@ -125,7 +122,7 @@
MultiKernelLoaderSpec empty_spec(/*arity=*/0);
auto executor = NewStreamExecutor();
- auto kernel = TypedKernelFactory<>::Create(executor.get(), empty_spec);
+ auto kernel = TypedKernelFactory<>::Create(executor, empty_spec);
EXPECT_FALSE(kernel.ok());
}
diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h
index c74a03e..bf964e0 100644
--- a/third_party/xla/xla/stream_executor/lazy_op_runner.h
+++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h
@@ -280,76 +280,6 @@
}
};
-struct FusedMHAOp {
- using Signature = FusedMHASignature;
- struct Config {
- double scale;
- const MatmulTensorDescriptor& bmm1_lhs_descriptor;
- const MatmulTensorDescriptor& bmm1_rhs_descriptor;
- const MatmulTensorDescriptor& bmm2_rhs_descriptor;
- const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor;
- const TensorDescriptor& output_descriptor;
- std::optional<TensorDescriptor> bias_descriptor;
- std::optional<TensorDescriptor> activation_descriptor;
- std::optional<double> dropout_rate;
- std::optional<int64_t> seed;
- FMHAMaskKind mask_type;
- };
-
- static absl::StatusOr<std::unique_ptr<const OpRunner<FusedMHASignature>>>
- RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config,
- Stream* stream) {
- TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream));
- return dnn->FusedMHARunnerFromDesc(
- stream, desc, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor,
- config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor,
- config.output_descriptor, config.activation_descriptor,
- config.bias_descriptor, config.scale, config.dropout_rate, config.seed,
- config.mask_type);
- }
-};
-
-struct FusedMHABackwardOp {
- using Signature = FusedMHABackwardSignature;
-
- struct Config {
- double scale;
- const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor;
- const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor;
- const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor;
- const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor;
- const MatmulTensorDescriptor& d_output_descriptor;
- const TensorDescriptor& d_bmm1_lhs_descriptor;
- const TensorDescriptor& d_bmm1_rhs_descriptor;
- const TensorDescriptor& d_bmm2_rhs_descriptor;
- std::optional<TensorDescriptor> d_s_descriptor;
- std::optional<TensorDescriptor> d_bias_descriptor;
- std::optional<TensorDescriptor> fwd_output_descriptor;
- std::optional<TensorDescriptor> bias_descriptor;
- std::optional<double> dropout_rate;
- std::optional<int64_t> seed;
- FMHAMaskKind mask_type;
- bool force_deterministic;
- };
-
- static absl::StatusOr<
- std::unique_ptr<const OpRunner<FusedMHABackwardSignature>>>
- RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config,
- Stream* stream) {
- TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream));
- return dnn->FusedMHABackwardRunnerFromDesc(
- stream, desc, config.bmm1_grad_gemm1_rhs_descriptor,
- config.bmm1_grad_gemm2_rhs_descriptor,
- config.bmm2_grad_gemm1_lhs_descriptor,
- config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor,
- config.d_bmm1_lhs_descriptor, config.d_bmm1_rhs_descriptor,
- config.d_bmm2_rhs_descriptor, config.d_s_descriptor,
- config.d_bias_descriptor, config.fwd_output_descriptor,
- config.bias_descriptor, config.scale, config.dropout_rate, config.seed,
- config.mask_type, config.force_deterministic);
- }
-};
-
} // namespace dnn
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h
index 03dd111..3c69f20 100644
--- a/third_party/xla/xla/stream_executor/mock_stream_executor.h
+++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h
@@ -22,7 +22,6 @@
#include <string>
#include <variant>
-#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
@@ -89,8 +88,7 @@
(override));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<MemoryAllocation>>,
HostMemoryAllocate, (uint64_t size), (override));
- MOCK_METHOD(void, HostMemoryDeallocate, (void* mem, uint64_t size),
- (override));
+ MOCK_METHOD(void, HostMemoryDeallocate, (void* mem), (override));
MOCK_METHOD(bool, SynchronizeAllActivity, (), (override));
MOCK_METHOD(absl::Status, SynchronousMemZero,
(DeviceMemoryBase * location, uint64_t size), (override));
diff --git a/third_party/xla/xla/stream_executor/platform.h b/third_party/xla/xla/stream_executor/platform.h
index 5fbc44b..c5120e5 100644
--- a/third_party/xla/xla/stream_executor/platform.h
+++ b/third_party/xla/xla/stream_executor/platform.h
@@ -105,6 +105,13 @@
virtual absl::StatusOr<std::unique_ptr<DeviceDescription>>
DescriptionForDevice(int ordinal) const = 0;
+ // Returns a StreamExecutor for the given ordinal if one has already been
+ // created, or an error is returned if none exists. Does not create a new
+ // context with the device.
+ virtual absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) {
+ return absl::NotFoundError("Not implemented for this platform.");
+ }
+
// Returns a device with the given ordinal on this platform with a default
// plugin configuration or, if none can be found with the given ordinal or
// there is an error in opening a context to communicate with the device, an
@@ -118,12 +125,6 @@
// Ownership of the executor is NOT transferred to the caller.
virtual absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) = 0;
-
- // Returns a device constructed with the options specified in "config" without
- // looking in or storing to the Platform's executor cache.
- // Ownership IS transferred to the caller.
- virtual absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
- const StreamExecutorConfig& config) = 0;
};
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD
index 06a44b0..7204036 100644
--- a/third_party/xla/xla/stream_executor/rocm/BUILD
+++ b/third_party/xla/xla/stream_executor/rocm/BUILD
@@ -184,6 +184,7 @@
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@local_config_rocm//rocm:rocm_headers",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:fingerprint",
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
index 465dbbe..2f61eae 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
@@ -1148,6 +1148,21 @@
return absl::OkStatus();
}
+absl::Status GpuDriver::LaunchKernel(
+ GpuContext* context, absl::string_view kernel_name,
+ GpuFunctionHandle function, unsigned int cluster_dim_x,
+ unsigned int cluster_dim_y, unsigned int cluster_dim_z,
+ unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z,
+ unsigned int block_dim_x, unsigned int block_dim_y,
+ unsigned int block_dim_z, unsigned int shared_mem_bytes,
+ GpuStreamHandle stream, void** kernel_params, void** extra) {
+ if (cluster_dim_x != 1 || cluster_dim_y != 1 || cluster_dim_z != 1)
+ return absl::UnimplementedError("Not implemented for ROCm");
+ return LaunchKernel(context, kernel_name, function, grid_dim_x, grid_dim_y,
+ grid_dim_z, block_dim_x, block_dim_y, block_dim_z,
+ shared_mem_bytes, stream, kernel_params, extra);
+}
+
/* static */ absl::Status GpuDriver::LoadPtx(GpuContext* context,
const char* ptx_contents,
hipModule_t* module) {
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
index cb096c6..76a4db7 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
@@ -333,58 +333,6 @@
return absl::OkStatus();
}
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const Kernel& kernel, const KernelArgs& args) {
- GpuStreamHandle hipstream = AsGpuStreamValue(stream);
- const GpuKernel* rocm_kernel = AsGpuKernel(&kernel);
- hipFunction_t hipfunc = rocm_kernel->gpu_function();
-
- if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
- TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
- hipfunc, rocm_kernel->GetGpuCacheConfig()));
- }
-
- auto launch = [&](const KernelArgsPackedArrayBase& packed) {
- CHECK_EQ(kernel.Arity() + (args.number_of_shared_bytes() > 0),
- packed.number_of_arguments());
-
- void** kernel_params =
- const_cast<void**>(packed.argument_addresses().data());
-
- return GpuDriver::LaunchKernel(
- GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x,
- block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
- args.number_of_shared_bytes(), hipstream, kernel_params, nullptr);
- };
-
- auto* packed_args = DynCast<KernelArgsPackedArrayBase>(&args);
- if (packed_args) return launch(*packed_args);
-
- if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
- auto& pack = kernel.args_packing();
- if (!pack) {
- return absl::InternalError(
- "Kernel is missing a custom arguments packing function for device "
- "memory arguments array");
- }
-
- TF_ASSIGN_OR_RETURN(auto packed_args, pack(kernel, *device_mem));
- return launch(*packed_args);
- }
-
- return absl::InternalError("Unsupported kernel arguments type");
-}
-
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
- const BlockDim& block_dims,
- const ClusterDim& cluster_dims,
- const Kernel& kernel, const KernelArgs& args) {
- if (cluster_dims.x != 1 || cluster_dims.y != 1 || cluster_dims.z != 1)
- return absl::UnimplementedError("Not implemented for ROCm");
- return Launch(stream, thread_dims, block_dims, kernel, args);
-}
-
absl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
ModuleHandle* module_handle) {
// In GpuExecutor we store the pointer to the HSACO binary as
@@ -447,20 +395,6 @@
GpuDriver::DeviceDeallocate(context_, mem->opaque());
}
-absl::StatusOr<std::unique_ptr<MemoryAllocation>>
-GpuExecutor::HostMemoryAllocate(uint64_t size) {
- auto* buffer = GpuDriver::HostAllocate(context_, size);
- if (buffer == nullptr && size > 0) {
- return absl::InternalError(
- absl::StrFormat("Failed to allocate HostMemory of size %d", size));
- }
- return std::make_unique<HostMemoryAllocation>(buffer, size, this);
-}
-
-void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) {
- return GpuDriver::HostDeallocate(context_, location);
-}
-
bool GpuExecutor::SynchronizeAllActivity() {
return GpuDriver::SynchronizeContext(context_);
}
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc
index 0ac3540..a65b5df 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc
@@ -28,67 +28,10 @@
namespace stream_executor {
namespace gpu {
-ROCmPlatform::ROCmPlatform()
- : name_("ROCM"), min_numa_node_(0), limit_numa_node_(0) {}
+ROCmPlatform::ROCmPlatform() : name_("ROCM") {}
ROCmPlatform::~ROCmPlatform() {}
-// Due to legacy issues in user code, we can't currently call InpectNumaNodes
-// at module initialization time, because non-GPU programs still include this
-// plugin via various methods, so instead, it has to be init-on-reference.
-void ROCmPlatform::InspectNumaNodes() {
- // To get NUMA node information, we need to create all executors, so we can
- // examine their device descriptions to see their bus assignments.
- absl::once_flag once;
- absl::call_once(once, [&] {
- StreamExecutorConfig config;
- for (int i = 0; i < VisibleDeviceCount(); i++) {
- config.ordinal = i;
- StreamExecutor* exec = GetExecutor(config).value();
- if (i == 0) {
- // NUMA nodes may not start at 0, so set the minimum node based on the
- // first executor we see.
- min_numa_node_ = exec->GetDeviceDescription().numa_node();
- limit_numa_node_ = min_numa_node_ + 1;
- } else {
- min_numa_node_ =
- std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
- limit_numa_node_ = std::max(
- limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
- }
- }
- });
-}
-
-int ROCmPlatform::BusCount() {
- InspectNumaNodes();
- return limit_numa_node_ - min_numa_node_;
-}
-
-int ROCmPlatform::DeviceToBus(int device_ordinal) {
- StreamExecutorConfig config;
- config.ordinal = device_ordinal;
- StreamExecutor* exec = GetExecutor(config).value();
- return exec->GetDeviceDescription().numa_node() - min_numa_node_;
-}
-
-absl::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
- int bus_ordinal) {
- InspectNumaNodes();
- CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
- for (int i = 0; i < VisibleDeviceCount(); i++) {
- if (DeviceToBus(i) == bus_ordinal) {
- StreamExecutorConfig config;
- config.ordinal = i;
- return GetExecutor(config).value();
- }
- }
-
- return absl::Status{
- absl::StatusCode::kNotFound,
- absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
-}
-
Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
int ROCmPlatform::VisibleDeviceCount() const {
@@ -115,6 +58,12 @@
return GetExecutor(config);
}
+absl::StatusOr<StreamExecutor*> ROCmPlatform::FindExisting(int ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ return executor_cache_.Get(config);
+}
+
absl::StatusOr<StreamExecutor*> ROCmPlatform::GetExecutor(
const StreamExecutorConfig& config) {
if (config.gpu_stream) {
@@ -130,27 +79,17 @@
absl::StatusOr<std::unique_ptr<StreamExecutor>>
ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = std::make_unique<GpuExecutor>(this, config.ordinal);
- auto init_status = executor->Init();
- if (!init_status.ok()) {
- return absl::Status{
- absl::StatusCode::kInternal,
- absl::StrFormat(
- "failed initializing StreamExecutor for ROCM device ordinal %d: %s",
- config.ordinal, init_status.ToString().c_str())};
- }
-
+ TF_RETURN_IF_ERROR(executor->Init());
return std::move(executor);
}
} // namespace gpu
static void InitializeROCmPlatform() {
- // Disabling leak checking, PlatformManager does not destroy its
- // registered platforms.
auto status = PlatformManager::PlatformWithName("ROCM");
if (!status.ok()) {
- std::unique_ptr<gpu::ROCmPlatform> platform(new gpu::ROCmPlatform);
- TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform)));
+ TF_CHECK_OK(PlatformManager::RegisterPlatform(
+ std::make_unique<gpu::ROCmPlatform>()));
}
}
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h
index 6d18cf4..923cda8 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h
@@ -41,16 +41,6 @@
ROCmPlatform();
~ROCmPlatform() override;
- // ROCmPlatform-specific functionality
- // Returns the number of distinct buses / NUMA nodes on the machine.
- int BusCount();
-
- // Returns the bus/NUMA node for the specified device ordinal.
- int DeviceToBus(int device_ordinal);
-
- // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
- absl::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
-
// Platform interface implementation:
// Returns the same value as kROCmPlatform above.
Platform::Id id() const override;
@@ -64,16 +54,17 @@
int ordinal) const override;
absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
+ absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
- absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
- const StreamExecutorConfig& config) override;
-
private:
- // Determines the number of NUMA nodes and the assignment of executor to each.
- void InspectNumaNodes();
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
+ absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+ const StreamExecutorConfig& config);
// This platform's name.
std::string name_;
@@ -84,15 +75,6 @@
// Cache of created executors.
ExecutorCache executor_cache_;
- // The smallest NUMA node value for any device managed by this machine
- // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
- // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./
- int min_numa_node_;
-
- // Larger than the NUMA node value for any device managed by this machine
- // manager.
- int limit_numa_node_;
-
ROCmPlatform(const ROCmPlatform&) = delete;
void operator=(const ROCmPlatform&) = delete;
};
diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc
index 048623d..b47198a 100644
--- a/third_party/xla/xla/stream_executor/stream_common.cc
+++ b/third_party/xla/xla/stream_executor/stream_common.cc
@@ -21,7 +21,6 @@
#include <utility>
#include <vector>
-#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h
index 3d2ade7..e3e75f6 100644
--- a/third_party/xla/xla/stream_executor/stream_common.h
+++ b/third_party/xla/xla/stream_executor/stream_common.h
@@ -28,7 +28,6 @@
#include <vector>
#include "absl/base/thread_annotations.h"
-#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h
index 49929d4..1fa4312 100644
--- a/third_party/xla/xla/stream_executor/stream_executor.h
+++ b/third_party/xla/xla/stream_executor/stream_executor.h
@@ -1,5 +1,3 @@
-#include "absl/functional/any_invocable.h"
-#include "absl/log/log.h"
/* Copyright 2015 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -31,6 +29,7 @@
#include <variant>
#include <vector>
+#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
@@ -200,7 +199,7 @@
uint64_t size) = 0;
// Deallocates a region of host memory allocated by HostMemoryAllocate().
- virtual void HostMemoryDeallocate(void* mem, uint64_t size) = 0;
+ virtual void HostMemoryDeallocate(void* mem) = 0;
// Returns the memory space of the given pointer.
virtual absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) {
diff --git a/third_party/xla/xla/stream_executor/stream_executor_test.cc b/third_party/xla/xla/stream_executor/stream_executor_test.cc
index 34bd459..9a2ca57 100644
--- a/third_party/xla/xla/stream_executor/stream_executor_test.cc
+++ b/third_party/xla/xla/stream_executor/stream_executor_test.cc
@@ -25,10 +25,9 @@
namespace stream_executor {
-static absl::StatusOr<std::unique_ptr<StreamExecutor>> NewStreamExecutor() {
- StreamExecutorConfig config(/*ordinal=*/0);
+static absl::StatusOr<StreamExecutor*> NewStreamExecutor() {
TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host"));
- TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetUncachedExecutor(config));
+ TF_ASSIGN_OR_RETURN(auto stream_exec, platform->ExecutorForDevice(0));
return stream_exec;
}
diff --git a/third_party/xla/xla/stream_executor/stream_test.cc b/third_party/xla/xla/stream_executor/stream_test.cc
index 473472d..ef5294e 100644
--- a/third_party/xla/xla/stream_executor/stream_test.cc
+++ b/third_party/xla/xla/stream_executor/stream_test.cc
@@ -29,31 +29,30 @@
class StreamTest : public ::testing::Test {
protected:
- std::unique_ptr<StreamExecutor> NewStreamExecutor() {
+ StreamExecutor* NewStreamExecutor() {
Platform* platform = PlatformManager::PlatformWithName("Host").value();
- StreamExecutorConfig config(/*ordinal=*/0);
- return platform->GetUncachedExecutor(config).value();
+ return platform->ExecutorForDevice(/*ordinal=*/0).value();
}
};
TEST_F(StreamTest, InitOk) {
- std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ StreamExecutor* executor = NewStreamExecutor();
TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
}
TEST_F(StreamTest, InitWithIntPriorityOk) {
- std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ StreamExecutor* executor = NewStreamExecutor();
TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream(1));
}
TEST_F(StreamTest, InitWithStreamPriorityOk) {
- std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ StreamExecutor* executor = NewStreamExecutor();
TF_ASSERT_OK_AND_ASSIGN(auto stream,
executor->CreateStream(StreamPriority::Highest));
}
TEST_F(StreamTest, OneSubStream) {
- std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ StreamExecutor* executor = NewStreamExecutor();
TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
// Get and return a sub-stream. Sub-streams are always initialized.
@@ -72,7 +71,7 @@
}
TEST_F(StreamTest, TwoSubStreams) {
- std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ StreamExecutor* executor = NewStreamExecutor();
TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
// Get two sub-streams.
diff --git a/third_party/xla/xla/stream_executor/sycl/BUILD b/third_party/xla/xla/stream_executor/sycl/BUILD
index 8745be9..86c00a0 100644
--- a/third_party/xla/xla/stream_executor/sycl/BUILD
+++ b/third_party/xla/xla/stream_executor/sycl/BUILD
@@ -48,6 +48,7 @@
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_collectives_header",
+ "@local_tsl//tsl/platform:errors",
]),
alwayslink = True, # Registers itself with the PlatformManager.
)
diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc
index 876775b..ac6da36 100644
--- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc
+++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc
@@ -35,65 +35,16 @@
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"
+#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
namespace stream_executor {
namespace gpu {
-SyclPlatform::SyclPlatform()
- : name_("SYCL"), min_numa_node_(0), limit_numa_node_(0) {}
+SyclPlatform::SyclPlatform() : name_("SYCL") {}
SyclPlatform::~SyclPlatform() {}
-// Due to legacy issues in user code, we can't currently call InspectNumaNodes
-// at module initialization time, because non-GPU programs still include this
-// plugin via various methods, so instead, it has to be init-on-reference.
-void SyclPlatform::InspectNumaNodes() {
- // To get NUMA node information, we need to create all executors, so we can
- // examine their device descriptions to see their bus assignments.
- static absl::once_flag once;
- absl::call_once(once, [&] {
- for (int i = 0; i < VisibleDeviceCount(); i++) {
- StreamExecutor* exec = *ExecutorForDevice(i);
- if (i == 0) {
- // NUMA nodes may not start at 0, so set the minimum node based on the
- // first executor we see.
- min_numa_node_ = exec->GetDeviceDescription().numa_node();
- limit_numa_node_ = min_numa_node_ + 1;
- } else {
- min_numa_node_ =
- std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
- limit_numa_node_ = std::max(
- limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
- }
- }
- });
-}
-
-int SyclPlatform::BusCount() {
- InspectNumaNodes();
- return limit_numa_node_ - min_numa_node_;
-}
-
-int SyclPlatform::DeviceToBus(int device_ordinal) {
- StreamExecutor* exec = *ExecutorForDevice(device_ordinal);
- return exec->GetDeviceDescription().numa_node() - min_numa_node_;
-}
-
-absl::StatusOr<StreamExecutor*> SyclPlatform::FirstExecutorForBus(
- int bus_ordinal) {
- InspectNumaNodes();
- CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
- for (int i = 0; i < VisibleDeviceCount(); i++) {
- if (DeviceToBus(i) == bus_ordinal) {
- return *ExecutorForDevice(i);
- }
- }
-
- return absl::NotFoundError(
- absl::StrFormat("Executor for bus %d not found.", bus_ordinal));
-}
-
Platform::Id SyclPlatform::id() const { return sycl::kSyclPlatformId; }
int SyclPlatform::VisibleDeviceCount() const {
@@ -133,24 +84,15 @@
absl::StatusOr<std::unique_ptr<StreamExecutor>>
SyclPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = std::make_unique<GpuExecutor>(this, config.ordinal);
- auto init_status = executor->Init();
- if (!init_status.ok()) {
- return absl::InternalError(absl::StrFormat(
- "failed initializing StreamExecutor for SYCL device ordinal %d: %s",
- config.ordinal, init_status.ToString()));
- }
-
+ TF_RETURN_IF_ERROR(executor->Init());
return std::move(executor);
}
} // namespace gpu
static void InitializeSyclPlatform() {
- // Disabling leak checking, PlatformManager does not destroy its
- // registered platforms.
-
- std::unique_ptr<gpu::SyclPlatform> platform(new gpu::SyclPlatform);
- TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform)));
+ TF_CHECK_OK(
+ PlatformManager::RegisterPlatform(std::make_unique<gpu::SyclPlatform>()));
}
} // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h
index 0c687f4..adc6cc9 100644
--- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h
+++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h
@@ -41,16 +41,6 @@
SyclPlatform();
~SyclPlatform() override;
- // SyclPlatform-specific functionality
- // Returns the number of distinct buses / NUMA nodes on the machine.
- int BusCount();
-
- // Returns the bus/NUMA node for the specified device ordinal.
- int DeviceToBus(int device_ordinal);
-
- // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
- absl::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
-
// Platform interface implementation:
// Returns the same value as kSyclPlatform above.
Platform::Id id() const override;
@@ -68,28 +58,19 @@
absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
+ private:
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
const StreamExecutorConfig& config) override;
- private:
- // Determines the number of NUMA nodes and the assignment of executor to each.
- void InspectNumaNodes();
-
// This platform's name.
std::string name_;
// Cache of created executors.
ExecutorCache executor_cache_;
- // The smallest NUMA node value for any device managed by this machine
- // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
- // ordinals. The NUMA node space occupied by GPUs is assumed to be dense.
- int min_numa_node_;
-
- // Larger than the NUMA node value for any device managed by this machine
- // manager.
- int limit_numa_node_;
-
SyclPlatform(const SyclPlatform&) = delete;
void operator=(const SyclPlatform&) = delete;
};
diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h
index c969c6d..85646af 100644
--- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h
+++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h
@@ -137,7 +137,7 @@
uint64_t size) override {
LOG(FATAL) << "not yet implemented";
}
- void HostMemoryDeallocate(void* mem, uint64_t size) override {
+ void HostMemoryDeallocate(void* mem) override {
LOG(FATAL) << "not yet implemented";
}
absl::Status SynchronousMemZero(DeviceMemoryBase* location,
diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h
index f2ea4f9..e5e616c 100644
--- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h
+++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h
@@ -88,11 +88,14 @@
return GetExecutor(config);
}
- absl::StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
- const ::stream_executor::StreamExecutorConfig& config) override;
+ absl::StatusOr<::stream_executor::StreamExecutor*> FindExisting(
+ int ordinal) override {
+ stream_executor::StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ return executor_cache_.Get(config);
+ }
- absl::StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
- GetUncachedExecutor(
+ absl::StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
const ::stream_executor::StreamExecutorConfig& config) override;
StreamMap* stream_map() { return &stream_map_; }
@@ -118,6 +121,12 @@
absl::Mutex& mutex() { return event_map_mu_; }
private:
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
+ absl::StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
+ GetUncachedExecutor(const ::stream_executor::StreamExecutorConfig& config);
+
mutable SE_Platform* platform_;
std::string name_;
stream_executor::ExecutorCache executor_cache_;
diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD
index e7219fd4..6697d09 100644
--- a/third_party/xla/xla/tests/BUILD
+++ b/third_party/xla/xla/tests/BUILD
@@ -201,6 +201,7 @@
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
+ "//xla/hlo/utils:hlo_query",
"//xla/service:backend",
"//xla/service:computation_layout",
"//xla/service:hlo_module_util",
@@ -1782,12 +1783,16 @@
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
- "//xla:array2d",
+ "//xla:array3d",
+ "//xla:array4d",
"//xla:literal",
+ "//xla:literal_util",
+ "//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/client:xla_builder",
"//xla/hlo/ir:hlo",
- "@local_tsl//tsl/platform:protobuf",
+ "@com_google_absl//absl/types:span",
+ "@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:test",
],
)
@@ -1900,6 +1905,8 @@
":test_macros_header",
":test_utils",
":xla_internal_test_main", # fixdeps: keep
+ "//xla:array2d",
+ "//xla:array3d",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
@@ -1910,15 +1917,19 @@
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
+ "//xla/service",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
+ "@eigen_archive//:eigen3",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
@@ -2124,6 +2135,9 @@
xla_test(
name = "dynamic_reshape_test",
srcs = ["dynamic_reshape_test.cc"],
+ backend_tags = {
+ "gpu": ["notsan"], # TODO(b/345034145): Fix tsan error.
+ },
disabled_backends = ["interpreter"],
tags = ["test_xla_cpu_thunks"],
deps = [
@@ -2311,14 +2325,9 @@
srcs = ["collective_pipeline_parallelism_test.cc"],
args = ["--xla_force_host_platform_device_count=4"],
backend_tags = {
- # This test is tagged "manual" because it requires multiple GPUs, and Forge only supports
- # single-GPU tests. Guitar skips "manual" tests unless they're also tagged "guitar".
"gpu": [
- "guitar",
- "manual",
"multi_gpu",
"no_oss",
- "notap",
],
"cpu": [
"notsan",
@@ -2357,15 +2366,9 @@
name = "collective_ops_e2e_test",
srcs = ["collective_ops_e2e_test.cc"],
backend_tags = {
- # This test is tagged "manual" because it requires multiple GPUs, and
- # Forge only supports single-GPU tests. Guitar skips "manual" tests
- # unless they're also tagged "guitar".
"gpu": [
- "guitar",
- "manual",
"multi_gpu",
"no_oss",
- "notap",
],
},
backends = [
@@ -2409,15 +2412,9 @@
name = "replicated_io_feed_test",
srcs = ["replicated_io_feed_test.cc"],
backend_tags = {
- # This test is tagged "manual" because it requires multiple GPUs, and
- # Forge only supports single-GPU tests. Guitar skips "manual" tests
- # unless they're also tagged "guitar".
"gpu": [
- "guitar",
- "manual",
"multi_gpu",
"no_oss",
- "notap",
],
},
backends = ["gpu"],
@@ -2790,7 +2787,10 @@
size = "large",
srcs = ["local_client_execute_test.cc"],
shard_count = 30,
- tags = ["optonly"],
+ tags = [
+ "optonly",
+ "test_xla_cpu_thunks",
+ ],
deps = [
":literal_test_util",
":local_client_test_base",
@@ -3081,6 +3081,9 @@
xla_test(
name = "set_dimension_size_test",
srcs = ["set_dimension_size_test.cc"],
+ backend_tags = {
+ "gpu": ["notsan"], # TODO(b/345034145): Fix tsan error.
+ },
tags = ["test_xla_cpu_thunks"],
deps = [
":hlo_test_base",
@@ -3189,6 +3192,7 @@
"//xla:types",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/status:statusor",
+ "@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
)
diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl
index 8a42642..8f528c5 100644
--- a/third_party/xla/xla/tests/build_defs.bzl
+++ b/third_party/xla/xla/tests/build_defs.bzl
@@ -320,8 +320,23 @@
# b/317293391. For this reason, if we would create an empty `test_suite`,
# instead create a `cc_test` with no srcs that links against `main` to have
# more predictable behavior that avoids bugs.
+ #
+ # Due to b/317293391, we also mark the test suite `manual`, so that wild card builds
+ # like in the XLA CI won't try to build the test suite target. Instead the wild card
+ # build will build the individual test targets and therefore respect the tags on each
+ # individual test target.
+ # Example: Assume we have an `xla_test(name=my_test)` in `//xla/service/gpu` with backends `cpu`
+ # and `gpu`. This generates two test targets `//xla/service/gpu:my_test_{cpu|gpu}`. The latter
+ # has a tag `gpu`.
+ #
+ # - `bazel test --test_tag_filters=-gpu //xla/service/gpu/...` will only run the cpu test.
+ # - `bazel test //xla/service/gpu/...` will run both tests.
+ # - `bazel test //xla/service/gpu:my_test` will run both tests.
+ # Caveat:
+ # - `bazel test --test_tag_filters=-gpu //xla/service/gpu:my_test` will run both tests and
+ # not respect the tag filter - but it's way better than the previous behavoir.
if test_names:
- native.test_suite(name = name, tags = tags, tests = test_names)
+ native.test_suite(name = name, tags = tags + ["manual"], tests = test_names)
else:
native.cc_test(name = name, deps = ["@local_tsl//tsl/platform:test_main"])
diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc
index f1d1c78..9908538 100644
--- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc
+++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc
@@ -154,6 +154,7 @@
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_all_reduce = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(kModuleStr, kNumReplicas));
@@ -190,6 +191,7 @@
}
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_all_gather = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
@@ -231,6 +233,7 @@
}
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_all_gather = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
@@ -268,6 +271,7 @@
}
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_collective_broadcast = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(kModuleStr, kNumReplicas));
@@ -300,6 +304,7 @@
}
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_collective_permute = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(kModuleStr, kNumReplicas));
@@ -343,6 +348,7 @@
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_reduce_scatter = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(kModuleStr, kNumReplicas));
@@ -376,6 +382,7 @@
}
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_all_to_all = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(kModuleStr, kNumReplicas));
@@ -420,6 +427,7 @@
}
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const bool enable_async_all_to_all = GetParam();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(kModuleStr, kNumReplicas));
@@ -472,6 +480,7 @@
}
)";
const int64_t kNumReplicas = 4;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -592,6 +601,7 @@
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(true);
@@ -646,6 +656,7 @@
}
)";
const int64_t kNumReplicas = 2;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -677,6 +688,7 @@
absl::string_view hlo_text, bool disable_dot_merger = false) {
const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -959,6 +971,7 @@
)";
const int64_t kNumReplicas = 1;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const int64_t kNumPartitions = 4;
HloModuleConfig config =
@@ -1052,6 +1065,7 @@
)";
const int64_t kNumReplicas = 1;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -1085,6 +1099,7 @@
)";
const int64_t kNumReplicas = 1;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
const int64_t kNumPartitions = 4;
HloModuleConfig config =
diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc
index 460864c..9cd874c 100644
--- a/third_party/xla/xla/tests/collective_ops_test.cc
+++ b/third_party/xla/xla/tests/collective_ops_test.cc
@@ -39,23 +39,17 @@
#include "tsl/platform/env.h"
#include "tsl/platform/threadpool.h"
+namespace xla {
+namespace {
+
// Tests cross-GPU operations.
//
// Several tests requires at least four GPUs. For instructions on running this
// within Google, see go/multi-gpu-unit-test.
-
-#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \
- if (num_devices_ < x) { \
- GTEST_SKIP() << "Test requires at least " << x << " devices"; \
- }
-
-namespace xla {
-namespace {
-
class CollectiveOpsTest : public HloTestBase {
public:
- CollectiveOpsTest() : num_devices_(backend().device_count()) {
- VLOG(1) << "Running with " << num_devices_ << " devices";
+ CollectiveOpsTest() {
+ VLOG(1) << "Running with " << num_devices() << " devices";
}
protected:
@@ -180,9 +174,6 @@
/*expected_value=*/to_literal({cast(-1), cast(-2), cast(-3)}));
}
}
-
- protected:
- const int64_t num_devices_;
};
// Returns the non-empty subsets of {0, 1, ..., n}. For example,
@@ -370,7 +361,7 @@
XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
const int64_t kNumElems = 1024;
- for (std::vector<int64_t> devices : PowerSetOfIota(num_devices_)) {
+ for (std::vector<int64_t> devices : PowerSetOfIota(num_devices())) {
SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
absl::StrJoin(devices, ", ")));
@@ -494,7 +485,7 @@
// Test a prime number so it's not all powers of 2.
const int64_t kNumElems = 137;
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -541,7 +532,7 @@
}
)";
static constexpr int kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -577,19 +568,19 @@
)";
HloModuleConfig config =
- GetModuleConfigForTest(/*replica_count=*/num_devices_);
+ GetModuleConfigForTest(/*replica_count=*/num_devices());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
- num_devices_,
+ num_devices(),
/*use_threads=*/true, /*run_hlo_passes=*/false));
- ASSERT_EQ(results.size(), num_devices_);
+ ASSERT_EQ(results.size(), num_devices());
// sum [0, num_devices)
- uint32_t expected = num_devices_ * (num_devices_ - 1) / 2;
- for (int i = 0; i < num_devices_; ++i) {
+ uint32_t expected = num_devices() * (num_devices() - 1) / 2;
+ for (int i = 0; i < num_devices(); ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
@@ -613,22 +604,22 @@
)";
HloModuleConfig config =
- GetModuleConfigForTest(/*replica_count=*/num_devices_);
+ GetModuleConfigForTest(/*replica_count=*/num_devices());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
- num_devices_,
+ num_devices(),
/*use_threads=*/true, /*run_hlo_passes=*/false));
- ASSERT_EQ(results.size(), num_devices_);
+ ASSERT_EQ(results.size(), num_devices());
// sum [0, num_devices)
- uint32_t expected0 = num_devices_ * (num_devices_ - 1) / 2;
+ uint32_t expected0 = num_devices() * (num_devices() - 1) / 2;
// sum squares [0, num_devices)
uint32_t expected1 =
- num_devices_ * (num_devices_ - 1) * (2 * num_devices_ - 1) / 6;
- for (int i = 0; i < num_devices_; ++i) {
+ num_devices() * (num_devices() - 1) * (2 * num_devices() - 1) / 6;
+ for (int i = 0; i < num_devices(); ++i) {
std::vector<Literal> replica_results = results[i].DecomposeTuple();
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected0, replica_results[0]);
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected1, replica_results[1]);
@@ -645,18 +636,18 @@
)";
HloModuleConfig config =
- GetModuleConfigForTest(/*replica_count=*/num_devices_);
+ GetModuleConfigForTest(/*replica_count=*/num_devices());
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
- num_devices_,
+ num_devices(),
/*use_threads=*/true, /*run_hlo_passes=*/true));
- ASSERT_EQ(results.size(), num_devices_);
- for (uint32_t i = 0; i < num_devices_; ++i) {
+ ASSERT_EQ(results.size(), num_devices());
+ for (uint32_t i = 0; i < num_devices(); ++i) {
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
}
}
@@ -680,7 +671,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -716,7 +707,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -753,7 +744,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -789,7 +780,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -826,7 +817,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -864,7 +855,7 @@
)";
const int64_t kNumReplicas = 2;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -906,7 +897,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -952,7 +943,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -992,7 +983,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -1024,7 +1015,7 @@
}
)";
const int64_t kNumReplicas = 4;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2003,7 +1994,7 @@
)";
const int64_t kNumReplicas = 2;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2083,7 +2074,7 @@
})";
const int64_t kNumReplicas = 2;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2162,7 +2153,7 @@
})";
const int64_t kNumReplicas = 2;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2263,7 +2254,7 @@
})";
const int64_t kNumReplicas = 2;
- SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
diff --git a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc
index bfcf5e1..ee84472 100644
--- a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc
+++ b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc
@@ -15,6 +15,7 @@
#include <cstdint>
#include <memory>
+#include <string>
#include <utility>
#include <vector>
@@ -32,28 +33,18 @@
#include "xla/tests/verified_hlo_module.h"
#include "tsl/platform/statusor.h"
+namespace xla {
+namespace {
+
// Tests cross-GPU operations.
//
// Several tests requires at least four GPUs. For instructions on running this
// within Google, see go/multi-gpu-unit-test.
-
-// TODO: Move this to hlo_test_base.h
-#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \
- if (num_devices_ < x) { \
- GTEST_SKIP() << "Test requires at least " << x << " devices"; \
- }
-
-namespace xla {
-namespace {
-
class CollectivePipelineParallelismTest : public HloTestBase {
public:
- CollectivePipelineParallelismTest() : num_devices_(backend().device_count()) {
- VLOG(1) << "Running with " << num_devices_ << " devices";
+ CollectivePipelineParallelismTest() {
+ VLOG(1) << "Running with " << num_devices() << " devices";
}
-
- protected:
- const int64_t num_devices_;
};
XLA_TEST_F(CollectivePipelineParallelismTest,
@@ -126,26 +117,59 @@
LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {1, 1}}, results[3]);
}
-// Naive implementation of pipeline parallelism:
-// - 4 devices
-// - 4 microbatches
-// - no circular repeat
-// - no disabled collectives
-// - no collective pipelining
-//
-// Every stage of the pipeline is a single linear layer.
-XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
- const absl::string_view kModuleStr = R"(
- HloModule test
+std::string GetModuleStrWithCommonComputations(
+ const std::string name, const std::string more_computations) {
+ static constexpr char kCommonComputationsStr[] = R"(
+ read_buffer_mb4 {
+ buffer = f32[4,16] parameter(0)
+ offset = u32[] parameter(1)
+ index = u32[] parameter(2)
+ c0 = u32[] constant(0)
+ c4 = u32[] constant(4)
+ index_ = u32[] add(index, offset)
+ index__ = u32[] remainder(index_, c4)
+ slice = f32[1,16] dynamic-slice(buffer, index__, c0),
+ dynamic_slice_sizes={1,16}
+ ROOT slice_ = f32[16] reshape(slice)
+ }
- get_circ_buffer_index {
- offset = u32[] parameter(0)
- index = u32[] parameter(1)
- size = u32[] parameter(2)
- t0 = u32[] add(offset, index)
- t1 = u32[] divide(t0, size)
- t2 = u32[] multiply(t1, size)
- ROOT t4 = u32[] subtract(t0, t2)
+ read_buffer_mb5 {
+ buffer = f32[5,16] parameter(0)
+ offset = u32[] parameter(1)
+ index = u32[] parameter(2)
+ c0 = u32[] constant(0)
+ c5 = u32[] constant(5)
+ index_ = u32[] add(index, offset)
+ index__ = u32[] remainder(index_, c5)
+ slice = f32[1,16] dynamic-slice(buffer, index__, c0),
+ dynamic_slice_sizes={1,16}
+ ROOT slice_ = f32[16] reshape(slice)
+ }
+
+ update_buffer_mb4 {
+ buffer = f32[4,16] parameter(0)
+ update = f32[16] parameter(1)
+ offset = u32[] parameter(2)
+ index = u32[] parameter(3)
+ c0 = u32[] constant(0)
+ c4 = u32[] constant(4)
+ index_ = u32[] add(index, offset)
+ index__ = u32[] remainder(index_, c4)
+ update_ = f32[1,16] reshape(update)
+ ROOT buffer_ = f32[4,16] dynamic-update-slice(buffer, update_, index__, c0)
+ }
+
+ update_buffer_mb5 {
+ buffer = f32[5,16] parameter(0)
+ update = f32[16] parameter(1)
+ offset = u32[] parameter(2)
+ index = u32[] parameter(3)
+ c0 = u32[] constant(0)
+ c5 = u32[] constant(5)
+ index_ = u32[] add(index, offset)
+ index__ = u32[] remainder(index_, c5)
+ update_ = f32[1,16] reshape(update)
+ ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
}
is_input_replica {
@@ -156,10 +180,40 @@
is_output_replica {
replica_id = u32[] replica-id()
- c1 = u32[] constant(1)
- ROOT predicate = pred[] compare(replica_id, c1), direction=EQ
+ c3 = u32[] constant(3)
+ ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
}
+ is_read_input_mb4 {
+ is_input_replica = pred[] call(), to_apply=is_input_replica
+ i = u32[] parameter(0)
+ c4 = u32[] constant(4)
+ is_input_iteration = pred[] compare(i, c4), direction=LT
+ ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
+ }
+
+ is_read_input_mb5 {
+ is_input_replica = pred[] call(), to_apply=is_input_replica
+ i = u32[] parameter(0)
+ c5 = u32[] constant(5)
+ is_input_iteration = pred[] compare(i, c5), direction=LT
+ ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
+ }
+ )";
+ return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" +
+ more_computations;
+}
+
+// Naive implementation of pipeline parallelism:
+// - 4 devices
+// - 4 microbatches
+// - no circular repeat
+// - no disabled collectives
+// - no collective pipelining
+//
+// Every stage of the pipeline is a single linear layer.
+XLA_TEST_F(CollectivePipelineParallelismTest, NaiveBFSMicrobatch4Replica4) {
+ constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0)
i = u32[] get-tuple-element(tuple), index=4
@@ -172,36 +226,34 @@
weights = f32[16,16] get-tuple-element(tuple), index=0
input = f32[4,16] get-tuple-element(tuple), index=1
output = f32[4,16] get-tuple-element(tuple), index=2
- tmp = f32[16] get-tuple-element(tuple), index=3
+ prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3
i = u32[] get-tuple-element(tuple), index=4
- c1 = u32[] constant(1)
c0 = u32[] constant(0)
+ c1 = u32[] constant(1)
c4 = u32[] constant(4)
- input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index
- input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
- dynamic_slice_sizes={1,16}
- input_slice_ = f32[16] reshape(input_slice)
+ // Read from buffers.
+ input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4
- prev_stage_slice = f32[16] collective-permute(tmp),
+ // Shift data to the next stage in the pipeline.
+ prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
+ // Select compute argument from previous stage or from input and perform
+ // compute.
read_input = pred[] call(), to_apply=is_input_replica
- compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice)
-
- compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+ compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice)
+ compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
- output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index
- output_slice = f32[1,16] reshape(compute_out)
- output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index,
- c0)
+ // Update buffers.
+ output_ = call(output, compute_res, c1, i), to_apply=update_buffer_mb4
i_ = add(i, c1)
ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(
- weights, input, output_, compute_out, i_)
+ weights, input, output_, compute_res, i_)
}
ENTRY main {
@@ -210,11 +262,11 @@
cf0 = f32[] constant(0)
output = f32[4,16] broadcast(cf0), dimensions={}
- tmp = f32[16] broadcast(cf0), dimensions={}
+ prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights,
- input, output, tmp, c0)
+ input, output, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple),
condition=while_condition, body=while_body
@@ -227,8 +279,11 @@
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module,
+ ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+ /*name=*/"test", kMoreComputationsStr),
+ config));
// This pipeline consists of 4 layers, each of which is a single linear layer.
// We assign the weights to the replicas such that the layers scale the input
@@ -260,7 +315,7 @@
// Check pipeline output for last replica.
// The combined effect of the pipeline is to scale the input data by 24.0.
const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0;
- Literal expected_output = LiteralUtil::CreateFingerprintMatixR2(
+ Literal expected_output = LiteralUtil::CreateFingerprintMatixR2<float>(
kMicrobatches, kInputSize, kExpectedFactor);
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3],
ErrorSpec{1e-5, 1e-5}));
@@ -274,32 +329,8 @@
// - no collective pipelining
//
// Every stage of the pipeline is a single linear layer.
-XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) {
- const absl::string_view kModuleStr = R"(
- HloModule test
-
- get_circ_buffer_index {
- offset = u32[] parameter(0)
- index = u32[] parameter(1)
- size = u32[] parameter(2)
- t0 = u32[] add(offset, index)
- t1 = u32[] divide(t0, size)
- t2 = u32[] multiply(t1, size)
- ROOT t4 = u32[] subtract(t0, t2)
- }
-
- is_input_replica {
- replica_id = u32[] replica-id()
- c0 = u32[] constant(0)
- ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
- }
-
- is_output_replica {
- replica_id = u32[] replica-id()
- c1 = u32[] constant(1)
- ROOT predicate = pred[] compare(replica_id, c1), direction=EQ
- }
-
+XLA_TEST_F(CollectivePipelineParallelismTest, NaiveBFSMicrobatch5Replica4) {
+ constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0)
i = u32[] get-tuple-element(tuple), index=4
@@ -312,37 +343,35 @@
weights = f32[16,16] get-tuple-element(tuple), index=0
input = f32[5,16] get-tuple-element(tuple), index=1
output = f32[5,16] get-tuple-element(tuple), index=2
- tmp = f32[16] get-tuple-element(tuple), index=3
+ prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3
i = u32[] get-tuple-element(tuple), index=4
+ c0 = u32[] constant(0)
c1 = u32[] constant(1)
c2 = u32[] constant(2)
- c0 = u32[] constant(0)
c5 = u32[] constant(5)
- input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
- input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
- dynamic_slice_sizes={1,16}
- input_slice_ = f32[16] reshape(input_slice)
+ // Read from buffers.
+ input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
- prev_stage_slice = f32[16] collective-permute(tmp),
+ // Shift data to the next stage in the pipeline.
+ prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
+ // Select compute argument from previous stage or from input and perform
+ // compute.
read_input = pred[] call(), to_apply=is_input_replica
- compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice)
-
- compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+ compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice)
+ compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
- output_index = u32[] call(c2, i, c5), to_apply=get_circ_buffer_index
- output_slice = f32[1,16] reshape(compute_out)
- output_ = f32[5,16] dynamic-update-slice(output, output_slice, output_index,
- c0)
+ // Update buffers.
+ output_ = call(output, compute_res, c2, i), to_apply=update_buffer_mb5
i_ = add(i, c1)
ROOT tuple1 = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[])
- tuple(weights, input, output_, compute_out, i_)
+ tuple(weights, input, output_, compute_res, i_)
}
ENTRY main {
@@ -351,11 +380,11 @@
cf0 = f32[] constant(0)
output = f32[5,16] broadcast(cf0), dimensions={}
- tmp = f32[16] broadcast(cf0), dimensions={}
+ prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[])
- tuple(weights, input, output, tmp, c0)
+ tuple(weights, input, output, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) while(tuple),
condition=while_condition, body=while_body
@@ -368,8 +397,11 @@
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module,
+ ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+ /*name=*/"test", kMoreComputationsStr),
+ config));
// This pipeline consists of 4 layers, each of which is a single linear layer.
// We assign the weights to the replicas such that the layers scale the input
@@ -415,40 +447,8 @@
//
// Every stage of the pipeline is a single linear layer.
XLA_TEST_F(CollectivePipelineParallelismTest,
- NaiveDFSMicrobatch4CircularRepeat2Replica4) {
- const absl::string_view kModuleStr = R"(
- HloModule test
-
- get_circ_buffer_index {
- offset = u32[] parameter(0)
- index = u32[] parameter(1)
- size = u32[] parameter(2)
- t0 = u32[] add(offset, index)
- t1 = u32[] divide(t0, size)
- t2 = u32[] multiply(t1, size)
- ROOT t4 = u32[] subtract(t0, t2)
- }
-
- is_input_replica {
- replica_id = u32[] replica-id()
- c0 = u32[] constant(0)
- ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
- }
-
- is_output_replica {
- replica_id = u32[] replica-id()
- c3 = u32[] constant(3)
- ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
- }
-
- is_read_input {
- is_input_replica = pred[] call(), to_apply=is_input_replica
- i = u32[] parameter(0)
- c4 = u32[] constant(4)
- is_input_iteration = pred[] compare(i, c4), direction=LT
- ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
- }
-
+ NaiveBFSMicrobatch4CircularRepeat2Replica4) {
+ constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0)
i = u32[] get-tuple-element(tuple), index=4
@@ -461,36 +461,35 @@
weights = f32[16,16] get-tuple-element(tuple), index=0
input = f32[4,16] get-tuple-element(tuple), index=1
output = f32[4,16] get-tuple-element(tuple), index=2
- tmp = f32[16] get-tuple-element(tuple), index=3
+ prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3
i = u32[] get-tuple-element(tuple), index=4
- c1 = u32[] constant(1)
c0 = u32[] constant(0)
+ c1 = u32[] constant(1)
c4 = u32[] constant(4)
- input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index
- input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
- dynamic_slice_sizes={1,16}
- input_slice_ = f32[16] reshape(input_slice)
+ // Read from buffers.
+ input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4
- prev_stage_slice = f32[16] collective-permute(tmp),
+ // Shift data to the next stage in the pipeline.
+ prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
- is_read_input = pred[] call(i), to_apply=is_read_input
- compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
-
- compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+ // Select compute argument from previous stage or from input and perform
+ // compute.
+ is_read_input = pred[] call(i), to_apply=is_read_input_mb4
+ compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
+ compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
- output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index
- output_slice = f32[1,16] reshape(compute_out)
- output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index,
- c0)
+ // Update buffers.
+ output_ = f32[4,16] call(output, compute_res, c1, i),
+ to_apply=update_buffer_mb4
i_ = add(i, c1)
ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[])
- tuple(weights, input, output_, compute_out, i_)
+ tuple(weights, input, output_, compute_res, i_)
}
ENTRY main {
@@ -499,11 +498,11 @@
cf0 = f32[] constant(0)
output = f32[4,16] broadcast(cf0), dimensions={}
- tmp = f32[16] broadcast(cf0), dimensions={}
+ prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights,
- input, output, tmp, c0)
+ input, output, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple),
condition=while_condition, body=while_body
@@ -516,8 +515,11 @@
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module,
+ ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+ /*name=*/"test", kMoreComputationsStr),
+ config));
// This pipeline consists of a total of 8 layers (2 per replica), each of
// which is a single linear layer. We assign the weights to the replicas such
@@ -556,7 +558,7 @@
ErrorSpec{1e-5, 1e-5}));
}
-// Naive implementation if pipeline parallelism:
+// Naive implementation of pipeline parallelism:
// - 4 devices
// - 5 microbatches
// - 2 circular repeat
@@ -565,66 +567,8 @@
//
// Every stage of the pipeline is a single linear layer.
XLA_TEST_F(CollectivePipelineParallelismTest,
- NaiveDFSMicrobatch5CircularRepeat2Replica4) {
- const absl::string_view kModuleStr = R"(
- HloModule test
-
- get_circ_buffer_index {
- offset = u32[] parameter(0)
- index = u32[] parameter(1)
- size = u32[] parameter(2)
- t0 = u32[] add(offset, index)
- t1 = u32[] divide(t0, size)
- t2 = u32[] multiply(t1, size)
- ROOT t4 = u32[] subtract(t0, t2)
- }
-
- read_buffer {
- buffer = f32[5,16] parameter(0)
- offset = u32[] parameter(1)
- index = u32[] parameter(2)
- c0 = u32[] constant(0)
- c5 = u32[] constant(5)
- index_ = u32[] add(index, offset)
- index__ = u32[] remainder(index_, c5)
- slice = f32[1,16] dynamic-slice(buffer, index__, c0),
- dynamic_slice_sizes={1,16}
- ROOT slice_ = f32[16] reshape(slice)
- }
-
- update_buffer {
- buffer = f32[5,16] parameter(0)
- update = f32[16] parameter(1)
- offset = u32[] parameter(2)
- index = u32[] parameter(3)
- c0 = u32[] constant(0)
- c5 = u32[] constant(5)
- index_ = u32[] add(index, offset)
- index__ = u32[] remainder(index_, c5)
- update_ = f32[1,16] reshape(update)
- ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
- }
-
- is_input_replica {
- replica_id = u32[] replica-id()
- c0 = u32[] constant(0)
- ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
- }
-
- is_output_replica {
- replica_id = u32[] replica-id()
- c3 = u32[] constant(3)
- ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
- }
-
- is_read_input {
- is_input_replica = pred[] call(), to_apply=is_input_replica
- i = u32[] parameter(0)
- c5 = u32[] constant(5)
- is_input_iteration = pred[] compare(i, c5), direction=LT
- ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
- }
-
+ NaiveBFSMicrobatch5CircularRepeat2Replica4) {
+ constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
parameter(0)
@@ -640,43 +584,46 @@
input = f32[5,16] get-tuple-element(tuple), index=1
output = f32[5,16] get-tuple-element(tuple), index=2
buffer = f32[5,16] get-tuple-element(tuple), index=3
- prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4
+ prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4
i = u32[] get-tuple-element(tuple), index=5
c0 = u32[] constant(0)
c1 = u32[] constant(1)
c2 = u32[] constant(2)
c3 = u32[] constant(3)
+ c4 = u32[] constant(4)
c5 = u32[] constant(5)
- input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
- input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
- dynamic_slice_sizes={1,16}
- input_slice_ = f32[16] reshape(input_slice)
+ // Read from buffers.
+ input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
+ buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5
- buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer
-
+ // Shift data to the next stage in the pipeline.
+ // Directly depends on the updated buffer of the previous iteration and,
+ // therefore, depends on the previous iteration's compute.
is_output_replica = pred[] call(), to_apply=is_output_replica
next_stage_slice = select(is_output_replica, buffer_slice,
- prev_iteration_compute_out)
-
+ prev_iteration_compute_res)
prev_stage_slice = f32[16] collective-permute(next_stage_slice),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
- is_read_input = pred[] call(i), to_apply=is_read_input
- compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
-
- compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+ // Select compute argument from previous stage or from input and perform
+ // compute.
+ is_read_input = pred[] call(i), to_apply=is_read_input_mb5
+ compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
+ compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
- output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer
-
- buffer_ = f32[5,16] call(buffer, compute_out, c0, i), to_apply=update_buffer
+ // Update buffers.
+ output_ = f32[5,16] call(output, compute_res, c2, i),
+ to_apply=update_buffer_mb5
+ buffer_ = f32[5,16] call(buffer, compute_res, c0, i),
+ to_apply=update_buffer_mb5
i_ = add(i, c1)
ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
- tuple(weights, input, output_, buffer_, compute_out, i_)
+ tuple(weights, input, output_, buffer_, compute_res, i_)
}
ENTRY main {
@@ -686,11 +633,12 @@
cf0 = f32[] constant(0)
output = f32[5,16] broadcast(cf0), dimensions={}
buffer = f32[5,16] broadcast(cf0), dimensions={}
- prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={}
+ prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
+ // Iterate through pipeline stages.
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
- tuple(weights, input, output, buffer, prev_iteration_compute_out, c0)
+ tuple(weights, input, output, buffer, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
while(tuple), condition=while_condition, body=while_body
@@ -703,8 +651,11 @@
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module,
+ ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+ /*name=*/"test", kMoreComputationsStr),
+ config));
// This pipeline consists of a total of 8 layers (2 per replica), each of
// which is a single linear layer. We assign the weights to the replicas such
@@ -754,66 +705,8 @@
//
// Every stage of the pipeline is a single linear layer.
XLA_TEST_F(CollectivePipelineParallelismTest,
- NaiveWoDirectBufferDependencyDFSMicrobatch5CircularRepeat2Replica4) {
- const absl::string_view kModuleStr = R"(
- HloModule test
-
- get_circ_buffer_index {
- offset = u32[] parameter(0)
- index = u32[] parameter(1)
- size = u32[] parameter(2)
- t0 = u32[] add(offset, index)
- t1 = u32[] divide(t0, size)
- t2 = u32[] multiply(t1, size)
- ROOT t4 = u32[] subtract(t0, t2)
- }
-
- read_buffer {
- buffer = f32[5,16] parameter(0)
- offset = u32[] parameter(1)
- index = u32[] parameter(2)
- c0 = u32[] constant(0)
- c5 = u32[] constant(5)
- index_ = u32[] add(index, offset)
- index__ = u32[] remainder(index_, c5)
- slice = f32[1,16] dynamic-slice(buffer, index__, c0),
- dynamic_slice_sizes={1,16}
- ROOT slice_ = f32[16] reshape(slice)
- }
-
- update_buffer {
- buffer = f32[5,16] parameter(0)
- update = f32[16] parameter(1)
- offset = u32[] parameter(2)
- index = u32[] parameter(3)
- c0 = u32[] constant(0)
- c5 = u32[] constant(5)
- index_ = u32[] add(index, offset)
- index__ = u32[] remainder(index_, c5)
- update_ = f32[1,16] reshape(update)
- ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
- }
-
- is_input_replica {
- replica_id = u32[] replica-id()
- c0 = u32[] constant(0)
- ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
- }
-
- is_output_replica {
- replica_id = u32[] replica-id()
- c3 = u32[] constant(3)
- ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
- }
-
- is_read_input {
- is_input_replica = pred[] call(), to_apply=is_input_replica
- i = u32[] parameter(0)
- c5 = u32[] constant(5)
- is_input_iteration = pred[] compare(i, c5), direction=LT
- ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
- }
-
+ NaiveWoDirectBufferDependencyBFSMicrobatch5CircularRepeat2Replica4) {
+ constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
parameter(0)
@@ -829,7 +722,7 @@
input = f32[5,16] get-tuple-element(tuple), index=1
output = f32[5,16] get-tuple-element(tuple), index=2
buffer = f32[5,16] get-tuple-element(tuple), index=3
- prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4
+ prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4
i = u32[] get-tuple-element(tuple), index=5
c0 = u32[] constant(0)
@@ -839,38 +732,36 @@
c4 = u32[] constant(4)
c5 = u32[] constant(5)
- input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
- input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
- dynamic_slice_sizes={1,16}
- input_slice_ = f32[16] reshape(input_slice)
+ // Read from buffers before they are updated.
+ input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
+ buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5
- buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer
-
- buffer_ = f32[5,16] call(buffer, prev_iteration_compute_out, c4, i),
- to_apply=update_buffer
-
+ // Shift data to the next stage in the pipeline.
// Depends on the non-updated buffer of the previous iteration and,
// therefore, does not depend on the previous iteration's compute.
is_output_replica = pred[] call(), to_apply=is_output_replica
next_stage_slice = select(is_output_replica, buffer_slice,
- prev_iteration_compute_out)
-
-
+ prev_iteration_compute_res)
prev_stage_slice = f32[16] collective-permute(next_stage_slice),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
- is_read_input = pred[] call(i), to_apply=is_read_input
- compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
-
- compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+ // Select compute argument from previous stage or from input and perform
+ // compute.
+ is_read_input = pred[] call(i), to_apply=is_read_input_mb5
+ compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
+ compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
- output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer
+ // Update buffers.
+ buffer_ = f32[5,16] call(buffer, prev_iteration_compute_res, c4, i),
+ to_apply=update_buffer_mb5
+ output_ = f32[5,16] call(output, compute_res, c2, i),
+ to_apply=update_buffer_mb5
i_ = add(i, c1)
ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
- tuple(weights, input, output_, buffer_, compute_out, i_)
+ tuple(weights, input, output_, buffer_, compute_res, i_)
}
ENTRY main {
@@ -880,11 +771,12 @@
cf0 = f32[] constant(0)
output = f32[5,16] broadcast(cf0), dimensions={}
buffer = f32[5,16] broadcast(cf0), dimensions={}
- prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={}
+ prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
+ // Iterate through pipeline stages.
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
- tuple(weights, input, output, buffer, prev_iteration_compute_out, c0)
+ tuple(weights, input, output, buffer, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
while(tuple), condition=while_condition, body=while_body
@@ -897,8 +789,11 @@
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(kModuleStr, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module,
+ ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+ /*name=*/"test", kMoreComputationsStr),
+ config));
// This pipeline consists of a total of 8 layers (2 per replica), each of
// which is a single linear layer. We assign the weights to the replicas such
diff --git a/third_party/xla/xla/tests/copy_test.cc b/third_party/xla/xla/tests/copy_test.cc
index 45d94ab..734ecbd 100644
--- a/third_party/xla/xla/tests/copy_test.cc
+++ b/third_party/xla/xla/tests/copy_test.cc
@@ -13,22 +13,31 @@
limitations under the License.
==============================================================================*/
+#include <cstddef>
+#include <cstdint>
#include <memory>
#include <utility>
+#include <vector>
-#include "xla/array2d.h"
+#include <gtest/gtest.h>
+#include "absl/types/span.h"
+#include "xla/array3d.h"
+#include "xla/array4d.h"
#include "xla/client/xla_builder.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
#include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
#include "xla/tests/client_library_test_base.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/test_macros.h"
#include "xla/xla_data.pb.h"
-#include "tsl/platform/protobuf.h"
#include "tsl/platform/test.h"
namespace xla {
@@ -50,6 +59,25 @@
EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
+ // TODO(vsytch): Remove special handling for dynamic shapes once *all* of XLA
+ // supports those as module inputs/outputs.
+ void TestDynamicCopyOp(const Literal& literal, const Shape& bounded_shape) {
+ Literal dynamic_literal = literal.ToBoundedDynamic(bounded_shape);
+ auto builder = HloComputation::Builder(TestName());
+ auto parameter = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, dynamic_literal.shape(), "param"));
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ parameter->shape(), HloOpcode::kCopy, parameter));
+ auto computation = builder.Build();
+ auto module = CreateNewVerifiedModule();
+ module->AddEntryComputation(std::move(computation));
+
+ std::vector<Literal*> args = {&dynamic_literal};
+ Literal result = ExecuteAndTransfer(std::move(module), args);
+ Literal dynamic_result = result.ToBoundedDynamic(bounded_shape);
+ EXPECT_TRUE(LiteralTestUtil::Equal(dynamic_literal, dynamic_result));
+ }
+
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4,
absl::Span<const int64_t> permutation);
@@ -67,6 +95,47 @@
TestCopyOp(LiteralUtil::CreateR1<uint32_t>({1, 2, 3}));
}
+XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic0) {
+ // TODO(vsytch): CPU emitter doesn't handle dynamic shapes.
+ if (backend().platform()->Name() == "Host") {
+ GTEST_SKIP();
+ }
+ Shape bounded_shape =
+ ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true});
+ TestDynamicCopyOp(LiteralUtil::CreateRandomLiteral<PrimitiveType::F32>(
+ ShapeUtil::MakeShape(PrimitiveType::F32, {0}), 0, 1)
+ .value(),
+ bounded_shape);
+}
+
+XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic106632) {
+ // TODO(vsytch): CPU emitter doesn't handle dynamic shapes.
+ if (backend().platform()->Name() == "Host") {
+ GTEST_SKIP();
+ }
+ Shape bounded_shape =
+ ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true});
+ TestDynamicCopyOp(
+ LiteralUtil::CreateRandomLiteral<PrimitiveType::F32>(
+ ShapeUtil::MakeShape(PrimitiveType::F32, {106632}), 0, 1)
+ .value(),
+ bounded_shape);
+}
+
+XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic1310720) {
+ // TODO(vsytch): CPU emitter doesn't handle dynamic shapes.
+ if (backend().platform()->Name() == "Host") {
+ GTEST_SKIP();
+ }
+ Shape bounded_shape =
+ ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true});
+ TestDynamicCopyOp(
+ LiteralUtil::CreateRandomLiteral<PrimitiveType::F32>(
+ ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}), 0, 1)
+ .value(),
+ bounded_shape);
+}
+
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc
index 2d0f370..3fd7cf5 100644
--- a/third_party/xla/xla/tests/custom_call_test.cc
+++ b/third_party/xla/xla/tests/custom_call_test.cc
@@ -26,17 +26,19 @@
#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
+#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
-#include "xla/client/lib/constants.h"
+#include "xla/array2d.h"
+#include "xla/array3d.h"
#include "xla/client/xla_builder.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
@@ -44,6 +46,7 @@
#include "xla/primitive_util.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_target_registry.h"
+#include "xla/service/service.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/client_library_test_base.h"
@@ -55,6 +58,9 @@
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
+#define EIGEN_USE_THREADS
+#include "unsupported/Eigen/CXX11/Tensor"
+
namespace {
void R0F32Add2(float* out, float** in) {
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*));
@@ -862,6 +868,40 @@
"__xla_test$$HandleTupleDifferentRanks", "Host",
kHandleTupleDifferentRanks);
+static absl::Status CustomCallWithIntraOpThreadPool(
+ ffi::Result<ffi::AnyBuffer>,
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool) {
+ // We use two blocking counters to ensure that the task is actually running
+ // inside a thread pool.
+ absl::BlockingCounter counter0(1);
+ absl::BlockingCounter counter1(1);
+
+ intra_op_thread_pool->getPool()->Schedule([&]() {
+ counter0.Wait();
+ counter1.DecrementCount();
+ });
+
+ // Unblock submitted task.
+ counter0.DecrementCount();
+
+ // TODO(b/356389210): It is unsafe to wait for the completion of a task
+ // submitted into an intra-op thread pool as we might be running on a thread
+ // inside the same thread pool, and this can lead to deadlocks. Custom calls
+ // should return `AsyncValue` to signal completion of all submitted tasks.
+ counter1.Wait();
+
+ return absl::OkStatus();
+}
+
+XLA_FFI_DEFINE_HANDLER(kIntraOpThreadPool, CustomCallWithIntraOpThreadPool,
+ ffi::Ffi::Bind()
+ .Ret<AnyBuffer>() // unused out buffer
+ .Ctx<ffi::IntraOpThreadPool>());
+
+XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
+ "__xla_test$$intra_op_thread_pool", "Host",
+ kIntraOpThreadPool);
+
} // namespace
// __xla_test$$ConcatVectors
@@ -1610,5 +1650,19 @@
EXPECT_EQ(result, expected);
}
+XLA_TEST_F(FfiCustomCallTest, IntraOpThreadPool) {
+ auto module = CreateNewVerifiedModule();
+ auto builder = HloComputation::Builder(TestName());
+
+ builder.AddInstruction(HloInstruction::CreateCustomCall(
+ r0f32_, {}, "__xla_test$$intra_op_thread_pool", "",
+ /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI));
+
+ module->AddEntryComputation(builder.Build());
+
+ auto status = Execute(std::move(module), {}).status();
+ EXPECT_EQ(status, absl::OkStatus());
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD
index 293e103..ed239be 100644
--- a/third_party/xla/xla/tests/exhaustive/BUILD
+++ b/third_party/xla/xla/tests/exhaustive/BUILD
@@ -29,6 +29,7 @@
deps = [
"//xla:bit_cast",
"//xla:executable_run_options",
+ "//xla:fp_util",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
@@ -184,6 +185,7 @@
deps = [
":exhaustive_op_test_utils",
"//xla:literal",
+ "//xla:types",
"//xla/client:xla_builder",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/util:command_line_flags",
diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc
index f9da77d..e3d0459 100644
--- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc
+++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc
@@ -16,8 +16,10 @@
#include <array>
#include <cmath>
#include <cstdint>
-#include <cstring>
+#include <cstdlib>
+#include <limits>
#include <tuple>
+#include <type_traits>
#include <utility>
#include "absl/log/check.h"
@@ -28,6 +30,7 @@
#include "xla/literal.h"
#include "xla/tests/exhaustive/exhaustive_op_test_utils.h"
#include "xla/tests/test_macros.h"
+#include "xla/types.h"
#include "tsl/platform/test.h"
#ifdef __FAST_MATH__
@@ -115,16 +118,101 @@
Run(AddEmptyBroadcastDimension(Sub), host_sub);
})
-// TODO(bixia): Mul fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), {
- auto host_mul = [](float x, float y) { return x * y; };
- Run(AddEmptyBroadcastDimension(Mul), host_mul);
+// Can be thought of as an absolute error of
+// `<= |std::numeric_limits::<float>::min()|`.
+double MulCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+ float output = static_cast<float>(left) * static_cast<float>(right);
+
+ // Subnormals are flushed to 0 (as inputs or outputs). In these cases, we
+ // calculate 0 instead of the expected very small number so we use the minimum
+ // float value as the absolute error to give a buffer.
+ auto left_is_subnormal = IsSubnormal(left);
+ auto right_is_subnormal = IsSubnormal(right);
+ auto output_is_subnormal = IsSubnormal(output);
+ if (left_is_subnormal || right_is_subnormal || output_is_subnormal) {
+ return std::numeric_limits<float>::min();
+ }
+
+ return 0.0;
+}
+
+bool MulCpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) {
+ // For BF16, multiplying a subnormal by infinity will lead to calculating 0
+ // multiplied by infinity due to subnormal flushing, which is defined to be
+ // NaN. However, the calculation in higher precision does not flush the
+ // subnormal value to 0, leading to a result of infinity.
+ auto left_is_subnormal = IsSubnormal(left);
+ auto left_is_infinite = std::isinf(left);
+ auto right_is_subnormal = IsSubnormal(right);
+ auto right_is_infinite = std::isinf(right);
+ if ((left_is_subnormal && right_is_infinite) ||
+ (left_is_infinite && right_is_subnormal)) {
+ return true;
+ }
+
+ return false;
+}
+
+BINARY_TEST_16BIT(Mul, {
+ ErrorSpecGen error_spec_gen = +[](NativeT left, NativeT right) {
+ return ErrorSpec::Builder().strict_signed_zeros().build();
+ };
+ if (IsCpu(platform_)) {
+ if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+ error_spec_gen = +[](NativeT left, NativeT right) {
+ return ErrorSpec::Builder()
+ .abs_err(MulCpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+ static_cast<xla::bfloat16>(right)))
+ .strict_signed_zeros()
+ .skip_comparison(MulCpuBf16Skip(static_cast<xla::bfloat16>(left),
+ static_cast<xla::bfloat16>(right)))
+ .build();
+ };
+ }
+ }
+ Run(
+ AddEmptyBroadcastDimension(Mul), [](float x, float y) { return x * y; },
+ error_spec_gen);
})
-// TODO(bixia): Div fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), {
- auto host_div = [](float x, float y) { return x / y; };
- Run(AddEmptyBroadcastDimension(Div), host_div);
+// Can be thought of as an absolute error of
+// `<= |std::numeric_limits::<float>::min()|`.
+double DivCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+ float output = static_cast<float>(left) / static_cast<float>(right);
+
+ // Subnormals are flushed to 0 so we add a absolute error margin that is
+ // larger than any subnormal.
+ auto output_is_subnormal = IsSubnormal(output);
+ if (output_is_subnormal) {
+ return std::numeric_limits<float>::min();
+ }
+
+ return 0.0;
+}
+
+BINARY_TEST_16BIT(Div, {
+ ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) {
+ return ErrorSpec::Builder().strict_signed_zeros().build();
+ };
+ if (IsCpu(platform_)) {
+ if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+ error_spec_gen = +[](NativeT left, NativeT right) {
+ return ErrorSpec::Builder()
+ .abs_err(DivCpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+ static_cast<xla::bfloat16>(right)))
+ .strict_signed_zeros()
+ .build();
+ };
+ }
+ }
+ if (IsGpu(platform_) && std::is_same_v<NativeT, xla::half>) {
+ error_spec_gen = +[](NativeT, NativeT) {
+ return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build();
+ };
+ }
+ Run(
+ AddEmptyBroadcastDimension(Div), [](float x, float y) { return x / y; },
+ error_spec_gen);
})
BINARY_TEST_16BIT(Max, {
@@ -135,21 +223,135 @@
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
})
-// TODO(bixia): Pow fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_GPU(DISABLED_ON_CPU(Pow)), {
- // See b/162664705.
- known_incorrect_fn_ = [](int64_t val) {
- Eigen::bfloat16 f;
- uint16_t val_16 = val;
- memcpy(&f, &val_16, 2);
- return std::isnan(f);
+template <typename NativeT>
+bool PowCpuGpuF16Skip(NativeT left, NativeT right) {
+ // Hardware seems to always return 1 if right is 0, no matter if left is NaN.
+ if (std::isnan(left) && right == 0) {
+ return true;
+ }
+ // Hardware seems to always return 1 if left is 1, no matter if right is NaN.
+ if (left == 1 && std::isnan(right)) {
+ return true;
+ }
+ return false;
+}
+
+double PowCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+ float output = std::pow(static_cast<float>(left), static_cast<float>(right));
+
+ // Output is flushed to 0 if subnormal.
+ if (IsSubnormal(output)) {
+ return std::numeric_limits<float>::min();
+ }
+
+ // TODO(b/359325328): pow computation for subnormal bases is different from
+ // std::pow.
+ //
+ // If the base is subnormal, the output computation selects a different base.
+ // The minimum value ever chosen is slightly greater than the 1e-91 used
+ // below. We return an absolute error from this value to the "real" output.
+ //
+ // Because the exponent (right) can be any floating point value, this allows
+ // an arbitrary absolute error for subnormal values.
+ if (IsSubnormal(left)) {
+ xla::bfloat16 output_as_bf16 = static_cast<xla::bfloat16>(output);
+ auto expected = std::pow(1e-91, static_cast<double>(right));
+ auto err = std::abs(expected - output_as_bf16);
+ if (!std::isnan(err)) {
+ return err;
+ }
+ }
+
+ return 0.0;
+}
+
+BINARY_TEST_16BIT(Pow, {
+ ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) {
+ return ErrorSpec::Builder().strict_signed_zeros().build();
};
- Run(AddEmptyBroadcastDimension(Pow), std::pow);
+ if (IsCpu(platform_)) {
+ if constexpr (std::is_same_v<NativeT, xla::half>) {
+ error_spec_gen = +[](NativeT left, NativeT right) {
+ return ErrorSpec::Builder()
+ .strict_signed_zeros()
+ .skip_comparison(PowCpuGpuF16Skip(left, right))
+ .build();
+ };
+ }
+ if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+ error_spec_gen = +[](NativeT left, NativeT right) {
+ return ErrorSpec::Builder()
+ .abs_err(PowCpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+ static_cast<xla::bfloat16>(right)))
+ .strict_signed_zeros()
+ .build();
+ };
+ }
+ }
+ if (IsGpu(platform_)) {
+ error_spec_gen = +[](NativeT left, NativeT right) {
+ return ErrorSpec::Builder()
+ .distance_err(1)
+ .strict_signed_zeros()
+ .skip_comparison(PowCpuGpuF16Skip(left, right))
+ .build();
+ };
+ }
+ Run(AddEmptyBroadcastDimension(Pow), std::pow, error_spec_gen);
})
-// TODO(bixia): Atan2 fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2),
- { Run(AddEmptyBroadcastDimension(Atan2), std::atan2); })
+// Can be thought of as an absolute error of
+// `<= |std::numeric_limits::<float>::min()|`.
+double Atan2CpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+ float output =
+ std::atan2(static_cast<float>(left), static_cast<float>(right));
+
+ // If the output would be a subnormal float, we allow some error to account
+ // for BF16 implementation flushing subnormals to zero.
+ auto output_is_subnormal = IsSubnormal(output);
+ if (output_is_subnormal) {
+ return std::numeric_limits<float>::min();
+ }
+
+ return 0.0;
+}
+
+bool Atan2CpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) {
+ // Subnormals are flushed to 0, but 0/0 returns NaN instead of
+ // <subnormal>/<subnormal> which returns some positive number. We cannot set
+ // an error to compare against NaN.
+ if (IsSubnormal(left) && IsSubnormal(right)) {
+ return true;
+ }
+
+ return false;
+}
+
+BINARY_TEST_16BIT(Atan2, {
+ auto error_spec_gen = +[](NativeT, NativeT) {
+ return ErrorSpec::Builder().strict_signed_zeros().build();
+ };
+ if (IsCpu(platform_)) {
+ if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+ error_spec_gen = +[](NativeT left, NativeT right) {
+ return ErrorSpec::Builder()
+ .abs_err(Atan2CpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+ static_cast<xla::bfloat16>(right)))
+ .strict_signed_zeros()
+ .skip_comparison(
+ Atan2CpuBf16Skip(static_cast<xla::bfloat16>(left),
+ static_cast<xla::bfloat16>(right)))
+ .build();
+ };
+ }
+ }
+ if (IsGpu(platform_)) {
+ error_spec_gen = +[](NativeT, NativeT) {
+ return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build();
+ };
+ }
+ Run(AddEmptyBroadcastDimension(Atan2), std::atan2, error_spec_gen);
+})
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc
index 6c791a8..f677539 100644
--- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc
+++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc
@@ -15,6 +15,7 @@
#include "xla/tests/exhaustive/exhaustive_op_test_utils.h"
+#include <algorithm>
#include <array>
#include <cmath>
#include <cstddef>
@@ -90,6 +91,64 @@
return IsMinNormal(value.imag());
}
+/*static*/ ErrorSpec::Builder builder() { return ErrorSpecBuilder(); }
+
+ErrorSpecBuilder& ErrorSpecBuilder::abs_err(double abs_err) & {
+ spec_.abs_err = abs_err;
+ return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::rel_err(double rel_err) & {
+ spec_.rel_err = rel_err;
+ return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::distance_err(int64_t distance_err) & {
+ spec_.distance_err = distance_err;
+ return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::strict_signed_zeros(
+ bool strict_signed_zeros) & {
+ spec_.strict_signed_zeros = strict_signed_zeros;
+ return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::skip_comparison(bool skip_comparison) & {
+ spec_.skip_comparison = skip_comparison;
+ return *this;
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::abs_err(double abs_err) && {
+ spec_.abs_err = abs_err;
+ return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::rel_err(double rel_err) && {
+ spec_.rel_err = rel_err;
+ return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::distance_err(int64_t distance_err) && {
+ spec_.distance_err = distance_err;
+ return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::strict_signed_zeros(
+ bool strict_signed_zeros) && {
+ spec_.strict_signed_zeros = strict_signed_zeros;
+ return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::skip_comparison(bool skip_comparison) && {
+ spec_.skip_comparison = skip_comparison;
+ return std::move(*this);
+}
+
+ErrorSpecBuilder::operator ErrorSpec() && { return std::move(*this).build(); }
+
+ErrorSpec ErrorSpecBuilder::build() && { return spec_; }
+
// For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
// precision to be guaranteed that we're printing the full number.
//
@@ -491,6 +550,9 @@
}
ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs);
+ ASSERT_GE(error_spec.abs_err, 0.0);
+ ASSERT_GE(error_spec.rel_err, 0.0);
+ ASSERT_GE(error_spec.distance_err, 0.0);
if (error_spec.skip_comparison) {
PrintSkipped(&skipped, [&] {
@@ -552,7 +614,7 @@
result = pure_subnormal_cache[cache_loc];
}
} else {
- result = result = CallOperation(evaluate_op, test_value);
+ result = CallOperation(evaluate_op, test_value);
}
if (IsClose(result, static_cast<NativeRefT>(actual), error_spec)) {
diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h
index 6a491be..6465350 100644
--- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h
+++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h
@@ -44,6 +44,7 @@
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/executable_run_options.h"
+#include "xla/fp_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
@@ -140,9 +141,40 @@
return IsSubnormal(value) || IsMinNormal(value);
}
+// Get the floating point distance (number of floating point values between)
+// expected and actual.
+//
+// This is a wrapper around xla::CalculateDistanceInFloats for most types. For
+// complex types, this returns the maximum distance between the real and
+// imaginary components.
+template <typename NativeT>
+int64_t GetDistanceErr(NativeT expected, NativeT actual) {
+ if constexpr (std::is_same_v<NativeT, xla::complex64> ||
+ std::is_same_v<NativeT, xla::complex128>) {
+ return std::max(
+ CalculateDistanceInFloats(expected.real(), actual.real()),
+ CalculateDistanceInFloats(expected.imag(), expected.imag()));
+ } else {
+ return CalculateDistanceInFloats(expected, actual);
+ }
+}
+
+class ErrorSpecBuilder;
+
struct ErrorSpec {
- double abs_err = 0;
- double rel_err = 0;
+ using Builder = ErrorSpecBuilder;
+
+ double abs_err = 0.0;
+ double rel_err = 0.0;
+ // The acceptable amount of floating point values between the expected and
+ // actual (also calling floating point distance).
+ //
+ // This is similar to absolute error, but the same distance_err can have
+ // different floating point values as the exponent changes. In some way, it is
+ // a hybrid of absolute and relative error, as it allows a fixed binary
+ // difference (like abs_err), but that has a varied floating point value based
+ // on the number (like rel_err).
+ int64_t distance_err = 0;
// If true, will consider -0 not near to +0 and vice versa. Note that
// +epsilon may still be considered close to -0, depending on the error
// spec; this only covers the case when both `expected` and `actual` are
@@ -154,6 +186,35 @@
bool skip_comparison = false;
};
+// Builder pattern to construct an ErrorSpec without a proliferation of
+// constructors or requiring extensive argument name comments.
+//
+// You can use an lvalue or rvalue to call the setter functions, but you can
+// only build (explicitly or implicitly) using an rvalue from std::move.
+class ErrorSpecBuilder {
+ public:
+ ErrorSpecBuilder() : spec_() {}
+
+ ErrorSpecBuilder& abs_err(double abs_err) &;
+ ErrorSpecBuilder& rel_err(double rel_err) &;
+ ErrorSpecBuilder& distance_err(int64_t distance_err) &;
+ ErrorSpecBuilder& strict_signed_zeros(bool strict_signed_zeros = true) &;
+ ErrorSpecBuilder& skip_comparison(bool skip_comparison = true) &;
+
+ ErrorSpecBuilder&& abs_err(double abs_err) &&;
+ ErrorSpecBuilder&& rel_err(double rel_err) &&;
+ ErrorSpecBuilder&& distance_err(int64_t distance_err) &&;
+ ErrorSpecBuilder&& strict_signed_zeros(bool strict_signed_zeros = true) &&;
+ ErrorSpecBuilder&& skip_comparison(bool skip_comparison = true) &&;
+
+ ErrorSpec build() &&;
+
+ explicit operator ErrorSpec() &&;
+
+ private:
+ ErrorSpec spec_;
+};
+
// Representations of the reference function passed in by the user.
template <typename NativeRefT, size_t K>
struct EvaluateOpWrapper {};
@@ -592,8 +653,12 @@
double abs_err =
std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual));
double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected));
+ // N.B.: For sub-32-bit floats, NativeRefT is `float`, so ULP comparisons
+ // will be wildly off. We convert back to NativeT for this comparison.
+ int64_t distance_err = GetDistanceErr(NativeT(expected), NativeT(actual));
- return abs_err <= spec.abs_err || rel_err <= spec.rel_err;
+ return abs_err <= spec.abs_err || rel_err <= spec.rel_err ||
+ distance_err <= spec.distance_err;
}
// Converts part or all bits in an uint64_t to the value of the floating point
@@ -1089,7 +1154,7 @@
kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<double>::min();
double rtol = kDefaultRelativeToleranceSlackFactor *
std::numeric_limits<double>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1097,7 +1162,7 @@
double atol =
kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<float>::min();
double rtol = 40 * std::numeric_limits<float>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1106,7 +1171,7 @@
kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<double>::min();
double rtol = kDefaultRelativeToleranceSlackFactor *
std::numeric_limits<double>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1115,7 +1180,7 @@
kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<float>::min();
double rtol = kDefaultRelativeToleranceSlackFactor *
std::numeric_limits<float>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1124,7 +1189,7 @@
std::numeric_limits<Eigen::half>::min();
// epsilon for FP16 is quite large, so a slack factor of 5 suffices.
double rtol = 5 * std::numeric_limits<Eigen::half>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1133,7 +1198,7 @@
std::numeric_limits<bfloat16>::min();
// epsilon for BF16 is quite large, so a slack factor of 2 suffices.
double rtol = 2 * std::numeric_limits<bfloat16>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1142,7 +1207,7 @@
kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<double>::min();
double rtol = kDefaultRelativeToleranceSlackFactor *
std::numeric_limits<double>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1151,7 +1216,7 @@
kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<float>::min();
double rtol = kDefaultRelativeToleranceSlackFactor *
std::numeric_limits<float>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1160,7 +1225,7 @@
std::numeric_limits<Eigen::half>::min();
// epsilon for FP16 is quite large, so a slack factor of 5 suffices.
double rtol = 5 * std::numeric_limits<Eigen::half>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <>
@@ -1169,7 +1234,7 @@
std::numeric_limits<bfloat16>::min();
// epsilon for BF16 is quite large, so a slack factor of 5 suffices.
double rtol = 2 * std::numeric_limits<bfloat16>::epsilon();
- return ErrorSpec{atol, rtol};
+ return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}
template <PrimitiveType T, size_t N>
@@ -1227,7 +1292,13 @@
};
template <PrimitiveType T>
-using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
+class ExhaustiveBinaryTest : public ExhaustiveOpTestBase<T, 2> {
+ public:
+ using typename ExhaustiveOpTestBase<T, 2>::ErrorSpecGen;
+ static ErrorSpecGen GetDefaultSpecGenerator() {
+ return exhaustive_op_test::GetDefaultSpecGenerator<T, 2>();
+ }
+};
} // namespace exhaustive_op_test
} // namespace xla
diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc
index 8c47362..6df8322 100644
--- a/third_party/xla/xla/tests/hlo_test_base.cc
+++ b/third_party/xla/xla/tests/hlo_test_base.cc
@@ -16,6 +16,7 @@
#include "xla/tests/hlo_test_base.h"
#include <functional>
+#include <iterator>
#include <memory>
#include <set>
#include <string>
@@ -27,6 +28,9 @@
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
#include "xla/layout_util.h"
#include "xla/service/hlo_module_util.h"
#include "xla/service/hlo_parser.h"
@@ -1010,23 +1014,15 @@
HloComputation* HloTestBase::FindComputation(HloModule* module,
absl::string_view name) {
- auto computations = module->computations();
- auto it = absl::c_find_if(
- computations, [&](HloComputation* c) { return c->name() == name; });
- if (it == computations.end()) {
- return nullptr;
- }
- return *it;
+ return hlo_query::FindComputation(module, name);
}
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
absl::string_view name) {
- for (const HloComputation* c : module->computations()) {
- auto instructions = c->instructions();
- auto it = absl::c_find_if(
- instructions, [&](HloInstruction* i) { return i->name() == name; });
- if (it != instructions.end()) {
- return *it;
+ for (const HloComputation* computation : module->computations()) {
+ if (auto instruction = hlo_query::FindFirstInstruction(computation, name);
+ instruction.first != nullptr) {
+ return instruction.first;
}
}
return nullptr;
@@ -1034,17 +1030,25 @@
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
HloOpcode opcode) {
- for (const HloComputation* c : module->computations()) {
- auto instructions = c->instructions();
- auto it = absl::c_find_if(
- instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
- if (it != instructions.end()) {
- return *it;
+ for (const HloComputation* computation : module->computations()) {
+ if (auto instruction = hlo_query::FindFirstInstruction(computation, opcode);
+ instruction.first != nullptr) {
+ return instruction.first;
}
}
return nullptr;
}
+std::vector<HloInstruction*> HloTestBase::FindInstructions(HloModule* module,
+ HloOpcode opcode) {
+ std::vector<HloInstruction*> instructions;
+ for (const HloComputation* c : module->computations()) {
+ absl::c_copy_if(c->instructions(), std::back_inserter(instructions),
+ [&](HloInstruction* i) { return i->opcode() == opcode; });
+ }
+ return instructions;
+}
+
se::DeviceMemoryAllocator* HloTestBase::GetAllocator() {
if (allocator_ == nullptr) {
allocator_ = std::make_unique<se::StreamExecutorMemoryAllocator>(
diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h
index 9858ed6..4c194c8 100644
--- a/third_party/xla/xla/tests/hlo_test_base.h
+++ b/third_party/xla/xla/tests/hlo_test_base.h
@@ -25,8 +25,10 @@
#include "absl/status/statusor.h"
#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
+#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/backend.h"
#include "xla/service/computation_layout.h"
#include "xla/service/hlo_runner.h"
@@ -423,13 +425,19 @@
}
// Gets the computation/instruction from the given module with the given name.
- //
+ // Note that it is encouraged to use these functions directly via the
+ // hlo_query.h header instead since they are independent from any test-time
+ // variables or contexts.
+
// This is useful for tests which create HLOs from a string and then want to
// inspect a particular computation or instruction.
HloComputation* FindComputation(HloModule* module, absl::string_view name);
HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
// Gets the instruction from the given module with the given opcode.
HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
+ // Gets all the instructions from the given module with the given opcode.
+ std::vector<HloInstruction*> FindInstructions(HloModule* module,
+ HloOpcode opcode);
// Return an HLO verifier constructed for the test backend.
HloVerifier& verifier() const { return *hlo_verifier_; }
@@ -438,6 +446,7 @@
// Returns the backend owned by the test runner.
Backend& backend();
+ int64_t num_devices() { return backend().device_count(); }
HloRunner test_runner_;
HloRunner reference_runner_;
@@ -513,6 +522,13 @@
se::Platform* test_platform);
};
+#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \
+ int64_t num_devices = backend().device_count(); \
+ if (num_devices < x) { \
+ GTEST_SKIP() << "Test requires at least " << x << " devices (" \
+ << num_devices << " available)"; \
+ }
+
} // namespace xla
#endif // XLA_TESTS_HLO_TEST_BASE_H_
diff --git a/third_party/xla/xla/tests/numerics_test.cc b/third_party/xla/xla/tests/numerics_test.cc
index 8a54242..b1bfcd9 100644
--- a/third_party/xla/xla/tests/numerics_test.cc
+++ b/third_party/xla/xla/tests/numerics_test.cc
@@ -24,6 +24,7 @@
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_macros.h"
#include "xla/types.h"
+#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
namespace xla {
@@ -86,5 +87,79 @@
std::numeric_limits<float>::quiet_NaN(), 0));
}
+// Case from one of XLA users, the following code produced incorrect results on
+// CPU thunks backend (due to incorrect LLVM IR generated).
+// This is an HLO module optimized for CPU backend, it may be invalid for other
+// backends.
+XLA_TEST_F(NumericsTest,
+ DISABLED_ON_GPU(DISABLED_ON_TPU(MultiplySubtractConcatTest))) {
+ const char* test_hlo = R"(
+ HloModule jit_step, is_scheduled=true
+
+ fused_computation {
+ param_0.2 = f32[1,5] parameter(0)
+ slice.11 = f32[1,1] slice(param_0.2), slice={[0:1], [1:2]}
+ slice.10 = f32[1,1] slice(param_0.2), slice={[0:1], [4:5]}
+ multiply.11 = f32[1,1] multiply(slice.11, slice.10)
+ slice.9 = f32[1,1] slice(param_0.2), slice={[0:1], [2:3]}
+ slice.8 = f32[1,1] slice(param_0.2), slice={[0:1], [3:4]}
+ multiply.10 = f32[1,1] multiply(slice.9, slice.8)
+ subtract.5 = f32[1,1] subtract(multiply.11, multiply.10)
+ slice.6 = f32[1,1] slice(param_0.2), slice={[0:1], [0:1]}
+ multiply.8 = f32[1,1] multiply(slice.6, slice.10)
+ subtract.4 = f32[1,1] subtract(slice.9, multiply.8)
+ ROOT concatenate.1 = f32[1,3] concatenate(
+ subtract.5, subtract.4, subtract.4), dimensions={1}
+ } // fused_computation
+
+ ENTRY main {
+ Arg_0.0 = f32[1,5] parameter(0)
+ ROOT fusion = f32[1,3] fusion(Arg_0.0), kind=kLoop,
+ calls=fused_computation
+ } // main
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto test_module,
+ ParseAndReturnVerifiedModule(test_hlo));
+ auto argument = LiteralUtil::CreateR2<float>(
+ {{0.261473775, -0.642940283, -0.719902277, 0.712947428, 0.543724537}});
+
+ TF_ASSERT_OK_AND_ASSIGN(auto test_result,
+ Execute(std::move(test_module), {&argument},
+ /*run_hlo_passes=*/false));
+
+ // Reference HLO module. It's a subgraph of the test module, it performs only
+ // the calculations needed for the first output element from the test module.
+ const char* reference_hlo = R"(
+ HloModule jit_step, is_scheduled=true
+
+ fused_computation {
+ param_0.2 = f32[1,5] parameter(0)
+ slice.11 = f32[1,1] slice(param_0.2), slice={[0:1], [1:2]}
+ slice.10 = f32[1,1] slice(param_0.2), slice={[0:1], [4:5]}
+ multiply.11 = f32[1,1] multiply(slice.11, slice.10)
+ slice.9 = f32[1,1] slice(param_0.2), slice={[0:1], [2:3]}
+ slice.8 = f32[1,1] slice(param_0.2), slice={[0:1], [3:4]}
+ multiply.10 = f32[1,1] multiply(slice.9, slice.8)
+ ROOT subtract.5 = f32[1,1] subtract(multiply.11, multiply.10)
+ } // fused_computation
+
+ ENTRY main {
+ Arg_0.0 = f32[1,5] parameter(0)
+ ROOT fusion = f32[1,1] fusion(Arg_0.0), kind=kLoop,
+ calls=fused_computation
+ } // main
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto reference_module,
+ ParseAndReturnVerifiedModule(reference_hlo));
+ TF_ASSERT_OK_AND_ASSIGN(auto reference_result,
+ Execute(std::move(reference_module), {&argument},
+ /*run_hlo_passes=*/false));
+
+ // Only compare the first element.
+ EXPECT_EQ(reference_result.data<float>()[0], test_result.data<float>()[0]);
+}
+
} // namespace
} // namespace xla
diff --git a/third_party/xla/xla/tests/replicated_io_feed_test.cc b/third_party/xla/xla/tests/replicated_io_feed_test.cc
index 9ee34a7..0164f8b 100644
--- a/third_party/xla/xla/tests/replicated_io_feed_test.cc
+++ b/third_party/xla/xla/tests/replicated_io_feed_test.cc
@@ -50,7 +50,10 @@
result = u32[] add(infeed.data, replica_id)
outfeed = token[] outfeed(result, infeed.token), outfeed_shape=u32[]
})";
+
const int kNumReplicas = 4;
+ SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
+
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
std::unique_ptr<HloModule> module =
diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD
index 36e284e..6975a38 100644
--- a/third_party/xla/xla/tools/BUILD
+++ b/third_party/xla/xla/tools/BUILD
@@ -56,7 +56,6 @@
name = "hex_floats_to_packed_literal",
srcs = ["hex_floats_to_packed_literal.cc"],
deps = [
- "//xla:types",
"//xla/tsl/util:command_line_flags",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
@@ -127,13 +126,12 @@
name = "convert_computation",
srcs = ["convert_computation.cc"],
deps = [
- "//xla:types",
"//xla/service:hlo_proto_cc",
- "@com_google_absl//absl/status:statusor",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:platform_port",
"@local_tsl//tsl/platform:protobuf",
+ "@local_tsl//tsl/platform:status",
],
)
@@ -193,19 +191,24 @@
name = "dumped_computation_to_text",
srcs = ["dumped_computation_to_text.cc"],
deps = [
- "//xla:types",
- "//xla/client",
+ "//xla:shape_util",
+ "//xla:xla_proto_cc",
"//xla/client:client_library",
+ "//xla/client:executable_build_options",
"//xla/client:local_client",
"//xla/client:xla_computation",
+ "//xla/hlo/ir:hlo",
"//xla/service",
"//xla/service:hlo_proto_cc",
"//xla/service:interpreter_plugin",
+ "//xla/service:local_service",
+ "//xla/tsl/util:command_line_flags",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:platform_port",
+ "@local_tsl//tsl/platform:status",
],
)
@@ -220,15 +223,16 @@
name = "dumped_computation_to_operation_list",
srcs = ["dumped_computation_to_operation_list.cc"],
deps = [
- "//xla:types",
- "//xla/client",
+ "//xla:shape_util",
"//xla/client:client_library",
+ "//xla/client:executable_build_options",
"//xla/client:local_client",
"//xla/client:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/service",
"//xla/service:hlo_proto_cc",
"//xla/service:interpreter_plugin",
+ "//xla/service:local_service",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
@@ -237,6 +241,7 @@
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:platform_port",
+ "@local_tsl//tsl/platform:status",
],
)
@@ -508,6 +513,19 @@
],
)
+xla_cc_test(
+ name = "prepare_reference_module_test",
+ srcs = ["prepare_reference_module_test.cc"],
+ deps = [
+ ":prepare_reference_module",
+ "//xla:test",
+ "//xla/hlo/ir:hlo",
+ "//xla/tests:hlo_test_base",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
tf_proto_library(
name = "run_hlo_module_proto",
srcs = ["run_hlo_module.proto"],
@@ -633,6 +651,7 @@
"data/add.hlo",
"data/add_mhlo.mlir",
"data/add_stablehlo.mlir",
+ "data/input_literal_f32_2_2.pbtxt",
"data/must_alias.hlo",
"data/must_alias_with_sharding.hlo",
":run_hlo_module",
@@ -644,6 +663,7 @@
"//xla/service:hlo_parser",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:statusor",
@@ -718,8 +738,14 @@
srcs = ["compute_cost.cc"],
deps = [
":hlo_module_loader",
+ "//xla:debug_options_flags",
+ "//xla:shape_util",
"//xla/service:hlo_cost_analysis",
+ "//xla/tsl/util:command_line_flags",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:platform_port",
+ "@local_tsl//tsl/platform:status",
],
)
@@ -798,32 +824,23 @@
"//xla/service/gpu:amdgpu_compiler",
"//xla/service/gpu:amdgpu_compiler_impl",
]) + if_gpu_is_configured([
- "//xla/service/gpu:autotuner_util",
"//xla/service/gpu:executable_proto_cc",
"//xla/service/gpu:gpu_compiler",
+ "//xla/service/gpu/autotuning:autotuner_util",
"//xla/stream_executor/gpu:gpu_init",
"//xla/service/gpu:gpu_symbol_repository",
]) + if_google(["@com_google_protobuf//:duration_cc_proto"]),
)
xla_test(
- name = "xla_compile_lib_test",
- srcs = ["xla_compile_lib_test.cc"],
- backend_tags = {
- "gpu": ["requires-gpu-nvidia"] + if_google(["config-cuda-only"]),
- },
+ name = "xla_cpu_compile_lib_test",
+ srcs = ["xla_cpu_compile_lib_test.cc"],
backends = [
"cpu",
- "gpu",
],
data = [
":data/add.hlo",
- "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt",
- "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto",
],
- local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
- "TENSORFLOW_USE_ROCM=1",
- ]),
deps = [
":xla_compile_lib",
"//xla:util",
@@ -831,11 +848,7 @@
"//xla/service:platform_util",
"//xla/service:symbol_repository",
"//xla/service:xla_compile_result_proto_cc_impl",
- "//xla/service/gpu:autotuner_util",
- "//xla/service/gpu:gpu_symbol_repository",
- "//xla/stream_executor:device_description_proto_cc",
"//xla/tests:hlo_test_base",
- "//xla/tests:test_macros_header",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/strings",
@@ -843,6 +856,7 @@
"@com_google_googletest//:gtest",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:env_time",
+ "@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
@@ -853,6 +867,44 @@
)
xla_test(
+ name = "xla_gpu_compile_lib_test",
+ srcs = ["xla_gpu_compile_lib_test.cc"],
+ backend_tags = {
+ "gpu": ["requires-gpu-nvidia"] + if_google(["config-cuda-only"]),
+ },
+ backends = [
+ "gpu",
+ ],
+ data = [
+ ":data/add.hlo",
+ "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt",
+ "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto",
+ ],
+ deps = [
+ ":xla_compile_lib",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:platform_util",
+ "//xla/service:symbol_repository",
+ "//xla/service:xla_compile_result_proto_cc_impl",
+ "//xla/service/gpu:gpu_symbol_repository",
+ "//xla/service/gpu/autotuning:autotuner_util",
+ "//xla/stream_executor:device_description_proto_cc",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//xla/tsl/lib/core:status_test_util",
+ "@com_google_googletest//:gtest",
+ "@local_tsl//tsl/platform:env",
+ "@local_tsl//tsl/platform:path",
+ "@local_tsl//tsl/platform:status_matchers",
+ "@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
+ "@local_tsl//tsl/protobuf:status_proto_cc",
+ ],
+)
+
+xla_test(
name = "hlo_decomposer_test",
srcs = ["hlo_decomposer_test.cc"],
deps = [
diff --git a/third_party/xla/xla/tools/compute_cost.cc b/third_party/xla/xla/tools/compute_cost.cc
index c5e0ddc..9615ae0 100644
--- a/third_party/xla/xla/tools/compute_cost.cc
+++ b/third_party/xla/xla/tools/compute_cost.cc
@@ -21,9 +21,16 @@
#include <string>
#include <vector>
+#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
+#include "xla/debug_options_flags.h"
#include "xla/service/hlo_cost_analysis.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
#include "xla/tools/hlo_module_loader.h"
+#include "xla/tsl/util/command_line_flags.h"
#include "tsl/platform/init_main.h"
+#include "tsl/platform/status.h"
namespace {
const char* const kUsage = R"(
diff --git a/third_party/xla/xla/tools/convert_computation.cc b/third_party/xla/xla/tools/convert_computation.cc
index f81d517..7ebc5d3 100644
--- a/third_party/xla/xla/tools/convert_computation.cc
+++ b/third_party/xla/xla/tools/convert_computation.cc
@@ -16,6 +16,7 @@
// Usage: convert_computation <txt2bin|bin2txt> serialized_computation_proto
//
// bin2txt spits out the result to stdout. txt2bin modifies the file in place.
+#include "tsl/platform/status.h"
#ifndef _WIN32
#include <unistd.h>
#endif
@@ -23,9 +24,7 @@
#include <string>
-#include "absl/status/statusor.h"
#include "xla/service/hlo.pb.h"
-#include "xla/types.h"
#include "tsl/platform/env.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
diff --git a/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt b/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt
new file mode 100644
index 0000000..6c39d03
--- /dev/null
+++ b/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt
@@ -0,0 +1,20 @@
+# proto-file: third_party/tensorflow/compiler/xla/tools/run_hlo_module.proto
+# proto-message: RunHloModuleIterationLiterals
+arguments {
+ shape {
+ element_type: F32
+ dimensions: 2
+ dimensions: 2
+ layout {
+ minor_to_major: 1
+ minor_to_major: 0
+ tail_padding_alignment_in_elements: 1
+ }
+ is_dynamic_dimension: false
+ is_dynamic_dimension: false
+ }
+ f32s: 0.1
+ f32s: 0.2
+ f32s: 0.3
+ f32s: 0.4
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc
index e17d487..6450210 100644
--- a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc
+++ b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc
@@ -22,20 +22,26 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
-#include "xla/client/client.h"
#include "xla/client/client_library.h"
+#include "xla/client/executable_build_options.h"
#include "xla/client/local_client.h"
#include "xla/client/xla_computation.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/hlo.pb.h"
-#include "xla/service/service.h"
-#include "xla/types.h"
+#include "xla/service/local_service.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
namespace xla {
namespace tools {
diff --git a/third_party/xla/xla/tools/dumped_computation_to_text.cc b/third_party/xla/xla/tools/dumped_computation_to_text.cc
index df9116e..695d4c9 100644
--- a/third_party/xla/xla/tools/dumped_computation_to_text.cc
+++ b/third_party/xla/xla/tools/dumped_computation_to_text.cc
@@ -21,16 +21,21 @@
#include "absl/status/statusor.h"
#include "absl/types/span.h"
-#include "xla/client/client.h"
#include "xla/client/client_library.h"
+#include "xla/client/executable_build_options.h"
#include "xla/client/local_client.h"
#include "xla/client/xla_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo.pb.h"
-#include "xla/service/service.h"
-#include "xla/types.h"
+#include "xla/service/local_service.h"
+#include "xla/shape.h"
+#include "xla/tsl/util/command_line_flags.h"
+#include "xla/xla.pb.h"
#include "tsl/platform/env.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
namespace xla {
namespace tools {
diff --git a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc
index c4d591b..6388a8f 100644
--- a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc
+++ b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc
@@ -21,10 +21,10 @@
#include "absl/base/casts.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/util/command_line_flags.h"
-#include "xla/types.h"
#include "tsl/lib/io/buffered_inputstream.h"
#include "tsl/lib/io/random_inputstream.h"
#include "tsl/platform/env.h"
+#include "tsl/platform/file_system.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/status.h"
diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD
index 3df7fbb..7f9747b 100644
--- a/third_party/xla/xla/tools/hlo_bisect/BUILD
+++ b/third_party/xla/xla/tools/hlo_bisect/BUILD
@@ -84,6 +84,7 @@
"//xla:protobuf_util",
"//xla:util",
"//xla/hlo/ir:hlo",
+ "//xla/service:dump",
"//xla/service:hlo_parser",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_proto_util",
diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc
index 0e38151..d4e6d0d 100644
--- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc
+++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc
@@ -25,6 +25,7 @@
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/protobuf_util.h"
+#include "xla/service/dump.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_parser.h"
#include "xla/service/hlo_proto_util.h"
@@ -137,7 +138,7 @@
HloProto proto = MakeHloProto(*module);
if (output_format == "hlo") {
tsl::Env* env = tsl::Env::Default();
- TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(std::string(dir_path)));
+ TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(dir_path), env));
std::string file_path =
tsl::io::JoinPath(dir_path, SanitizeFileName(file_name)) + ".hlo";
LOG(INFO) << "Dumped HLO text to " << file_path;
@@ -148,8 +149,8 @@
.set_compact_operands(false))));
} else if (output_format == "pb") {
std::string path;
- TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
- proto, std::string(dir_path), file_name, &path));
+ TF_RETURN_IF_ERROR(
+ DumpProtoToDirectory(proto, std::string(dir_path), file_name, &path));
LOG(INFO) << "Dumped HLO module proto to " << path;
} else {
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
index d37c6b0..e81f3ef 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
@@ -496,7 +496,7 @@
TF_RETURN_IF_ERROR(RemoveCollective(instruction).status());
}
changed = true;
- } else if (remove_comm_ &&
+ } else if ((remove_comm_ || remove_id_) &&
(instruction->opcode() == HloOpcode::kPartitionId ||
instruction->opcode() == HloOpcode::kReplicaId ||
(instruction->opcode() == HloOpcode::kCustomCall &&
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.h b/third_party/xla/xla/tools/hlo_control_flow_flattening.h
index cff9db4..450aeab 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening.h
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.h
@@ -49,6 +49,8 @@
bool flatten_while_loop = true;
bool remove_comm = true;
bool remove_host_transfer = false;
+ // Removes partition-id, replica-id, and slice-id.
+ bool remove_id = false;
};
explicit HloControlFlowFlattening(const Options& options)
: while_execution_count_(options.while_execution_count),
@@ -57,7 +59,8 @@
remove_infeed_outfeed_(options.remove_infeed_outfeed),
flatten_while_loop_(options.flatten_while_loop),
remove_host_transfer_(options.remove_host_transfer),
- remove_comm_(options.remove_comm) {}
+ remove_comm_(options.remove_comm),
+ remove_id_(options.remove_id) {}
~HloControlFlowFlattening() override = default;
absl::string_view name() const override { return "control-flow-flattening"; }
using HloPassInterface::Run;
@@ -102,6 +105,7 @@
HloInstruction* recv_done,
absl::flat_hash_set<HloInstruction*>* additional_removed) const;
bool remove_comm_;
+ bool remove_id_;
};
// Retrieves the original loop bound. If fail, return a default value. If bounds
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
index ceb51be..40d16ca 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
@@ -515,6 +515,37 @@
"replica-id.18600");
}
+TEST_F(HloControlFlowFlatteningTest, RemoveReplicaIdButKeepAllReduce) {
+ absl::string_view kHloText = R"(
+ HloModule RemoveReplicaIdButKeepCollective
+
+%sum (a: f32[], b: f32[]) -> f32[] {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] a, f32[] b)
+ }
+ ENTRY ReplicaId {
+ replica-id.1 = u32[]{:T(128)} replica-id()
+ ROOT all-reduce.1 = u32[]{:T(128)} all-reduce(replica-id.1), to_apply=sum, replica_groups={}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(kHloText));
+ HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
+ /*while_execution_count=*/1, /*max_outer_loop_count=*/1,
+ /*max_loop_count=*/1, /*remove_infeed_outfeed=*/false,
+ /*flatten_while_loop=*/false, /*remove_comm=*/false,
+ /*remove_host_transfer=*/false, /*remove_id=*/true});
+ EXPECT_TRUE(flattening.Run(module.get()).value());
+ TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/true)
+ .Run(module.get())
+ .status());
+ EXPECT_THAT(module->entry_computation()->root_instruction(), op::AllReduce());
+ EXPECT_THAT(module->entry_computation()->root_instruction()->operand(0),
+ op::Constant());
+}
+
TEST_F(HloControlFlowFlatteningTest, CollectivePermuteInPlaceUpdate) {
absl::string_view hlo_string = R"(
HloModule CollectivePermuteInPlaceUpdate
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD
index cefc272..1e1450b 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD
@@ -84,6 +84,7 @@
testonly = True,
tags = [
"gpu",
+ "no_rocm",
"nomac",
] + tf_gpu_tests_tags(),
deps = [
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc
index 1a490a0..3aabf56 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc
@@ -57,43 +57,41 @@
if (enable_mock_nccl) {
CHECK_GT(num_nodes, 1);
return CreateMockGpuClient(num_nodes);
- } else {
- if (num_nodes == 1) {
- return CreateGpuClient({});
- } else {
- TF_RET_CHECK(!address.empty());
- TF_RET_CHECK(node_id >= 0)
- << "Node id is expected to be in range [0, num_nodes)";
- TF_RET_CHECK(node_id < num_nodes)
- << "Node id is expected to be in range [0, num_nodes)";
-
- CHECK_GT(address.length(), 0);
- // Multinode. Start service on task 0.
- if (node_id == 0) {
- std::string coordinator_bind_address =
- "[::]:" + std::string(address).substr(address.rfind(':') + 1);
- xla::CoordinationServiceImpl::Options options;
- options.num_nodes = num_nodes;
- auto status_or = xla::GetDistributedRuntimeService(
- coordinator_bind_address, options);
- TF_QCHECK_OK(status_or.status());
- service = std::move(status_or.value());
- }
- xla::DistributedRuntimeClient::Options options;
- options.node_id = node_id;
- options.init_timeout = init_timeout;
- distributed_client =
- GetDistributedRuntimeClient(std::string(address), options);
- TF_QCHECK_OK(distributed_client->Connect());
- kv_store = GetDistributedKeyValueStore(distributed_client,
- /*key_prefix=*/"gpu:");
- GpuClientOptions gpu_client_options;
- gpu_client_options.node_id = node_id;
- gpu_client_options.num_nodes = num_nodes;
- gpu_client_options.kv_store = kv_store;
- return CreateGpuClient(std::move(gpu_client_options));
- }
}
+
+ if (num_nodes == 1) {
+ return CreateGpuClient({});
+ }
+
+ TF_RET_CHECK(!address.empty());
+ TF_RET_CHECK(node_id >= 0)
+ << "Node id is expected to be in range [0, num_nodes)";
+ TF_RET_CHECK(node_id < num_nodes)
+ << "Node id is expected to be in range [0, num_nodes)";
+
+ CHECK_GT(address.length(), 0);
+ // Multinode. Start service on task 0.
+ if (node_id == 0) {
+ std::string coordinator_bind_address =
+ "[::]:" + std::string(address).substr(address.rfind(':') + 1);
+ xla::CoordinationServiceImpl::Options options;
+ options.num_nodes = num_nodes;
+ TF_ASSIGN_OR_RETURN(service, xla::GetDistributedRuntimeService(
+ coordinator_bind_address, options));
+ }
+ xla::DistributedRuntimeClient::Options options;
+ options.node_id = node_id;
+ options.init_timeout = init_timeout;
+ distributed_client =
+ GetDistributedRuntimeClient(std::string(address), options);
+ TF_QCHECK_OK(distributed_client->Connect());
+ kv_store = GetDistributedKeyValueStore(distributed_client,
+ /*key_prefix=*/"gpu:");
+ GpuClientOptions gpu_client_options;
+ gpu_client_options.node_id = node_id;
+ gpu_client_options.num_nodes = num_nodes;
+ gpu_client_options.kv_store = kv_store;
+ return CreateGpuClient(std::move(gpu_client_options));
}
absl::StatusOr<PjRtEnvironment> GetPjRtClient(absl::string_view device_type,
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo b/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo
index c745ee7..c182a59 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo
@@ -1,35 +1,46 @@
f1 {
- p0 = f16[720,720,720]{2,1,0} parameter(0)
- p1 = s8[720,720,720]{2,1,0} parameter(1)
- c = f16[720,720,720]{2,1,0} convert(p1)
- ROOT d1 = f16[720,720,720]{2,1,0} dot(p0, c),
+ p0 = f16[64,64,64] parameter(0)
+ p1 = s8[64,64,64] parameter(1)
+ c = f16[64,64,64] convert(p1)
+ ROOT d1 = f32[64,64,64] dot(p0, c),
lhs_batch_dims={0}, lhs_contracting_dims={2},
rhs_batch_dims={0}, rhs_contracting_dims={1}
}
f2 {
- p0 = s8[720,720,720]{2,1,0} parameter(0)
- c0 = f32[720,720,720]{2,1,0} convert(p0)
- p1 = f16[720,720,720]{2,1,0} parameter(1)
- c1 = f32[720,720,720]{2,1,0} convert(p1)
- ROOT %dot.1 = f32[720,720,720]{2,1,0} dot(c0, c1),
+ p0 = s8[64,64,64] parameter(0)
+ c0 = f32[64,64,64] convert(p0)
+ p1 = f16[64,64,64] parameter(1)
+ c1 = f32[64,64,64] convert(p1)
+ ROOT d2 = f32[64,64,64] dot(c0, c1),
lhs_batch_dims={0}, lhs_contracting_dims={2},
rhs_batch_dims={0}, rhs_contracting_dims={1}
}
+f3 {
+ p0 = f16[64,64,64] parameter(0)
+ p1 = f16[64,64,64] parameter(1)
+ ROOT d3 = f32[64,64,64] dot(p0, p1),
+ lhs_batch_dims={0}, lhs_contracting_dims={1},
+ rhs_batch_dims={0}, rhs_contracting_dims={2}
+}
+
fa {
- p1 = f16[720,720,720]{2,1,0} parameter(1)
- c = f32[720,720,720]{2,1,0} convert(p1)
- p0 = f32[720,720,720]{2,1,0} parameter(0)
- ROOT %add.1.1 = f32[720,720,720]{2,1,0} add(c, p0)
+ p0 = f32[64,64,64] parameter(0)
+ p1 = f32[64,64,64] parameter(1)
+ p2 = f32[64,64,64] parameter(2)
+ a1 = f32[64,64,64] add(p2, p1)
+ ROOT a = f32[64,64,64] add(p0, a1)
}
ENTRY e {
- p1 = s8[720,720,720]{2,1,0} parameter(1)
- p0 = f16[720,720,720]{2,1,0} parameter(0)
- f1r = f16[720,720,720]{2,1,0} fusion(p0, p1), kind=kCustom, calls=f1,
+ p0 = f16[64,64,64] parameter(0)
+ p1 = s8[64,64,64] parameter(1)
+ f1r = f32[64,64,64] fusion(p0, p1), kind=kCustom, calls=f1,
backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
- f2r = f32[720,720,720]{2,1,0} fusion(p1, p0), kind=kCustom, calls=f2,
+ f2r = f32[64,64,64] fusion(p1, p0), kind=kCustom, calls=f2,
backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
- ROOT _ = f32[720,720,720]{2,1,0} fusion(f2r, f1r), kind=kLoop, calls=fa
+ f3r = f32[64,64,64] fusion(p0, p0), kind=kCustom, calls=f3,
+ backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
+ ROOT _ = f32[64,64,64] fusion(f1r, f2r, f3r), kind=kLoop, calls=fa
}
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
index b55c1be..405e421 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
@@ -256,7 +256,7 @@
// Name of the test binary.
static const char* binary_name;
-constexpr int kNumNodes = 3;
+constexpr int kNumNodes = 2;
TEST_F(FunctionalHloRunnerTest, ShardedAutotuningWorks) {
if (IsTestingCpu()) {
@@ -308,13 +308,8 @@
env.kv_store->Get("gemm_fusion_autotuning_results_1_1",
absl::Seconds(1)));
CHECK(absl::StrContains(results1, "run_time"));
- // First two nodes autotune two different fusions.
+ // The nodes autotune different fusions.
CHECK_NE(results0, results1);
- TF_ASSIGN_OR_RETURN(std::string results2,
- env.kv_store->Get("gemm_fusion_autotuning_results_1_2",
- absl::Seconds(1)));
- // Third node has nothing to autotune.
- CHECK(!absl::StrContains(results2, "run_time"));
}
return absl::OkStatus();
}
diff --git a/third_party/xla/xla/tools/prepare_reference_module.cc b/third_party/xla/xla/tools/prepare_reference_module.cc
index 4ce766d..82fd57a 100644
--- a/third_party/xla/xla/tools/prepare_reference_module.cc
+++ b/third_party/xla/xla/tools/prepare_reference_module.cc
@@ -34,7 +34,8 @@
const HloModule& test_module, HloRunnerInterface* test_runner,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
const std::function<absl::Status(const HloModule&, HloRunnerInterface*,
- HloModule*)>& module_modifier_hook) {
+ HloModule*)>& module_modifier_hook,
+ bool skip_despecialization) {
DebugOptions debug_options = GetDebugOptionsFromFlags();
// The combination of fast math and optimizations leads to unsound code
// transformations (see third_party/tensorflow/compiler/xla/xla.proto for
@@ -51,7 +52,7 @@
if (module_modifier_hook) {
TF_RETURN_IF_ERROR(
module_modifier_hook(test_module, test_runner, reference_module.get()));
- } else {
+ } else if (!skip_despecialization) {
TF_RETURN_IF_ERROR(Despecializer().Run(reference_module.get()).status());
}
return std::move(reference_module);
diff --git a/third_party/xla/xla/tools/prepare_reference_module.h b/third_party/xla/xla/tools/prepare_reference_module.h
index 4a1064d..f26e847 100644
--- a/third_party/xla/xla/tools/prepare_reference_module.h
+++ b/third_party/xla/xla/tools/prepare_reference_module.h
@@ -37,7 +37,8 @@
const HloModule& test_module, HloRunnerInterface* test_runner,
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
const std::function<absl::Status(const HloModule&, HloRunnerInterface*,
- HloModule*)>& module_modifier_hook = {});
+ HloModule*)>& module_modifier_hook = {},
+ bool skip_despecialization = false);
} // namespace xla
diff --git a/third_party/xla/xla/tools/prepare_reference_module_test.cc b/third_party/xla/xla/tools/prepare_reference_module_test.cc
new file mode 100644
index 0000000..0b2ad0e
--- /dev/null
+++ b/third_party/xla/xla/tools/prepare_reference_module_test.cc
@@ -0,0 +1,76 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/tools/prepare_reference_module.h"
+
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+const char* const kModuleStr = R"(
+ HloModule jit_step
+
+ %fused_computation (param_0.2: f32[1,4]) -> f32[1,3] {
+ %param_0.2 = f32[1,4]{1,0} parameter(0)
+ ROOT %slice.11 = f32[1,3]{1,0} slice(f32[1,4]{1,0} %param_0.2),
+ slice={[0:1], [0:3]}
+ }
+
+ ENTRY %main.3491 (Arg_0.0: f32[1,4]) -> f32[1,3] {
+ %Arg_0.0 = f32[1,4]{1,0} parameter(0)
+ ROOT %fusion = f32[1,3]{1,0} fusion(f32[1,4]{1,0} %Arg_0.0), kind=kLoop,
+ calls=%fused_computation
+ }
+)";
+
+using PrepareReferenceModuleTest = HloTestBase;
+
+// Ideally 'Despecializer' pass should be mocked. Because it is not feasible
+// with the current design, despecialization tests in this file are based on
+// Despecializer's implementation (Despecializer removes fusion op from the
+// module).
+TEST_F(PrepareReferenceModuleTest, PerformDespecialization) {
+ TF_ASSERT_OK_AND_ASSIGN(auto test_module,
+ ParseAndReturnVerifiedModule(kModuleStr));
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto reference_module,
+ PrepareReferenceModule(*test_module, nullptr, {}, {},
+ /*skip_despecialization=*/false));
+
+ // Fusion op should have been removed.
+ EXPECT_THAT(reference_module->ToString(),
+ Not(::testing::HasSubstr("fusion")));
+}
+
+TEST_F(PrepareReferenceModuleTest, SkipDespecialization) {
+ TF_ASSERT_OK_AND_ASSIGN(auto test_module,
+ ParseAndReturnVerifiedModule(kModuleStr));
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto reference_module,
+ PrepareReferenceModule(*test_module, nullptr, {}, {},
+ /*skip_despecialization=*/true));
+
+ // Fusion op should be there.
+ EXPECT_THAT(reference_module->ToString(), ::testing::HasSubstr("fusion"));
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/tools/run_hlo_module.cc b/third_party/xla/xla/tools/run_hlo_module.cc
index 690b509..22c0c02 100644
--- a/third_party/xla/xla/tools/run_hlo_module.cc
+++ b/third_party/xla/xla/tools/run_hlo_module.cc
@@ -168,6 +168,12 @@
return std::move(result_status).value();
}
+void UseCpuThunkRuntime(HloModule& module) {
+ auto debug_options = module.config().debug_options();
+ debug_options.set_xla_cpu_use_thunk_runtime(true);
+ module.mutable_config().set_debug_options(debug_options);
+}
+
absl::Status RunAndCompareInternal(
std::unique_ptr<HloModule> test_module,
const BufferAssignmentProto* buffer_assignment_proto,
@@ -255,17 +261,27 @@
std::unique_ptr<HloModule> reference_module;
if (reference_runner != nullptr) {
+ // If reference platform is the same as test platform, we shouldn't
+ // deoptimize the reference module.
+ bool skip_deoptimization = options.reference_platform == options.platform;
+
// PrepareReferenceModule needs to know the *test* runner, in order to
// properly match the test runner's numerics.
TF_ASSIGN_OR_RETURN(
reference_module,
copy_result_on_failure(
- PrepareReferenceModule(*test_module, test_runner,
- config_modifier_hook,
- reference_module_modifier_hook),
+ PrepareReferenceModule(
+ *test_module, test_runner, config_modifier_hook,
+ reference_module_modifier_hook, skip_deoptimization),
ModuleResult::kCompilationError, reference_run_result));
}
+ // Now when reference_module is ready, we can modify test_module without
+ // impacting the reference run.
+ if (options.force_use_cpu_thunk_runtime_for_test) {
+ UseCpuThunkRuntime(*test_module);
+ }
+
TF_ASSIGN_OR_RETURN(
auto test_result,
copy_result_on_failure(
diff --git a/third_party/xla/xla/tools/run_hlo_module.h b/third_party/xla/xla/tools/run_hlo_module.h
index 3300f1b..66afdc5 100644
--- a/third_party/xla/xla/tools/run_hlo_module.h
+++ b/third_party/xla/xla/tools/run_hlo_module.h
@@ -40,6 +40,7 @@
bool flatten_control_flow{false};
bool run_test_hlo_passes{true};
bool run_reference_hlo_passes{true};
+ bool force_use_cpu_thunk_runtime_for_test{false};
// Using small float range by default, as otherwise all reductions
// miscompare vs. the interpreter with inf/nan.
bool use_large_float_range{false};
diff --git a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc
index fe82122..9b6138f 100644
--- a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc
+++ b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc
@@ -14,10 +14,12 @@
==============================================================================*/
#include <memory>
+#include <optional>
#include <string>
#include <vector>
#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
@@ -32,22 +34,41 @@
namespace xla {
namespace {
+std::vector<std::string> make_args(
+ const std::string& run_hlo_module_bin, const std::string& file_name,
+ const std::vector<std::string>& extra_args = {},
+ std::optional<std::string> input_literals_file = std::nullopt) {
+ std::string hlo_path = file_name[0] == '/'
+ ? file_name
+ : tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
+ "tools", "data", file_name);
+
+ std::vector<std::string> args = {run_hlo_module_bin, hlo_path,
+ "--platform=Host"};
+
+ args.insert(args.end(), extra_args.begin(), extra_args.end());
+
+ if (input_literals_file.has_value()) {
+ std::string input_path =
+ tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "data",
+ input_literals_file.value());
+ args.push_back(absl::StrCat("--input_literals_file=", input_path));
+ }
+
+ return args;
+}
+
class RunHloModuleTest : public ::testing::Test {
protected:
void RunHlo(const std::string& file_name,
- const std::vector<std::string>& extra_args = {}) {
+ const std::vector<std::string>& extra_args = {},
+ std::optional<std::string> input_literals_file = std::nullopt) {
std::string run_hlo_module_bin = tsl::io::JoinPath(
tsl::testing::XlaSrcRoot(), "tools", "run_hlo_module");
- std::string hlo_path = file_name[0] == '/'
- ? file_name
- : tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
- "tools", "data", file_name);
-
tsl::SubProcess proc;
- std::vector<std::string> args = {run_hlo_module_bin, hlo_path,
- "--platform=Host"};
- args.insert(args.end(), extra_args.begin(), extra_args.end());
+ auto args = make_args(run_hlo_module_bin, file_name, extra_args,
+ input_literals_file);
proc.SetProgram(run_hlo_module_bin, args);
proc.SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
proc.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
@@ -131,6 +152,22 @@
testing::Not(testing::HasSubstr("memory allocation bug")));
}
+TEST_F(RunHloModuleTest, ReadInputLiteralsFromFile) {
+ RunHlo("add.hlo",
+ /*extra_args=*/{"--print_literals=true", "--reference_platform="},
+ /*input_literals_file=*/"input_literal_f32_2_2.pbtxt");
+
+ EXPECT_TRUE(exited_normally_);
+ EXPECT_EQ(exit_status_, 0);
+
+ ASSERT_THAT(
+ stdout_output_,
+ testing::HasSubstr("{ 0.1, 0.2 },")); // First two values of the input
+ ASSERT_THAT(
+ stdout_output_,
+ testing::HasSubstr("{ 0.2, 0.4 },")); // First two values of the result
+}
+
TEST_F(RunHloModuleTest, AddSnapshot) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(R"(
diff --git a/third_party/xla/xla/tools/run_hlo_module_main.cc b/third_party/xla/xla/tools/run_hlo_module_main.cc
index 0b3ff5d..92e2e23 100644
--- a/third_party/xla/xla/tools/run_hlo_module_main.cc
+++ b/third_party/xla/xla/tools/run_hlo_module_main.cc
@@ -122,7 +122,13 @@
"other "
"than the reference this is necessary because some HLO passes are "
"legalization passes which must be run prior to code generation."),
-
+ tsl::Flag(
+ "force_use_cpu_thunk_runtime_for_test",
+ &opts.force_use_cpu_thunk_runtime_for_test,
+ "Use thunk runtime for the test platform. If true, thunks runtime "
+ "will be used for the test run regardless of the "
+ "xla_cpu_use_thunk_runtime flag in XLA_FLAGS. This option doesn't "
+ "impact reference run. It is ignored for platforms other than CPU."),
tsl::Flag("random_init_input_literals", &opts.random_init_input_literals,
"Initialize input literals with random numbers."
"Leave them uninitialized otherwise."),
@@ -252,9 +258,9 @@
&input_literals_proto);
}
- for (int i = 1; i <= iteration_count; ++i) {
+ for (int i = 0; i < iteration_count; ++i) {
if (iteration_count != 1) {
- std::cerr << "\n=== Iteration " << i << "\n";
+ std::cerr << "\n=== Iteration " << i + 1 << "\n";
}
xla::RunHloModuleIterationLiterals* iteration_literals_proto = nullptr;
if (!opts.output_literals_file.empty() ||
@@ -276,7 +282,7 @@
opts, iteration_literals_proto,
/*reference_module_modifier_hook=*/{},
[&](xla::HloModuleConfig* config) {
- config->set_seed(different_random_seeds ? i : 42);
+ config->set_seed(different_random_seeds ? i + 1 : 42);
});
if (result.ok()) {
diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc
index fe4289b..16d4d0d 100644
--- a/third_party/xla/xla/tools/xla_compile_lib.cc
+++ b/third_party/xla/xla/tools/xla_compile_lib.cc
@@ -68,7 +68,7 @@
#include "tsl/platform/statusor.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/executable.pb.h"
#include "xla/service/gpu/gpu_symbol_repository.h"
#include "xla/stream_executor/gpu/gpu_init.h"
diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_compile_lib_test.cc
deleted file mode 100644
index 101282c..0000000
--- a/third_party/xla/xla/tools/xla_compile_lib_test.cc
+++ /dev/null
@@ -1,326 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/tools/xla_compile_lib.h"
-
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-
-#include "google/protobuf/duration.pb.h"
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/gpu_symbol_repository.h"
-#include "xla/service/platform_util.h"
-#include "xla/service/symbol_repository.h"
-#include "xla/service/xla_compile_result.pb.h"
-#include "xla/stream_executor/device_description.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_macros.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/env_time.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-#include "tsl/protobuf/error_codes.pb.h"
-#include "tsl/protobuf/status.pb.h"
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#include "xla/service/gpu/autotuner_util.h"
-#endif
-
-namespace xla {
-namespace {
-
-using ::testing::IsEmpty;
-using ::testing::IsNull;
-using ::testing::Not;
-using ::tsl::testing::IsOk;
-using ::tsl::testing::IsOkAndHolds;
-using ::tsl::testing::StatusIs;
-
-#if XLA_TEST_BACKEND_CPU
-static constexpr absl::string_view kPlatformName = "Host";
-#elif XLA_TEST_BACKEND_GPU
-static constexpr absl::string_view kPlatformName =
-#if TENSORFLOW_USE_ROCM
- "ROCM";
-#else
- "CUDA";
-#endif
-#endif // XLA_TEST_BACKEND_CPU
-
-class XlaCompileLibTest : public HloTestBase {
- protected:
- XlaCompileLibTest()
- : HloTestBase(*PlatformUtil::GetPlatform(std::string(kPlatformName)),
- GetReferencePlatform()) {}
- void SetUp() override {
- const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
- "tools", "data", "add.hlo");
- std::string hlo;
- TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo));
- TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
- }
-
- std::unique_ptr<HloModule> module_;
-};
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(CompilesForCpu)) {
- CompilationResult result;
- EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kCpu,
- std::nullopt, result),
- IsOkAndHolds(Not(IsEmpty())));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) {
- CompilationResult result;
- EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
- std::nullopt, result),
- IsOkAndHolds(Not(IsEmpty())));
- EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) {
- const std::string target_config_path =
- tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service",
- "xla_aot_compile_test_gpu_target_config.prototxt");
- stream_executor::GpuTargetConfigProto target_config;
- TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path,
- &target_config));
- CompilationResult result;
- EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
- std::nullopt, result),
- IsOkAndHolds(Not(IsEmpty())));
- EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(ErrorsOnUnexpectedPlatform)) {
- XlaCompileOptions options;
- options.platform = "tpu";
- EXPECT_THAT(XlaCompileMain(options), StatusIs(tsl::error::UNIMPLEMENTED));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFilePropagatesErrors)) {
- TimerStats stats;
- CompilationResult result;
- EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk()));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFileWritesTheFile)) {
- std::string result_output_file;
- ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file));
-
- TimerStats stats;
- {
- absl::MutexLock ml(&stats.stats_mutex);
- stats.cumulative_secs = 5.5;
- stats.max_secs = 5.5;
- }
-
- CompilationResult result;
- google::protobuf::Duration duration;
- duration.set_seconds(5);
- duration.set_nanos(0.5 * tsl::EnvTime::kSecondsToNanos);
- *result.mutable_perf_stats()->mutable_compilation_duration() = duration;
- *result.mutable_perf_stats()->mutable_total_duration() = duration;
-
- TF_ASSERT_OK(WriteResultFile(result_output_file, stats, result));
-
- CompilationResult got_result;
- TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_output_file,
- &got_result));
- // Sadly EqualsProto isn't OSS, so we inspect a few fields manually.
- // See googletest#1761 and b/229726259.
- EXPECT_EQ(5, got_result.perf_stats().compilation_duration().seconds());
- EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
- got_result.perf_stats().compilation_duration().nanos());
- EXPECT_EQ(5, got_result.perf_stats().total_duration().seconds());
- EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
- got_result.perf_stats().total_duration().nanos());
-}
-
-TEST_F(XlaCompileLibTest, LoadModuleErrors) {
- EXPECT_THAT(LoadModule("/does/not/exist"), Not(IsOk()));
-}
-
-TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) {
- const std::string module_file =
- tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
- TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
- module_->ToString()));
-
- EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull())));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) {
- const std::string module_file =
- tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
- TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
- module_->ToString()));
-
- const std::string output_path =
- tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_output");
- const std::string result_file =
- tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_result.pb");
-
- XlaCompileOptions options;
- options.module_path = module_file;
- options.output_path = output_path;
- options.platform = "cpu";
- options.result_output_file = result_file;
- TF_EXPECT_OK(XlaCompileMain(options));
-
- CompilationResult result;
- TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
- EXPECT_TRUE(result.has_status());
- EXPECT_EQ(result.status().code(), tensorflow::error::OK);
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) {
- const std::string module_file =
- tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
- TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
- module_->ToString()));
-
- const std::string output_path =
- tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_output");
- const std::string result_file =
- tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_result.pb");
-
- XlaCompileOptions options;
- options.module_path = module_file;
- options.output_path = output_path;
- options.platform = "gpu";
- options.result_output_file = result_file;
- options.gpu_options.use_attached_device = true;
- TF_EXPECT_OK(XlaCompileMain(options));
-
- CompilationResult result;
- TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
- EXPECT_TRUE(result.has_status());
- EXPECT_EQ(result.status().code(), tensorflow::error::OK);
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(LoadAutotuneDataCpu)) {
- HloModuleAndMetadata mod;
- mod.hlo_module = std::move(module_);
-
- EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu),
- IsOkAndHolds(false));
-}
-
-TEST_F(XlaCompileLibTest,
- DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningEnabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
- gpu::AutotunerUtil::ClearAutotuneResults();
-
- HloModuleAndMetadata mod;
- mod.hlo_module = std::move(module_);
- auto data = std::make_unique<gpu::GpuBackendSpecificData>();
-
- AutotuneResults autotune_results;
- TF_ASSERT_OK(tsl::ReadTextProto(
- tsl::Env::Default(),
- tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
- "gpu_compiler_test_autotune_db.textproto"),
- &autotune_results));
- data->autotune_results = autotune_results;
- mod.backend_specific_data = std::move(data);
-
- DebugOptions opts = mod.hlo_module->config().debug_options();
- opts.set_xla_gpu_autotune_level(3);
- mod.hlo_module->mutable_config().set_debug_options(opts);
-
- EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
- IsOkAndHolds(true));
- EXPECT_FALSE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-TEST_F(XlaCompileLibTest,
- DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningDisabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
- gpu::AutotunerUtil::ClearAutotuneResults();
-
- HloModuleAndMetadata mod;
- mod.hlo_module = std::move(module_);
- auto data = std::make_unique<gpu::GpuBackendSpecificData>();
-
- AutotuneResults autotune_results;
- TF_ASSERT_OK(tsl::ReadTextProto(
- tsl::Env::Default(),
- tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
- "gpu_compiler_test_autotune_db.textproto"),
- &autotune_results));
- data->autotune_results = autotune_results;
- mod.backend_specific_data = std::move(data);
-
- DebugOptions opts = mod.hlo_module->config().debug_options();
- opts.set_xla_gpu_autotune_level(0);
- mod.hlo_module->mutable_config().set_debug_options(opts);
-
- EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
- IsOkAndHolds(false));
- EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-TEST_F(XlaCompileLibTest,
- DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
- gpu::AutotunerUtil::ClearAutotuneResults();
-
- HloModuleAndMetadata mod;
- mod.hlo_module = std::move(module_);
-
- DebugOptions opts = mod.hlo_module->config().debug_options();
- opts.set_xla_gpu_autotune_level(3);
- mod.hlo_module->mutable_config().set_debug_options(opts);
-
- EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
- IsOkAndHolds(false));
- EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-TEST_F(
- XlaCompileLibTest,
- DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
- gpu::AutotunerUtil::ClearAutotuneResults();
-
- HloModuleAndMetadata mod;
- mod.hlo_module = std::move(module_);
-
- DebugOptions opts = mod.hlo_module->config().debug_options();
- opts.set_xla_gpu_autotune_level(0);
- mod.hlo_module->mutable_config().set_debug_options(opts);
-
- EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
- IsOkAndHolds(false));
- EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-} // namespace
-} // namespace xla
diff --git a/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc
new file mode 100644
index 0000000..62c0673
--- /dev/null
+++ b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/duration.pb.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/synchronization/mutex.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/symbol_repository.h"
+#include "xla/service/xla_compile_result.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tools/xla_compile_lib.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/env_time.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+#include "tsl/protobuf/error_codes.pb.h"
+#include "tsl/protobuf/status.pb.h"
+
+namespace xla {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::IsNull;
+using ::testing::Not;
+using ::tsl::testing::IsOk;
+using ::tsl::testing::IsOkAndHolds;
+using ::tsl::testing::StatusIs;
+
+class XlaCompileLibTest : public HloTestBase {
+ protected:
+ XlaCompileLibTest()
+ : HloTestBase(*PlatformUtil::GetPlatform("Host"),
+ GetReferencePlatform()) {}
+ void SetUp() override {
+ const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
+ "tools", "data", "add.hlo");
+ std::string hlo;
+ TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo));
+ TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
+ }
+
+ std::unique_ptr<HloModule> module_;
+};
+
+TEST_F(XlaCompileLibTest, CompilesForCpu) {
+ CompilationResult result;
+ EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kCpu,
+ std::nullopt, result),
+ IsOkAndHolds(Not(IsEmpty())));
+}
+
+TEST_F(XlaCompileLibTest, ErrorsOnUnexpectedPlatform) {
+ XlaCompileOptions options;
+ options.platform = "tpu";
+ EXPECT_THAT(XlaCompileMain(options), StatusIs(tsl::error::UNIMPLEMENTED));
+}
+
+TEST_F(XlaCompileLibTest, WriteResultFilePropagatesErrors) {
+ TimerStats stats;
+ CompilationResult result;
+ EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk()));
+}
+
+TEST_F(XlaCompileLibTest, WriteResultFileWritesTheFile) {
+ std::string result_output_file;
+ ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file));
+
+ TimerStats stats;
+ {
+ absl::MutexLock ml(&stats.stats_mutex);
+ stats.cumulative_secs = 5.5;
+ stats.max_secs = 5.5;
+ }
+
+ CompilationResult result;
+ google::protobuf::Duration duration;
+ duration.set_seconds(5);
+ duration.set_nanos(0.5 * tsl::EnvTime::kSecondsToNanos);
+ *result.mutable_perf_stats()->mutable_compilation_duration() = duration;
+ *result.mutable_perf_stats()->mutable_total_duration() = duration;
+
+ TF_ASSERT_OK(WriteResultFile(result_output_file, stats, result));
+
+ CompilationResult got_result;
+ TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_output_file,
+ &got_result));
+ // Sadly EqualsProto isn't OSS, so we inspect a few fields manually.
+ // See googletest#1761 and b/229726259.
+ EXPECT_EQ(5, got_result.perf_stats().compilation_duration().seconds());
+ EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
+ got_result.perf_stats().compilation_duration().nanos());
+ EXPECT_EQ(5, got_result.perf_stats().total_duration().seconds());
+ EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
+ got_result.perf_stats().total_duration().nanos());
+}
+
+TEST_F(XlaCompileLibTest, LoadModuleErrors) {
+ EXPECT_THAT(LoadModule("/does/not/exist"), Not(IsOk()));
+}
+
+TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) {
+ const std::string module_file =
+ tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
+ TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
+ module_->ToString()));
+
+ EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull())));
+}
+
+TEST_F(XlaCompileLibTest, MainForCpu) {
+ const std::string module_file =
+ tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
+ TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
+ module_->ToString()));
+
+ const std::string output_path =
+ tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_output");
+ const std::string result_file =
+ tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_result.pb");
+
+ XlaCompileOptions options;
+ options.module_path = module_file;
+ options.output_path = output_path;
+ options.platform = "cpu";
+ options.result_output_file = result_file;
+ TF_EXPECT_OK(XlaCompileMain(options));
+
+ CompilationResult result;
+ TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
+ EXPECT_TRUE(result.has_status());
+ EXPECT_EQ(result.status().code(), tensorflow::error::OK);
+}
+
+TEST_F(XlaCompileLibTest, LoadAutotuneDataCpu) {
+ HloModuleAndMetadata mod;
+ mod.hlo_module = std::move(module_);
+
+ EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu),
+ IsOkAndHolds(false));
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc
new file mode 100644
index 0000000..bc34c87
--- /dev/null
+++ b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc
@@ -0,0 +1,195 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/gpu_symbol_repository.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/symbol_repository.h"
+#include "xla/service/xla_compile_result.pb.h"
+#include "xla/stream_executor/device_description.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tools/xla_compile_lib.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+#include "tsl/protobuf/error_codes.pb.h"
+#include "tsl/protobuf/status.pb.h"
+
+namespace xla {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::tsl::testing::IsOkAndHolds;
+
+class XlaCompileLibTest : public HloTestBase {
+ protected:
+ XlaCompileLibTest()
+ : HloTestBase(*PlatformUtil::GetPlatform(std::string("GPU")),
+ GetReferencePlatform()) {}
+ void SetUp() override {
+ const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
+ "tools", "data", "add.hlo");
+ std::string hlo;
+ TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo));
+ TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
+ }
+
+ std::unique_ptr<HloModule> module_;
+};
+
+TEST_F(XlaCompileLibTest, CompilesForGpuWithDevice) {
+ CompilationResult result;
+ EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
+ std::nullopt, result),
+ IsOkAndHolds(Not(IsEmpty())));
+ EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
+}
+
+TEST_F(XlaCompileLibTest, CompilesForGpuWithoutDevice) {
+ const std::string target_config_path =
+ tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service",
+ "xla_aot_compile_test_gpu_target_config.prototxt");
+ stream_executor::GpuTargetConfigProto target_config;
+ TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path,
+ &target_config));
+ CompilationResult result;
+ EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
+ std::nullopt, result),
+ IsOkAndHolds(Not(IsEmpty())));
+ EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
+}
+
+TEST_F(XlaCompileLibTest, MainForGpu) {
+ const std::string module_file =
+ tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
+ TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
+ module_->ToString()));
+
+ const std::string output_path =
+ tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_output");
+ const std::string result_file =
+ tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_result.pb");
+
+ XlaCompileOptions options;
+ options.module_path = module_file;
+ options.output_path = output_path;
+ options.platform = "gpu";
+ options.result_output_file = result_file;
+ options.gpu_options.use_attached_device = true;
+ TF_EXPECT_OK(XlaCompileMain(options));
+
+ CompilationResult result;
+ TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
+ EXPECT_TRUE(result.has_status());
+ EXPECT_EQ(result.status().code(), tensorflow::error::OK);
+}
+
+TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningEnabled) {
+ gpu::AutotunerUtil::ClearAutotuneResults();
+
+ HloModuleAndMetadata mod;
+ mod.hlo_module = std::move(module_);
+ auto data = std::make_unique<gpu::GpuBackendSpecificData>();
+
+ AutotuneResults autotune_results;
+ TF_ASSERT_OK(tsl::ReadTextProto(
+ tsl::Env::Default(),
+ tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
+ "gpu_compiler_test_autotune_db.textproto"),
+ &autotune_results));
+ data->autotune_results = autotune_results;
+ mod.backend_specific_data = std::move(data);
+
+ DebugOptions opts = mod.hlo_module->config().debug_options();
+ opts.set_xla_gpu_autotune_level(3);
+ mod.hlo_module->mutable_config().set_debug_options(opts);
+
+ EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+ IsOkAndHolds(true));
+ EXPECT_FALSE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningDisabled) {
+ gpu::AutotunerUtil::ClearAutotuneResults();
+
+ HloModuleAndMetadata mod;
+ mod.hlo_module = std::move(module_);
+ auto data = std::make_unique<gpu::GpuBackendSpecificData>();
+
+ AutotuneResults autotune_results;
+ TF_ASSERT_OK(tsl::ReadTextProto(
+ tsl::Env::Default(),
+ tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
+ "gpu_compiler_test_autotune_db.textproto"),
+ &autotune_results));
+ data->autotune_results = autotune_results;
+ mod.backend_specific_data = std::move(data);
+
+ DebugOptions opts = mod.hlo_module->config().debug_options();
+ opts.set_xla_gpu_autotune_level(0);
+ mod.hlo_module->mutable_config().set_debug_options(opts);
+
+ EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+ IsOkAndHolds(false));
+ EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(XlaCompileLibTest,
+ LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled) {
+ gpu::AutotunerUtil::ClearAutotuneResults();
+
+ HloModuleAndMetadata mod;
+ mod.hlo_module = std::move(module_);
+
+ DebugOptions opts = mod.hlo_module->config().debug_options();
+ opts.set_xla_gpu_autotune_level(3);
+ mod.hlo_module->mutable_config().set_debug_options(opts);
+
+ EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+ IsOkAndHolds(false));
+ EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(XlaCompileLibTest,
+ LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled) {
+ gpu::AutotunerUtil::ClearAutotuneResults();
+
+ HloModuleAndMetadata mod;
+ mod.hlo_module = std::move(module_);
+
+ DebugOptions opts = mod.hlo_module->config().debug_options();
+ opts.set_xla_gpu_autotune_level(0);
+ mod.hlo_module->mutable_config().set_debug_options(opts);
+
+ EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+ IsOkAndHolds(false));
+ EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD
index 2d1ae78..179ae72 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD
@@ -21,12 +21,34 @@
"//xla:xla_data_proto_cc",
"//xla/mlir_hlo",
"//xla/service:hlo_proto_cc",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
cc_library(
+ name = "async_importer",
+ srcs = ["async_importer.cc"],
+ hdrs = ["async_importer.h"],
+ deps = [
+ ":attribute_importer",
+ ":hlo_utils",
+ "//xla:util",
+ "//xla/hlo/ir:hlo",
+ "//xla/mlir_hlo",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+cc_library(
name = "custom_call_importer",
srcs = ["custom_call_importer.cc"],
hdrs = ["custom_call_importer.h"],
@@ -67,6 +89,7 @@
"hlo_module_importer.h",
],
deps = [
+ ":async_importer",
":attribute_importer",
":custom_call_importer",
":hlo_utils",
@@ -84,7 +107,6 @@
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo",
"//xla/service:hlo_proto_cc",
- "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
@@ -113,7 +135,11 @@
":hlo_module_importer",
"//xla:status_macros",
"//xla/mlir/utils:error_util",
+ "//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@llvm-project//mlir:IR",
+ "@local_tsl//tsl/platform:errors",
],
)
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc
new file mode 100644
index 0000000..57bc78a
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc
@@ -0,0 +1,383 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/translate/hlo_to_mhlo/async_importer.h"
+
+#include <cassert>
+#include <functional>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/translate/hlo_to_mhlo/attribute_importer.h"
+#include "xla/translate/hlo_to_mhlo/hlo_utils.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+
+namespace {
+
+constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
+constexpr char kShardingAttr[] = "mhlo.sharding";
+
+// ============
+// Imports an old-style async start op. E.g. an HLO all-gather-start
+// instruction is imported as an async-start associated with an all-gather
+// computation.
+//
+// Eventually, old-style async ops (e.g. all-gather-start) and new-style async
+// ops (i.e. async-start, async-update and async-done) will converge on the
+// HLO side, so we decided to not introduce new MHLO ops for all-gather-start
+// and friends.
+//
+// In the end, there may be new ops added in the old-style because they're not
+// compatible with the new-style async semantics, but those should be handled
+// on their own, rather than this function which "upgrades" ops to the
+// new-style async API.
+// ============
+template <typename sync_op>
+absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncStart(
+ mlir::SymbolTable& symbol_table,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
+ mlir::Type result_type, mlir::OpBuilder* builder, std::string func_name,
+ std::function<absl::Status(sync_op)> mutate_op) {
+ auto context = builder->getContext();
+ if (!llvm::isa<mlir::TupleType>(result_type)) {
+ return tsl::errors::InvalidArgument(
+ "expected async_bundle tuple result type");
+ }
+ auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+ if (result_types.size() < 2) {
+ return tsl::errors::InvalidArgument(
+ "async_bundle must contain at least two values");
+ }
+ auto func_type = mlir::FunctionType::get(context, Untuple(result_types[0]),
+ Untuple(result_types[1]));
+ auto function = mlir::func::FuncOp::create(loc, func_name, func_type);
+
+ // The new function doesn't need to be inserted in the beginning but is done
+ // to make testing easier and preserve the original behavior.
+ mlir::Block& block = symbol_table.getOp()->getRegion(0).front();
+ symbol_table.insert(function, mlir::Block::iterator(block.begin()));
+
+ function.setPrivate();
+ auto async_builder = mlir::OpBuilder(function.getBody());
+
+ llvm::SmallVector<mlir::NamedAttribute> async_attributes;
+ async_attributes.push_back(builder->getNamedAttr(
+ "called_computation",
+ mlir::FlatSymbolRefAttr::get(builder->getContext(), function.getName())));
+ async_attributes.push_back(builder->getNamedAttr(
+ "execution_thread", builder->getStringAttr("main")));
+
+ // Attach the frontend_attributes and sharding attributes to the async op
+ // instead of the sync op. First, semantically sharding attributes cannot be
+ // attached to the sync op since the sync op may not produce the same number
+ // of results as the sharding's tuple element count, e.g., `mhlo.send` vs. HLO
+ // `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from the
+ // `mhlo.async_start` ops, so attaching them to the sync op will make them
+ // disappear during MHLO to HLO lowering.
+ for (auto it = attributes.begin(); it != attributes.end();) {
+ if (it->getName() == kShardingAttr ||
+ it->getName() == kFrontendAttributesAttr) {
+ async_attributes.push_back(*it);
+ it = attributes.erase(it);
+ } else {
+ ++it;
+ }
+ }
+
+ llvm::SmallVector<mlir::Location, 1> locs(Untuple(result_types[0]).size(),
+ loc);
+ auto sync_operand =
+ async_builder
+ .createBlock(&function.getBody(), {}, Untuple(result_types[0]), locs)
+ ->getArguments();
+ auto sync_operation = async_builder.create<sync_op>(
+ loc, Untuple(result_types[1]), sync_operand, attributes);
+ async_builder.create<mlir::func::ReturnOp>(loc, sync_operation->getResults());
+ TF_RETURN_IF_ERROR(mutate_op(sync_operation));
+
+ function->setAttr("execution_thread", builder->getStringAttr("main"));
+
+ auto bundle_result_type =
+ mlir::mhlo::AsyncBundleType::get(context, result_types);
+ return builder
+ ->create<mlir::mhlo::AsyncStartOp>(loc, bundle_result_type, operands,
+ async_attributes)
+ .getOperation();
+}
+
+absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncDone(
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ bool useBundleResult = false) {
+ assert(operands.size() == 1 &&
+ "*-done ops must take only a single async_bundle operand");
+ auto async_start = operands[0].getDefiningOp<mlir::mhlo::AsyncStartOp>();
+ if (!async_start) return InvalidArgument("*-start requires *-done as input");
+ attributes.push_back(builder->getNamedAttr(
+ "called_computation",
+ mlir::FlatSymbolRefAttr::get(builder->getContext(),
+ async_start.getCalledComputation())));
+ attributes.push_back(builder->getNamedAttr("execution_thread",
+ builder->getStringAttr("main")));
+
+ auto async_bundle = llvm::cast<mlir::mhlo::AsyncBundleType>(
+ async_start.getResult().getType());
+
+ auto start_tuple =
+ llvm::dyn_cast<mlir::TupleType>(async_bundle.getTypes()[1]);
+ if (start_tuple && llvm::isa<mlir::TupleType>(start_tuple.getType(0))) {
+ auto op = builder->create<mlir::mhlo::AsyncDoneOp>(loc, result_type,
+ operands, attributes);
+ return {op};
+ } else {
+ if (useBundleResult) result_type = async_bundle.getTypes()[1];
+ auto op = builder->create<mlir::mhlo::AsyncDoneOp>(
+ loc, Untuple(result_type), operands, attributes);
+ return CreateTupleFromOpResults(builder, loc, op.getOperation(),
+ result_type);
+ }
+}
+
+} // namespace
+
+// Op Converters
+
+absl::StatusOr<mlir::Operation*> ImportSend(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table) {
+ auto send_op = Cast<HloSendInstruction>(instruction);
+ attributes.push_back(builder->getNamedAttr(
+ "is_host_transfer", builder->getBoolAttr(send_op->is_host_transfer())));
+ if (send_op->channel_id().has_value()) {
+ ChannelHandle channel_handle;
+ channel_handle.set_handle(send_op->channel_id().value());
+ channel_handle.set_type(send_op->is_host_transfer()
+ ? ChannelHandle::DEVICE_TO_HOST
+ : ChannelHandle::DEVICE_TO_DEVICE);
+ attributes.push_back(ConvertChannelHandle(channel_handle, builder));
+ }
+
+ // Return async_start/done for pipelined send.
+ //
+ // old-style send returns a bundle of (arg, sync flag, token) to be passed
+ // along to send-done.
+ // However, the new-style async ops have a shared bundle
+ // format of (args, results, scratchpad), so to rewrite the `send` and
+ // `send-done` ops to use the new-style async API, we need to reorder the
+ // arguments to be in (args, token, sync flag) order.
+ auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+ if (result_types.size() != 3)
+ return InvalidArgument("send should return a 3-tuple");
+ auto async_arg_type = mlir::TupleType::get(
+ builder->getContext(), {result_types[0], result_types[2]});
+ auto async_bundled_tuple =
+ mlir::TupleType::get(builder->getContext(),
+ {async_arg_type, result_types[2], result_types[1]});
+ return ImportOldStyleAsyncStart<mlir::mhlo::SendOp>(
+ symbol_table, attributes, operands, loc, async_bundled_tuple, builder,
+ "send_", [](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportRecv(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table) {
+ auto recv_op = Cast<HloRecvInstruction>(instruction);
+ attributes.push_back(builder->getNamedAttr(
+ "is_host_transfer", builder->getBoolAttr(recv_op->is_host_transfer())));
+ if (recv_op->channel_id().has_value()) {
+ ChannelHandle channel_handle;
+ channel_handle.set_handle(recv_op->channel_id().value());
+ channel_handle.set_type(recv_op->is_host_transfer()
+ ? ChannelHandle::HOST_TO_DEVICE
+ : ChannelHandle::DEVICE_TO_DEVICE);
+ attributes.push_back(ConvertChannelHandle(channel_handle, builder));
+ }
+
+ // Old-style `recv` returns a bundle of (result, sync flag, token) to be
+ // passed along to recv-done.
+ // However, the new-style async ops have a shared
+ // bundle format of (args, results, scratchpad), so to rewrite the `recv`
+ // and `recv-done` ops to use the new-style async API, we need to reorder
+ // the arguments to be in (token, (result, token), sync flag) order.
+ // OR (token, token, sync flag) if no result is received.
+ auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+ if (result_types.size() != 3)
+ return InvalidArgument("recv should return a 3-tuple");
+
+ // Allow recv of no values, only token.
+ // b/TODO: Allow recv of no values, only token.
+ auto async_result_type = mlir::TupleType::get(
+ builder->getContext(), {result_types[0], result_types[2]});
+ auto async_bundled_tuple = mlir::TupleType::get(
+ builder->getContext(),
+ {result_types[2], async_result_type, result_types[1]});
+ return ImportOldStyleAsyncStart<mlir::mhlo::RecvOp>(
+ symbol_table, attributes, operands, loc, async_bundled_tuple, builder,
+ "recv_", [](auto) { return absl::OkStatus(); });
+}
+
+// Async Collectives
+
+absl::StatusOr<mlir::Operation*> ImportAllGatherStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table) {
+ auto all_gather_start = Cast<HloAllGatherInstruction>(instruction);
+ attributes.push_back(builder->getNamedAttr(
+ "all_gather_dim",
+ builder->getI64IntegerAttr(all_gather_start->all_gather_dimension())));
+ attributes.push_back(
+ ConvertReplicaGroups(all_gather_start->replica_groups(), builder));
+ if (all_gather_start->channel_id().has_value())
+ attributes.push_back(
+ ConvertChannelHandle(all_gather_start->channel_id().value(), builder));
+ if (all_gather_start->use_global_device_ids())
+ attributes.push_back(ConvertUseGlobalDeviceIds(builder));
+ if (all_gather_start->operands().size() > 1)
+ return InvalidArgument("Async tuple all-gather is not supported in MHLO");
+
+ if (!llvm::isa<mlir::TupleType>(result_type)) {
+ // Async AllGather's output type is bundle<input_type,output_type>
+ // There are some instances where the output type is not a tuple, this seems
+ // to be the more modern case, so we will wrap these in a tuple for MHLO.
+ result_type = mlir::TupleType::get(builder->getContext(),
+ {operands[0].getType(), result_type});
+ }
+
+ return ImportOldStyleAsyncStart<mlir::mhlo::AllGatherOp>(
+ symbol_table, attributes, operands, loc, result_type, builder,
+ "all_gather_", [](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportAllReduceStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ std::function<absl::Status(mlir::mhlo::AllReduceOp)> mutate_op,
+ mlir::SymbolTable& symbol_table) {
+ auto all_reduce_start = Cast<HloAllReduceInstruction>(instruction);
+ attributes.push_back(
+ ConvertReplicaGroups(all_reduce_start->replica_groups(), builder));
+ if (all_reduce_start->channel_id().has_value())
+ attributes.push_back(
+ ConvertChannelHandle(all_reduce_start->channel_id().value(), builder));
+ if (all_reduce_start->use_global_device_ids())
+ attributes.push_back(ConvertUseGlobalDeviceIds(builder));
+ if (all_reduce_start->operands().size() > 1)
+ return InvalidArgument("Async tuple all-reduce is not supported in MHLO");
+
+ if (!llvm::isa<mlir::TupleType>(result_type)) {
+ // Async AllReduce's output type is bundle<input_type,output_type>
+ // There are some instances where the output type is not a tuple, this seems
+ // to be the more modern case, so we will wrap these in a tuple for MHLO.
+ result_type = mlir::TupleType::get(builder->getContext(),
+ {operands[0].getType(), result_type});
+ }
+
+ return ImportOldStyleAsyncStart<mlir::mhlo::AllReduceOp>(
+ symbol_table, attributes, operands, loc, result_type, builder,
+ "all_reduce_", mutate_op);
+}
+
+// Collective Permute
+
+absl::StatusOr<mlir::Operation*> ImportCollectivePermuteStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table) {
+ attributes.push_back(
+ ConvertSourceTargetPairs(instruction->source_target_pairs(), builder));
+ if (!llvm::isa<mlir::TupleType>(result_type)) {
+ // Async CollectivePermute's output type is bundle<input_type,output_type>
+ // There are some instances where the output type is not a tuple, this seems
+ // to be the more modern case, so we will wrap these in a tuple for MHLO.
+ result_type = mlir::TupleType::get(builder->getContext(),
+ {operands[0].getType(), result_type});
+ }
+ return ImportOldStyleAsyncStart<mlir::mhlo::CollectivePermuteOp>(
+ symbol_table, attributes, operands, loc, result_type, builder,
+ "collective_permute_", [&](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportCopyStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table) {
+ auto context = builder->getContext();
+ auto copy_start_instruction = Cast<HloCopyStartInstruction>(instruction);
+ if (auto cross_program_prefetch_index =
+ copy_start_instruction->cross_program_prefetch_index()) {
+ attributes.push_back(builder->getNamedAttr(
+ "cross_program_prefetch_index",
+ builder->getIntegerAttr(builder->getIntegerType(32),
+ *cross_program_prefetch_index)));
+ // Cross-program prefetch allows copy ops to accept tuples, in which
+ // case, we need to double-wrap inputs and outputs in tuples.
+ if (operands[0].getType().isa<mlir::TupleType>()) {
+ auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+ result_type = mlir::TupleType::get(
+ context,
+ {mlir::TupleType::get(context, {result_types[0]}),
+ mlir::TupleType::get(context, {result_types[1]}), result_types[2]});
+ }
+ }
+ return ImportOldStyleAsyncStart<mlir::mhlo::CopyOp>(
+ symbol_table, attributes, operands, loc, result_type, builder, "copy_",
+ [](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportAsyncOpDone(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder) {
+ return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
+ builder);
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h
new file mode 100644
index 0000000..efdd487
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h
@@ -0,0 +1,88 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_
+#define XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_
+
+#include <functional>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace xla {
+
+// Op Converters
+absl::StatusOr<mlir::Operation*> ImportSend(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportRecv(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table);
+
+// Async Collectives
+absl::StatusOr<mlir::Operation*> ImportAllGatherStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportAllReduceStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ std::function<absl::Status(mlir::mhlo::AllReduceOp)> mutate_op,
+ mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportCollectivePermuteStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportCopyStart(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder,
+ mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportAsyncOpDone(
+ const HloInstruction* instruction, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value>& operands,
+ llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+ mlir::Type result_type, mlir::OpBuilder* builder);
+
+} // namespace xla
+
+#endif // XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc
index 9a2fa06..fbe1a90 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc
@@ -17,13 +17,24 @@
#include <sys/types.h>
+#include <algorithm>
+#include <cstdint>
#include <optional>
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
#include "absl/status/statusor.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
@@ -195,6 +206,62 @@
}
}
+mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel,
+ mlir::Builder* builder) {
+ return builder->getNamedAttr(
+ "channel_handle",
+ mlir::mhlo::ChannelHandleAttr::get(builder->getContext(),
+ channel.handle(), channel.type()));
+}
+mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id,
+ mlir::Builder* builder) {
+ ChannelHandle channel_handle;
+ if (channel_id) channel_handle.set_handle(*channel_id);
+ return ConvertChannelHandle(channel_handle, builder);
+}
+
+mlir::NamedAttribute ConvertReplicaGroups(
+ absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
+ const int64_t num_groups = replica_groups.size();
+ // Replica groups in HLO can be non-uniform in size, for example:
+ // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
+ // tensor, pad the smaller sized replica groups with -1.
+ const int64_t group_size = absl::c_accumulate(
+ replica_groups, static_cast<int64_t>(0),
+ [](int64_t current, const ReplicaGroup& g) {
+ return std::max<int64_t>(current, g.replica_ids_size());
+ });
+ // Initialize all elements to -1 to support non-uniform replica groups.
+ std::vector<int64_t> attr(num_groups * group_size, -1);
+ for (int i = 0; i < num_groups; ++i) {
+ int index = i * group_size;
+ for (const int64_t& id : replica_groups[i].replica_ids())
+ attr[index++] = id;
+ }
+ auto type = mlir::RankedTensorType::get({num_groups, group_size},
+ builder->getIntegerType(64));
+ return builder->getNamedAttr("replica_groups",
+ mlir::DenseIntElementsAttr::get(type, attr));
+}
+
+mlir::NamedAttribute ConvertSourceTargetPairs(
+ const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
+ mlir::Builder* builder) {
+ std::vector<int64_t> attr(source_target_pairs.size() * 2);
+ for (const auto& p : llvm::enumerate(source_target_pairs)) {
+ attr[2 * p.index()] = p.value().first;
+ attr[2 * p.index() + 1] = p.value().second;
+ }
+ auto type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
+ return builder->getNamedAttr("source_target_pairs",
+ mlir::DenseIntElementsAttr::get(type, attr));
+}
+
+mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder) {
+ return builder->getNamedAttr("use_global_device_ids", builder->getUnitAttr());
+}
+
absl::StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder) {
std::vector<mlir::Attribute> layouts;
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h
index 4f1ba9e..f836814 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h
@@ -16,10 +16,14 @@
#ifndef XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_
#define XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_
+#include <cstdint>
+#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
+#include "absl/types/span.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/hlo.pb.h"
@@ -66,6 +70,20 @@
absl::StatusOr<mlir::mhlo::CustomCallApiVersion> ConvertCustomCallApiVersion(
xla::CustomCallApiVersion api_version);
+mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel,
+ mlir::Builder* builder);
+mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id,
+ mlir::Builder* builder);
+
+mlir::NamedAttribute ConvertReplicaGroups(
+ absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder);
+
+mlir::NamedAttribute ConvertSourceTargetPairs(
+ const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
+ mlir::Builder* builder);
+
+mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder);
+
// Extracts layouts from shapes and converts it into layout attributes (array of
// rank-1 index tensors). Returns an error if any of the shapes is a tuple.
absl::StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 64e08df..6490839 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -15,7 +15,6 @@
#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h"
-#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -26,7 +25,6 @@
#include <utility>
#include <vector>
-#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@@ -37,6 +35,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/AsmParser/AsmParser.h"
@@ -71,6 +70,7 @@
#include "xla/service/hlo.pb.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
+#include "xla/translate/hlo_to_mhlo/async_importer.h"
#include "xla/translate/hlo_to_mhlo/attribute_importer.h"
#include "xla/translate/hlo_to_mhlo/custom_call_importer.h"
#include "xla/translate/hlo_to_mhlo/hlo_utils.h"
@@ -90,6 +90,8 @@
using mlir::Value;
using mlir::func::FuncOp;
+#define DEBUG_TYPE "xla-translate"
+
namespace xla {
namespace {
@@ -100,7 +102,7 @@
// Note: This sanitization function causes an irreversible many-to-one mapping
// and any solution to mitigate this would cause issues with the reverse
-// direction. Longterm solution is to add a function attribute to maintain the
+// direction. Long-term solution is to add a function attribute to maintain the
// original HLO naming.
std::string SanitizeFunctionName(llvm::StringRef name) {
std::string output(name);
@@ -170,115 +172,6 @@
} // namespace
-mlir::TypeRange Untuple(const mlir::Type& type) {
- if (type.isa<mlir::TupleType>()) {
- return llvm::dyn_cast<mlir::TupleType>(type).getTypes();
- }
- return type;
-}
-
-template <typename sync_op>
-absl::StatusOr<mlir::Operation*> HloFunctionImporter::ImportOldStyleAsyncStart(
- llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
- const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
- mlir::Type result_type, mlir::OpBuilder* func_builder,
- std::string func_name, std::function<absl::Status(sync_op)> mutate_op) {
- auto result_types = result_type.cast<mlir::TupleType>().getTypes();
- if (result_types.size() < 2) {
- return tsl::errors::InvalidArgument(
- "async_bundle must contain at least two values");
- }
- auto func_type = mlir::FunctionType::get(context_, Untuple(result_types[0]),
- Untuple(result_types[1]));
- auto function = FuncOp::create(loc, func_name, func_type);
-
- // The new function doesn't need to be inserted in the beginning but is done
- // to make testing easier and preserve the original behavior.
- mlir::Block& block = symbol_table_.getOp()->getRegion(0).front();
- symbol_table_.insert(function, mlir::Block::iterator(block.begin()));
-
- function.setPrivate();
- auto async_builder = mlir::OpBuilder(function.getBody());
-
- llvm::SmallVector<mlir::NamedAttribute> async_attributes;
- async_attributes.push_back(builder_->getNamedAttr(
- "called_computation", mlir::FlatSymbolRefAttr::get(builder_->getContext(),
- function.getName())));
- async_attributes.push_back(builder_->getNamedAttr(
- "execution_thread", builder_->getStringAttr("main")));
-
- // Attach the frontend_attributes and sharding attributes to the async op
- // instead of the sync op. First, semantically sharding attributes cannot be
- // attached to the sync op since the sync op may not produce the same number
- // of results as the sharding's tuple element count, e.g., `mhlo.send` vs. HLO
- // `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from the
- // `mhlo.async_start` ops, so attaching them to the sync op will make them
- // disappear during MHLO to HLO lowering.
- for (auto it = attributes.begin(); it != attributes.end();) {
- if (it->getName() == kShardingAttr ||
- it->getName() == kFrontendAttributesAttr) {
- async_attributes.push_back(*it);
- it = attributes.erase(it);
- } else {
- ++it;
- }
- }
-
- llvm::SmallVector<mlir::Location, 1> locs(Untuple(result_types[0]).size(),
- loc);
- auto sync_operand =
- async_builder
- .createBlock(&function.getBody(), {}, Untuple(result_types[0]), locs)
- ->getArguments();
- auto sync_operation = async_builder.create<sync_op>(
- loc, Untuple(result_types[1]), sync_operand, attributes);
- async_builder.create<mlir::func::ReturnOp>(loc, sync_operation->getResults());
- TF_RETURN_IF_ERROR(mutate_op(sync_operation));
-
- function->setAttr("execution_thread", builder_->getStringAttr("main"));
-
- auto bundle_result_type =
- mlir::mhlo::AsyncBundleType::get(context_, result_types);
- return func_builder
- ->create<mlir::mhlo::AsyncStartOp>(loc, bundle_result_type, operands,
- async_attributes)
- .getOperation();
-}
-
-absl::StatusOr<mlir::Operation*> HloFunctionImporter::ImportOldStyleAsyncDone(
- llvm::SmallVectorImpl<NamedAttribute>& attributes,
- const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
- mlir::Type result_type, mlir::OpBuilder* func_builder) {
- if (operands.size() != 1) {
- return InvalidArgument(
- "async-done must take only a single async_bundle operand");
- }
- auto async_start = operands[0].getDefiningOp<mlir::mhlo::AsyncStartOp>();
- if (!async_start) return InvalidArgument("*-start requires *-done as input");
- attributes.push_back(builder_->getNamedAttr(
- "called_computation",
- mlir::FlatSymbolRefAttr::get(builder_->getContext(),
- async_start.getCalledComputation())));
- attributes.push_back(builder_->getNamedAttr("execution_thread",
- builder_->getStringAttr("main")));
-
- auto start_tuple = async_start.getResult()
- .getType()
- .cast<mlir::mhlo::AsyncBundleType>()
- .getTypes()[1]
- .dyn_cast<mlir::TupleType>();
- if (start_tuple && start_tuple.getType(0).isa<mlir::TupleType>()) {
- auto op = func_builder->create<mlir::mhlo::AsyncDoneOp>(
- loc, result_type, operands, attributes);
- return {op};
- } else {
- auto op = func_builder->create<mlir::mhlo::AsyncDoneOp>(
- loc, Untuple(result_type), operands, attributes);
- return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
- result_type);
- }
-}
-
void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands(
mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands) {
assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
@@ -296,20 +189,6 @@
}
}
-mlir::Operation* HloFunctionImporter::CreateTupleFromOpResults(
- mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op,
- mlir::Type type) {
- if (!type.isa<mlir::TupleType>()) return op;
-
- mlir::ValueRange flattened_results_ref(op->getResults());
- auto result =
- CreateTupleValue(func_builder, loc, flattened_results_ref, type);
- auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
- assert(defining_tuple_op && "builder didn't return the right type");
- auto tupleOp = defining_tuple_op.getOperation();
- return tupleOp;
-}
-
static bool IsNestedTupleInData(Type type) {
auto tuple_type = type.dyn_cast<mlir::TupleType>();
if (!tuple_type) return false;
@@ -381,27 +260,6 @@
return flattened_values;
}
-Value HloFunctionImporter::CreateTupleValue(mlir::OpBuilder* func_builder,
- mlir::Location loc,
- mlir::ValueRange& flatten_values,
- Type type) {
- auto tuple_type = type.dyn_cast<mlir::TupleType>();
- if (!tuple_type) {
- assert(!flatten_values.empty());
- auto retval = flatten_values.front();
- flatten_values = flatten_values.drop_front();
- return retval;
- }
-
- llvm::SmallVector<mlir::Value> flatten_sub_values;
- for (auto child_type : tuple_type.getTypes())
- flatten_sub_values.push_back(
- CreateTupleValue(func_builder, loc, flatten_values, child_type));
-
- return func_builder->create<mlir::mhlo::TupleOp>(loc, flatten_sub_values)
- .getResult();
-}
-
absl::StatusOr<mlir::func::FuncOp> HloFunctionImporter::ImportAsFunc(
const HloComputation& computation, mlir::SymbolTable& symbol_table,
std::unordered_map<const HloComputation*, FuncOp>* function_map,
@@ -761,7 +619,10 @@
frontend_attributes.push_back(
builder_->getNamedAttr(k, builder_->getStringAttr(v)));
}
+
+ int frontend_attributes_index = 0;
if (!frontend_attributes.empty()) {
+ frontend_attributes_index = attributes.size();
attributes.push_back(builder_->getNamedAttr(
kFrontendAttributesAttr,
builder_->getDictionaryAttr(frontend_attributes)));
@@ -926,6 +787,68 @@
FuncOp function,
ImportAsFunc(*instruction->to_apply(), /*is_main=*/false));
mlir::Operation* new_operation;
+ if (instruction->is_composite()) {
+ // TODO: b/354721812 - Support flatten_computation_args_result_ flag
+ // for composite calls
+
+ mlir::DictionaryAttr frontend_attributes_attr =
+ builder_->getDictionaryAttr(frontend_attributes);
+ if (frontend_attributes.empty() ||
+ !frontend_attributes_attr.contains("composite.attributes") ||
+ !frontend_attributes_attr.contains("composite.name") ||
+ !frontend_attributes_attr.contains("composite.version")) {
+ return InvalidArgument(
+ "A composite call op must have frontend attributes with the "
+ "following keys: composite.attributes, composite.name, "
+ "composite.version");
+ }
+
+ llvm::SmallVector<NamedAttribute, 4> fe_attrs_without_composite_attrs;
+ for (const auto& attr : frontend_attributes) {
+ if (attr.getName() != "composite.attributes" &&
+ attr.getName() != "composite.name" &&
+ attr.getName() != "composite.version") {
+ fe_attrs_without_composite_attrs.push_back(attr);
+ }
+ }
+
+ // Frontend attributes may have been created by composite related
+ // attributes. If frontend attributes is empty after removing
+ // composite related attributes, it is not needed, so we remove it
+ // entirely. Otherwise, we update it.
+ if (fe_attrs_without_composite_attrs.empty()) {
+ attributes.erase(attributes.begin() + frontend_attributes_index);
+ } else {
+ attributes[frontend_attributes_index] = builder_->getNamedAttr(
+ kFrontendAttributesAttr,
+ builder_->getDictionaryAttr(fe_attrs_without_composite_attrs));
+ }
+
+ auto frontend_attributes_map = instruction->frontend_attributes().map();
+ mlir::StringAttr name = builder_->getStringAttr(
+ frontend_attributes_map.find("composite.name")->second);
+ mlir::Attribute composite_attributes = mlir::parseAttribute(
+ frontend_attributes_map.find("composite.attributes")->second,
+ builder_->getContext());
+ mlir::FlatSymbolRefAttr decomposition = mlir::SymbolRefAttr::get(
+ builder_->getContext(), instruction->to_apply()->name());
+ mlir::IntegerAttr version = builder_->getIntegerAttr(
+ builder_->getI32Type(),
+ std::stoi(
+ frontend_attributes_map.find("composite.version")->second));
+
+ new_operation = func_builder->create<mlir::mhlo::CompositeOp>(
+ loc, result_type, operands);
+ new_operation->setAttr("name", name);
+ new_operation->setAttr("composite_attributes", composite_attributes);
+ new_operation->setAttr("decomposition", decomposition);
+ new_operation->setAttr("version", version);
+ for (const auto& attr : attributes) {
+ new_operation->setAttr(attr.getName(), attr.getValue());
+ }
+ return new_operation;
+ }
+
if (flatten_computation_args_result_) {
// Flatten the tuple-typed operands.
llvm::SmallVector<Value> flattened_operands = FlattenTupleValues(
@@ -946,7 +869,7 @@
} else {
new_operation =
func_builder->create<mlir::func::CallOp>(loc, function, operands);
- for (auto attr : attributes) {
+ for (const auto& attr : attributes) {
new_operation->setAttr(attr.getName(), attr.getValue());
}
}
@@ -957,8 +880,8 @@
attributes.push_back(ConvertReplicaGroups(
collective_broadcast->replica_groups(), builder_));
if (collective_broadcast->channel_id().has_value())
- attributes.push_back(
- ConvertChannelHandle(collective_broadcast->channel_id().value()));
+ attributes.push_back(ConvertChannelHandle(
+ collective_broadcast->channel_id().value(), builder_));
return func_builder
->create<mlir::mhlo::CollectiveBroadcastOp>(loc, result_type,
operands, attributes)
@@ -970,23 +893,21 @@
attributes.push_back(ConvertSourceTargetPairs(
collective_permute->source_target_pairs(), builder_));
if (collective_permute->channel_id().has_value())
- attributes.push_back(
- ConvertChannelHandle(collective_permute->channel_id().value()));
+ attributes.push_back(ConvertChannelHandle(
+ collective_permute->channel_id().value(), builder_));
return func_builder
->create<mlir::mhlo::CollectivePermuteOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCollectivePermuteStart: {
- attributes.push_back(ConvertSourceTargetPairs(
- instruction->source_target_pairs(), builder_));
- return ImportOldStyleAsyncStart<mlir::mhlo::CollectivePermuteOp>(
- attributes, operands, loc, result_type, func_builder,
- "collective_permute_", [&](auto) { return absl::OkStatus(); });
+ return ImportCollectivePermuteStart(instruction, loc, operands,
+ attributes, result_type, func_builder,
+ symbol_table_);
}
case HloOpcode::kCollectivePermuteDone: {
- return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
- func_builder);
+ return ImportAsyncOpDone(instruction, loc, operands, attributes,
+ result_type, func_builder);
}
case HloOpcode::kCustomCall: {
auto custom_call = Cast<HloCustomCallInstruction>(instruction);
@@ -1310,103 +1231,31 @@
.getOperation();
}
case HloOpcode::kCopyStart: {
- auto copy_start_instruction = Cast<HloCopyStartInstruction>(instruction);
- if (auto cross_program_prefetch_index =
- copy_start_instruction->cross_program_prefetch_index()) {
- attributes.push_back(builder_->getNamedAttr(
- "cross_program_prefetch_index",
- builder_->getIntegerAttr(builder_->getIntegerType(32),
- *cross_program_prefetch_index)));
- // Cross-program prefetch allows copy ops to accept tuples, in which
- // case, we need to double-wrap inputs and outputs in tuples.
- if (operands[0].getType().isa<mlir::TupleType>()) {
- auto result_types = result_type.cast<mlir::TupleType>().getTypes();
- result_type = mlir::TupleType::get(
- context_, {mlir::TupleType::get(context_, {result_types[0]}),
- mlir::TupleType::get(context_, {result_types[1]}),
- result_types[2]});
- }
- }
- return ImportOldStyleAsyncStart<mlir::mhlo::CopyOp>(
- attributes, operands, loc, result_type, func_builder, "copy_",
- [](auto) { return absl::OkStatus(); });
+ return ImportCopyStart(instruction, loc, operands, attributes,
+ result_type, func_builder, symbol_table_);
}
case HloOpcode::kCopyDone: {
- return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
- func_builder);
+ return ImportAsyncOpDone(instruction, loc, operands, attributes,
+ result_type, func_builder);
}
case HloOpcode::kSend: {
- // old-style send returns a bundle of (arg, sync flag, token) to be passed
- // along to send-done.
- // However, the new-style async ops have a shared bundle
- // format of (args, results, scratchpad), so to rewrite the `send` and
- // `send-done` ops to use the new-style async API, we need to reorder the
- // arguments to be in (args, token, sync flag) order.
- auto result_types = result_type.cast<mlir::TupleType>().getTypes();
- if (result_types.size() != 3)
- return InvalidArgument("send should return a 3-tuple");
- auto async_arg_type =
- mlir::TupleType::get(context_, {result_types[0], result_types[2]});
- auto async_bundled_tuple = mlir::TupleType::get(
- context_, {async_arg_type, result_types[2], result_types[1]});
- auto send_op = Cast<HloSendInstruction>(instruction);
- attributes.push_back(builder_->getNamedAttr(
- "is_host_transfer",
- builder_->getBoolAttr(send_op->is_host_transfer())));
- if (send_op->channel_id().has_value()) {
- ChannelHandle channel_handle;
- channel_handle.set_handle(send_op->channel_id().value());
- channel_handle.set_type(send_op->is_host_transfer()
- ? ChannelHandle::DEVICE_TO_HOST
- : ChannelHandle::DEVICE_TO_DEVICE);
- attributes.push_back(ConvertChannelHandle(channel_handle));
- }
- return ImportOldStyleAsyncStart<mlir::mhlo::SendOp>(
- attributes, operands, loc, async_bundled_tuple, func_builder, "send_",
- [](auto) { return absl::OkStatus(); });
+ return ImportSend(instruction, loc, operands, attributes, result_type,
+ func_builder, symbol_table_);
}
case HloOpcode::kSendDone: {
- return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
- func_builder);
+ return ImportAsyncOpDone(instruction, loc, operands, attributes,
+ result_type, func_builder);
}
case HloOpcode::kRecv: {
- // Old-style `recv` returns a bundle of (result, sync flag, token) to be
- // passed along to recv-done.
- // However, the new-style async ops have a shared
- // bundle format of (args, results, scratchpad), so to rewrite the `recv`
- // and `recv-done` ops to use the new-style async API, we need to reorder
- // the arguments to be in (token, (result, token), sync flag) order.
- auto result_types = result_type.cast<mlir::TupleType>().getTypes();
- if (result_types.size() != 3)
- return InvalidArgument("recv should return a 3-tuple");
- auto async_result_type =
- mlir::TupleType::get(context_, {result_types[0], result_types[2]});
- auto async_bundled_tuple = mlir::TupleType::get(
- context_, {result_types[2], async_result_type, result_types[1]});
- auto recv_op = Cast<HloRecvInstruction>(instruction);
- attributes.push_back(builder_->getNamedAttr(
- "is_host_transfer",
- builder_->getBoolAttr(recv_op->is_host_transfer())));
- if (recv_op->channel_id().has_value()) {
- ChannelHandle channel_handle;
- channel_handle.set_handle(recv_op->channel_id().value());
- channel_handle.set_type(recv_op->is_host_transfer()
- ? ChannelHandle::HOST_TO_DEVICE
- : ChannelHandle::DEVICE_TO_DEVICE);
- attributes.push_back(ConvertChannelHandle(channel_handle));
- }
- return ImportOldStyleAsyncStart<mlir::mhlo::RecvOp>(
- attributes, operands, loc, async_bundled_tuple, func_builder, "recv_",
- [](auto) { return absl::OkStatus(); });
+ return ImportRecv(instruction, loc, operands, attributes, result_type,
+ func_builder, symbol_table_);
}
case HloOpcode::kRecvDone: {
- return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
- func_builder);
+ return ImportAsyncOpDone(instruction, loc, operands, attributes,
+ result_type, func_builder);
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;
-
- // Flatten the tuple-typed operands.
llvm::SmallVector<Value> flattened_operands =
FlattenTupleValues(func_builder, loc, operands);
@@ -1498,9 +1347,9 @@
ConvertReplicaGroups(all_gather->replica_groups(), builder_));
if (all_gather->channel_id().has_value())
attributes.push_back(
- ConvertChannelHandle(all_gather->channel_id().value()));
+ ConvertChannelHandle(all_gather->channel_id().value(), builder_));
if (all_gather->use_global_device_ids())
- attributes.push_back(ConvertUseGlobalDeviceIds());
+ attributes.push_back(ConvertUseGlobalDeviceIds(builder_));
auto all_gather_op = func_builder->create<mlir::mhlo::AllGatherOp>(
loc, result_types, operands, attributes);
if (result_tuple_ty) {
@@ -1512,28 +1361,12 @@
return all_gather_op.getOperation();
}
case HloOpcode::kAllGatherStart: {
- auto all_gather_start = Cast<HloAllGatherInstruction>(instruction);
- attributes.push_back(builder_->getNamedAttr(
- "all_gather_dim", builder_->getI64IntegerAttr(
- all_gather_start->all_gather_dimension())));
- attributes.push_back(
- ConvertReplicaGroups(all_gather_start->replica_groups(), builder_));
- if (all_gather_start->channel_id().has_value())
- attributes.push_back(
- ConvertChannelHandle(all_gather_start->channel_id().value()));
- if (all_gather_start->use_global_device_ids())
- attributes.push_back(ConvertUseGlobalDeviceIds());
- if (all_gather_start->operands().size() > 1)
- return InvalidArgument(
- "Async tuple all-gather is not supported in MHLO");
-
- return ImportOldStyleAsyncStart<mlir::mhlo::AllGatherOp>(
- attributes, operands, loc, result_type, func_builder, "all_gather_",
- [](auto) { return absl::OkStatus(); });
+ return ImportAllGatherStart(instruction, loc, operands, attributes,
+ result_type, func_builder, symbol_table_);
}
case HloOpcode::kAllGatherDone: {
- return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
- func_builder);
+ return ImportAsyncOpDone(instruction, loc, operands, attributes,
+ result_type, func_builder);
}
case HloOpcode::kAllReduce: {
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
@@ -1548,9 +1381,9 @@
ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
if (all_reduce->channel_id().has_value())
attributes.push_back(
- ConvertChannelHandle(all_reduce->channel_id().value()));
+ ConvertChannelHandle(all_reduce->channel_id().value(), builder_));
if (all_reduce->use_global_device_ids())
- attributes.push_back(ConvertUseGlobalDeviceIds());
+ attributes.push_back(ConvertUseGlobalDeviceIds(builder_));
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_types, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
@@ -1564,29 +1397,19 @@
return all_reduce_op.getOperation();
}
case HloOpcode::kAllReduceStart: {
- auto all_reduce_start = Cast<HloAllReduceInstruction>(instruction);
- attributes.push_back(
- ConvertReplicaGroups(all_reduce_start->replica_groups(), builder_));
- if (all_reduce_start->channel_id().has_value())
- attributes.push_back(
- ConvertChannelHandle(all_reduce_start->channel_id().value()));
- if (all_reduce_start->use_global_device_ids())
- attributes.push_back(ConvertUseGlobalDeviceIds());
- if (all_reduce_start->operands().size() > 1)
- return InvalidArgument(
- "Async tuple all-reduce is not supported in MHLO");
+ auto appendRegion = [&](mlir::mhlo::AllReduceOp all_reduce_sync) {
+ TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
+ &all_reduce_sync.getComputation()));
+ return absl::OkStatus();
+ };
- return ImportOldStyleAsyncStart<mlir::mhlo::AllReduceOp>(
- attributes, operands, loc, result_type, func_builder, "all_reduce_",
- [&](auto all_reduce_sync) {
- TF_RETURN_IF_ERROR(ImportAsRegion(
- *instruction->to_apply(), &all_reduce_sync.getComputation()));
- return absl::OkStatus();
- });
+ return ImportAllReduceStart(instruction, loc, operands, attributes,
+ result_type, func_builder, appendRegion,
+ symbol_table_);
}
case HloOpcode::kAllReduceDone: {
- return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
- func_builder);
+ return ImportAsyncOpDone(instruction, loc, operands, attributes,
+ result_type, func_builder);
}
case HloOpcode::kAllToAll: {
auto all_to_all = Cast<HloAllToAllInstruction>(instruction);
@@ -1623,7 +1446,8 @@
replica_groups_attr);
if (all_to_all->channel_id().has_value()) {
- auto handle = ConvertChannelHandle(all_to_all->channel_id().value());
+ auto handle =
+ ConvertChannelHandle(all_to_all->channel_id().value(), builder_);
result.setChannelHandleAttr(
handle.getValue().cast<mlir::mhlo::ChannelHandleAttr>());
}
@@ -1829,10 +1653,10 @@
attributes.push_back(
ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_));
if (reduce_scatter->channel_id().has_value())
- attributes.push_back(
- ConvertChannelHandle(reduce_scatter->channel_id().value()));
+ attributes.push_back(ConvertChannelHandle(
+ reduce_scatter->channel_id().value(), builder_));
if (reduce_scatter->use_global_device_ids())
- attributes.push_back(ConvertUseGlobalDeviceIds());
+ attributes.push_back(ConvertUseGlobalDeviceIds(builder_));
auto reduce_scatter_op =
func_builder->create<mlir::mhlo::ReduceScatterOp>(
loc, result_type, operands, attributes);
@@ -2192,10 +2016,22 @@
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
+ LLVM_DEBUG(llvm::dbgs() << "Importing instruction: "
+ << HloOpcodeString(instruction->opcode()) << '\n');
+ LLVM_DEBUG({
+ llvm::dbgs() << " operands: (";
+ llvm::interleaveComma(operands, llvm::dbgs(),
+ [](Value v) { llvm::dbgs() << v.getType(); });
+ llvm::dbgs() << ")\n";
+ });
TF_ASSIGN_OR_RETURN(
mlir::Operation * op,
ImportInstructionImpl(instruction, operands, func_builder, mode));
- if (op == nullptr) return op;
+ if (op == nullptr) {
+ LLVM_DEBUG(llvm::dbgs() << " instruction skipped.\n");
+ return op;
+ }
+ LLVM_DEBUG(llvm::dbgs() << " imported: " << *op << '\n');
// See MlirToHloConversionOptions for more about layouts.
//
@@ -2322,62 +2158,6 @@
return builder_->getNamedAttr("padding", attr);
}
-mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
- const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
- mlir::Builder* builder) {
- std::vector<int64_t> attr(source_target_pairs.size() * 2);
- for (const auto& p : llvm::enumerate(source_target_pairs)) {
- attr[2 * p.index()] = p.value().first;
- attr[2 * p.index() + 1] = p.value().second;
- }
- auto type = mlir::RankedTensorType::get(
- {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
- return builder->getNamedAttr("source_target_pairs",
- DenseIntElementsAttr::get(type, attr));
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
- absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
- const int64_t num_groups = replica_groups.size();
- // Replica groups in HLO can be non-uniform in size, for example:
- // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
- // tensor, pad the smaller sized replica groups with -1.
- const int64_t group_size = absl::c_accumulate(
- replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
- return std::max<int64_t>(current, g.replica_ids_size());
- });
- // Initialize all elements to -1 to support non-uniform replica groups.
- std::vector<int64_t> attr(num_groups * group_size, -1);
- for (int i = 0; i < num_groups; ++i) {
- int index = i * group_size;
- for (const int64_t& id : replica_groups[i].replica_ids())
- attr[index++] = id;
- }
- auto type = mlir::RankedTensorType::get({num_groups, group_size},
- builder->getIntegerType(64));
- return builder->getNamedAttr("replica_groups",
- DenseIntElementsAttr::get(type, attr));
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
- std::optional<int64_t> channel_id) {
- ChannelHandle channel_handle;
- if (channel_id) channel_handle.set_handle(*channel_id);
- return ConvertChannelHandle(channel_handle);
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
- const ChannelHandle& channel) {
- return builder_->getNamedAttr(
- "channel_handle", mlir::mhlo::ChannelHandleAttr::get(
- context_, channel.handle(), channel.type()));
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertUseGlobalDeviceIds() {
- return builder_->getNamedAttr("use_global_device_ids",
- builder_->getUnitAttr());
-}
-
void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
const Shape& shape,
llvm::StringRef attr_name) {
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h
index 5c5a4e3..fa22a6d 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h
@@ -16,21 +16,28 @@
#ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_
#define XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_
-#include <string>
+#include <cstdint>
#include <unordered_map>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "xla/comparison_util.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
@@ -90,30 +97,12 @@
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape,
llvm::StringRef attr_name);
- // TODO(b/179166199): move this to attribute_importer.h.
- // Converts XLA instruction source target pairs to MLIR attribute.
- static mlir::NamedAttribute ConvertSourceTargetPairs(
- const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
- mlir::Builder* builder);
-
- // TODO(b/179166199): move this to attribute_importer.h.
- // Converts replica groups to attribute
- static mlir::NamedAttribute ConvertReplicaGroups(
- absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder);
-
// For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block
// arguments with 'implicit_operands'. Here | implicit_operands | == sum of
// the number of arguments in all the regions in IfOp or CaseOp.
void ReplaceBlockArgumentsWithImplicitOperands(
mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands);
- // Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType.
- // Otherwise, return 'op'.
- mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
- mlir::Location loc,
- mlir::Operation* op,
- mlir::Type type);
-
// FlattenTupleType flattens the types in (nested) tuple-type 'type' and
// stores them in 'flattened_types'.
static void FlattenTupleType(
@@ -131,23 +120,6 @@
mlir::OpBuilder* func_builder, mlir::Location loc,
mlir::ValueRange values, std::optional<int> reserve_size = std::nullopt);
- // CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using
- // the non-tuple-typed values in 'flatten_values'.
- //
- // e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple<T1,tuple<T1,T2>>,
- // The function returns %t2 such that:
- // %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple<T2,T3>
- // %t2 = mhlo.tuple(V1,%t1): (T1,tuple<T2,T3>) -> tuple<T1,tuple<T1,T2>>
- //
- // Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to
- // resp. flatten and create tuples in the exact same order.
- // 2. `flatten_values`, initially storing the flattened values, will be
- // mutated to a 0-length array by the end of function invocation.
- static mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder,
- mlir::Location loc,
- mlir::ValueRange& flatten_values,
- mlir::Type type);
-
private:
HloFunctionImporter(mlir::SymbolTable& symbol_table,
std::unordered_map<const HloComputation*,
@@ -221,6 +193,7 @@
// Returns the Mlir Value for the corresponding HloInstruction.
absl::StatusOr<mlir::Value> GetMlirValue(const HloInstruction* instruction);
+ // TODO(b/179166199): Move attribute converters to attribute_importer.
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonDirection(
ComparisonDirection direction);
@@ -245,43 +218,6 @@
// padding low and padding high for each of the spatial dimensions.
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
- // Converts channel id to attribute
- mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id);
-
- // Convert use global device ids flag to attribute
- mlir::NamedAttribute ConvertUseGlobalDeviceIds();
-
- // Converts channel handle to attribute
- mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel);
-
- // ============
- // Imports an old-style async start op. E.g. an HLO all-gather-start
- // instruction is imported as an async-start associated with an all-gather
- // computation.
- //
- // Eventually, old-style async ops (e.g. all-gather-start) and new-style async
- // ops (i.e. async-start, async-update and async-done) will converge on the
- // HLO side, so we decided to not introduce new MHLO ops for all-gather-start
- // and friends.
- //
- // In the end, there may be new ops added in the old-style because they're not
- // compatible with the new-style async semantics, but those should be handled
- // on their own, rather than this function which "upgrades" ops to the
- // new-style async API.
- // ============
- template <typename SyncOp>
- absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncStart(
- llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
- const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
- mlir::Type result_type, mlir::OpBuilder* func_builder,
- std::string func_name, std::function<absl::Status(SyncOp)> mutate_op);
-
- // Imports an old-style async done op
- absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncDone(
- llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
- const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
- mlir::Type result_type, mlir::OpBuilder* func_builder);
-
mlir::MLIRContext* context_;
// SymbolTable to which new functions should be inserted.
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc
index e8d81dc..d6dafe0 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc
@@ -15,12 +15,29 @@
#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
+#include "absl/status/statusor.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
#include "xla/mlir/utils/error_util.h"
+#include "xla/service/llvm_ir/llvm_util.h"
#include "xla/status_macros.h"
#include "xla/translate/hlo_to_mhlo/hlo_module_importer.h"
+#include "tsl/platform/errors.h"
namespace xla {
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+ mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module,
+ bool import_all_computations, bool flatten_computation_args_result) {
+ mlir::OwningOpRef<mlir::ModuleOp> module =
+ llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx));
+ TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module,
+ import_all_computations,
+ flatten_computation_args_result));
+ return module;
+}
+
absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
xla::HloModuleProto const* hlo_module_proto,
bool import_all_computation,
@@ -32,7 +49,7 @@
}
absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
- xla::HloModule* hlo_module,
+ const xla::HloModule* hlo_module,
bool import_all_computation,
bool flatten_computation_args_result) {
mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext());
@@ -41,4 +58,15 @@
.Import(*hlo_module);
}
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+ mlir::MLIRContext& ctx, const xla::HloModule* hlo_module,
+ bool import_all_computations, bool flatten_computation_args_result) {
+ mlir::OwningOpRef<mlir::ModuleOp> module =
+ llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx));
+ TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module,
+ import_all_computations,
+ flatten_computation_args_result));
+ return module;
+}
+
} // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h
index 161823a1..775d636 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h
@@ -19,6 +19,10 @@
#include <stdbool.h>
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
namespace mlir {
class ModuleOp;
@@ -35,6 +39,11 @@
//
// If `flatten_computation_args_result` is set to true, flattens all tuple
// arguments and result of every computation when importing them as func ops.
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+ mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module,
+ bool import_all_computations = false,
+ bool flatten_computation_args_result = false);
+
absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
xla::HloModuleProto const* hlo_module,
bool import_all_computations = false,
@@ -47,8 +56,13 @@
//
// If `flatten_computation_args_result` is set to true, flattens all tuple
// arguments and result of every computation when importing them as func ops.
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+ mlir::MLIRContext& ctx, const xla::HloModule* hlo_module,
+ bool import_all_computations = false,
+ bool flatten_computation_args_result = false);
+
absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
- xla::HloModule* hlo_module,
+ const xla::HloModule* hlo_module,
bool import_all_computations = false,
bool flatten_computation_args_result = false);
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc
index 468c29a..e6004cf 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc
@@ -17,20 +17,34 @@
#include "xla/translate/hlo_to_mhlo/hlo_utils.h"
+#include <cassert>
#include <cstddef>
-#include <type_traits>
+#include <cstdint>
#include <vector>
+#include "absl/status/statusor.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
+#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/mlir/utils/type_util.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/primitive_util.h"
+#include "xla/shape.h"
#include "xla/types.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
namespace xla {
namespace {
@@ -139,4 +153,46 @@
vector);
}
+mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc,
+ mlir::ValueRange& flatten_values,
+ mlir::Type type) {
+ auto tuple_type = type.dyn_cast<mlir::TupleType>();
+ if (!tuple_type) {
+ assert(!flatten_values.empty());
+ auto retval = flatten_values.front();
+ flatten_values = flatten_values.drop_front();
+ return retval;
+ }
+
+ llvm::SmallVector<mlir::Value> flatten_sub_values;
+ for (auto child_type : tuple_type.getTypes())
+ flatten_sub_values.push_back(
+ CreateTupleValue(func_builder, loc, flatten_values, child_type));
+
+ return func_builder->create<mlir::mhlo::TupleOp>(loc, flatten_sub_values)
+ .getResult();
+}
+
+mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
+ mlir::Location loc,
+ mlir::Operation* op,
+ mlir::Type type) {
+ if (!type.isa<mlir::TupleType>()) return op;
+
+ mlir::ValueRange flattened_results_ref(op->getResults());
+ auto result =
+ CreateTupleValue(func_builder, loc, flattened_results_ref, type);
+ auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
+ assert(defining_tuple_op && "builder didn't return the right type");
+ auto tupleOp = defining_tuple_op.getOperation();
+ return tupleOp;
+}
+
+mlir::TypeRange Untuple(const mlir::Type& type) {
+ if (llvm::isa<mlir::TupleType>(type)) {
+ return llvm::dyn_cast<mlir::TupleType>(type).getTypes();
+ }
+ return type;
+}
+
} // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h
index 81fe60e..dd7f68a 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h
@@ -32,6 +32,10 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
@@ -169,6 +173,29 @@
return ConvertTensorShapeToType<TypeT>(shape, builder);
}
+// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using
+// the non-tuple-typed values in 'flatten_values'.
+//
+// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple<T1,tuple<T1,T2>>,
+// The function returns %t2 such that:
+// %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple<T2,T3>
+// %t2 = mhlo.tuple(V1,%t1): (T1,tuple<T2,T3>) -> tuple<T1,tuple<T1,T2>>
+//
+// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to
+// resp. flatten and create tuples in the exact same order.
+// 2. `flatten_values`, initially storing the flattened values, will be
+// mutated to a 0-length array by the end of function invocation.
+mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc,
+ mlir::ValueRange& flatten_values, mlir::Type type);
+
+// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType.
+// Otherwise, return 'op'.
+mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
+ mlir::Location loc,
+ mlir::Operation* op, mlir::Type type);
+
+mlir::TypeRange Untuple(const mlir::Type& type);
+
static std::pair<mlir::Attribute, mlir::ArrayAttr> GetLayoutAttribute(
mlir::Builder& b, const Shape& shape,
std::optional<const Layout> maybe_layout = std::nullopt) {
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD
index 9c3500c..fd980a0 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD
@@ -11,6 +11,7 @@
[
"bool_compare.hlo",
"case_conditional.hlo",
+ "composite_call.hlo",
"custom_call.hlo",
"dynamic_param.hlo",
"entry_computation_layout.hlo",
@@ -20,11 +21,11 @@
"if_conditional.hlo",
"import.hlo",
"import_async.hlo",
+ "import_async2.hlo",
"layouts_and_names.hlo",
"location.hlo",
"module_attributes.hlo",
"module_config.hlo",
- "send_recv.hlo",
"simple.hlo",
"spmd_module_sharding.hlo",
"stacktrace_to_location.hlo",
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo
new file mode 100644
index 0000000..ad3dc70
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo
@@ -0,0 +1,186 @@
+// RUN: xla-translate -split-input-file -hlo-text-to-mlir-hlo %s | FileCheck %s
+
+// dictionary-like frontend_attributes
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2, version = 1 : i32} : (tensor<f32>) -> tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK: %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+ %Arg_0.3 = f32[] parameter(0)
+ %constant.4 = f32[] constant(2)
+ ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// string-like frontend_attributes
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2, version = 1 : i32} : (tensor<f32>) -> tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK: %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+ %Arg_0.3 = f32[] parameter(0)
+ %constant.4 = f32[] constant(2)
+ ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes="{n = 1 : i32, tensor = dense<1> : tensor<i32>}",composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// zero-output composite
+HloModule composite, entry_computation_layout={()->()}
+
+// CHECK: func.func @main() -> tuple<> {
+// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @return.2, version = 1 : i32, xla_shape = "()"} : (tensor<f32>) -> tuple<>
+// CHECK: return %1 : tuple<>
+// CHECK: }
+// CHECK: func.func private @return.2(%arg0: tensor<f32>) -> tuple<> {
+// CHECK: %0 = mhlo.tuple {xla_shape = "()"} : tuple<>
+// CHECK: return %0 : tuple<>
+// CHECK: }
+%return.2 (Arg_0.3: f32[]) -> () {
+ %Arg_0.3 = f32[] parameter(0)
+ ROOT %tuple.4 = () tuple()
+}
+
+ENTRY %main.7 () -> () {
+ %constant.1 = f32[] constant(42)
+ ROOT %call.5 = () call(f32[] %constant.1), to_apply=%return.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// multi-output composite
+HloModule composite, entry_computation_layout={()->(f32[], f32[])}
+
+// CHECK: func.func @main() -> tuple<tensor<f32>, tensor<f32>> {
+// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2, version = 1 : i32, xla_shape = "(f32[], f32[])"} : (tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
+// CHECK: return %1 : tuple<tensor<f32>, tensor<f32>>
+// CHECK: }
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>> {
+// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK: %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK: %2 = mhlo.tuple %1, %1 {xla_shape = "(f32[], f32[])"} : tuple<tensor<f32>, tensor<f32>>
+// CHECK: return %2 : tuple<tensor<f32>, tensor<f32>>
+// CHECK: }
+%add.2 (Arg_0.3: f32[]) -> (f32[], f32[]) {
+ %Arg_0.3 = f32[] parameter(0)
+ %constant.4 = f32[] constant(2)
+ %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+ ROOT %tuple.6 = (f32[], f32[]) tuple(f32[] %add.5, f32[] %add.5)
+}
+
+ENTRY %main.9 () -> (f32[], f32[]) {
+ %constant.1 = f32[] constant(42)
+ ROOT %call.7 = (f32[], f32[]) call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// optional composite attributes
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK: %1 = mhlo.composite "foo.bar" %0 {decomposition = @add.2, version = 1 : i32} : (tensor<f32>) -> tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK: %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+ %Arg_0.3 = f32[] parameter(0)
+ %constant.4 = f32[] constant(2)
+ ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// optional composite version
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2} : (tensor<f32>) -> tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK: %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+ %Arg_0.3 = f32[] parameter(0)
+ %constant.4 = f32[] constant(2)
+ ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes="{n = 1 : i32, tensor = dense<1> : tensor<i32>}",composite.name="foo.bar",composite.version="0"}
+}
+
+// -----
+
+// optional composite attributes and version
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK: %1 = mhlo.composite "foo.bar" %0 {decomposition = @add.2} : (tensor<f32>) -> tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK: %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK: return %1 : tensor<f32>
+// CHECK: }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+ %Arg_0.3 = f32[] parameter(0)
+ %constant.4 = f32[] constant(2)
+ ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+ %constant.1 = f32[] constant(42)
+ ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"}
+}
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo
index 7dcd16a..4e96330 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo
@@ -1,142 +1,162 @@
-// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
-// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION
+// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations -split-input-file %s -o - | FileCheck %s
-// NO_DEAD_FUNCTION-NOT: @test
+// CHECK-LABEL: func.func private @recv_
+// CHECK: %0:2 = "mhlo.recv"(%arg0) <{channel_handle = #mhlo.channel_handle<handle = 5, type = 3>, is_host_transfer = true}> : (!mhlo.token) -> (tensor<i32>, !mhlo.token)
-// CHECK: module @foobar
+// CHECK-LABEL: func.func private @send_
+// CHECK: %0 = "mhlo.send"(%arg0, %arg1) <{channel_handle = #mhlo.channel_handle<handle = 3, type = 2>, is_host_transfer = true}> : (tensor<i32>, !mhlo.token) -> !mhlo.token
+
+// CHECK-LABEL: func.func @main
+// CHECK-LITERAL: %0 = "mhlo.async_start"(%arg0, %arg1) <{called_computation = @send_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (tensor<i32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
+// CHECK-NEXT-LITERAL: %1 = "mhlo.async_done"(%0) {called_computation = @send_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{maximal device=0}", xla_shape = "token[]"} : (!mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>) -> !mhlo.token
+// CHECK-NEXT-LITERAL: %2 = "mhlo.async_start"(%1) <{called_computation = @recv_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>
+// CHECK-NEXT-LITERAL: %3:2 = "mhlo.async_done"(%2) {called_computation = @recv_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>) -> (tensor<i32>, !mhlo.token)
HloModule foobar
-// Compiler-generated functions
+ENTRY %async_send_recv_test (arg_0: s32[], arg_1: token[]) -> (s32[], token[]) {
+ %arg_0 = s32[] parameter(0)
+ %arg_1 = token[] parameter(1)
-// CHECK: func private [[RECV_DTD_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
- // CHECK-NEXT: "mhlo.recv"([[TOK]]
- // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 1>, is_host_transfer = false}
+ %send.0 = (s32[], u32[], token[]) send(s32[] %arg_0, token[] %arg_1), channel_id=3, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
+ %send-done.1 = token[] send-done((s32[], u32[], token[]) %send.0), channel_id=3, is_host_transfer=true, sharding={maximal device=0}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
-// CHECK: func private [[RECV_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
- // CHECK-NEXT: "mhlo.recv"([[TOK]]
- // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 3>, is_host_transfer = true}
+ %recv.2 = (s32[], u32[], token[]) recv(token[] %send-done.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
+ %recv-done.3 = (s32[], token[]) recv-done((s32[], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
-// CHECK: func private [[SEND_GENSYM:@.*send.*]]([[INPUT:%.*]]: tensor<128x32xf32>, %arg1: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
- // CHECK-NEXT: "mhlo.send"([[INPUT]]
- // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 2>, is_host_transfer = true}
-
-// CHECK: func private [[COPY_GENSYM:@.*copy.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
- // CHECK-NEXT: mhlo.copy [[INPUT]]
- // CHECK-SAME: cross_program_prefetch_index
-
-// CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
- // CHECK-NEXT: "mhlo.collective_permute"([[INPUT]])
- // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32>
-
-// CHECK: func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
- // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]])
- // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
- // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
- // CHECK-SAME: use_global_device_ids
- // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
- // CHECK: mhlo.add [[LHS]], [[RHS]]
-
-// CHECK: func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
- // CHECK-NEXT: "mhlo.all_gather"([[INPUT]])
- // CHECK-SAME: all_gather_dim = 1 : i64
- // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
- // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
- // CHECK-SAME: use_global_device_ids
-
-// CHECK: func @main(%arg0: tensor<f32>) -> tensor<f32> {
-ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
- ROOT %Arg_0.1 = f32[] parameter(0)
+ %get-tuple-element.4 = s32[] get-tuple-element((s32[], token[]) %recv-done.3), index=0, sharding={maximal device=0}
+ %get-tuple-element.5 = token[] get-tuple-element((s32[], token[]) %recv-done.3), index=1, sharding={maximal device=0}
+ ROOT %tuple.6 = (s32[], token[]) tuple(s32[] %get-tuple-element.4, token[] %get-tuple-element.5)
}
-// Tests
+// -----
-// CHECK: func private @test_all_gather_start
-// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>)
-%test_all_gather_start {
- input = f32[128,32] parameter(0)
- // CHECK-NEXT: [[AG_START:%.*]] = "mhlo.async_start"([[INPUT]])
- // CHECK-SAME: called_computation = [[AG_GENSYM]], execution_thread = "main"
- ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
- // CHECK-NEXT: "mhlo.async_done"([[AG_START]])
- ROOT ag-done = f32[128,128] all-gather-done(ag-start)
+HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}}
+
+// CHECK-LABEL: func.func private @all_gather_
+// CHECK: mhlo.all_gather
+
+// CHECK-LABEL: func.func @main
+// CHECK: mhlo.async_start{{.*}}called_computation = @all_gather_
+// CHECK: mhlo.async_done
+
+ENTRY %async_all_gather_test (Arg_0.1: f32[128,32]) -> f32[128,128] {
+ %Arg_0.1 = f32[128,32] parameter(0)
+ %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16}
+ ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17}
}
-add {
- lhs = f32[] parameter(0)
- rhs = f32[] parameter(1)
- ROOT add = f32[] add(lhs, rhs)
+// -----
+
+HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}}
+
+%region_1.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] {
+ %Arg_0.3 = f32[] parameter(0)
+ %Arg_1.4 = f32[] parameter(1)
+ ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7}
}
-// CHECK: func private @test_all_reduce_start
-// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>)
-%test_all_reduce_start {
- input = f32[128,32] parameter(0)
- // CHECK-NEXT: [[AR_START:%.*]] = "mhlo.async_start"([[INPUT]])
- // CHECK-SAME: called_computation = [[AR_GENSYM]], execution_thread = "main"
- ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true
- // CHECK-NEXT: "mhlo.async_done"([[AR_START]])
- ROOT ar-done = f32[128,32] all-reduce-done(ar-start)
+// CHECK-LABEL: func.func private @all_reduce_
+// CHECK: mhlo.all_reduce
+
+// CHECK-LABEL: func.func @main
+// CHECK: mhlo.async_start{{.*}}called_computation = @all_reduce_
+// CHECK: mhlo.async_done
+ENTRY %async_all_reduce_test (Arg_0.1: f32[10]) -> f32[10] {
+ %Arg_0.1 = f32[10] parameter(0)
+ %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22}
+ ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23}
}
-// CHECK: func private @test_collective_permute
-// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
-%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
- %input = f32[128,32]{1,0} parameter(0)
- // CHECK-NEXT: [[CP_START:%.*]] = "mhlo.async_start"([[ARG]])
- // CHECK-SAME: called_computation = [[CP_GENSYM]], execution_thread = "main"
- %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}}
- // CHECK-NEXT: "mhlo.async_done"([[CP_START]])
- ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start)
+// -----
+
+HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}}
+
+// CHECK-LABEL: func.func private @collective_permute_
+// CHECK: mhlo.collective_permute
+
+// CHECK-LABEL: func.func @main
+// CHECK: mhlo.async_start{{.*}}called_computation = @collective_permute_
+// CHECK: mhlo.async_done
+ENTRY %async_collective_permute_test (Arg_0.1: f32[128,32]) -> f32[128,32] {
+ %Arg_0.1 = f32[128,32] parameter(0)
+ %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13}
+ ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14}
}
-// CHECK: func private @test_copy_start
-// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>)
-%test_copy_start {
- input = f32[128,32] parameter(0)
- // CHECK-NEXT: [[COPY_START:%.*]] = "mhlo.async_start"([[INPUT]])
- // CHECK-SAME: called_computation = [[COPY_GENSYM]], execution_thread = "main"
- copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0
- // CHECK-NEXT: "mhlo.async_done"([[COPY_START]])
- ROOT copy-done = f32[128,32] copy-done(copy-start)
+// -----
+
+HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}}
+
+ENTRY %async_copy_test (Arg_0.1: f32[128,32]) -> f32[128,32] {
+ %Arg_0.1 = f32[128,32] parameter(0)
+ %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10}
+ ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11}
}
-// CHECK: func private @test_send
-// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
-%test_send_start {
- input = f32[128,32] parameter(0)
- tok = token[] parameter(1)
- // CHECK-NEXT: [[SEND_START:%.*]] = "mhlo.async_start"([[INPUT]], [[TOK]])
- // CHECK-SAME: called_computation = [[SEND_GENSYM]], execution_thread = "main"
- // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<tuple<tensor<128x32xf32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
- send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true
- // CHECK-NEXT: "mhlo.async_done"([[SEND_START]])
- ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true
+// -----
+
+HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])}
+
+ENTRY %async_recv_test_tuple (Arg_0.1: token[]) -> (s32[3,4], token[]) {
+ %Arg_0.1 = token[] parameter(0)
+ %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16}
+ %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17}
+ %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17}
+ %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17}
+ ROOT %tuple.6 = (s32[3,4], token[]) tuple(s32[3,4] %get-tuple-element.4, token[] %get-tuple-element.5)
}
-// CHECK: func private @test_recv
-// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
-%test_recv_start {
- input = f32[128,32] parameter(0)
- tok = token[] parameter(1)
- // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
- // CHECK-SAME: called_computation = [[RECV_GENSYM]], execution_thread = "main"
- // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
- recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true
- // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
- recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true
- ROOT gte = get-tuple-element(recv-done), index=0
+// -----
+
+HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]}
+
+ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] {
+ %Arg_0.1 = s32[3,4] parameter(0)
+ %Arg_1.2 = token[] parameter(1)
+ %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16}
+ ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17}
}
-// CHECK: func private @test_recv_dtd
-// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
-%test_recv_dtd_start {
- input = f32[128,32] parameter(0)
- tok = token[] parameter(1)
- // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
- // CHECK-SAME: called_computation = [[RECV_DTD_GENSYM]], execution_thread = "main"
- // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
- recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5
- // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
- recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5
- ROOT gte = get-tuple-element(recv-done), index=0
-}
+
+// BROKEN: b/TODO: Async custom calls?
+
+// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})}
+
+// ENTRY %async_custom_call_test2 (Arg_0.1: f32[10]) -> (f32[20]) {
+// %Arg_0.1 = f32[10] parameter(0)
+// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21}
+// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22}
+// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23}
+// }
+
+// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})}
+
+// ENTRY %async_custom_call_test (Arg_0.1: f32[10]) -> (f32[20]) {
+// %Arg_0.1 = f32[10] parameter(0)
+// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16}
+// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18}
+// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20}
+// }
+
+
+///////////
+
+// BROKEN: b/TODO: Empty arg send/recv don't roundtrip
+
+// HloModule main, entry_computation_layout={(token[])->token[]}
+
+// ENTRY %async_send_test_empty (Arg_0.1: token[]) -> token[] {
+// %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15}
+// %Arg_0.1 = token[] parameter(0)
+// %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15}
+// ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16}
+// }
+
+// HloModule main, entry_computation_layout={(token[])->((), token[])}
+
+// ENTRY %async_recv_test (Arg_0.1: token[]) -> ((), token[]) {
+// %Arg_0.1 = token[] parameter(0)
+// %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17}
+// ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18}
+// }
+
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo
new file mode 100644
index 0000000..7493c95
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo
@@ -0,0 +1,146 @@
+// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
+// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION
+
+// It would be great to consolidate this test with `import_async.hlo`, but
+// this test is very fragile and doesn't run properly in a `-split-input-file`
+// mode.
+
+// NO_DEAD_FUNCTION-NOT: @test
+
+// CHECK: module @foobar
+HloModule foobar
+
+// Compiler-generated functions
+
+// CHECK: func private [[RECV_DTD_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
+ // CHECK-NEXT: "mhlo.recv"([[TOK]]
+ // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 1>, is_host_transfer = false}
+
+// CHECK: func private [[RECV_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
+ // CHECK-NEXT: "mhlo.recv"([[TOK]]
+ // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 3>, is_host_transfer = true}
+
+// CHECK: func private [[SEND_GENSYM:@.*send.*]]([[INPUT:%.*]]: tensor<128x32xf32>, %arg1: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
+ // CHECK-NEXT: "mhlo.send"([[INPUT]]
+ // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 2>, is_host_transfer = true}
+
+// CHECK: func private [[COPY_GENSYM:@.*copy.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+ // CHECK-NEXT: mhlo.copy [[INPUT]]
+ // CHECK-SAME: cross_program_prefetch_index
+
+// CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+ // CHECK-NEXT: "mhlo.collective_permute"([[INPUT]])
+ // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32>
+
+// CHECK: func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+ // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]])
+ // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+ // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
+ // CHECK-SAME: use_global_device_ids
+ // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
+ // CHECK: mhlo.add [[LHS]], [[RHS]]
+
+// CHECK: func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
+ // CHECK-NEXT: "mhlo.all_gather"([[INPUT]])
+ // CHECK-SAME: all_gather_dim = 1 : i64
+ // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+ // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
+ // CHECK-SAME: use_global_device_ids
+
+// CHECK: func @main(%arg0: tensor<f32>) -> tensor<f32> {
+ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
+ ROOT %Arg_0.1 = f32[] parameter(0)
+}
+
+// Tests
+
+// CHECK: func private @test_all_gather_start
+// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>)
+%test_all_gather_start {
+ input = f32[128,32] parameter(0)
+ // CHECK-NEXT: [[AG_START:%.*]] = "mhlo.async_start"([[INPUT]])
+ // CHECK-SAME: called_computation = [[AG_GENSYM]], execution_thread = "main"
+ ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
+ // CHECK-NEXT: "mhlo.async_done"([[AG_START]])
+ ROOT ag-done = f32[128,128] all-gather-done(ag-start)
+}
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK: func private @test_all_reduce_start
+// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>)
+%test_all_reduce_start {
+ input = f32[128,32] parameter(0)
+ // CHECK-NEXT: [[AR_START:%.*]] = "mhlo.async_start"([[INPUT]])
+ // CHECK-SAME: called_computation = [[AR_GENSYM]], execution_thread = "main"
+ ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true
+ // CHECK-NEXT: "mhlo.async_done"([[AR_START]])
+ ROOT ar-done = f32[128,32] all-reduce-done(ar-start)
+}
+
+// CHECK: func private @test_collective_permute
+// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
+%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
+ %input = f32[128,32]{1,0} parameter(0)
+ // CHECK-NEXT: [[CP_START:%.*]] = "mhlo.async_start"([[ARG]])
+ // CHECK-SAME: called_computation = [[CP_GENSYM]], execution_thread = "main"
+ %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}}
+ // CHECK-NEXT: "mhlo.async_done"([[CP_START]])
+ ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start)
+}
+
+// CHECK: func private @test_copy_start
+// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>)
+%test_copy_start {
+ input = f32[128,32] parameter(0)
+ // CHECK-NEXT: [[COPY_START:%.*]] = "mhlo.async_start"([[INPUT]])
+ // CHECK-SAME: called_computation = [[COPY_GENSYM]], execution_thread = "main"
+ copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0
+ // CHECK-NEXT: "mhlo.async_done"([[COPY_START]])
+ ROOT copy-done = f32[128,32] copy-done(copy-start)
+}
+
+// CHECK: func private @test_send
+// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
+%test_send_start {
+ input = f32[128,32] parameter(0)
+ tok = token[] parameter(1)
+ // CHECK-NEXT: [[SEND_START:%.*]] = "mhlo.async_start"([[INPUT]], [[TOK]])
+ // CHECK-SAME: called_computation = [[SEND_GENSYM]], execution_thread = "main"
+ // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<tuple<tensor<128x32xf32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
+ send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true
+ // CHECK-NEXT: "mhlo.async_done"([[SEND_START]])
+ ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true
+}
+
+// CHECK: func private @test_recv
+// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
+%test_recv_start {
+ input = f32[128,32] parameter(0)
+ tok = token[] parameter(1)
+ // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
+ // CHECK-SAME: called_computation = [[RECV_GENSYM]], execution_thread = "main"
+ // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
+ recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true
+ // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
+ recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true
+ ROOT gte = get-tuple-element(recv-done), index=0
+}
+
+// CHECK: func private @test_recv_dtd
+// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
+%test_recv_dtd_start {
+ input = f32[128,32] parameter(0)
+ tok = token[] parameter(1)
+ // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
+ // CHECK-SAME: called_computation = [[RECV_DTD_GENSYM]], execution_thread = "main"
+ // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
+ recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5
+ // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
+ recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5
+ ROOT gte = get-tuple-element(recv-done), index=0
+}
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo
deleted file mode 100644
index ef40699..0000000
--- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo
+++ /dev/null
@@ -1,55 +0,0 @@
-// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
-
-HloModule foo
-
-// CHECK: func private @[[RECV_FUNC:[^(]*]]
-// CHECK: mhlo.recv
-// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 5, type = 3>
-// CHECK-NOT: mhlo.sharding
-
-// CHECK: func private @[[SEND_FUNC:[^(]*]]
-// CHECK: mhlo.send
-// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 3, type = 2>
-
-// CHECK: func @main
-// CHECK: mhlo.async_start
-// CHECK-SAME: called_computation = @[[SEND_FUNC]]
-// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}
-// CHECK-SAME: mhlo.sharding = "{
-// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0}
-// CHECK-SAME: }"
-// CHECK-SAME: (tensor<i32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
-// CHECK: mhlo.async_done
-// CHECK-SAME: called_computation = @[[SEND_FUNC]]
-// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}
-// CHECK-SAME: mhlo.sharding = "{maximal device=0}"
-// CHECK-SAME: (!mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>) -> !mhlo.token
-// CHECK: mhlo.async_start
-// CHECK-SAME: called_computation = @[[RECV_FUNC]]
-// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}
-// CHECK-SAME: mhlo.sharding = "{
-// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0}
-// CHECK-SAME: }"
-// CHECK-SAME: (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>
-// CHECK: mhlo.async_done
-// CHECK-SAME: called_computation = @[[RECV_FUNC]]
-// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}
-// CHECK-SAME: mhlo.sharding = "{
-// CHECK-SAME: {maximal device=0}, {maximal device=0}
-// CHECK-SAME: }"
-// CHECK-SAME: (!mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>) -> (tensor<i32>, !mhlo.token)
-
-ENTRY %foo (arg_0: s32[], arg_1: token[]) -> (s32[], token[]) {
- %arg_0 = s32[] parameter(0)
- %arg_1 = token[] parameter(1)
-
- %send.0 = (s32[], u32[], token[]) send(s32[] %arg_0, token[] %arg_1), channel_id=3, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
- %send-done.1 = token[] send-done((s32[], u32[], token[]) %send.0), channel_id=3, is_host_transfer=true, sharding={maximal device=0}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
-
- %recv.2 = (s32[], u32[], token[]) recv(token[] %send-done.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
- %recv-done.3 = (s32[], token[]) recv-done((s32[], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
-
- %get-tuple-element.4 = s32[] get-tuple-element((s32[], token[]) %recv-done.3), index=0, sharding={maximal device=0}
- %get-tuple-element.5 = token[] get-tuple-element((s32[], token[]) %recv-done.3), index=1, sharding={maximal device=0}
- ROOT %tuple.6 = (s32[], token[]) tuple(s32[] %get-tuple-element.4, token[] %get-tuple-element.5)
-}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
index 1047e8a..40e05c8 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
@@ -99,6 +99,7 @@
":type_to_shape",
"//xla:array",
"//xla:comparison_util",
+ "//xla:debug_options_flags",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
@@ -106,6 +107,7 @@
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/client:xla_builder",
+ "//xla/client:xla_computation",
"//xla/client/lib:approx_topk",
"//xla/client/lib:approx_topk_shape",
"//xla/client/lib:matrix",
@@ -120,8 +122,11 @@
"//xla/service:hlo_parser",
"//xla/service:hlo_proto_cc",
"//xla/service/gpu:backend_configs_cc",
+ "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
@@ -133,8 +138,11 @@
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
+ "@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:ml_dtypes",
"@local_tsl//tsl/platform:statusor",
+ "@local_tsl//tsl/platform:types",
+ "@stablehlo//:base",
"@stablehlo//:stablehlo_ops",
],
)
@@ -177,18 +185,26 @@
deps = [
":mlir_hlo_to_hlo",
":type_to_shape",
+ "//xla:debug_options_flags",
+ "//xla:shape_util",
+ "//xla/client:xla_builder",
+ "//xla/client:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo:hlo_dialect_registration",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_proto_util",
"@com_google_absl//absl/log",
+ "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
],
)
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h
index 2c85a82..2ecd4e3 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h
@@ -19,8 +19,9 @@
#define XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_
#include <functional>
-#include <vector>
+#include <optional>
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/client/xla_builder.h"
#include "xla/hlo/ir/hlo_sharding.h"
@@ -30,10 +31,10 @@
namespace mlir {
// XLA Layout preferences. Currently, when it comes to TPU, there are two
-// primary layout choices for any XLA argumetns (parameter or resource): (1)
+// primary layout choices for any XLA arguments (parameter or resource): (1)
// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU
// layout while Linear is native host (CPU) layout.
-// This enum allows the caller of XLA to progogate layout preference to the XLA
+// This enum allows the caller of XLA to propagate layout preference to the XLA
// compiler.
// kNoPreference: the generic layout where the XLA compiler has the freedom
// to assign any layout.
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
index 90eb1a9..287b455 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
@@ -19,26 +19,25 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
-#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
+#include "absl/log/check.h"
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MemoryBuffer.h"
-#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -46,9 +45,12 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
@@ -56,25 +58,26 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/Base.h"
#include "xla/array.h"
#include "xla/client/lib/approx_topk.h"
#include "xla/client/lib/approx_topk_shape.h"
-#include "xla/client/lib/matrix.h"
-#include "xla/client/lib/quantize.h"
+#include "xla/client/lib/matrix.h" // IWYU pragma: keep
#include "xla/client/lib/slicing.h"
#include "xla/client/xla_builder.h"
+#include "xla/client/xla_computation.h"
#include "xla/comparison_util.h"
+#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/dynamic_parameter_binding.h"
-#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/mlir/utils/error_util.h"
@@ -88,16 +91,16 @@
#include "xla/service/hlo_parser.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
-#include "xla/status_macros.h"
#include "xla/translate/mhlo_to_hlo/attribute_exporter.h"
+#include "xla/translate/mhlo_to_hlo/layout_util.h"
#include "xla/translate/mhlo_to_hlo/location_exporter.h"
#include "xla/translate/mhlo_to_hlo/module_config_exporter.h"
#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h"
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"
-#include "xla/types.h"
#include "xla/xla_data.pb.h"
-#include "tsl/platform/ml_dtypes.h"
+#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
+#include "tsl/platform/types.h"
using ::int64_t;
using ::tsl::int16;
@@ -109,18 +112,54 @@
using ::tsl::uint64;
using ::tsl::uint8;
-constexpr char kShapeIndicesAttr[] = "shape_indices";
-constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
-constexpr char kShardingAttr[] = "mhlo.sharding";
-constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
-constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
-constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication";
-constexpr char kLiteralAttr[] = "mhlo.literal";
+// Boolean attribute.
+constexpr char kJaxBufferDonor[] = "jax.buffer_donor";
+// BitcastOp lowering strings.
+constexpr char kResultLayout[] = "result_layout";
+constexpr char kSourceLayout[] = "source_layout";
+
+// CustomCallOp lowering strings.
+constexpr char kAggregateToTopk[] = "aggregate_to_topk";
+constexpr char kApiVersion[] = "api_version";
+constexpr char kApproxTopK[] = "ApproxTopK";
+constexpr char kBackendConfig[] = "backend_config";
+constexpr char kCallTargetName[] = "call_target_name";
+constexpr char kCalledComputations[] = "called_computations";
+constexpr char kHasSideEffect[] = "has_side_effect";
+constexpr char kIsFallback[] = "is_fallback";
+constexpr char kRecallTarget[] = "recall_target";
+constexpr char kReductionDim[] = "reduction_dim";
+constexpr char kReductionInputSizeOverride[] = "reduction_input_size_override";
+constexpr char kTopK[] = "top_k";
+
+// MHLO attributes. Module level attributes require namespacing.
+constexpr char kMhloCrossProgramPrefetches[] = "mhlo.cross_program_prefetches";
+constexpr char kMhloFrontendAttributes[] = "mhlo.frontend_attributes";
+constexpr char kMhloInputOutputAlias[] = "mhlo.input_output_alias";
+constexpr char kMhloIsDynamic[] = "mhlo.is_dynamic";
+constexpr char kMhloLiteral[] = "mhlo.literal";
+constexpr char kMhloParameterReplication[] = "mhlo.parameter_replication";
+constexpr char kMhloReplication[] = "mhlo.is_same_data_across_replicas";
+constexpr char kMhloSharding[] = "mhlo.sharding";
+constexpr char kMhloSpmdOutputSharding[] = "mhlo.spmd_output_sharding";
+constexpr char kMhloSpmdParametersShardings[] =
+ "mhlo.spmd_parameters_shardings";
+constexpr char kMhloUseAutoSpmdPartitioning[] =
+ "mhlo.use_auto_spmd_partitioning";
+
+// Miscellaneous string literals.
+constexpr char kArgEmptyTuple[] = "arg_empty_tuple";
+constexpr char kArgPrefix[] = "Arg_";
+constexpr char kArgTuple[] = "arg_tuple";
+constexpr char kDefaultLayoutAttrName[] = "xla_shape";
+constexpr char kExecutionThread[] = "execution_thread";
// Array attribute. Same shape as infeed result, but contains a
// minor_to_major array for every tensor.
-constexpr char kLayoutAttr[] = "layout";
-constexpr char kDefaultLayoutAttrName[] = "xla_shape";
+constexpr char kLayout[] = "layout";
+constexpr char kMain[] = "main";
+constexpr char kRegionPrefix[] = "region_";
+constexpr char kTfAliasingOutput[] = "tf.aliasing_output";
// Passes through everything except for unique_ptr, on which it calls get().
// This exists to allow the generated code to call XLA functions that take a raw
@@ -585,7 +624,7 @@
// returns std::nullopt.
static std::optional<xla::OpSharding> CreateOpShardingFromAttribute(
mlir::Operation* op) {
- auto shardingAttr = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
+ auto shardingAttr = op->getAttrOfType<mlir::StringAttr>(kMhloSharding);
if (!shardingAttr) return std::nullopt;
return xla::ConvertSharding(shardingAttr.getValue());
}
@@ -606,7 +645,7 @@
mlir::Operation* op) {
xla::FrontendAttributes frontend_attributes;
auto frontend_attributes_dict =
- op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
+ op->getAttrOfType<mlir::DictionaryAttr>(kMhloFrontendAttributes);
if (!frontend_attributes_dict) return frontend_attributes;
ConstructFrontendAttributesFromAttribute(frontend_attributes_dict,
frontend_attributes);
@@ -619,7 +658,7 @@
fe_attrs->resize(function.getNumArguments(), std::nullopt);
for (int i = 0, end = function.getNumArguments(); i < end; ++i)
if (auto fe_attr = function.getArgAttrOfType<mlir::DictionaryAttr>(
- i, kFrontendAttributesAttr)) {
+ i, kMhloFrontendAttributes)) {
xla::FrontendAttributes frontend_attributes;
ConstructFrontendAttributesFromAttribute(fe_attr, frontend_attributes);
(*fe_attrs)[i] = frontend_attributes;
@@ -643,14 +682,14 @@
std::optional<xla::OpSharding>());
for (int i = 0, end = function.getNumArguments(); i < end; ++i)
if (auto sharding =
- function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
+ function.getArgAttrOfType<mlir::StringAttr>(i, kMhloSharding))
(*arg_shardings)[i] = xla::ConvertSharding(sharding.getValue());
ret_shardings->resize(function.getNumResults(),
std::optional<xla::OpSharding>());
for (int i = 0, end = function.getNumResults(); i < end; ++i)
if (auto sharding =
- function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
+ function.getResultAttrOfType<mlir::StringAttr>(i, kMhloSharding))
(*ret_shardings)[i] = xla::ConvertSharding(sharding.getValue());
}
@@ -753,7 +792,7 @@
//
// TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
LogicalResult Run() {
- auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
+ auto main = module_.lookupSymbol<mlir::func::FuncOp>(kMain);
if (!main)
return module_.emitError(
"conversion requires module with `main` function");
@@ -771,8 +810,8 @@
// Lower a `mlir::Region` to a `XlaComputation`
LogicalResult LowerRegionAsComputation(
mlir::Region* region, xla::XlaComputation* func,
- std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
- std::nullopt,
+ llvm::ArrayRef<mlir::Value> implicit_operands = {},
+ llvm::ArrayRef<mlir::Value> implicit_results = {},
bool ensure_single_arg = false,
llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings = {},
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings = {});
@@ -786,11 +825,11 @@
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
xla::XlaComputation* result,
- std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
- std::nullopt);
+ llvm::ArrayRef<mlir::Value> implicit_operands = {},
+ llvm::ArrayRef<mlir::Value> implicit_results = {});
::xla::HloModuleProto ConsumeMainProto() {
- auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
+ auto main = module_.lookupSymbol<mlir::func::FuncOp>(kMain);
// This is an invariant check as Run returns failure if there is no main
// function and so the main proto shouldn't be consumed in that case.
CHECK(main) << "requires module to have main function"; // Crash Ok.
@@ -816,7 +855,7 @@
LogicalResult Lower(
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
- xla::XlaBuilder* builder,
+ llvm::ArrayRef<mlir::Value> implicit_results, xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaOp* return_value);
@@ -916,12 +955,13 @@
}
void BuildGetTupleElementsForTupleResults(mlir::Operation* op, xla::XlaOp tuple,
- OpLoweringContext ctx) {
+ OpLoweringContext ctx,
+ unsigned num_implicit_results = 0) {
const std::optional<xla::OpSharding>& sharding = ctx.builder->sharding();
if (sharding.has_value()) {
bool is_tuple_sharding = sharding->type() == xla::OpSharding::TUPLE;
- assert(!is_tuple_sharding ||
- op->getNumResults() == sharding->tuple_shardings_size());
+ assert(!is_tuple_sharding || (op->getNumResults() + num_implicit_results ==
+ sharding->tuple_shardings_size()));
for (auto [index, result] : llvm::enumerate(op->getResults())) {
// If `sharding` is not a tuple sharding, then every `get-tuple-element`
// gets the same sharding.
@@ -956,7 +996,8 @@
}
LogicalResult ExportXlaOp(CompositeOp, OpLoweringContext) {
- // TODO: b/328526226 - Implement MHLO export for CompositeOp.
+ // Failure on purpose because `mhlo::CompositeOp` will be handled by
+ // special purpose logic in `ConvertToHloModule::Lower`.
return failure();
}
@@ -1610,7 +1651,7 @@
xla::XlaComputation false_branch;
auto& value_map = *ctx.values;
- // mhlo.IfOp does not have any operands or blocks-arguments. The computation
+ // mhlo.IfOp does not have any operands or blocks arguments. The computation
// inside the region-blocks use implicit captures of values defined above.
// In order to create the xla parameters for functions corresponding to
// IfOp regions, we need to infer the a region-block's arguments, using all
@@ -1628,10 +1669,10 @@
getUsedValuesDefinedAbove(op.getFalseBranch(), op.getFalseBranch(),
implicit_false_operand_set);
- llvm::SmallVector<mlir::Value> implicit_true_operands(
- implicit_true_operand_set.begin(), implicit_true_operand_set.end());
- llvm::SmallVector<mlir::Value> implicit_false_operands(
- implicit_false_operand_set.begin(), implicit_false_operand_set.end());
+ llvm::SmallVector<mlir::Value> implicit_true_operands =
+ implicit_true_operand_set.takeVector();
+ llvm::SmallVector<mlir::Value> implicit_false_operands =
+ implicit_false_operand_set.takeVector();
llvm::SmallVector<std::optional<xla::OpSharding>> ret_shardings =
GetResultShardings(ctx.builder->sharding(), op->getNumResults());
@@ -1657,13 +1698,13 @@
// implicit captures operands. Also export the instructions within those
// regions.
if (failed(ctx.converter->LowerRegionAsComputation(
- &op.getTrueBranch(), &true_branch,
- llvm::ArrayRef(implicit_true_operands),
- /*ensure_single_arg*/ true, true_arg_shardings, ret_shardings)) ||
+ &op.getTrueBranch(), &true_branch, implicit_true_operands,
+ /*implicit_results=*/{}, /*ensure_single_arg=*/true,
+ true_arg_shardings, ret_shardings)) ||
failed(ctx.converter->LowerRegionAsComputation(
- &op.getFalseBranch(), &false_branch,
- llvm::ArrayRef(implicit_false_operands),
- /*ensure_single_arg*/ true, false_arg_shardings, ret_shardings))) {
+ &op.getFalseBranch(), &false_branch, implicit_false_operands,
+ /*implicit_results=*/{}, /*ensure_single_arg=*/true,
+ false_arg_shardings, ret_shardings))) {
return failure();
}
@@ -1701,7 +1742,7 @@
std::vector<xla::XlaComputation> computations(branches.size());
std::vector<xla::XlaComputation*> computations_p(branches.size());
- // mhlo.CaseOp does not have any operands or blocks-arguments. The computation
+ // mhlo.CaseOp does not have any operands or blocks arguments. The computation
// inside the region-blocks use implicit captures of values defined above.
// In order to create the xla parameters for functions corresponding to
// CaseOp regions, we need to infer the a region-block's arguments, using all
@@ -1715,8 +1756,8 @@
for (unsigned i = 0; i < branches.size(); ++i) {
llvm::SetVector<mlir::Value> implicit_operand_set;
getUsedValuesDefinedAbove(branches[i], branches[i], implicit_operand_set);
- llvm::SmallVector<mlir::Value> implicit_operands(
- implicit_operand_set.begin(), implicit_operand_set.end());
+ llvm::SmallVector<mlir::Value> implicit_operands =
+ implicit_operand_set.takeVector();
llvm::SmallVector<std::optional<xla::OpSharding>> ret_shardings =
GetResultShardings(ctx.builder->sharding(), op->getNumResults());
@@ -1740,8 +1781,9 @@
// that region.
computations_p[i] = &computations[i];
if (failed(ctx.converter->LowerRegionAsComputation(
- &branches[i], computations_p[i], llvm::ArrayRef(implicit_operands),
- /*ensure_single_arg*/ true, arg_shardings, ret_shardings)))
+ &branches[i], computations_p[i], implicit_operands,
+ /*implicit_results=*/{}, /*ensure_single_arg=*/true, arg_shardings,
+ ret_shardings)))
return failure();
}
@@ -1905,12 +1947,12 @@
// This feature is at time of writing only used by JAX, and is tested in the
// jax2tf backwards compatibility tests.
- if (op.getCallTargetName() == "ApproxTopK") {
+ if (op.getCallTargetName() == kApproxTopK) {
auto isSupportedAttrName = [](NamedAttribute attr) {
auto name = attr.getName();
- return name == "call_target_name" || name == "backend_config" ||
- name == "api_version" || name == "called_computations" ||
- name == "has_side_effect";
+ return name == kCallTargetName || name == kBackendConfig ||
+ name == kApiVersion || name == kCalledComputations ||
+ name == kHasSideEffect;
};
for (const auto& attr : op->getAttrs()) {
if (!isSupportedAttrName(attr))
@@ -1925,9 +1967,9 @@
for (auto attr : backend_config) {
auto name = attr.getName();
- if (!(name == "top_k" || name == "reduction_dim" ||
- name == "recall_target" || name == "aggregate_to_topk" ||
- name == "reduction_input_size_override" || name == "is_fallback"))
+ if (!(name == kTopK || name == kReductionDim || name == kRecallTarget ||
+ name == kAggregateToTopk || name == kReductionInputSizeOverride ||
+ name == kIsFallback))
return op.emitOpError()
<< name.getValue() << " is not a supported backend_config"
<< " attribute for ApproxTopK";
@@ -1969,29 +2011,28 @@
<< " attribute in backend_config must be of bool type";
return success();
};
- if (failed(checkI64Attr("top_k"))) return failure();
- if (failed(checkI64Attr("reduction_dim"))) return failure();
- if (failed(checkF32Attr("recall_target"))) return failure();
- if (failed(checkBoolAttr("aggregate_to_topk"))) return failure();
- if (failed(checkI64Attr("reduction_input_size_override"))) return failure();
- bool has_is_fallback = backend_config.contains("is_fallback");
- if (has_is_fallback && !backend_config.getAs<BoolAttr>("is_fallback"))
+ if (failed(checkI64Attr(kTopK))) return failure();
+ if (failed(checkI64Attr(kReductionDim))) return failure();
+ if (failed(checkF32Attr(kRecallTarget))) return failure();
+ if (failed(checkBoolAttr(kAggregateToTopk))) return failure();
+ if (failed(checkI64Attr(kReductionInputSizeOverride))) return failure();
+ bool has_is_fallback = backend_config.contains(kIsFallback);
+ if (has_is_fallback && !backend_config.getAs<BoolAttr>(kIsFallback))
return op.emitOpError()
<< "is_fallback attribute in backend_config must be of bool type";
- int64_t top_k = backend_config.getAs<IntegerAttr>("top_k").getInt();
+ int64_t top_k = backend_config.getAs<IntegerAttr>(kTopK).getInt();
int64_t reduction_dim =
- backend_config.getAs<IntegerAttr>("reduction_dim").getInt();
- float recall_target = backend_config.getAs<FloatAttr>("recall_target")
+ backend_config.getAs<IntegerAttr>(kReductionDim).getInt();
+ float recall_target = backend_config.getAs<FloatAttr>(kRecallTarget)
.getValue()
.convertToFloat();
bool aggregate_to_topk =
- backend_config.getAs<BoolAttr>("aggregate_to_topk").getValue();
+ backend_config.getAs<BoolAttr>(kAggregateToTopk).getValue();
int64_t reduction_input_size_override =
- backend_config.getAs<IntegerAttr>("reduction_input_size_override")
- .getInt();
+ backend_config.getAs<IntegerAttr>(kReductionInputSizeOverride).getInt();
bool is_fallback = has_is_fallback &&
- backend_config.getAs<BoolAttr>("is_fallback").getValue();
+ backend_config.getAs<BoolAttr>(kIsFallback).getValue();
// (C1)
if (args.size() % 2 != 0) {
@@ -2151,7 +2192,7 @@
absl::StatusOr<xla::Literal> literal;
const xla::Literal* literal_ptr = nullptr;
- auto literal_attr = op->getAttrOfType<DenseElementsAttr>(kLiteralAttr);
+ auto literal_attr = op->getAttrOfType<DenseElementsAttr>(kMhloLiteral);
if (literal_attr) {
literal = CreateArrayLiteralFromAttr(literal_attr, {});
if (!literal.ok()) return failure();
@@ -2712,15 +2753,57 @@
LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
xla::XlaComputation condition;
xla::XlaComputation body;
+
// If the results of the while op have a sharding, we use those shardings for
// the corresponding arguments and return shardings in the body and condition.
llvm::SmallVector<std::optional<xla::OpSharding>> res_shardings =
GetResultShardings(ctx.builder->sharding(), op->getNumResults());
+
+ // mhlo.WhileOp has operands and corresponding blocks arguments, but the
+ // computation inside its region-blocks can also use implicit captures of
+ // values defined above.
+ // In order to create the xla parameters for functions corresponding to
+ // WhileOp regions, we need to infer the implicit region-block's arguments,
+ // using all the values used in the region but defined above.
+ //
+ // Note that the body and cond regions of WhileOp share the same block
+ // arguments, so we collect the implicit values for both in a single set.
+ llvm::SetVector<mlir::Value> implicit_operand_set;
+ getUsedValuesDefinedAbove(op->getRegions(), implicit_operand_set);
+ llvm::SmallVector<mlir::Value> implicit_operands =
+ implicit_operand_set.takeVector();
+
+ llvm::SmallVector<xla::XlaOp> implicit_args;
+ if (failed(GetXlaOps(op, implicit_operands, ctx, implicit_args)))
+ return failure();
+
+ // We need to append the shardings of the implicit values to the result
+ // shardings, since the HLO While will have those implcit values as additional
+ // operands and results.
+ llvm::SmallVector<std::optional<xla::OpSharding>> implicit_shardings;
+ if (!implicit_args.empty() && !res_shardings.empty()) {
+ // We only add implicit arg shardings if there are result shardings,
+ // otherwise it means sharding propagation hasn't been done yet.
+ implicit_shardings = GetXlaOpShardings(implicit_args);
+
+ res_shardings.append(implicit_shardings.begin(), implicit_shardings.end());
+ if (std::optional<xla::OpSharding> new_sharding =
+ CreateTupleSharding(res_shardings)) {
+ ctx.builder->SetSharding(*new_sharding);
+ }
+ }
+
+ // The body of the While needs to return the same number of values as its
+ // arguments, as they are carried over to the next iteration. Thus, we pass
+ // the `implicit_operands` as `implicit_results`, to carry them over as is.
if (failed(ctx.converter->LowerRegionAsComputation(
- &op.getBody(), &body, std::nullopt, /*ensure_single_arg=*/true,
- /*arg_shardings=*/res_shardings, /*ret_shardings=*/res_shardings)) ||
+ &op.getBody(), &body, implicit_operands,
+ /*implicit_results=*/implicit_operands,
+ /*ensure_single_arg=*/true, /*arg_shardings=*/res_shardings,
+ /*ret_shardings=*/res_shardings)) ||
failed(ctx.converter->LowerRegionAsComputation(
- &op.getCond(), &condition, std::nullopt,
+ &op.getCond(), &condition, implicit_operands,
+ /*implicit_results=*/{},
/*ensure_single_arg=*/true, /*arg_shardings=*/res_shardings))) {
return failure();
}
@@ -2729,11 +2812,12 @@
// those operands, to be used as sole operand of xla::While.
llvm::SmallVector<xla::XlaOp> operands;
if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure();
+ operands.append(implicit_args.begin(), implicit_args.end());
xla::XlaOp operand = operands[0];
if (operands.size() > 1) operand = Tuple(ctx.builder, operands);
- auto whileop = xla::While(condition, body, operand);
+ xla::XlaOp whileop = xla::While(condition, body, operand);
auto& value_map = *ctx.values;
auto shape_or = whileop.builder()->GetShape(whileop);
@@ -2748,7 +2832,8 @@
}
// mhlo.WhileOp supports multiple returns, untuple all the results of XLA's.
- BuildGetTupleElementsForTupleResults(op, whileop, ctx);
+ BuildGetTupleElementsForTupleResults(
+ op, whileop, ctx, /*num_implicit_results=*/implicit_args.size());
return success();
}
@@ -2824,11 +2909,11 @@
xla::internal::XlaBuilderFriend::GetInstruction(operand);
xla::LayoutProto result_layout =
ExtractLayout(op, bitcast_proto->shape().dimensions_size(),
- "result_layout")
+ kResultLayout)
.ToProto();
xla::LayoutProto source_layout =
ExtractLayout(op, operand_proto->shape().dimensions_size(),
- "source_layout")
+ kSourceLayout)
.ToProto();
xla::gpu::BitcastBackendConfig bitcast_config;
*bitcast_config.mutable_source_layout() = source_layout;
@@ -2879,7 +2964,7 @@
LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
xla::ShapeProto* shape) {
- // In the case of tuples, ShapeProtos can be nested, and so can the mlir
+ // In the case of tuples, Shape protos can be nested, and so can the mlir
// attribute describing the layout. So recurse into the subshapes in both data
// structures in parallel.
if (shape->element_type() == xla::TUPLE) {
@@ -3045,7 +3130,7 @@
LogicalResult ConvertToHloModule::Lower(
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
- xla::XlaBuilder* builder,
+ llvm::ArrayRef<mlir::Value> implicit_results, xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaOp* return_value) {
// Explicitly fail for ops that are not supported for export.
@@ -3092,8 +3177,7 @@
// For infeed ops stemming back to InfeedDequeueTuple, respect the
// layout attribute, and create the corresponding layout in hlo.
if (isa<mhlo::InfeedOp>(inst)) {
- mlir::ArrayAttr layout =
- inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
+ mlir::ArrayAttr layout = inst->getAttrOfType<mlir::ArrayAttr>(kLayout);
if (layout) {
// We propagate layout to the following three ops:
@@ -3222,40 +3306,51 @@
if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) {
// Construct the return value for the function. If there is a single value
// returned, then return it directly, else create a tuple and return.
- unsigned num_return_values = inst->getNumOperands();
+ unsigned num_return_values =
+ inst->getNumOperands() + implicit_results.size();
std::optional<xla::OpSharding> ret_tuple_sharding =
CreateTupleSharding(ret_shardings);
if ((options_.return_tuple && is_entry_function) ||
num_return_values != 1) {
- std::vector<xla::XlaOp> returns(num_return_values);
- for (OpOperand& ret : inst->getOpOperands()) {
- unsigned index = ret.getOperandNumber();
- xla::XlaOp operand;
- if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
- return failure();
+ std::vector<xla::XlaOp> returns;
+ returns.reserve(num_return_values);
+ // NOTE: we can't use operand_range in llvm::concat.
+ for (Value ret : inst->getOperands()) {
+ xla::XlaOp& operand = returns.emplace_back();
+ if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure();
+ }
+ for (Value ret : implicit_results) {
+ xla::XlaOp& operand = returns.emplace_back();
+ if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure();
+ }
+ if (is_entry_function && ret_tuple_sharding) {
+ assert(implicit_results.empty() &&
+ "entry functions shouldn't have implicit results");
+ for (OpOperand& ret : inst->getOpOperands()) {
+ unsigned index = ret.getOperandNumber();
- returns[index] = operand;
- if (!is_entry_function || !ret_tuple_sharding) continue;
+ xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
+ absl::StatusOr<xla::XlaOp> reshape =
+ ReshapeWithCorrectRepresentationAndSharding(
+ builder, returns[index], return_shape,
+ options_.layout_preference_fn,
+ options_.shape_representation_fn, ret_shardings[index],
+ /*fast_mem=*/false);
+ if (!reshape.ok())
+ return inst->emitError() << reshape.status().message();
- xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
- absl::StatusOr<xla::XlaOp> reshape =
- ReshapeWithCorrectRepresentationAndSharding(
- builder, returns[index], return_shape,
- options_.layout_preference_fn, options_.shape_representation_fn,
- ret_shardings[index], /*fast_mem=*/false);
- if (!reshape.ok())
- return inst->emitError() << reshape.status().message();
-
- returns[index] = reshape.value();
+ returns[index] = reshape.value();
+ }
}
xla::XlaScopedShardingAssignment scoped_sharding(builder,
ret_tuple_sharding);
*return_value = xla::Tuple(builder, returns);
} else if (num_return_values == 1) {
+ Value ret = implicit_results.empty() ? inst->getOperand(0)
+ : implicit_results.front();
xla::XlaOp operand;
- if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
- return failure();
+ if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure();
if (ret_tuple_sharding) {
auto tuple = Tuple(builder, {operand});
@@ -3270,6 +3365,59 @@
return success();
}
+ if (auto composite_op = dyn_cast<mhlo::CompositeOp>(inst)) {
+ SmallVector<xla::XlaOp, 1> operands;
+ for (const Value& val : inst->getOperands()) {
+ xla::XlaOp operand;
+ if (failed(GetXlaOp(val, value_map, &operand, inst))) {
+ return failure();
+ }
+ operands.push_back(operand);
+ }
+
+ xla::XlaComputation computation;
+ if (failed(LowerBasicBlockAsFunction(
+ /*block=*/&module_
+ .lookupSymbol<mlir::func::FuncOp>(
+ composite_op.getDecomposition())
+ .getBody()
+ .front(),
+ /*builder=*/
+ module_builder_
+ .CreateSubBuilder(composite_op.getDecomposition().str())
+ .get(),
+ /*is_entry_function=*/false,
+ /*ensure_single_arg=*/false,
+ /*entry_args_same_across_replicas=*/{},
+ /*arg_shardings=*/{}, /*ret_shardings=*/{},
+ /*fe_attrs=*/{}, /*result=*/&computation,
+ /*implicit_operands=*/{}))) {
+ return failure();
+ }
+
+ std::string composite_attributes;
+ llvm::raw_string_ostream(composite_attributes)
+ << composite_op.getCompositeAttributes();
+
+ xla::XlaOp composite_call = xla::CompositeCall(
+ builder, computation, operands, composite_op.getName().str(),
+ composite_attributes, composite_op.getVersion());
+
+ // Use GetTupleElement for multiple outputs
+ unsigned num_results = composite_op.getNumResults();
+ if (num_results > 1) {
+ for (unsigned i = 0; i != num_results; ++i) {
+ value_map[composite_op.getResult(i)] =
+ xla::GetTupleElement(composite_call, i);
+ }
+ } else if (num_results == 1) {
+ value_map[composite_op.getResult(0)] = composite_call;
+ }
+ *return_value = composite_call;
+
+ return success();
+ }
+
inst->emitOpError() << "can't be translated to XLA HLO";
return failure();
}
@@ -3318,7 +3466,7 @@
// Create a sub-builder if this is not the main function.
std::unique_ptr<xla::XlaBuilder> builder_up;
- bool entry_function = f.getName() == "main";
+ bool entry_function = f.getName() == kMain;
if (!entry_function)
builder_up = module_builder_.CreateSubBuilder(f.getName().str());
auto& builder = entry_function ? module_builder_ : *builder_up;
@@ -3332,14 +3480,14 @@
bool any_arg_replicated = false;
entry_args_same_across_replicas.reserve(f.getNumArguments());
for (int64_t i = 0; i < f.getNumArguments(); ++i) {
- auto attr = f.getArgAttrOfType<mlir::BoolAttr>(i, kReplicationAttr);
+ auto attr = f.getArgAttrOfType<mlir::BoolAttr>(i, kMhloReplication);
entry_args_same_across_replicas.push_back(attr != nullptr &&
attr.getValue());
any_arg_replicated |= entry_args_same_across_replicas.back();
// Pass the alias info to the builder so that it will build the alias info
// into the resulting HloModule.
auto buffer_donor =
- f.getArgAttrOfType<mlir::BoolAttr>(i, "jax.buffer_donor");
+ f.getArgAttrOfType<mlir::BoolAttr>(i, kJaxBufferDonor);
if (buffer_donor) {
if (options_.use_tuple_args) {
builder.AddBufferDonor(/*param_number=*/0, /*param_index=*/{i});
@@ -3348,7 +3496,7 @@
}
}
auto aliasing_output =
- f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
+ f.getArgAttrOfType<mlir::IntegerAttr>(i, kTfAliasingOutput);
if (!aliasing_output) continue;
xla::ShapeIndex output_index;
if ((options_.return_tuple && entry_function) || f.getNumResults() != 1) {
@@ -3383,13 +3531,13 @@
return failure();
}
if (auto execution_thread =
- f->getAttrOfType<mlir::StringAttr>("execution_thread")) {
+ f->getAttrOfType<mlir::StringAttr>(kExecutionThread)) {
computation.mutable_proto()->mutable_computations(0)->set_execution_thread(
execution_thread.str());
}
for (int i = 0; i < f.getNumArguments(); ++i) {
if (auto pr =
- f.getArgAttrOfType<mlir::ArrayAttr>(i, kParameterReplicationAttr)) {
+ f.getArgAttrOfType<mlir::ArrayAttr>(i, kMhloParameterReplication)) {
for (auto b : pr.getValue())
for (auto& instr : *computation.mutable_proto()
->mutable_computations(0)
@@ -3494,8 +3642,8 @@
llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
- xla::XlaComputation* result,
- std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands) {
+ xla::XlaComputation* result, llvm::ArrayRef<mlir::Value> implicit_operands,
+ llvm::ArrayRef<mlir::Value> implicit_results) {
// Mapping from the Value to lowered XlaOp.
ValueLoweringMap lowering;
@@ -3519,7 +3667,7 @@
// fuse all the `mlir::Location`s or join the operation name strings with
// ";" (which is essentially the same).
auto tuple =
- xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
+ xla::Parameter(builder, 0, input_shape, kArgTuple, leaf_replication);
builder->ClearSharding();
for (BlockArgument& arg : block->getArguments()) {
@@ -3533,17 +3681,16 @@
// Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp.
llvm::SmallVector<xla::Shape, 4> arg_shapes;
- auto args_size = block->getNumArguments();
- if (implicit_operands) args_size = implicit_operands->size();
+ // Lowering supports mix of block args and implicit operands
+ // Block args must be added before implicit capture operands
+
+ auto args_size = block->getNumArguments() + implicit_operands.size();
arg_shapes.reserve(args_size);
- if (implicit_operands) {
- for (auto implicit_operand : *implicit_operands)
- arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType()));
- } else {
- for (BlockArgument& arg : block->getArguments())
- arg_shapes.push_back(xla::TypeToShape(arg.getType()));
- }
+ for (BlockArgument& arg : block->getArguments())
+ arg_shapes.push_back(xla::TypeToShape(arg.getType()));
+ for (Value implicit_operand : implicit_operands)
+ arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType()));
if (args_size > 1) {
xla::XlaScopedShardingAssignment scoped_sharding(
@@ -3554,26 +3701,23 @@
// but not tuple params. Do the same for tuple params. To do so, either
// fuse all the `mlir::Location`s or join the operation name strings
// with ";" (which is essentially the same).
- auto tuple = xla::Parameter(builder, 0,
- xla::ShapeUtil::MakeTupleShape(arg_shapes),
- "arg_tuple");
+ auto tuple = xla::Parameter(
+ builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes), kArgTuple);
- if (implicit_operands) {
- for (auto [arg_index, implicit_operand] :
- llvm::enumerate(*implicit_operands)) {
- xla::XlaScopedShardingAssignment scoped_sharding(
- builder, arg_shardings.empty() ? std::nullopt
- : arg_shardings[arg_index]);
- lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index);
- }
- } else {
- for (BlockArgument& arg : block->getArguments()) {
- auto num = arg.getArgNumber();
- xla::XlaScopedShardingAssignment scoped_sharding(
- builder,
- arg_shardings.empty() ? std::nullopt : arg_shardings[num]);
- lowering[arg] = xla::GetTupleElement(tuple, num);
- }
+ for (BlockArgument& arg : block->getArguments()) {
+ auto num = arg.getArgNumber();
+ xla::XlaScopedShardingAssignment scoped_sharding(
+ builder,
+ arg_shardings.empty() ? std::nullopt : arg_shardings[num]);
+ lowering[arg] = xla::GetTupleElement(tuple, num);
+ }
+ for (auto [implicit_index, implicit_operand] :
+ llvm::enumerate(implicit_operands)) {
+ int64_t arg_index = block->getNumArguments() + implicit_index;
+ xla::XlaScopedShardingAssignment scoped_sharding(
+ builder,
+ arg_shardings.empty() ? std::nullopt : arg_shardings[arg_index]);
+ lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index);
}
} else if (args_size == 1) {
// Save the location information as a name. For example JAX will set the
@@ -3581,23 +3725,17 @@
xla::XlaScopedShardingAssignment scoped_sharding(
builder,
arg_shardings.empty() ? std::nullopt : arg_shardings.front());
- if (implicit_operands) {
- mlir::Value arg = (*implicit_operands)[0];
- xla::XlaScopedOpMetadataAssignment op_metadata(
- builder, GetOpNameMetadataFromLocation(arg));
- lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
- } else {
- mlir::BlockArgument arg = block->getArgument(0);
- xla::XlaScopedOpMetadataAssignment op_metadata(
- builder, GetOpNameMetadataFromLocation(arg));
- lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
- }
+ mlir::Value arg = implicit_operands.empty() ? block->getArgument(0)
+ : implicit_operands.front();
+ xla::XlaScopedOpMetadataAssignment op_metadata(
+ builder, GetOpNameMetadataFromLocation(arg));
+ lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], kArgPrefix);
} else {
// Applicable only for IfOp or CaseOp. No implicit operands implies no
// xla parameters. In this case, we create an empty tuple as the
// block-parameter.
xla::Parameter(builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes),
- "arg_empty_tuple");
+ kArgEmptyTuple);
}
} else {
for (BlockArgument& arg : block->getArguments()) {
@@ -3616,11 +3754,11 @@
xla::XlaScopedOpMetadataAssignment op_metadata(
builder, GetOpNameMetadataFromLocation(arg));
if (entry_args_same_across_replicas.empty()) {
- lowering[arg] =
- xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
+ lowering[arg] = xla::Parameter(builder, num, shape,
+ absl::StrCat(kArgPrefix, num));
} else {
lowering[arg] = xla::Parameter(
- builder, num, shape, absl::StrCat("Arg_", num),
+ builder, num, shape, absl::StrCat(kArgPrefix, num),
std::vector<bool>(entry_args_same_across_replicas[num],
xla::ShapeUtil::GetLeafCount(shape)));
}
@@ -3631,8 +3769,8 @@
xla::XlaOp return_value;
for (auto& inst : *block)
- if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
- &lowering, &return_value)))
+ if (failed(Lower(&inst, is_entry_function, ret_shardings, implicit_results,
+ builder, &lowering, &return_value)))
return failure();
// Build the XlaComputation and check for failures.
@@ -3648,18 +3786,18 @@
LogicalResult ConvertToHloModule::LowerRegionAsComputation(
mlir::Region* region, xla::XlaComputation* func,
- std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands,
- bool ensure_single_arg,
+ llvm::ArrayRef<mlir::Value> implicit_operands,
+ llvm::ArrayRef<mlir::Value> implicit_results, bool ensure_single_arg,
llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings) {
- std::unique_ptr<xla::XlaBuilder> builder =
- module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
- return LowerBasicBlockAsFunction(®ion->front(), builder.get(),
- /*is_entry_function=*/false,
- /*ensure_single_arg*/ ensure_single_arg,
- /*entry_args_same_across_replicas=*/{},
- arg_shardings, ret_shardings,
- /*fe_attrs=*/{}, func, implicit_operands);
+ std::unique_ptr<xla::XlaBuilder> builder = module_builder_.CreateSubBuilder(
+ absl::StrCat(kRegionPrefix, region_id_++));
+ return LowerBasicBlockAsFunction(
+ ®ion->front(), builder.get(),
+ /*is_entry_function=*/false,
+ /*ensure_single_arg*/ ensure_single_arg,
+ /*entry_args_same_across_replicas=*/{}, arg_shardings, ret_shardings,
+ /*fe_attrs=*/{}, func, implicit_operands, implicit_results);
}
// Runs the PrepareForExport pass on the ModuleOp.
@@ -3704,47 +3842,46 @@
TF_RETURN_IF_ERROR(PrepareForExport(module));
mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext());
- xla::XlaBuilder module_builder("main");
+ xla::XlaBuilder module_builder(kMain);
ConvertToHloModule converter(module, module_builder, options);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
- StringRef module_name = module.getName() ? *module.getName() : "main";
+ StringRef module_name = module.getName() ? *module.getName() : kMain;
hlo_module.set_name(module_name.str());
- if (auto cross_program_prefetches = module->getAttrOfType<mlir::ArrayAttr>(
- "mhlo.cross_program_prefetches")) {
+ if (auto cross_program_prefetches =
+ module->getAttrOfType<mlir::ArrayAttr>(kMhloCrossProgramPrefetches)) {
for (const auto& prefetch :
Convert_cross_program_prefetches(cross_program_prefetches)) {
*hlo_module.add_cross_program_prefetches() = std::move(prefetch);
}
}
- if (auto is_dynamic =
- module->getAttrOfType<mlir::BoolAttr>("mhlo.is_dynamic")) {
+ if (auto is_dynamic = module->getAttrOfType<mlir::BoolAttr>(kMhloIsDynamic)) {
hlo_module.set_is_dynamic(is_dynamic.getValue());
}
if (auto frontend_attributes =
- module->getAttrOfType<DictionaryAttr>(kFrontendAttributesAttr)) {
+ module->getAttrOfType<DictionaryAttr>(kMhloFrontendAttributes)) {
ConstructFrontendAttributesFromAttribute(
frontend_attributes, *hlo_module.mutable_frontend_attributes());
}
- if (auto use_auto_spmd_partitioning = module->getAttrOfType<mlir::BoolAttr>(
- "mhlo.use_auto_spmd_partitioning")) {
+ if (auto use_auto_spmd_partitioning =
+ module->getAttrOfType<mlir::BoolAttr>(kMhloUseAutoSpmdPartitioning)) {
hlo_module.set_use_auto_spmd_partitioning(
use_auto_spmd_partitioning.getValue());
}
- if (auto spmd_output_sharding = module->getAttrOfType<mlir::StringAttr>(
- "mhlo.spmd_output_sharding")) {
+ if (auto spmd_output_sharding =
+ module->getAttrOfType<mlir::StringAttr>(kMhloSpmdOutputSharding)) {
*hlo_module.mutable_spmd_output_sharding() =
*xla::ConvertSharding(spmd_output_sharding.getValue());
}
if (auto input_output_alias =
- module->getAttrOfType<mlir::ArrayAttr>("mhlo.input_output_alias")) {
+ module->getAttrOfType<mlir::ArrayAttr>(kMhloInputOutputAlias)) {
if (std::optional<xla::HloInputOutputAliasProto> input_output_alias_proto =
xla::ConvertInputOutputAlias(input_output_alias.getValue())) {
*hlo_module.mutable_input_output_alias() = *input_output_alias_proto;
}
}
if (auto spmd_parameters_sharding = module->getAttrOfType<mlir::ArrayAttr>(
- "mhlo.spmd_parameters_shardings")) {
+ kMhloSpmdParametersShardings)) {
for (const auto& sharding : spmd_parameters_sharding.getValue()) {
*hlo_module.add_spmd_parameters_shardings() = *xla::ConvertSharding(
mlir::cast<mlir::StringAttr>(sharding).getValue());
@@ -3812,7 +3949,8 @@
} else {
xla::XlaOp return_value;
if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
- /*ret_shardings=*/{}, &builder, &lowering,
+ /*ret_shardings=*/{},
+ /*implicit_results=*/{}, &builder, &lowering,
&return_value)))
return diag_handler.ConsumeStatus();
}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc
index 88f0523..7dad7c3 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc
@@ -22,18 +22,20 @@
namespace mlir {
namespace mhlo {
namespace {
-constexpr char kConfigNumPartitions[] = "mhlo.num_partitions";
-constexpr char kConfigNumReplicas[] = "mhlo.num_replicas";
+
+constexpr char kMhloNumPartitions[] = "mhlo.num_partitions";
+constexpr char kMhloNumReplicas[] = "mhlo.num_replicas";
+
} // namespace
void ExportHloModuleConfig(xla::HloModuleConfig& config,
mlir::ModuleOp module) {
if (auto num_partitions =
- module->getAttrOfType<mlir::IntegerAttr>(kConfigNumPartitions)) {
+ module->getAttrOfType<mlir::IntegerAttr>(kMhloNumPartitions)) {
config.set_num_partitions(num_partitions.getInt());
}
if (auto num_replicas =
- module->getAttrOfType<mlir::IntegerAttr>(kConfigNumReplicas)) {
+ module->getAttrOfType<mlir::IntegerAttr>(kMhloNumReplicas)) {
config.set_replica_count(num_replicas.getInt());
}
}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD
index 37354ba..f947307 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD
@@ -12,9 +12,11 @@
[
"add.mlir",
"case.mlir",
+ "composite.mlir",
"dynamic.mlir",
"export-with-layouts.mlir",
"export.mlir",
+ "export_async.mlir",
"export_and_check_layouts.mlir",
"export_large_constants.mlir",
"export_replicas.mlir",
@@ -36,6 +38,7 @@
"simple.mlir",
"unsupported_type.mlir",
"while.mlir",
+ "while_free_vars.mlir",
],
include = [
"*.mlir",
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir
new file mode 100644
index 0000000..60c5548
--- /dev/null
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir
@@ -0,0 +1,190 @@
+// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+module @composite {
+ // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+ // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+ // CHECK: %Arg_0.3 = f32[] parameter(0)
+ // CHECK: %constant.4 = f32[] constant(2)
+ // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+ // CHECK: }
+ // CHECK: ENTRY %main.7 () -> f32[] {
+ // CHECK: %constant.1 = f32[] constant(42)
+ // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+ // CHECK: }
+ func.func @main() -> tensor<f32> {
+ %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+ %1 = mhlo.composite "foo.bar" %0 {
+ composite_attributes = {
+ n = 1 : i32,
+ tensor = dense<1> : tensor<i32>
+ },
+ decomposition = @add,
+ version = 1 : i32
+ } : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+ }
+ func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+ %1 = mhlo.add %arg0, %0 : tensor<f32>
+ return %1 : tensor<f32>
+ }
+}
+
+// -----
+
+// zero-output composite
+module @composite {
+ //CHECK: HloModule composite, entry_computation_layout={()->()}
+ //CHECK: %return.2 (Arg_0.3: f32[]) -> () {
+ //CHECK: %Arg_0.3 = f32[] parameter(0)
+ //CHECK: ROOT %tuple.4 = () tuple()
+ //CHECK: }
+ //CHECK: ENTRY %main.7 () -> () {
+ //CHECK: %constant.1 = f32[] constant(42)
+ //CHECK: %call.5 = () call(f32[] %constant.1), to_apply=%return.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+ //CHECK: ROOT %tuple.6 = () tuple()
+ //CHECK: }
+ func.func @main() -> () {
+ %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+ "mhlo.composite"(%0) {
+ name = "foo.bar",
+ composite_attributes = {
+ n = 1 : i32,
+ tensor = dense<1> : tensor<i32>
+ },
+ decomposition = @return,
+ version = 1 : i32
+ } : (tensor<f32>) -> ()
+ return
+ }
+ func.func @return(%arg0: tensor<f32>) -> () {
+ return
+ }
+}
+
+// -----
+
+// multi-output composite
+module @composite {
+ //CHECK: HloModule composite, entry_computation_layout={()->(f32[], f32[])}
+ //CHECK: %add.2 (Arg_0.3: f32[]) -> (f32[], f32[]) {
+ //CHECK: %Arg_0.3 = f32[] parameter(0)
+ //CHECK: %constant.4 = f32[] constant(2)
+ //CHECK: %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+ //CHECK: ROOT %tuple.6 = (f32[], f32[]) tuple(f32[] %add.5, f32[] %add.5)
+ //CHECK: }
+ //CHECK: ENTRY %main.11 () -> (f32[], f32[]) {
+ //CHECK: %constant.1 = f32[] constant(42)
+ //CHECK: %call.7 = (f32[], f32[]) call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+ //CHECK: %get-tuple-element.8 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=0
+ //CHECK: %get-tuple-element.9 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=1
+ //CHECK: ROOT %tuple.10 = (f32[], f32[]) tuple(f32[] %get-tuple-element.8, f32[] %get-tuple-element.9)
+ //CHECK: }
+ func.func @main() -> (tensor<f32>, tensor<f32>) {
+ %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+ %result:2 = "mhlo.composite"(%0) {
+ name = "foo.bar",
+ composite_attributes = {
+ n = 1 : i32,
+ tensor = dense<1> : tensor<i32>
+ },
+ decomposition = @add,
+ version = 1 : i32
+ } : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
+ return %result#0, %result#1 : tensor<f32>, tensor<f32>
+ }
+ func.func @add(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
+ %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+ %1 = mhlo.add %arg0, %0 : tensor<f32>
+ return %1, %1 : tensor<f32>, tensor<f32>
+ }
+}
+
+// -----
+
+// optional composite attributes
+module @composite {
+ // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+ // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+ // CHECK: %Arg_0.3 = f32[] parameter(0)
+ // CHECK: %constant.4 = f32[] constant(2)
+ // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+ // CHECK: }
+ // CHECK: ENTRY %main.7 () -> f32[] {
+ // CHECK: %constant.1 = f32[] constant(42)
+ // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"}
+ // CHECK: }
+ func.func @main() -> tensor<f32> {
+ %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+ %1 = mhlo.composite "foo.bar" %0 {
+ decomposition = @add,
+ version = 1 : i32
+ } : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+ }
+ func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+ %1 = mhlo.add %arg0, %0 : tensor<f32>
+ return %1 : tensor<f32>
+ }
+}
+
+// -----
+
+// optional composite version
+module @composite {
+ // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+ // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+ // CHECK: %Arg_0.3 = f32[] parameter(0)
+ // CHECK: %constant.4 = f32[] constant(2)
+ // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+ // CHECK: }
+ // CHECK: ENTRY %main.7 () -> f32[] {
+ // CHECK: %constant.1 = f32[] constant(42)
+ // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="0"}
+ // CHECK: }
+ func.func @main() -> tensor<f32> {
+ %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+ %1 = mhlo.composite "foo.bar" %0 {
+ composite_attributes = {
+ n = 1 : i32,
+ tensor = dense<1> : tensor<i32>
+ },
+ decomposition = @add
+ } : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+ }
+ func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+ %1 = mhlo.add %arg0, %0 : tensor<f32>
+ return %1 : tensor<f32>
+ }
+}
+
+// -----
+
+// optional composite attributes and version
+module @composite {
+ // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+ // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+ // CHECK: %Arg_0.3 = f32[] parameter(0)
+ // CHECK: %constant.4 = f32[] constant(2)
+ // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+ // CHECK: }
+ // CHECK: ENTRY %main.7 () -> f32[] {
+ // CHECK: %constant.1 = f32[] constant(42)
+ // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"}
+ // CHECK: }
+ func.func @main() -> tensor<f32> {
+ %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+ %1 = mhlo.composite "foo.bar" %0 {
+ decomposition = @add
+ } : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+ }
+ func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+ %1 = mhlo.add %arg0, %0 : tensor<f32>
+ return %1 : tensor<f32>
+ }
+}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir
index 680341f..3d44aff 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir
@@ -1,5 +1,5 @@
// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts %s | FileCheck %s
-// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts --via-builder=true %s | FileCheck %s
+// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts --via-builder=true %s | FileCheck %s
#CSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : dense, d1 : compressed),
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir
index dec3e5d..6672e62 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir
@@ -1,6 +1,16 @@
// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s
// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics --via-builder=true %s | FileCheck %s
+// CHECK: HloModule foo
+// CHECK: ENTRY %main
+module @foo {
+ func.func @main(%arg: tensor<i1>) -> tensor<i1> {
+ func.return %arg : tensor<i1>
+ }
+}
+
+// -----
+
// CHECK: HloModule
func.func @main(%arg0: tensor<2xi1>) -> tensor<2xi1> {
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
@@ -109,114 +119,6 @@
// -----
-// CHECK: HloModule
-func.func @all_gather_0(%arg1: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
- %0 = "mhlo.all_gather"(%arg1) {
- all_gather_dim = 1 : i64,
- channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
- shard_count = 4,
- replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
- use_global_device_ids
- } : (tensor<128x32xf32>) -> tensor<128x128xf32>
- return %0 : tensor<128x128xf32>
-}
-
-func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> {
- %0 = "mhlo.async_start"(%arg0) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>
- %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>) -> tensor<128x128xf32>
- return %1 : tensor<128x128xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(f32[128,32] %[[INPUT]])
-// CHECK-SAME: channel_id=1
-// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}}
-// CHECK-SAME: dimensions={1}
-// CHECK-SAME: use_global_device_ids=true
-// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(f32[128,128] %[[OUTPUT]]
-
-// -----
-
-// CHECK: HloModule
-func.func @all_reduce_0(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {execution_thread = "main"} {
- %0 = "mhlo.all_reduce"(%arg0) ({
- ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
- %max = mhlo.maximum %lhs, %rhs : tensor<f32>
- "mhlo.return"(%max) : (tensor<f32>) -> ()
- })
- {
- replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
- channel_handle = #mhlo.channel_handle<
- handle = 5,
- type = 2
- >,
- use_global_device_ids
- } : (tensor<10xf32>) -> tensor<10xf32>
- func.return %0 : tensor<10xf32>
-}
-
-func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
- %0 = "mhlo.async_start"(%arg0) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>) -> !mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>
- %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>) -> tensor<10xf32>
- return %1 : tensor<10xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[10] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(f32[10] %[[INPUT]])
-// CHECK-SAME: channel_id=5
-// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}}
-// CHECK-SAME: use_global_device_ids=true
-// CHECK: ROOT {{.*}} f32[10] all-reduce-done(f32[10] %[[OUTPUT]]
-
-// -----
-
-// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
-func.func @all_reduce_0(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) attributes {execution_thread = "main"} {
- %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({
- ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
- %max = mhlo.maximum %lhs, %rhs : tensor<f32>
- "mhlo.return"(%max) : (tensor<f32>) -> ()
- })
- {
- replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
- channel_handle = #mhlo.channel_handle<
- handle = 5,
- type = 2
- >,
- use_global_device_ids
- } : (tensor<10xf32>, tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>)
- func.return %0#0, %0#1 : tensor<10xf32>, tensor<1xf32>
-}
-
-func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) {
- %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>, tensor<1xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>
- %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>) -> (tensor<10xf32>, tensor<1xf32>)
- return %1#0, %1#1 : tensor<10xf32>, tensor<1xf32>
-}
-
-// -----
-
-// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
-func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} {
- %0:2 = "mhlo.all_gather"(%arg0, %arg1) {
- all_gather_dim = 1 : i64,
- channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
- replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
- use_global_device_ids
- } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
- func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32>
-}
-
-func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) {
- %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>
- %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
- return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32>
-}
-
-// -----
-
func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple<tensor<8x8xf32>, tensor<8x16xf32>> {
// CHECK: %[[ARG0:.*]] = f32[8,2] parameter(0)
// CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1)
@@ -625,30 +527,6 @@
// -----
// CHECK: HloModule
-func.func @collective_permute_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
- %0 = "mhlo.collective_permute"(%arg0) {
- source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
- channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
- } : (tensor<128x32xf32>) -> tensor<128x32xf32>
- func.return %0 : tensor<128x32xf32>
-}
-
-func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
- %0 = "mhlo.async_start"(%arg0) {called_computation = @collective_permute_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
- %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
- return %1 : tensor<128x32xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(f32[128,32] %[[INPUT]])
-// CHECK-SAME: channel_id=1
-// CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}}
-// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(f32[128,32] %[[OUTPUT]]
-
-// -----
-
-// CHECK: HloModule
func.func @main(%arg0 : tensor<5x2xf32>,
%arg1 : tensor<5x5xf32>,
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
@@ -893,27 +771,6 @@
// -----
// CHECK: HloModule
-func.func @copy_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
- %0 = "mhlo.copy"(%arg0) {cross_program_prefetch_index = 0 : i32} : (tensor<128x32xf32>) -> tensor<128x32xf32>
- func.return %0 : tensor<128x32xf32>
-}
-
-func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
- %0 = "mhlo.async_start"(%arg0) {called_computation = @copy_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
- %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
- return %1 : tensor<128x32xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %[[INPUT]])
-// CHECK-SAME: cross_program_prefetch_index=0
-// CHECK: ROOT {{.*}} f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %[[OUTPUT]]
-
-
-// -----
-
-// CHECK: HloModule
func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = mhlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
%1 = "mhlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
@@ -2131,67 +1988,6 @@
// -----
// CHECK: HloModule
-
-func.func @recv_0(%token: !mhlo.token) -> (!mhlo.token) attributes {execution_thread = "main"} {
- %0 = "mhlo.recv"(%token) {
- channel_handle = #mhlo.channel_handle<
- handle = 5,
- type = 1 // Device to device channel
- >,
- is_host_transfer = false
- } : (!mhlo.token) -> (!mhlo.token)
- func.return %0 : !mhlo.token
-}
-
-func.func @main(%token: !mhlo.token) -> (!mhlo.token) {
- %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
- %2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> (!mhlo.token)
- return %2 : !mhlo.token
-}
-
-// CHECK: ENTRY
-// CHECK: [[TOKEN:%.*]] = token[] parameter(0)
-// CHECK: [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5
-// CHECK: ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5
-
-// -----
-
-// CHECK: HloModule
-func.func @recv_0(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) attributes {execution_thread = "main"} {
- %0:2 = "mhlo.recv"(%token) {
- channel_handle = #mhlo.channel_handle<
- handle = 5,
- type = 3 // Host to device channel
- >,
- is_host_transfer = true
- } : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token)
- func.return %0#0, %0#1 : tensor<3x4xi32>, !mhlo.token
-}
-
-func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) {
- %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main", mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>
- %1, %2 = "mhlo.async_done"(%0) {mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>) -> (tensor<3x4xi32>, !mhlo.token)
- return %1, %2 : tensor<3x4xi32>, !mhlo.token
-}
-
-// CHECK: ENTRY
-// CHECK: [[TOKEN:%.*]] = token[] parameter(0)
-// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer
-// CHECK-SAME: sharding={
-// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0}
-// CHECK-SAME: }
-// CHECK: [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer
-// CHECK-SAME: sharding={
-// CHECK-SAME: {maximal device=0}, {maximal device=0}
-// CHECK-SAME: }
-// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0}
-// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0}
-// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])
-
-// -----
-
-
-// CHECK: HloModule
func.func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
%result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>):
@@ -2466,58 +2262,6 @@
// -----
// CHECK: HloModule
-func.func @send_0(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
- %0 = "mhlo.send"(%arg, %token) {
- channel_handle = #mhlo.channel_handle<
- handle = 5,
- type = 2 // Device to host channel
- >,
- is_host_transfer = true
- } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
- func.return %0 : !mhlo.token
-}
-
-func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
- %0 = "mhlo.async_start"(%arg, %token) {called_computation = @send_0, execution_thread = "main"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>
- %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>) -> !mhlo.token
- return %1 : !mhlo.token
-}
-
-// CHECK: ENTRY
-// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0)
-// CHECK: [[TOKEN:%.*]] = token[] parameter(1)
-// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer
-// CHECK: ROOT
-// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5
-
-// -----
-
-// CHECK: HloModule
-func.func @send_0(%token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
- %0 = "mhlo.send"(%token) {
- channel_handle = #mhlo.channel_handle<
- handle = 5,
- type = 1 // Device to device channel
- >
- } : (!mhlo.token) -> !mhlo.token
- func.return %0 : !mhlo.token
-}
-
-func.func @main(%token: !mhlo.token) -> !mhlo.token {
- %0 = "mhlo.async_start"(%token) {called_computation = @send_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
- %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> !mhlo.token
- return %1 : !mhlo.token
-}
-
-// CHECK: ENTRY
-// CHECK: [[TOKEN:%.*]] = token[] parameter(0)
-// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send(() [[UNIT:%.*]], token[] [[TOKEN]]), channel_id=5
-// CHECK: ROOT
-// CHECK-SAME: token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5
-
-// -----
-
-// CHECK: HloModule
func.func @main(%arg: tensor<4x4xf32>, %size: tensor<i32>) -> tensor<4x4xf32> {
%0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i64} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
func.return %0 : tensor<4x4xf32>
@@ -2935,55 +2679,6 @@
window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x2xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
func.return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
}
-// -----
-
-// CHECK: HloModule
-// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
-func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
- attributes {execution_thread = "thread"} {
- %0 = "mhlo.custom_call"(%arg0) {call_target_name = "foo"} : (tensor<10xf32>) -> tensor<20xf32>
- return %0 : tensor<20xf32>
-}
-
-// CHECK: ENTRY
-func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
- // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
- // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]])
- // CHECK-SAME: calls=[[CALLED_COMPUTATION]]
- %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
- // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
- %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
- // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
- %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
- return %2 : tensor<20xf32>
-}
-
-// -----
-
-// CHECK: HloModule
-// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
-func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
- attributes {execution_thread = "thread"} {
- %1 = "mhlo.custom_call"(%arg0) {call_target_name = "bar"} : (tensor<10xf32>) -> tensor<20xf32>
- // CHECK: custom-call
- // CHECK-SAME: custom_call_target="bar"
- return %1 : tensor<20xf32>
-}
-
-// CHECK: ENTRY
-func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
- // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
- // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]],
- // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
- // CHECK: ROOT
- // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
-
- %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
- %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
- %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
- return %2 : tensor<20xf32>
-}
-
// -----
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir
new file mode 100644
index 0000000..70bf10c
--- /dev/null
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir
@@ -0,0 +1,312 @@
+// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s
+
+// CHECK: HloModule
+func.func @all_gather_0(%arg1: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
+ %0 = "mhlo.all_gather"(%arg1) {
+ all_gather_dim = 1 : i64,
+ channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
+ shard_count = 4,
+ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+ use_global_device_ids
+ } : (tensor<128x32xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> {
+ %0 = "mhlo.async_start"(%arg0) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>
+ %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>) -> tensor<128x128xf32>
+ return %1 : tensor<128x128xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(f32[128,32] %[[INPUT]])
+// CHECK-SAME: channel_id=1
+// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}}
+// CHECK-SAME: dimensions={1}
+// CHECK-SAME: use_global_device_ids=true
+// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(f32[128,128] %[[OUTPUT]]
+
+// -----
+
+// CHECK: HloModule
+func.func @all_reduce_0(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {execution_thread = "main"} {
+ %0 = "mhlo.all_reduce"(%arg0) ({
+ ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+ %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+ "mhlo.return"(%max) : (tensor<f32>) -> ()
+ })
+ {
+ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 2
+ >,
+ use_global_device_ids
+ } : (tensor<10xf32>) -> tensor<10xf32>
+ func.return %0 : tensor<10xf32>
+}
+
+func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+ %0 = "mhlo.async_start"(%arg0) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>) -> !mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>
+ %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>) -> tensor<10xf32>
+ return %1 : tensor<10xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[10] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(f32[10] %[[INPUT]])
+// CHECK-SAME: channel_id=5
+// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}}
+// CHECK-SAME: use_global_device_ids=true
+// CHECK: ROOT {{.*}} f32[10] all-reduce-done(f32[10] %[[OUTPUT]]
+
+// -----
+
+// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
+func.func @all_reduce_0(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) attributes {execution_thread = "main"} {
+ %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({
+ ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+ %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+ "mhlo.return"(%max) : (tensor<f32>) -> ()
+ })
+ {
+ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 2
+ >,
+ use_global_device_ids
+ } : (tensor<10xf32>, tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>)
+ func.return %0#0, %0#1 : tensor<10xf32>, tensor<1xf32>
+}
+
+func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) {
+ %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>, tensor<1xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>
+ %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>) -> (tensor<10xf32>, tensor<1xf32>)
+ return %1#0, %1#1 : tensor<10xf32>, tensor<1xf32>
+}
+
+// -----
+
+// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
+func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} {
+ %0:2 = "mhlo.all_gather"(%arg0, %arg1) {
+ all_gather_dim = 1 : i64,
+ channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
+ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+ use_global_device_ids
+ } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
+ func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32>
+}
+
+func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) {
+ %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>
+ %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
+ return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32>
+}
+
+// -----
+
+// CHECK: HloModule
+func.func @collective_permute_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+ %0 = "mhlo.collective_permute"(%arg0) {
+ source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
+ channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+ } : (tensor<128x32xf32>) -> tensor<128x32xf32>
+ func.return %0 : tensor<128x32xf32>
+}
+
+func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
+ %0 = "mhlo.async_start"(%arg0) {called_computation = @collective_permute_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
+ %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
+ return %1 : tensor<128x32xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(f32[128,32] %[[INPUT]])
+// CHECK-SAME: channel_id=1
+// CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}}
+// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(f32[128,32] %[[OUTPUT]]
+
+// -----
+
+// CHECK: HloModule
+func.func @copy_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+ %0 = "mhlo.copy"(%arg0) {cross_program_prefetch_index = 0 : i32} : (tensor<128x32xf32>) -> tensor<128x32xf32>
+ func.return %0 : tensor<128x32xf32>
+}
+
+func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
+ %0 = "mhlo.async_start"(%arg0) {called_computation = @copy_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
+ %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
+ return %1 : tensor<128x32xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %[[INPUT]])
+// CHECK-SAME: cross_program_prefetch_index=0
+// CHECK: ROOT {{.*}} f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %[[OUTPUT]]
+
+// -----
+
+// CHECK: HloModule
+
+func.func @recv_0(%token: !mhlo.token) -> (!mhlo.token) attributes {execution_thread = "main"} {
+ %0 = "mhlo.recv"(%token) {
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 1 // Device to device channel
+ >,
+ is_host_transfer = false
+ } : (!mhlo.token) -> (!mhlo.token)
+ func.return %0 : !mhlo.token
+}
+
+func.func @main(%token: !mhlo.token) -> (!mhlo.token) {
+ %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
+ %2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> (!mhlo.token)
+ return %2 : !mhlo.token
+}
+
+// CHECK: ENTRY
+// CHECK: [[TOKEN:%.*]] = token[] parameter(0)
+// CHECK: [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5
+// CHECK: ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5
+
+// -----
+
+// CHECK: HloModule
+func.func @recv_0(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) attributes {execution_thread = "main"} {
+ %0:2 = "mhlo.recv"(%token) {
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 3 // Host to device channel
+ >,
+ is_host_transfer = true
+ } : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token)
+ func.return %0#0, %0#1 : tensor<3x4xi32>, !mhlo.token
+}
+
+func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) {
+ %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main", mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>
+ %1, %2 = "mhlo.async_done"(%0) {mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>) -> (tensor<3x4xi32>, !mhlo.token)
+ return %1, %2 : tensor<3x4xi32>, !mhlo.token
+}
+
+// CHECK: ENTRY
+// CHECK: [[TOKEN:%.*]] = token[] parameter(0)
+// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer
+// CHECK-SAME: sharding={
+// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0}
+// CHECK-SAME: }
+// CHECK: [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer
+// CHECK-SAME: sharding={
+// CHECK-SAME: {maximal device=0}, {maximal device=0}
+// CHECK-SAME: }
+// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0}
+// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0}
+// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])
+
+// -----
+
+// CHECK: HloModule
+func.func @send_0(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
+ %0 = "mhlo.send"(%arg, %token) {
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 2 // Device to host channel
+ >,
+ is_host_transfer = true
+ } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
+ func.return %0 : !mhlo.token
+}
+
+func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
+ %0 = "mhlo.async_start"(%arg, %token) {called_computation = @send_0, execution_thread = "main"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>
+ %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>) -> !mhlo.token
+ return %1 : !mhlo.token
+}
+
+// CHECK: ENTRY
+// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0)
+// CHECK: [[TOKEN:%.*]] = token[] parameter(1)
+// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer
+// CHECK: ROOT
+// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5
+
+// -----
+
+// CHECK: HloModule
+func.func @send_0(%token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
+ %0 = "mhlo.send"(%token) {
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 1 // Device to device channel
+ >
+ } : (!mhlo.token) -> !mhlo.token
+ func.return %0 : !mhlo.token
+}
+
+func.func @main(%token: !mhlo.token) -> !mhlo.token {
+ %0 = "mhlo.async_start"(%token) {called_computation = @send_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
+ %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> !mhlo.token
+ return %1 : !mhlo.token
+}
+
+// CHECK: ENTRY
+// CHECK: [[TOKEN:%.*]] = token[] parameter(0)
+// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send(() [[UNIT:%.*]], token[] [[TOKEN]]), channel_id=5
+// CHECK: ROOT
+// CHECK-SAME: token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5
+
+// -----
+
+// CHECK: HloModule
+// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
+func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
+ attributes {execution_thread = "thread"} {
+ %0 = "mhlo.custom_call"(%arg0) {call_target_name = "foo"} : (tensor<10xf32>) -> tensor<20xf32>
+ return %0 : tensor<20xf32>
+}
+
+// CHECK: ENTRY
+func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
+ // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
+ // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]])
+ // CHECK-SAME: calls=[[CALLED_COMPUTATION]]
+ %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+ // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
+ %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+ // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
+ %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
+ return %2 : tensor<20xf32>
+}
+
+// -----
+
+// CHECK: HloModule
+// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
+func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
+ attributes {execution_thread = "thread"} {
+ %1 = "mhlo.custom_call"(%arg0) {call_target_name = "bar"} : (tensor<10xf32>) -> tensor<20xf32>
+ // CHECK: custom-call
+ // CHECK-SAME: custom_call_target="bar"
+ return %1 : tensor<20xf32>
+}
+
+// CHECK: ENTRY
+func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
+ // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
+ // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]],
+ // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
+ // CHECK: ROOT
+ // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
+
+ %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+ %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+ %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
+ return %2 : tensor<20xf32>
+}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir
new file mode 100644
index 0000000..3663f92
--- /dev/null
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir
@@ -0,0 +1,89 @@
+// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s -o - | FileCheck %s
+
+// This test verifies that the correct shardings are added when a while loop
+// has free variables.
+
+// CHECK-LABEL: HloModule main
+
+// CHECK: %region_0.7 (arg_tuple.8: (s32[], f32[4], s32[], s32[], f32[4])) -> (s32[], f32[4], s32[], s32[], f32[4]) {
+// CHECK-NEXT: %arg_tuple.8 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0)
+// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK-DAG: %get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=3
+// CHECK-DAG: %get-tuple-element.13 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=4, sharding={devices=[4]<=[4]}
+// CHECK-DAG: %add.14 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.12)
+// CHECK-DAG: %add.15 = f32[4] add(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.13)
+// CHECK: ROOT %tuple.16 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %add.14, f32[4] %add.15, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[4] %get-tuple-element.13)
+// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+
+// CHECK: %region_1.17 (arg_tuple.18: (s32[], f32[4], s32[], s32[], f32[4])) -> pred[] {
+// CHECK-NEXT: %arg_tuple.18 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0)
+// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK: %get-tuple-element.21 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.18), index=2
+// CHECK-NEXT: ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.19, s32[] %get-tuple-element.21), direction=LT
+
+// CHECK: ENTRY %main.28 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] {
+// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0)
+// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
+// CHECK-NEXT: %constant.4 = s32[] constant(0)
+// CHECK-NEXT: %constant.5 = s32[] constant(1)
+// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2)
+// CHECK-NEXT: %tuple.6 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %constant.4, s32[] %constant.5, f32[4] %Arg_2.3)
+// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK-NEXT: %while.25 = (s32[], f32[4], s32[], s32[], f32[4]) while((s32[], f32[4], s32[], s32[], f32[4]) %tuple.6), condition=%region_1.17, body=%region_0.7
+// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK-NEXT: %get-tuple-element.26 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=0, sharding={replicated}
+// CHECK-NEXT: ROOT %get-tuple-element.27 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
+
+func.func @main(%arg0: tensor<i32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> tensor<4xf32> {
+ %0 = mhlo.constant dense<0> : tensor<i32>
+ %1 = mhlo.constant dense<1> : tensor<i32>
+ %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor<i32>, tensor<4xf32>
+ attributes {mhlo.sharding = "{{replicated},{devices=[2,2]<=[4] last_tile_dim_replicate}}"}
+ cond {
+ %3 = mhlo.compare LT, %iterArg, %0 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ mhlo.return %3 : tensor<i1>
+ } do {
+ %3 = mhlo.add %iterArg, %1 : tensor<i32>
+ %4 = mhlo.add %iterArg_0, %arg2 : tensor<4xf32>
+ mhlo.return %3, %4: tensor<i32>, tensor<4xf32>
+ }
+ func.return %2#1 : tensor<4xf32>
+}
+
+// -----
+
+// This test verifies that a value captured multiple times is only lifted once
+// and all its uses are replaced. Also verifies that no sharding is added to
+// region parameters or root when the while doesn't have a sharding.
+
+// CHECK-LABEL: HloModule main
+
+// CHECK: %region_0.5 (arg_tuple.6: (s32[], f32[4], s32[])) -> (s32[], f32[4], s32[]) {
+// CHECK-NEXT: %arg_tuple.6 = (s32[], f32[4], s32[]) parameter(0)
+// CHECK: %get-tuple-element.9 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.6), index=2
+// CHECK: %add.10 = s32[] add(s32[] %get-tuple-element.7, s32[] %get-tuple-element.9)
+// CHECK: ROOT %tuple.11 = (s32[], f32[4], s32[]) tuple(s32[] %add.10, f32[4] %get-tuple-element.8, s32[] %get-tuple-element.9)
+
+// CHECK: %region_1.12 (arg_tuple.13: (s32[], f32[4], s32[])) -> pred[] {
+// CHECK-NEXT: %arg_tuple.13 = (s32[], f32[4], s32[]) parameter(0)
+// CHECK: %get-tuple-element.16 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.13), index=2
+// CHECK: ROOT %compare.17 = pred[] compare(s32[] %get-tuple-element.14, s32[] %get-tuple-element.16), direction=LT
+
+// CHECK: ENTRY %main.21 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: s32[]) -> f32[4] {
+// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0)
+// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
+// CHECK-NEXT: %Arg_2.3 = s32[] parameter(2)
+// CHECK-NEXT: %tuple.4 = (s32[], f32[4], s32[]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %Arg_2.3)
+// CHECK-NEXT: %while.18 = (s32[], f32[4], s32[]) while((s32[], f32[4], s32[]) %tuple.4), condition=%region_1.12, body=%region_0.5
+
+func.func @main(%arg0: tensor<i32>, %arg1: tensor<4xf32>, %arg2: tensor<i32>) -> tensor<4xf32> {
+ %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor<i32>, tensor<4xf32>
+ cond {
+ %3 = mhlo.compare LT, %iterArg, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ mhlo.return %3 : tensor<i1>
+ } do {
+ %3 = mhlo.add %iterArg, %arg2 : tensor<i32>
+ mhlo.return %3, %iterArg_0: tensor<i32>, tensor<4xf32>
+ }
+ func.return %2#1 : tensor<4xf32>
+}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc
index 8cff1c9..7c07582 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc
@@ -16,26 +16,42 @@
#include <memory>
#include <utility>
+#include <vector>
#include "absl/log/log.h"
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/Value.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "xla/client/xla_builder.h"
+#include "xla/client/xla_computation.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_input_output_alias_config.h"
+#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/mlir_hlo/mhlo/IR/register.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_proto_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication";
@@ -123,6 +139,8 @@
mlir::cast<mlir::BoolAttr>(b).getValue());
auto hlo_module = computation.proto();
+ mlir::StringRef module_name = module.getName() ? *module.getName() : "main";
+ hlo_module.set_name(module_name.str());
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
return absl::OkStatus();
diff --git a/third_party/xla/xla/tsl/c/tsl_status.cc b/third_party/xla/xla/tsl/c/tsl_status.cc
index fea8943..75b9481 100644
--- a/third_party/xla/xla/tsl/c/tsl_status.cc
+++ b/third_party/xla/xla/tsl/c/tsl_status.cc
@@ -35,7 +35,7 @@
return;
}
s->status =
- Status(static_cast<absl::StatusCode>(code), tsl::StringPiece(msg));
+ Status(static_cast<absl::StatusCode>(code), absl::string_view(msg));
}
void TSL_SetPayload(TSL_Status* s, const char* key, const char* value) {
diff --git a/third_party/xla/xla/tsl/concurrency/BUILD b/third_party/xla/xla/tsl/concurrency/BUILD
index 0363d15..578b6ce 100644
--- a/third_party/xla/xla/tsl/concurrency/BUILD
+++ b/third_party/xla/xla/tsl/concurrency/BUILD
@@ -31,7 +31,6 @@
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:logging",
- "@local_tsl//tsl/platform:platform_port",
],
)
@@ -72,6 +71,7 @@
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:test",
+ "@local_tsl//tsl/platform:test_benchmark",
"@local_tsl//tsl/platform:test_main",
],
)
diff --git a/third_party/xla/xla/tsl/concurrency/async_value.cc b/third_party/xla/xla/tsl/concurrency/async_value.cc
index dd26e04..fa3f058 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value.cc
+++ b/third_party/xla/xla/tsl/concurrency/async_value.cc
@@ -63,12 +63,6 @@
std::atomic<size_t> AsyncValue::total_allocated_async_values_;
-const AsyncValue::TypeInfo& AsyncValue::GetTypeInfo() const {
- TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton();
- DCHECK_NE(type_id_, 0);
- return (*type_info_table)[type_id_ - 1];
-}
-
// This is called when the value is set into the ConcreteAsyncValue buffer, or
// when the IndirectAsyncValue is forwarded to an available AsyncValue, and we
// need to change our state and clear out the notifications. The current state
diff --git a/third_party/xla/xla/tsl/concurrency/async_value.h b/third_party/xla/xla/tsl/concurrency/async_value.h
index 372db4f..30e0d8e 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value.h
+++ b/third_party/xla/xla/tsl/concurrency/async_value.h
@@ -21,17 +21,18 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
-#include <iostream>
#include <memory>
+#include <new>
#include <type_traits>
#include <utility>
+#include "absl/base/optimization.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "xla/tsl/concurrency/concurrent_vector.h"
#include "xla/tsl/concurrency/ref_count.h"
-#include "tsl/platform/mem.h"
+#include "tsl/platform/logging.h"
namespace tsl {
@@ -164,8 +165,8 @@
// process. This is intended for debugging/assertions only, and shouldn't be
// used for mainline logic in the runtime.
static size_t GetNumAsyncValueInstances() {
- assert(AsyncValueAllocationTrackingEnabled() &&
- "AsyncValue instance tracking disabled!");
+ DCHECK(AsyncValueAllocationTrackingEnabled())
+ << "AsyncValue instance tracking disabled!";
return total_allocated_async_values_.load(std::memory_order_relaxed);
}
@@ -418,8 +419,9 @@
private:
// Information about a ConcreteAsyncValue<T> subclass.
struct TypeInfo {
- // Destructor returns the size of the derived AsyncValue to be deallocated.
- using DestructorFn = size_t (*)(AsyncValue*);
+ // Destructor returns the size and alignment of the derived AsyncValue to
+ // be deallocated.
+ using DestructorFn = std::pair<size_t, std::align_val_t> (*)(AsyncValue*);
using GetErrorFn = const absl::Status& (*)(const AsyncValue*);
using SetErrorFn = void (*)(AsyncValue*, absl::Status);
using HasDataFn = bool (*)(const AsyncValue*);
@@ -433,9 +435,9 @@
template <typename Derived>
static TypeInfo MakeTypeInfo() {
return TypeInfo{
- [](AsyncValue* v) {
+ [](AsyncValue* v) -> std::pair<size_t, std::align_val_t> {
static_cast<Derived*>(v)->~Derived();
- return sizeof(Derived);
+ return {sizeof(Derived), std::align_val_t{alignof(Derived)}};
},
[](const AsyncValue* v) -> const absl::Status& {
return static_cast<const Derived*>(v)->GetError();
@@ -454,14 +456,17 @@
template <typename T>
const T& GetConcreteValue() const;
- // Get the TypeInfo instance for this AsyncValue.
- const TypeInfo& GetTypeInfo() const;
-
- using TypeInfoTable = internal::ConcurrentVector<TypeInfo>;
-
// Returns the TypeInfoTable instance (there is one per process).
+ using TypeInfoTable = internal::ConcurrentVector<TypeInfo>;
static TypeInfoTable* GetTypeInfoTableSingleton();
+ // Get the TypeInfo instance for this AsyncValue.
+ const TypeInfo& GetTypeInfo() const {
+ TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton();
+ DCHECK_NE(type_id_, 0) << "TypeId must be set";
+ return (*type_info_table)[type_id_ - 1];
+ }
+
void EnqueueWaiter(absl::AnyInvocable<void()> waiter,
WaitersAndState old_value);
@@ -569,7 +574,7 @@
// Return the underlying error. IsError() must return true.
const absl::Status& GetError() const {
- assert(IsError());
+ DCHECK(IsError());
return data_store_.error();
}
@@ -579,12 +584,12 @@
}
const T& get() const {
- assert(HasData());
+ DCHECK(HasData());
return data_store_.data();
}
T& get() {
- assert(HasData());
+ DCHECK(HasData());
return data_store_.data();
}
@@ -629,7 +634,7 @@
}
void SetError(State s, absl::Status status) {
- assert(s == State::kUnconstructed || s == State::kConstructed);
+ DCHECK(s == State::kUnconstructed || s == State::kConstructed);
if (s == State::kConstructed) {
data_.~T();
}
@@ -677,13 +682,13 @@
}
void SetError(State s, absl::Status status) {
- assert(!error_);
+ DCHECK(!error_);
error_ = std::make_unique<absl::Status>(std::move(status));
}
template <typename... Args>
void EmplaceData(Args&&... args) {
- assert(!HasData());
+ DCHECK(!HasData());
new (&data_) T(std::forward<Args>(args)...);
has_data_ = true;
}
@@ -807,8 +812,8 @@
};
inline AsyncValue::~AsyncValue() {
- assert(waiters_and_state_.load().waiter() == nullptr &&
- "An async value with waiters should never have refcount of zero");
+ DCHECK_EQ(waiters_and_state_.load().waiter(), nullptr)
+ << "An async value with waiters should never have refcount of zero";
if (AsyncValueAllocationTrackingEnabled() && is_refcounted_)
total_allocated_async_values_.fetch_sub(1, std::memory_order_relaxed);
@@ -853,7 +858,7 @@
#endif
if (count > 0) {
- assert(refcount_.load(std::memory_order_relaxed) > 0);
+ DCHECK_GT(refcount_.load(std::memory_order_relaxed), 0);
// Increasing the reference counter can always be done with
// memory_order_relaxed: New references to an object can only be formed from
// an existing reference, and passing an existing reference from one thread
@@ -871,7 +876,7 @@
if (!is_refcounted_) return;
#endif
- assert(refcount_.load(std::memory_order_relaxed) > 0);
+ DCHECK_GT(refcount_.load(std::memory_order_relaxed), 0);
// We expect that `count` argument will often equal the actual reference count
// here; optimize for that. If `count` == reference count, only an acquire
// barrier is needed to prevent the effects of the deletion from leaking
@@ -894,8 +899,8 @@
const T& AsyncValue::GetConcreteValue() const {
// Make sure both T (the stored type) and BaseT have vtable_ptr or
// neither have the vtable_ptr.
- assert(std::is_polymorphic<T>::value == has_vtable_);
- assert(IsTypeIdCompatible<T>() && "Incorrect accessor");
+ DCHECK_EQ(std::is_polymorphic<T>::value, has_vtable_);
+ DCHECK(IsTypeIdCompatible<T>()) << "Incorrect accessor";
const char* this_ptr = reinterpret_cast<const char*>(this);
return *reinterpret_cast<const T*>(this_ptr + AsyncValue::kDataOffset);
@@ -909,32 +914,27 @@
switch (kind()) {
case Kind::kConcrete:
#ifndef NDEBUG
- // TODO(ezhulenev): Use `DLOG_IF` when absl logging is available.
if (!GetTypeInfo().has_data(this)) {
- std::cerr << "Cannot call get() when ConcreteAsyncValue" // Crash OK
- << " isn't constructed; state: " << s.DebugString() << ","
- << " error message: "
- << (IsError() ? GetError().message() : "None");
- std::abort();
+ LOG(FATAL) << "Cannot call get() when ConcreteAsyncValue"
+ << " isn't constructed; state: " << s.DebugString() << ","
+ << " error message: "
+ << (IsError() ? GetError().message() : "None");
}
#endif // NDEBUG
return GetConcreteValue<T>();
case Kind::kIndirect:
#ifndef NDEBUG
- // TODO(ezhulenev): Use `DLOG_IF` when absl logging is available.
if (s != State::kConcrete) {
- std::cerr << "Cannot call get() when IndirectAsyncValue" // Crash OK
- << " isn't concrete; state: " << s.DebugString() << ","
- << " error message: "
- << (IsError() ? GetError().message() : "None");
- std::abort();
+ LOG(FATAL) << "Cannot call get() when IndirectAsyncValue"
+ << " isn't concrete; state: " << s.DebugString() << ","
+ << " error message: "
+ << (IsError() ? GetError().message() : "None");
}
#endif // NDEBUG
auto* iv_value = static_cast<const IndirectAsyncValue*>(this)->value_;
- assert(iv_value && "Indirect value not resolved");
+ DCHECK(iv_value) << "Indirect value not resolved";
return iv_value->get<T>();
}
- assert(false && "unexpected AsyncValue kind");
}
template <typename T>
@@ -943,14 +943,14 @@
}
inline void AsyncValue::SetStateConcrete() {
- assert(IsConstructed() && kind() == Kind::kConcrete);
+ DCHECK(IsConstructed() && kind() == Kind::kConcrete);
NotifyAvailable(State::kConcrete);
}
template <typename T, typename... Args>
void AsyncValue::emplace(Args&&... args) {
- assert(GetTypeId<T>() == type_id_ && "Incorrect accessor");
- assert(IsUnconstructed() && kind() == Kind::kConcrete);
+ DCHECK_EQ(GetTypeId<T>(), type_id_) << "Incorrect accessor";
+ DCHECK(IsUnconstructed() && kind() == Kind::kConcrete);
static_cast<internal::ConcreteAsyncValue<T>*>(this)->emplace(
std::forward<Args>(args)...);
@@ -968,7 +968,7 @@
// Unresolved IndirectAsyncValues are not errors.
if (!iv_value) return nullptr;
- assert(iv_value->kind() != Kind::kIndirect);
+ DCHECK(iv_value->kind() != Kind::kIndirect);
return iv_value->GetErrorIfPresent();
}
}
@@ -976,7 +976,7 @@
inline const absl::Status& AsyncValue::GetError() const {
auto* result = GetErrorIfPresent();
- assert(result && "Cannot call GetError() when error isn't available.");
+ DCHECK(result) << "Cannot call GetError() when error isn't available.";
return *result;
}
@@ -988,7 +988,7 @@
auto old_value = waiters_and_state_.load(std::memory_order_acquire);
if (old_value.state() == State::kConcrete ||
old_value.state() == State::kError) {
- assert(old_value.waiter() == nullptr);
+ DCHECK_EQ(old_value.waiter(), nullptr);
waiter();
return;
}
@@ -1003,7 +1003,7 @@
auto old_value = waiters_and_state_.load(std::memory_order_acquire);
if (old_value.state() == State::kConcrete ||
old_value.state() == State::kError) {
- assert(old_value.waiter() == nullptr);
+ DCHECK_EQ(old_value.waiter(), nullptr);
executor.Execute(std::forward<Waiter>(waiter));
return;
}
@@ -1018,17 +1018,30 @@
// Copy `is_refcounted` flag before destroying the async value object.
bool was_ref_counted = is_refcounted_;
- if (kind() == Kind::kIndirect) {
+ if (ABSL_PREDICT_FALSE(kind() == Kind::kIndirect)) {
// Depending on what the benchmarks say, it might make sense to remove this
// explicit check and instead make ~IndirectAsyncValue go through the
// GetTypeInfo().destructor case below.
static_cast<IndirectAsyncValue*>(this)->~IndirectAsyncValue();
- if (was_ref_counted) port::AlignedFree(this);
+ if (was_ref_counted) {
+#if defined(__cpp_sized_deallocation)
+ ::operator delete(this, sizeof(IndirectAsyncValue),
+ std::align_val_t{alignof(IndirectAsyncValue)});
+#else // defined(__cpp_sized_deallocation)
+ ::operator delete(this, std::align_val_t{alignof(IndirectAsyncValue)});
+#endif // defined(__cpp_sized_deallocation)
+ }
return;
}
- GetTypeInfo().destructor(this);
- if (was_ref_counted) port::AlignedFree(this);
+ auto [size, alignment] = GetTypeInfo().destructor(this);
+ if (was_ref_counted) {
+#if defined(__cpp_sized_deallocation)
+ ::operator delete(this, size, alignment);
+#else // defined(__cpp_sized_deallocation)
+ ::operator delete(this, alignment);
+#endif // defined(__cpp_sized_deallocation)
+ }
}
inline bool AsyncValue::IsUnique() const {
diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref.h b/third_party/xla/xla/tsl/concurrency/async_value_ref.h
index 625c908..65fb655 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value_ref.h
+++ b/third_party/xla/xla/tsl/concurrency/async_value_ref.h
@@ -18,6 +18,7 @@
#include <algorithm>
#include <cstddef>
+#include <new>
#include <string_view>
#include <type_traits>
#include <utility>
@@ -32,7 +33,6 @@
#include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/logging.h"
-#include "tsl/platform/mem.h"
namespace tsl {
@@ -88,8 +88,8 @@
AsyncValueRef(const AsyncValueRef&) = default;
AsyncValueRef& operator=(const AsyncValueRef&) = default;
- AsyncValueRef(AsyncValueRef&&) = default;
- AsyncValueRef& operator=(AsyncValueRef&&) = default;
+ AsyncValueRef(AsyncValueRef&&) noexcept = default;
+ AsyncValueRef& operator=(AsyncValueRef&&) noexcept = default;
explicit AsyncValueRef(RCReference<AsyncValue> value)
: value_(std::move(value)) {}
@@ -135,7 +135,7 @@
// Return true if the AsyncValue contains a concrete value.
bool IsConcrete() const { return value_->IsConcrete(); }
- // Return true if state is kUnconstructed.
+ // Return true if state is `kUnconstructed`.
bool IsUnconstructed() const { return value_->IsUnconstructed(); }
// Return the stored value. The AsyncValueRef must be available.
@@ -876,7 +876,7 @@
template <typename T, typename... Args>
T* AllocateAndConstruct(Args&&... args) {
- void* buf = port::AlignedMalloc(sizeof(T), alignof(T));
+ void* buf = ::operator new(sizeof(T), std::align_val_t{alignof(T)});
return PlacementConstruct<T, Args...>(buf, std::forward<Args>(args)...);
}
@@ -953,13 +953,13 @@
AsyncValueOwningRef(const AsyncValueOwningRef&) = delete;
AsyncValueOwningRef& operator=(const AsyncValueOwningRef&) = delete;
- AsyncValueOwningRef& operator=(AsyncValueOwningRef&& other) {
+ AsyncValueOwningRef& operator=(AsyncValueOwningRef&& other) noexcept {
Destroy();
std::swap(value_, other.value_);
return *this;
}
- AsyncValueOwningRef(AsyncValueOwningRef&& other) {
+ AsyncValueOwningRef(AsyncValueOwningRef&& other) noexcept {
Destroy();
std::swap(value_, other.value_);
}
diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc
index 2c4ce86..646b05b 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc
+++ b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc
@@ -16,6 +16,7 @@
#include "xla/tsl/concurrency/async_value_ref.h"
#include <any>
+#include <array>
#include <atomic>
#include <cstddef>
#include <cstdint>
@@ -30,6 +31,7 @@
#include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/test.h"
+#include "tsl/platform/test_benchmark.h"
namespace tsl {
@@ -787,4 +789,25 @@
EXPECT_EQ(counter, 1 + 2 + 3);
}
+//===----------------------------------------------------------------------===//
+// Performance benchmarks below
+//===----------------------------------------------------------------------===//
+
+template <size_t size>
+static void BM_MakeConstructed(benchmark::State& state) {
+ for (auto _ : state) {
+ auto ref = MakeConstructedAsyncValueRef<std::array<char, size>>();
+ benchmark::DoNotOptimize(ref);
+ }
+}
+
+BENCHMARK(BM_MakeConstructed<1>);
+BENCHMARK(BM_MakeConstructed<4>);
+BENCHMARK(BM_MakeConstructed<8>);
+BENCHMARK(BM_MakeConstructed<16>);
+BENCHMARK(BM_MakeConstructed<32>);
+BENCHMARK(BM_MakeConstructed<64>);
+BENCHMARK(BM_MakeConstructed<128>);
+BENCHMARK(BM_MakeConstructed<256>);
+
} // namespace tsl
diff --git a/third_party/xla/xla/tsl/concurrency/ref_count.h b/third_party/xla/xla/tsl/concurrency/ref_count.h
index 664dd95..4ea65ee 100644
--- a/third_party/xla/xla/tsl/concurrency/ref_count.h
+++ b/third_party/xla/xla/tsl/concurrency/ref_count.h
@@ -124,7 +124,7 @@
public:
RCReference() : pointer_(nullptr) {}
- RCReference(RCReference&& other) : pointer_(other.pointer_) {
+ RCReference(RCReference&& other) noexcept : pointer_(other.pointer_) {
other.pointer_ = nullptr;
}
@@ -132,7 +132,7 @@
if (pointer_) pointer_->AddRef();
}
- RCReference& operator=(RCReference&& other) {
+ RCReference& operator=(RCReference&& other) noexcept {
reset(other.pointer_);
other.pointer_ = nullptr;
return *this;
@@ -187,7 +187,7 @@
explicit operator bool() const { return pointer_ != nullptr; }
- void swap(RCReference& other) {
+ void swap(RCReference& other) noexcept {
using std::swap;
swap(pointer_, other.pointer_);
}
@@ -256,7 +256,7 @@
}
// For ADL style swap.
template <typename T>
-void swap(RCReference<T>& a, RCReference<T>& b) {
+void swap(RCReference<T>& a, RCReference<T>& b) noexcept {
a.swap(b);
}
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD
index 5e727c6..a198b4c 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD
@@ -13,11 +13,13 @@
cc_library(
name = "coordination_service_error_util",
+ srcs = ["coordination_service_error_util.cc"],
hdrs = ["coordination_service_error_util.h"],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
+ "@local_tsl//tsl/platform:regexp",
"@local_tsl//tsl/protobuf:coordination_service_proto_cc",
],
)
@@ -28,6 +30,7 @@
deps = [
":coordination_service_error_util",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc
index dd53c80..c3ba0da 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc
@@ -246,10 +246,6 @@
void Disconnect(uint64_t grace_period_duration_us);
absl::Status RecordHeartbeat(uint64_t task_incarnation);
int64_t TimeSinceLastHeartbeatMs();
- // This denotes the deadline after which we stop accepting heartbeats from a
- // disconnected task. This grace period accounts for the lag time between
- // the service recording the state change and the agent stopping heartbeats.
- uint64_t GetDisconnectedGracePeriodMicros();
void SetError(absl::Status status);
DeviceInfo GetDeviceInfo() { return devices_; }
void CollectDeviceInfo(const DeviceInfo& devices) { devices_ = devices; }
@@ -260,6 +256,11 @@
absl::flat_hash_set<std::string> GetOngoingBarriers();
void JoinBarrier(std::string_view barrier_id);
void ExitBarrier(std::string_view barrier_id);
+ // Returns true if the task has been disconnected beyond the grace period
+ // and no further agent requests are expected. Note that the grace period
+ // accounts for the lag time between the service recording the state change
+ // and the agent stopping heartbeats/error polling.
+ bool IsDisconnectedBeyondGracePeriod();
private:
// Incarnation ID for CPU:0 on remote task.
@@ -269,9 +270,10 @@
absl::Status status_;
absl::Mutex last_heartbeat_mu_;
uint64_t last_heartbeat_us_ ABSL_GUARDED_BY(last_heartbeat_mu_);
- // This denotes the deadline after which we stop accepting heartbeats from a
- // disconnected task. This grace period accounts for the lag time between
- // the service recording the state change and the agent stopping heartbeats.
+ // This denotes the deadline after which we stop accepting heartbeats or
+ // error polling requests from a disconnected task. This grace period
+ // accounts for the lag time between the service recording the state change
+ // and the agent stopping heartbeats/error polling.
uint64_t disconnect_grace_period_us_ = 0;
DeviceInfo devices_;
// For now, we assume there won't be many simultaneous barriers so we simply
@@ -392,11 +394,6 @@
return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000;
}
-uint64_t CoordinationServiceStandaloneImpl::TaskState::
- GetDisconnectedGracePeriodMicros() {
- return disconnect_grace_period_us_;
-}
-
absl::flat_hash_set<std::string>
CoordinationServiceStandaloneImpl::TaskState::GetOngoingBarriers() {
return ongoing_barriers_for_task_;
@@ -412,6 +409,12 @@
ongoing_barriers_for_task_.erase(barrier_id);
}
+bool CoordinationServiceStandaloneImpl::TaskState::
+ IsDisconnectedBeyondGracePeriod() {
+ return GetState() == CoordinatedTaskState::TASKSTATE_DISCONNECTED &&
+ Env::Default()->NowMicros() > disconnect_grace_period_us_;
+}
+
void CoordinationServiceStandaloneImpl::SetDeviceAggregationFunction(
std::function<DeviceInfo(const DeviceInfo& devices)>
post_aggregate_device_fn) {
@@ -551,7 +554,8 @@
absl::StrAppend(
&error_message,
"Total Number of tasks already at the barrier: ",
- barrier->tasks_at_barrier.size() - pending_task_count,
+ barrier->tasks_at_barrier.size() - pending_task_count, "/",
+ barrier->tasks_at_barrier.size(),
". Timed out task names:\n%s", pending_tasks);
}
const absl::Status error = MakeCoordinationError(
@@ -890,14 +894,10 @@
}
if (!cluster_state_[task_name]->GetStatus().ok()) {
return cluster_state_[task_name]->GetStatus();
- } else if (cluster_state_[task_name]->GetState() ==
- CoordinatedTaskState::TASKSTATE_DISCONNECTED &&
- // We accept heartbeats for a short grace period to account for
- // the lag time between the service recording the state change
- // and the agent stopping heartbeats.
- Env::Default()->NowMicros() >
- cluster_state_[task_name]
- ->GetDisconnectedGracePeriodMicros()) {
+ } else if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) {
+ // We accept heartbeats for a short grace period to account for the lag
+ // time between the service recording the state change and the agent
+ // stopping heartbeats.
return MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat(
"Task with task_name=", task_name,
" must be registered before sending heartbeat messages")));
@@ -1193,11 +1193,26 @@
return;
}
- if (cluster_state_[task_name]->GetState() !=
- CoordinatedTaskState::TASKSTATE_CONNECTED) {
- done(MakeCoordinationError(absl::InvalidArgumentError(
+ // On the agent side, the error polling thread will only be started when the
+ // task is connected, but by the time the request is processed by the service,
+ // the task state may have changed due to actions by the service or the main
+ // thread on the agent. As a way to handle this, we accept error polling for a
+ // short grace period. After the grace period, the service will return an
+ // error to the task.
+ if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) {
+ done(MakeCoordinationError(absl::FailedPreconditionError(
absl::StrCat("Task (", task_name,
- ") that has not been registered polling for errors."))));
+ ") that has not been registered or has disconnected "
+ "polling for errors."))));
+ return;
+ }
+
+ if (cluster_state_[task_name]->GetState() ==
+ CoordinatedTaskState::TASKSTATE_ERROR) {
+ done(MakeCoordinationError(absl::FailedPreconditionError(absl::StrCat(
+ "Task (", task_name,
+ ") that is already in error state polling for errors. Current error: ",
+ cluster_state_[task_name]->GetStatus().ToString()))));
return;
}
@@ -1471,9 +1486,11 @@
return;
}
}
- LOG(ERROR) << "An error is encountered. Sending the error as a response to "
- "all error polling requests: "
- << error;
+ if (!absl::IsCancelled(error)) {
+ VLOG(2) << "An error is encountered. Sending the error as a response to "
+ "all error polling requests: "
+ << error;
+ }
std::vector<std::string> missing_tasks;
{
absl::MutexLock l(&state_mu_);
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
index 8bcf451..617da59 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
@@ -412,21 +412,19 @@
}
void CoordinationServiceAgentImpl::StartPollingForError() {
- LOG(INFO) << "Polling error from coordination service. This thread "
- "will run until an error is encountered or the agent is "
- "shutdown.";
+ LOG(INFO) << "Polling for error from coordination service. This thread will "
+ "run until an error is encountered or the agent is shutdown.";
absl::Status status = PollForError();
CHECK(!status.ok()) << "PollForError returned OK status. Should "
"always return an error.";
if (absl::IsCancelled(status)) {
- LOG(INFO) << "Stop polling error from coordination service because "
- "the service or the agent is shutting down."
- << status;
+ LOG(INFO) << "Cancelling error polling because the service or the agent is "
+ "shutting down.";
+ // Return early and there is no need to set error.
return;
}
- LOG(INFO) << "Error returned from coordination service after polling: "
- << status;
-
+ LOG(ERROR) << "An error is returned from coordination service (this can be "
+ "an error from this or another task).";
SetError(status);
}
@@ -440,10 +438,6 @@
n.WaitForNotification();
CHECK(!status.ok())
<< "PollForError returned OK status. Should always return an error.";
- LOG(ERROR)
- << "PollForError returned with status (this can be an error from this or "
- "another task): "
- << status;
return status;
}
@@ -628,7 +622,7 @@
} else {
LOG(ERROR)
<< "Failed to disconnect from coordination service with status: "
- << status
+ << TrimCoordinationErrorMessage(status)
<< "\nProceeding with agent shutdown anyway. This is usually caused "
"by an earlier error during execution. Check the logs (this task "
"or the leader) for an earlier error to debug further.";
@@ -893,11 +887,12 @@
assert(!error.ok());
absl::MutexLock l(&state_mu_);
if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return;
+ absl::Status trimmed_error = TrimCoordinationErrorMessage(error);
- LOG(ERROR) << "Coordination agent is set to ERROR: " << error;
+ LOG(ERROR) << "Coordination agent is set to ERROR: " << trimmed_error;
state_ = CoordinatedTaskState::TASKSTATE_ERROR;
- status_ = error;
- error_fn_(error);
+ status_ = trimmed_error;
+ error_fn_(trimmed_error);
}
absl::Status CoordinationServiceAgentImpl::ActivateWatch(
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc
index ee2eb23..1281ea8 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc
@@ -454,6 +454,58 @@
EXPECT_TRUE(agent_->IsError());
}
+TEST_F(CoordinationServiceAgentTest, CancelledPollForErrorRequest) {
+ // Connect coordination agent.
+ PollForErrorResponse mocked_response;
+ EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _))
+ .WillOnce(DoAll(SetArgPointee<2>(mocked_response),
+ InvokeArgument<3>(absl::CancelledError("Test Error."))));
+
+ CoordinationServiceConfig config;
+ config.set_poll_for_error_from_service_at_startup(true);
+ InitializeAgent(config);
+ TF_ASSERT_OK(agent_->Connect());
+ // Wait a bit for the error polling thread to start.
+ absl::SleepFor(absl::Seconds(2));
+ // Cancelled error polling request will not set agent to error.
+ ASSERT_FALSE(agent_->IsError());
+}
+
+TEST_F(CoordinationServiceAgentTest, InvalidPollForErrorRequest) {
+ // Connect coordination agent.
+ PollForErrorResponse mocked_response;
+ EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _))
+ .WillOnce(
+ DoAll(SetArgPointee<2>(mocked_response),
+ InvokeArgument<3>(absl::InvalidArgumentError("Test Error."))));
+
+ CoordinationServiceConfig config;
+ config.set_poll_for_error_from_service_at_startup(true);
+ InitializeAgent(config);
+ TF_ASSERT_OK(agent_->Connect());
+ // Wait a bit for the error polling thread to start.
+ absl::SleepFor(absl::Seconds(2));
+ ASSERT_TRUE(agent_->IsError());
+}
+
+TEST_F(CoordinationServiceAgentTest,
+ PollForErrorRequestWithFailedPrecondition) {
+ // Connect coordination agent.
+ PollForErrorResponse mocked_response;
+ EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _))
+ .WillOnce(DoAll(
+ SetArgPointee<2>(mocked_response),
+ InvokeArgument<3>(absl::FailedPreconditionError("Test Error."))));
+
+ CoordinationServiceConfig config;
+ config.set_poll_for_error_from_service_at_startup(true);
+ InitializeAgent(config);
+ TF_ASSERT_OK(agent_->Connect());
+ // Wait a bit for the error polling thread to start.
+ absl::SleepFor(absl::Seconds(2));
+ ASSERT_TRUE(agent_->IsError());
+}
+
TEST_F(CoordinationServiceAgentTest, ResetCanBeRetried) {
// Mock reset error failing for the first time.
EXPECT_CALL(*GetClient(), ResetTaskAsync(_, _, _))
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc
new file mode 100644
index 0000000..8fc7631
--- /dev/null
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc
@@ -0,0 +1,75 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h"
+
+#include <optional>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/str_cat.h"
+#include "tsl/platform/regexp.h"
+
+namespace tsl {
+absl::Status TrimCoordinationErrorMessage(const absl::Status& s) {
+ if (s.ok()) {
+ return s;
+ }
+ auto status_message = std::string(s.message());
+ auto additional_info_index = status_message.find("Additional GRPC");
+ // This error didn't come from gRPC, so we don't need to trim it.
+ if (additional_info_index == std::string::npos) {
+ return s;
+ }
+
+ std::optional<absl::Cord> payload =
+ s.GetPayload(CoordinationErrorPayloadKey());
+ if (!payload.has_value() && absl::IsUnavailable(s)) {
+ // This error is not provided by us, so it's probably an RPC layer error.
+ auto prefix_message =
+ "Failed to send RPC to coordination service. Either the leader task "
+ "died/restarted unexpectedly or this task is experiencing network "
+ "issues. Check earlier logs from this task and the "
+ "leader (usually slice 0 process/task/worker 0) to debug further.\n";
+ status_message = absl::StrCat(
+ prefix_message,
+ // Replace the duplicated error message at the start with the prefix.
+ status_message.substr(additional_info_index));
+ } else {
+ // Extract RPC called.
+ std::string rpc_name;
+ // Note: it is unfortunate that we have to keep the tensorflow prefix
+ // because that's the RPC service proto namespace.
+ RE2::PartialMatch(status_message,
+ "(/tensorflow.CoordinationService/(\\w+))", &rpc_name);
+ // Erase duplicated error message.
+ status_message = status_message.substr(0, additional_info_index);
+ absl::StrAppend(&status_message, "\nRPC: ", rpc_name);
+ }
+ auto trimmed_status = absl::Status(s.code(), status_message);
+ // Reattach payload.
+ if (payload.has_value()) {
+ trimmed_status.SetPayload(CoordinationErrorPayloadKey(), *payload);
+ }
+#if defined(PLATFORM_GOOGLE)
+ // Reattach source locations.
+ for (const auto& source_location : s.GetSourceLocations()) {
+ trimmed_status.AddSourceLocation(source_location);
+ }
+#endif
+ return trimmed_status;
+}
+} // namespace tsl
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h
index 4555a4e..e1a3cdc 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h
@@ -55,6 +55,14 @@
absl::Cord(payload.SerializeAsString()));
return s;
}
+
+// Trims the error message by replacing the `Additional GRPC error` part.
+// Note: The duplicated error message is a quirk of the underlying gRPC code
+// that we are using. Changing the shared code may hide important messages for
+// other libraries, so we trim the error message for coordination service
+// instead. See tsl/distributed_runtime/rpc/grpc_state.h for more details.
+absl::Status TrimCoordinationErrorMessage(const absl::Status& s);
+
} // namespace tsl
#endif // XLA_TSL_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_ERROR_UTIL_H_
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc
index 3c19fa5..535f471 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc
@@ -17,6 +17,7 @@
#include <string>
#include "absl/status/status.h"
+#include "absl/strings/match.h"
#include "tsl/platform/test.h"
#include "tsl/protobuf/coordination_service.pb.h"
namespace tsl {
@@ -99,5 +100,54 @@
EXPECT_EQ(actual_payload.is_reported_error(), payload.is_reported_error());
}
+TEST(CoordinationServiceErrorUtil,
+ TrimCoordinationErrorMessage_CoordinationError) {
+ absl::Status error = MakeCoordinationError(absl::InternalError(
+ "Coordination service has stopped. RecordHeartbeat() from task: "
+ "/job:jax_worker/replica:0/task:2 failed. Additional GRPC error "
+ "information from remote target coordination_service while calling "
+ "/tensorflow.CoordinationService/Heartbeat::UNKNOWN:Error received from "
+ "peer "
+ "{file:'third_party/grpc/src/core/lib/surface/filter_stack_call.cc', "
+ "file_line:464, created_time:'2024-08-05T13:57:51.331198242-07:00', "
+ "grpc_status:13, grpc_message:'Coordination service has stopped. "
+ "RecordHeartbeat() from task: /job:jax_worker/replica:0/task:2 failed. "
+ "'} "));
+
+ absl::Status trimmed_error = TrimCoordinationErrorMessage(error);
+ EXPECT_EQ(trimmed_error.code(), error.code());
+ EXPECT_EQ(trimmed_error.message(),
+ "Coordination service has stopped. RecordHeartbeat() from task: "
+ "/job:jax_worker/replica:0/task:2 failed. \nRPC: "
+ "/tensorflow.CoordinationService/Heartbeat");
+ // Payload exists but has no value.
+ EXPECT_EQ(trimmed_error.GetPayload(CoordinationErrorPayloadKey()).value(),
+ "");
+}
+
+TEST(CoordinationServiceErrorUtil, TrimCoordinationErrorMessage_NetworkError) {
+ absl::Status error = absl::UnavailableError(
+ "failed to connect to all addresses; last error: UNKNOWN: "
+ "ipv4:127.0.0.1:10001: Failed to connect to remote host: Connection "
+ "refused. Additional GRPC error information from remote target "
+ "coordination_service while calling "
+ "/tensorflow.CoordinationService/Heartbeat::UNKNOWN:Error received from "
+ "peer "
+ "{file:'third_party/grpc/src/core/lib/surface/filter_stack_call.cc', "
+ "file_line:464, created_time:'2024-08-05T13:57:53.123562608-07:00', "
+ "grpc_status:14, grpc_message:'failed to connect to all addresses; last "
+ "error: UNKNOWN: ipv4:127.0.0.1:10001: Failed to connect to remote host: "
+ "Connection refused'} ");
+
+ absl::Status trimmed_error = TrimCoordinationErrorMessage(error);
+ auto message = trimmed_error.message();
+ EXPECT_EQ(trimmed_error.code(), error.code());
+ EXPECT_TRUE(absl::StrContains(message, "Check earlier logs"));
+ // Message is not duplicated.
+ EXPECT_EQ(message.find("failed to connect"),
+ message.rfind("failed to connect"))
+ << trimmed_error;
+}
+
} // namespace
} // namespace tsl
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc
index 0fb11db..2fa5001 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc
@@ -1199,6 +1199,9 @@
GetTaskName(GetTask(1)))); // First task at barrier.
EXPECT_TRUE(absl::StrContains(barrier_status_0.message(),
GetTaskName(GetTask(2)))); // Timed-out task.
+ EXPECT_TRUE(absl::StrContains(
+ barrier_status_0.message(),
+ "2/3")); // Number of tasks at barrier / total number of tasks.
}
TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) {
@@ -1820,10 +1823,77 @@
coord_service_->PollForErrorAsync(
task_0_, [&](const absl::Status& status) { s = status; });
- EXPECT_THAT(s, StatusIs(absl::StatusCode::kInvalidArgument,
+ EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
HasSubstr("has not been registered")));
}
+TEST_F(CoordinateTwoTasksTest,
+ AllowPollForErrorWithinGracePeriodIfTaskHasShutDown) {
+ EnableCoordinationService(/*has_service_to_client_connection=*/false);
+ absl::Status s;
+ ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+ ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
+ coord_service_->ShutdownTaskAsync(task_0_,
+ [&](const absl::Status& status) {});
+ coord_service_->ShutdownTaskAsync(task_1_,
+ [&](const absl::Status& status) {});
+
+ coord_service_->PollForErrorAsync(
+ task_0_, [&](const absl::Status& status) { s = status; });
+ // Stop the service.
+ coord_service_.reset();
+ // The error polling request will still proceed because of grace period. It
+ // will be cancelled.
+ EXPECT_THAT(s, StatusIs(absl::StatusCode::kCancelled));
+}
+
+TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfTaskHasShutDown) {
+ EnableCoordinationService(/*has_service_to_client_connection=*/false);
+ absl::Status s;
+ ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+ ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
+ coord_service_->ShutdownTaskAsync(task_0_,
+ [&](const absl::Status& status) {});
+ coord_service_->ShutdownTaskAsync(task_1_,
+ [&](const absl::Status& status) {});
+
+ // Sleep past the grace period.
+ Env::Default()->SleepForMicroseconds(
+ absl::ToInt64Microseconds(2 * kHeartbeatTimeout));
+ coord_service_->PollForErrorAsync(
+ task_0_, [&](const absl::Status& status) { s = status; });
+ EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
+ HasSubstr("has disconnected")));
+}
+
+TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorAfterReset) {
+ EnableCoordinationService(/*has_service_to_client_connection=*/false);
+ absl::Status s;
+ ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+ ASSERT_OK(coord_service_->ResetTask(task_0_));
+
+ // Sleep past the grace period.
+ Env::Default()->SleepForMicroseconds(
+ absl::ToInt64Microseconds(2 * kHeartbeatTimeout));
+ coord_service_->PollForErrorAsync(
+ task_0_, [&](const absl::Status& status) { s = status; });
+ EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
+ HasSubstr("has disconnected")));
+}
+
+TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorWhenInErrorState) {
+ EnableCoordinationService(/*has_service_to_client_connection=*/false);
+ absl::Status s;
+ ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+ ASSERT_OK(coord_service_->ReportTaskError(task_0_,
+ absl::InternalError("test_error")));
+
+ coord_service_->PollForErrorAsync(
+ task_0_, [&](const absl::Status& status) { s = status; });
+ EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
+ HasSubstr("test_error")));
+}
+
TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfServiceHasStopped) {
EnableCoordinationService(/*has_service_to_client_connection=*/false);
ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
index 4afe13f..6bd7885 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
@@ -360,9 +360,7 @@
CoordinationClient* NewGrpcCoordinationClient(
std::shared_ptr<::grpc::Channel> channel) {
- // TODO(hanyangtay): Pass in the logical task name for better logging.
- return new GrpcCoordinationClient(
- channel, /*target=*/"unknown_target_for_coordination_leader");
+ return new GrpcCoordinationClient(channel, /*target=*/"coordination_service");
}
} // namespace tsl
diff --git a/third_party/xla/xla/tsl/framework/cancellation.cc b/third_party/xla/xla/tsl/framework/cancellation.cc
index d0a841f..7802eb9 100644
--- a/third_party/xla/xla/tsl/framework/cancellation.cc
+++ b/third_party/xla/xla/tsl/framework/cancellation.cc
@@ -103,7 +103,7 @@
bool CancellationManager::RegisterCallbackWithErrorLogging(
CancellationToken token, CancelCallback callback,
- tsl::StringPiece callback_name) {
+ absl::string_view callback_name) {
return RegisterCallbackConfig(
token, CallbackConfiguration{callback, std::string(callback_name), true});
}
diff --git a/third_party/xla/xla/tsl/framework/cancellation.h b/third_party/xla/xla/tsl/framework/cancellation.h
index 56076c8..38f7ebf 100644
--- a/third_party/xla/xla/tsl/framework/cancellation.h
+++ b/third_party/xla/xla/tsl/framework/cancellation.h
@@ -135,7 +135,7 @@
// callback, which will be displayed on the log.
bool RegisterCallbackWithErrorLogging(CancellationToken token,
CancelCallback callback,
- tsl::StringPiece callback_name);
+ absl::string_view callback_name);
// Deregister the callback that, when registered, was associated
// with the given cancellation token. Returns true iff the callback
diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.cc b/third_party/xla/xla/tsl/framework/device_id_utils.cc
index a751a3a..812b119 100644
--- a/third_party/xla/xla/tsl/framework/device_id_utils.cc
+++ b/third_party/xla/xla/tsl/framework/device_id_utils.cc
@@ -29,12 +29,6 @@
#include "tsl/platform/str_util.h"
namespace tsl {
-namespace {
-int GetTfDeviceIdFromDeviceParsedName(
- const DeviceNameUtils::ParsedName& device_name) {
- return device_name.id;
-}
-} // namespace
void CheckValidTfDeviceId(const DeviceType& type,
const int visible_device_count,
@@ -62,7 +56,7 @@
std::iota(visible_device_order->begin(), visible_device_order->end(), 0);
} else {
const std::vector<std::string> order_str =
- tsl::str_util::Split(visible_device_list, ',');
+ tsl::str_util::Split(visible_device_list, ','); // non-absl ok
for (const std::string& platform_device_id_str : order_str) {
int32_t platform_device_id;
if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
@@ -126,7 +120,7 @@
absl::StatusOr<int> GetPlatformDeviceIdFromDeviceParsedName(
const DeviceNameUtils::ParsedName& device_name,
const DeviceType& device_type) {
- const TfDeviceId tf_device_id(GetTfDeviceIdFromDeviceParsedName(device_name));
+ const TfDeviceId tf_device_id(GetDeviceIdFromDeviceParsedName(device_name));
PlatformDeviceId platform_device_id;
absl::Status platform_id_status = DeviceIdManager::TfToPlatformDeviceId(
device_type, tf_device_id, &platform_device_id);
@@ -136,15 +130,10 @@
return platform_id_status;
}
-absl::StatusOr<int> GetDeviceIdFromDeviceParsedName(
- const DeviceNameUtils::ParsedName& device_name,
- const DeviceType& device_type) {
- auto platform_id =
- GetPlatformDeviceIdFromDeviceParsedName(device_name, device_type);
- if (platform_id.ok()) {
- return *platform_id;
- }
- return GetTfDeviceIdFromDeviceParsedName(device_name);
+int GetDeviceIdFromDeviceParsedName(
+ const DeviceNameUtils::ParsedName& device_name) {
+ // This assumes that TF device ID is the same as PJRT local device ID.
+ return device_name.id;
}
} // namespace tsl
diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.h b/third_party/xla/xla/tsl/framework/device_id_utils.h
index d25ae1c..0da5969 100644
--- a/third_party/xla/xla/tsl/framework/device_id_utils.h
+++ b/third_party/xla/xla/tsl/framework/device_id_utils.h
@@ -60,12 +60,9 @@
const DeviceNameUtils::ParsedName& device_name,
const DeviceType& device_type);
-// TODO(b/293324740): support virtual devices.
-// Returns the corresponding PlatformDeviceId if it is found. Otherwise returns
-// the id in device_name.
-absl::StatusOr<int> GetDeviceIdFromDeviceParsedName(
- const DeviceNameUtils::ParsedName& device_name,
- const DeviceType& device_type);
+// Returns the id in device_name.
+int GetDeviceIdFromDeviceParsedName(
+ const DeviceNameUtils::ParsedName& device_name);
} // namespace tsl
diff --git a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc
index da12c3b..e230d85 100644
--- a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc
+++ b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc
@@ -182,11 +182,7 @@
DeviceNameUtils::ParsedName device_name;
device_name.id = 0;
- TF_ASSERT_OK_AND_ASSIGN(int device_id,
- GetDeviceIdFromDeviceParsedName(
- device_name, DeviceType(kTestDeviceType)));
-
- EXPECT_EQ(device_id, 1);
+ EXPECT_EQ(GetDeviceIdFromDeviceParsedName(device_name), 0);
DeviceIdManager::TestOnlyReset();
}
@@ -194,11 +190,7 @@
DeviceNameUtils::ParsedName device_name;
device_name.id = 0;
- TF_ASSERT_OK_AND_ASSIGN(int device_id,
- GetDeviceIdFromDeviceParsedName(
- device_name, DeviceType(kTestDeviceType)));
-
- EXPECT_EQ(device_id, 0);
+ EXPECT_EQ(GetDeviceIdFromDeviceParsedName(device_name), 0);
}
} // namespace
diff --git a/third_party/xla/xla/tsl/lib/strings/BUILD b/third_party/xla/xla/tsl/lib/strings/BUILD
new file mode 100644
index 0000000..03f82a3
--- /dev/null
+++ b/third_party/xla/xla/tsl/lib/strings/BUILD
@@ -0,0 +1,57 @@
+load(
+ "@local_tsl//tsl/platform:rules_cc.bzl",
+ "cc_library",
+)
+load("//xla/tsl:tsl.bzl", "internal_visibility")
+load("//xla/tsl:tsl.default.bzl", "filegroup")
+
+# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+
+cc_library(
+ name = "proto_serialization",
+ srcs = ["proto_serialization.cc"],
+ hdrs = ["proto_serialization.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/lib/gtl:inlined_vector",
+ "@local_tsl//tsl/platform:hash",
+ "@local_tsl//tsl/platform:logging",
+ "@local_tsl//tsl/platform:macros",
+ "@local_tsl//tsl/platform:protobuf",
+ ],
+)
+
+filegroup(
+ name = "mobile_srcs_only_runtime",
+ srcs = [
+ "proto_serialization.cc",
+ "proto_serialization.h",
+ ],
+ visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
+
+filegroup(
+ name = "legacy_lib_strings_all_headers",
+ srcs = [
+ "proto_serialization.h",
+ ],
+ visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
+
+filegroup(
+ name = "legacy_lib_string_headers",
+ srcs = [
+ "proto_serialization.h",
+ ],
+ visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
+
+filegroup(
+ name = "legacy_lib_internal_public_string_headers",
+ srcs = [
+ "proto_serialization.h",
+ ],
+ visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
diff --git a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc
new file mode 100644
index 0000000..06ef074
--- /dev/null
+++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/tsl/lib/strings/proto_serialization.h"
+
+#include <cstring>
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tsl/lib/gtl/inlined_vector.h"
+#include "tsl/platform/hash.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/macros.h"
+
+namespace tsl {
+namespace {
+
+// Helper for deterministic serialization.
+class DeterministicSerializer {
+ public:
+ explicit DeterministicSerializer(const protobuf::MessageLite& msg)
+ : DeterministicSerializer(msg, msg.ByteSizeLong()) {}
+
+ DeterministicSerializer(const protobuf::MessageLite& msg, size_t size)
+ : size_(size) {
+ char* ptr = space_;
+ if (size_ > sizeof(space_)) {
+ ptr = new char[size_];
+ alloc_.reset(ptr);
+ }
+ bool ok = SerializeToBufferDeterministic(msg, ptr, size_);
+ DCHECK(ok);
+ }
+
+ size_t size() const { return size_; }
+ const char* data() const { return alloc_ == nullptr ? space_ : alloc_.get(); }
+
+ private:
+ // Avoid InlinedVector since it causes 2x slowdown in the compilation
+ // of graphs containing large tensors in debug mode.
+ static constexpr int kInlinedBufferSize = 256;
+ const size_t size_;
+ std::unique_ptr<char[]> alloc_;
+ char space_[kInlinedBufferSize];
+};
+} // namespace
+
+bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
+ string* result) {
+ const size_t size = msg.ByteSizeLong();
+ DCHECK_LE(size, static_cast<size_t>(INT_MAX));
+ *result = string(size, '\0');
+ return SerializeToBufferDeterministic(msg, const_cast<char*>(result->data()),
+ result->size());
+}
+
+bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
+ char* buffer, size_t size) {
+ DCHECK(msg.ByteSizeLong() == size && size <= static_cast<size_t>(INT_MAX));
+ protobuf::io::ArrayOutputStream array_stream(buffer, size);
+ protobuf::io::CodedOutputStream output_stream(&array_stream);
+ output_stream.SetSerializationDeterministic(true);
+ msg.SerializeWithCachedSizes(&output_stream);
+ return !output_stream.HadError() &&
+ size == static_cast<size_t>(output_stream.ByteCount());
+}
+
+bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
+ const protobuf::MessageLite& y) {
+ const size_t size = x.ByteSizeLong();
+ if (size != y.ByteSizeLong()) return false;
+ if (size == 0) return true;
+ DeterministicSerializer x_serialized(x, size);
+ DeterministicSerializer y_serialized(y, size);
+ return memcmp(x_serialized.data(), y_serialized.data(), size) == 0;
+}
+
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
+ uint64 seed) {
+ DeterministicSerializer serialized(proto);
+ return Hash64(serialized.data(), serialized.size(), seed);
+}
+
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto) {
+ DeterministicSerializer serialized(proto);
+ return Hash64(serialized.data(), serialized.size());
+}
+
+} // namespace tsl
diff --git a/third_party/xla/xla/tsl/lib/strings/proto_serialization.h b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h
new file mode 100644
index 0000000..b79e9af
--- /dev/null
+++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
+#define XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
+
+#include "tsl/platform/protobuf.h"
+
+namespace tsl {
+
+// Wrapper around protocol buffer serialization that requests deterministic
+// serialization, in particular for Map fields, which serialize in a random
+// order by default. Returns true on success.
+// Serialization is guaranteed to be deterministic for a given binary only.
+// See the following for more details:
+// https://github.com/google/protobuf/blob/a1bb147e96b6f74db6cdf3c3fcb00492472dbbfa/src/google/protobuf/io/coded_stream.h#L834
+bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
+ string* result);
+
+// As above, but takes a pre-allocated buffer wrapped by result.
+// PRECONDITION: size == msg.ByteSizeLong() && size <= INT_MAX.
+bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
+ char* buffer, size_t size);
+
+// Returns true if serializing x and y using
+// SerializeToBufferDeterministic() yields identical strings.
+bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
+ const protobuf::MessageLite& y);
+
+// Computes Hash64 of the output of SerializeToBufferDeterministic().
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto);
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
+ uint64 seed);
+
+} // namespace tsl
+
+#endif // XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto
index b43944c..620f1b6 100644
--- a/third_party/xla/xla/xla.proto
+++ b/third_party/xla/xla/xla.proto
@@ -875,7 +875,15 @@
// TODO(b/355487968): Remove this option when validation complete.
bool xla_enable_command_buffers_during_profiling = 317;
- // Next id: 318
+ // Limit for the number of kernel configurations (plans) to use during
+ // autotuning of cuDNN GEMM fusions. The more - the slower the autotuning
+ // but potentially higher the performance.
+ int32 xla_gpu_cudnn_gemm_max_plans = 318;
+
+ // If enabled, uses the libnvjitlink library for PTX compilation and linking
+ bool xla_gpu_enable_libnvjitlink = 319;
+
+ // Next id: 320
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto
index d394361..335b59e 100644
--- a/third_party/xla/xla/xla_data.proto
+++ b/third_party/xla/xla/xla_data.proto
@@ -661,6 +661,13 @@
// The dimension in the start_indices input that contains the starting
// indices.
int64 index_vector_dim = 4;
+
+ // This is the batch dimensions in the operand.
+ repeated int64 operand_batching_dims = 5;
+
+ // This is the batch dimensions in the index, and it should be the same size
+ // as operand_batching_dims.
+ repeated int64 start_indices_batching_dims = 6;
}
// Describes the dimension numbers for a scatter operation.
@@ -675,6 +682,12 @@
repeated int64 scatter_dims_to_operand_dims = 3;
int64 index_vector_dim = 4;
+
+ // This is the batch dimension in the input.
+ repeated int64 input_batching_dims = 5;
+
+ // This is the batch dimension in the index.
+ repeated int64 scatter_indices_batching_dims = 6;
}
message ConvolutionDimensionNumbers {